banhxeo.train.trainer module

class banhxeo.train.trainer.Trainer(model: NeuralLanguageModel, config: TrainerConfig, train_dataset: TorchTextDataset, eval_dataset: TorchTextDataset, train_step_fn: Callable[[Trainer, NeuralLanguageModel, Dict[str, Tensor]], Dict[str, Any]], eval_step_fn: Callable[[Trainer, NeuralLanguageModel, Dict[str, Tensor]], Dict[str, Any]] | None = None, device: device | str | None = None, callbacks: List[TrainerCallback] | None = None, collate_fn: Callable[[List[Dict[str, Tensor]]], Any] | None = None, **kwargs)[source]

Bases: object

Handles the training and evaluation loop for neural language models.

The Trainer orchestrates the training process, including data loading, model optimization, logging, checkpointing, and evaluation, according to a provided TrainerConfig. It supports custom training and evaluation logic via train_step_fn and eval_step_fn callables.

Variables:
  • model (NeuralLanguageModel) – The model to be trained.

  • config (TrainerConfig) – Configuration for the training process.

  • train_dataset (TorchTextDataset) – The dataset for training.

  • eval_dataset (Optional[TorchTextDataset]) – The dataset for evaluation.

  • optimizer (Optional[torch.optim.Optimizer]) – The optimizer instance.

  • loss_fn (Optional[torch.nn.modules.loss._Loss]) – The loss function instance. Renamed from loss to avoid conflict with loss values.

  • scheduler (Optional[torch.optim.lr_scheduler._LRScheduler]) – The LR scheduler.

  • train_step_fn (TrainStepCallable) – User-defined function for a single training step.

  • eval_step_fn (Optional[EvalStepCallable]) – User-defined function for a single eval step.

  • callbacks (List[TrainerCallback]) – List of callbacks to customize training.

  • collate_fn (Optional[Callable]) – Custom collate function for DataLoaders.

  • device (torch.device) – The device (CPU or GPU) where training will occur.

  • global_step (int) – Total number of training steps (optimizer updates or micro-batches) performed.

  • current_epoch (int) – The current training epoch (1-indexed).

  • total_train_steps (int) – The total number of training steps planned across all epochs.

  • best_metric (Optional[float]) – Stores the best evaluation metric value achieved, used for saving the best model (if logic is implemented in a callback).

__init__(model: NeuralLanguageModel, config: TrainerConfig, train_dataset: TorchTextDataset, eval_dataset: TorchTextDataset, train_step_fn: Callable[[Trainer, NeuralLanguageModel, Dict[str, Tensor]], Dict[str, Any]], eval_step_fn: Callable[[Trainer, NeuralLanguageModel, Dict[str, Tensor]], Dict[str, Any]] | None = None, device: device | str | None = None, callbacks: List[TrainerCallback] | None = None, collate_fn: Callable[[List[Dict[str, Tensor]]], Any] | None = None, **kwargs)[source]

Initializes the Trainer.

Parameters:
  • model – The NeuralLanguageModel to train.

  • config – The TrainerConfig specifying training parameters.

  • train_dataset – The training TorchTextDataset.

  • eval_dataset – Optional evaluation TorchTextDataset.

  • train_step_fn – A callable that executes a single training step. It’s responsible for the forward pass, loss calculation, backward pass, and optimizer step (including gradient accumulation if desired). See TrainStepCallable for signature.

  • eval_step_fn – An optional callable that executes a single evaluation step. See EvalStepCallable for signature.

  • device – The device to train on (‘cuda’, ‘mps’, ‘cpu’, or torch.device object). If None, attempts to use GPU, otherwise CPU.

  • callbacks – An optional list of TrainerCallback instances. ProgressCallback is added by default.

  • collate_fn – An optional custom collate function for the DataLoaders.

get_callbacks_output() Dict[str, Any][source]

Collects and returns outputs from all callbacks that provide them.

Returns:

A dictionary where keys are callback names and values are their outputs (obtained via callback.get_output()).

get_train_dataloader() DataLoader[source]

Creates and returns the DataLoader for the training set.

get_eval_dataloader() DataLoader | None[source]

Creates and returns the DataLoader for the evaluation set, if available.

train()[source]

Runs the main training loop.

The loop iterates over epochs and batches, calling train_step_fn for each batch. It handles optimizer creation, logging, callbacks, and optional evaluation during training.

Returns:

A dictionary containing training results, potentially including outputs from callbacks.

evaluate()[source]

Runs the evaluation loop on the eval_dataset.

If eval_dataset or eval_step_fn is not provided, evaluation is skipped. Sets the model to evaluation mode (model.eval()) and disables gradient calculations.

Returns:

A dictionary containing evaluation metrics, or None if evaluation is skipped. Keys typically include “avg_eval_loss” and other metrics returned by eval_step_fn.

save_model(output_dir: str | Path) None[source]

Saves the model, optimizer, scheduler, and trainer configuration.

Parameters:

output_dir – The directory where components will be saved. It will be created if it doesn’t exist.