diff --git a/docs/sphinx_doc/assets/RAFT_alfworld_reward_curve.png b/docs/sphinx_doc/assets/RAFT_alfworld_reward_curve.png
new file mode 100644
index 0000000000..c62e12e6be
Binary files /dev/null and b/docs/sphinx_doc/assets/RAFT_alfworld_reward_curve.png differ
diff --git a/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml b/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml
new file mode 100644
index 0000000000..c195b255bc
--- /dev/null
+++ b/examples/RAFT_alfworld/RAFT_alfworld_7B.yaml
@@ -0,0 +1,75 @@
+project: "Trinity-RFT-RAFT-alfworld"
+name: "qwen2.5-7B-RAFT-alfworld"
+mode: both
+checkpoint_root_dir: /PATH/TO/CHECKPOINT/RAFT_ALFWORLD/
+algorithm:
+ algorithm_type: raft
+ repeat_times: 1
+model:
+ model_path: /PATH/TO/MODEL/
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 30
+ batch_size: 80
+ max_retry_times: 1
+ max_retry_interval: 1
+ explorer_input:
+ taskset:
+ name: alfworld-train
+ storage_type: file
+ path: '/PATH/TO/ALFWORLD_DATA/'
+ split: 'train'
+ format:
+ prompt_key: 'game_file'
+ response_key: 'task_desc'
+ rollout_args:
+ temperature: 0.85
+ top_k: 25
+ top_p: 0.95
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: '/PATH/TO/ALFWORLD_DATA/'
+ split: 'test'
+ format:
+ prompt_key: 'game_file'
+ response_key: 'task_desc'
+ rollout_args:
+ temperature: 0.85
+ top_k: 25
+ top_p: 0.95
+ default_workflow_type: 'RAFT_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: RAFT_buffer
+ storage_type: queue
+ path: 'sqlite:///RAFT_alfworld.db'
+explorer:
+ eval_interval: 30
+ runner_num: 100
+ max_timeout: 3000 # Increased timeout for alfworld environment interactions
+ max_retry_times: 2
+ rollout_model:
+ engine_type: vllm_async
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: true
+ enforce_eager: false
+ dtype: bfloat16
+ gpu_memory_utilization: 0.86
+ seed: 42
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 4
+ sync_timeout: 1200
+trainer:
+ trainer_type: 'verl'
+ trainer_config_path: 'examples/RAFT_alfworld/train_alfworld.yaml'
+ save_interval: 100000
+monitor:
+ monitor_type: wandb
diff --git a/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml b/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml
new file mode 100644
index 0000000000..6dc17ae32c
--- /dev/null
+++ b/examples/RAFT_alfworld/RAFT_reflect_alfworld_7B.yaml
@@ -0,0 +1,75 @@
+project: "Trinity-RFT-RAFT-reflect-alfworld"
+name: "qwen2.5-7B-RAFT-reflect-alfworld"
+mode: both
+checkpoint_root_dir: /PATH/TO/CHECKPOINT/RAFT_REFLECT_ALFWORLD/
+algorithm:
+ algorithm_type: raft
+ repeat_times: 1
+model:
+ model_path: /PATH/TO/MODEL/
+ max_response_tokens: 4096
+ max_model_len: 20480
+cluster:
+ node_num: 1
+ gpu_per_node: 8
+buffer:
+ total_epochs: 30
+ batch_size: 80
+ max_retry_times: 1
+ max_retry_interval: 1
+ explorer_input:
+ taskset:
+ name: alfworld-train
+ storage_type: file
+ path: '/PATH/TO/ALFWORLD_DATA/'
+ split: 'train'
+ format:
+ prompt_key: 'game_file'
+ response_key: 'task_desc'
+ rollout_args:
+ temperature: 0.85
+ top_k: 25
+ top_p: 0.95
+ eval_tasksets:
+ - name: alfworld-eval
+ storage_type: file
+ path: '/PATH/TO/ALFWORLD_DATA/'
+ split: 'test'
+ format:
+ prompt_key: 'game_file'
+ response_key: 'task_desc'
+ rollout_args:
+ temperature: 0.85
+ top_k: 25
+ top_p: 0.95
+ default_workflow_type: 'RAFT_reflect_alfworld_workflow'
+ trainer_input:
+ experience_buffer:
+ name: RAFT_buffer
+ storage_type: queue
+ path: 'sqlite:///RAFT_reflect_alfworld.db'
+explorer:
+ eval_interval: 30
+ runner_num: 100
+ max_timeout: 3000 # Increased timeout for alfworld environment interactions
+ max_retry_times: 2
+ rollout_model:
+ engine_type: vllm_async
+ engine_num: 4
+ tensor_parallel_size: 1
+ enable_prefix_caching: false
+ enforce_eager: false
+ dtype: bfloat16
+ gpu_memory_utilization: 0.86
+ seed: 42
+synchronizer:
+ sync_style: dynamic_by_explorer
+ sync_method: 'nccl'
+ sync_interval: 4
+ sync_timeout: 1200
+trainer:
+ trainer_type: 'verl'
+ trainer_config_path: 'examples/RAFT_alfworld/train_alfworld.yaml'
+ save_interval: 100000
+monitor:
+ monitor_type: wandb
diff --git a/examples/RAFT_alfworld/README.md b/examples/RAFT_alfworld/README.md
new file mode 100644
index 0000000000..3253cc25e4
--- /dev/null
+++ b/examples/RAFT_alfworld/README.md
@@ -0,0 +1,33 @@
+# RAFT on ALFWorld Dataset
+
+This example shows the usage of RAFT on the ALFWorld dataset, with both standard and reflection-enhanced variants.
+
+
+
+## Variants
+
+### Standard RAFT
+The config files are located in [`RAFT_alfworld_7B.yaml`](RAFT_alfworld_7B.yaml) and [`train_alfworld.yaml`](train_alfworld.yaml).
+
+### RAFT with Reflection
+The config files are located in [`RAFT_reflect_alfworld_7B.yaml`](RAFT_reflect_alfworld_7B.yaml) and [`train_alfworld.yaml`](train_alfworld.yaml).
+
+## Setup
+
+### Data Preparation
+To prepare the ALFWorld dataset, run:
+```bash
+python examples/grpo_alfworld/get_alfworld_data.py
+```
+
+### Configuration
+Before running, make sure to update the following paths in the YAML files:
+- `model.model_path`: Replace with your model path (e.g., `/PATH/TO/MODEL/`)
+- `buffer.explorer_input.taskset.path`: Replace with your alfworld dataset path
+- `buffer.explorer_input.eval_tasksets[0].path`: Replace with your alfworld dataset path
+- `checkpoint_root_dir`: Replace with your desired checkpoint directory
+
+## Implementation
+The workflow implementations are located in:
+- Standard RAFT: [`trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py`](../../trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py)
+- RAFT with Reflection: [`trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py`](../../trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py)
diff --git a/examples/RAFT_alfworld/train_alfworld.yaml b/examples/RAFT_alfworld/train_alfworld.yaml
new file mode 100644
index 0000000000..e586c4e73e
--- /dev/null
+++ b/examples/RAFT_alfworld/train_alfworld.yaml
@@ -0,0 +1,45 @@
+actor_rollout_ref:
+ hybrid_engine: True
+ model:
+ external_lib: null
+ override_config: { }
+ enable_gradient_checkpointing: True
+ use_remove_padding: True
+ actor:
+ strategy: fsdp
+ ppo_mini_batch_size: 16
+ ppo_micro_batch_size_per_gpu: 1
+ use_dynamic_bsz: True
+ ppo_max_token_len_per_gpu: 20000 # Adjusted for alfworld longer sequences
+ grad_clip: 1.0
+ clip_ratio: 0.2
+ ppo_epochs: 1
+ shuffle: False
+ ulysses_sequence_parallel_size: 1
+ optim:
+ lr: 1e-6
+ lr_warmup_steps_ratio: 0.
+ total_training_steps: -1 # Will be overridden by program
+ fsdp_config:
+ wrap_policy:
+ min_num_params: 0
+ param_offload: False
+ optimizer_offload: False
+ fsdp_size: -1
+ ref:
+ fsdp_config:
+ param_offload: False
+ wrap_policy:
+ min_num_params: 0
+ log_prob_micro_batch_size_per_gpu: 8
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}
+
+trainer:
+ balance_batch: True
+ resume_mode: disable
+ default_hdfs_dir: null
+ remove_previous_ckpt_in_save: True
+ del_local_ckpt_after_load: False
+ val_before_train: False
diff --git a/tests/utils/eval_utils_test.py b/tests/utils/eval_utils_test.py
index 8105b692ce..4cb137bf73 100644
--- a/tests/utils/eval_utils_test.py
+++ b/tests/utils/eval_utils_test.py
@@ -32,7 +32,7 @@ def test_extract_answer(self):
]
for i, (input_str, expected_output, description) in enumerate(test_cases):
- with self.subTest(f"Case {i+1}: {description}"):
+ with self.subTest(f"Case {i + 1}: {description}"):
actual_output = extract_answer(input_str)
self.assertEqual(
actual_output,
@@ -58,7 +58,7 @@ def test_verify_math_answer(self):
]
for i, (response, ground_truth, expected_correct, description) in enumerate(test_cases):
- with self.subTest(f"Case {i+1}: {description}"):
+ with self.subTest(f"Case {i + 1}: {description}"):
accuracy, details = verify_math_answer(response, ground_truth)
is_correct = accuracy == 1.0
self.assertEqual(
@@ -88,7 +88,7 @@ def test_is_equiv(self):
]
for i, (str1, str2, expected_output, description) in enumerate(test_cases):
- with self.subTest(f"Case {i+1}: {description}"):
+ with self.subTest(f"Case {i + 1}: {description}"):
actual_output = is_equiv(str1, str2)
self.assertEqual(
actual_output,
diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py
index 640ceff1ef..e652ac8f4e 100644
--- a/trinity/algorithm/algorithm.py
+++ b/trinity/algorithm/algorithm.py
@@ -196,3 +196,26 @@ def default_config(cls) -> Dict:
"advantage_fn": "grpo",
"sample_strategy": "mix",
}
+
+
+@ALGORITHM_TYPE.register_module("raft")
+class RAFTAlgorithm(AlgorithmType):
+ """RAFT Algorithm.
+ This algorithm is conceptually similar to Supervised Fine-Tuning (SFT)
+ but is designed to work with `ExperienceModel` schema from rollouts.
+ """
+
+ use_critic: bool = False
+ use_reference: bool = False
+ compute_advantage_in_trainer: bool = False
+ can_balance_batch: bool = True
+ schema: type = ExperienceModel
+
+ @classmethod
+ def default_config(cls) -> Dict:
+ return {
+ "sample_strategy": "default",
+ "policy_loss_fn": "sft",
+ "kl_loss_fn": "none",
+ "entropy_loss_fn": "none",
+ }
diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py
index ebafdb066c..062f9f5f09 100644
--- a/trinity/common/workflows/__init__.py
+++ b/trinity/common/workflows/__init__.py
@@ -4,6 +4,8 @@
from .customized_toolcall_workflows import ToolCallWorkflow
from .envs.agentscope.agentscope_react_workflow import AgentScopeReactV2MathWorkflow
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow, StepWiseAlfworldWorkflow
+from .envs.alfworld.RAFT_alfworld_workflow import RAFTAlfworldWorkflow
+from .envs.alfworld.RAFT_reflect_alfworld_workflow import RAFTReflectAlfworldWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
from .eval_workflow import MathEvalWorkflow
@@ -19,6 +21,8 @@
"WebShopWorkflow",
"AlfworldWorkflow",
"StepWiseAlfworldWorkflow",
+ "RAFTAlfworldWorkflow",
+ "RAFTReflectAlfworldWorkflow",
"SciWorldWorkflow",
"MathBoxedWorkflow",
"MathRMWorkflow",
diff --git a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py
new file mode 100644
index 0000000000..5b3cb8d786
--- /dev/null
+++ b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py
@@ -0,0 +1,225 @@
+# -*- coding: utf-8 -*-
+from datetime import datetime
+from typing import Dict, List, Optional
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.alfworld.RAFT_utils import (
+ create_alfworld_environment,
+ format_observation,
+ generate_default_empty_experience,
+ get_jinja_env,
+ parse_response,
+ process_messages_to_experience,
+ validate_trajectory_format,
+)
+from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
+
+
+@WORKFLOWS.register_module("RAFT_alfworld_workflow")
+class RAFTAlfworldWorkflow(Workflow):
+ """
+ RAFT workflow for alfworld using trajectory context.
+
+ Process:
+ 1. First exploration with normal experience generation
+ 2. Generate SFT data from successful attempt
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+ # Initialize workflow parameters
+ self.temperature = getattr(task.rollout_args, "temperature", 1.0)
+ self.top_k = getattr(task.rollout_args, "top_k", 20)
+ self.top_p = getattr(task.rollout_args, "top_p", 0.95)
+ self.max_env_steps = 50
+ self.max_tokens = 4096
+ self.task = task
+ self.is_eval = task.is_eval
+
+ # Setup Jinja2 templates
+ self.jinja_env = get_jinja_env()
+ self.alfworld_system_template = self.jinja_env.get_template("alfworld_system.j2")
+
+ print(
+ f"Initializing RAFTAlfworldWorkflow with RAFT learning, temperature={self.temperature}"
+ )
+ self.reset(task)
+
+ def reset(self, task: Task):
+ """Reset the workflow with a new task"""
+ self.game_file_path = task.task_desc or task.raw_task.get("game_file", "")
+ self.is_eval = task.is_eval
+
+ def create_environment(self, game_file):
+ """Create alfworld environment"""
+ return create_alfworld_environment(game_file)
+
+ def run_single_rollout(
+ self, env
+ ) -> tuple[List[Dict[str, str]], float, bool, int, List[Dict[str, str]]]:
+ """Run a single rollout with RAFT-guided actions"""
+ observation, info = env.reset()
+ trajectory = []
+ parsed_steps = [] # Store parsed experience, think, action for each step
+ action_history = [] # Track last 3 actions for repetition detection
+
+ trajectory.append({"role": "system", "content": self.alfworld_system_template.render()})
+
+ # Track the last reward from environment
+ last_reward = 0.0
+
+ for step in range(self.max_env_steps):
+ trajectory.append({"role": "user", "content": format_observation(observation)})
+
+ # Get model response with RAFT guidance
+ responses = self.model.chat(
+ trajectory,
+ n=1,
+ temperature=self.temperature,
+ top_k=self.top_k,
+ top_p=self.top_p,
+ max_tokens=self.max_tokens,
+ )
+ response_text = responses[0].response_text.strip()
+ trajectory.append({"role": "assistant", "content": response_text})
+
+ # Parse the three components
+ parsed = parse_response(response_text)
+ experience_text, think_text, action_text = (
+ parsed["experience"],
+ parsed["think"],
+ parsed["action"],
+ )
+
+ # Store parsed step for SFT data construction
+ parsed_steps.append(
+ {
+ "observation": observation,
+ "experience": experience_text,
+ "think": think_text,
+ "action": action_text,
+ "full_response": response_text,
+ }
+ )
+
+ # Check for consecutive action repetition
+ action_history.append(action_text)
+ if len(action_history) > 3:
+ action_history.pop(0)
+
+ # If last 3 actions are the same, terminate with failure
+ if len(action_history) >= 3 and all(
+ action == action_history[0] for action in action_history
+ ):
+ print(f"Terminating due to 3 consecutive identical actions: {action_text}")
+ return trajectory, 0.0, False, step + 1, parsed_steps
+
+ # Execute action in environment
+ observation, reward, done, info = env.step(action_text)
+ last_reward = reward # Always track the latest reward from environment
+
+ if done:
+ return trajectory, reward, done, step + 1, parsed_steps
+
+ # If timeout, return the last reward from environment instead of fixed value
+ return trajectory, last_reward, False, self.max_env_steps, parsed_steps
+
+ def _execute_first_attempt(self) -> tuple:
+ """Execute the first attempt and return results"""
+ env = self.create_environment(self.game_file_path)
+
+ try:
+ trajectory, reward, done, steps, parsed_steps = self.run_single_rollout(env)
+ except Exception as e:
+ print(f"Single rollout failed: {e}")
+ env.close()
+ raise e
+
+ env.close()
+ success = done and reward >= 1
+ traj_format_valid = validate_trajectory_format(parsed_steps)
+
+ return trajectory, reward, done, steps, parsed_steps, success, traj_format_valid
+
+ def eval_alfworld(self) -> List[Experience]:
+ """Evaluate a single alfworld trajectory"""
+ env = self.create_environment(self.game_file_path)
+ try:
+ trajectory, reward, done, steps, parsed_steps = self.run_single_rollout(env)
+ except Exception as e:
+ print(f"Single rollout failed during eval: {e}")
+ env.close()
+ return [generate_default_empty_experience(f"Eval rollout failed: {str(e)}")]
+ env.close()
+
+ # Save eval data
+ task_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+ success = done and reward >= 1
+
+ # Convert trajectory to experience
+ experience = generate_default_empty_experience(
+ msg="Eval completed successfully",
+ info={"task_id": task_id, "success": success, "reward": reward, "steps": steps},
+ metrics={"success": float(success), "reward": float(reward), "steps": float(steps)},
+ )
+
+ return [experience]
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT alfworld workflow and return experiences"""
+
+ if self.is_eval:
+ return self.eval_alfworld()
+
+ # Execute first attempt
+ try:
+ (
+ trajectory,
+ reward,
+ done,
+ steps,
+ parsed_steps,
+ success,
+ traj_format_valid,
+ ) = self._execute_first_attempt()
+ except Exception as e:
+ return [generate_default_empty_experience(f"Training rollout failed: {str(e)}")]
+
+ print(f"Task result: done={done}, reward={reward:.3f}, steps={steps}, success={success}")
+
+ if reward >= 1 and traj_format_valid:
+ print("✅ Task completed successfully in the first attempt!")
+ experience = process_messages_to_experience(
+ self.model, trajectory, info={"success": success, "reward": reward, "steps": steps}
+ )
+ return [experience]
+ elif not traj_format_valid and reward >= 1:
+ print(
+ "❌ Task completed but trajectory format is invalid, skipping SFT data generation."
+ )
+ else:
+ print("❌ Task failed.")
+
+ experience = generate_default_empty_experience(
+ "Experience conversion failed: Trajectory format invalid",
+ metrics={"success": float(success), "reward": float(reward), "steps": float(steps)},
+ )
+ return [experience]
+
+ def resettable(self) -> bool:
+ """Indicate that this workflow can be reset to avoid re-initialization"""
+ return True
+
+ def set_repeat_times(self, repeat_times, run_id_base):
+ self.repeat_times = repeat_times
+ self.run_id_base = run_id_base
diff --git a/trinity/common/workflows/envs/alfworld/RAFT_prompt/alfworld_system.j2 b/trinity/common/workflows/envs/alfworld/RAFT_prompt/alfworld_system.j2
new file mode 100644
index 0000000000..d60ba7f622
--- /dev/null
+++ b/trinity/common/workflows/envs/alfworld/RAFT_prompt/alfworld_system.j2
@@ -0,0 +1,29 @@
+You are an agent interacting with a virtual text-based environment.
+
+## Response Format:
+You MUST use this exact format for every response. All three tags are REQUIRED in sequential order:
+
+Working principles, strategies, common knowledge, and potential pitfalls relevant to the current task.\n\n
+your reasoning process\n\n
+exactly one action command
+
+## Action Commands:
+ look: look around your current location
+ inventory: check your current inventory(you can only have 1 item in your inventory)
+ go to (receptacle): move to a receptacle
+ open (receptacle): open a receptacle
+ close (receptacle): close a receptacle
+ take (object) from (receptacle): take an object from a receptacle
+ move (object) to (receptacle): place an object in or on a receptacle
+ examine (something): examine a receptacle or an object
+ use (object): use an object
+ heat (object) with (receptacle): heat an object using a receptacle
+ clean (object) with (receptacle): clean an object using a receptacle
+ cool (object) with (receptacle): cool an object using a receptacle
+ slice (object) with (object): slice an object using a sharp object
+
+For example your output should be like this:
+In household tasks, I should start by exploring the environment to understand available objects and receptacles. Common pitfalls include forgetting to check inventory capacity and not examining objects before taking action.\n\n To solve the task, I need first to ... \n\ngo to cabinet 1
+
+## Important Note:
+You must ensure that the section contains descriptions of working principles, strategies, common sense knowledge, and potential pitfalls that are universally applicable to this type of task, rather than generic statements, placeholder content, or overly specific behavioral guidelines.
diff --git a/trinity/common/workflows/envs/alfworld/RAFT_prompt/second_attempt_guidance.j2 b/trinity/common/workflows/envs/alfworld/RAFT_prompt/second_attempt_guidance.j2
new file mode 100644
index 0000000000..1945cf314b
--- /dev/null
+++ b/trinity/common/workflows/envs/alfworld/RAFT_prompt/second_attempt_guidance.j2
@@ -0,0 +1,3 @@
+{{ reward_feedback }}
+
+This is your second attempt. You need to perform better this time. Focus on updating a more thoughtful and comprehensive ... section that incorporates strategies, common pitfalls, and domain knowledge to help improve task performance and avoid previous mistakes. Learn from your first attempt and apply that knowledge to make better decisions. IMPORTANT: You must update your experience based on what you learned from the first attempt. Also, DO NOT repeat the same action consecutively as this is inefficient and useless - if you find yourself repeating actions, stop and try a different approach. In your responses, do not explicitly mention this is a second attempt or reference previous attempts. Instead, present your actions as if starting fresh, but incorporate general experience and strategies you can summarize for this type of task.
diff --git a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py
new file mode 100644
index 0000000000..8c7ff101a8
--- /dev/null
+++ b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py
@@ -0,0 +1,310 @@
+# -*- coding: utf-8 -*-
+import os
+from datetime import datetime
+from typing import Any, Dict, List, Optional
+
+from trinity.common.experience import Experience
+from trinity.common.models.model import ModelWrapper
+from trinity.common.workflows.envs.alfworld.RAFT_alfworld_workflow import (
+ RAFTAlfworldWorkflow,
+)
+from trinity.common.workflows.envs.alfworld.RAFT_utils import (
+ create_alfworld_environment,
+ format_observation,
+ generate_default_empty_experience,
+ generate_reward_feedback,
+ parse_response,
+ process_messages_to_experience,
+ save_task_data,
+ validate_trajectory_format,
+)
+from trinity.common.workflows.workflow import WORKFLOWS, Task
+
+
+@WORKFLOWS.register_module("RAFT_reflect_alfworld_workflow")
+class RAFTReflectAlfworldWorkflow(RAFTAlfworldWorkflow):
+ """
+ RAFT workflow for alfworld using trajectory context.
+
+ Process:
+ 1. First exploration with normal experience generation
+ 2. If failed, re-explore with first trajectory as context
+ 3. Generate SFT data from successful attempt
+ """
+
+ def __init__(
+ self,
+ model: ModelWrapper,
+ task: Task,
+ auxiliary_models: Optional[List] = None,
+ ):
+ super().__init__(
+ model=model,
+ task=task,
+ auxiliary_models=auxiliary_models,
+ )
+
+ # Create data directories specific to reflect workflow
+ self.data_dir = "RAFT_reflect_alfworld_data"
+ self.sft_dir = os.path.join(self.data_dir, "sft_data")
+ self.non_sft_dir = os.path.join(self.data_dir, "non_sft_data")
+
+ os.makedirs(self.sft_dir, exist_ok=True)
+ os.makedirs(self.non_sft_dir, exist_ok=True)
+
+ # Setup additional template for second attempt
+ self.second_attempt_template = self.jinja_env.get_template("second_attempt_guidance.j2")
+
+ print(
+ f"Initializing RAFTReflectAlfworldWorkflow with RAFT learning, temperature={self.temperature}"
+ )
+
+ def construct_sft_data(
+ self,
+ first_trajectory: List[Dict[str, str]],
+ success: bool,
+ reward: float,
+ original_steps: int,
+ ) -> tuple[List[Dict[str, str]], Dict[str, Any], List[Dict[str, str]]]:
+ """Generate SFT training data using RAFT learning"""
+
+ # Always perform second attempt with first trajectory as context
+ (
+ new_trajectory,
+ new_reward,
+ new_success,
+ new_steps,
+ new_parsed_steps,
+ ) = self.re_explore_with_context(first_trajectory, reward, success, original_steps)
+
+ # Consider improvement if reward is higher OR same reward with fewer steps
+ reward_improved = new_reward > reward
+ efficiency_improved = new_steps < original_steps
+
+ return (
+ new_trajectory,
+ {
+ "new_reward": new_reward,
+ "new_steps": new_steps,
+ "reward_improved": reward_improved,
+ "efficiency_improved": efficiency_improved,
+ },
+ new_parsed_steps,
+ )
+
+ def re_explore_with_context(
+ self,
+ first_trajectory: List[Dict[str, str]],
+ original_reward: float,
+ original_success: bool,
+ original_steps: int,
+ ) -> tuple[List[Dict[str, str]], float, bool, int, List[Dict[str, str]]]:
+ """Re-explore with first trajectory as context"""
+
+ env = create_alfworld_environment(self.game_file_path)
+
+ observation, info = env.reset()
+
+ # Use first trajectory as context for generation
+ context_messages = first_trajectory.copy()
+
+ # Add reward feedback about first attempt
+ reward_feedback = generate_reward_feedback(
+ original_reward, original_steps, original_success, self.max_env_steps
+ )
+ context_messages.append(
+ {
+ "role": "system",
+ "content": self.second_attempt_template.render(reward_feedback=reward_feedback),
+ }
+ )
+
+ # Build clean SFT trajectory (like first trajectory format)
+ sft_trajectory = [{"role": "system", "content": self.alfworld_system_template.render()}]
+ parsed_steps = [] # Track parsed steps for quality analysis
+
+ for step in range(self.max_env_steps):
+ # Add to context for generation
+ context_messages.append({"role": "user", "content": format_observation(observation)})
+
+ # Add to clean SFT trajectory
+ sft_trajectory.append({"role": "user", "content": format_observation(observation)})
+
+ responses = self.model.chat(
+ context_messages,
+ n=1,
+ temperature=self.temperature,
+ top_k=self.top_k,
+ top_p=self.top_p,
+ max_tokens=self.max_tokens,
+ )
+
+ response_text = responses[0].response_text.strip()
+
+ # Parse components for quality analysis
+ parsed = parse_response(response_text)
+ experience_text, think_text, action_text = (
+ parsed["experience"],
+ parsed["think"],
+ parsed["action"],
+ )
+
+ parsed_steps.append(
+ {
+ "observation": observation,
+ "experience": experience_text,
+ "think": think_text,
+ "action": action_text,
+ "full_response": response_text,
+ }
+ )
+
+ # Add to both trajectories
+ context_messages.append({"role": "assistant", "content": response_text})
+ sft_trajectory.append({"role": "assistant", "content": response_text})
+
+ observation, reward, done, info = env.step(action_text)
+
+ if done:
+ env.close()
+ return sft_trajectory, reward, done and reward > 0, step + 1, parsed_steps
+
+ env.close()
+ return sft_trajectory, reward, False, self.max_env_steps, parsed_steps
+
+ def _handle_invalid_format_success(
+ self, success: bool, reward: float, steps: int
+ ) -> List[Experience]:
+ """Handle case where task succeeded but format is invalid"""
+ print("❌ Task completed but trajectory format is invalid, skipping SFT data generation.")
+ experience = generate_default_empty_experience(
+ "Experience conversion failed: Trajectory format invalid",
+ metrics={"success": float(success), "reward": float(reward), "steps": float(steps)},
+ )
+ return [experience]
+
+ def _execute_second_attempt(
+ self, trajectory: list, success: bool, reward: float, steps: int
+ ) -> tuple:
+ """Execute second attempt and return SFT data"""
+ try:
+ sft_messages, re_explore_info, new_parsed_steps = self.construct_sft_data(
+ trajectory, success, reward, steps
+ )
+ return sft_messages, re_explore_info, new_parsed_steps, None
+ except Exception as e:
+ print(f"SFT data construction failed: {e}")
+ return None, None, None, e
+
+ def _build_metrics(
+ self, reward: float, steps: int, new_parsed_steps: list, re_explore_info: dict
+ ) -> dict:
+ """Build metrics for tracking"""
+ return {
+ "reward": float(reward),
+ "steps": float(steps),
+ "trajectory_length": len(new_parsed_steps),
+ "second_reward": float(re_explore_info["new_reward"]),
+ "second_steps": float(re_explore_info["new_steps"]),
+ "improvement": 1.0 if re_explore_info["reward_improved"] else 0.0,
+ }
+
+ def _should_keep_for_sft(self, second_traj_format_valid: bool, re_explore_info: dict) -> bool:
+ """Determine if trajectory should be kept for SFT"""
+ return second_traj_format_valid and (
+ re_explore_info["reward_improved"]
+ or (re_explore_info["efficiency_improved"] and re_explore_info["new_reward"] >= 1.0)
+ )
+
+ def _generate_experience_from_sft(self, sft_messages: list, metrics: dict) -> Experience:
+ """Generate experience from SFT messages"""
+ return process_messages_to_experience(self.model, sft_messages, info=metrics)
+
+ def run(self) -> List[Experience]:
+ """Run the RAFT alfworld workflow and return experiences"""
+
+ if self.is_eval:
+ return self.eval_alfworld()
+
+ # Generate unique task ID using timestamp
+ task_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+
+ # Execute first attempt
+ try:
+ (
+ trajectory,
+ reward,
+ done,
+ steps,
+ parsed_steps,
+ success,
+ traj_format_valid,
+ ) = self._execute_first_attempt()
+ except Exception as e:
+ return [generate_default_empty_experience(f"Training rollout failed: {str(e)}")]
+
+ # Handle first attempt success cases
+ if reward >= 1 and traj_format_valid:
+ print("✅ Task completed successfully in the first attempt!")
+ experience = process_messages_to_experience(
+ self.model, trajectory, info={"success": success, "reward": reward, "steps": steps}
+ )
+ return [experience]
+ elif not traj_format_valid and reward >= 1:
+ return self._handle_invalid_format_success(success, reward, steps)
+
+ print(f"Task result: done={done}, reward={reward:.3f}, steps={steps}, success={success}")
+
+ # Execute second attempt
+ sft_messages, re_explore_info, new_parsed_steps, error = self._execute_second_attempt(
+ trajectory, success, reward, steps
+ )
+ if error:
+ default_experience = generate_default_empty_experience(
+ f"SFT data construction failed: {str(error)}",
+ )
+ return [default_experience]
+
+ # Validate second attempt and build metrics
+ second_success = re_explore_info["new_reward"] >= 1
+ second_traj_format_valid = validate_trajectory_format(new_parsed_steps)
+ metrics = self._build_metrics(reward, steps, new_parsed_steps, re_explore_info)
+
+ # Generate experience if conditions are met
+ experiences = []
+ kept_for_sft = self._should_keep_for_sft(second_traj_format_valid, re_explore_info)
+
+ if kept_for_sft:
+ experience = self._generate_experience_from_sft(sft_messages, metrics)
+ experiences.append(experience)
+ print(
+ f"✅ Generated good training data: orig={reward}, steps={steps}, new={re_explore_info['new_reward']}, new_steps={re_explore_info['new_steps']}"
+ )
+ else:
+ print(
+ f"❌ Filtered trajectory: orig={reward}, steps={steps}, new={re_explore_info['new_reward']}, new_steps={re_explore_info['new_steps']}, second_traj_format_valid: {second_traj_format_valid}"
+ )
+
+ # Save detailed task data
+ save_task_data(
+ game_file_path=self.game_file_path,
+ sft_dir=self.sft_dir,
+ non_sft_dir=self.non_sft_dir,
+ task_id=task_id,
+ first_trajectory=trajectory,
+ first_reward=reward,
+ first_steps=steps,
+ first_success=success,
+ second_trajectory=sft_messages,
+ second_reward=re_explore_info["new_reward"],
+ second_steps=re_explore_info["new_steps"],
+ second_success=second_success,
+ kept_for_sft=kept_for_sft,
+ training_data=sft_messages,
+ )
+
+ # Return default experience if no valid experience generated
+ if not experiences:
+ experiences.append(generate_default_empty_experience())
+
+ return experiences
diff --git a/trinity/common/workflows/envs/alfworld/RAFT_utils.py b/trinity/common/workflows/envs/alfworld/RAFT_utils.py
new file mode 100644
index 0000000000..46b6f356a6
--- /dev/null
+++ b/trinity/common/workflows/envs/alfworld/RAFT_utils.py
@@ -0,0 +1,196 @@
+# -*- coding: utf-8 -*-
+import json
+import os
+import re
+from datetime import datetime
+from typing import Dict, List, Optional
+
+import torch
+from jinja2 import Environment, FileSystemLoader
+
+from trinity.common.experience import Experience
+
+
+# Setup Jinja2 environment for prompt templates
+def get_jinja_env():
+ """Get Jinja2 environment for loading templates"""
+ prompt_dir = os.path.join(os.path.dirname(__file__), "RAFT_prompt")
+ return Environment(loader=FileSystemLoader(prompt_dir))
+
+
+def format_observation(observation: str):
+ """Format observation string with additional guidance"""
+ if "Nothing happens." in observation:
+ observation += "Please check if the action you take is valid or you have carefully followed the action format."
+ return "Observation: " + observation
+
+
+def parse_response(response):
+ """Parse all three components from response with a single regex"""
+ try:
+ # Use single regex to extract all three components at once
+ pattern = r"\s*(.*?)\s*.*?\s*(.*?)\s*.*?\s*(.*?)\s*"
+ match = re.search(pattern, response, re.DOTALL)
+
+ if match:
+ return {
+ "experience": match.group(1).strip(),
+ "think": match.group(2).strip(),
+ "action": match.group(3).strip(),
+ }
+ else:
+ return {"experience": "", "think": "", "action": ""}
+ except Exception as e:
+ print(f"Error parsing response: {e}")
+ return {"experience": "", "think": "", "action": ""}
+
+
+def validate_response_format(parsed: Dict[str, str]) -> bool:
+ """Validate if parsed response has valid content in all required fields"""
+ has_think = len(parsed["think"].strip()) > 0
+ has_experience = len(parsed["experience"].strip()) > 0
+ has_action = len(parsed["action"].strip()) > 0
+
+ return has_think and has_experience and has_action
+
+
+def validate_trajectory_format(parsed_steps: List[Dict[str, str]]) -> bool:
+ """Validate if all steps in a trajectory have valid format"""
+ for step in parsed_steps:
+ if not validate_response_format(step):
+ return False
+ return True
+
+
+def generate_default_empty_experience(
+ msg: str = "Unknown error", info=None, metrics=None
+) -> Experience:
+ """Generate a default empty experience when errors occur"""
+ if info is None:
+ info = {"error_reason": msg, "rollout_failed": True}
+ if metrics is None:
+ metrics = {"success": 0.0, "reward": 0.0}
+
+ return Experience(
+ tokens=torch.tensor([0], dtype=torch.long),
+ prompt_length=0,
+ action_mask=torch.tensor([False], dtype=torch.bool),
+ info=info,
+ metrics=metrics,
+ )
+
+
+def create_alfworld_environment(game_file):
+ """Create alfworld environment"""
+ try:
+ import textworld
+ import textworld.gym
+ from alfworld.agents.environment.alfred_tw_env import (
+ AlfredDemangler,
+ AlfredExpert,
+ AlfredExpertType,
+ )
+
+ expert = AlfredExpert(expert_type=AlfredExpertType.HANDCODED)
+ request_infos = textworld.EnvInfos(
+ description=True, inventory=True, admissible_commands=True
+ )
+
+ env_id = textworld.gym.register_game(
+ game_file, request_infos, wrappers=[AlfredDemangler(), expert]
+ )
+ env = textworld.gym.make(env_id)
+ return env
+
+ except Exception as e:
+ error_message = f"Error importing AlfworldTWEnv {str(e)}. Please make sure you have installed the alfworld package successfully, following the instructions in https://github.com/alfworld/alfworld"
+ raise ImportError(error_message)
+
+
+def process_messages_to_experience(model, messages, info=None) -> Experience:
+ """Convert messages to experience for training, with fallback to default empty experience"""
+ if info is None:
+ info = {}
+
+ try:
+ converted_experience = model.convert_messages_to_experience(messages)
+
+ metrics = {}
+ for k, v in info.items():
+ if isinstance(v, float) or isinstance(v, int):
+ metrics[k] = float(v)
+ converted_experience.info = info
+ converted_experience.metrics = metrics
+
+ return converted_experience
+ except Exception as e:
+ print(f"Failed to convert messages to experience: {e}")
+ return generate_default_empty_experience(
+ f"Experience conversion failed: {str(e)}",
+ info=info,
+ metrics={k: float(v) for k, v in info.items() if isinstance(v, (float, int))},
+ )
+
+
+def generate_reward_feedback(reward: float, steps: int, done: bool, max_env_steps: int) -> str:
+ """Generate natural language feedback about the attempt's performance"""
+ if done and reward >= 1:
+ return f"In your attempt, you successfully completed the task in {steps} steps with a reward of {reward:.3f}. Try to maintain this success while being more efficient."
+ elif done and reward < 1:
+ return f"In your attempt, you completed the task in {steps} steps but only achieved a reward of {reward:.3f}. You need to improve your performance to achieve full success."
+ elif not done and steps >= max_env_steps:
+ return f"In your attempt, you reached the maximum step limit of {max_env_steps} steps without completing the task (reward: {reward:.3f}). You need to be more efficient and focused to complete the task within the step limit."
+ else:
+ return f"In your attempt, you stopped after {steps} steps with a reward of {reward:.3f} without completing the task. You need to improve your strategy and persistence to achieve success."
+
+
+def save_task_data(
+ task_id: str,
+ game_file_path: str,
+ first_trajectory: List[Dict[str, str]],
+ first_reward: float,
+ first_steps: int,
+ first_success: bool,
+ second_trajectory: Optional[List[Dict[str, str]]],
+ second_reward: Optional[float],
+ second_steps: Optional[int],
+ second_success: Optional[bool],
+ kept_for_sft: bool,
+ sft_dir: str,
+ non_sft_dir: str,
+ training_data: Optional[List[Dict[str, str]]] = None,
+):
+ """Save detailed exploration data to individual task file in appropriate folder"""
+ task_data = {
+ "task_id": task_id,
+ "timestamp": datetime.now().isoformat(),
+ "game_file": game_file_path,
+ # First exploration
+ "first_exploration": {
+ "trajectory": first_trajectory,
+ "reward": first_reward,
+ "steps": first_steps,
+ "success": first_success,
+ },
+ # Second exploration
+ "second_exploration": {
+ "trajectory": second_trajectory,
+ "reward": second_reward,
+ "steps": second_steps,
+ "success": second_success,
+ },
+ # Training data (clean dialogue format for SFT)
+ "training_data": training_data if training_data is not None else "",
+ "kept_for_sft": kept_for_sft,
+ }
+
+ # Determine folder based on SFT data status (following webshop pattern)
+ if kept_for_sft:
+ target_dir = sft_dir
+ else:
+ target_dir = non_sft_dir
+
+ task_file_path = os.path.join(target_dir, f"{task_id}.json")
+
+ with open(task_file_path, "w", encoding="utf-8") as f:
+ json.dump(task_data, f, ensure_ascii=False, indent=2)