Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skyrl-gym/skyrl_gym/envs/base_text_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class BaseTextEnvStepOutput(TypedDict):
observations: ConversationType # OpenAI API Messages Format
reward: float
reward: Optional[float] # None if intermediate steps have no reward
done: bool
metadata: Dict[str, Any]
postprocessed_action: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion skyrl-gym/skyrl_gym/envs/search/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _get_reward(self, action: str, done: bool) -> float:
return compute_score(chat_history_str, self.ground_truth)
else:
# No reward for intermediate steps for Search tasks
return 0
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

With this change, the function can now return None. Please update the function's return type hint on line 45 from float to Optional[float] to reflect this. Optional is already imported in this file.


def _is_done(self, action: str) -> bool:
if self.turns >= self.max_turns:
Expand Down
2 changes: 1 addition & 1 deletion skyrl-gym/skyrl_gym/envs/sql/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _get_reward(self, action: str, done: bool) -> float:
return compute_score_single(chat_history_str, self.gold_sql, self.db_file)
else:
# No reward for intermediate steps for SQL tasks
return 0
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function now returns None for intermediate steps. The return type hint on line 78 should be updated from float to Optional[float] to match this change. You'll need to add Optional to your imports from the typing module (e.g., from typing import Optional).


def _is_done(self, action: str) -> bool:
if self.turns >= self.max_turns:
Expand Down
20 changes: 10 additions & 10 deletions skyrl-gym/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_successful_search(search_env, mock_search_api):

# Verify result structure
assert not result["done"]
assert result["reward"] == 0.0 # No reward for intermediate steps
assert result["reward"] is None # No reward for intermediate steps
assert len(result["observations"]) == 1
assert result["observations"][0]["role"] == "user"

Expand All @@ -207,7 +207,7 @@ def test_search_with_no_results(search_env, mock_search_api):

# Verify result structure
assert not result["done"]
assert result["reward"] == 0.0
assert result["reward"] is None
assert len(result["observations"]) == 1

# Verify observation contains no results message
Expand All @@ -223,7 +223,7 @@ def test_search_timeout_error(search_env, mock_search_api):

# Verify result structure
assert not result["done"]
assert result["reward"] == 0.0
assert result["reward"] is None
assert len(result["observations"]) == 1

# Verify observation contains error message
Expand All @@ -240,7 +240,7 @@ def test_search_server_error(search_env, mock_search_api):

# Verify result structure
assert not result["done"]
assert result["reward"] == 0.0
assert result["reward"] is None
assert len(result["observations"]) == 1

# Verify observation contains error message
Expand Down Expand Up @@ -276,15 +276,15 @@ def test_invalid_search_parsing(search_env, mock_search_api):
# Incorrect answer
("<answer>Nicolas Sarkozy</answer>", {"target": "Emmanuel Macron"}, 0.0, True),
# Search action (not done)
("<search>Who is the president?</search>", {"target": "Emmanuel Macron"}, 0.0, False),
("<search>Who is the president?</search>", {"target": "Emmanuel Macron"}, None, False),
# Answer with extra whitespace
("<answer> Emmanuel Macron </answer>", {"target": "Emmanuel Macron"}, 1.0, True),
# Case insensitive match
("<answer>emmanuel macron</answer>", {"target": "Emmanuel Macron"}, 1.0, True),
# Answer without articles
("<answer>Emmanuel Macron</answer>", {"target": "The Emmanuel Macron"}, 1.0, True),
# No answer tag
("Just text without answer tag", {"target": "Emmanuel Macron"}, 0.0, False),
("Just text without answer tag", {"target": "Emmanuel Macron"}, None, False),
],
)
def test_reward_computation(action, ground_truth, expected_reward, expected_done):
Expand Down Expand Up @@ -321,7 +321,7 @@ def test_successful_search_and_answer(mock_search_api):
# Step 1: Search
result1 = env.step("<search>Who is the president of France?</search>")
assert not result1["done"]
assert result1["reward"] == 0.0
assert result1["reward"] is None
assert len(result1["observations"]) == 1
assert "Emmanuel Macron" in result1["observations"][0]["content"]

Expand Down Expand Up @@ -350,7 +350,7 @@ def test_max_turns_reached(mock_search_api):
# Step 1: Search
result1 = env.step("<search>Who is the president of France?</search>")
assert not result1["done"]
assert result1["reward"] == 0.0
assert result1["reward"] is None

# Step 2: Another search (should terminate due to max turns)
result2 = env.step("<search>More info about France?</search>")
Expand All @@ -371,7 +371,7 @@ def test_search_then_wrong_answer(mock_search_api):
# Step 1: Search
result1 = env.step("<search>Who is the president of France?</search>")
assert not result1["done"]
assert result1["reward"] == 0.0
assert result1["reward"] is None

# Step 2: Wrong answer
result2 = env.step("<answer>Nicolas Sarkozy</answer>")
Expand Down Expand Up @@ -471,6 +471,6 @@ def test_tool_execution_exception(search_env):

# Should handle the exception gracefully
assert not result["done"]
assert result["reward"] == 0.0
assert result["reward"] is None
assert len(result["observations"]) == 1
assert "Tool execution failed" in result["observations"][0]["content"]
2 changes: 1 addition & 1 deletion skyrl-gym/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_compute_score(mock_db_file, mock_sqlite_connection, step_1_output, step
reward = output["reward"]

# intermediate step reward is 0
assert reward == 0.0
assert reward is None
# check reminder message
assert reminder_text in obs1[0]["content"]
if "<sql>" not in step_1_output:
Expand Down
2 changes: 1 addition & 1 deletion skyrl-train/docs/tutorials/new_env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ SkyRL-Gym includes a simple text-in/text-out environment interface for LLM tasks
Returns:
BaseTextEnvStepOutput containing:
- observations: New messages from the environment
- reward: Float reward for the action
- reward: Optional[Float] reward for the action, None if intermediate steps have no reward
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with Python's type hinting syntax, it would be clearer to use Optional[float] instead of Optional[Float].

Suggested change
- reward: Optional[Float] reward for the action, None if intermediate steps have no reward
- reward: Optional[float] reward for the action, None if intermediate steps have no reward

- done: Whether the episode is finished
- metadata: Additional info (optional)
"""
Expand Down
2 changes: 2 additions & 0 deletions skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ async def agent_loop(

if len(input_ids) > max_input_length:
stop_reason = "length"
if per_step_rewards[-1][0] is None:
per_step_rewards[-1] = (0.0, per_step_rewards[-1][1])
break

await self._run_in_executor_if_available(env.close)
Expand Down