Skip to content

Commit

Permalink
Add support for torch_dtype in the run_mlm example (#29776)
Browse files Browse the repository at this point in the history
feat: add support for torch_dtype

Co-authored-by: Jacky Lee <jackylee328@gmail.com>
  • Loading branch information
jla524 and jackylee328 authored Mar 21, 2024
1 parent 10d232e commit ef6e371
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions examples/pytorch/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import datasets
import evaluate
import torch
from datasets import load_dataset

import transformers
Expand Down Expand Up @@ -133,6 +134,16 @@ class ModelArguments:
)
},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
low_cpu_mem_usage: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -425,6 +436,11 @@ def main():
)

if model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
model = AutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
Expand All @@ -433,6 +449,7 @@ def main():
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
torch_dtype=torch_dtype,
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
)
else:
Expand Down

0 comments on commit ef6e371

Please sign in to comment.