diff --git a/trinity/algorithm/add_strategy/__init__.py b/trinity/algorithm/add_strategy/__init__.py index d1bbc84e1c..51df3ed540 100644 --- a/trinity/algorithm/add_strategy/__init__.py +++ b/trinity/algorithm/add_strategy/__init__.py @@ -1,11 +1,15 @@ from trinity.algorithm.add_strategy.add_strategy import ( ADD_STRATEGY, AddStrategy, + GRPOAddStrategy, + OPMDAddStrategy, RewardVarianceAddStrategy, ) __all__ = [ "ADD_STRATEGY", "AddStrategy", + "GRPOAddStrategy", + "OPMDAddStrategy", "RewardVarianceAddStrategy", ] diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index c2762ae43c..ea90ec63ec 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -75,7 +75,7 @@ def run(self) -> List[Experience]: else: responses = self.model.generate([prompt_text], **self.rollout_args) - for response in responses: + for run_id, response in enumerate(responses): reward_dict = self.reward_fn( # type: ignore [misc] response=response.response_text, # type: ignore [arg-type] truth=self.truth, @@ -89,6 +89,7 @@ def run(self) -> List[Experience]: response.metrics.update(reward_dict) reward = sum(reward_dict.values()) response.reward = reward + response.eid.run = run_id if not self.use_base: logger.debug( diff --git a/trinity/common/workflows/customized_toolcall_workflows.py b/trinity/common/workflows/customized_toolcall_workflows.py index a3a8a83554..17894b79cc 100644 --- a/trinity/common/workflows/customized_toolcall_workflows.py +++ b/trinity/common/workflows/customized_toolcall_workflows.py @@ -247,7 +247,7 @@ def run(self) -> List[Experience]: logger.debug("start chat") responses = self.model.chat(messages, **self.rollout_args) - for i, response in enumerate(responses): + for run_id, response in enumerate(responses): reward = 0.0 if self.raw_task is not None: @@ -267,5 +267,5 @@ def run(self) -> List[Experience]: f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) response.reward = reward - response.eid.run = i + response.eid.run = run_id return responses diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_rm_workflow.py index f8e1d41720..682ef00cc1 100644 --- a/trinity/common/workflows/math_rm_workflow.py +++ b/trinity/common/workflows/math_rm_workflow.py @@ -36,7 +36,7 @@ def run(self) -> List[Experience]: logger.debug("start chat") responses = self.model.chat(messages, **self.rollout_args) - for response in responses: + for run_id, response in enumerate(responses): reward_dict = self.reward_fn( # type: ignore response, messages, @@ -48,6 +48,7 @@ def run(self) -> List[Experience]: response.metrics.update(reward_dict) reward = sum(reward_dict.values()) response.reward = reward + response.eid.run = run_id logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index ceabdc771a..54301737e5 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -215,7 +215,7 @@ def run(self) -> List[Experience]: logger.debug("start chat") responses = self.model.chat(messages, **self.rollout_args) - for response in responses: + for run_id, response in enumerate(responses): reward_dict = self.reward_fn( # type: ignore [misc] response=response.response_text, # type: ignore [arg-type] truth=self.truth, @@ -226,6 +226,7 @@ def run(self) -> List[Experience]: response.metrics.update(reward_dict) reward = sum(reward_dict.values()) response.reward = reward + response.eid.run = run_id logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}"