banhxeo.train.callbacks module

class banhxeo.train.callbacks.TrainerCallback[source]

Bases: object

Abstract base class for trainer callbacks.

Callbacks allow custom actions to be performed at various stages of the training and evaluation process. Subclasses should override specific on_* methods to implement their desired behavior.

Variables:

name (str) – A unique name for the callback, often used for logging or retrieving callback-specific outputs.

name: str = 'base_callback'
get_output() Any[source]

Returns any data accumulated or generated by the callback.

This method can be overridden by subclasses to provide access to callback-specific results (e.g., a list of accuracies, paths to saved models).

Returns:

Any data the callback wishes to expose. Defaults to None.

on_init_end(trainer: Trainer) None[source]

Called at the end of the Trainer’s __init__ method.

on_train_begin(trainer: Trainer) None[source]

Called at the beginning of the train method.

on_train_end(trainer: Trainer, logs: Dict[str, float] | None = None) None[source]

Called at the end of the train method.

Parameters:
  • trainer – The Trainer instance.

  • logs – A dictionary of logs collected at the end of training, such as final average loss.

on_epoch_begin(trainer: Trainer, epoch: int) None[source]

Called at the beginning of each training epoch.

Parameters:
  • trainer – The Trainer instance.

  • epoch – The current epoch number (1-indexed).

on_epoch_end(trainer: Trainer, epoch: int, logs: Dict[str, float] | None = None) None[source]

Called at the end of each training epoch.

Parameters:
  • trainer – The Trainer instance.

  • epoch – The current epoch number.

  • logs – A dictionary of logs for the epoch, such as average epoch loss.

on_step_begin(trainer: Trainer, global_step: int, batch_idx: int) None[source]

Called at the beginning of each training step (batch processing).

Parameters:
  • trainer – The Trainer instance.

  • global_step – The total number of training steps performed so far (optimizer updates if accumulation is handled by trainer, or micro-batches if handled by train_step_fn).

  • batch_idx – The index of the current batch within the current epoch.

on_step_end(trainer: Trainer, global_step: int, batch_idx: int, logs: Dict[str, Any] | None = None) None[source]

Called at the end of each training step.

Parameters:
  • trainer – The Trainer instance.

  • global_step – The current global step number.

  • batch_idx – The index of the current batch.

  • logs – A dictionary of logs from the training step, typically including “loss” and any other metrics returned by train_step_fn.

on_evaluate(trainer: Trainer, metrics: Dict[str, float] | None = None) None[source]

Called after an evaluation loop is completed.

Parameters:
  • trainer – The Trainer instance.

  • metrics – A dictionary of evaluation metrics (e.g., “eval_loss”, “eval_accuracy”).

on_save(trainer: Trainer, checkpoint_dir: Path) None[source]

Called after a model checkpoint has been saved.

Parameters:
  • trainer – The Trainer instance.

  • checkpoint_dir – The path to the directory where the checkpoint was saved.

class banhxeo.train.callbacks.ProgressCallback[source]

Bases: TrainerCallback

A callback that displays training progress using tqdm.

name: str = 'progress'
__init__() None[source]

Initializes the ProgressCallback.

on_train_begin(trainer: Trainer) None[source]

Initializes and displays the training progress bar.

on_epoch_begin(trainer: Trainer, epoch: int) None[source]

Updates the progress bar description for the new epoch.

on_train_end(trainer: Trainer, logs: Dict[str, float] | None = None) None[source]

Closes the training progress bar.

on_step_end(trainer: Trainer, global_step: int, batch_idx: int, logs: Dict[str, Any] | None = None) None[source]

Updates the training progress bar and sets postfix with current logs.

class banhxeo.train.callbacks.AccuracyCallback(log_step: int = 100)[source]

Bases: TrainerCallback

A callback to compute and log accuracy during training and evaluation.

Assumes that the logs dictionary passed to on_step_end (from train_step_fn) contains “correct” (number of correctly classified samples in the batch) and “total” (total samples in the batch). Similarly for metrics in on_evaluate.

Variables:
  • log_step (int) – Frequency (in global steps) at which to log training accuracy.

  • correct (int) – Accumulated count of correct predictions since last log.

  • total (int) – Accumulated count of total predictions since last log.

  • accs (Dict[int, float]) – Dictionary storing training accuracies at logged steps. Key is global_step, value is accuracy percentage.

name: str = 'accuracy'
__init__(log_step: int = 100) None[source]

Initializes the AccuracyCallback.

Parameters:

log_step – Log training accuracy every log_step global steps.

on_step_end(trainer: Trainer, global_step: int, batch_idx: int, logs: Dict[str, Any] | None = None) None[source]

Accumulates correct/total counts from step logs and logs training accuracy.

on_evaluate(trainer: Trainer, metrics: Dict[str, float] | None = None) None[source]

Logs evaluation accuracy from the metrics dictionary.

get_output() Dict[int, float][source]

Returns the dictionary of training accuracies recorded at log_step intervals.

class banhxeo.train.callbacks.CheckpointCallback(save_epoch: bool = True)[source]

Bases: TrainerCallback

A callback that saves model checkpoints during training.

Saves checkpoints at specified step intervals (TrainerConfig.save_steps) and/or at the end of each epoch.

name: str = 'checkpoint'
__init__(save_epoch: bool = True)[source]

Initializes the CheckpointCallback.

Parameters:

save_epoch – Save at the end of epoch (default = True)

on_step_end(trainer: Trainer, global_step: int, batch_idx: int, logs: Dict[str, Any] | None = None) None[source]

Saves a checkpoint if save_steps interval is met.

on_epoch_end(trainer: Trainer, epoch: int, logs: Dict[str, float] | None = None) None[source]

Saves a checkpoint at the end of an epoch.