> ## Documentation Index
> Fetch the complete documentation index at: https://rllm-org-rllm-19-feat-renderer-parser-backend.mintlify.site/llms.txt
> Use this file to discover all available pages before exploring further.

# Trainer

> Training infrastructure for RL-based agent learning

The trainer module provides high-level APIs for training agents using reinforcement learning with PPO.

## AgentTrainer

Wrapper class for training agents with custom environments using various backends.

```python theme={null}
from rllm.trainer import AgentTrainer
```

### Constructor

```python theme={null}
def __init__(
    workflow_class: type | None = None,
    workflow_args: dict[str, Any] | None = None,
    config: dict[str, Any] | list[str] | None = None,
    train_dataset: Dataset | None = None,
    val_dataset: Dataset | None = None,
    backend: Literal["verl", "fireworks", "tinker"] = "verl",
    agent_run_func: Callable | None = None
)
```

Two ways to plug in your agent: pass either `workflow_class` (a `Workflow` subclass) **or** `agent_run_func` (a plain rollout function for the AgentSdk path).

<ParamField path="workflow_class" type="type | None">
  Workflow class to use for training (e.g., `SimpleWorkflow`, `MultiTurnWorkflow`).
</ParamField>

<ParamField path="workflow_args" type="dict | None">
  Arguments to pass to the workflow class.
</ParamField>

<ParamField path="config" type="dict | list[str] | None">
  Configuration overrides. Can be:

  * Dictionary with dot notation keys: `{"data.train_batch_size": 8}`
  * List of strings: `["data.train_batch_size=8", "trainer.total_epochs=3"]`
</ParamField>

<ParamField path="train_dataset" type="Dataset | None">
  Training dataset.
</ParamField>

<ParamField path="val_dataset" type="Dataset | None">
  Validation dataset.
</ParamField>

<ParamField path="backend" type="Literal['verl', 'fireworks', 'tinker']" default="verl">
  Training backend:

  * `"verl"`: Standard distributed PPO via the verl framework
  * `"fireworks"`: Pipeline-based variant (workflow-only) for the Fireworks workflow API
  * `"tinker"`: Single-machine LoRA training via tinker (workflow-only)
</ParamField>

<ParamField path="agent_run_func" type="Callable | None">
  Plain rollout function — drives the AgentSdk path. Use this **or**
  `workflow_class`, not both.
</ParamField>

<Note>
  The legacy `agent_class` + `env_class` parameters that drove the
  `AgentExecutionEngine` rollout have been removed. Port your agent to
  either a [Workflow](/api/workflows) or an
  [AgentFlow](/core-concepts/agentflow-evaluator) — see the
  [`cookbooks/`](https://github.com/rllm-org/rllm/tree/main/cookbooks)
  directory for examples.
</Note>

### Methods

#### train

Start the training process.

```python theme={null}
trainer.train()
```

***

## Configuration

The trainer uses Hydra for configuration management. Default config is at `rllm/trainer/config/agent_ppo_trainer.yaml`.

### Common Config Overrides

```python theme={null}
config = {
    # Data settings
    "data.train_batch_size": 512,
    "data.val_batch_size": 1024,
    
    # Training settings
    "trainer.total_epochs": 3,
    "trainer.total_training_steps": 1000,
    
    # PPO hyperparameters
    "algorithm.gamma": 1.0,
    "algorithm.lam": 0.95,
    "algorithm.kl_penalty": 0.001,
    
    # Model settings
    "actor_rollout_ref.model.path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    
    # GRPO settings
    "algorithm.adv_estimator": "grpo",
    "algorithm.num_samples_per_prompt": 4,
}
```

***

## Example: Training with SimpleWorkflow

```python theme={null}
import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn
from rllm.data import DatasetRegistry

@hydra.main(
    config_path="pkg://rllm.trainer.config",
    config_name="agent_ppo_trainer",
    version_base=None
)
def main(config):
    # Load datasets
    train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
    val_dataset = DatasetRegistry.load_dataset("math500", "test")
    
    # Create trainer
    trainer = AgentTrainer(
        workflow_class=SimpleWorkflow,
        workflow_args={
            "reward_function": math_reward_fn,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        backend="verl"
    )
    
    # Start training
    trainer.train()

if __name__ == "__main__":
    main()
```

***

## Example: Config Overrides

```python theme={null}
import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import MultiTurnWorkflow
from rllm.data import DatasetRegistry

@hydra.main(
    config_path="pkg://rllm.trainer.config",
    config_name="agent_ppo_trainer",
    version_base=None
)
def main(config):
    # Override config values
    config_overrides = {
        "data.train_batch_size": 256,
        "data.val_batch_size": 512,
        "trainer.total_epochs": 5,
        "algorithm.gamma": 1.0,
        "algorithm.num_samples_per_prompt": 8,
        "actor_rollout_ref.model.path": "Qwen/Qwen3-4B",
    }
    
    # Apply overrides
    for key, value in config_overrides.items():
        keys = key.split(".")
        cfg = config
        for k in keys[:-1]:
            cfg = getattr(cfg, k)
        setattr(cfg, keys[-1], value)
    
    trainer = AgentTrainer(
        workflow_class=MultiTurnWorkflow,
        workflow_args={
            # Substitute your own BaseAgent / BaseEnv subclasses.
            "agent_cls": MyAgent,
            "env_cls": MyEnv,
            "max_steps": 5,
        },
        config=config,
        train_dataset=DatasetRegistry.load_dataset("mydata", "train"),
        val_dataset=DatasetRegistry.load_dataset("mydata", "val"),
        backend="verl",
    )
    
    trainer.train()

if __name__ == "__main__":
    main()
```

***

## Running Training

Run training scripts with Hydra CLI overrides:

```bash theme={null}
# Basic training
python train.py

# Override config from command line
python train.py data.train_batch_size=512 trainer.total_epochs=5

# Use different model
python train.py actor_rollout_ref.model.path=Qwen/Qwen3-4B

# Adjust PPO hyperparameters
python train.py algorithm.gamma=1.0 algorithm.num_samples_per_prompt=8
```
