Source code for alien.benchmarks.oracle

from abc import abstractmethod

import numpy as np

from ..sample_generation.from_set import SetSampleGenerator

# TODO: Docstrings


[docs]class Oracle:
[docs] @abstractmethod def get_label(self, x, remove=False): pass
[docs] def get_labels(self, x, remove=False): labels = [] for x_val in x: labels.append(self.get_label(x_val, remove=remove)) return np.array(labels)
[docs]class SetOracle(SetSampleGenerator, Oracle): def __init__(self, *args, **kwargs): super().__init__(*args, Xy=True, **kwargs)
[docs] def get_label(self, x, remove=False): indices = self.data.find(x, first=False) if remove: self.remove_data_indices(indices) return self.labels[indices[0]]