diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 825592a001bb48..474596c4f44893 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -32,6 +32,7 @@ import datasets import evaluate +import torch from datasets import load_dataset import transformers @@ -133,6 +134,16 @@ class ModelArguments: ) }, ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) low_cpu_mem_usage: bool = field( default=False, metadata={ @@ -425,6 +436,11 @@ def main(): ) if model_args.model_name_or_path: + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) model = AutoModelForMaskedLM.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -433,6 +449,7 @@ def main(): revision=model_args.model_revision, token=model_args.token, trust_remote_code=model_args.trust_remote_code, + torch_dtype=torch_dtype, low_cpu_mem_usage=model_args.low_cpu_mem_usage, ) else: