Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix constant_with_warmup not being so constant or warming up #919

Merged
merged 3 commits into from
Sep 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def _config_to_obj(self, config):

def parse_arguments(self, args=None):
self.config = load_config(args)
report_to = None if self.config.report_to.lower() == "none" else self.config.report_to
report_to = (
None if self.config.report_to.lower() == "none" else self.config.report_to
)
self.accelerator = Accelerator(
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
mixed_precision=(
Expand Down Expand Up @@ -769,8 +771,12 @@ def init_trainable_peft_adapter(self):

if self.config.init_lora is not None:
from lycoris import create_lycoris_from_weights

self.lycoris_wrapped_network = create_lycoris_from_weights(
multiplier, self.config.init_lora, model_for_lycoris_wrap, weights_sd=None,
multiplier,
self.config.init_lora,
model_for_lycoris_wrap,
weights_sd=None,
**self.lycoris_config,
)[0]
else:
Expand Down Expand Up @@ -1253,7 +1259,7 @@ def init_resume_checkpoint(self):
self.accelerator.load_state(os.path.join(self.config.output_dir, path))
try:
if (
"constant" in self.config.lr_scheduler
"constant" == self.config.lr_scheduler
and not self.config.is_schedulefree
):
for g in self.optimizer.param_groups:
Expand Down Expand Up @@ -1283,7 +1289,7 @@ def init_resume_checkpoint(self):
)
if hasattr(self.lr_scheduler, "last_step"):
self.lr_scheduler.last_step = self.state["global_resume_step"]
logger.info(f"Resuming from global_step {self.state['global_resume_step']}.")
logger.info(f"Resuming from global_step {self.state['global_resume_step']}).")

# Log the current state of each data backend.
for _, backend in StateTracker.get_data_backends().items():
Expand All @@ -1300,6 +1306,9 @@ def init_resume_checkpoint(self):
)
self.state["current_epoch"] = self.state["first_epoch"]
StateTracker.set_epoch(self.state["current_epoch"])
if hasattr(self.lr_scheduler, "last_epoch"):
# epoch here represents batch steps, not actual epochs.
self.lr_scheduler.last_epoch = self.state["global_resume_step"]

if self.state["current_epoch"] > self.config.num_train_epochs + 1:
logger.info(
Expand Down
Loading