"""Helper functions for Pytorch models."""
from collections.abc import Iterable
import numpy as np
from collections.abc import Iterable
from ...utils import is_one
# pylint: disable=import-outside-toplevel
[docs]def as_tensor(x):
"""Return the tensor version of x (i.e. itself it it already is ArrayLike, or x.data)"""
import torch
if isinstance(x, torch.Tensor):
return x
if isinstance(x, np.ndarray):
return torch.from_numpy(x)
from ...data import NumpyDataset, TorchDataset
if isinstance(x, TorchDataset):
return x.data
if isinstance(x, NumpyDataset):
return torch.from_numpy(x.data)
if hasattr(x, "__array__"):
return torch.from_numpy(x.__array__())
return torch.asarray(x)
[docs]def dropout_forward(self, x):
import torch
import torch.nn.functional as F
if is_one(self.training):
ones = torch.ones(x.shape[1:], dtype=x.dtype, device=x.device, requires_grad=True)
return x * F.dropout(ones, self.p, True, self.inplace)
return F.dropout(x, self.p, self.training, self.inplace)
[docs]def submodules(module, include_names=True, skip=frozenset()):
"""
Iterator through submodules of `module` (paired with their names,
if `include_names` is True). A submodule is returned only once, on its
first occurence in a depth-first traversal.
Any modules in `skip` (given either as the actual module, or its name)
will be skipped, along with all their submodules.
Args:
module (torch.nn.module): The module, whose submodules we will
iterate through.
include_names (bool): True, the iterator yields pairs `(name, submodule)`,
otherwise it yields just `submodule`. (The returned name is the name it's
indexed as the first time it occurs in the tree. If the submodule is not
named, its name will return as `None`)
skip: A collection of modules to skip. Can contain
either modules themselves, and/or their names.
"""
skip = set(skip)
for name, m in module.named_children():
if m not in skip and name not in skip:
skip.add(m)
yield (name, m) if include_names else m
for sub in submodules(
m, include_names=include_names, skip=skip
):
yield sub
pl_argnames = [
'accelerator',
'strategy',
'devices',
'num_nodes',
'precision',
'logger',
'callbacks',
'fast_dev_run',
'overfit_batches',
'val_check_interval',
'check_val_every_n_epoch',
'num_sanity_val_steps',
'log_every_n_steps',
'enable_checkpointing',
'enable_progress_bar',
'enable_model_summary',
'accumulate_grad_batches',
'gradient_clip_val',
'gradient_clip_algorithm',
'benchmark',
'inference_mode',
'use_distributed_sampler',
'profiler',
'detect_anomaly',
'barebones',
'plugins',
'sync_batchnorm',
'reload_dataloaders_every_n_epochs',
'default_root_dir',
]