2626from torchtune .modules .peft import (
2727 DoRALinear ,
2828 get_adapter_params ,
29+ get_adapter_state_dict ,
2930 get_lora_module_names ,
3031 get_merged_lora_ckpt ,
3132 load_dora_magnitudes ,
@@ -452,8 +453,7 @@ def _setup_model(
452453 with training .set_default_dtype (self ._dtype ), torch .device ("meta" ):
453454 model = config .instantiate (cfg_model )
454455
455- self .adapter_params = get_adapter_params (model )
456- set_trainable_params (model , self .adapter_params )
456+ set_trainable_params (model , get_adapter_params (model ))
457457
458458 if self ._compile :
459459 training .compile_model (model , verbose = self ._is_rank_zero )
@@ -664,11 +664,14 @@ def save_checkpoint(
664664
665665 # To prevent GPU memory from spiking during checkpoint save,
666666 # we consolidate the full model and optim state dicts on CPU for rank 0
667- cpu_state_dict = training .get_full_model_state_dict (
668- self ._model ,
667+ state_dict = self ._model .state_dict ()
668+ if self ._save_adapter_weights_only :
669+ state_dict = get_adapter_state_dict (state_dict , device = None )
670+
671+ cpu_state_dict = training .gather_cpu_state_dict (
672+ state_dict ,
669673 self ._is_rank_zero ,
670674 device = self ._device ,
671- trainable_only = self ._save_adapter_weights_only ,
672675 )
673676 if self ._is_rank_zero :
674677 log .info (
@@ -694,22 +697,22 @@ def save_checkpoint(
694697 # to be sent to the checkpointer and ultimately written to file
695698 if self ._is_rank_zero :
696699 start = time .perf_counter ()
697- # Filter out the adapter keys and weights from the model state dict. These will
698- # be saved separately
699- adapter_key_filter = lambda x : x in self .adapter_params
700- adapter_state_dict = {
701- k : v for k , v in cpu_state_dict .items () if adapter_key_filter (k )
702- }
703- checkpoint_dict .update ({training .ADAPTER_KEY : adapter_state_dict })
704700
705- # merge the adapter weights and base weights to create the model checkpoint
706- if not self ._save_adapter_weights_only :
701+ if self ._save_adapter_weights_only :
702+ adapter_state_dict = cpu_state_dict
703+ else :
704+ # Filter out the adapter keys and weights from the model state dict. These will
705+ # be saved separately
706+ adapter_state_dict = get_adapter_state_dict (cpu_state_dict )
707+
708+ # merge the adapter weights and base weights to create the model checkpoint
707709 merged_state_dict = get_merged_lora_ckpt (
708710 cpu_state_dict ,
709711 rank = self ._lora_rank ,
710712 alpha = self ._lora_alpha ,
711713 )
712714 checkpoint_dict .update ({training .MODEL_KEY : merged_state_dict })
715+ checkpoint_dict .update ({training .ADAPTER_KEY : adapter_state_dict })
713716
714717 # if training is in-progress, checkpoint the optimizer state and recipe state
715718 # as well.
0 commit comments