Source code for geogenie.utils.data

import torch
from torch.utils.data import Dataset


[docs] class CustomDataset(Dataset): """Class to create a custom PyTorch Dataset with sample weighting and sample IDs. This class defines a custom PyTorch Dataset that incorporates sample weighting and sample IDs. Attributes: tensors (tuple): Tuple consisting of (features, labels, sample_weights, sample_ids). """ def __init__( self, features, labels=None, sample_weights=None, sample_ids=None, dtype=torch.float32, ): """Initialize custom PyTorch Dataset that incorporates sample weighting and sample IDs. Args: features (torch.Tensor): Input features. labels (torch.Tensor, optional): Labels corresponding to the features. Defaults to None. sample_weights (torch.Tensor, optional): Weights for each sample. If None, then a sample_weights tensor is still created, but all weights will be equal to 1.0 (equal weighting). Defaults to None. sample_ids (list or array-like, optional): Unique identifiers for each sample. If None, indices will be used as sample IDs. Defaults to None. dtype (torch.dtype): Data type to use with PyTorch. Must be a torch dtype. Defaults to torch.float32. """ self.dtype = dtype self._features = features self._labels = labels self._sample_weights = sample_weights if self._sample_weights is None: self._sample_weights = torch.ones(len(self._features), dtype=self.dtype) self._sample_ids = sample_ids self.tensors = (self.features, self.labels, self.sample_weights) @property def features(self): """Get the features tensor.""" if not isinstance(self._features, torch.Tensor): self._features = torch.tensor(self._features, dtype=self.dtype) return self._features @features.setter def features(self, value): """Set the features tensor.""" if not isinstance(value, torch.Tensor): value = torch.tensor(value, dtype=self.dtype) self._features = value @property def labels(self): """Get the labels tensor.""" if not isinstance(self._labels, torch.Tensor) and self._labels is not None: self._labels = torch.tensor(self._labels, dtype=self.dtype) return self._labels @labels.setter def labels(self, value): """Set the labels tensor.""" if value is not None and not isinstance(value, torch.Tensor): value = torch.tensor(value, dtype=self.dtype) self._labels = value @property def sample_weights(self): """Get the sample weights.""" if ( not isinstance(self._sample_weights, torch.Tensor) and self._sample_weights is not None ): self._sample_weights = torch.tensor(self._sample_weights, dtype=self.dtype) if self._sample_weights is None: self._sample_weights = torch.ones(len(self.features), dtype=self.dtype) return self._sample_weights @sample_weights.setter def sample_weights(self, value): """Set the sample weights.""" if value is None: value = torch.ones(len(self.features), dtype=self.dtype) elif not isinstance(value, torch.Tensor): value = torch.tensor(value, dtype=self.dtype) self._sample_weights = value @property def sample_ids(self): """Get the sample IDs.""" return self._sample_ids @sample_ids.setter def sample_ids(self, value): """Set the sample IDs.""" if value is None: value = list(range(len(self.features))) # Use indices as default sample IDs self._sample_ids = value @property def n_features(self): """Return the number of columns in the features dataset.""" return self.features.shape[1] if self.features.ndimension() > 1 else 1 @property def n_labels(self): """Return the number of columns in the labels dataset.""" return self.labels.shape[1] if self.labels.ndimension() > 1 else 1 def __len__(self): """Return the total number of samples in the dataset.""" return len(self.features) def __getitem__(self, idx): """Retrieve the sample at the given index. Args: idx (int): Index of the sample to retrieve. Returns: tuple: (feature, label, sample_weight, sample_id) for the specified index. """ if self.sample_ids is not None: if self.labels is None: return self.features[idx], self.sample_ids[idx] return ( self.features[idx], self.labels[idx], self.sample_weights[idx], self.sample_ids[idx], ) else: sample_ids = list(range(len(self.features))) if self.labels is None: return self.features[idx], sample_ids[idx] return ( self.features[idx], self.labels[idx], self.sample_weights[idx], sample_ids[idx], )