Fix handling of f_divergence_type in DPO #4171
Merged
+25
−5
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fix handling
f_divergence_typein DPO.This PR improves the handling of the
f_divergence_typeconfiguration in the DPO trainer by consistently using theFDivergenceTypeenum internally, while allowing flexibility in input types. It also adds new tests to ensure correct normalization and serialization of thef_divergence_typefield.Note that currently:
f_divergence_typeinDPOConfigis astr, since the trainer logic compares its value against the string representations (.value) of theFDivergenceTypeenum:trl/trl/trainer/dpo_trainer.py
Line 1068 in 864e593
trl/trl/trainer/dpo_trainer.py
Line 1090 in 864e593
FDivergenceType.REVERSE_KL, ofFDivergenceTypeenum type:trl/trl/trainer/dpo_config.py
Lines 399 to 400 in 864e593
This inconsistency can lead to confusion and potential type mismatches during usage. The proposed changes aim to standardize the handling of f_divergence_type, ensuring consistent type normalization and comparison throughout the codebase.
Follow-up to:
Changes
Config normalization and input handling:
f_divergence_typefield inDPOConfigto accept bothFDivergenceTypeenum members and strings, improving flexibility for users. (trl/trainer/dpo_config.py)__post_init__method ofDPOConfigto always convertf_divergence_typeto anFDivergenceTypeenum member, ensuring consistent internal usage. (trl/trainer/dpo_config.py)Loss function logic update:
dpo_lossfunction to comparef_divergence_typedirectly with enum members instead of their string values, leveraging the normalized config.Testing improvements:
TestDPOConfigto verify normalization and serialization off_divergence_type, including parameterized tests for both enum and string inputs. (tests/test_dpo_trainer.py)