Skip to content

Commit 58f6807

Browse files
committed
Update
[ghstack-poisoned]
1 parent 675f0d9 commit 58f6807

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

test/llm/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414
from tensordict import lazy_stack, set_list_to_stack, TensorDict
1515

16-
from torchrl import torchrl_logger
16+
from torchrl import logger as torchrl_logger
1717

1818
from torchrl.data import (
1919
History,

test/llm/test_envs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def test_python_interpreter_single_batch(self):
375375
"```python\n"
376376
"print(1 + 1)\n"
377377
"```<|im_end|>\n"
378-
" <|im_start|>user\n"
378+
" <|im_start|>tool\n"
379379
"<tool_response>\n"
380380
"Code block 1 executed successfully:\n"
381381
"2\n"
@@ -395,7 +395,7 @@ def test_python_interpreter_single_batch(self):
395395
content="Here is a python code to execute:\n```python\nprint(1 + 1)\n```",
396396
),
397397
History(
398-
role="user",
398+
role="tool",
399399
content="<tool_response>\nCode block 1 executed successfully:\n2\n\n</tool_response>",
400400
tool_responses=["Code block 1 executed successfully:\n2\n"],
401401
),
@@ -478,7 +478,7 @@ def test_python_interpreter_persistent(self):
478478
"```python\n"
479479
"a=1\n"
480480
"```<|im_end|>\n"
481-
" <|im_start|>user\n"
481+
" <|im_start|>tool\n"
482482
"<tool_response>\n"
483483
"Code block 1 executed successfully:\n"
484484
"\n"
@@ -489,7 +489,7 @@ def test_python_interpreter_persistent(self):
489489
"a+=1\n"
490490
"assert a == 2\n"
491491
"```<|im_end|>\n"
492-
" <|im_start|>user\n"
492+
" <|im_start|>tool\n"
493493
"<tool_response>\n"
494494
"Code block 1 executed successfully:\n"
495495
"\n"

torchrl/envs/llm/transforms/tools.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,10 @@ def _step(
530530

531531
procs = []
532532
# Iterate over env batch-size
533-
for i, t in enumerate(local_history.content):
533+
content = local_history.content
534+
if isinstance(content, str):
535+
content = [content]
536+
for i, t in enumerate(content):
534537
results = self._process_llm_response(t, i)
535538
if len(results) == 0:
536539
procs.append(None)

0 commit comments

Comments
 (0)