Source code for alien.models.keras.keras

"""Model classes for wrapping Keras-based models into ALiEN format."""

import numpy as np

from ...decorators import flatten_batch, get_defaults_from_self, get_Xy
from ...utils import dict_pop, seed_all, shift_seed, update_copy
from ...config import default_training_epochs, INIT_SEED_INCREMENT
from ..mc_dropout import MCDropoutRegressor
from ..models import Model
from ...data import ArrayDataset

# pylint: disable=import-outside-toplevel


[docs]class KerasRegressor(MCDropoutRegressor): """Base Class for wrapped Keras regression models.""" def __init__(self, model=None, X=None, y=None, **kwargs): self.compile_kwargs = dict_pop( kwargs, "metrics", "loss_weights", "weighted_metrics", "run_eagerly", "steps_per_execution", "jit_compile", optimizer="adam", loss="mse", ) self.fit_kwargs = dict_pop( kwargs, "batch_size", "verbose", "callbacks", "validation_split", "validation_data", "shuffle", "class_weight", "sample_weight", "initial_epoch", "steps_per_epoch", "validation_steps", "validation_batch_size", "validation_freq", "max_queue_size", "workers", "use_multiprocessing", epochs=default_training_epochs, ) global tf import tensorflow as tf self.initial_weights = None self.compiled = False self.model = model super().__init__(X=X, y=y, **kwargs)
[docs] @get_defaults_from_self def initialize(self, init_seed=None, sample_input=None): if init_seed is not None: seed_all(init_seed) self.init_seed = shift_seed(init_seed, INIT_SEED_INCREMENT) self.save_initial_weights(sample_input) # Actually reinitializing the weights in keras is almost impossible, # so instead we shuffle the saved initial weights (within each tensor), # giving the same distribution self.set_weights( [ np.random.default_rng(init_seed).permutation(w.flat).reshape(w.shape) for w in self.initial_weights ] )
[docs] def get_weights(self): """Return model weights.""" weights = [w.copy() for w in self.model.get_weights()] return None if len(weights) == 0 else weights
[docs] @get_defaults_from_self def set_weights(self, initial_weights=None): """Set model weights according to initial_weights""" self.model.set_weights([w.copy() for w in initial_weights])
[docs] def save_initial_weights(self, sample_input=None): """Save initial weights in self.initial_weights object.""" if self.initial_weights is not None: return if sample_input is None: if self.X is None: raise ValueError( "Can't initialize weights for the first time without sample input." ) sample_input = self.X[:1] self.predict(sample_input) self.initial_weights = self.get_weights().copy()
[docs] @get_Xy @get_defaults_from_self def fit_model(self, X=None, y=None, **kwargs): if not self.compiled: self.model.compile(**self.compile_kwargs) self.compiled = True if not (isinstance(X, tf.Tensor) or isinstance(X, np.ndarray)): X, y = np.asarray(X), np.asarray(y) self.model.fit(x=X, y=y, **update_copy(self.fit_kwargs, **kwargs))
[docs] @flatten_batch def predict(self, X, **kwargs): #import tensorflow as tf if not (isinstance(X, tf.Tensor) or isinstance(X, np.ndarray)): X = np.asarray(X) preds = self.model(X, training=False, **kwargs) if preds.ndim == 2 and preds.shape[-1] == 1: preds = preds[:, 0] return preds
[docs] def fix_dropouts(self): #import tensorflow as tf from .utils import modify_dropout, subobjects for m in subobjects(self.model, skip=self.nodropout_layers): if modify_dropout(m): self.dropouts.append(m)
[docs] @flatten_batch def predict_samples(self, X, n=1, multiple=1.0): #import tensorflow as tf preds = [] seed = self.rng.integers(1e6) if not (isinstance(X, tf.Tensor) or isinstance(X, np.ndarray)): X = np.asarray(X) for i in range(n): # get output tf.random.set_seed(shift_seed(seed, i * 7)) out = self.model(X, training=1) # training=1 sets special dropout behaviour if out.ndim >= 2 and out.shape[-1] == 1: out = out[..., 0] preds.append(out) return tf.stack(preds, axis=1)