"""
Module with dataset (sub-)classes for storing data.
"""
# TODO in this module:
# - join function for DictDataset
# - check join implementation for TupleDataset
# - concatenate function for TupleDataset
# - numpy warning in concatenate
# - specify exceptions
# - Dataset.from_data parameters align with TeachableDataset.from_data
# - other smaller todos throughout
import sys
import warnings
from abc import ABCMeta, abstractmethod
from collections.abc import MutableSequence
from typing import Any, Optional, Union
import numpy as np
from numpy.random import BitGenerator, Generator, SeedSequence
from numpy.typing import ArrayLike
if "torch" in sys.modules:
import torch
from ..utils import add_slice, reshape, isint, update_copy
[docs]class Dataset(metaclass=ABCMeta):
"""
Abstract interface to a readable dataset.
"""
def __new__(cls, *args, **kwargs):
if cls == Dataset:
return Dataset.from_data(*args, **kwargs)
else:
return super().__new__(cls)
def __init__(self, *, has_Xy=None, bdim=1, **kwargs):
super().__init__(**kwargs)
self.has_Xy = has_Xy
self.bdim = bdim
@abstractmethod
def __getitem__(self, index):
pass
@abstractmethod
def __len__(self):
pass
[docs] @abstractmethod
def find(self, value, first=True):
"""
Finds instances of `value` in this dataset.
If first is True, returns the index of the first
occurence (or None if not found), otherwise returns
an iterable of indices of all occurences.
"""
def __iter__(self):
"Default iterator implementation"
return iter(self[i] for i in range(len(self)))
[docs] @staticmethod
def from_data(*args, **kwargs):
"""
Returns a Dataset built from the given data and other args.
Arguments and functionality are exactly like
TeachableDataset.from_data
In fact, at present, this method just calls
TeachableDataset.from_data
"""
# TODO: pylint doesn't like that the parent class uses *args, **kwargs.
# Need to figure out a general way that doesn't break this.
dataset = TeachableDataset.from_data(*args, **kwargs)
return dataset
@property
def X(self):
"""Return features."""
self.check_Xy()
if self.bdim == 1:
return Dataset.from_data(self[:, :-1], recursive=False)
else:
i = (slice(None),) * self.bdim + (slice(None, -1),)
return Dataset.from_data(self[i], recursive=False)
@property
def y(self):
"""Return targets."""
self.check_Xy()
if self.bdim == 1:
return Dataset.from_data(self[:, -1], recursive=False)
else:
i = (slice(None),) * self.bdim + (-1,)
return Dataset.from_data(self[i], recursive=False)
[docs] def check_Xy(self):
if not self.has_Xy:
warnings.warn("Dataset doesn't store separate `X` or `y` columns.")
@property
@abstractmethod
def shape(self):
"""Abstract method for returning shape."""
@property
def ndim(self):
"""Returns: int: number of dimensions"""
return len(self.shape)
@property
def batch_shape(self):
return self.shape[: self.bdim]
@property
def feature_shape(self):
return self.shape[self.bdim :]
[docs] def reshape(self, *shape, index=None, bdim=None):
raise NotImplementedError
[docs]class TeachableDataset(Dataset):
"""
Abstract interface to a teachable dataset.
"""
[docs] @abstractmethod
def append(self, x: Any):
"""
Appends a single sample to the end of the dataset.
"""
[docs] def extend(self, X: ArrayLike):
"""
Appends a batch of samples to the end of the dataset.
"""
# This is the default implementation of extend.
# Subclasses may accomplish this faster
for val in X:
self.append(val)
[docs] @staticmethod
def from_data(
data=None,
shuffle: Optional[Union[bool, str]] = False,
random_seed: Optional[Union[int, ArrayLike, SeedSequence, BitGenerator, Generator]] = None,
recursive: bool = True,
convert_sequences: bool = True,
**kwargs,
):
"""
Creates a TeachableDataset with given data.
:param data: the initial data of the dataset
Can be:
* another TeachableDataset
* a Python mutable sequence (eg., a list) or
anything that implements the interface
* a Numpy array
* a Pytorch tensor
* a dictionary or tuple whose values are one of the above types
* a Pandas DataFrame
:param shuffle: if this evaluates to True, data will be wrapped in a shuffle,
exposing the ShuffledDataset interface.
Can be:
* anything evaluating to False
* 'identity' (initial shuffle is the identity)
* 'random' (initial shuffle is random)
:param random_seed: a random seed to pass to Numpy's shuffle algorithm.
If None (the default), Numpy gets entropy from the OS.
:param recursive: if True, data like MutableSequences or TeachableDatasets that
already expose the needed interface, will still be wrapped;
if False, such data will be returned as-is, with no new object
created.
"""
if shuffle:
return ShuffledDataset(
TeachableDataset.from_data(
data, recursive=False, convert_sequences=convert_sequences, **kwargs
),
shuffle=shuffle,
random_seed=random_seed,
)
elif data is None or isinstance(data, dict):
return DictDataset(data, convert_sequences=convert_sequences, **kwargs)
elif convert_sequences and isinstance(data, MutableSequence):
return NumpyDataset(np.asarray(data), **kwargs)
elif isinstance(data, TeachableDataset) or isinstance(data, MutableSequence):
return TeachableWrapperDataset(data, **kwargs) if recursive else data
elif isinstance(data, np.ndarray):
return NumpyDataset(data, **kwargs)
elif isinstance(data, tuple):
return TupleDataset(data, convert_sequences=convert_sequences, **kwargs)
elif "torch" in str(type(data)):
return TorchDataset(data, **kwargs)
elif "DataFrame" in str(type(data)):
return DictDataset({k: data[k].values for k in data.columns})
elif "deepchem" in str(type(data)):
return TeachableDataset.from_deepchem(data)
else:
warnings.warn("Passing an unknown data format into TeachableDataset.")
return TeachableWrapperDataset(data)
[docs] @staticmethod
def from_deepchem(data):
try:
# pylint: disable=import-outside-toplevel
import deepchem
assert isinstance(data, deepchem.data.Dataset)
from .deepchem import DeepChemDataset
return DeepChemDataset(data)
except Exception as exc:
raise NotImplementedError(
"We thought this was a DeepChem dataset, but apparently not!"
) from exc
[docs] def get_shuffle(self, shuffle="random", random_seed=None):
"""Return a shuffled version of self
Args:
shuffle (str, optional): The initial shuffle - `'identity'` or `'random'`. Defaults to `'random`'.
random_seed (int, optional): A random seed for the shuffle. Defaults to None.
Returns:
ShuffledDataset: A shuffled version of `self`
"""
return ShuffledDataset(self, shuffle=shuffle, random_seed=random_seed)
[docs]class TeachableWrapperDataset(TeachableDataset):
"""
Wraps another dataset-like object.
Functions as an abstract base class for wrapping specific data types.
Also functions concretely as the default wrapper for MutableSequences,
other TeachableDatasets, and anything else which exposes a suitable
interface.
"""
def __init__(self, data, **kwargs):
super().__init__(**kwargs)
self.data = data
[docs] def append(self, x):
val = self.data.append(x)
if val is not None:
self.data = val
[docs] def extend(self, X):
try:
val = self.data.extend(X)
if val is not None:
self.data = val
except AttributeError:
super().extend(X)
[docs] def find(self, value: Any, first: bool = True):
# Raising NotImplementedError to avoid missing abstract method.
raise NotImplementedError
[docs] def reshape_features(self, *shape, index=None):
return self.reshape(*shape, index=add_slice(index, self.bdim), bdim=self.bdim)
[docs] def reshape_batch(self, *shape, index=None):
if index is None:
index = slice(0, self.bdim)
bdim = len(shape)
elif isinstance(index, slice):
assert index.step is None or index.step == 1
index = slice(index.start, min(index.stop, self.bdim))
bdim = self.bdim + len(shape) - (index.stop - index.start)
return self.reshape(*shape, index=index, bdim=bdim)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
if isinstance(index, tuple):
i0, *i = index
if i0 == ...:
try:
# Assume we can push '...' onto the rows
return [row[(..., *i)] for row in self.data]
except LookupError:
# Apparently not
return [row[i] for row in self.data]
try:
# Assume i0 is an integer, so self[i0] will be
# a single row
return self[int(i0)][i]
except (ValueError, TypeError):
# Apparently not
return [row[i] for row in self[i0]]
try:
return self.data[index]
except LookupError:
return [self.data[i] for i in index]
def _ignore__iter__(self):
try:
return iter(self.data)
except TypeError:
return super().__iter__()
@property
def shape(self):
return self.data.shape
[docs]class ShuffledDataset(TeachableWrapperDataset):
"""
Presents a shuffle of an existing dataset (or MutableSequence)
Added data goes at the end and isn't shuffled (until reshuffle() is called).
:param data: the existing dataset to wrap
:param shuffle: determines the initial shuffle state: 'random' or 'identity'
:param random_seed: random seed to pass to the numpy shuffle algorithm.
If None, get a source of randomness from the OS.
"""
def __init__(
self,
data,
shuffle="random",
random_seed: Optional[Union[int, ArrayLike, SeedSequence, BitGenerator, Generator]] = None,
recursive=False,
bdim=1,
):
assert bdim == 1, "ShuffledDataset is only possible with one batch dimension."
super().__init__(data)
self.rng = np.random.default_rng(random_seed)
if (not recursive) and isinstance(data, ShuffledDataset):
self.data = data[data.shuffle]
if isinstance(shuffle, np.ndarray):
assert len(shuffle) == len(self.data), "Supplied shuffle must be same length as data!"
self.shuffle = shuffle
elif shuffle == "identity" or not shuffle:
self.shuffle = np.arange(len(self.data))
else: # shuffle == 'random' OR any True-valued
self.shuffle = np.arange(len(self.data))
self.reshuffle()
[docs] def reshuffle(
self,
# random_seed: Optional[
# Union[int, ArrayLike, SeedSequence, BitGenerator, Generator]
# ] = None,
):
"""Reshuffles self with self.rng."""
# TODO: random_seed is not used here. Should remove or refactor to use it
self.rng.shuffle(self.shuffle)
[docs] def extend_shuffle(self):
"""Extend self.shuffle with [len(self.shuffle), ..., len(self.data)]."""
len_shuffle, len_data = len(self.shuffle), len(self.data)
if len_shuffle < len_data:
self.shuffle = np.append(self.shuffle, np.arange(len_shuffle, len_data))
def __getitem__(self, index):
self.extend_shuffle()
if isinstance(index, tuple):
i0, *i = index
return self.data[(self.shuffle[i0], *i)]
return self.data[self.shuffle[index]]
[docs] def find(self, value: Any, first: bool = True):
"""Return index(es) of value in self.
Args:
value (Any): value to look for
first (bool, optional): whether to return first instance of value or all of them. Defaults to True.
Returns:
_type_: _description_
"""
i = self.data.find(value, first)
if first:
return i if i is None else self.shuffle[i]
else:
return i if len(i) == 0 else self.shuffle[i]
def __iter__(self):
self.extend_shuffle()
return iter(TeachableDataset.from_data(self.data[self.shuffle]))
def __array__(self, dtype=None):
"Converts to a Numpy array"
return np.array(self.data, dtype=dtype)[self.shuffle]
@property
def X(self):
X = ShuffledDataset(self.data.X, shuffle=self.shuffle)
X.rng = None
return X
@property
def y(self):
y = ShuffledDataset(self.data.y, shuffle=self.shuffle)
y.rng = None
return y
[docs]def compute_bdim(old_shape, old_bdim, new_shape):
b_size = np.prod(old_shape[:old_bdim])
size = 1
for bdim, d in enumerate(new_shape):
size *= d
if size == b_size:
return bdim + (size == b_size)
elif size > b_size:
raise ValueError("New shape must have initial axes with total size equal to the original batch size.")
[docs]class ArrayDataset(TeachableWrapperDataset):
"""
Abstract base class for datasets based on numpy, pytorch,
or other similarly-interfaced arrays.
"""
def __getitem__(self, index):
bdim = self.bdim
if isint(index):
bdim -= 1
elif isinstance(index, tuple):
for i in index[: self.bdim]:
bdim -= isint(i)
if bdim > 0:
return self.__class__(self.data[index], bdim=bdim)
return self.data[index]
def __setitem__(self, index, value):
self.data[index] = value
[docs] def append(self, x):
self.extend(np.array(x)[None, ...])
[docs] def find(self, value, first=True):
matches = self.data == value
# remove extra dimensions
for _ in range(matches.ndim - self.bdim):
matches = np.all(np.array(matches), axis=-1)
index = np.argwhere(matches)[:, 0]
if first:
# take only the first match:
index = None if len(index) == 0 else index[0]
return index
def __array__(self, dtype=None):
return np.asarray(self.data, dtype=dtype)
[docs] def reshape(self, *shape, index=None, bdim=None):
if index is not None:
assert index.step is None or index.step == 1
shape = self.shape[: index.start] + shape + self.shape[index.stop :]
if bdim is None:
bdim = compute_bdim(self.shape, self.bdim, shape)
return self.__class__(reshape(self.data, shape), bdim=bdim)
[docs]class NumpyDataset(ArrayDataset):
"""Dataset with Numpy array as data."""
[docs] def extend(self, X):
self.data = np.append(self.data, np.asarray(X), axis=0)
def __array__(self, dtype=None):
return self.data if dtype is None else self.data.astype(dtype, copy=False)
[docs]class TorchDataset(ArrayDataset):
"""Dataset with torch.tensor as data."""
[docs] def extend(self, X):
import torch
if isinstance(X, Dataset):
X = X.data
self.data = torch.cat((self.data, torch.tensor(X)), axis=0)
def __array__(self, dtype=None):
data = self.data.numpy(force=True)
return data if dtype is None else data.astype(dtype, copy=False)
[docs]class DictDataset(TeachableWrapperDataset):
"""
Contains a dictionary whose values are datasets.
For indexing purposes, the first `self.bdim` axes (i.e., the
batch dimensions) index into the first axes of the constituent
datasets, whereas the dictionary key "dimension" occurs right after
the batch dimensions. Since there is usually exactly one batch
dimension, this means you can index like
>>> dataset[:20, 'X']
which will return the first 20 rows of the `'X'` constituent dataset,
whereas
>>> dataset[:20]
will take the first 20 rows of each constituent dataset, and package
them into a new `DictDataset` with the same keys.
"""
def __init__(self,
data={}, # NOSONAR
convert_sequences=True,
bdim=1,
has_Xy=None,
**kw_data
):
data = update_copy(data, kw_data) # NOSONAR
super().__init__(None, bdim=bdim,
has_Xy=bool({'X','x','y'} & set(data)) if has_Xy is None else has_Xy)
self.data = {
k: TeachableDataset.from_data(
d, recursive=False, convert_sequences=convert_sequences, bdim=bdim
)
for k, d in data.items()
}
[docs] def append(self, x):
for key in self.data.keys():
self.data[key].append(x[key])
[docs] def extend(self, X):
if isinstance(X, DictDataset):
X = X.data
for key in self.data.keys():
self.data[key].extend(X[key])
[docs] def reshape(self, *shape, index=None, bdim=None):
if bdim is None:
if index is not None:
assert isinstance(index, slice)
new_shape = self.shape[: index.start] + shape + self.shape[index.stop :]
else:
new_shape = shape
bdim = compute_bdim(self.shape, self.bdim, new_shape)
if shape[bdim] != len(self.data):
raise ValueError("When reshaping a DictDataset, the first non-batch dimension must equal the number of keys.")
shape = shape[:bdim] + shape[bdim+1:]
return self.__class__(
{k: reshape(v, shape, index) for k, v in self.data.items()}, bdim=bdim
)
def __getitem__(self, index):
if isinstance(index, tuple) and len(index) > self.bdim:
# i is the indices into each dataset in the dictionary
i = index[: self.bdim] + index[self.bdim + 1 :]
# k is the dict key(s)
k = index[self.bdim]
if k == slice(None, None):
k = self.data.keys()
elif not isinstance(k, MutableSequence):
# single dict key, so return its value
return self.data[k][i]
else:
i = index
k = self.data.keys()
sub_data = {key: self.data[key][i] for key in k}
bdim = getattr(next(iter(sub_data.values())), "bdim", 0)
if bdim == 0: # batch is fully-indexed, so we return a dict
return sub_data
else: # some batch indices remain, so return a DictDataset
return self.__class__(sub_data, bdim=bdim)
def __setitem__(self, index, value):
raise NotImplementedError
def __iter__(self):
for i in np.ndindex(self.shape[:self.bdim]):
yield {k: v[i] for k, v in self.data.items()}
def __len__(self):
return len(next(iter(self.data.values())))
def __setattr__(self, name, value):
if name in {"data", "bdim", "has_Xy"} or name[:2] == "__":
object.__setattr__(self, name, value)
else:
self.data[name] = value
def __getattr__(self, name):
try:
return self.data[name]
except (IndexError, TypeError, KeyError):
raise AttributeError
[docs] def find(self, value, first=True):
indices = tuple(self.data[k].find(value[k], first=False) for k in value.keys())
while len(indices) > 1:
indices = (
np.intersect1d(indices[0], indices[1], assume_unique=True),
*(indices[2:]),
)
index = indices[0]
if first:
index = None if len(index) == 0 else index[0]
return index
@property
def X(self):
self.check_Xy()
return self.data["X"]
@property
def y(self):
self.check_Xy()
return self.data["y"]
@property
def shape(self):
inner_shape = next(iter(self.data.values())).shape
return inner_shape[: self.bdim] + (len(self.data),) + inner_shape[self.bdim :]
@property
def ndim(self):
return next(iter(self.data.values())).ndim + 1
[docs]class TupleDataset(TeachableWrapperDataset):
"""Dataset with Tuple as self.data."""
def __init__(self, data, convert_sequences=True, bdim=1):
super().__init__(None, bdim=bdim)
self.data = tuple(
TeachableDataset.from_data(
d, recursive=False, convert_sequences=convert_sequences, bdim=bdim
)
for d in data
)
[docs] def append(self, x):
for data_n, x_n in zip(self.data, x):
data_n.append(x_n)
[docs] def extend(self, X):
for data_n, x_n in zip(self.data, X):
data_n.extend(x_n)
[docs] def reshape(self, *shape, index=None, bdim=None):
if bdim is None:
self_shape = self.data[0].shape
if index is not None:
assert isinstance(index, slice)
new_shape = self_shape[: index.start] + shape + self_shape[index.stop :]
else:
new_shape = shape
bdim = compute_bdim(self_shape, self.bdim, new_shape)
if shape[bdim] != len(self.data):
raise ValueError("When reshaping a TupleDataset, the first non-batch dimension must equal the number of keys.")
shape = shape[:bdim] + shape[bdim+1:]
return self.__class__(tuple(reshape(v, shape, index) for v in self.data), bdim=bdim)
def __getitem__(self, index):
# Case 1: indexing multiple axes
if isinstance(index, tuple) and len(index) > self.bdim:
# i is the indices into each dataset in the tuple
i = index[: self.bdim] + index[self.bdim + 1 :]
# k is the tuple key(s)
k = index[self.bdim]
if isint(k):
# returning a single dataset in the tuple
return self.data[k][i]
elif isinstance(k, slice):
# select a slice of the tuple
sub_data = tuple(d[i] for d in self.data[k])
else:
# selecting multiple elements of the tuple
# TODO: d is undefined here
sub_data = tuple(d[key][i] for key in k)
else:
sub_data = tuple(d[index] for d in self.data)
bdim = getattr(sub_data[0], "bdim", 0)
if bdim == 0: # batch is fully-indexed, so we return a tuple
return sub_data
else: # some batch indices remain, so return a TupleDataset
return self.__class__(sub_data, bdim=bdim)
def __iter__(self):
return zip(*(self.data))
def __len__(self):
return len(self.data[0])
def __array__(self, dtype=None):
arrays = list(np.asarray(X_n, dtype=dtype) for X_n in self.data)
max_dim = max(a.ndim for a in arrays)
for i, arr in enumerate(arrays):
while arr.ndim < max_dim:
arr = np.expand_dims(arr, 1)
arrays[i] = arr
if concatenate:
return np.concatenate(arrays, axis=1)
else:
return np.stack(arrays, axis=1)
[docs] def find(self, value, first=True):
indices = tuple(d_n.find(v_n, first=False) for d_n, v_n in zip(self.data, value))
while len(indices) > 1:
indices = (
np.intersect1d(indices[0], indices[1], assume_unique=True),
*(indices[2:]),
)
index = indices[0]
if first:
index = None if len(index) == 0 else index[0]
return index
@property
def tuple(self):
"""Getter for self.data."""
return self.data
@property
def shape(self):
inner_shape = self.data[0].shape
return inner_shape[:self.bdim] + (len(self.data),) + (inner_shape[self.bdim:])
@property
def X(self):
self.check_Xy()
X = self.data[:-1]
return TupleDataset(X) if len(X) > 1 else X[0]
@property
def y(self):
self.check_Xy()
return self.data[-1]