Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training Reproducibility #142

Open
davedgd opened this issue Dec 20, 2024 · 5 comments
Open

Training Reproducibility #142

davedgd opened this issue Dec 20, 2024 · 5 comments

Comments

@davedgd
Copy link

davedgd commented Dec 20, 2024

Thank you for putting together this awesome model!

I'm currently evaluating it in various benchmarks against other models (e.g., deberta-v3) using AutoModelForSequenceClassification for fine-tuning, and I'm having an issue with model reproducibility. Specifically, using existing code that allows for finding optimal hyperparameters and then subsequently refitting the final model, I am unable to make my ModernBERT model reproducible using coding that works fine for other model types (e.g., deberta-v3, roberta).

The process of refitting the model uses the huggingface recommended model_init approach for reproducibility:

final_trainer = Trainer(
    args = final_args,
    data_collator = data_collator,
    model_init = model_init,
    train_dataset = dataset_dict_tokenized['train'],
    eval_dataset = dataset_dict_tokenized['val'],
    compute_metrics = compute_metrics
)

However, while this works for all other models, it doesn't seem to for ModernBERT. Initially I thought this might be a flash-attn related issue, but it happens regardless of whether I use that or not.

Any advice here is much appreciated!

@davedgd
Copy link
Author

davedgd commented Dec 20, 2024

Ah, I solved my own issue. Unlike other models, ModernBERT must be loaded with the torch_dtype set, e.g.,

def model_init ():
    return AutoModelForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path, 
        num_labels = 2,
        torch_dtype = torch.bfloat16 # must be set for reproducibility
        )

I noticed a warning about this from flash-attn that helped diagnose the issue -- sorry for the hasty issue post. This is resolved!

@davedgd davedgd closed this as completed Dec 20, 2024
@umarbutler
Copy link

In their model README, they don't load their model with torch_dtype = torch.bfloat16. Presumably, that would cause it to be loaded in full bfloat16? Their paper doesn't mention that the model is a full bfloat16 model either.

The README does, however, load the fill-mask pipeline with torch_dtype=torch.bfloat16 but perhaps that is for AMP?

@davedgd
Copy link
Author

davedgd commented Dec 21, 2024

I went with that for the same reason you pointed out (i.e., the readme) as well as a message from flash-attn during model load that suggested the same. That said, if there’s a better way, it would certainly be good to know.

@umarbutler
Copy link

I just checked and it does seem like the pipeline would be loading ModernBERT in full bfloat16, which is inconsistent with the previous example 😆

Now, its possible they're advising full bfloat16 only for inference. I've raised an issue on their model page seeking clarity over this https://huggingface.co/answerdotai/ModernBERT-base/discussions/7

Would you mind reopening this issue until we get a conclusive answer.

RE the flash-attn warning, AFAIK that message will appear for any model. You can't use flash-attn without float16 or bfloat16. But you have the option of AMP or full half precision. The question is what was used to train ModernBERT.

And unfortunately the answer can in fact have an impact on reproducability and accuracy to some extent.

My guess is that it is AMP and not full bfloat16 because AFAIK it is rare to see full bfloat16 used in training. AMP is generally used and full bfloat16 is considered unstable for training.

@davedgd davedgd reopened this Dec 21, 2024
@davedgd
Copy link
Author

davedgd commented Dec 21, 2024

Good call and thanks for doing the diligence: reopened!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants