Skip to content

Base Hyperparams Trainer

Base class for TRL hyperparameter trainers.

This module provides a common base class for all TRL hyperparameter trainer implementations, encapsulating shared functionality across SFT, GRPO, and DPO training approaches.

Authors

Alfan Dinda Rahmawan (alfan.d.rahmawan@gdplabs.id)

Reviewers

Muhammad Afif Al Hawari (muhammad.a.a.hawari@gdplabs.id)

References

NONE

BaseTRLHyperparamsTrainer

Bases: ABC

Base class for TRL hyperparameter trainers.

Provides common functionality for extracting and processing training metrics across all trainer types (SFT, GRPO, DPO). This class encapsulates shared methods for metrics extraction, GPU statistics retrieval, and training results processing.

Subclasses must implement the abstract methods
  • create_training_arguments: Create trainer-specific config
  • create_trainer: Create and configure trainer instance

create_trainer(*args, **kwargs) abstractmethod staticmethod

Create and configure a trainer instance.

Subclasses must implement this method to instantiate the appropriate trainer class with the provided configuration and datasets.

Parameters:

Name Type Description Default
*args Any

Positional arguments (trainer-specific)

()
**kwargs Any

Keyword arguments (trainer-specific)

{}

Returns:

Type Description
Any

Configured trainer instance.

create_training_arguments(*args, **kwargs) abstractmethod staticmethod

Create training configuration for the trainer.

Subclasses must implement this method to create trainer-specific configuration objects (e.g., DPOConfig, GRPOConfig, HfTrainingArguments).

Parameters:

Name Type Description Default
*args Any

Positional arguments (trainer-specific)

()
**kwargs Any

Keyword arguments (trainer-specific)

{}

Returns:

Type Description
Any

Trainer-specific configuration object.

process_training_results(trainer, training_stats_obj) staticmethod

Process training results to extract metrics.

Processes the trainer and training statistics objects to extract and aggregate training metrics, GPU statistics, and final loss values. Attempts to retrieve loss values from log history first, then falls back to the metrics object.

Parameters:

Name Type Description Default
trainer Any

The trainer instance used for training. Must have a state attribute with log_history (list of log dicts).

required
training_stats_obj Any

Training statistics object returned by trainer. Must have a metrics attribute.

required

Returns:

Type Description
Dict[str, Any]

Dictionary containing processed training results with keys: - training_metrics: Extracted training metrics - gpu_stats: GPU statistics - train_loss: Final training loss value - val_loss: Final validation/evaluation loss value