@@ -496,10 +496,9 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
496496
497497def load_state_dict (
498498 checkpoint_file : Union [str , os .PathLike ],
499- is_quantized : bool = False , #change to hf_quantizer (default is none)
499+ is_quantized : bool = False ,
500500 map_location : Optional [Union [str , torch .device ]] = "cpu" ,
501501 weights_only : bool = True ,
502- hf_quantizer : Optional [HfQuantizer ] = None ,
503502):
504503 """
505504 Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
@@ -596,7 +595,7 @@ def set_initialized_submodules(model, state_dict_keys):
596595 return not_initialized_submodules
597596
598597
599- def _end_ptr (tensor : torch .Tensor ) -> int :
598+ def _end_ptr (tensor : torch .Tensor ) -> int :
600599 # extract the end of the pointer if the tensor is a slice of a bigger tensor
601600 if tensor .nelement ():
602601 stop = tensor .view (- 1 )[- 1 ].data_ptr () + tensor .element_size ()
@@ -728,7 +727,6 @@ def _load_state_dict_into_meta_model(
728727 keep_in_fp32_regex : Optional [re .Pattern ] = None ,
729728 unexpected_keys : Optional [list [str ]] = None , # passing `unexpected` for cleanup from quantization items
730729 device_mesh : Optional ["torch.distributed.device_mesh.DeviceMesh" ] = None ,
731- metadata : Optional [dict ] = None
732730) -> tuple [Optional [dict ], Optional [dict ]]:
733731 """Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
734732 device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
@@ -746,16 +744,13 @@ def _load_state_dict_into_meta_model(
746744 is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer .quantization_config .quant_method in {
747745 QuantizationMethod .HQQ ,
748746 QuantizationMethod .BITS_AND_BYTES ,
749- QuantizationMethod .TORCHAO
747+ QuantizationMethod .TORCHAO ,
750748 }
751749 is_meta_state_dict = shard_file .endswith (".safetensors" ) and not is_hqq_or_bnb_or_ao
752750 file_pointer = None
753751 if is_meta_state_dict :
754752 file_pointer = safe_open (shard_file , framework = "pt" , device = tensor_device )
755753
756- if hf_quantizer and hasattr (hf_quantizer , "transform_state_dict" ) and metadata :
757- state_dict = hf_quantizer .transform_state_dict (state_dict , metadata )
758-
759754 for param_name , empty_param in state_dict .items ():
760755 if param_name not in expected_keys : # when loading from ckpt, we skip param if doesnt exist in modeling
761756 continue
@@ -787,8 +782,7 @@ def _load_state_dict_into_meta_model(
787782 device_map = device_map ,
788783 )
789784 )
790- ):
791- # In this case, the param is already on the correct device!
785+ ): # In this case, the param is already on the correct device!
792786 shard_and_distribute_module (
793787 model ,
794788 param ,
@@ -938,7 +932,7 @@ def load_shard_file(args):
938932 # If shard_file is "", we use the existing state_dict instead of loading it
939933 if shard_file != "" :
940934 state_dict = load_state_dict (
941- shard_file , is_quantized = is_quantized , map_location = map_location , weights_only = weights_only , hf_quantizer = hf_quantizer
935+ shard_file , is_quantized = is_quantized , map_location = map_location , weights_only = weights_only
942936 )
943937
944938 # Fix the key names
@@ -948,6 +942,9 @@ def load_shard_file(args):
948942 with safe_open (shard_file , framework = "pt" ) as f :
949943 metadata = f .metadata ()
950944
945+ if hf_quantizer :
946+ state_dict = hf_quantizer .update_state_dict_with_metadata (state_dict , metadata )
947+
951948 error_msgs = []
952949
953950 if is_deepspeed_zero3_enabled () and not is_quantized :
@@ -970,7 +967,6 @@ def load_shard_file(args):
970967 keep_in_fp32_regex = keep_in_fp32_regex ,
971968 unexpected_keys = unexpected_keys ,
972969 device_mesh = device_mesh ,
973- metadata = metadata ,
974970 )
975971
976972 return error_msgs , disk_offload_index , cpu_offload_index
@@ -3994,11 +3990,11 @@ def save_pretrained(
39943990 and hf_quantizer .is_serializable (safe_serialization = safe_serialization )
39953991 )
39963992
3997- # if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
3998- # raise ValueError(
3999- # f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
4000- # " the logger on the traceback to understand the reason why the quantized model is not serializable."
4001- # )
3993+ if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable :
3994+ raise ValueError (
3995+ f"The model is quantized with { hf_quantizer .quantization_config .quant_method } and is not serializable - check out the warnings from"
3996+ " the logger on the traceback to understand the reason why the quantized model is not serializable."
3997+ )
40023998
40033999 if "save_config" in kwargs :
40044000 warnings .warn (
@@ -4029,10 +4025,8 @@ def save_pretrained(
40294025
40304026 metadata = {}
40314027 if hf_quantizer is not None :
4032- state_dict = hf_quantizer .get_state_dict (self )
4033- metadata = {}
4034- if isinstance (state_dict , tuple ):
4035- state_dict , metadata = state_dict
4028+ state_dict , metadata = hf_quantizer .get_state_dict_and_metadata (self , safe_serialization )
4029+ metadata ["format" ] = "pt"
40364030
40374031 # Only save the model itself if we are using distributed training
40384032 model_to_save = unwrap_model (self )
@@ -4180,8 +4174,7 @@ def save_pretrained(
41804174 else :
41814175 ptrs [id_tensor_storage (tensor )].append (name )
41824176
4183- # shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
4184- shared_ptrs = {}
4177+ shared_ptrs = {ptr : names for ptr , names in ptrs .items () if len (names ) > 1 }
41854178
41864179 # Recursively descend to find tied weight keys
41874180 _tied_weights_keys = _get_tied_weight_keys (self )
@@ -4312,7 +4305,6 @@ def save_pretrained(
43124305 if safe_serialization :
43134306 # At some point we will need to deal better with save_function (used for TPU and other distributed
43144307 # joyfulness), but for now this enough.
4315- metadata ["format" ] = "pt"
43164308 safe_save_file (shard , os .path .join (save_directory , shard_file ), metadata = metadata )
43174309 else :
43184310 save_function (shard , os .path .join (save_directory , shard_file ))
@@ -4808,7 +4800,6 @@ def from_pretrained(
48084800
48094801 if distributed_config is not None :
48104802 tp_plan = "auto"
4811-
48124803 # Not used anymore -- remove them from the kwargs
48134804 _ = kwargs .pop ("resume_download" , None )
48144805 _ = kwargs .pop ("mirror" , None )
@@ -4960,7 +4951,6 @@ def from_pretrained(
49604951 "Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
49614952 "requires `accelerate`. You can install it with `pip install accelerate`"
49624953 )
4963-
49644954 # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
49654955 if load_in_4bit or load_in_8bit :
49664956 if quantization_config is not None :
@@ -5030,7 +5020,6 @@ def from_pretrained(
50305020 "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
50315021 f"{ transformers_explicit_filename } "
50325022 )
5033-
50345023 hf_quantizer , config , dtype , device_map = get_hf_quantizer (
50355024 config , quantization_config , dtype , from_tf , from_flax , device_map , weights_only , user_agent
50365025 )
@@ -5103,6 +5092,7 @@ def from_pretrained(
51035092 )
51045093
51055094 from_pt = not (from_tf | from_flax )
5095+
51065096 if from_pt :
51075097 if gguf_file :
51085098 from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
@@ -5121,7 +5111,6 @@ def from_pretrained(
51215111 )
51225112
51235113 config .name_or_path = pretrained_model_name_or_path
5124-
51255114 model_init_context = cls .get_init_context (is_quantized , _is_ds_init_called )
51265115 config = copy .deepcopy (config ) # We do not want to modify the config inplace in from_pretrained.
51275116 with ContextManagers (model_init_context ):
@@ -5449,7 +5438,7 @@ def _load_pretrained_model(
54495438 is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer .quantization_config .quant_method in {
54505439 QuantizationMethod .HQQ ,
54515440 QuantizationMethod .BITS_AND_BYTES ,
5452- QuantizationMethod .TORCHAO
5441+ QuantizationMethod .TORCHAO ,
54535442 }
54545443
54555444 # Get all the keys of the state dicts that we have to initialize the model
@@ -5568,7 +5557,6 @@ def _load_pretrained_model(
55685557 if sharded_metadata is None :
55695558 weight_map = dict .fromkeys (checkpoint_keys , checkpoint_files [0 ])
55705559 else :
5571- # weight file full path
55725560 folder = os .path .sep .join (checkpoint_files [0 ].split (os .path .sep )[:- 1 ])
55735561 # Fix the weight map keys according to the key mapping
55745562 weight_map = {
0 commit comments