diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f3fd3e232a0a25..4b0ff838c270cb 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -336,7 +336,7 @@ def __init__( self.place_model_on_device = args.place_model_on_device if ( self.is_model_parallel - or (args.deepspeed and args.do_train) + or args.deepspeed or (args.fp16_full_eval and not args.do_train) or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) ): @@ -954,8 +954,15 @@ def train( # memory metrics - must set up as early as possible self._memory_tracker.start() + args = self.args + self.is_in_train = True + # do_train is not a reliable argument, as it might not be set and .train() still called, so + # the following is a workaround: + if args.fp16_full_eval and not args.do_train: + self.model = self.model.to(args.device) + if "model_path" in kwargs: resume_from_checkpoint = kwargs.pop("model_path") warnings.warn( @@ -972,7 +979,7 @@ def train( model_reloaded = False if self.model_init is not None: # Seed must be set before instantiating the model when using model_init. - set_seed(self.args.seed) + set_seed(args.seed) self.model = self.call_model_init(trial) model_reloaded = True # Reinitializes optimizer and scheduler @@ -980,9 +987,9 @@ def train( # Load potential model checkpoint if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: - resume_from_checkpoint = get_last_checkpoint(self.args.output_dir) + resume_from_checkpoint = get_last_checkpoint(args.output_dir) if resume_from_checkpoint is None: - raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})") + raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") if resume_from_checkpoint is not None: if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): @@ -1003,7 +1010,7 @@ def train( # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: if self.place_model_on_device: - self.model = self.model.to(self.args.device) + self.model = self.model.to(args.device) self.model_wrapped = self.model # Keeping track whether we can can len() on the dataset or not @@ -1017,24 +1024,24 @@ def train( # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps if train_dataset_is_sized: - num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps + num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) - if self.args.max_steps > 0: - max_steps = self.args.max_steps - num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( - self.args.max_steps % num_update_steps_per_epoch > 0 + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 ) else: - max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch) - num_train_epochs = math.ceil(self.args.num_train_epochs) + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + num_train_epochs = math.ceil(args.num_train_epochs) else: # see __init__. max_steps is set when the dataset has no __len__ - max_steps = self.args.max_steps - num_train_epochs = int(self.args.num_train_epochs) + max_steps = args.max_steps + num_train_epochs = int(args.num_train_epochs) num_update_steps_per_epoch = max_steps delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE - if self.args.deepspeed: + if args.deepspeed: deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint ) @@ -1068,24 +1075,22 @@ def train( # Train! if is_torch_tpu_available(): world_size = xm.xrt_world_size() - elif self.args.local_rank != -1: + elif args.local_rank != -1: world_size = dist.get_world_size() else: world_size = 1 - total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps * world_size + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size num_examples = ( - self.num_examples(train_dataloader) - if train_dataset_is_sized - else total_train_batch_size * self.args.max_steps + self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps ) logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Num Epochs = {num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") - logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps}") self.state.epoch = 0 @@ -1099,16 +1104,16 @@ def train( ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json")) epochs_trained = self.state.global_step // num_update_steps_per_epoch - if not self.args.ignore_data_skip: + if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) - steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps + steps_trained_in_current_epoch *= args.gradient_accumulation_steps else: steps_trained_in_current_epoch = 0 logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") - if not self.args.ignore_data_skip: + if not args.ignore_data_skip: logger.info( f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " "batches in the first epoch." @@ -1129,17 +1134,17 @@ def train( self.state.is_world_process_zero = self.is_world_process_zero() # tr_loss is a tensor to avoid synchronization of TPUs through .item() - tr_loss = torch.tensor(0.0).to(self.args.device) + tr_loss = torch.tensor(0.0).to(args.device) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step self._total_flos = self.state.total_flos model.zero_grad() - self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control) + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. - if not self.args.ignore_data_skip: + if not args.ignore_data_skip: for epoch in range(epochs_trained): # We just need to begin an iteration to create the randomization of the sampler. for _ in train_dataloader: @@ -1152,23 +1157,19 @@ def train( train_dataloader.dataset.set_epoch(epoch) if is_torch_tpu_available(): - parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( - self.args.device - ) + parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) epoch_iterator = parallel_loader else: epoch_iterator = train_dataloader # Reset the past mems state at the beginning of each epoch if necessary. - if self.args.past_index >= 0: + if args.past_index >= 0: self._past = None steps_in_epoch = ( - len(epoch_iterator) - if train_dataset_is_sized - else self.args.max_steps * self.args.gradient_accumulation_steps + len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps ) - self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) for step, inputs in enumerate(epoch_iterator): @@ -1177,13 +1178,13 @@ def train( steps_trained_in_current_epoch -= 1 continue - if step % self.args.gradient_accumulation_steps == 0: - self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) if ( - ((step + 1) % self.args.gradient_accumulation_steps != 0) - and self.args.local_rank != -1 - and self.args._no_sync_in_gradient_accumulation + ((step + 1) % args.gradient_accumulation_steps != 0) + and args.local_rank != -1 + and args._no_sync_in_gradient_accumulation ): # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. with model.no_sync(): @@ -1196,13 +1197,13 @@ def train( if self.deepspeed: self.deepspeed.step() - if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( + if (step + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps - steps_in_epoch <= self.args.gradient_accumulation_steps + steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch ): # Gradient clipping - if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0 and not self.deepspeed: + if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: # deepspeed does its own clipping if self.use_amp: @@ -1211,15 +1212,15 @@ def train( if hasattr(self.optimizer, "clip_grad_norm"): # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping - self.optimizer.clip_grad_norm(self.args.max_grad_norm) + self.optimizer.clip_grad_norm(args.max_grad_norm) elif hasattr(model, "clip_grad_norm_"): # Some models (like FullyShardedDDP) have a specific way to do gradient clipping - model.clip_grad_norm_(self.args.max_grad_norm) + model.clip_grad_norm_(args.max_grad_norm) else: # Revert to normal clipping otherwise, handling Apex or full precision torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer) if self.use_apex else model.parameters(), - self.args.max_grad_norm, + args.max_grad_norm, ) # Optimizer step @@ -1243,17 +1244,17 @@ def train( model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch - self.control = self.callback_handler.on_step_end(self.args, self.state, self.control) + self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) if self.control.should_epoch_stop or self.control.should_training_stop: break - self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control) + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) - if self.args.tpu_metrics_debug or self.args.debug: + if args.tpu_metrics_debug or args.debug: if is_torch_tpu_available(): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) @@ -1265,16 +1266,16 @@ def train( if self.control.should_training_stop: break - if self.args.past_index and hasattr(self, "_past"): + if args.past_index and hasattr(self, "_past"): # Clean the state at the end of training delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") - if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: # Wait for everyone to get here so we are sur the model has been saved by process 0. if is_torch_tpu_available(): xm.rendezvous("load_best_model_at_end") - elif self.args.local_rank != -1: + elif args.local_rank != -1: dist.barrier() logger.info( @@ -1283,7 +1284,7 @@ def train( if isinstance(self.model, PreTrainedModel): self.model = self.model.from_pretrained(self.state.best_model_checkpoint) if self.place_model_on_device: - self.model = self.model.to(self.args.device) + self.model = self.model.to(args.device) else: state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) self.model.load_state_dict(state_dict) @@ -1299,7 +1300,7 @@ def train( metrics["total_flos"] = self.state.total_flos self.log(metrics) - self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) + self.control = self.callback_handler.on_train_end(args, self.state, self.control) # add remaining tr_loss self._total_loss_scalar += tr_loss.item() @@ -1952,7 +1953,7 @@ def evaluation_loop( model = self._wrap_model(self.model, training=False) # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while - # ``train`` is running, half it first and then put on device + # ``train`` is running, halve it first and then put on device if not self.is_in_train and self.args.fp16_full_eval: model = model.half().to(self.args.device) @@ -2288,7 +2289,7 @@ def prediction_loop( model = self._wrap_model(self.model, training=False) # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while - # ``train`` is running, half it first and then put on device + # ``train`` is running, halve it first and then put on device if not self.is_in_train and self.args.fp16_full_eval: model = model.half().to(self.args.device)