Source code for alien.models.pytorch.last_layer

"""Model wrapper for last layer linearization."""

# The code in this file is largely borrowed from the Laplace Redux
# implementation by Alex Immer:
#
# https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/feature_extractor.py
#
# That code is under the MIT License, which we provide a copy of here, in
# accordance with the license requirements. Note, the rest of the Active
# Learning SDK is *not* provided under the MIT license.
#
# MIT License
#
# Copyright (c) 2021 Alex Immer
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from numpy.typing import ArrayLike

from ...decorators import flatten_batch
from ..linear import LastLayerLinearizableRegressor

# pylint: disable=import-outside-toplevel


[docs]class LastLayerPytorchLinearization(LastLayerLinearizableRegressor): """ "Last layer linearization for Pytorch-based models.""" def __init__(self, model=None, X=None, y=None, *args, **kwargs): self.last_layer = None self._features = {} self._last_layer_name = "" super().__init__(model=model, X=X, y=y, *args, **kwargs)
[docs] @flatten_batch def predict_with_embedding(self, X): """Forward pass which returns the output of the penultimate layer along with the output of the last layer. If the last layer is not known yet, it will be determined when this function is called for the first time. :param X: one batch of data to use as input for the forward pass """ if self.last_layer is None: # if this is the first forward pass and last layer is unknown out = self.find_last_layer(X) else: # if last layer is already known out = self.predict(X) features = self._features[self._last_layer_name] return out, features
[docs] def last_layer_embedding(self, X): return self.predict_with_embedding(X)[1]
[docs] def linearization(self): return self.last_layer.weight, self.last_layer.bias
[docs] def set_last_layer(self, last_layer_name: str) -> None: """Set the last layer of the model by its name. This sets the forward hook to get the output of the penultimate layer. :param last_layer_name: the name of the last layer (fixed in `model.named_modules()`). """ import torch # set last_layer attributes and check if it is linear self._last_layer_name = last_layer_name self.last_layer = dict(self.model.named_modules())[last_layer_name] if not isinstance(self.last_layer, torch.nn.Linear): raise ValueError("Use model with a linear last layer.") # set forward hook to extract features in future forward passes self.last_layer.register_forward_hook(self._get_hook(last_layer_name))
def _get_hook(self, name: str): def hook(_, input, __): # only accepts one input (expects linear layer) self._features[name] = input[0].detach() return hook
[docs] def find_last_layer(self, X: ArrayLike): """Automatically determines the last layer of the model with one forward pass. It assumes that the last layer is the same for every forward pass and that it is an instance of `torch.nn.Linear`. Might not work with every architecture, but is tested with all PyTorch torchvision classification models (besides SqueezeNet, which has no linear last layer). :param X: batch of samples used to find last layer. :return: Returns the output of the forward pass, so as not to waste computation. """ if self.last_layer is not None: raise ValueError("Last layer is already known.") act_out = dict() def get_act_hook(name): def act_hook(_, input, __): # only accepts one input (expects linear layer) try: act_out[name] = input[0].detach() except (IndexError, AttributeError): act_out[name] = None # remove hook handles[name].remove() return act_hook # set hooks for all modules handles = dict() for name, module in self.model.named_modules(): handles[name] = module.register_forward_hook(get_act_hook(name)) # check if model has more than one module # (there might be pathological exceptions) if len(handles) <= 2: raise ValueError("The model only has one module.") # forward pass to find execution order out = self.predict(X) # find the last layer, store features, return output of forward pass keys = list(act_out.keys()) for key in reversed(keys): layer = dict(self.model.named_modules())[key] if len(list(layer.children())) == 0: self.set_last_layer(key) # save features from first forward pass self._features[key] = act_out[key] return out raise ValueError("Something went wrong (all modules have children).")