Source code for beamds.beam.similarity.core

from typing import Any

import numpy as np
from dataclasses import dataclass

from ..data import BeamData
from ..processor import Processor
from ..type.utils import is_beam_data
from ..utils import as_scipy_csr, as_scipy_coo, as_numpy, as_tensor


[docs] @dataclass class Similarities: index: Any distance: Any sparse_scores: Any = None metric: str = None model: str = None
[docs] class BeamSimilarity(Processor): def __init__(self, *args, metric=None, **kwargs): super().__init__(*args, metric=metric, **kwargs) self.metric = self.hparams.metric self.index = None self._is_trained = None self._is_range_index = None self._is_numeric_index = None self.reset() @property def is_trained(self): return self._is_trained
[docs] def reset(self): self.index = np.array([]) self._is_trained = False
[docs] @staticmethod def extract_data_and_index(x, index=None, convert_to='numpy'): if is_beam_data(x): index = x.index x = x.values if convert_to is None: pass elif convert_to == 'numpy': x = as_numpy(x) elif convert_to == 'tensor': x = as_tensor(x) elif convert_to == 'scipy_csr': x = as_scipy_csr(x) elif convert_to == 'scipy_coo': x = as_scipy_coo(x) else: raise ValueError(f"Unknown conversion: {convert_to}") return x, as_numpy(index)
@property def metric_type(self): return self.metric
[docs] def add(self, x, index=None, **kwargs): raise NotImplementedError
[docs] def search(self, x, k=1) -> Similarities: raise NotImplementedError
[docs] def train(self, x): raise NotImplementedError
[docs] def remove_ids(self, ids): raise NotImplementedError
[docs] def reconstruct(self, id0): raise NotImplementedError
[docs] def reconstruct_n(self, id0, id1): raise NotImplementedError
@property def ntotal(self): if self.index is not None: return len(self.index) return 0 def __len__(self): return self.ntotal
[docs] def get_index(self, index): return self.index[as_numpy(index)]
[docs] def add_index(self, x, index=None): if self.index is None or not len(self.index): if index is None: try: l = len(x) except TypeError: l = x.shape[0] index = np.arange(l) self._is_range_index = True self._is_numeric_index = True else: index = as_numpy(index) self._is_range_index = False if index.dtype.kind in 'iuf': self._is_numeric_index = True self.index = index else: if index is None: index = np.arange(len(x)) + self.index.max() + 1 else: index = as_numpy(index) self.index = np.concatenate([self.index, index]) return index