Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPO issues #1089

Open
peregilk opened this issue Dec 9, 2024 · 1 comment
Open

DPO issues #1089

peregilk opened this issue Dec 9, 2024 · 1 comment
Assignees

Comments

@peregilk
Copy link

peregilk commented Dec 9, 2024

@rdyro I am testing out the DPO branch, and I am currently facing these two issues with my DPO training:

  • Evaluation is not running with HF datasets. Setting eval_interval: 20 in base.yml simply causes training to stop after finishing evaluation.

  • Unable to run with per_device_batch_size > 1

The following code is reproducing the error training a llama3.1 70B on a TPU v5-256:

MaxText/train.py MaxText/configs/dpo.yml \
    base_output_directory='gs://mybucket' \ # Change to a valid bucket
    per_device_batch_size=0.5 \ # Works with 1
    tokenizer_path='north/llama3.1-8b-instruct-reference' \ # The original tokenizer is gated. This one is open
    max_target_length=128 \ # This is way too short, but does not give OOM when training with per_device_batch_size=1
    load_parameters_path='gs://maxtext-public-test/nb-llama-3.1-70B-sft/checkpoints/900/items' \
    steps=181 \
    checkpoint_period=180 \
    run_name='llama3.1-70B_dpo_helpfulandharmless_test1' \
    model_name='llama3.1-70b' \
    enable_checkpointing=True \
    async_checkpointing=True \
    dataset_type='hf' \
    hf_path='json' \
    hf_train_files='gs://maxtext-public-test/hh-rlhf-helpful-and-harmless/train*.jsonl' \
    remat_policy='minimal' \
    attention='flash' \
    warmup_steps_fraction=0.1 \
    hf_eval_split='' \
    hf_eval_files='gs://maxtext-public-test/hh-rlhf-helpful-and-harmless/test*.jsonl' \
    eval_steps=1 \
    allow_split_physical_axes=True \
    ici_tensor_parallelism=8 \
    use_dpo=True \
    dpo_reference_params_path=''

Both the model and training set is open. I am hosting it temporarily in a public bucket. This will be shut down when you have tried replicating this.

It seems like it is defaulting to non-dpo training, but I have been unable to figure out where and why. I did however notice that global_batch_size_to_train_on and global_batch_size_to_load differs when per_device_batch_size < 1. Maybe this mismatch can create a shape mismatch between chosen and rejected pairs?

@peregilk peregilk mentioned this issue Dec 9, 2024
Merged
4 tasks
@peregilk
Copy link
Author

@rdyro Did you have a chance to test if you could reproduce this error?

@rdyro rdyro self-assigned this Jan 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants