diff --git a/src/axolotl/train.py b/src/axolotl/train.py index fa6dbceaf..5ed5837f2 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -117,6 +117,10 @@ def terminate_handler(_, __, model): LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") + if trainer.is_fsdp_enabled: + trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") + if cfg.relora_steps: if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): model = model.merge_and_unload()