File tree Expand file tree Collapse file tree 3 files changed +9
-6
lines changed
torchrl/envs/llm/transforms Expand file tree Collapse file tree 3 files changed +9
-6
lines changed Original file line number Diff line number Diff line change 1313import torch
1414from tensordict import lazy_stack , set_list_to_stack , TensorDict
1515
16- from torchrl import torchrl_logger
16+ from torchrl import logger as torchrl_logger
1717
1818from torchrl .data import (
1919 History ,
Original file line number Diff line number Diff line change @@ -375,7 +375,7 @@ def test_python_interpreter_single_batch(self):
375375 "```python\n "
376376 "print(1 + 1)\n "
377377 "```<|im_end|>\n "
378- " <|im_start|>user \n "
378+ " <|im_start|>tool \n "
379379 "<tool_response>\n "
380380 "Code block 1 executed successfully:\n "
381381 "2\n "
@@ -395,7 +395,7 @@ def test_python_interpreter_single_batch(self):
395395 content = "Here is a python code to execute:\n ```python\n print(1 + 1)\n ```" ,
396396 ),
397397 History (
398- role = "user " ,
398+ role = "tool " ,
399399 content = "<tool_response>\n Code block 1 executed successfully:\n 2\n \n </tool_response>" ,
400400 tool_responses = ["Code block 1 executed successfully:\n 2\n " ],
401401 ),
@@ -478,7 +478,7 @@ def test_python_interpreter_persistent(self):
478478 "```python\n "
479479 "a=1\n "
480480 "```<|im_end|>\n "
481- " <|im_start|>user \n "
481+ " <|im_start|>tool \n "
482482 "<tool_response>\n "
483483 "Code block 1 executed successfully:\n "
484484 "\n "
@@ -489,7 +489,7 @@ def test_python_interpreter_persistent(self):
489489 "a+=1\n "
490490 "assert a == 2\n "
491491 "```<|im_end|>\n "
492- " <|im_start|>user \n "
492+ " <|im_start|>tool \n "
493493 "<tool_response>\n "
494494 "Code block 1 executed successfully:\n "
495495 "\n "
Original file line number Diff line number Diff line change @@ -530,7 +530,10 @@ def _step(
530530
531531 procs = []
532532 # Iterate over env batch-size
533- for i , t in enumerate (local_history .content ):
533+ content = local_history .content
534+ if isinstance (content , str ):
535+ content = [content ]
536+ for i , t in enumerate (content ):
534537 results = self ._process_llm_response (t , i )
535538 if len (results ) == 0 :
536539 procs .append (None )
You can’t perform that action at this time.
0 commit comments