-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reward verification and evaluation fixes #55
Conversation
You can return any value here: the model generates a group of completions for one prompt, then these completions are rewarded, then normalised. If all completions get the same reward (because the common prompt is not parsable here), then the normalisation outputs a tensor of 0s. So no gradient. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!!
I approve the GRPO part
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one comment about abstracting away how reward is calculated from internals of the loop
reward = float(verify(answer, parse(sol))) | ||
except Exception: # if it fails for any reason, return 0.0 | ||
reward = 0.0 | ||
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd move l46-l73 to another method named get_reward
def get_reward(content: str, solution: str) -> float:
"""
Calculate reward by comparing the parsed content with the ground truth solution.
Args:
content (str): The model's output content to evaluate
solution (str): The ground truth solution to compare against
Returns:
float: 1.0 if the content matches the solution, 0.0 otherwise
"""
gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) == 0:
return 1.0
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
return float(verify(answer_parsed, gold_parsed))
then the loop on line 45 changes to
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
reward = get_reward(content, sol)
rewards.append(reward)
return rewards
Which improves readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think it's fine for now, let's move it once we add stuff for code and other rewards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix on AIME and improvement to GRPO! Both changes LGTM - feel free to merge with @qgallouedec's suggestion
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
1 .The answer from model must be clearly preceded by answer is xxx or by be in \boxed env (with the boxed env being tried first).
2. No latex malformed operators + nits are applied to the resulting latex, meaning model has to output correct latex.
3. the gold must be in latex format, this ensures that the extracted gold is exactly what we want to compare against
4. Removed the try/catch it's handled in math-verify now
cc @qgallouedec, I am not sure whether returning 1 if the gold is not found is the correct way to do this. It assumes that if rewards is 1 no gradient change is enforced is that correct?