|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward |
| 15 | +from trl.rewards import accuracy_reward, get_soft_overlong_punishment, reasoning_accuracy_reward, think_format_reward |
16 | 16 |
|
17 | 17 | from .testing_utils import TrlTestCase, require_math_latex |
18 | 18 |
|
@@ -117,27 +117,76 @@ def test_accuracy_reward_unparseable_gold(self): |
117 | 117 | """Test accuracy_reward with an unparseable gold solution.""" |
118 | 118 | completion = [ |
119 | 119 | [{"content": "Answer is forty two."}], |
120 | | - [{"content": "Some other content."}], |
121 | | - [{"content": r"Answer is \boxed{42}."}], |
122 | | - [{"content": r"Answer is \boxed{\mathbf{42}}."}], # Make response bold |
123 | | - [{"content": r"Answer is \boxed{\textbf{42}}."}], # Different latex command for bold |
124 | | - [{"content": r"Answer is \boxed{42}."}], |
125 | | - [{"content": r"Answer is \boxed{42.3456}."}], |
| 120 | + [{"content": r"Some other content. \boxed{43}."}], |
126 | 121 | ] |
127 | 122 | solution = [ |
128 | 123 | "Answer is forty two.", |
129 | 124 | "Answer is forty three.", |
130 | | - "Answer is 42.0", # Decimal point |
131 | | - "Answer is 42 43 okay?", # Extra space |
132 | | - "Answer is 42", |
133 | | - r"Answer is \n\boxed{42}", # Newline in gold solution |
134 | | - "Answer is 42.34560", # Extra trailing zero |
135 | 125 | ] |
136 | 126 | rewards = accuracy_reward(completion, solution) |
137 | | - assert rewards[0] == 1.0 # Should revert to exact text match |
| 127 | + assert rewards[0] is None |
| 128 | + assert rewards[1] is None |
| 129 | + |
| 130 | + |
| 131 | +class TestReasoningAccuracyReward: |
| 132 | + @require_math_latex |
| 133 | + def test_correct_answer_yields_unit_reward(self): |
| 134 | + completions = [ |
| 135 | + [{"content": r"<think> Reasoning content </think> \boxed{\frac{63}{400}}"}], |
| 136 | + [{"content": r"Reasoning content </think> \boxed{\frac{63}{400}}"}], |
| 137 | + ] |
| 138 | + solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] |
| 139 | + rewards = reasoning_accuracy_reward(completions, solutions) |
| 140 | + assert rewards[0] == 1.0 |
| 141 | + assert rewards[1] == 1.0 |
| 142 | + |
| 143 | + @require_math_latex |
| 144 | + def test_correct_answer_with_custom_tags_yields_unit_reward(self): |
| 145 | + completions = [ |
| 146 | + [{"content": r"<REASONING_START> Reasoning content </REASONING_END> \boxed{\frac{63}{400}}"}], |
| 147 | + ] |
| 148 | + solutions = [ |
| 149 | + r"\frac{63}{400}", |
| 150 | + ] |
| 151 | + rewards = reasoning_accuracy_reward(completions, solutions, reasoning_delimiters=["</REASONING_END>"]) |
| 152 | + assert rewards[0] == 1.0 |
| 153 | + |
| 154 | + @require_math_latex |
| 155 | + def test_incorrect_answer_yields_zero_reward(self): |
| 156 | + completion = [[{"content": r"<think> Reasoning content </think> \boxed{\frac{64}{400}}"}]] |
| 157 | + solution = [r"\frac{63}{400}"] |
| 158 | + rewards = reasoning_accuracy_reward(completion, solution) |
| 159 | + assert rewards[0] == 0.0 |
| 160 | + |
| 161 | + @require_math_latex |
| 162 | + def test_correct_answer_in_reasoning_yields_zero_reward(self): |
| 163 | + completions = [ |
| 164 | + [{"content": r"<think> My answer is \boxed{42} </think> Some other text."}], |
| 165 | + [{"content": r"<think> The answer is \boxed{42} </think> Here's a wrong answer: \boxed{43}."}], |
| 166 | + ] |
| 167 | + solutions = [r"\boxed{42}", r"\boxed{42}"] |
| 168 | + rewards = reasoning_accuracy_reward(completions, solutions) |
| 169 | + assert rewards[0] == 0.0 |
| 170 | + assert rewards[1] == 0.0 |
| 171 | + |
| 172 | + @require_math_latex |
| 173 | + def test_incomplete_reasoning_yields_zero_reward(self): |
| 174 | + completions = [ |
| 175 | + [{"content": r"<think> Incomplete reasoning without closing tag"}], |
| 176 | + [{"content": r"Correct answer \frac{63}{400} but completely missing reasoning content"}], |
| 177 | + ] |
| 178 | + solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] |
| 179 | + rewards = reasoning_accuracy_reward(completions, solutions) |
| 180 | + assert rewards[0] == 0.0 |
138 | 181 | assert rewards[1] == 0.0 |
139 | | - assert rewards[2] == 1.0 |
140 | | - assert rewards[3] == 1.0 |
141 | | - assert rewards[4] == 1.0 |
142 | | - assert rewards[5] == 1.0 |
143 | | - assert rewards[6] == 1.0 # Should ignore trailing zeros |
| 182 | + |
| 183 | + @require_math_latex |
| 184 | + def test_unparseable_gold_solution_yields_none_reward(self): |
| 185 | + completions = [ |
| 186 | + [{"content": r"<think> Reasoning content </think> \boxed{42}"}], |
| 187 | + ] |
| 188 | + solutions = [ |
| 189 | + "forty two", |
| 190 | + ] |
| 191 | + rewards = reasoning_accuracy_reward(completions, solutions) |
| 192 | + assert rewards[0] is None |
0 commit comments