Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,15 +289,19 @@ async def agent_loop(
if retokenize_chat_history:
reward_out = per_step_rewards[-1][0]
else:
# Build token-level rewards placed at assistant turn boundaries
token_level_rewards: List[float] = [0.0] * len(response_ids)
for step_reward, idx in per_step_rewards:
if step_reward is None:
continue
if idx >= len(response_ids):
break
token_level_rewards[idx] += step_reward
reward_out = token_level_rewards
if all(reward is None for reward, _ in per_step_rewards[:-1]):
# If all rewards besides the last one are None (i.e. per-trajectory reward), we keep it as a float
reward_out = per_step_rewards[-1][0]
else:
# Otherwise build token-level rewards placed at assistant turn boundaries
token_level_rewards: List[float] = [0.0] * len(response_ids)
for step_reward, idx in per_step_rewards:
if step_reward is None:
continue
if idx >= len(response_ids):
break
token_level_rewards[idx] += step_reward
reward_out = token_level_rewards

return AgentLoopOutput(
response_ids=response_ids,
Expand Down
23 changes: 16 additions & 7 deletions skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,8 +965,9 @@ def close(self):

@pytest.mark.asyncio
@patch("skyrl_gym.make")
@pytest.mark.parametrize("is_per_turn_reward", [True, False])
async def test_agent_loop_token_level_rewards_multi_turn_conversation_format(
mock_make, mock_tokenizer, mock_llm, mock_env_cfg
mock_make, mock_tokenizer, mock_llm, mock_env_cfg, is_per_turn_reward
):
"""use_conversation_multi_turn=True; verify rewards placed at ends of assistant segments before observations."""
mock_tokenizer.eos_token_id = 4
Expand Down Expand Up @@ -1007,7 +1008,10 @@ def step(self, action):
self.turns += 1
if self.turns == 1:
return BaseTextEnvStepOutput(
observations=[{"role": "user", "content": "obs1"}], reward=0.5, done=False, metadata={}
observations=[{"role": "user", "content": "obs1"}],
reward=0.5 if is_per_turn_reward else None,
done=False,
metadata={},
)
else:
return BaseTextEnvStepOutput(observations=[], reward=0.25, done=True, metadata={})
Expand Down Expand Up @@ -1046,11 +1050,16 @@ def close(self):

# Response ids layout: step1 assistant (4 incl. eos) + obs(2) + step2 assistant(4 incl. eos) = 10
assert len(out.response_ids) == 10
# Rewards at indices: 3 (end of step1 assistant), 9 (end of step2 assistant)
expected = [0.0] * 10
expected[3] = 0.5
expected[9] = 0.25
assert isinstance(out.reward, list)
if is_per_turn_reward:
# Rewards at indices: 3 (end of step1 assistant), 9 (end of step2 assistant)
expected = [0.0] * 10
expected[3] = 0.5
expected[9] = 0.25
assert isinstance(out.reward, list)
else:
# Per-trajectory reward is a single float
expected = 0.25
assert isinstance(out.reward, float)
assert out.reward == expected
assert out.stop_reason == "stop"

Expand Down