banhxeo.data.torch module

class banhxeo.data.torch.TorchTextDataset(text_dataset: BaseTextDataset, config: TorchDatasetConfig)[source]

Bases: Dataset

A PyTorch Dataset wrapper for BaseTextDataset.

This class handles the transformation of raw text samples from a BaseTextDataset into tokenized and numericalized PyTorch tensors, ready for model consumption. It applies tokenization, vocabulary mapping, and any specified text transformations.

Variables:
  • text_dataset (BaseTextDataset) – The underlying raw text dataset.

  • config (TorchDatasetConfig) – Configuration specifying tokenizer, vocabulary, text processing, and label handling.

__init__(text_dataset: BaseTextDataset, config: TorchDatasetConfig)[source]

Initializes the TorchTextDataset.

Parameters:
  • text_dataset – The instance of BaseTextDataset containing raw data.

  • config – A TorchDatasetConfig object that defines how to process the raw data into tensors.

to_loader(batch_size: int, num_workers: int, shuffle: bool = True, **kwargs)[source]

Creates a PyTorch DataLoader for this dataset.

Parameters:
  • batch_size – Number of samples per batch.

  • num_workers – Number of subprocesses to use for data loading.

  • shuffle – Whether to shuffle the data at every epoch. If sampler is provided, shuffle must be False (or will be ignored).

  • collate_fn – Custom function to merge a list of samples to form a mini-batch of Tensor(s). If None, uses default PyTorch collate_fn. Note: Default collate_fn works well if __getitem__ returns a dictionary of tensors, which this class does.

  • sampler – Defines the strategy to draw samples from the dataset. If specified, shuffle must be False.

  • **kwargs – Additional arguments passed directly to torch.utils.data.DataLoader. (e.g., pin_memory, drop_last).

Returns:

A PyTorch DataLoader instance.