diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index b8d0f312216d11..6bc35994f34abe 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -41,7 +41,7 @@ body: Integrations: - - deepspeed: HF Trainer: @stas00, Accelerate: @pacman100 + - deepspeed: HF Trainer/Accelerate: @pacman100 - ray/raytune: @richardliaw, @amogkam - Big Model Inference: @sgugger @muellerzr diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index e38d5ac9242ea1..4c2ca5752e2987 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -55,7 +55,7 @@ Library: Integrations: -- deepspeed: HF Trainer: @stas00, Accelerate: @pacman100 +- deepspeed: HF Trainer/Accelerate: @pacman100 - ray/raytune: @richardliaw, @amogkam Documentation: @sgugger, @stevhliu and @MKhalusova diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index 3a36db5f3ea54b..7af2bedece84a7 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -17,7 +17,6 @@ import importlib.util import weakref -from copy import deepcopy from functools import partialmethod from .dependency_versions_check import dep_version_check @@ -256,10 +255,12 @@ def deepspeed_config(): return None -def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps): +def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters): """ A convenience wrapper that deals with optimizer and lr scheduler configuration. """ + from accelerate.utils import DummyOptim, DummyScheduler + config = hf_deepspeed_config.config # Optimizer + Scheduler @@ -267,13 +268,13 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps # 1. DS scheduler + DS optimizer: Yes # 2. HF scheduler + HF optimizer: Yes # 3. DS scheduler + HF optimizer: Yes - # 4. HF scheduler + DS optimizer: Yes + # 4. HF scheduler + DS optimizer: No # # Unless Offload is enabled in which case it's: # 1. DS scheduler + DS optimizer: Yes # 2. HF scheduler + HF optimizer: Mostly* # 3. DS scheduler + HF optimizer: Mostly* - # 4. HF scheduler + DS optimizer: Yes + # 4. HF scheduler + DS optimizer: No # # Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB) @@ -284,6 +285,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps "--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. " "Only one optimizer can be configured." ) + optimizer = DummyOptim(params=model_parameters) else: if hf_deepspeed_config.is_offload(): logger.info( @@ -297,21 +299,21 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer` config["zero_allow_untested_optimizer"] = True - def _lr_scheduler_callable(optimizer): - return trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) - lr_scheduler = None - if "scheduler" not in config: - if optimizer is None: - # Optimizer is not available, so use callable to defer lr_scheduler creation to DS init - lr_scheduler = _lr_scheduler_callable - else: - lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) + if "scheduler" in config: + lr_scheduler = DummyScheduler(optimizer) + else: + if isinstance(optimizer, DummyOptim): + raise ValueError( + "Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. " + "Please configure a scheduler in the DeepSpeed config." + ) + lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) return optimizer, lr_scheduler -def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False): +def deepspeed_init(trainer, num_training_steps, inference=False): """ Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. @@ -323,28 +325,22 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load inference: launch in inference mode (no optimizer and no lr scheduler) - Returns: model, optimizer, lr_scheduler + Returns: optimizer, lr_scheduler We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on: https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it can't resume from a checkpoint after it did some stepping https://github.com/microsoft/DeepSpeed/issues/1612 """ - import deepspeed from deepspeed.utils import logger as ds_logger model = trainer.model args = trainer.args - if hasattr(trainer, "hf_deepspeed_config_orig"): - hf_deepspeed_config = deepcopy(trainer.hf_deepspeed_config_orig) - else: - hf_deepspeed_config = args.hf_deepspeed_config - trainer.hf_deepspeed_config_orig = deepcopy(args.hf_deepspeed_config) + hf_deepspeed_config = trainer.accelerator.state.deepspeed_plugin.hf_ds_config # resume config update - some bits like `model` and `num_training_steps` only become available during train hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) - config = hf_deepspeed_config.config # set the Deepspeed log level consistent with the Trainer ds_logger.setLevel(args.get_process_log_level()) @@ -361,40 +357,33 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf model_parameters = None else: trainer.optimizer = None # important for when deepspeed_init is used as re-init - optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps) model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + optimizer, lr_scheduler = deepspeed_optim_sched( + trainer, hf_deepspeed_config, args, num_training_steps, model_parameters + ) # keep for quick debug: # from pprint import pprint; pprint(config) - kwargs = { - "model": model, - "model_parameters": model_parameters, - "config_params": config, - "optimizer": optimizer, - "lr_scheduler": lr_scheduler, - } - - deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) - - if resume_from_checkpoint is not None: - # it's possible that the user is trying to resume from model_path, which doesn't necessarily - # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's - # a resume from a checkpoint and not just a local pretrained weight. So we check here if the - # path contains what looks like a deepspeed checkpoint - import glob - - deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/global_step*")) - - if len(deepspeed_checkpoint_dirs) > 0: - logger.info(f"Attempting to resume from {resume_from_checkpoint}") - # this magically updates self.optimizer and self.lr_scheduler - load_path, _ = deepspeed_engine.load_checkpoint( - resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True - ) - if load_path is None: - raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}") - else: - raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") + return optimizer, lr_scheduler + - return deepspeed_engine, optimizer, lr_scheduler +def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): + # it's possible that the user is trying to resume from model_path, which doesn't necessarily + # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's + # a resume from a checkpoint and not just a local pretrained weight. So we check here if the + # path contains what looks like a deepspeed checkpoint + import glob + + deepspeed_checkpoint_dirs = sorted(glob.glob(f"{checkpoint_path}/global_step*")) + + if len(deepspeed_checkpoint_dirs) > 0: + logger.info(f"Attempting to resume from {checkpoint_path}") + # this magically updates self.optimizer and self.lr_scheduler + load_path, _ = deepspeed_engine.load_checkpoint( + checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True + ) + if load_path is None: + raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}") + else: + raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}") diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 9e8f51bd934566..40d5d4c022cbe2 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -112,6 +112,10 @@ ) +if is_accelerate_available(): + from accelerate.state import AcceleratorState, PartialState + + SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown" DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" @@ -1331,6 +1335,9 @@ def tearDown(self): for path in self.teardown_tmp_dirs: shutil.rmtree(path, ignore_errors=True) self.teardown_tmp_dirs = [] + if is_accelerate_available(): + AcceleratorState._reset_state() + PartialState._reset_state() def mockenv(**kwargs): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1531a65e4695cc..6289dd7f72163c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -67,7 +67,7 @@ from .configuration_utils import PretrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .debug_utils import DebugOption, DebugUnderflowOverflow -from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled +from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled from .dependency_versions_check import dep_version_check from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model @@ -337,18 +337,34 @@ def __init__( # Seed must be set before instantiating the model when using model enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) self.hp_name = None - self.deepspeed = None self.is_in_train = False # create accelerator object - self.accelerator = Accelerator() + self.accelerator = Accelerator( + deepspeed_plugin=self.args.deepspeed_plugin, + gradient_accumulation_steps=self.args.gradient_accumulation_steps, + ) + + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None # post accelerator creation setup - if getattr(self.accelerator.state, "fsdp_plugin", None) is not None: + if self.is_fsdp_enabled: fsdp_plugin = self.accelerator.state.fsdp_plugin fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False) fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False) + if self.is_deepspeed_enabled: + if getattr(self.args, "hf_deepspeed_config", None) is None: + from transformers.deepspeed import HfTrainerDeepSpeedConfig + + ds_plugin = self.accelerator.state.deepspeed_plugin + + ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) + ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config + ds_plugin.hf_ds_config.trainer_config_process(self.args) + # memory metrics - must set up as early as possible self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) self._memory_tracker.start() @@ -420,7 +436,7 @@ def __init__( # Setup Sharded DDP training self.sharded_ddp = None if len(args.sharded_ddp) > 0: - if args.deepspeed: + if self.is_deepspeed_enabled: raise ValueError( "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." ) @@ -446,7 +462,7 @@ def __init__( self.fsdp = None if len(args.fsdp) > 0: - if args.deepspeed: + if self.is_deepspeed_enabled: raise ValueError( "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." ) @@ -494,10 +510,11 @@ def __init__( self.place_model_on_device = args.place_model_on_device if ( self.is_model_parallel - or args.deepspeed + or self.is_deepspeed_enabled or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) or (self.fsdp is not None) + or self.is_fsdp_enabled ): self.place_model_on_device = False @@ -541,7 +558,7 @@ def __init__( " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." ) - if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and ( + if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( self.optimizer is not None or self.lr_scheduler is not None ): raise RuntimeError( @@ -634,7 +651,7 @@ def __init__( logger.info(f"Using {args.half_precision_backend} half precision backend") self.do_grad_scaling = False - if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()): + if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): # deepspeed and SageMaker Model Parallel manage their own half precision if self.sharded_ddp is not None: if args.half_precision_backend == "cuda_amp": @@ -1316,12 +1333,17 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): logger.info(f"SigOpt Assignments: {trial.assignments}") if self.hp_search_backend == HPSearchBackend.WANDB: logger.info(f"W&B Sweep parameters: {trial}") - if self.args.deepspeed: + if self.is_deepspeed_enabled: + if self.args.deepspeed is None: + raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set") # Rebuild the deepspeed config to reflect the updated training parameters + from accelerate.utils import DeepSpeedPlugin + from transformers.deepspeed import HfTrainerDeepSpeedConfig self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) self.args.hf_deepspeed_config.trainer_config_process(self.args) + self.accelerator.state.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config) def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): if self.hp_search_backend is None or trial is None: @@ -1440,10 +1462,6 @@ def _wrap_model(self, model, training=True, dataloader=None): return self.model_wrapped return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) - # already initialized its own DDP and AMP - if self.deepspeed: - return self.deepspeed - # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again if unwrap_model(model) is not model: return model @@ -1628,7 +1646,7 @@ def train( if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") - if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and args.deepspeed is None: + if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled: self._load_from_checkpoint(resume_from_checkpoint) # If model was re-initialized, put it on the right device and update self.model_wrapped @@ -1717,16 +1735,11 @@ def _inner_training_loop( or is_sagemaker_mp_enabled() or self.fsdp is not None ) - if args.deepspeed: - deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( - self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - elif not delay_optimizer_creation: + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() @@ -1755,6 +1768,27 @@ def _inner_training_loop( self.model, self.optimizer, self.lr_scheduler ) + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # deepspeed ckpt loading + if resume_from_checkpoint is not None and self.is_deepspeed_enabled: + deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + if getattr(self.accelerator.state, "fsdp_plugin", None) is not None: self.model = model @@ -1921,16 +1955,7 @@ def _inner_training_loop( if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - if ( - (total_batched_samples % args.gradient_accumulation_steps != 0) - and args.parallel_mode == ParallelMode.DISTRIBUTED - and args._no_sync_in_gradient_accumulation - and hasattr(model, "no_sync") - ): - # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. - with model.no_sync(): - tr_loss_step = self.training_step(model, inputs) - else: + with self.accelerator.accumulate(model): tr_loss_step = self.training_step(model, inputs) if ( @@ -1945,17 +1970,16 @@ def _inner_training_loop( self.current_flos += float(self.floating_point_ops(inputs)) - # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps - if self.deepspeed: - self.deepspeed.step() - + # should this be under the accumulate context manager? + # the `or` condition of `steps_in_epoch <= args.gradient_accumulation_steps` is not covered + # in accelerate if total_batched_samples % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch ): # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: + if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping if self.do_grad_scaling: @@ -1988,9 +2012,7 @@ def _inner_training_loop( # Optimizer step optimizer_was_run = True - if self.deepspeed: - pass # called outside the loop - elif is_torch_tpu_available(): + if is_torch_tpu_available(): if self.do_grad_scaling: self.scaler.step(self.optimizer) self.scaler.update() @@ -2005,7 +2027,7 @@ def _inner_training_loop( else: self.optimizer.step() - if optimizer_was_run and not self.deepspeed: + if optimizer_was_run: # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() @@ -2159,6 +2181,8 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): load_result = model.load_state_dict(state_dict, strict=True) # release memory del state_dict + elif self.is_fsdp_enabled: + self.accelerator.state.fsdp_plugin.load_model(self.accelerator, model, resume_from_checkpoint) else: # We load the model state dict on the CPU to avoid an OOM error. if self.args.save_safetensors and os.path.isfile(safe_weights_file): @@ -2186,23 +2210,8 @@ def _load_best_model(self): best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path): - if self.deepspeed: - if self.model_wrapped is not None: - # this removes the pre-hooks from the previous engine - self.model_wrapped.destroy() - self.model_wrapped = None - - # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping - deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( - self, - num_training_steps=self.args.max_steps, - resume_from_checkpoint=self.state.best_model_checkpoint, - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) else: if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): @@ -2224,6 +2233,10 @@ def _load_best_model(self): state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) + elif self.is_fsdp_enabled: + self.accelerator.state.fsdp_plugin.load_model( + self.accelerator, model, self.state.best_model_checkpoint + ) else: if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False): # If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly. @@ -2381,10 +2394,10 @@ def _save_checkpoint(self, model, trial, metrics=None): run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) self.save_model(output_dir, _internal_call=True) - if self.deepspeed: + if self.is_deepspeed_enabled: # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed # config `stage3_gather_16bit_weights_on_model_save` is True - self.deepspeed.save_checkpoint(output_dir) + self.model_wrapped.save_checkpoint(output_dir) # Save optimizer and scheduler if self.sharded_ddp == ShardedDDPOption.SIMPLE: @@ -2418,7 +2431,7 @@ def _save_checkpoint(self, model, trial, metrics=None): reissue_pt_warnings(caught_warnings) if self.do_grad_scaling: torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) - elif self.args.should_save and not self.deepspeed: + elif self.args.should_save and not self.is_deepspeed_enabled: # deepspeed.save_checkpoint above saves model/optim/sched if self.fsdp: torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) @@ -2488,7 +2501,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): if checkpoint is None: return - if self.deepspeed: + if self.is_deepspeed_enabled: # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init return @@ -2675,11 +2688,11 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, return type(data)(self._prepare_input(v) for v in data) elif isinstance(data, torch.Tensor): kwargs = {"device": self.args.device} - if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)): + if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): # NLP models inputs are int/uint and those get adjusted to the right dtype of the # embedding. Other models such as wav2vec2's inputs are already float and thus # may need special handling to match the dtypes of the model - kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()}) + kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) return data.to(**kwargs) return data @@ -2755,22 +2768,15 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training - if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: - # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` - loss = loss / self.args.gradient_accumulation_steps - if self.do_grad_scaling: self.scaler.scale(loss).backward() elif self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() - elif self.deepspeed: - # loss gets scaled under gradient_accumulation_steps in deepspeed - loss = self.deepspeed.backward(loss) else: self.accelerator.backward(loss) - return loss.detach() + return loss.detach() / self.args.gradient_accumulation_steps def compute_loss(self, model, inputs, return_outputs=False): """ @@ -2848,16 +2854,16 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp or self.fsdp is not None - or getattr(self.accelerator.state, "fsdp_plugin", None) is not None + or self.is_fsdp_enabled ): - if getattr(self.accelerator.state, "fsdp_plugin", None) is not None: + if self.is_fsdp_enabled: self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir) else: state_dict = self.model.state_dict() if self.args.should_save: self._save(output_dir, state_dict=state_dict) - elif self.deepspeed: + elif self.is_deepspeed_enabled: # this takes care of everything as long as we aren't under zero3 if self.args.should_save: self._save(output_dir) @@ -2876,13 +2882,13 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa # now save the real model if stage3_gather_16bit_weights_on_model_save=True # if false it will not be saved. # This must be called on all ranks - if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME): + if not self.model_wrapped.save_16bit_model(output_dir, WEIGHTS_NAME): logger.warning( "deepspeed.save_16bit_model didn't save the model, since" " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" " zero_to_fp32.py to recover weights" ) - self.deepspeed.save_checkpoint(output_dir) + self.model_wrapped.save_checkpoint(output_dir) elif self.args.should_save: self._save(output_dir) @@ -3162,15 +3168,10 @@ def evaluation_loop( prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only # if eval is called w/o train init deepspeed here - if args.deepspeed and not self.deepspeed: - # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval - # from the checkpoint eventually - deepspeed_engine, _, _ = deepspeed_init( - self, num_training_steps=0, resume_from_checkpoint=None, inference=True - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine + if self.is_deepspeed_enabled and self.model_wrapped is self.model: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) + model = self.accelerator.prepare(self.model) + self.model_wrapped = self.deepspeed = model model = self._wrap_model(self.model, training=False, dataloader=dataloader) @@ -3762,18 +3763,10 @@ def prediction_loop( prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only # if eval is called w/o train init deepspeed here - if args.deepspeed and not self.deepspeed: - # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval - # from the checkpoint eventually - deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine - # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since - # for example the Z3-optimizer is a must for zero3 to work even for inference - what we - # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer - deepspeed_engine.optimizer.optimizer = None - deepspeed_engine.lr_scheduler = None + if self.is_deepspeed_enabled and self.model_wrapped is self.model: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) + model = self.accelerator.prepare(self.model) + self.model_wrapped = self.deepspeed = model model = self._wrap_model(self.model, training=False, dataloader=dataloader) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index dee1dce0f6f741..011f3162a635c5 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -838,7 +838,7 @@ def __len__(self): def _get_learning_rate(self): - if self.deepspeed: + if self.is_deepspeed_enabled: # with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may # not run for the first few dozen steps while loss scale is too large, and thus during # that time `get_last_lr` will fail if called during that warm up stage, so work around it: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 403d437843ec18..d213b08d6de792 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -64,7 +64,7 @@ import torch.distributed as dist if is_accelerate_available(): - from accelerate import PartialState + from accelerate.state import AcceleratorState, PartialState from accelerate.utils import DistributedType if is_torch_tpu_available(check_device=False): @@ -1550,6 +1550,7 @@ def __post_init__(self): if isinstance(self.debug, str): self.debug = [DebugOption(s) for s in self.debug.split()] + self.deepspeed_plugin = None if self.deepspeed: # - must be run very last in arg parsing, since it will use a lot of these settings. # - must be run before the model is created. @@ -1562,6 +1563,12 @@ def __post_init__(self): self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed) self.hf_deepspeed_config.trainer_config_process(self) + # Accelerate DeepSpeed Plugin + from accelerate.utils import DeepSpeedPlugin + + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config) + if self.push_to_hub_token is not None: warnings.warn( "`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " @@ -1660,6 +1667,8 @@ def ddp_timeout_delta(self) -> timedelta: def _setup_devices(self) -> "torch.device": requires_backends(self, ["torch"]) logger.info("PyTorch: setting up devices") + AcceleratorState._reset_state() + PartialState._reset_state() if not is_sagemaker_mp_enabled() and not is_accelerate_available(check_partial_state=True): raise ImportError( "Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`" diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 446952ef7a0422..c460bc9c150877 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -365,16 +365,19 @@ def test_ds_scheduler_hf_optimizer(self): self.assertNotEqual(new_a, a) def test_hf_scheduler_ds_optimizer(self): - a = 0 with mockenv_context(**self.dist_env_1_gpu): ds_config_zero2_dict = self.get_config_dict(ZERO2) del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler ds_config_zero2_dict["zero_optimization"]["offload_optimizer"]["device"] = "none" ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict) - trainer.train() - new_a = trainer.model.a.item() - self.assertNotEqual(new_a, a) + with self.assertRaises(Exception) as context: + trainer.train() + self.assertIn( + "Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. " + "Please configure a scheduler in the DeepSpeed config.", + str(context.exception), + ) @require_deepspeed_aio def test_stage3_nvme_offload(self): @@ -751,6 +754,8 @@ def test_config_object(self): config = deepspeed_config() self.assertTrue(bool(config), "Deepspeed config should be accessible") + # with accelerate integration below line is additionally required for this test to pass + trainer.accelerator.state._reset_state() del trainer # now weakref should gc the global and we shouldn't get anything here config = deepspeed_config() @@ -783,8 +788,8 @@ def test_load_best_model(self, stage, dtype): with mockenv_context(**self.dist_env_1_gpu): args_dict = { - "per_gpu_train_batch_size": 1, - "per_gpu_eval_batch_size": 1, + "per_device_train_batch_size": 1, + "per_device_eval_batch_size": 1, "gradient_accumulation_steps": 1, "learning_rate": 1e-4, "num_train_epochs": 1,