from collections.abc import Mapping, MutableSet, MutableSequence, Collection
from typing import List, Union
import numpy as np
from numpy.typing import ArrayLike
# pylint: disable=import-error
[docs]def seed_all(seed):
import random
random.seed(seed)
np.random.seed(seed)
try:
import torch
torch.manual_seed(seed)
except ImportError:
pass
try:
import tensorflow as tf
tf.random.set_seed(seed)
except ImportError:
pass
[docs]def as_list(x):
if isinstance(x, list):
return x
if isinstance(x, Collection) and not isinstance(x, str):
return list(x)
if x is None:
return []
return [x]
[docs]def match(target, pool, fn=lambda x, y: x == y):
for p in pool:
if fn(target, p):
return p
return None
[docs]def isint(i):
"""Check whether i can be cast to an integer
Args:
i (_type_): _description_
Returns:
_type_: _description_
"""
try:
return int(i) == i
except (ValueError, TypeError):
return False
[docs]def dict_get(d, *keys, _pop=False, **kwargs):
out = {}
d_get = d.pop if _pop else d.get
for k in keys:
if isinstance(k, list) or isinstance(k, tuple):
for kk in reversed(k):
if kk in d:
out[k[0]] = d_get(kk)
elif k in d:
out[k] = d_get(k)
for k, v in kwargs.items():
if k in d:
out[k] = d_get(k)
elif k not in out:
out[k] = v
return out
[docs]def dict_pop(d, *keys, **kwargs):
return dict_get(d, *keys, _pop=True, **kwargs)
[docs]def std_keys(d, *groups, **defaults):
"""
Standardizes keys in a dictionary.
Each positional arg, `g`, apart from the first one (`d`) should be an
iterable of keys. Then `d` is processed so that all keys in `g`
are replaced with the first key in `g`. If more than one key in `g`
is in `d`, then all keys in `g` except the first (in `d`) will be removed.
"""
out = d.copy()
for g in groups:
for h in reversed(g):
if h in d:
del out[h]
out[g[0]] = d[h]
for k, v in defaults.items():
if k not in out:
out[k] = v
return out
[docs]def update_copy(d1, d2=None, **kwargs):
d = d1.copy()
if isinstance(d2, dict):
d.update(d2)
d.update(kwargs)
return d
ERROR = "f6b1c433450cb749f45844dd60ab3b27"
[docs]def any_get(s, elements, default=ERROR):
assert isinstance(elements, Collection) and not isinstance(elements, str)
assert len(elements) > 0
if isinstance(s, Mapping):
try:
return s[any_get(s.keys(), elements, default=ERROR)]
except KeyError:
pass
else:
for e in elements:
if e in s:
return e
if default==ERROR:
raise KeyError(f'None of the given keys are in the {type(s)}.')
else:
return default
[docs]def any_pop(s, keys, default=ERROR):
try:
if isinstance(s, Mapping):
k = any_get(s.keys(), keys, default=ERROR)
return s.pop(k)
else:
k = any_get(s, keys, default=ERROR)
s.remove(k)
return k
except KeyError:
if default==ERROR:
raise KeyError(f'None of the given keys are in the {type(s)}.')
else:
return default
[docs]def alias(argname):
"""Aliases an argument"""
# TODO: write argument aliasing tool
#raise NotImplementedError
[docs]def multisearch(a, query, one_to_one=True):
"""
Finds the indices of multiple query values. Searches over the
first axis of 'a', with further axes corresponding to 'feature
space', i.e., the shape of the search terms. Return type depends
on one_to_one.
:param a: the array to search
:param query: an array of query values
:one_to_one: if True, validates that each search term
appears exactly once, and returns a corresponding array
of indices. If False, returns a 2D array with 2nd axis of
length 2, with each pair of the form (query_index, array_index)
"""
red_axes = tuple(range(-a.ndim + 1, 0)) # 'feature space' axes
hits = np.all(a == query[:, None, ...], axis=red_axes)
args = np.argwhere(hits)
if one_to_one:
assert np.all(
args[:, 0] == np.arange(len(query))
), "Search results are not one-to-one with search queries!"
return args[:, 1]
else:
return args
[docs]def ufunc(f):
def wrapped_f(*args, **kwargs):
try:
return wrapped_f._f(*args, **kwargs)
except TypeError:
wrapped_f._f = np.frompyfunc(f, len(args), 1)
return wrapped_f._f(*args, **kwargs)
wrapped_f._f = f
return wrapped_f
[docs]def shift_seed(seed, shift):
try:
return seed + shift
except TypeError:
return seed
[docs]def ranges(*args):
"""
Takes arguments ([start,] stop, step).
Returns a list of pairs consisting of
(start_i, stop_i), which together divide
up range(start, stop) into chunks of size
step (plus final chunk).
"""
if len(args) == 2:
args = (0,) + args
stop = args[-2]
edges = list(range(*args)) + [stop]
return list(zip(edges[:-1], edges[1:]))
[docs]class chunks:
"""
Takes arguments (seq, [[start,] stop,] step).
Returns an iterator which iterates over chunks
of seq, of size step.
"""
def __init__(self, seq, *args):
if len(args) == 1:
try:
args = (len(seq),) + args
except TypeError:
seq = list(seq)
args = (len(seq),) + args
self.seq = seq
self.range_iter = iter(ranges(*args))
self.args = args
self.step = args[-1]
def __len__(self):
return (len(self.seq) + self.step - 1) // self.step
def __iter__(self):
self.range_iter = iter(ranges(*self.args))
return self
def __next__(self):
start, stop = next(self.range_iter)
return self.seq[start:stop]
[docs]def frac_enum(seq, start_zero=True):
"""
Much like enumerate, returns an iterator yielding ordered pairs
(t, x) where x is an element of seq and t is the fraction of the
way through the sequence.
if start_zero==True, the fraction starts at 0 and ends just
short of 1. Otherwise, starts just over 0 and ends at 1.
"""
fracs, step = np.linspace(0, 1, len(seq), endpoint=False, retstep=True)
if not start_zero:
fracs += step
return zip(fracs, seq)
[docs]def add_slice(s: slice, i: int) -> slice:
"""
Returns a 'shifted' slice `slice(s.start + i, s.stop + i, s.step)`,
unless `s` is `None`, in which case it returns a slice representing
the whole window, minus a bit at the start (if `i > 0`) or the end
(if `i < 0`).
"""
if s is None:
s = slice(None)
if s.start is None:
start = i if i > 0 else None
else:
start = s.start + i
if s.stop is None:
stop = i if i < 0 else None
else:
stop = s.stop + i
return slice(start, stop, s.step)
[docs]def reshape(x, shape, index=None):
if isinstance(index, slice):
assert index.step == 1, f"index step must be 1, but yours is {index.step}"
shape = x.shape[: index.start] + shape + x.shape[index.stop :]
try:
return x.reshape(*shape)
except AttributeError:
pass
if "tensorflow" in str(type(x)):
import tensorflow as tf
return tf.reshape(x, shape)
raise TypeError(f"Can't reshape tensors of type {type(x)}")
[docs]def flatten(a, dims):
if dims < 0:
return reshape(a, (1 - dims) * (1,) + a.shape)
return reshape(a, (-1,) + a.shape[dims:])
[docs]def diagonal(x, dims=2, degree=1, bdim=0):
bshape = x.shape[:bdim]
mshape = x.shape[bdim : bdim + degree]
mlen = np.prod(mshape)
fshape = x.shape[bdim + dims * degree :]
x = reshape(x, bshape + dims * (mlen,) + fshape)
for _ in range(dims - 1):
x = np.diagonal(x, axis1=bdim, axis2=bdim + 1)
return x.reshape(bshape + mshape + fshape)
[docs]def concatenate(*args):
"""
Concatenates a series of datasets, or one of the supported
datatypes, along axis 0 (the samples axis)
"""
args = [a for a in args if a is not None and a != []]
if len(args) == 0:
return None
from .data.dataset import TeachableDataset, TeachableWrapperDataset
if all(isinstance(a, TeachableWrapperDataset) for a in args):
return TeachableDataset.from_data(concatenate(*(a.data for a in args)))
elif all(isinstance(a, np.ndarray) for a in args):
return np.concatenate(args, axis=0)
elif all("torch" in str(type(a)) for a in args):
import torch
return torch.cat(args, dim=0)
elif all(isinstance(a, tuple) for a in args):
return tuple(concatenate(*vals) for vals in zip(*args))
elif all(isinstance(a, dict) for a in args):
return {k: concatenate(*(d[k] for d in args)) for k in args[0].keys()}
elif all(isinstance(a, MutableSequence) for a in args):
return sum(args, [])
else:
raise TypeError("Unsupported types for concatenate function.")
[docs]def join(*args, make_ds=False):
"""
Concatenates a series of datasets along axis 1 (the first
feature axis). Datasets must have same length, and if they are
numpy or torch arrays, they must have the same shape in
dimensions >= 2.
"""
# pylint: disable=import-outside-toplevel
from .data.dataset import TeachableDataset, TeachableWrapperDataset
if len(args) == 1 and is_iterable(args[0]):
args = args[0]
assert all(len(a) == len(args[0]) for a in args[1:])
args = [a.data if isinstance(a, TeachableWrapperDataset) and (make_ds := True) else a for a in args]
if any(type(a) == tuple for a in args):
args_unpacked = []
for a in args:
if isinstance(a, tuple):
args_unpacked += [*a]
else:
args_unpacked += [a]
data = tuple(args_unpacked)
elif any(isinstance(a, np.ndarray) for a in args):
args = [(a[...,None] if a.ndim < 2 else a) for a in args]
data = np.concatenate(args, axis=1)
elif all("torch" in str(type(a)) for a in args):
import torch
args = [(a[...,None] if a.ndim < 2 else a) for a in args]
data = torch.cat(args, dim=1)
elif all(isinstance(a, dict) for a in args):
data = {}
for a in args:
data.update(a)
return TeachableDataset.from_data(data) if make_ds else data
[docs]def is_iterable(x):
try:
iter(x)
return True
except RuntimeError:
return False
[docs]def as_numpy(data):
if isinstance(data, np.ndarray):
return data
from .data import ArrayDataset
t = str(type(data))
if "torch" in t:
return data.cpu().detach().numpy()
elif "tensorflow" in t:
return data.numpy()
elif isinstance(data, ArrayDataset):
return data.__array__()
else:
return np.asarray(data)
[docs]def is_one(x):
return x == 1 and type(x) == int
[docs]def zip_dict(*dicts):
"""Similar behavior of zip(*) for dictionaries.
Assumes that all dicts have the same keys.
>>> zip_dict({'a': 1}, {'a': 2})
{'a': (1, 2)}
Returns:
_type_: _description_
"""
return {k: (d[k] for d in dicts) for k in dicts[0].keys()}
[docs]def axes_except(X, non_axes):
if isint(non_axes):
non_axes = (non_axes,)
non_axes = tuple(a % X.ndim for a in non_axes)
return tuple(a for a in range(X.ndim) if a not in non_axes)
[docs]def sum_except(X, non_axes):
return X.sum(axis=axes_except(X, non_axes))
[docs]def no_default():
raise NotImplementedError("no_default is used as a unique reference, and should not be called")
[docs]class SelfDict(dict):
"""
Subclass of dict class which allows you to refer to,
eg., d['attr'] as d.attr and vice versa.
You can also index with lists, where
d[['a', 'b']] == {'a': d['a'], 'b': d['b']}
Similarly with pop()
"""
def __init__(self, *args, default=no_default, **kwargs):
super().__init__(*args)
self.update(kwargs)
self.__dict__ = self
for k, v in self.items():
if isinstance(v, Mapping):
self[k] = SelfDict(v)
def __setitem__(self, key, value):
if isinstance(value, Mapping):
super().__setitem__(key, SelfDict(value))
else:
super().__setitem__(key, value)
def __setattr__(self, key, value):
if key.startswith("__"):
super().__setattr__(key, value)
else:
self[key] = value
def __getitem__(self, key):
s = super()
if isinstance(key, list):
# if self.__default == no_default:
return SelfDict({k: s.__getitem__(k) for k in key if k in s.__iter__()})
# else:
# return SelfDict({k:(s[k] if k in s else self.__default) for k in key})
else:
# if key not in self and self.__default != no_default:
# return self.__default
return s.__getitem__(key)
[docs] def pop(self, key, default=no_default):
s = super()
# default == self.__default if default == no_default else default
if isinstance(key, list):
if default == no_default:
return SelfDict({k: s.pop(k) for k in key if k in s.__iter__()})
else:
return SelfDict({k: s.pop(k, default) for k in key})
else:
return s.pop(key) if default == no_default else s.pop(key, default)
[docs]class CachedDict:
def __init__(self, get, *init_keys, **d2):
self.cache = {}
self._get = get.__get__(self)
if isinstance(init_keys[0], Mapping):
d1, init_keys = init_keys[0], init_keys[1:]
else:
d1 = {}
self.cache.update(d1)
self.cache.update(d2)
if len(init_keys) == 1 and (
isinstance(init_keys[0], MutableSet) or
isinstance(init_keys[0], MutableSequence)):
init_keys = init_keys[0]
[self(k) for k in init_keys]
for m in ['keys', 'values', 'items', 'get', 'pop', '__iter__', '__contains__']:
self.__dict__[m] = getattr(self.cache, m)
def __call__(self, key):
if key in self.cache:
return self.cache[key]
value = self._get(key)
self.cache[key] = value
return value
__getitem__ = __getattr__ = __call__
[docs]def dot_last(a, b):
return (a * b).sum(axis=-1)