From 9c9de38b54f7a0babac68060b01b950dc6c81e10 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 4 Nov 2024 07:47:34 -0500 Subject: [PATCH] Update trainer for easier handling of accumulate, compile fixes, and proper reporting (#34511) * Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta --- src/transformers/modeling_utils.py | 3 +- src/transformers/trainer.py | 73 ++++++++++++++++-------------- tests/trainer/test_trainer.py | 43 +++++++++++++----- 3 files changed, 71 insertions(+), 48 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8481fa7df9cd96..2ef4c3615c9fa2 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -28,7 +28,7 @@ import warnings from contextlib import contextmanager from dataclasses import dataclass -from functools import lru_cache, partial, wraps +from functools import partial, wraps from threading import Thread from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from zipfile import is_zipfile @@ -5014,7 +5014,6 @@ def _is_quantized_training_enabled(self): return self.hf_quantizer.is_trainable @property - @lru_cache def loss_function(self): if getattr(self.config, "loss_type", None) is not None: loss_type = self.config.loss_type diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 30caa2de260cb7..d41b7181be6334 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -233,7 +233,6 @@ from accelerate.utils import ( DistributedDataParallelKwargs, DistributedType, - GradientAccumulationPlugin, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, @@ -601,8 +600,10 @@ def __init__( if not _is_peft_model(unwrapped_model) else unwrapped_model.get_base_model().forward ) - - self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters + forward_params = inspect.signature(model_forward).parameters + self.model_accepts_loss_kwargs = ( + "loss_kwargs" in forward_params and forward_params["loss_kwargs"].kind == inspect.Parameter.VAR_KEYWORD + ) self.neftune_noise_alpha = args.neftune_noise_alpha @@ -2444,7 +2445,7 @@ def _inner_training_loop( update_step += 1 num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) - for inputs in batch_samples: + for i, inputs in enumerate(batch_samples): step += 1 do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch # Since we perform prefetching, we need to manually set sync_gradients @@ -2484,7 +2485,13 @@ def _inner_training_loop( if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - with self.accelerator.accumulate(model): + # We explicitly want to avoid relying on `accelerator.accumulate` for generation training + context = ( + functools.partial(self.accelerator.no_sync, model=model) + if i == len(batch_samples) - 1 + else contextlib.nullcontext + ) + with context(): tr_loss_step = self.training_step(model, inputs, num_items_in_batch) if ( @@ -3636,15 +3643,11 @@ def training_step( with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: - if num_items_in_batch is not None: - if self.compute_loss_func or self.model_accepts_loss_kwargs: - loss *= self.args.gradient_accumulation_steps - # Average tokens across devices is orthogonal to gradient accumulation - if self.args.average_tokens_across_devices: - loss *= self.args.world_size self.accelerator.backward(loss, **kwargs) - - return loss.detach() / self.args.gradient_accumulation_steps + # Finally we need to normalize the loss for reporting + if num_items_in_batch is None: + return loss.detach() / self.args.gradient_accumulation_steps + return loss.detach() def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ @@ -3656,9 +3659,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N labels = inputs.pop("labels") else: labels = None - if self.args.average_tokens_across_devices and num_items_in_batch is not None: - num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device) - num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu()) if self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: @@ -3692,6 +3692,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # We don't use .loss here since the model may return tuples instead of ModelOutput. loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + loss *= self.accelerator.num_processes + return (loss, outputs) if return_outputs else loss def is_local_process_zero(self) -> bool: @@ -4946,24 +4949,21 @@ def _add_sm_patterns_to_gitignore(self) -> None: self.repo.git_push() def create_accelerator_and_postprocess(self): + # We explicitly don't rely on the `Accelerator` to do gradient accumulation grad_acc_kwargs = {} if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs # check if num_steps is attempted to be passed in gradient_accumulation_kwargs - if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1: - # raise because we do not know which setting is intended. - raise ValueError( - "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" - "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." - ) - elif "num_steps" not in grad_acc_kwargs: - # take the gradient_accumulation_steps setting from TrainingArguments. - grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps - - grad_acc_kwargs["sync_with_dataloader"] = False - - gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) + if "num_steps" in grad_acc_kwargs: + if self.args.gradient_accumulation_steps > 1: + # raise because we do not know which setting is intended. + raise ValueError( + "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" + "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." + ) + else: + self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"] accelerator_config = self.args.accelerator_config.to_dict() @@ -4994,7 +4994,6 @@ def create_accelerator_and_postprocess(self): args = { "deepspeed_plugin": self.args.deepspeed_plugin, - "gradient_accumulation_plugin": gradient_accumulation_plugin, } if is_accelerate_available("0.28.0"): args["dataloader_config"] = dataloader_config @@ -5090,12 +5089,18 @@ def get_batch_samples(self, epoch_iterator, num_batches): batch_samples += [next(epoch_iterator)] except StopIteration: break + + # Keep default behavior the same + if not self.model_accepts_loss_kwargs: + return batch_samples, None + if len(batch_samples) > 0 and "labels" in batch_samples[0]: # For now we don't support object detection try: - num_items_in_batch = sum( - [data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples] - ) - except TypeError: + num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + except (TypeError, AttributeError): pass + + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() return batch_samples, num_items_in_batch diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b6fe807fa4961a..5658372fa71308 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -272,6 +272,19 @@ def __getitem__(self, i): return {"input_ids": self.x, "labels": self.x} +class SequenceClassificationDataset: + def __init__(self, length=64, vocab_size=100, num_labels=5): + self.length = length + self.sequences = [torch.randint(0, vocab_size, (64,)).tolist() for _ in range(length)] + self.labels = torch.randint(0, num_labels, (length,)).tolist() + + def __len__(self): + return self.length + + def __getitem__(self, i): + return {"input_ids": self.sequences[i], "label": self.labels[i]} + + class DynamicShapesDataset: def __init__(self, length=64, seed=42, batch_size=8): self.length = length @@ -1144,6 +1157,23 @@ def test_number_of_steps_in_training_with_ipex(self): train_output = trainer.train() self.assertEqual(train_output.global_step, 10) + def test_torch_compile_loss_func_compatibility(self): + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) + + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmp_dir: + args = TrainingArguments( + tmp_dir, + per_device_train_batch_size=2, + torch_compile=True, + max_steps=1, # compile happens on the first step + ) + trainer = Trainer(model=tiny_llama, args=args, train_dataset=train_dataset) # noqa + trainer.train() + @require_peft @require_bitsandbytes def test_bnb_compile(self): @@ -3676,9 +3706,6 @@ def test_accelerator_config_from_dict(self): self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, True) - if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True) - def test_accelerator_config_from_yaml(self): # Checks that accelerator kwargs can be passed through # and the accelerator is initialized respectively @@ -3691,8 +3718,6 @@ def test_accelerator_config_from_yaml(self): "even_batches": False, "use_seedable_sampler": False, } - if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: - accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True} json.dump(accelerator_config, f) config = RegressionModelConfig(a=1.5, b=2.5) model = RegressionPreTrainedModel(config) @@ -3706,9 +3731,6 @@ def test_accelerator_config_from_yaml(self): self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, False) - if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE: - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True) - def test_accelerator_config_from_dataclass(self): # Checks that accelerator kwargs can be passed through # and the accelerator is initialized respectively @@ -3754,10 +3776,7 @@ def test_accelerate_config_from_dataclass_grad_accum(self): with tempfile.TemporaryDirectory() as tmp_dir: args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config) trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10) - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False) - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False) - self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True) + self.assertEqual(trainer.args.gradient_accumulation_steps, 10) def test_accelerator_config_from_partial(self): # Checks that accelerator kwargs can be passed through