Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
75 changes: 75 additions & 0 deletions examples/RAFT_alfworld/RAFT_alfworld_7B.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
project: "Trinity-RFT-RAFT-alfworld"
name: "qwen2.5-7B-RAFT-alfworld"
mode: both
checkpoint_root_dir: /PATH/TO/CHECKPOINT/RAFT_ALFWORLD/
algorithm:
algorithm_type: raft
repeat_times: 1
model:
model_path: /PATH/TO/MODEL/
max_response_tokens: 4096
max_model_len: 20480
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 30
batch_size: 80
max_retry_times: 1
max_retry_interval: 1
explorer_input:
taskset:
name: alfworld-train
storage_type: file
path: '/PATH/TO/ALFWORLD_DATA/'
split: 'train'
format:
prompt_key: 'game_file'
response_key: 'task_desc'
rollout_args:
temperature: 0.85
top_k: 25
top_p: 0.95
eval_tasksets:
- name: alfworld-eval
storage_type: file
path: '/PATH/TO/ALFWORLD_DATA/'
split: 'test'
format:
prompt_key: 'game_file'
response_key: 'task_desc'
rollout_args:
temperature: 0.85
top_k: 25
top_p: 0.95
default_workflow_type: 'RAFT_alfworld_workflow'
trainer_input:
experience_buffer:
name: RAFT_buffer
storage_type: queue
path: 'sqlite:///RAFT_alfworld.db'
explorer:
eval_interval: 30
runner_num: 100
max_timeout: 3000 # Increased timeout for alfworld environment interactions
max_retry_times: 2
rollout_model:
engine_type: vllm_async
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: true
enforce_eager: false
dtype: bfloat16
gpu_memory_utilization: 0.86
seed: 42
synchronizer:
sync_style: dynamic_by_explorer
sync_method: 'nccl'
sync_interval: 4
sync_timeout: 1200
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/RAFT_alfworld/train_alfworld.yaml'
save_interval: 100000
monitor:
monitor_type: wandb
75 changes: 75 additions & 0 deletions examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
project: "Trinity-RFT-RAFT-reflect-alfworld"
name: "qwen2.5-7B-RAFT-reflect-alfworld"
mode: both
checkpoint_root_dir: /PATH/TO/CHECKPOINT/RAFT_REFLECT_ALFWORLD/
algorithm:
algorithm_type: raft
repeat_times: 1
model:
model_path: /PATH/TO/MODEL/
max_response_tokens: 4096
max_model_len: 20480
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 30
batch_size: 80
max_retry_times: 1
max_retry_interval: 1
explorer_input:
taskset:
name: alfworld-train
storage_type: file
path: '/PATH/TO/ALFWORLD_DATA/'
split: 'train'
format:
prompt_key: 'game_file'
response_key: 'task_desc'
rollout_args:
temperature: 0.85
top_k: 25
top_p: 0.95
eval_tasksets:
- name: alfworld-eval
storage_type: file
path: '/PATH/TO/ALFWORLD_DATA/'
split: 'test'
format:
prompt_key: 'game_file'
response_key: 'task_desc'
rollout_args:
temperature: 0.85
top_k: 25
top_p: 0.95
default_workflow_type: 'RAFT_reflect_alfworld_workflow'
trainer_input:
experience_buffer:
name: RAFT_buffer
storage_type: queue
path: 'sqlite:///RAFT_reflect_alfworld.db'
explorer:
eval_interval: 30
runner_num: 100
max_timeout: 3000 # Increased timeout for alfworld environment interactions
max_retry_times: 2
rollout_model:
engine_type: vllm_async
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: false
dtype: bfloat16
gpu_memory_utilization: 0.86
seed: 42
synchronizer:
sync_style: dynamic_by_explorer
sync_method: 'nccl'
sync_interval: 4
sync_timeout: 1200
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/RAFT_alfworld/train_alfworld.yaml'
save_interval: 100000
monitor:
monitor_type: wandb
33 changes: 33 additions & 0 deletions examples/RAFT_alfworld/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# RAFT on ALFWorld Dataset

This example shows the usage of RAFT on the ALFWorld dataset, with both standard and reflection-enhanced variants.

![RAFT ALFWorld Reward Curve](../../docs/sphinx_doc/assets/RAFT_alfworld_reward_curve.png)

## Variants

### Standard RAFT
The config files are located in [`RAFT_alfworld_7B.yaml`](RAFT_alfworld_7B.yaml) and [`train_alfworld.yaml`](train_alfworld.yaml).

### RAFT with Reflection
The config files are located in [`RAFT_reflect_alfworld_7B.yaml`](RAFT_reflect_alfworld_7B.yaml) and [`train_alfworld.yaml`](train_alfworld.yaml).

## Setup

### Data Preparation
To prepare the ALFWorld dataset, run:
```bash
python examples/grpo_alfworld/get_alfworld_data.py
```

### Configuration
Before running, make sure to update the following paths in the YAML files:
- `model.model_path`: Replace with your model path (e.g., `/PATH/TO/MODEL/`)
- `buffer.explorer_input.taskset.path`: Replace with your alfworld dataset path
- `buffer.explorer_input.eval_tasksets[0].path`: Replace with your alfworld dataset path
- `checkpoint_root_dir`: Replace with your desired checkpoint directory

## Implementation
The workflow implementations are located in:
- Standard RAFT: [`trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py`](../../trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py)
- RAFT with Reflection: [`trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py`](../../trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py)
45 changes: 45 additions & 0 deletions examples/RAFT_alfworld/train_alfworld.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
actor_rollout_ref:
hybrid_engine: True
model:
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True
actor:
strategy: fsdp
ppo_mini_batch_size: 16
ppo_micro_batch_size_per_gpu: 1
use_dynamic_bsz: True
ppo_max_token_len_per_gpu: 20000 # Adjusted for alfworld longer sequences
grad_clip: 1.0
clip_ratio: 0.2
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0.
total_training_steps: -1 # Will be overridden by program
fsdp_config:
wrap_policy:
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
min_num_params: 0
log_prob_micro_batch_size_per_gpu: 8
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}

trainer:
balance_batch: True
resume_mode: disable
default_hdfs_dir: null
remove_previous_ckpt_in_save: True
del_local_ckpt_after_load: False
val_before_train: False
6 changes: 3 additions & 3 deletions tests/utils/eval_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_extract_answer(self):
]

for i, (input_str, expected_output, description) in enumerate(test_cases):
with self.subTest(f"Case {i+1}: {description}"):
with self.subTest(f"Case {i + 1}: {description}"):
actual_output = extract_answer(input_str)
self.assertEqual(
actual_output,
Expand All @@ -58,7 +58,7 @@ def test_verify_math_answer(self):
]

for i, (response, ground_truth, expected_correct, description) in enumerate(test_cases):
with self.subTest(f"Case {i+1}: {description}"):
with self.subTest(f"Case {i + 1}: {description}"):
accuracy, details = verify_math_answer(response, ground_truth)
is_correct = accuracy == 1.0
self.assertEqual(
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_is_equiv(self):
]

for i, (str1, str2, expected_output, description) in enumerate(test_cases):
with self.subTest(f"Case {i+1}: {description}"):
with self.subTest(f"Case {i + 1}: {description}"):
actual_output = is_equiv(str1, str2)
self.assertEqual(
actual_output,
Expand Down
23 changes: 23 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,26 @@ def default_config(cls) -> Dict:
"advantage_fn": "grpo",
"sample_strategy": "mix",
}


@ALGORITHM_TYPE.register_module("raft")
class RAFTAlgorithm(AlgorithmType):
"""RAFT Algorithm.
This algorithm is conceptually similar to Supervised Fine-Tuning (SFT)
but is designed to work with `ExperienceModel` schema from rollouts.
"""

use_critic: bool = False
use_reference: bool = False
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
schema: type = ExperienceModel

@classmethod
def default_config(cls) -> Dict:
return {
"sample_strategy": "default",
"policy_loss_fn": "sft",
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
}
4 changes: 4 additions & 0 deletions trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .customized_toolcall_workflows import ToolCallWorkflow
from .envs.agentscope.agentscope_react_workflow import AgentScopeReactV2MathWorkflow
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow, StepWiseAlfworldWorkflow
from .envs.alfworld.RAFT_alfworld_workflow import RAFTAlfworldWorkflow
from .envs.alfworld.RAFT_reflect_alfworld_workflow import RAFTReflectAlfworldWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
from .eval_workflow import MathEvalWorkflow
Expand All @@ -19,6 +21,8 @@
"WebShopWorkflow",
"AlfworldWorkflow",
"StepWiseAlfworldWorkflow",
"RAFTAlfworldWorkflow",
"RAFTReflectAlfworldWorkflow",
"SciWorldWorkflow",
"MathBoxedWorkflow",
"MathRMWorkflow",
Expand Down
Loading