166166else :
167167 IS_SAGEMAKER_MP_POST_1_10 = False
168168
169+ from torchao .prototype .safetensors .safetensors_utils import is_metadata_dict_torchao
170+
169171
170172logger = logging .get_logger (__name__ )
171173
@@ -496,10 +498,9 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
496498
497499def load_state_dict (
498500 checkpoint_file : Union [str , os .PathLike ],
499- is_quantized : bool = False , #change to hf_quantizer (default is none)
501+ is_quantized : bool = False ,
500502 map_location : Optional [Union [str , torch .device ]] = "cpu" ,
501503 weights_only : bool = True ,
502- hf_quantizer : Optional [HfQuantizer ] = None ,
503504):
504505 """
505506 Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
@@ -596,7 +597,7 @@ def set_initialized_submodules(model, state_dict_keys):
596597 return not_initialized_submodules
597598
598599
599- def _end_ptr (tensor : torch .Tensor ) -> int :
600+ def _end_ptr (tensor : torch .Tensor ) -> int :
600601 # extract the end of the pointer if the tensor is a slice of a bigger tensor
601602 if tensor .nelement ():
602603 stop = tensor .view (- 1 )[- 1 ].data_ptr () + tensor .element_size ()
@@ -728,7 +729,7 @@ def _load_state_dict_into_meta_model(
728729 keep_in_fp32_regex : Optional [re .Pattern ] = None ,
729730 unexpected_keys : Optional [list [str ]] = None , # passing `unexpected` for cleanup from quantization items
730731 device_mesh : Optional ["torch.distributed.device_mesh.DeviceMesh" ] = None ,
731- metadata : Optional [dict ] = None
732+ metadata : Optional [dict ] = None ,
732733) -> tuple [Optional [dict ], Optional [dict ]]:
733734 """Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
734735 device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
@@ -746,14 +747,13 @@ def _load_state_dict_into_meta_model(
746747 is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer .quantization_config .quant_method in {
747748 QuantizationMethod .HQQ ,
748749 QuantizationMethod .BITS_AND_BYTES ,
749- QuantizationMethod .TORCHAO
750+ QuantizationMethod .TORCHAO ,
750751 }
751752 is_meta_state_dict = shard_file .endswith (".safetensors" ) and not is_hqq_or_bnb_or_ao
752753 file_pointer = None
753754 if is_meta_state_dict :
754755 file_pointer = safe_open (shard_file , framework = "pt" , device = tensor_device )
755-
756- if hf_quantizer and hasattr (hf_quantizer , "transform_state_dict" ) and metadata :
756+ if hf_quantizer and hasattr (hf_quantizer , "transform_state_dict" ) and is_metadata_dict_torchao (metadata ):
757757 state_dict = hf_quantizer .transform_state_dict (state_dict , metadata )
758758
759759 for param_name , empty_param in state_dict .items ():
@@ -787,8 +787,7 @@ def _load_state_dict_into_meta_model(
787787 device_map = device_map ,
788788 )
789789 )
790- ):
791- # In this case, the param is already on the correct device!
790+ ): # In this case, the param is already on the correct device!
792791 shard_and_distribute_module (
793792 model ,
794793 param ,
@@ -938,7 +937,7 @@ def load_shard_file(args):
938937 # If shard_file is "", we use the existing state_dict instead of loading it
939938 if shard_file != "" :
940939 state_dict = load_state_dict (
941- shard_file , is_quantized = is_quantized , map_location = map_location , weights_only = weights_only , hf_quantizer = hf_quantizer
940+ shard_file , is_quantized = is_quantized , map_location = map_location , weights_only = weights_only
942941 )
943942
944943 # Fix the key names
@@ -3987,11 +3986,11 @@ def save_pretrained(
39873986 and hf_quantizer .is_serializable (safe_serialization = safe_serialization )
39883987 )
39893988
3990- # if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
3991- # raise ValueError(
3992- # f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
3993- # " the logger on the traceback to understand the reason why the quantized model is not serializable."
3994- # )
3989+ if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable :
3990+ raise ValueError (
3991+ f"The model is quantized with { hf_quantizer .quantization_config .quant_method } and is not serializable - check out the warnings from"
3992+ " the logger on the traceback to understand the reason why the quantized model is not serializable."
3993+ )
39953994
39963995 if "save_config" in kwargs :
39973996 warnings .warn (
@@ -4020,9 +4019,9 @@ def save_pretrained(
40204019 repo_id = self ._create_repo (repo_id , ** kwargs )
40214020 files_timestamps = self ._get_files_timestamps (save_directory )
40224021
4022+ metadata = {}
40234023 if hf_quantizer is not None :
40244024 state_dict = hf_quantizer .get_state_dict (self )
4025- metadata = {}
40264025 if isinstance (state_dict , tuple ):
40274026 state_dict , metadata = state_dict
40284027
@@ -4171,8 +4170,7 @@ def save_pretrained(
41714170 else :
41724171 ptrs [id_tensor_storage (tensor )].append (name )
41734172
4174- # shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
4175- shared_ptrs = {}
4173+ shared_ptrs = {ptr : names for ptr , names in ptrs .items () if len (names ) > 1 }
41764174
41774175 # Recursively descend to find tied weight keys
41784176 _tied_weights_keys = _get_tied_weight_keys (self )
@@ -5095,6 +5093,7 @@ def from_pretrained(
50955093 )
50965094
50975095 from_pt = not (from_tf | from_flax )
5096+
50985097 if from_pt :
50995098 if gguf_file :
51005099 from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
@@ -5113,7 +5112,6 @@ def from_pretrained(
51135112 )
51145113
51155114 config .name_or_path = pretrained_model_name_or_path
5116-
51175115 model_init_context = cls .get_init_context (is_quantized , _is_ds_init_called )
51185116 config = copy .deepcopy (config ) # We do not want to modify the config inplace in from_pretrained.
51195117 with ContextManagers (model_init_context ):
@@ -5448,7 +5446,7 @@ def _load_pretrained_model(
54485446 is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer .quantization_config .quant_method in {
54495447 QuantizationMethod .HQQ ,
54505448 QuantizationMethod .BITS_AND_BYTES ,
5451- QuantizationMethod .TORCHAO
5449+ QuantizationMethod .TORCHAO ,
54525450 }
54535451
54545452 # Get all the keys of the state dicts that we have to initialize the model
@@ -5567,7 +5565,6 @@ def _load_pretrained_model(
55675565 if sharded_metadata is None :
55685566 weight_map = dict .fromkeys (checkpoint_keys , checkpoint_files [0 ])
55695567 else :
5570- # weight file full path
55715568 folder = os .path .sep .join (checkpoint_files [0 ].split (os .path .sep )[:- 1 ])
55725569 # Fix the weight map keys according to the key mapping
55735570 weight_map = {
0 commit comments