Source code for alien.benchmarks.uncertainty_metrics

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats


[docs]def KL_divergence(preds, std_devs, y, noise=None, normalize=True): """ Computes the KL-divergence from the predicted distribution (using preds and std_devs, assuming normal distributions) to the ground-truth distribution (all of the probability mass on the true y values). In other words, this tells you how much information would be gained by learning the true values. Averaged over sample points. Lower is generally better. Not only does this penalize uncertainties which are poorly calibrated (i.e., it penalizes uncertainties where the actual error distribution has a different standard deviation), but also it penalizes uncertainties which are not as specific as they could be, i.e., which fail to discriminate between certain and uncertain predictions. :param preds: predicted values :param std_devs: predicted uncertainties, given as standard deviations :param y: true values :param normalize: if True (the default), normalizes with respect to the RMSE """ if noise is not None: std_devs = std_devs + noise nonzeros = std_devs != 0 preds, std_devs, y = preds[nonzeros], std_devs[nonzeros], y[nonzeros] # print(f"len = {len(preds)}") if normalize: scale = np.sqrt(np.mean(np.square(preds - y))) preds = preds / scale std_devs = std_devs / scale y = y / scale return -np.mean(np.log2(stats.norm.pdf(y, loc=preds, scale=std_devs)))
[docs]def best_multiple(preds, std_devs, y, noise=None, max_precision=5): """ Does a simple binary search to find which multiple of std_devs gives the lowest KL-divergence score. To converge, assumes there is only one local minimum. (I expect this to be true, but I will have to check.) Arguments preds, std_devs, y and noise are as in KL_divergence, except this will compute the KL-divergence for multiples of std_devs. :param max_precision: the number of interval splits the score must be adjacent to before returning the value. Defaults to 5. """ def KL_score(m, preds, std_devs, y, noise=None): return KL_divergence(preds, m * std_devs, y, noise=None) return binary_optimize( KL_score, preds, std_devs, y, noise=noise, mode="min", max_precision=max_precision, )
[docs]def binary_optimize( fn, *args, mode="max", start=1.0, max_precision=5, max_iterations=50, **kwargs ): """ Does a simple binary search to find which scalar value gives the best (max/min) value of fn. Optimization converges to a local max/min. Each search iteration starts with the previously-explored value with the best score, and looks on either side of it. If the best score is at the beginning of the current list of value, it looks at half this value on the low side; if at the end of the list, looks at twice this value on the high side. If the best value is somewhere in the middle, it divides the interval on either side in half. :param max_precision: the number of interval splits the score must be adjacent to before returning the value. Defaults to 5. """ opt = {"max": np.argmax, "min": np.argmin}[mode] vals = [start] scores = [fn(start, *args, **kwargs)] precision = [0] for _ in range(max_iterations): i = opt(scores) if precision[i] >= max_precision: break # look below v = search_lower(i, vals, precision) if v is None: break vals.insert(i, v) scores.insert(i, fn(v, *args, **kwargs)) i += 1 # min value has shifted because of insert # look above v = search_upper(i, vals, precision) if v is None: break vals.insert(i + 1, v) scores.insert(i + 1, fn(v, *args, **kwargs)) for prec, val, score in zip(precision, vals, scores): print(f"{prec:02d} - {val:03.2f} : {score:.3f}") return vals[opt(scores)]
[docs]def search_lower(i, vals, precision, max_precision=5): v = None if i == 0: v = vals[i] / 2 precision.insert(i, 0) else: if precision[i - 1] >= max_precision: return v v = (vals[i - 1] + vals[i]) / 2 precision.insert(i, max(precision[i - 1 : i + 1]) + 1) return v
[docs]def search_upper(i, vals, precision, max_precision=5): v = None if i == len(vals) - 1: v = vals[i] * 2 precision.insert(i + 1, 0) else: if precision[i + 1] >= max_precision: return v v = (vals[i] + vals[i + 1]) / 2 precision.insert(i + 1, max(precision[i : i + 2]) + 1) return v
[docs]def plot_errors(preds, std_devs, y, noise=None, show=True, axes=None, **kwargs): if axes is None: axes = plt.gca() errs = y - preds sort = np.argsort(std_devs) std_devs, errs = std_devs[sort], errs[sort] X = np.arange(len(errs)) axes.plot(X, np.zeros(len(errs)), color="green") axes.fill_between(X, -std_devs, std_devs, alpha=0.3) if "alpha" not in kwargs: kwargs["alpha"] = 0.35 axes.scatter(X, errs, **kwargs) if show: plt.show()