diff --git a/tests/buffer/reward_shaping_mapper_test.py b/tests/buffer/reward_shaping_mapper_test.py new file mode 100644 index 0000000000..d44122f649 --- /dev/null +++ b/tests/buffer/reward_shaping_mapper_test.py @@ -0,0 +1,74 @@ +import unittest +from copy import deepcopy +from typing import List + +import torch + +from trinity.buffer.pipelines.experience_pipeline import ExperienceOperator +from trinity.common.config import OperatorConfig +from trinity.common.experience import EID, Experience + + +def get_experiences(task_num: int, repeat_times: int = 1, step_num: int = 1) -> List[Experience]: + """Generate a list of experiences for testing.""" + return [ + Experience( + eid=EID(task=i, run=j, step=k), + tokens=torch.zeros((5,)), + prompt_length=4, + reward=j, + logprobs=torch.tensor([0.1]), + info={ + "llm_quality_score": i, + "llm_difficulty_score": k, + }, + ) + for i in range(task_num) + for j in range(repeat_times) + for k in range(step_num) + ] + + +class TestRewardShapingMapper(unittest.TestCase): + def test_basic_usage(self): + # test input cache + op_configs = [ + OperatorConfig( + name="reward_shaping_mapper", + args={ + "reward_shaping_configs": [ + { + "stats_key": "llm_quality_score", + "op_type": "ADD", + "weight": 1.0, + }, + { + "stats_key": "llm_difficulty_score", + "op_type": "MUL", + "weight": 0.5, + }, + ] + }, + ) + ] + ops = ExperienceOperator.create_operators(op_configs) + self.assertEqual(len(ops), 1) + + op = ops[0] + task_num = 8 + repeat_times = 4 + step_num = 2 + experiences = get_experiences( + task_num=task_num, repeat_times=repeat_times, step_num=step_num + ) + res_exps, metrics = op.process(deepcopy(experiences)) + self.assertEqual(len(res_exps), task_num * repeat_times * step_num) + self.assertEqual(len(metrics), 0) + + for prev_exp, res_exp in zip(experiences, res_exps): + self.assertAlmostEqual( + (prev_exp.reward + prev_exp.info["llm_quality_score"]) + * 0.5 + * prev_exp.info["llm_difficulty_score"], + res_exp.reward, + ) diff --git a/trinity/buffer/operators/__init__.py b/trinity/buffer/operators/__init__.py index b99f484f11..e7e7b0d73e 100644 --- a/trinity/buffer/operators/__init__.py +++ b/trinity/buffer/operators/__init__.py @@ -3,10 +3,12 @@ ExperienceOperator, ) from trinity.buffer.operators.filters.reward_filter import RewardFilter, RewardSTDFilter +from trinity.buffer.operators.mappers.reward_shaping_mapper import RewardShapingMapper __all__ = [ "ExperienceOperator", "EXPERIENCE_OPERATORS", "RewardFilter", "RewardSTDFilter", + "RewardShapingMapper", ] diff --git a/trinity/buffer/operators/mappers/__init__.py b/trinity/buffer/operators/mappers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trinity/buffer/operators/mappers/reward_shaping_mapper.py b/trinity/buffer/operators/mappers/reward_shaping_mapper.py new file mode 100644 index 0000000000..6e037d08c9 --- /dev/null +++ b/trinity/buffer/operators/mappers/reward_shaping_mapper.py @@ -0,0 +1,91 @@ +from typing import Dict, List, Optional, Tuple + +from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator +from trinity.common.constants import OpType +from trinity.common.experience import Experience + + +@EXPERIENCE_OPERATORS.register_module("reward_shaping_mapper") +class RewardShapingMapper(ExperienceOperator): + """Re-shaping the existing rewards of experiences based on rules or other advanced methods. + + Note: + This mapper assumes that the reward is already calculated and stored in the Experience object, + and the necessary stats are already calculated and stored in the Experience info field. + """ + + def __init__(self, reward_shaping_configs: Optional[List[Dict]] = None): + """Initializes the RewardShapingMapper. + + Args: + reward_shaping_configs (list[dict], optional): A list of dictionaries containing reward shaping + configurations. Each dictionary should include the following keys: + + - stats_key (str): The field key name of target stats used to shape the reward. + - op_type (str): The type of operator to apply between the reward and the target stats. + Should be one of {"ADD", "SUB", "MUL", "DIV"}. + - weight (float): The weight for the target stats. + + Example: + [ + { + "stats_key": "llm_quality_score", + "op_type": "ADD", + "weight": 1.0, + } + ] + """ + if reward_shaping_configs is None: + reward_shaping_configs = [] + self.reward_shaping_configs = reward_shaping_configs + + def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: + res_exps = [] + for exp in exps: + # skip experiences that don't have reward + if exp.reward is None: + continue + res_exp = exp + for reward_shaping_config in self.reward_shaping_configs: + res_exp = self._reward_shaping_single(res_exp, reward_shaping_config) + res_exps.append(res_exp) + return res_exps, {} + + def _reward_shaping_single(self, exp: Experience, reward_shaping_config: Dict): + """Re-shapes the existing reward of one experience based on the given reward_shaping_config. + + Args: + exp (Experience): The experience object whose reward is to be reshaped. + reward_shaping_config (dict): A dictionary containing the reward shaping configuration. + It should include the following keys: + - stats_key (str): The field key name of target stats used to shape the reward. + - op_type (str): The type of operator to apply between the reward and the target stats. + Should be one of {"ADD", "SUB", "MUL", "DIV"}. + - weight (float): The weight for the target stats. + + Returns: + Experience: The experience object with the reshaped reward. + """ + tgt_stats = reward_shaping_config.get("stats_key", None) + op_type = OpType[reward_shaping_config.get("op_type", "ADD")] + weight = reward_shaping_config.get("weight", 1.0) + # if the target stats is not specified, skip the stats and return the original experience + if tgt_stats is None: + return exp + exp_info = exp.info + if exp_info is None or len(exp_info) == 0: + return exp + # if the target stats does not exist in the exp info, skip the stats and return the original experience + if tgt_stats not in exp_info: + return exp + if op_type == OpType.ADD: + exp.reward += weight * exp_info[tgt_stats] + elif op_type == OpType.MUL: + exp.reward *= weight * exp_info[tgt_stats] + elif op_type == OpType.SUB: + exp.reward -= weight * exp_info[tgt_stats] + elif op_type == OpType.DIV: + divisor = weight * exp_info[tgt_stats] + if divisor != 0: + exp.reward /= divisor + return exp