Source code for alien.selection.selector

"""Base class for selecting samples and helper functions."""
from abc import ABCMeta, abstractmethod
from numbers import Number

import numpy as np
import scipy.stats

from ..classes import final, override
from ..data import DictDataset
from ..decorators import get_defaults_from_self
from ..sample_generation import SampleGenerator
from ..stats import augment_ensemble
from ..utils import concatenate, isint


[docs]class SampleSelector(metaclass=ABCMeta): """ Abstract base class for selection strategies :param model: An instance of models.RegressionModel. Will be used to determine prediction covariances for proposed batches. :param samples: The sample pool to select from. Can be a numpy-style addressable array (with first dimension indexing samples, and other dimensions indexing features)---note that :class:`alien.data.Dataset` serves this purpose---or an instance of :class:`sample_generation.SampleGenerator`, in which case the num_samples parameter is in effect. :param num_samples: If a `SampleGenerator` has been provided via the 'samples' parameter, then at the start of a call to :meth:`.select`, `num_samples` samples will be drawn from the `SampleGenerator`, or as many samples as the `SampleGenerator` can provide, whichever is less. Defaults to Inf, i.e., draws as many samples as available. :param labelled_samples: Some selection strategies need to know the previously-labelled samples. :param batch_size: Size of the batch to select. :param prior: Specifies a "prior probability" for each sample. Each selector may use this prior as it sees fit, but generally, samples with low prior are de-emphasized in the selection process. This is a covenient way of introducing factors other than uncertainty into the ranking. `prior` may be an array of numbers (of size num_samples), or a function (applied to the samples), or one of the following: 'prediction': calculates a prior from the quantile of the predicted performance (not the uncertainties). `prior_scale` sets the power this quantile is raised to. Defaults to the constant value 1. :param prior_scale: The prior will be raised to this power before applying it to the samples. Defaults to 1. :param prefilter: Reduces the incoming sample pool before applying batch selection. If `a` is the single-sample acquisition function, then `prefilter = True` selects a subset of the provided samples maximizing `a` If 0 < prefilter < 1, takes this fraction of the sample pool. If prefilter >= 1, takes this many samples. Some of the selectors are limited in how many samples they can consider for the final, batch-selection problem. For example, :class:`CovarianceSelector` computes and stores the size-N^2 covariance matrix for the whole sample pool; therefore, because of memory constraints it should work with at most around 10,000 samples. In such cases, there is often a cheaper prefiltering operation available. Eg., :class:`CovarianceSelector` prefilters only with the variance, rather than the full covariance. A practical strategy in such cases is to take a sample pool about 5 times as big as the selector can handle for the final computation, then narrow down to only the top 20% individual scores before batch selection. Narrowing to much less than 20% risks reducing diversity too much and changing what would ultimately be the selected batch. :param random_seed: A random seed for deterministic behaviour. :param return_indices: If True, :meth:`.select` will return the indices of the selection (from within the given sample pool). If False, :meth:`.select` will return the actual selected samples. Defaults to `False`. """ def __init__( self, model=None, batch_size=1, samples=None, num_samples=None, labelled_samples=None, X_key="X", prior=None, prior_scale=1, return_indices=False, verbose=1, ): super().__init__() self.model = model self.labelled_samples = labelled_samples self.samples = samples self.num_samples = num_samples self.X_key = X_key self.batch_size = batch_size if prior == "prediction": self.prior_func = self.prediction_prior elif prior is None: self.prior_func = lambda X: 1 else: self.prior_func = prior if prior_scale != 1: self.prior_func = lambda X: np.power(self.prior_func(X), prior_scale) self._last_pred = None self._last_std = None self._last_X = None self.return_indices = return_indices self.verbose = verbose def prediction_prior(self, samples): # TODO: Maybe this should use self.samples if samples is not given pred = self.model_predict(samples) return scipy.stats.rankdata(pred) / len(np.ravel(pred)) def model_predict(self, X, return_std_dev=False): if X is not self._last_X: self._last_pred = self.model.predict(X) self._last_X = X return self._last_pred
[docs] @final @get_defaults_from_self def select( self, batch_size=None, samples=None, num_samples=None, prior=None, X_key=None, fixed_samples=None, fixed_prior=None, return_indices=None, tail_call=None, **kwargs, ): """ Selects a batch from the provided samples, and returns it (or its indices). All of the arguments to :meth:`.select` are optional. If you have provided `samples` and other necessary parameters to the constructor already, then you may omit them here. However, some parameters here are *not* in the constructor: `fixed_samples` and `fixed_prior`. Args: batch_size (int, optional): The size of the batch to select. samples (ArrayLike, optional): The sample pool to select from. Can be a numpy-style addressable array (with first dimension indexing samples, and other dimensions indexing features)---note that :class:`alien.data.Dataset` serves this purpose---or an instance of :class:`sample_generation.SampleGenerator`, in which case the num_samples parameter isnin effect. num_samples If a `SampleGenerator` has been provided via the 'samples' parameter, then at the start of a call to :meth:`.select`, `num_samples` samples will be drawn from the `SampleGenerator`, or as many samples as the `SampleGenerator` can provide, whichever is less. Defaults to Inf, i.e., draws as many samples as available. prior A "prior probability" for each sample. May be an array of numbers, a function, or the string `'prediction'`. A more detailed explanation is above, in the class definition. Defaults to the constant value 1. prior_scale The prior will be raised to this power before applying it to the samples. Defaults to 1. prefilter Reduces the incoming sample pool before applying batch selection. If 0 < prefilter < 1, we use this fraction of the sample pool. If prefilter >= 1, we use this many samples. A more detailed explanation if above, in the class definition. return_indices If True, :meth:`.select` will return the indices of the selection (from within the given sample pool). If False, :meth:`.select` will return the actual selected samples. Defaults to `False`. X_key The key used to extract the X values from `samples`. I.e., `X = samples[X_key]`. This is only in effect if you pass an explicit value to `X_key`, or if `samples` is a :class:`DictDatabase` with key `'X'`. By default, `X = samples`. fixed_samples (ArrayLike, optional) This parameter is for passing in those samples which have been previously selected for labeling, but which haven't been labeled yet. (Eg., you've previously sent off a batch to the laboratory pipeline for testing, but you need to select the next batch for the pipeline before the results are in.) Some selection strategies (eg., :class:`CovarianceSelector` and :class:`BAITSelector`) will use this information to avoid redundancy between the newly selected batch and the `fixed_samples`. fixed_prior: If you provide an explicit (i.e., array-like) prior for `samples`, then you must also provide a prior for `fixed_samples`. Returns: The selected batch, either as a sub-array of `samples`, or as an array of indices into `samples` (if `return_indices` is set to True). """ self._last_X = None # If 'samples' is a SampleGenerator, generate a pool of samples if isinstance(samples, SampleGenerator): samples = samples.generate_samples( float("inf") if num_samples is None else num_samples ) # If `samples` has a key or attribute `X_key`, pull this X out of # samples if isinstance(samples, DictDataset) and hasattr(samples, X_key): kwargs["full_samples"] = samples X = getattr(samples, X_key) if fixed_samples is not None: kwargs["full_fixed_samples"] = fixed_samples fixed_X = getattr(fixed_samples, X_key) else: fixed_X = None else: X, fixed_X = samples, fixed_samples # prior is unspecified, get it by calling self.prior_func if prior is None: if self.prior_func is not None: if fixed_X is not None: prior = self.prior_func(concatenate(fixed_X, X)) fixed_prior, prior = prior[: len(fixed_X)], prior[len(fixed_X) :] else: prior = self.prior_func(X) else: prior = 1 fixed_prior = 1 if isinstance(prior, Number): prior = np.asarray(prior).reshape((1,)) indices = (self._select if tail_call is None else tail_call)( batch_size=batch_size, samples=X, fixed_samples=fixed_X, prior=prior, fixed_prior=fixed_prior, **kwargs, ) return indices if return_indices else samples[indices]
@abstractmethod def _select(self, batch_size=None, samples=None, prior=None, **kwargs): pass
class UncertaintySelector(SampleSelector): def __init__(self, prefilter=None, **kwargs): super().__init__(**kwargs) self.prefilter = prefilter def model_predict(self, X, return_std_dev=False): if X is not self._last_X: self._last_pred, self._last_std = self.model.predict(X, return_std_dev=True) self._last_X = X return self._last_pred, self._last_std if return_std_dev else self._last_pred def get_prefilter(self, X=None, k=None, prior=1, score=None, return_indices=True): if k is None: k = self.prefilter if k is None or k == 1: return np.arange(len(X)) if return_indices else X if score is None: if X is None: raise ValueError("Must provide either 'X' or 'score' to the method 'prefilter'") _, score = self.model_predict(X, return_std_dev=True) if 0 < k < 1: k = int(round(k * len(score))) else: k = min(k, len(score)) # adjust score by prior score *= prior indices = np.argsort(score)[-int(k) :] return indices if return_indices else X[indices] @final @override def select(self, *args, **kwargs): return super().select(*args, tail_call=self._select_uncertain, **kwargs) @get_defaults_from_self def _select_uncertain( self, batch_size=None, samples=None, prior=None, prefilter=None, **kwargs, ): if prefilter not in {None, 1}: pre_indices = self.get_prefilter( samples, prefilter, prior=prior, return_indices=True, ) samples = samples[pre_indices] prior = prior[pre_indices] indices = self._select( batch_size=batch_size, samples=samples, prior=prior, **kwargs, ) return pre_indices[indices] if prefilter not in {None, 1} else indices # @profile def optimize_batch( scoring_fn, batch_size, samples, n_fixed=0, scoring_opt_step=None, n_tuples=None, n_best=20, n_rounds=10, random_seed=None, callback=None, verbose=1, ): samples, n_tuples = _init_samples(samples, n_tuples=n_tuples) n_samples = len(samples) rng = np.random.default_rng(random_seed) # build distribution using singleton scores if verbose: print("Computing singleton scores...") single = scoring_fn(samples[:, None, ...]) # .reshape((-1,)) if batch_size == 1: return samples[[np.argmax(single.flatten())]] tuples = _generate_initial_batches(samples, n_tuples, batch_size, single, rng, verbose=verbose) # if we have fixed samples, put them in at the front of every batch if n_fixed > 0: fixed_batch = np.broadcast_to(samples[:n_fixed], (n_tuples, n_fixed)) tuples = np.concatenate([fixed_batch, tuples], axis=-1) batch_size += n_fixed scores, n_best, best_tuples = _init_scores( tuples, n_samples, scoring_fn, batch_size, n_best=n_best, verbose=verbose ) if verbose: print("Greedy optimization of best batches...") # Greedy optimization of best_tuples for round_ in range(n_rounds): if verbose: print(f" Optimization round {round_+1} - ", end="") # keeps track of which samples have been used in each batch so far # during optimization used_mask = np.zeros((n_samples, n_best), dtype=bool) used_mask[:n_fixed] = True n_changed = 0 for i in range(n_fixed, batch_size): _run_callback(callback, round_, i, best_tuples) # if n_rounds * batch_size < 100: # print(f" {i+1}: Optimizing over batch coordinate {i}...") scores = _scoring_step( i, best_tuples, samples, n_samples, n_best, batch_size, scoring_fn, scoring_opt_step, ) assert scores.shape == (n_samples, n_best) # rule out batches which repeat a sample scores[used_mask] = -np.inf # Find max score over n_samples for each batch next_steps = scores.argmax(axis=0) assert next_steps.shape == (n_best,) # mask these samples as used for each batch used_mask[next_steps, np.arange(n_best)] = True n_changed += (best_tuples[:, i] != samples[next_steps]).sum(axis=-1) best_tuples[:, i] = samples[next_steps] frac_changed = n_changed / (n_best * batch_size) if verbose: print(f"Changed {(frac_changed*100):.1f} %") if n_changed == 0: break _run_callback(callback, round_, i, best_tuples) if verbose: print("Done...\n") scores = scoring_fn(best_tuples) return best_tuples[np.argmax(scores)][n_fixed:] def precomputed_ensemble_score(model, samples, fn, multiple=1.0, augment_ensemble_size=-1): """ :param fn: scoring function. It should take predicted values of shape (..., batch_size, ensemble_size) I.e., the first axes counts separate batches. 'fn' should consume the last two axes, and return a set of scalar scores, one for each batch. """ preds = model.predict_ensemble(samples, multiple=multiple) if augment_ensemble_size and augment_ensemble_size > preds.shape[-1]: preds = augment_ensemble(preds, augment_ensemble_size, rng=model.rng) del model del samples def score(indices): """ Returns a set of scalar scores, one for each batch. :param indices: indices are of shape (n_batches, batch_size) """ return fn(preds[indices]) return score def _init_samples(samples, n_tuples=None): """Helper function to initialize samples in optimize_batch""" if isint(samples): samples = np.arange(samples) n_samples = len(samples) if n_tuples is None: n_tuples = n_samples return samples, n_tuples def _init_scores( tuples, n_samples, scoring_fn, batch_size, n_best=20, verbose=1, ): """Helper function to initialize scores in optimize_batch""" if verbose: print("Scoring initial batches...") scores = scoring_fn(tuples) # scoring fn should eat 2nd axis (batch_size) n_best = min(n_best, n_samples) best_tuples = tuples[np.argsort(scores)[-n_best:]] assert best_tuples.shape == (n_best, batch_size) return scores, n_best, best_tuples def _generate_initial_batches(samples, n_tuples, batch_size, single, rng, verbose=False): """Helper function to generate initial batches in optimize_batch.""" # Generate a collection of tuples. # Probability density is a linear # function over the ranked sample list. # Somewhat favours higher ranked samples if verbose: print("Generating initial batches...") tuples = np.empty((n_tuples, batch_size), dtype=int) ranked_scores = scipy.stats.rankdata(single) ranked_scores /= np.sum(ranked_scores) for i in range(n_tuples): selected_sample = rng.choice(samples, batch_size, replace=False, p=ranked_scores) tuples[i] = selected_sample assert tuples.shape == (n_tuples, batch_size) return tuples def _scoring_step( i, best_tuples, samples, n_samples, n_best, batch_size, scoring_fn, scoring_opt_step ): """Helper function to perform a scoring step in optimize batch.""" if scoring_opt_step is None: # 3d array, last index broadcast to len(samples)? best_tuples_alternates = np.broadcast_to( best_tuples, (n_samples, n_best, batch_size) ).copy() assert best_tuples_alternates.shape == (n_samples, n_best, batch_size) best_tuples_alternates[:, :, i] = samples[ :, None ] # at one place in the batch, scan over all samples scores = scoring_fn(best_tuples_alternates) # scoring_fn eats axis -1, i.e, batch_size else: # We pass to the scoring function only the fixed samples of each batch # (plus the whole sample space to test for the variable sample). # Scoring function is responsible for generating whole space of counterfactual # batches, and evaluating them scores = scoring_opt_step(np.delete(best_tuples, i, axis=-1), samples) return scores def _run_callback(callback, round_, i, best_tuples): """Helper function to run callback function for a given iteration.""" if callback: callback(round_, i, best_tuples)