From ef6e371dba6bec48e9dc4e883d913ef2abb55006 Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Thu, 21 Mar 2024 08:09:35 -0700 Subject: [PATCH] Add support for `torch_dtype` in the run_mlm example (#29776) feat: add support for torch_dtype Co-authored-by: Jacky Lee --- examples/pytorch/language-modeling/run_mlm.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 825592a001bb48..474596c4f44893 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -32,6 +32,7 @@ import datasets import evaluate +import torch from datasets import load_dataset import transformers @@ -133,6 +134,16 @@ class ModelArguments: ) }, ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " + "dtype will be automatically derived from the model's weights." + ), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) low_cpu_mem_usage: bool = field( default=False, metadata={ @@ -425,6 +436,11 @@ def main(): ) if model_args.model_name_or_path: + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) model = AutoModelForMaskedLM.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), @@ -433,6 +449,7 @@ def main(): revision=model_args.model_revision, token=model_args.token, trust_remote_code=model_args.trust_remote_code, + torch_dtype=torch_dtype, low_cpu_mem_usage=model_args.low_cpu_mem_usage, ) else: