Source code for alien.benchmarks.retrospective

import os
import pickle

import numpy as np

from ..data import DictDataset
from ..utils import shift_seed
from .oracle import SetOracle


[docs]def run_experiments( X, y, model, runs_dir, overwrite_old_runs=True, n_initial=None, batch_size=20, num_samples=float("inf"), selector=None, # 'covariance', 'random', 'expected improvement'/'ei', 'greedy' selector_args=None, fit_args=None, n_runs=10, ids=None, save_ids=True, random_seed=1, split_seed=421, test_size=0.2, timestamps=None, stop_samples=None, stop_rmse=None, stop_frac=None, # suggest .85, # The following arguments are a bit odd, and best avoided peek_score=0, # 0 if no peeking test_samples_x=None, test_samples_y=None, ): """ :param runs_dir: directory to store the runs and results of this training (each run in separate subdirectories) :param n_initial: number of samples to randomly select for initial training data :param batch_size: number of samples selected for batch :param num_samples: number of samples (drawn from the sample pool X) to select from. Default is inf, which takes all of the samples available in X. :param selector: the selector to use for batch selection, either given by one of the strings 'covariance', 'random', 'expected improvement'/'ei', 'greedy', or passed as an actual SampleSelector instance. Defaults to 'covariance'. :param selector_args: a dictionary passed as kwargs to the selector constructor. The following constructor arguments are already automatically included, and don't need to be included in this dictionary: model, labelled_samples, samples, num_samples, batch_size :param fit_args: a dictionary passed as kwargs each time model.fit(...) is called. Typically, this is model- or framework-specific; so, eg., different arguments would be appropriate for pytorch models, DeepChem models, etc. :param n_runs: the number of overall runs (each starting from a random initial selection) to do (for averaging) :param random_seed: random seed for most RNG generation :param split_seed: random seed for shuffling and splitting of data :param test_size: the size of the test/validation set to take from X,y. If test_size >= 1, then takes that many samples. if 0 < test_size < 1, takes that fraction of the dataset size. :param stop_samples: if this is not None, stops an experiment run when this many samples are labelled. Defaults to None :param stop_rmse: if this is not None, stops an experiment run when this RMSE has been reached. Defaults to None :param stop_frac: if this is not None, stops an experiment run when the RMSE has moved this fraction of the way from the RMSE after the first round to the RMSE trained on the whole dataset. We suggest something like .85, if you want to use this feature. Defaults to None """ selector, selector_args, fit_args = _args_init(selector, selector_args, fit_args) data = _data_init(X, y, ids=ids, timestamps=timestamps) data_train, test_samples = _get_train_test( data, split_seed, test_size, test_samples_x=test_samples_x, test_samples_y=test_samples_y ) if stop_samples is not None and 0 < stop_samples < 1: stop_samples = len(data_train) * stop_samples if stop_frac is not None: print("Fitting model to whole dataset, to determine stopping limit... ", flush=True) model.fit(data_train.X, data_train.y, **fit_args) final_rmse = np.sqrt( np.mean( np.array(model.predict(test_samples.X.data, return_std_dev=False)) - np.array(test_samples.y.data) ) ) print("Done.") else: final_rmse = None if n_initial is None: n_initial = batch_size n_0 = _overwrite(runs_dir, overwrite_old_runs) selector_args = selector_args.copy() selector_args.update({"model": model, "batch_size": batch_size, "num_samples": num_samples}) selector = _get_selector(selector, selector_args) first_seed = random_seed for n in range(n_0, n_0 + n_runs): run_dir = os.path.join(runs_dir, f"run_{n:02d}") os.makedirs(run_dir, exist_ok=True) random_seed = shift_seed(first_seed, n * 5) _run_retro_iter( n, run_dir, model, data_train, test_samples, batch_size=batch_size, n_initial=n_initial, selector=selector, fit_args=fit_args, save_ids=save_ids, random_seed=random_seed, stop_samples=stop_samples, stop_rmse=stop_rmse, stop_frac=stop_frac, final_rmse=final_rmse, peek_score=peek_score, )
def _run_retro_iter( # NOSONAR n, run_dir, model, unlabelled_samples, test_samples, batch_size, n_initial=None, selector=None, fit_args={}, # NOSONAR save_ids=True, random_seed=None, stop_samples=None, stop_rmse=None, stop_frac=None, # suggest .85, final_rmse=None, peek_score=0, # 0 if no peeking ): """Helper function to run an iteration of an experiment.""" rng = np.random.default_rng(random_seed) model.initialize(init_seed=random_seed, sample_input=test_samples.X[:1]) sample_pool = SetOracle(data=unlabelled_samples, shuffle=False, random_seed=shift_seed(random_seed, 1)) initial_indices = rng.choice(len(sample_pool), size=n_initial, replace=False) labelled_samples = unlabelled_samples[initial_indices] sample_pool.remove_samples(labelled_samples) batch_ids = [list(initial_indices)] if save_ids: _save_ids(batch_ids, run_dir) model.data = labelled_samples selector.labelled_samples = labelled_samples selector.samples = sample_pool from .metrics import RMSE, Scatter, TopScore scatter = Scatter(model=model, test=test_samples, name="test scatter", errs=None) top_scores = TopScore(name="top scores", file_path=f"{run_dir}/top_scores.pickle") rmse = RMSE(name="RMSE", file_path=f"{run_dir}/RMSE.pickle", scatter=scatter) round_ = 0 first_rmse = None while True: print(f"\n\n### Round {round_} (Run {n}) ###") print(f" Model: {model}") print(f" Selector: {selector}") print(f" {len(labelled_samples)} labelled samples") print(f" {len(sample_pool)} samples remaining") print("\nTraining model...") model.fit(**fit_args) print("\nComputing metrics...") labels = labelled_samples.y if peek_score: samples = sample_pool.generate_samples() preds = model.predict(samples) best = np.argsort(preds)[:-peek_score] labels = np.concatenate((labels, samples.y)) top_scores.compute(x=len(labelled_samples), labels=labels, average_over=5) top_scores.save() scatter.compute(samples=len(labelled_samples)) scatter.save(f"{run_dir}/scatter_{round_:02d}.pickle") rmse.compute() rmse.save() if round_ == 0: first_rmse = rmse.y[0] if _check_retro_break( round_, first_rmse, rmse, final_rmse, sample_pool, labelled_samples, stop_samples, stop_frac, stop_rmse, ): break print("\nSelecting batch...") if len(sample_pool) <= batch_size: batch = sample_pool.generate_samples() else: batch = selector.select() print("\nLearning batch...") labelled_samples.extend(batch) sample_pool.remove_samples(batch) batch_ids.append(list(batch.ids.data)) if save_ids: _save_ids(batch_ids, run_dir) round_ += 1 def _save_ids(ids, dir='', filename=None): with open(f"{dir}/batch_ids.json" if filename is None else filename, "w") as f: f.write( "[\n [" + "],\n [".join( ", ".join(str(i) for i in b) for b in ids ) + "]\n]\n" ) def _get_selector(selector, selector_args): """Helper function to spawn a selector.""" if selector in {"expected improvement", "ei"}: from ..selection.expected_improvement import ExpectedImprovementSelector selector = ExpectedImprovementSelector(**selector_args) elif selector == "covariance": from ..selection.covariance import CovarianceSelector selector = CovarianceSelector(**selector_args) elif selector in {"bait", "BAIT"}: from ..selection.bait import BAITSelector selector = BAITSelector(**selector_args) elif selector == "random": from ..selection.random import RandomSelector selector = RandomSelector(**selector_args) elif selector == "greedy": from ..selection.greedy import GreedySelector selector = GreedySelector(**selector_args) elif selector == "kmeans": from ..selection.kmeans import KmeansSelector selector = KmeansSelector(**selector_args) elif selector == "timestamp": from ..selection import TimestampSelector selector = TimestampSelector(**selector_args) else: raise ValueError(f"'{selector}' is an invalid choice of selector.") return selector def _get_train_test(data, split_seed, test_size, test_samples_x=None, test_samples_y=None): """Helper function to get train-test split of the data.""" if test_samples_x is None: shuffle = np.arange(len(data.X)) np.random.default_rng(split_seed).shuffle(shuffle) test_size = test_size if test_size >= 1 else int(len(data.X) * test_size) test_indices, train_indices = shuffle[:test_size], shuffle[test_size:] data_train = data[train_indices] test_samples = DictDataset( { "X": data.X.data[test_indices], "y": data.y.data[test_indices], "ids": data.ids.data[test_indices], } ) else: data_train = data test_samples = DictDataset({"X": test_samples_x, "y": test_samples_y}) return data_train, test_samples def _data_init(X, y, ids=None, timestamps=None): """Helper function to initialize the data object from samples.""" if ids is None: ids = np.arange(len(X)) data_dict = { "X": X, "y": y, "ids": ids, } if timestamps is not None: data_dict["t"] = timestamps data = DictDataset(data_dict) return data def _args_init(selector, selector_args, fit_args): """Helper function to initialize optional args in run_experiments.""" if selector is None: selector = "covariance" if selector_args is None: selector_args = {} if fit_args is None: fit_args = {} return selector, selector_args, fit_args def _overwrite(runs_dir: str, overwrite_old_runs: bool) -> int: """Helper function to overwrite old runs in run_experiments.""" os.makedirs(runs_dir, exist_ok=True) n_0 = 0 if overwrite_old_runs: import shutil shutil.rmtree(runs_dir) os.makedirs(runs_dir) else: while f"run_{n_0:02d}" in os.listdir(runs_dir): n_0 += 1 return n_0 def _check_retro_break( round_, first_rmse, rmse, final_rmse, sample_pool, labelled_samples, stop_samples, stop_frac, stop_rmse, ): """Helper function to check whether to break experiment loop.""" if len(sample_pool) == 0: print("Ending run because we've selected the whole dataset") return True if stop_samples is not None and len(labelled_samples) >= stop_samples: print(f"Ending run because we've reached limit of {stop_samples} samples.") return True if stop_frac is not None and (first_rmse - rmse.y[-1]) >= stop_frac * ( first_rmse - final_rmse ): print("Ending run because RMSE is close enough to final RMSE.") return True if stop_rmse is not None and rmse.y[-1] <= stop_rmse: print("Ending run because RMSE is better than threshold.") return True return False # I was going to create a command-line interface, with argument parser, but I # am busy and am deprioritizing it # def main(): # from argparse import ArgumentParser # parser = ArgumentParser() ## parser.add_argument('--dir', '-d', # initial_samples = 20, # batch_size = 20, # num_samples = float('inf'), # selector = 'greedy', # 'random', 'expected improvement'/'ei', 'covariance', 'greedy' # n_runs = 10, # # random_seed = 1, # split_seed = 421, # test_size = .2,