Source code for alien.models.pytorch.pytorch

from math import sqrt

import numpy as np
from numpy.typing import ArrayLike

from ...decorators import flatten_batch, get_defaults_from_self, get_Xy
from ..laplace import LinearizableLaplaceRegressor
from ..mc_dropout import MCDropoutRegressor
from ...config import INIT_SEED_INCREMENT
from .last_layer import LastLayerPytorchLinearization
from .training_limits import default_limit
from .utils import as_tensor, dropout_forward, submodules, pl_argnames
from ...utils import dict_pop, shift_seed

# imports of torch occur inside __init__ to avoid import when not used
# pylint: disable=import-outside-toplevel


[docs]def init_weights(module): if hasattr(module, "reset_parameters"): module.reset_parameters() else: import torch if hasattr(module, "weight") and module.weight is not None: if module.weight.dim() >= 2: torch.nn.init.kaiming_uniform_( module.weight, a=0, mode="fan_in", nonlinearity="leaky_relu" ) elif module.weight.dim() == 1: bound = sqrt(6 / module.weight.shape[0]) torch.nn.init.uniform_(module.weight, -bound, bound) else: torch.nn.init.uniform_(module.weight, -sqrt(2), sqrt(2)) init_bias(module, torch)
[docs]def init_bias(module, torch): if hasattr(module, "bias") and module.bias is not None: if module.weight.dim() >= 2: fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) bound = 1 / sqrt(fan_in) if fan_in > 0 else 0 else: bound = sqrt(2) torch.nn.init.uniform_(module.bias, -bound, bound)
[docs]class PytorchRegressor(LastLayerPytorchLinearization, MCDropoutRegressor, LinearizableLaplaceRegressor): """ :param trainer: Specifies how the model will be trained. May be: 'model' --- calls self.model.fit 'lightning' --- trains with pytorch-lightning trainer --- calls trainer.fit None --- chooses from the above in order, if available """ def __init__( self, model=None, X=None, y=None, trainer=None, batch_size=64, training_limit=default_limit, collate_fn=None, random_seed=None, **kwargs, ): # imports occur inside __init__ to avoid import when not used: global torch # pylint: disable=global-statement import torch assert ( isinstance(model, torch.nn.Module) or model is None ), f"model is of type {type(model)}. Should be torch.nn.Module or None." self.model = model self.batch_size = batch_size self.training_limit = training_limit self.collate_fn = collate_fn pl_kwargs = dict_pop(kwargs, *pl_argnames) super().__init__( X=X, y=y, random_seed=random_seed, **kwargs ) # if no trainer is provided, choose one based on what's available if trainer is None: if hasattr(model, "fit"): trainer = "model" else: try: import pytorch_lightning trainer = "lightning" except ImportError: pass if trainer == "model": self.trainer = model elif trainer == "lightning": self.trainer = self.get_lightning_trainer( random_seed=random_seed, collate_fn=collate_fn, training_limit=training_limit, **pl_kwargs ) else: self.trainer = trainer
[docs] def fix_dropouts(self): import torch for name, module in submodules(self.model, skip=self.nodropout_layers): if isinstance(module, torch.nn.Dropout): module.forward = dropout_forward.__get__(module) self.dropouts.append(module)
[docs] def get_lightning_trainer(self, random_seed=None, **kwargs): """Rerturn a LightningTrainer object from current model.""" from .lightning import LightningTrainer if random_seed is not None: from pytorch_lightning import seed_everything seed_everything(shift_seed(random_seed, 31523), workers=True) kwargs["deterministic"] = True return LightningTrainer(self.model, **kwargs)
[docs] @get_defaults_from_self def fit_model(self, X=None, y=None, reinitialize=None, init_seed=None, **kwargs): self.trainer.fit(X, y, **kwargs)
[docs] @get_defaults_from_self def initialize(self, init_seed=None, sample_input=None): import torch if init_seed is not None: torch.manual_seed(init_seed) self.init_seed = shift_seed(init_seed, INIT_SEED_INCREMENT) self.model.apply(init_weights)
[docs] @flatten_batch def predict(self, X, return_std_dev=False, convert_dtype=True): import torch try: X = as_tensor(X) if convert_dtype: X = X.type(self.dtype) except (ValueError, TypeError, RuntimeError): pass with torch.no_grad(): self.model.eval() return self.model(X)
[docs] @flatten_batch def predict_samples(self, X, n=1, multiple=1.0, convert_dtype=True): assert multiple == 1 self.model.eval() for d in self.dropouts: d.training = 1 try: X = as_tensor(X) if convert_dtype: X = X.type(self.dtype) except (ValueError, TypeError, RuntimeError): pass with torch.no_grad(): return torch.stack([self.model(X).squeeze() for _ in range(n)], dim=1)
@property def dtype(self): if getattr(self, '_dtype', None) is None: self._dtype = next(iter(self.model.parameters())).dtype return self._dtype