@@ -496,9 +496,10 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
496496
497497def load_state_dict (
498498 checkpoint_file : Union [str , os .PathLike ],
499- is_quantized : bool = False ,
499+ is_quantized : bool = False , #change to hf_quantizer (default is none)
500500 map_location : Optional [Union [str , torch .device ]] = "cpu" ,
501501 weights_only : bool = True ,
502+ hf_quantizer : Optional [HfQuantizer ] = None ,
502503):
503504 """
504505 Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
@@ -595,7 +596,7 @@ def set_initialized_submodules(model, state_dict_keys):
595596 return not_initialized_submodules
596597
597598
598- def _end_ptr (tensor : torch .Tensor ) -> int :
599+ def _end_ptr (tensor : torch .Tensor ) -> int :
599600 # extract the end of the pointer if the tensor is a slice of a bigger tensor
600601 if tensor .nelement ():
601602 stop = tensor .view (- 1 )[- 1 ].data_ptr () + tensor .element_size ()
@@ -727,6 +728,7 @@ def _load_state_dict_into_meta_model(
727728 keep_in_fp32_regex : Optional [re .Pattern ] = None ,
728729 unexpected_keys : Optional [list [str ]] = None , # passing `unexpected` for cleanup from quantization items
729730 device_mesh : Optional ["torch.distributed.device_mesh.DeviceMesh" ] = None ,
731+ metadata : Optional [dict ] = None
730732) -> tuple [Optional [dict ], Optional [dict ]]:
731733 """Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
732734 device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
@@ -741,15 +743,19 @@ def _load_state_dict_into_meta_model(
741743 device_map_regex = "|" .join ([re .escape (k ) for k in sorted (device_map .keys (), reverse = True )])
742744
743745 is_quantized = hf_quantizer is not None
744- is_hqq_or_bnb = is_quantized and hf_quantizer .quantization_config .quant_method in {
746+ is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer .quantization_config .quant_method in {
745747 QuantizationMethod .HQQ ,
746748 QuantizationMethod .BITS_AND_BYTES ,
749+ QuantizationMethod .TORCHAO
747750 }
748- is_meta_state_dict = shard_file .endswith (".safetensors" ) and not is_hqq_or_bnb
751+ is_meta_state_dict = shard_file .endswith (".safetensors" ) and not is_hqq_or_bnb_or_ao
749752 file_pointer = None
750753 if is_meta_state_dict :
751754 file_pointer = safe_open (shard_file , framework = "pt" , device = tensor_device )
752755
756+ if hf_quantizer and hasattr (hf_quantizer , "transform_state_dict" ) and metadata :
757+ state_dict = hf_quantizer .transform_state_dict (state_dict , metadata )
758+
753759 for param_name , empty_param in state_dict .items ():
754760 if param_name not in expected_keys : # when loading from ckpt, we skip param if doesnt exist in modeling
755761 continue
@@ -781,7 +787,8 @@ def _load_state_dict_into_meta_model(
781787 device_map = device_map ,
782788 )
783789 )
784- ): # In this case, the param is already on the correct device!
790+ ):
791+ # In this case, the param is already on the correct device!
785792 shard_and_distribute_module (
786793 model ,
787794 param ,
@@ -887,7 +894,7 @@ def load_shard_file(args):
887894 shard_file ,
888895 state_dict ,
889896 disk_only_shard_files ,
890- is_hqq_or_bnb ,
897+ is_hqq_or_bnb_or_ao ,
891898 is_quantized ,
892899 device_map ,
893900 hf_quantizer ,
@@ -913,7 +920,7 @@ def load_shard_file(args):
913920 map_location = "cpu"
914921 if (
915922 shard_file .endswith (".safetensors" )
916- and not is_hqq_or_bnb
923+ and not is_hqq_or_bnb_or_ao
917924 and not (is_deepspeed_zero3_enabled () and not is_quantized )
918925 ):
919926 map_location = "meta"
@@ -931,11 +938,15 @@ def load_shard_file(args):
931938 # If shard_file is "", we use the existing state_dict instead of loading it
932939 if shard_file != "" :
933940 state_dict = load_state_dict (
934- shard_file , is_quantized = is_quantized , map_location = map_location , weights_only = weights_only
941+ shard_file , is_quantized = is_quantized , map_location = map_location , weights_only = weights_only , hf_quantizer = hf_quantizer
935942 )
936943
937944 # Fix the key names
938945 state_dict = {key_renaming_mapping [k ]: v for k , v in state_dict .items () if k in key_renaming_mapping }
946+ metadata = None
947+ if shard_file .endswith (".safetensors" ) and is_safetensors_available ():
948+ with safe_open (shard_file , framework = "pt" ) as f :
949+ metadata = f .metadata ()
939950
940951 error_msgs = []
941952
@@ -959,6 +970,7 @@ def load_shard_file(args):
959970 keep_in_fp32_regex = keep_in_fp32_regex ,
960971 unexpected_keys = unexpected_keys ,
961972 device_mesh = device_mesh ,
973+ metadata = metadata ,
962974 )
963975
964976 return error_msgs , disk_offload_index , cpu_offload_index
@@ -3975,11 +3987,11 @@ def save_pretrained(
39753987 and hf_quantizer .is_serializable (safe_serialization = safe_serialization )
39763988 )
39773989
3978- if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable :
3979- raise ValueError (
3980- f"The model is quantized with { hf_quantizer .quantization_config .quant_method } and is not serializable - check out the warnings from"
3981- " the logger on the traceback to understand the reason why the quantized model is not serializable."
3982- )
3990+ # if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
3991+ # raise ValueError(
3992+ # f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
3993+ # " the logger on the traceback to understand the reason why the quantized model is not serializable."
3994+ # )
39833995
39843996 if "save_config" in kwargs :
39853997 warnings .warn (
@@ -4010,6 +4022,10 @@ def save_pretrained(
40104022
40114023 if hf_quantizer is not None :
40124024 state_dict = hf_quantizer .get_state_dict (self )
4025+ metadata = {}
4026+ if isinstance (state_dict , tuple ):
4027+ state_dict , metadata = state_dict
4028+
40134029 # Only save the model itself if we are using distributed training
40144030 model_to_save = unwrap_model (self )
40154031 # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
@@ -4155,7 +4171,8 @@ def save_pretrained(
41554171 else :
41564172 ptrs [id_tensor_storage (tensor )].append (name )
41574173
4158- shared_ptrs = {ptr : names for ptr , names in ptrs .items () if len (names ) > 1 }
4174+ # shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
4175+ shared_ptrs = {}
41594176
41604177 # Recursively descend to find tied weight keys
41614178 _tied_weights_keys = _get_tied_weight_keys (self )
@@ -4286,7 +4303,8 @@ def save_pretrained(
42864303 if safe_serialization :
42874304 # At some point we will need to deal better with save_function (used for TPU and other distributed
42884305 # joyfulness), but for now this enough.
4289- safe_save_file (shard , os .path .join (save_directory , shard_file ), metadata = {"format" : "pt" })
4306+ metadata ["format" ] = "pt"
4307+ safe_save_file (shard , os .path .join (save_directory , shard_file ), metadata = metadata )
42904308 else :
42914309 save_function (shard , os .path .join (save_directory , shard_file ))
42924310
@@ -5077,7 +5095,6 @@ def from_pretrained(
50775095 )
50785096
50795097 from_pt = not (from_tf | from_flax )
5080-
50815098 if from_pt :
50825099 if gguf_file :
50835100 from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
@@ -5096,6 +5113,7 @@ def from_pretrained(
50965113 )
50975114
50985115 config .name_or_path = pretrained_model_name_or_path
5116+
50995117 model_init_context = cls .get_init_context (is_quantized , _is_ds_init_called )
51005118 config = copy .deepcopy (config ) # We do not want to modify the config inplace in from_pretrained.
51015119 with ContextManagers (model_init_context ):
@@ -5427,9 +5445,10 @@ def _load_pretrained_model(
54275445 QuantizationMethod .HQQ ,
54285446 QuantizationMethod .QUARK ,
54295447 }
5430- is_hqq_or_bnb = is_quantized and hf_quantizer .quantization_config .quant_method in {
5448+ is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer .quantization_config .quant_method in {
54315449 QuantizationMethod .HQQ ,
54325450 QuantizationMethod .BITS_AND_BYTES ,
5451+ QuantizationMethod .TORCHAO
54335452 }
54345453
54355454 # Get all the keys of the state dicts that we have to initialize the model
@@ -5548,6 +5567,7 @@ def _load_pretrained_model(
55485567 if sharded_metadata is None :
55495568 weight_map = dict .fromkeys (checkpoint_keys , checkpoint_files [0 ])
55505569 else :
5570+ # weight file full path
55515571 folder = os .path .sep .join (checkpoint_files [0 ].split (os .path .sep )[:- 1 ])
55525572 # Fix the weight map keys according to the key mapping
55535573 weight_map = {
@@ -5602,7 +5622,7 @@ def _load_pretrained_model(
56025622 shard_file ,
56035623 state_dict ,
56045624 disk_only_shard_files ,
5605- is_hqq_or_bnb ,
5625+ is_hqq_or_bnb_or_ao ,
56065626 is_quantized ,
56075627 device_map ,
56085628 hf_quantizer ,
0 commit comments