banhxeo.data.torch module
- class banhxeo.data.torch.TorchTextDataset(text_dataset: BaseTextDataset, config: TorchDatasetConfig)[source]
Bases:
Dataset
A PyTorch Dataset wrapper for BaseTextDataset.
This class handles the transformation of raw text samples from a BaseTextDataset into tokenized and numericalized PyTorch tensors, ready for model consumption. It applies tokenization, vocabulary mapping, and any specified text transformations.
- Variables:
text_dataset (BaseTextDataset) – The underlying raw text dataset.
config (TorchDatasetConfig) – Configuration specifying tokenizer, vocabulary, text processing, and label handling.
- __init__(text_dataset: BaseTextDataset, config: TorchDatasetConfig)[source]
Initializes the TorchTextDataset.
- Parameters:
text_dataset – The instance of BaseTextDataset containing raw data.
config – A TorchDatasetConfig object that defines how to process the raw data into tensors.
- to_loader(batch_size: int, num_workers: int, shuffle: bool = True, **kwargs)[source]
Creates a PyTorch DataLoader for this dataset.
- Parameters:
batch_size – Number of samples per batch.
num_workers – Number of subprocesses to use for data loading.
shuffle – Whether to shuffle the data at every epoch. If sampler is provided, shuffle must be False (or will be ignored).
collate_fn – Custom function to merge a list of samples to form a mini-batch of Tensor(s). If None, uses default PyTorch collate_fn. Note: Default collate_fn works well if __getitem__ returns a dictionary of tensors, which this class does.
sampler – Defines the strategy to draw samples from the dataset. If specified, shuffle must be False.
**kwargs – Additional arguments passed directly to torch.utils.data.DataLoader. (e.g., pin_memory, drop_last).
- Returns:
A PyTorch DataLoader instance.