Base Hyperparams Trainer
Base class for TRL hyperparameter trainers.
This module provides a common base class for all TRL hyperparameter trainer implementations, encapsulating shared functionality across SFT, GRPO, and DPO training approaches.
Reviewers
Muhammad Afif Al Hawari (muhammad.a.a.hawari@gdplabs.id)
References
NONE
BaseTRLHyperparamsTrainer
Bases: ABC
Base class for TRL hyperparameter trainers.
Provides common functionality for extracting and processing training metrics across all trainer types (SFT, GRPO, DPO). This class encapsulates shared methods for metrics extraction, GPU statistics retrieval, and training results processing.
Subclasses must implement the abstract methods
- create_training_arguments: Create trainer-specific config
- create_trainer: Create and configure trainer instance
create_trainer(*args, **kwargs)
abstractmethod
staticmethod
Create and configure a trainer instance.
Subclasses must implement this method to instantiate the appropriate trainer class with the provided configuration and datasets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args |
Any
|
Positional arguments (trainer-specific) |
()
|
**kwargs |
Any
|
Keyword arguments (trainer-specific) |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Configured trainer instance. |
create_training_arguments(*args, **kwargs)
abstractmethod
staticmethod
Create training configuration for the trainer.
Subclasses must implement this method to create trainer-specific configuration objects (e.g., DPOConfig, GRPOConfig, HfTrainingArguments).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args |
Any
|
Positional arguments (trainer-specific) |
()
|
**kwargs |
Any
|
Keyword arguments (trainer-specific) |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Trainer-specific configuration object. |
process_training_results(trainer, training_stats_obj)
staticmethod
Process training results to extract metrics.
Processes the trainer and training statistics objects to extract and aggregate training metrics, GPU statistics, and final loss values. Attempts to retrieve loss values from log history first, then falls back to the metrics object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer |
Any
|
The trainer instance used for training. Must have a |
required |
training_stats_obj |
Any
|
Training statistics object returned by trainer.
Must have a |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dictionary containing processed training results with keys: - training_metrics: Extracted training metrics - gpu_stats: GPU statistics - train_loss: Final training loss value - val_loss: Final validation/evaluation loss value |