Skip to content

Commit

Permalink
Merge pull request #1040 from bghira/multinode/trainer-state-filenames
Browse files Browse the repository at this point in the history
multi-node training fixes for state tracker
  • Loading branch information
bghira authored Oct 11, 2024
2 parents 5db547c + f7ec503 commit d65763e
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 30 deletions.
4 changes: 2 additions & 2 deletions documentation/DISTRIBUTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ sudo systemctl status nfs-kernel-server

---

#### **On the Slave Machine:**
#### On the slave nodes that send optimiser and other states

**1. Install NFS Client Packages**

Expand Down Expand Up @@ -276,4 +276,4 @@ Lower batch sizes, lower resolution, and enabling torch compile can bring the sp
- This is a very high-cost operation, and high batch sizes might slow you down more than you want, requiring scaling up the count of GPUs in the cluster. A careful balance of budgeting should be considered.
- (DeepSpeed) Validations might need to be disabled when training with DeepSpeed ZeRO 3
- (DeepSpeed) Model saving ends up creating weird sharded copies when saving with ZeRO level 3, but levels 1 and 2 function as expected
- (DeepSpeed) The use of DeepSpeed's CPU-based optimisers becomes required as it handles sharding and offload of the optim states.
- (DeepSpeed) The use of DeepSpeed's CPU-based optimisers becomes required as it handles sharding and offload of the optim states.
19 changes: 12 additions & 7 deletions helpers/training/save_hooks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from diffusers.training_utils import EMAModel, _set_state_dict_into_text_encoder
from helpers.training.wrappers import unwrap_model
from helpers.training.multi_process import _get_rank as get_rank
from diffusers.utils import (
convert_state_dict_to_diffusers,
convert_unet_state_dict_to_peft,
Expand Down Expand Up @@ -182,6 +183,11 @@ def __init__(
self.ema_model_cls = SD3Transformer2DModel
elif self.args.model_family == "pixart_sigma":
self.ema_model_cls = PixArtTransformer2DModel
self.training_state_path = "training_state.json"
if self.accelerator is not None:
rank = get_rank()
if rank > 0:
self.training_state_path = f"training_state-rank{rank}.json"

def _save_lora(self, models, weights, output_dir):
# for SDXL/others, there are only two options here. Either are just the unet attn processor layers
Expand Down Expand Up @@ -324,11 +330,11 @@ def _save_full_model(self, models, weights, output_dir):

def save_model_hook(self, models, weights, output_dir):
# Write "training_state.json" to the output directory containing the training state
if not self.accelerator.is_main_process:
return
StateTracker.save_training_state(
os.path.join(output_dir, "training_state.json")
os.path.join(output_dir, self.training_state_path)
)
if not self.accelerator.is_main_process:
return
if "lora" in self.args.model_type and self.args.lora_type == "standard":
self._save_lora(models=models, weights=weights, output_dir=output_dir)
return
Expand Down Expand Up @@ -485,12 +491,11 @@ def _load_full_model(self, models, input_dir):

def load_model_hook(self, models, input_dir):
# Check the checkpoint dir for a "training_state.json" file to load
training_state_path = os.path.join(input_dir, "training_state.json")
if os.path.exists(training_state_path):
StateTracker.load_training_state(training_state_path)
if os.path.exists(self.training_state_path):
StateTracker.load_training_state(self.training_state_path)
else:
logger.warning(
f"Could not find training_state.json in checkpoint dir {input_dir}"
f"Could not find {self.training_state_path} in checkpoint dir {input_dir}"
)
if "lora" in self.args.model_type and self.args.lora_type == "standard":
self._load_lora(models=models, input_dir=input_dir)
Expand Down
61 changes: 40 additions & 21 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from helpers import log_format # noqa
from helpers.configuration.loader import load_config
from helpers.caching.memory import reclaim_memory
from helpers.training.multi_process import _get_rank as get_rank
from helpers.training.validation import Validation, prepare_validation_prompt_list
from helpers.training.state_tracker import StateTracker
from helpers.training.schedulers import load_scheduler_from_args
Expand Down Expand Up @@ -1341,6 +1342,14 @@ def init_unload_vae(self):
)

def init_validations(self):
if (
self.accelerator.state.deepspeed_plugin.deepspeed_config[
"zero_optimization"
].get("stage")
== 3
):
logger.error("Cannot run validations with DeepSpeed ZeRO stage 3.")
return
self.validation = Validation(
accelerator=self.accelerator,
unet=self.unet,
Expand All @@ -1363,22 +1372,21 @@ def init_validations(self):
vae=self.vae,
controlnet=self.controlnet if self.config.controlnet else None,
)
if not self.config.train_text_encoder:
if not self.config.train_text_encoder and self.validation is not None:
self.validation.clear_text_encoders()
self.init_benchmark_base_model()
self.accelerator.wait_for_everyone()

def init_benchmark_base_model(self):
if self.config.disable_benchmark or self.validation.benchmark_exists(
"base_model"
if (
self.config.disable_benchmark
or self.validation is None
or self.validation.benchmark_exists("base_model")
):
# if we've disabled it or the benchmark exists, we will not do it again.
# deepspeed zero3 can't do validations at all.
return
if (
not self.accelerator.is_main_process
and not self.config.use_deepspeed_optimizer
):
# on deepspeed, every process has to enter. otherwise, only the main process does.
if not self.accelerator.is_main_process:
return
logger.info(
"Benchmarking base model for comparison. Supply `--disable_benchmark: true` to disable this behaviour."
Expand Down Expand Up @@ -1464,7 +1472,9 @@ def init_resume_checkpoint(self, lr_scheduler):
if "sampler" in backend:
backend["sampler"].load_states(
state_path=os.path.join(
self.config.output_dir, path, "training_state.json"
self.config.output_dir,
path,
f"training_state-{get_rank()}.json",
),
)
self.state["global_resume_step"] = self.state["global_step"] = (
Expand Down Expand Up @@ -1779,7 +1789,8 @@ def train(self):
# Just in Case.
self.mark_optimizer_eval()
# normal run-of-the-mill validation on startup.
self.validation.run_validations(validation_type="base_model", step=0)
if self.validation is not None:
self.validation.run_validations(validation_type="base_model", step=0)

self.mark_optimizer_train()

Expand Down Expand Up @@ -2642,7 +2653,8 @@ def train(self):
logger.debug(f"Backend: {backend}")
backend["sampler"].save_state(
state_path=os.path.join(
save_path, "training_state.json"
save_path,
self.model_hooks.training_state_path,
),
)

Expand All @@ -2663,9 +2675,10 @@ def train(self):

progress_bar.set_postfix(**logs)
self.mark_optimizer_eval()
self.validation.run_validations(
validation_type="intermediary", step=step
)
if self.validation is not None:
self.validation.run_validations(
validation_type="intermediary", step=step
)
self.mark_optimizer_train()
if (
self.config.push_to_hub
Expand All @@ -2677,7 +2690,11 @@ def train(self):
if self.accelerator.is_main_process:
try:
self.hub_manager.upload_latest_checkpoint(
validation_images=self.validation.validation_images,
validation_images=(
getattr(self.validation, "validation_images")
if self.validation is not None
else None
),
webhook_handler=self.webhook_handler,
)
except Exception as e:
Expand Down Expand Up @@ -2706,14 +2723,16 @@ def train(self):

# Create the pipeline using the trained modules and save it.
self.accelerator.wait_for_everyone()
validation_images = None
if self.accelerator.is_main_process:
self.mark_optimizer_eval()
validation_images = self.validation.run_validations(
validation_type="final",
step=self.state["global_step"],
force_evaluation=True,
skip_execution=True,
).validation_images
if self.validation is not None:
validation_images = self.validation.run_validations(
validation_type="final",
step=self.state["global_step"],
force_evaluation=True,
skip_execution=True,
).validation_images
if self.unet is not None:
self.unet = unwrap_model(self.accelerator, self.unet)
if self.transformer is not None:
Expand Down

0 comments on commit d65763e

Please sign in to comment.