You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@rdyro I am testing out the DPO branch, and I am currently facing these two issues with my DPO training:
Evaluation is not running with HF datasets. Setting eval_interval: 20 in base.yml simply causes training to stop after finishing evaluation.
Unable to run with per_device_batch_size > 1
The following code is reproducing the error training a llama3.1 70B on a TPU v5-256:
MaxText/train.pyMaxText/configs/dpo.yml \
base_output_directory='gs://mybucket' \ # Change to a valid bucketper_device_batch_size=0.5 \ # Works with 1tokenizer_path='north/llama3.1-8b-instruct-reference' \ # The original tokenizer is gated. This one is openmax_target_length=128 \ # This is way too short, but does not give OOM when training with per_device_batch_size=1load_parameters_path='gs://maxtext-public-test/nb-llama-3.1-70B-sft/checkpoints/900/items' \
steps=181 \
checkpoint_period=180 \
run_name='llama3.1-70B_dpo_helpfulandharmless_test1' \
model_name='llama3.1-70b' \
enable_checkpointing=True \
async_checkpointing=True \
dataset_type='hf' \
hf_path='json' \
hf_train_files='gs://maxtext-public-test/hh-rlhf-helpful-and-harmless/train*.jsonl' \
remat_policy='minimal' \
attention='flash' \
warmup_steps_fraction=0.1 \
hf_eval_split='' \
hf_eval_files='gs://maxtext-public-test/hh-rlhf-helpful-and-harmless/test*.jsonl' \
eval_steps=1 \
allow_split_physical_axes=True \
ici_tensor_parallelism=8 \
use_dpo=True \
dpo_reference_params_path=''
Both the model and training set is open. I am hosting it temporarily in a public bucket. This will be shut down when you have tried replicating this.
It seems like it is defaulting to non-dpo training, but I have been unable to figure out where and why. I did however notice that global_batch_size_to_train_on and global_batch_size_to_load differs when per_device_batch_size < 1. Maybe this mismatch can create a shape mismatch between chosen and rejected pairs?
The text was updated successfully, but these errors were encountered:
@rdyro I am testing out the DPO branch, and I am currently facing these two issues with my DPO training:
Evaluation is not running with HF datasets. Setting
eval_interval: 20
in base.yml simply causes training to stop after finishing evaluation.Unable to run with per_device_batch_size > 1
The following code is reproducing the error training a llama3.1 70B on a TPU v5-256:
Both the model and training set is open. I am hosting it temporarily in a public bucket. This will be shut down when you have tried replicating this.
It seems like it is defaulting to non-dpo training, but I have been unable to figure out where and why. I did however notice that
global_batch_size_to_train_on
andglobal_batch_size_to_load
differs when per_device_batch_size < 1. Maybe this mismatch can create a shape mismatch between chosen and rejected pairs?The text was updated successfully, but these errors were encountered: