banhxeo.train.trainer module
- class banhxeo.train.trainer.Trainer(model: NeuralLanguageModel, config: TrainerConfig, train_dataset: TorchTextDataset, eval_dataset: TorchTextDataset, train_step_fn: Callable[[Trainer, NeuralLanguageModel, Dict[str, Tensor]], Dict[str, Any]], eval_step_fn: Callable[[Trainer, NeuralLanguageModel, Dict[str, Tensor]], Dict[str, Any]] | None = None, device: device | str | None = None, callbacks: List[TrainerCallback] | None = None, collate_fn: Callable[[List[Dict[str, Tensor]]], Any] | None = None, **kwargs)[source]
Bases:
object
Handles the training and evaluation loop for neural language models.
The Trainer orchestrates the training process, including data loading, model optimization, logging, checkpointing, and evaluation, according to a provided TrainerConfig. It supports custom training and evaluation logic via train_step_fn and eval_step_fn callables.
- Variables:
model (NeuralLanguageModel) – The model to be trained.
config (TrainerConfig) – Configuration for the training process.
train_dataset (TorchTextDataset) – The dataset for training.
eval_dataset (Optional[TorchTextDataset]) – The dataset for evaluation.
optimizer (Optional[torch.optim.Optimizer]) – The optimizer instance.
loss_fn (Optional[torch.nn.modules.loss._Loss]) – The loss function instance. Renamed from loss to avoid conflict with loss values.
scheduler (Optional[torch.optim.lr_scheduler._LRScheduler]) – The LR scheduler.
train_step_fn (TrainStepCallable) – User-defined function for a single training step.
eval_step_fn (Optional[EvalStepCallable]) – User-defined function for a single eval step.
callbacks (List[TrainerCallback]) – List of callbacks to customize training.
collate_fn (Optional[Callable]) – Custom collate function for DataLoaders.
device (torch.device) – The device (CPU or GPU) where training will occur.
global_step (int) – Total number of training steps (optimizer updates or micro-batches) performed.
current_epoch (int) – The current training epoch (1-indexed).
total_train_steps (int) – The total number of training steps planned across all epochs.
best_metric (Optional[float]) – Stores the best evaluation metric value achieved, used for saving the best model (if logic is implemented in a callback).
- __init__(model: NeuralLanguageModel, config: TrainerConfig, train_dataset: TorchTextDataset, eval_dataset: TorchTextDataset, train_step_fn: Callable[[Trainer, NeuralLanguageModel, Dict[str, Tensor]], Dict[str, Any]], eval_step_fn: Callable[[Trainer, NeuralLanguageModel, Dict[str, Tensor]], Dict[str, Any]] | None = None, device: device | str | None = None, callbacks: List[TrainerCallback] | None = None, collate_fn: Callable[[List[Dict[str, Tensor]]], Any] | None = None, **kwargs)[source]
Initializes the Trainer.
- Parameters:
model – The NeuralLanguageModel to train.
config – The TrainerConfig specifying training parameters.
train_dataset – The training TorchTextDataset.
eval_dataset – Optional evaluation TorchTextDataset.
train_step_fn – A callable that executes a single training step. It’s responsible for the forward pass, loss calculation, backward pass, and optimizer step (including gradient accumulation if desired). See TrainStepCallable for signature.
eval_step_fn – An optional callable that executes a single evaluation step. See EvalStepCallable for signature.
device – The device to train on (‘cuda’, ‘mps’, ‘cpu’, or torch.device object). If None, attempts to use GPU, otherwise CPU.
callbacks – An optional list of TrainerCallback instances. ProgressCallback is added by default.
collate_fn – An optional custom collate function for the DataLoaders.
- get_callbacks_output() Dict[str, Any] [source]
Collects and returns outputs from all callbacks that provide them.
- Returns:
A dictionary where keys are callback names and values are their outputs (obtained via callback.get_output()).
- get_train_dataloader() DataLoader [source]
Creates and returns the DataLoader for the training set.
- get_eval_dataloader() DataLoader | None [source]
Creates and returns the DataLoader for the evaluation set, if available.
- train()[source]
Runs the main training loop.
The loop iterates over epochs and batches, calling train_step_fn for each batch. It handles optimizer creation, logging, callbacks, and optional evaluation during training.
- Returns:
A dictionary containing training results, potentially including outputs from callbacks.
- evaluate()[source]
Runs the evaluation loop on the eval_dataset.
If eval_dataset or eval_step_fn is not provided, evaluation is skipped. Sets the model to evaluation mode (model.eval()) and disables gradient calculations.
- Returns:
A dictionary containing evaluation metrics, or None if evaluation is skipped. Keys typically include “avg_eval_loss” and other metrics returned by eval_step_fn.