banhxeo.train.callbacks module
- class banhxeo.train.callbacks.TrainerCallback[source]
Bases:
object
Abstract base class for trainer callbacks.
Callbacks allow custom actions to be performed at various stages of the training and evaluation process. Subclasses should override specific on_* methods to implement their desired behavior.
- Variables:
name (str) – A unique name for the callback, often used for logging or retrieving callback-specific outputs.
- name: str = 'base_callback'
- get_output() Any [source]
Returns any data accumulated or generated by the callback.
This method can be overridden by subclasses to provide access to callback-specific results (e.g., a list of accuracies, paths to saved models).
- Returns:
Any data the callback wishes to expose. Defaults to None.
- on_train_end(trainer: Trainer, logs: Dict[str, float] | None = None) None [source]
Called at the end of the train method.
- Parameters:
trainer – The Trainer instance.
logs – A dictionary of logs collected at the end of training, such as final average loss.
- on_epoch_begin(trainer: Trainer, epoch: int) None [source]
Called at the beginning of each training epoch.
- Parameters:
trainer – The Trainer instance.
epoch – The current epoch number (1-indexed).
- on_epoch_end(trainer: Trainer, epoch: int, logs: Dict[str, float] | None = None) None [source]
Called at the end of each training epoch.
- Parameters:
trainer – The Trainer instance.
epoch – The current epoch number.
logs – A dictionary of logs for the epoch, such as average epoch loss.
- on_step_begin(trainer: Trainer, global_step: int, batch_idx: int) None [source]
Called at the beginning of each training step (batch processing).
- Parameters:
trainer – The Trainer instance.
global_step – The total number of training steps performed so far (optimizer updates if accumulation is handled by trainer, or micro-batches if handled by train_step_fn).
batch_idx – The index of the current batch within the current epoch.
- on_step_end(trainer: Trainer, global_step: int, batch_idx: int, logs: Dict[str, Any] | None = None) None [source]
Called at the end of each training step.
- Parameters:
trainer – The Trainer instance.
global_step – The current global step number.
batch_idx – The index of the current batch.
logs – A dictionary of logs from the training step, typically including “loss” and any other metrics returned by train_step_fn.
- class banhxeo.train.callbacks.ProgressCallback[source]
Bases:
TrainerCallback
A callback that displays training progress using tqdm.
- name: str = 'progress'
- on_epoch_begin(trainer: Trainer, epoch: int) None [source]
Updates the progress bar description for the new epoch.
- class banhxeo.train.callbacks.AccuracyCallback(log_step: int = 100)[source]
Bases:
TrainerCallback
A callback to compute and log accuracy during training and evaluation.
Assumes that the logs dictionary passed to on_step_end (from train_step_fn) contains “correct” (number of correctly classified samples in the batch) and “total” (total samples in the batch). Similarly for metrics in on_evaluate.
- Variables:
log_step (int) – Frequency (in global steps) at which to log training accuracy.
correct (int) – Accumulated count of correct predictions since last log.
total (int) – Accumulated count of total predictions since last log.
accs (Dict[int, float]) – Dictionary storing training accuracies at logged steps. Key is global_step, value is accuracy percentage.
- name: str = 'accuracy'
- __init__(log_step: int = 100) None [source]
Initializes the AccuracyCallback.
- Parameters:
log_step – Log training accuracy every log_step global steps.
- on_step_end(trainer: Trainer, global_step: int, batch_idx: int, logs: Dict[str, Any] | None = None) None [source]
Accumulates correct/total counts from step logs and logs training accuracy.
- class banhxeo.train.callbacks.CheckpointCallback(save_epoch: bool = True)[source]
Bases:
TrainerCallback
A callback that saves model checkpoints during training.
Saves checkpoints at specified step intervals (TrainerConfig.save_steps) and/or at the end of each epoch.
- name: str = 'checkpoint'
- __init__(save_epoch: bool = True)[source]
Initializes the CheckpointCallback.
- Parameters:
save_epoch – Save at the end of epoch (default = True)