Skip to content

Conversation

@albertvillanova
Copy link
Member

Fix handling f_divergence_type in DPO.

This PR improves the handling of the f_divergence_type configuration in the DPO trainer by consistently using the FDivergenceType enum internally, while allowing flexibility in input types. It also adds new tests to ensure correct normalization and serialization of the f_divergence_type field.

Note that currently:

  • The expected type of the provided f_divergence_type in DPOConfig is a str, since the trainer logic compares its value against the string representations (.value) of the FDivergenceType enum:
    if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:
    if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
  • However, if no value is explicitly passed, the config defaults to FDivergenceType.REVERSE_KL, of FDivergenceType enum type:
    f_divergence_type: FDivergenceType = field(
    default=FDivergenceType.REVERSE_KL,

This inconsistency can lead to confusion and potential type mismatches during usage. The proposed changes aim to standardize the handling of f_divergence_type, ensuring consistent type normalization and comparison throughout the codebase.

Follow-up to:

Changes

Config normalization and input handling:

  • Updated the f_divergence_type field in DPOConfig to accept both FDivergenceType enum members and strings, improving flexibility for users. (trl/trainer/dpo_config.py)
  • Added normalization in the __post_init__ method of DPOConfig to always convert f_divergence_type to an FDivergenceType enum member, ensuring consistent internal usage. (trl/trainer/dpo_config.py)

Loss function logic update:

  • Modified checks in the dpo_loss function to compare f_divergence_type directly with enum members instead of their string values, leveraging the normalized config.

Testing improvements:

  • Added a new test class TestDPOConfig to verify normalization and serialization of f_divergence_type, including parameterized tests for both enum and string inputs. (tests/test_dpo_trainer.py)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

I was worried that it wouldn't work with the cli, but it seems to work:

trl dpo   --model_name_or_path Qwen/Qwen2.5-0.5B   --dataset_name anthropic/hh-rlhf   --f_divergence_type reverse_kl

qgallouedec and others added 3 commits September 30, 2025 17:51
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@albertvillanova albertvillanova merged commit 5a4021f into huggingface:main Oct 1, 2025
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants