From 4acce7a43be3aa1d29d32c353ca2d15dd27bddf0 Mon Sep 17 00:00:00 2001 From: Kazem Faghih Date: Sun, 30 Jul 2023 12:23:45 +0330 Subject: [PATCH 1/4] Add lr_scheduler checkpointing when deepspeed does not support the lr_scheduler type --- src/transformers/trainer.py | 7 +++++++ src/transformers/training_args.py | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 319c36e7c874b2..695a1f9a6dd6f0 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1669,6 +1669,11 @@ def _inner_training_loop( # deepspeed ckpt loading if resume_from_checkpoint is not None and self.is_deepspeed_enabled: deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + if self.args.deepspeed_force_lr_scheduler_checkpointing and self.model_wrapped.lr_scheduler is None: + if os.path.isfile(os.path.join(resume_from_checkpoint, SCHEDULER_NAME)): + with warnings.catch_warnings(record=True) as caught_warnings: + self.lr_scheduler.load_state_dict(torch.load(os.path.join(resume_from_checkpoint, SCHEDULER_NAME))) + reissue_pt_warnings(caught_warnings) # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) @@ -2290,6 +2295,8 @@ def _save_checkpoint(self, model, trial, metrics=None): # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed # config `stage3_gather_16bit_weights_on_model_save` is True self.model_wrapped.save_checkpoint(output_dir) + if self.args.deepspeed_force_lr_scheduler_checkpointing and self.model_wrapped.lr_scheduler is None: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) # Save optimizer and scheduler if self.sharded_ddp == ShardedDDPOption.SIMPLE: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 8189c22fe5a1c5..694cb630640891 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1209,6 +1209,18 @@ class TrainingArguments: }, ) + deepspeed_force_lr_scheduler_checkpointing: bool = field( + default=False, + metadata={ + "help": ( + "Force saving and loading or checkpointing the lr_scheduler when deepspeed is enabled and it does not " + "support the lr_scheduler type. " + "Use this to force keeping track of lr_scheduler when the model lr_scheduler type does not fall into " + "its supported lr_scheduler categories." + ) + }, + ) + def __post_init__(self): # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home From c0d12fc950be0bee4407aadb3e7cee5e7d42ef6a Mon Sep 17 00:00:00 2001 From: Kazem Faghih Date: Mon, 31 Jul 2023 14:07:18 +0330 Subject: [PATCH 2/4] Fix an issue with missing keys of state_dict in loading Lora model --- src/transformers/deepspeed.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index 7af2bedece84a7..7b1d6e853a11ff 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -381,7 +381,8 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): logger.info(f"Attempting to resume from {checkpoint_path}") # this magically updates self.optimizer and self.lr_scheduler load_path, _ = deepspeed_engine.load_checkpoint( - checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True + checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True, + load_module_strict=False ) if load_path is None: raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}") From fd47c333cc6c128e924eaf4d10bbc351b325cf58 Mon Sep 17 00:00:00 2001 From: thepowerfuldeez Date: Tue, 13 Feb 2024 18:02:54 +0800 Subject: [PATCH 3/4] merge external branch --- src/transformers/deepspeed.py | 375 --------------------- src/transformers/integrations/deepspeed.py | 2 +- src/transformers/trainer.py | 8 +- src/transformers/training_args.py | 3 +- 4 files changed, 5 insertions(+), 383 deletions(-) diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index 086431c941c054..840d9cc2f55a16 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -38,378 +38,3 @@ set_hf_deepspeed_config, unset_hf_deepspeed_config, ) - -import importlib.util -import weakref -from functools import partialmethod - -from .dependency_versions_check import dep_version_check -from .utils import is_accelerate_available, is_torch_available, logging - - -if is_torch_available(): - import torch - -logger = logging.get_logger(__name__) - - -def is_deepspeed_available(): - return importlib.util.find_spec("deepspeed") is not None - - -if is_accelerate_available() and is_deepspeed_available(): - from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig -else: - # Inherits from a dummy `object` if accelerate is not available, so that python succeeds to import this file. - # Deepspeed glue code will never inherit this dummy object as it checks if accelerate is available. - from builtins import object as DeepSpeedConfig - - -class HfDeepSpeedConfig(DeepSpeedConfig): - """ - This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage. - - A `weakref` of this object is stored in the module's globals to be able to access the config from areas where - things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore - it's important that this object remains alive while the program is still running. - - [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration - with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic - the DeepSpeed configuration is not modified in any way. - - Args: - config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict. - - """ - - def __init__(self, config_file_or_dict): - # set global weakref object - set_hf_deepspeed_config(self) - dep_version_check("accelerate") - dep_version_check("deepspeed") - super().__init__(config_file_or_dict) - - -class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): - """ - The `HfTrainerDeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the - same lifespan as the latter. - """ - - def __init__(self, config_file_or_dict): - super().__init__(config_file_or_dict) - self._dtype = None - self.mismatches = [] - - def dtype(self): - if self._dtype is None: - raise ValueError("trainer_config_process() wasn't called yet to tell dtype") - return self._dtype - - def is_auto(self, ds_key_long): - val = self.get_value(ds_key_long) - if val is None: - return False - else: - return val == "auto" - - def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True): - """ - A utility method that massages the config file and can optionally verify that the values match. - - 1. Replace "auto" values with `TrainingArguments` value. - - 2. If it wasn't "auto" and `must_match` is true, then check that DS config matches Trainer - config values and if mismatched add the entry to `self.mismatched` - will assert during - `trainer_config_finalize` for one or more mismatches. - - """ - config, ds_key = self.find_config_node(ds_key_long) - if config is None: - return - - if config.get(ds_key) == "auto": - config[ds_key] = hf_val - return - - if not must_match: - return - - ds_val = config.get(ds_key) - if ds_val is not None and ds_val != hf_val: - self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}") - - fill_only = partialmethod(fill_match, must_match=False) - - def trainer_config_process(self, args): - """ - Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object - creation. - """ - # DeepSpeed does: - # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps - train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps - self.fill_match( - "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, "per_device_train_batch_size" - ) - self.fill_match("gradient_accumulation_steps", args.gradient_accumulation_steps, "gradient_accumulation_steps") - self.fill_match("train_batch_size", train_batch_size, "train_batch_size (calculated)") - self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm") - - self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate") - self.fill_match("optimizer.params.betas", [args.adam_beta1, args.adam_beta2], "adam_beta1+adam_beta2") - self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon") - self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay") - - self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg - self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate") - # total_num_steps - will get set in trainer_config_finalize - - # fp16 - if args.fp16 or args.fp16_full_eval: - fp16_backend = "apex" if args.fp16_backend == "apex" else "amp" - else: - fp16_backend = None - - if args.save_on_each_node: - # deepspeed uses shared storage by default. Let's override this setting if save_on_each_node == True - self.config["checkpoint"] = self.config.get("checkpoint", {}) - self.config["checkpoint"]["use_node_local_storage"] = args.save_on_each_node - - # amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set - # any here unless the user did the work - self.fill_match( - "fp16.enabled", - ((args.fp16 or args.fp16_full_eval) and fp16_backend == "amp"), - "fp16|fp16_full_eval+fp16_backend(amp)", - ) - - # apex: delegates amp work to apex (which needs to be available), but it cannot be used with any - # ZeRO features - self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)") - self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level") - - self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval") - - # deepspeed's default mode is fp16 unless there is a config that says differently - if self.is_true("bf16.enabled"): - self._dtype = torch.bfloat16 - elif self.is_false("fp16.enabled"): - self._dtype = torch.float32 - else: - self._dtype = torch.float16 - - def trainer_config_finalize(self, args, model, num_training_steps): - """ - This stage is run after we have the model and know num_training_steps. - - Now we can complete the configuration process. - """ - # zero - - # deal with config keys that use `auto` value and rely on model's hidden_size - hidden_size_based_keys = [ - "zero_optimization.reduce_bucket_size", - "zero_optimization.stage3_prefetch_bucket_size", - "zero_optimization.stage3_param_persistence_threshold", - ] - hidden_size_auto_keys = [x for x in hidden_size_based_keys if self.is_auto(x)] - - if len(hidden_size_auto_keys) > 0: - if hasattr(model.config, "hidden_size"): - hidden_size = model.config.hidden_size - elif hasattr(model.config, "hidden_sizes"): - # if there are many hidden sizes pick the largest one - hidden_size = max(model.config.hidden_sizes) - else: - raise ValueError( - "The model's config file has neither `hidden_size` nor `hidden_sizes` entry, " - "therefore it's not possible to automatically fill out the following `auto` entries " - f"in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing " - "`auto` values for these keys with an integer value of your choice." - ) - - self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size) - if self.is_zero3(): - # automatically assign the optimal config values based on model config - self.fill_only("zero_optimization.stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size) - self.fill_only("zero_optimization.stage3_param_persistence_threshold", 10 * hidden_size) - - # scheduler - self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)") - self.fill_match("scheduler.params.warmup_num_steps", args.get_warmup_steps(num_training_steps), "warmup_steps") - - if len(self.mismatches) > 0: - mismatches = "\n".join(self.mismatches) - raise ValueError( - "Please correct the following DeepSpeed config values that mismatch TrainingArguments" - f" values:\n{mismatches}\nThe easiest method is to set these DeepSpeed config values to 'auto'." - ) - - -# keep the config object global to be able to access it anywhere during TrainingArguments life-cycle -_hf_deepspeed_config_weak_ref = None - - -def set_hf_deepspeed_config(hf_deepspeed_config_obj): - # this is a special weakref global object to allow us to get to Deepspeed config from APIs - # that don't have an easy way to get to the Deepspeed config outside of the Trainer domain. - global _hf_deepspeed_config_weak_ref - # will go away automatically when HfDeepSpeedConfig is destroyed (when TrainingArguments is destroyed) - _hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj) - - -def unset_hf_deepspeed_config(): - # useful for unit tests to ensure the global state doesn't leak - call from `tearDown` method - global _hf_deepspeed_config_weak_ref - _hf_deepspeed_config_weak_ref = None - - -def is_deepspeed_zero3_enabled(): - if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: - return _hf_deepspeed_config_weak_ref().is_zero3() - else: - return False - - -def deepspeed_config(): - if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None: - return _hf_deepspeed_config_weak_ref().config - else: - return None - - -def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters): - """ - A convenience wrapper that deals with optimizer and lr scheduler configuration. - """ - from accelerate.utils import DummyOptim, DummyScheduler - - config = hf_deepspeed_config.config - - # Optimizer + Scheduler - # Currently supported combos: - # 1. DS scheduler + DS optimizer: Yes - # 2. HF scheduler + HF optimizer: Yes - # 3. DS scheduler + HF optimizer: Yes - # 4. HF scheduler + DS optimizer: No - # - # Unless Offload is enabled in which case it's: - # 1. DS scheduler + DS optimizer: Yes - # 2. HF scheduler + HF optimizer: Mostly* - # 3. DS scheduler + HF optimizer: Mostly* - # 4. HF scheduler + DS optimizer: No - # - # Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB) - - optimizer = None - if "optimizer" in config: - if args.adafactor: - raise ValueError( - "--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. " - "Only one optimizer can be configured." - ) - optimizer = DummyOptim(params=model_parameters) - else: - if hf_deepspeed_config.is_offload(): - logger.info( - "Detected ZeRO Offload and non-DeepSpeed optimizers: This combination should work as long as the" - " custom optimizer has both CPU and GPU implementation (except LAMB)" - ) - - # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch. - # But trainer uses AdamW by default. - optimizer = trainer.create_optimizer() - # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer` - config["zero_allow_untested_optimizer"] = True - - lr_scheduler = None - if "scheduler" in config: - lr_scheduler = DummyScheduler(optimizer) - else: - if isinstance(optimizer, DummyOptim): - raise ValueError( - "Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. " - "Please configure a scheduler in the DeepSpeed config." - ) - lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) - - return optimizer, lr_scheduler - - -def deepspeed_init(trainer, num_training_steps, inference=False): - """ - Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. - - If `resume_from_checkpoint` was passed then an attempt to resume from a previously saved checkpoint will be made. - - Args: - trainer: Trainer object - num_training_steps: per single gpu - resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load - inference: launch in inference mode (no optimizer and no lr scheduler) - - Returns: optimizer, lr_scheduler - - We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on: - https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it - can't resume from a checkpoint after it did some stepping https://github.com/microsoft/DeepSpeed/issues/1612 - - """ - from deepspeed.utils import logger as ds_logger - - model = trainer.model - args = trainer.args - - hf_deepspeed_config = trainer.accelerator.state.deepspeed_plugin.hf_ds_config - - # resume config update - some bits like `model` and `num_training_steps` only become available during train - hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) - - # set the Deepspeed log level consistent with the Trainer - ds_logger.setLevel(args.get_process_log_level()) - - if inference: - # only Z3 makes sense for the inference - if not hf_deepspeed_config.is_zero3(): - raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config") - - # in case the training config is re-used for inference - hf_deepspeed_config.del_config_sub_tree("optimizer") - hf_deepspeed_config.del_config_sub_tree("lr_scheduler") - optimizer, lr_scheduler = None, None - model_parameters = None - else: - trainer.optimizer = None # important for when deepspeed_init is used as re-init - model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) - optimizer, lr_scheduler = deepspeed_optim_sched( - trainer, hf_deepspeed_config, args, num_training_steps, model_parameters - ) - - # keep for quick debug: - # from pprint import pprint; pprint(config) - - return optimizer, lr_scheduler - - -def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): - # it's possible that the user is trying to resume from model_path, which doesn't necessarily - # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's - # a resume from a checkpoint and not just a local pretrained weight. So we check here if the - # path contains what looks like a deepspeed checkpoint - import glob - - deepspeed_checkpoint_dirs = sorted(glob.glob(f"{checkpoint_path}/global_step*")) - - if len(deepspeed_checkpoint_dirs) > 0: - logger.info(f"Attempting to resume from {checkpoint_path}") - # this magically updates self.optimizer and self.lr_scheduler - load_path, _ = deepspeed_engine.load_checkpoint( - checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True, - load_module_strict=False - ) - if load_path is None: - raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}") - else: - raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}") ->>>>>>> ext/main diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index b0db718dba016b..e08a096d07ccdd 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -414,7 +414,7 @@ def deepspeed_init(trainer, num_training_steps, inference=False): return optimizer, lr_scheduler -def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True): +def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=False): # it's possible that the user is trying to resume from model_path, which doesn't necessarily # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's # a resume from a checkpoint and not just a local pretrained weight. So we check here if the diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0cccc07e7b0b5b..c0d02a069fdd18 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2416,7 +2416,6 @@ def _save_checkpoint(self, model, trial, metrics=None): run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) -<<<<<<< HEAD if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 0: logger.warning( f"Checkpoint destination directory {output_dir} already exists and is non-empty. " @@ -2426,15 +2425,12 @@ def _save_checkpoint(self, model, trial, metrics=None): else: staging_output_dir = os.path.join(run_dir, f"tmp-{checkpoint_folder}") self.save_model(staging_output_dir, _internal_call=True) -======= - self.save_model(output_dir, _internal_call=True) if self.is_deepspeed_enabled: # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed # config `stage3_gather_16bit_weights_on_model_save` is True - self.model_wrapped.save_checkpoint(output_dir) + self.model_wrapped.save_checkpoint(staging_output_dir) if self.args.deepspeed_force_lr_scheduler_checkpointing and self.model_wrapped.lr_scheduler is None: - torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) ->>>>>>> ext/main + torch.save(self.lr_scheduler.state_dict(), os.path.join(staging_output_dir, SCHEDULER_NAME)) if not self.args.save_only_model: # Save optimizer and scheduler diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index cde15206604ba0..47aa4d11b472d6 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1289,7 +1289,6 @@ class TrainingArguments: }, ) -<<<<<<< HEAD split_batches: Optional[bool] = field( default=False, metadata={ @@ -1315,6 +1314,8 @@ class TrainingArguments: default=None, metadata={ "help": "Activates neftune noise embeddings into the model. NEFTune has been proven to drastically improve model performances for instrcution fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune. Only supported for `PreTrainedModel` and `PeftModel` classes." + }, + ) deepspeed_force_lr_scheduler_checkpointing: bool = field( default=False, From 3d62791ccbfdc30ac36280e50e3dd470718c123b Mon Sep 17 00:00:00 2001 From: thepowerfuldeez Date: Thu, 15 Feb 2024 18:53:10 +0800 Subject: [PATCH 4/4] revert load_module_strict in deepspeed integration and set to False in trainer --- src/transformers/integrations/deepspeed.py | 2 +- src/transformers/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index e08a096d07ccdd..b0db718dba016b 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -414,7 +414,7 @@ def deepspeed_init(trainer, num_training_steps, inference=False): return optimizer, lr_scheduler -def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=False): +def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True): # it's possible that the user is trying to resume from model_path, which doesn't necessarily # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's # a resume from a checkpoint and not just a local pretrained weight. So we check here if the diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c0d02a069fdd18..0b2a88cb803dad 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1737,7 +1737,7 @@ def _inner_training_loop( # deepspeed ckpt loading if resume_from_checkpoint is not None and self.is_deepspeed_enabled: - deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint, load_module_strict=False) if self.args.deepspeed_force_lr_scheduler_checkpointing and self.model_wrapped.lr_scheduler is None: if os.path.isfile(os.path.join(resume_from_checkpoint, SCHEDULER_NAME)): with warnings.catch_warnings(record=True) as caught_warnings: