Source code for alien.models.pytorch.training_limits

from dataclasses import dataclass
from functools import wraps
from inspect import signature
from typing import Optional

from ...utils import dict_pop
from ...config import default_training_epochs, default_training_samples

limit_long_names = ["sample_limit", "batch_limit", "epoch_limit"]
limit_short_names = ["samples", "batches", "epochs"]
limit_name_pairs = list(zip(limit_long_names, limit_short_names))
# limit_short_to_long = dict(zip(limit_short_names, limit_long_names))
limit_long_to_short = dict(limit_name_pairs)


# decorator
[docs]def get_training_limit(fn): """ *** Decorator *** Modifies a function so that it can take any of a number of different naive arguments to specify limits to pytorch training. The inner (wrapped) function will see only an argument `training_limit`, which receives an instance of the fancy TrainingLimit class. Thus, the inner fn must have an argument named `training_limit` (or `**kwargs`). The outer, decorated fn may take additional optional kwargs: `'sample_limit'`, `'batch_limit'`, `'epoch_limit'` `'samples'`, `'batches'`, `'epochs'` If these are given, they determine the value of training_limit according a calculation that favors them in the order provided (though you may only want to provide one). """ @wraps(fn) def wrapped_fn(*args, **kwargs): lim_kwargs = dict_pop(kwargs, *limit_name_pairs) if "training_limit" not in kwargs: if len(lim_kwargs) > 0: kwargs["training_limit"] = StdLimit(**lim_kwargs) return fn(*args, **kwargs) return wrapped_fn
[docs]@dataclass class TrainingLimit: """ Encapsulates the computation of training limits, which may depend on things like dataset length. """ min_samples: int = 0 min_epochs: float = 0 samples: Optional[int] = None epochs: Optional[float] = None batches: Optional[int] = None max_samples: float = float("inf") max_epochs: float = float("inf")
[docs] def sample_limit(self, length=None): if length is None: min_samples = self.min_samples max_samples = self.max_samples samples = self.samples else: min_samples = max(self.min_samples, self.min_epochs * length) max_samples = min(self.max_samples, self.max_epochs * length) samples = self.epochs * length if self.epochs is not None else self.samples if min_samples > max_samples: return 0.5 * (min_samples + max_samples) elif samples is None: if max_samples == float("inf"): return min_samples if min_samples > 0 else \ (default_training_epochs * length if length else \ default_training_samples) return 0.5 * (min_samples + max_samples) if samples < min_samples: return min_samples elif samples > max_samples: return max_samples else: return samples
[docs] def batch_limit(self, batch_size=None, length=None): if self.batches: return self.batches if not batch_size: raise ValueError( "Must provide positive batch_size to method batch_limit(...)\n Or else pass batch_limit into the constructor." ) return self.sample_limit(length=length) // batch_size
[docs]class StdLimit(TrainingLimit): def __init__(self, **kwargs): assert set(kwargs).issubset( limit_long_names ), f"For StdLimit, may only pass in\n{*limit_long_names, *limit_short_names} \nYou may want to try TrainingLimit." kwargs = {limit_long_to_short[name]: v for name, v in kwargs.items()} super().__init__(**kwargs)
default_limit = TrainingLimit(min_samples=1e4, min_epochs=10)