Skip to content

Commit

Permalink
fix trainer slow tests related to hyperparam search (#24011)
Browse files Browse the repository at this point in the history
* fix trainer slow tests

* commit 2
  • Loading branch information
pacman100 authored Jun 5, 2023
1 parent 3c31089 commit 460b844
Showing 1 changed file with 30 additions and 26 deletions.
56 changes: 30 additions & 26 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,31 +339,7 @@ def __init__(
self.hp_name = None
self.is_in_train = False

# create accelerator object
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
)

# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)

if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
from transformers.deepspeed import HfTrainerDeepSpeedConfig

ds_plugin = self.accelerator.state.deepspeed_plugin

ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args)
self.create_accelerator_and_postprocess()

# memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
Expand Down Expand Up @@ -1343,7 +1319,8 @@ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):

self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
self.args.hf_deepspeed_config.trainer_config_process(self.args)
self.accelerator.state.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config)
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
self.create_accelerator_and_postprocess()

def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
if self.hp_search_backend is None or trial is None:
Expand Down Expand Up @@ -3924,3 +3901,30 @@ def _add_sm_patterns_to_gitignore(self) -> None:
if not self.repo.is_repo_clean():
self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
self.repo.git_push()

def create_accelerator_and_postprocess(self):
# create accelerator object
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
)

# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)

if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
from transformers.deepspeed import HfTrainerDeepSpeedConfig

ds_plugin = self.accelerator.state.deepspeed_plugin

ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args)

0 comments on commit 460b844

Please sign in to comment.