1111from pydantic import validate_call
1212from rich import print
1313from rich .logging import RichHandler
14- from sotopia . generation_utils .output_parsers import (
14+ from tiny_chat . generator .output_parsers import (
1515 EnvResponse ,
1616 OutputParser ,
1717 OutputType ,
1818 PydanticOutputParser ,
1919 ScriptOutputParser ,
2020 StrOutputParser ,
2121)
22- from sotopia .utils import format_docstring
22+ from tiny_chat .utils import format_docstring
2323
2424from tiny_chat .messages import ActionType , AgentAction , ScriptBackground
25- from tiny_chat .messages . message_classes import (
25+ from tiny_chat .messages import (
2626 ScriptInteraction ,
2727 ScriptInteractionReturnType ,
2828)
29- from tiny_chat .profile import EnvironmentProfile , RelationshipProfile
29+ from tiny_chat .profile import BaseEnvironmentProfile , BaseRelationshipProfile
3030
3131# Configure logger
3232log = logging .getLogger ('sotopia.generation' )
@@ -187,7 +187,7 @@ async def agenerate_env_profile(
187187 temperature : float = 0.7 ,
188188 bad_output_process_model : str | None = None ,
189189 use_fixed_model_version : bool = True ,
190- ) -> EnvironmentProfile :
190+ ) -> BaseEnvironmentProfile :
191191 """
192192 Using langchain to generate the background
193193 """
@@ -204,7 +204,7 @@ async def agenerate_env_profile(
204204 inspiration_prompt = inspiration_prompt ,
205205 examples = examples ,
206206 ),
207- output_parser = PydanticOutputParser (pydantic_object = EnvironmentProfile ),
207+ output_parser = PydanticOutputParser (pydantic_object = BaseEnvironmentProfile ),
208208 temperature = temperature ,
209209 bad_output_process_model = bad_output_process_model ,
210210 use_fixed_model_version = use_fixed_model_version ,
@@ -217,7 +217,7 @@ async def agenerate_relationship_profile(
217217 agents_profiles : list [str ],
218218 bad_output_process_model : str | None = None ,
219219 use_fixed_model_version : bool = True ,
220- ) -> tuple [RelationshipProfile , str ]:
220+ ) -> tuple [BaseRelationshipProfile , str ]:
221221 """
222222 Using langchain to generate the background
223223 """
@@ -232,7 +232,7 @@ async def agenerate_relationship_profile(
232232 input_values = dict (
233233 agent_profile = agent_profile ,
234234 ),
235- output_parser = PydanticOutputParser (pydantic_object = RelationshipProfile ),
235+ output_parser = PydanticOutputParser (pydantic_object = BaseRelationshipProfile ),
236236 bad_output_process_model = bad_output_process_model ,
237237 use_fixed_model_version = use_fixed_model_version ,
238238 )
0 commit comments