From 8296b1a7a586f4adfbb45fa7b5687436c50ffc17 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 9 Apr 2020 11:01:48 -0400 Subject: [PATCH] Using configuration for `xla_device` --- examples/run_glue_tpu.py | 14 ++++++-------- src/transformers/configuration_utils.py | 3 +++ src/transformers/modeling_utils.py | 8 +++----- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/run_glue_tpu.py b/examples/run_glue_tpu.py index 893d5fc1c2e6c4..3e3e54e90e429b 100644 --- a/examples/run_glue_tpu.py +++ b/examples/run_glue_tpu.py @@ -162,7 +162,7 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False): # Barrier to wait for saving checkpoint. xm.rendezvous("mid_training_checkpoint") # model.save_pretrained needs to be called by all ordinals - model.save_pretrained(output_dir, xla_device=True) + model.save_pretrained(output_dir) model.train() inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} @@ -416,14 +416,12 @@ def main(args): args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, - xla_device=True, ) model = model_class.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, cache_dir=args.cache_dir if args.cache_dir else None, - xla_device=True, ) if xm.is_master_ordinal(): @@ -457,17 +455,17 @@ def main(args): xm.rendezvous("post_training_checkpoint") # model.save_pretrained needs to be called by all ordinals - model.save_pretrained(args.output_dir, xla_device=True) + model.save_pretrained(args.output_dir) # Load a trained model and vocabulary that you have fine-tuned - model = model_class.from_pretrained(args.output_dir, xla_device=True) - tokenizer = tokenizer_class.from_pretrained(args.output_dir, xla_device=True) + model = model_class.from_pretrained(args.output_dir) + tokenizer = tokenizer_class.from_pretrained(args.output_dir) model.to(args.device) # Evaluation results = {} if args.do_eval: - tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case, xla_device=True) + tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) checkpoints = [args.output_dir] if args.eval_all_checkpoints: checkpoints = list( @@ -479,7 +477,7 @@ def main(args): global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" - model = model_class.from_pretrained(checkpoint, xla_device=True) + model = model_class.from_pretrained(checkpoint) model.to(args.device) result = evaluate(args, model, tokenizer, prefix=prefix, disable_logging=disable_logging) result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 67477b76d0c32d..d438960f303da8 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -102,6 +102,9 @@ def __init__(self, **kwargs): # task specific arguments self.task_specific_params = kwargs.pop("task_specific_params", None) + # TPU arguments + self.xla_device = kwargs.pop("xla_device", None) + # Additional attributes without default values for key, value in kwargs.items(): try: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6ce9a2f2cfd07d..c4e5c576560c4c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -320,13 +320,12 @@ def prune_heads(self, heads_to_prune): self.base_model._prune_heads(heads_to_prune) - def save_pretrained(self, save_directory, xla_device=False): + def save_pretrained(self, save_directory): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method. Arguments: save_directory: directory to which to save. - xla_device: True if saving after training on TPU/XLA. """ assert os.path.isdir( save_directory @@ -341,7 +340,7 @@ def save_pretrained(self, save_directory, xla_device=False): # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(save_directory, WEIGHTS_NAME) - if xla_device: + if hasattr(self.config, "xla_device") and self.config.xla_device: import torch_xla.core.xla_model as xm if xm.is_master_ordinal(): @@ -435,7 +434,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) - xla_device = kwargs.pop("xla_device", False) # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): @@ -640,7 +638,7 @@ def load(module: nn.Module, prefix=""): } return model, loading_info - if xla_device: + if hasattr(config, "xla_device") and config.xla_device: import torch_xla.core.xla_model as xm model = xm.send_cpu_data_to_device(model, xm.xla_device())