diff --git a/conftest.py b/conftest.py index ef5203804433..ee012215e070 100644 --- a/conftest.py +++ b/conftest.py @@ -46,10 +46,6 @@ "test_keep_in_fp32_modules", "test_gradient_checkpointing_backward_compatibility", "test_gradient_checkpointing_enable_disable", - "test_save_load_fast_init_from_base", - "test_fast_init_context_manager", - "test_fast_init_tied_embeddings", - "test_save_load_fast_init_to_base", "test_torch_save_load", "test_initialization", "test_forward_signature", diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 696e284b748c..1700301db51f 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -303,7 +303,7 @@ def deepspeed_config(): return None -def _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_to_params_buffers=False): +def _load_state_dict_into_zero3_model(model_to_load, state_dict): """ Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers` tensor parallelism API. @@ -346,10 +346,7 @@ def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=Fals if child is not None: load(child, state_dict, prefix + name + ".", assign_to_params_buffers) - load(model_to_load, state_dict, assign_to_params_buffers=assign_to_params_buffers) - # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so - # it's safe to delete it. - del state_dict + load(model_to_load, state_dict, assign_to_params_buffers=False) return error_msgs diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 98aa3206a7a3..0d1126ba549f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -73,7 +73,6 @@ from .quantizers.quantizers_utils import get_module_from_name from .safetensors_conversion import auto_conversion from .utils import ( - ACCELERATE_MIN_VERSION, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, CONFIG_NAME, @@ -137,7 +136,6 @@ load_offloaded_weights, offload_weight, save_offload_index, - set_module_tensor_to_device, ) accelerate_version = version.parse(importlib.metadata.version("accelerate")) @@ -208,32 +206,29 @@ def is_local_dist_rank_0(): @contextmanager -def no_init_weights(_enable=True): +def no_init_weights(): """ Context manager to globally disable weight initialization to speed up loading large models. - - TODO(Patrick): Delete safety argument `_enable=True` at next major version. . """ global _init_weights old_init_weights = _init_weights - if _enable: - _init_weights = False + _init_weights = False - def _skip_init(*args, **kwargs): - pass + def _skip_init(*args, **kwargs): + pass + + # Save the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, _skip_init) - # # Save the original initialization functions - for name, init_func in TORCH_INIT_FUNCTIONS.items(): - setattr(torch.nn.init, name, _skip_init) try: yield finally: _init_weights = old_init_weights - if _enable: - # # Restore the original initialization functions - for name, init_func in TORCH_INIT_FUNCTIONS.items(): - setattr(torch.nn.init, name, init_func) + # Restore the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, init_func) @contextmanager @@ -404,37 +399,6 @@ def dtype_byte_size(dtype): return bit_size // 8 -def check_support_param_buffer_assignment(model_to_load, state_dict): - """ - Checks if `model_to_load` supports param buffer assignment (such - as when loading in empty weights) by first checking - if the model explicitly disables it, then by ensuring that the state dict keys - are a subset of the model's parameters. - - Note: We fully disable this if we are using `deepspeed` - """ - if len(state_dict) == 0: - return False - - if is_deepspeed_zero3_enabled(): - return False - - # Some models explicitly do not support param buffer assignment - if not getattr(model_to_load, "_supports_param_buffer_assignment", True): - logger.debug( - f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" - ) - return False - - # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype - first_key = next(iter(model_to_load.state_dict().keys())) - if first_key in state_dict: - return state_dict[first_key].dtype == model_to_load.state_dict()[first_key].dtype - - # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`) - return False - - def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): """ This is the same as @@ -750,6 +714,13 @@ def _infer_parameter_dtype( return old_param is not None and old_param.is_contiguous(), casting_dtype +def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor): + """Cast a single parameter `param_name` into the `model`, with value `tensor`.""" + module, param_type = get_module_from_name(model, param_name) + # This will check potential shape mismatch if skipped before + module.load_state_dict({param_type: tensor}, strict=False, assign=True) + + @torch.no_grad() def _load_state_dict_into_meta_model( model: "PreTrainedModel", @@ -857,17 +828,12 @@ def _load_state_dict_into_meta_model( ): if is_fsdp_enabled(): param_device = "cpu" if is_local_dist_rank_0() else "meta" - module, param_type = get_module_from_name(model, param_name) # avoid tied weights if param.data_ptr() in data_ptrs: param = param.clone() - module.load_state_dict( - {param_type: param.to(param_device)}, - strict=False, - assign=True, - ) + _load_parameter_into_model(model, param_name, param.to(param_device)) # Add `data_ptr` of `model.state_dict()[param_name]` to avoid tied weights data_ptrs.add(model.state_dict()[param_name].data_ptr()) @@ -1397,18 +1363,7 @@ def _find_missing_and_unexpected_keys( if has_inv_freq_buffers: unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k] - if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): - ptrs = collections.defaultdict(list) - for name, tensor in model.state_dict().items(): - id_tensor = id_tensor_storage(tensor) - ptrs[id_tensor].append(name) - - # These are all the pointers of shared tensors. - tied_params = [names for _, names in ptrs.items() if len(names) > 1] - else: - # id function doesn't work for meta tensor so we need this function - tied_params = find_tied_parameters(model) - + tied_params = find_tied_parameters(model) for group in tied_params: missing_in_group = [k for k in missing_keys if k in group] if len(missing_in_group) > 0 and len(missing_in_group) < len(group): @@ -1430,29 +1385,59 @@ def _find_missing_and_unexpected_keys( def _find_mismatched_keys( - model_to_load: "PreTrainedModel", - state_dict: Dict, + model: "PreTrainedModel", + state_dict: Optional[Dict], + checkpoint_files: Optional[List[str]], ignore_mismatched_sizes: bool, - prefix: str, -) -> List: - """Find mismatch of shapes between the model parameters and the loaded state dict, and optionally remove the - problematic keys from `state_dict` if `ignore_mismatched_sizes` is `True`.""" + keys_to_rename_mapping: Dict[str, str], + is_quantized: bool, + weights_only: bool, +) -> Tuple[List[str], List[Tuple[int, int]]]: + """ + Find potential shape mismatch between the different state dicts and the model parameters, but only if `ignore_mismatched_sizes` + is True. Otherwise, return immediately and any shape mismatch that may exist will be raised later on. This avoids checking + every parameter in advance, as shape mismatch are extremely rare in practice. If we want to ignore them however, we do + need to check in advance as we need to know which parameters we need to move back from meta to cpu, and initialize + correctly. Indeed, as our model initialization takes place at the module level, and not the weight level, in the + case of a sharded checkpoint we cannot correctly initialize the weights according to `model._init_weights()` if we perform + this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the + mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be + initialized, not only the weights that are mismatched). + """ + + # An error will be raised later on anyway if there is a mismatch - this avoids running the rest of this function + # if there are no mismatch (which is almost always the case) + if not ignore_mismatched_sizes: + return [], [] + + if state_dict is not None: + checkpoint_files = [""] + + model_state_dict = model.state_dict() mismatched_keys = [] - if ignore_mismatched_sizes: - model_state_dict = model_to_load.state_dict() - state_dict_keys = list(state_dict.keys()) - for key in state_dict_keys: - if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape: - if state_dict[key].shape[-1] == 1 and state_dict[key].numel() * 2 == model_state_dict[key].numel(): - # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. - # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. - pass - else: - # Add prefix if we removed it before, to add the correct state dict key to the warnings - key_with_prefix = prefix + key - mismatched_keys.append((key_with_prefix, state_dict[key].shape, model_state_dict[key].shape)) - del state_dict[key] - return mismatched_keys + mismatched_shapes = [] + for shard_file in checkpoint_files: + # If shard_file is "", we use the existing state_dict instead of loading it + if shard_file != "": + state_dict = load_state_dict( + shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only + ) + + # Fix the key names + new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping} + + for key in new_state_dict.keys(): + if key in model_state_dict and new_state_dict[key].shape != model_state_dict[key].shape: + # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. + # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. + if not ( + new_state_dict[key].shape[-1] == 1 + and new_state_dict[key].numel() * 2 == model_state_dict[key].numel() + ): + mismatched_keys.append(key) + mismatched_shapes.append((new_state_dict[key].shape, model_state_dict[key].shape)) + + return mismatched_keys, mismatched_shapes class PipelineParallel(Enum): @@ -3773,13 +3758,9 @@ def float(self, *args): @classmethod def get_init_context( cls: Type[SpecificPreTrainedModelType], - _fast_init=True, is_quantized=None, _is_ds_init_called=None, - low_cpu_mem_usage=True, ): - init_contexts = [no_init_weights(_enable=_fast_init)] - if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called: import deepspeed @@ -3787,13 +3768,10 @@ def get_init_context( init_contexts = [ deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state(), - ] + init_contexts - elif low_cpu_mem_usage: - if not is_accelerate_available(): - raise ImportError( - f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" - ) - init_contexts.append(init_empty_weights()) + no_init_weights(), + ] + else: + init_contexts = [no_init_weights(), init_empty_weights()] if is_deepspeed_zero3_enabled() and is_quantized: init_contexts.append(set_quantized_state()) @@ -3829,10 +3807,6 @@ def from_pretrained( The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those weights are discarded. - If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded - in using the `meta` device and brought into memory once an input is passed through that layer regardless of - `low_cpu_mem_usage`. - Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: @@ -3910,31 +3884,12 @@ def from_pretrained( To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. - - _fast_init(`bool`, *optional*, defaults to `True`): - Whether or not to disable fast initialization. - - - - One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ < - 4.6.0` for seeded model initialization. This argument will be removed at the next major version. See - [pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information. - attn_implementation (`str`, *optional*): The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. > Parameters for big model inference - low_cpu_mem_usage(`bool`, *optional*): - Tries not to use more than 1x model size in CPU memory (including peak memory) while loading the model. - Generally should be combined with a `device_map` (such as `"auto"`) for best results. - This is an experimental feature and a subject to change at any moment. - - If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without - `device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However, - this should still be enabled if you are passing in a `device_map`. - torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model under a specific `dtype`. The different options are: @@ -4045,37 +4000,16 @@ def from_pretrained( >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower) >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True) ``` - - * `low_cpu_mem_usage` algorithm: - - This is an experimental function that loads the model using ~1x model size CPU memory - - Here is how it works: - - 1. save which state_dict keys we have - 2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory - 3. after the model has been instantiated switch to the meta device all params/buffers that - are going to be replaced from the loaded state_dict - 4. load state_dict 2nd time - 5. replace the params/buffers from the state_dict - - Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors - """ state_dict = kwargs.pop("state_dict", None) from_tf = kwargs.pop("from_tf", False) from_flax = kwargs.pop("from_flax", False) - _ = kwargs.pop("resume_download", None) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) use_auth_token = kwargs.pop("use_auth_token", None) - _ = kwargs.pop("trust_remote_code", None) - _ = kwargs.pop("mirror", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) - _fast_init = kwargs.pop("_fast_init", True) torch_dtype = kwargs.pop("torch_dtype", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) @@ -4094,6 +4028,12 @@ def from_pretrained( gguf_file = kwargs.pop("gguf_file", None) tp_plan = kwargs.pop("tp_plan", None) key_mapping = kwargs.pop("key_mapping", None) + # Not used anymore -- remove them from the kwargs + _ = kwargs.pop("resume_download", None) + _ = kwargs.pop("trust_remote_code", None) + _ = kwargs.pop("mirror", None) + _ = kwargs.pop("_fast_init", True) + _ = kwargs.pop("low_cpu_mem_usage", None) if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None): raise ValueError( @@ -4156,9 +4096,6 @@ def from_pretrained( world_size = torch.distributed.get_world_size() device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,)) - if is_fsdp_enabled(): - low_cpu_mem_usage = True - if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", @@ -4240,20 +4177,8 @@ def from_pretrained( device_map = {"": device_map} if device_map is not None: - if low_cpu_mem_usage is None: - low_cpu_mem_usage = True - elif not low_cpu_mem_usage: - raise ValueError("Passing along a `device_map` or a `tp_plan` requires `low_cpu_mem_usage=True`") - - if low_cpu_mem_usage: if is_deepspeed_zero3_enabled(): - raise ValueError( - "DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`." - ) - elif not is_accelerate_available(): - raise ImportError( - f"Using `low_cpu_mem_usage=True`, a `device_map` or a `tp_plan` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" - ) + raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.") # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation. if load_in_4bit or load_in_8bit: @@ -4355,10 +4280,6 @@ def from_pretrained( user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value else: user_agent["quant"] = hf_quantizer.quantization_config.quant_method - # Force-set to `True` for more mem efficiency - if low_cpu_mem_usage is None: - low_cpu_mem_usage = True - logger.warning("`low_cpu_mem_usage` was None, now default to True since model is quantized.") if gguf_file is not None and hf_quantizer is not None: raise ValueError( @@ -4438,8 +4359,6 @@ def from_pretrained( state_dict = load_gguf_checkpoint(checkpoint_files[0], return_tensors=True, model_to_load=dummy_model)[ "tensors" ] - # Force it if is not already the case - low_cpu_mem_usage = True # Find the correct dtype based on current state config, torch_dtype, dtype_orig = _get_torch_dtype( @@ -4449,7 +4368,7 @@ def from_pretrained( config.name_or_path = pretrained_model_name_or_path # Instantiate model. - model_init_context = cls.get_init_context(_fast_init, is_quantized, _is_ds_init_called, low_cpu_mem_usage) + model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called) config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. if not getattr(config, "_attn_implementation_autoset", False): @@ -4480,8 +4399,6 @@ def from_pretrained( if model._keep_in_fp32_modules is not None and ( torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False) ): - # Only the path with `low_cpu_mem_usage` will check every param for the correct dtype - low_cpu_mem_usage = True # We need to match exact layers, so we add either `.` on each side, or start/end of string keep_in_fp32_regex = re.compile( "|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules]) @@ -4526,7 +4443,6 @@ def from_pretrained( pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, sharded_metadata=sharded_metadata, - low_cpu_mem_usage=low_cpu_mem_usage, device_map=device_map, disk_offload_folder=offload_folder, offload_state_dict=offload_state_dict, @@ -4536,7 +4452,6 @@ def from_pretrained( device_mesh=device_mesh, key_mapping=key_mapping, weights_only=weights_only, - _fast_init=_fast_init, ) # make sure token embedding weights are still tied if needed @@ -4735,7 +4650,6 @@ def _load_pretrained_model( pretrained_model_name_or_path: Optional[str], ignore_mismatched_sizes: bool = False, sharded_metadata: Optional[Dict] = None, - low_cpu_mem_usage: bool = False, device_map: Optional[Dict] = None, disk_offload_folder: Optional[str] = None, offload_state_dict: Optional[bool] = None, @@ -4745,7 +4659,6 @@ def _load_pretrained_model( device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None, key_mapping: Optional[Dict[str, str]] = None, weights_only: bool = True, - _fast_init: bool = True, ): # Useful flags is_quantized = hf_quantizer is not None @@ -4787,20 +4700,28 @@ def _load_pretrained_model( hf_quantizer, device_map, ) + # Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the + # same way as missing keys) + mismatched_keys, mismatched_shapes = _find_mismatched_keys( + model, + state_dict, + checkpoint_files, + ignore_mismatched_sizes, + key_renaming_mapping, + is_quantized, + weights_only, + ) + + # We need to update both the mapping and the list of checkpoint keys to remove the mismatched ones + key_renaming_mapping = {k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys} + checkpoint_keys = list(key_renaming_mapping.values()) - # Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as - # they are not in the loaded state dict) - if low_cpu_mem_usage: - model._move_missing_keys_from_meta_to_cpu(missing_keys, unexpected_keys, dtype, hf_quantizer) - # In this case we also need to move everything back - if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: - for key, param in model.state_dict().items(): - if param.device == torch.device("meta"): - set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype)) + # Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when + # loading the weights as they are not in the loaded state dict) + model._move_missing_keys_from_meta_to_cpu(missing_keys + mismatched_keys, unexpected_keys, dtype, hf_quantizer) - # correctly initialize the missing keys if it was skipped before - if _fast_init or low_cpu_mem_usage: - model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized) + # correctly initialize the missing (and potentially mismatched) keys + model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized) # Set some modules to fp32 if needed if keep_in_fp32_regex is not None: @@ -4907,7 +4828,6 @@ def _load_pretrained_model( caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4) error_msgs = [] - mismatched_keys = [] # Iterate on all the shards to load the weights for shard_file in checkpoint_files: # Skip the load for shards that only contain disk-offloaded weights @@ -4915,16 +4835,15 @@ def _load_pretrained_model( continue map_location = "cpu" - if low_cpu_mem_usage: - if shard_file.endswith(".safetensors"): - map_location = "meta" - elif ( - device_map is not None - and hf_quantizer is not None - and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] - ): - map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + if shard_file.endswith(".safetensors"): + map_location = "meta" + elif ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) # If shard_file is "", we use the existing state_dict instead of loading it if shard_file != "": @@ -4935,41 +4854,27 @@ def _load_pretrained_model( # Fix the key names state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} - # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not - # matching the weights in the model. - mismatched_keys += _find_mismatched_keys( - model_to_load, - state_dict, - ignore_mismatched_sizes, - prefix if loading_base_model_from_task_state_dict else "", - ) - - if low_cpu_mem_usage: - # Skip it with fsdp on ranks other than 0 - if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): - disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( - model_to_load, - state_dict, - shard_file, - expected_keys, - reverse_key_renaming_mapping, - device_map=device_map, - disk_offload_folder=disk_offload_folder, - disk_offload_index=disk_offload_index, - cpu_offload_folder=cpu_offload_folder, - cpu_offload_index=cpu_offload_index, - hf_quantizer=hf_quantizer, - is_safetensors=is_offloaded_safetensors, - keep_in_fp32_regex=keep_in_fp32_regex, - unexpected_keys=unexpected_keys, - device_mesh=device_mesh, - ) - else: - assign_params = check_support_param_buffer_assignment(model_to_load, state_dict) - if is_deepspeed_zero3_enabled(): - error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params) - else: - model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) + if is_deepspeed_zero3_enabled(): + error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict) + # Skip it with fsdp on ranks other than 0 + elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): + disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + shard_file, + expected_keys, + reverse_key_renaming_mapping, + device_map=device_map, + disk_offload_folder=disk_offload_folder, + disk_offload_index=disk_offload_index, + cpu_offload_folder=cpu_offload_folder, + cpu_offload_index=cpu_offload_index, + hf_quantizer=hf_quantizer, + is_safetensors=is_offloaded_safetensors, + keep_in_fp32_regex=keep_in_fp32_regex, + unexpected_keys=unexpected_keys, + device_mesh=device_mesh, + ) # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop del state_dict @@ -5068,7 +4973,7 @@ def _load_pretrained_model( mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys + for key, (shape1, shape2) in zip(mismatched_keys, mismatched_shapes) ] ) logger.warning( @@ -5323,19 +5228,26 @@ def _move_missing_keys_from_meta_to_cpu( """ is_quantized = hf_quantizer is not None + # In this case we need to move everything back + if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: + # We only do it for the parameters, as the buffers are not initialized on the meta device by default + for key, param in self.named_parameters(): + value = torch.empty_like(param, dtype=dtype, device="cpu") + _load_parameter_into_model(self, key, value) + return + model_state_dict = self.state_dict() for key in missing_keys: param = model_state_dict[key] + # Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them if param.device == torch.device("meta"): - # upcast in fp32 if any - target_dtype = dtype - value = torch.empty(*param.size(), dtype=target_dtype) + value = torch.empty_like(param, dtype=dtype, device="cpu") if ( not is_quantized or (getattr(hf_quantizer, "requires_parameters_quantization", False)) or not hf_quantizer.check_quantized_param(self, param_value=value, param_name=key, state_dict={}) ): - set_module_tensor_to_device(self, key, "cpu", value) + _load_parameter_into_model(self, key, value) else: hf_quantizer.create_quantized_param(self, value, key, "cpu", model_state_dict, unexpected_keys) diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index e9e029cf53f3..1aa24ca08a5a 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -401,7 +401,6 @@ class ASTPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flash_attn_2 = True - # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): @@ -415,6 +414,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, ASTEmbeddings): + module.cls_token.data.zero_() + module.position_embeddings.data.zero_() + module.distillation_token.data.zero_() AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 1697ae3c5c46..aaca155b0c8c 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -361,22 +361,20 @@ class AutoformerSinusoidalPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: super().__init__(num_positions, embedding_dim) - @staticmethod - def _init_weight(out: nn.Parameter) -> nn.Parameter: + def _init_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] """ - n_pos, dim = out.shape + n_pos, dim = self.weight.shape position_enc = np.array( [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] ) - out.requires_grad = False # set early to avoid an error in pytorch-1.8+ + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - out.detach_() - return out + self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: @@ -903,7 +901,7 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() elif isinstance(module, AutoformerSinusoidalPositionalEmbedding): - module.weight = module._init_weight(module.weight) + module._init_weight() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index b4b116bdfb0f..264bce993a34 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -770,6 +770,18 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, BeitEmbeddings): + module.cls_token.data.zero_() + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, BeitRelativePositionBias): + module.relative_position_bias_table.data.zero_() + elif isinstance(module, BeitLayer): + if module.lambda_1 is not None: + module.lambda_1.data.fill_(self.config.layer_scale_init_value) + module.lambda_2.data.fill_(self.config.layer_scale_init_value) BEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 8e48263c9300..d7a26500ccfa 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -848,6 +848,8 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, BertLMPredictionHead): + module.bias.data.zero_() @dataclass diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index e44ef805531e..b69590ae21a5 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -715,7 +715,7 @@ class CamembertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->CamembertLMHead def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): @@ -731,6 +731,8 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, CamembertLMHead): + module.bias.data.zero_() CAMEMBERT_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index 155f466ac4ae..8eeb98b089ff 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -286,9 +286,12 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, (nn.LayerNorm, ConvNextLayerNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, ConvNextLayer): + if module.layer_scale_parameter is not None: + module.layer_scale_parameter.data.fill_(self.config.layer_scale_init_value) CONVNEXT_START_DOCSTRING = r""" diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index b779dfbe415f..98e5ba15513e 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -287,7 +287,6 @@ def forward( ) -# Copied from transformers.models.convnext.modeling_convnext.ConvNextPreTrainedModel with ConvNext->ConvNextV2, convnext->convnextv2 class ConvNextV2PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -307,9 +306,12 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, (nn.LayerNorm, ConvNextV2LayerNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, ConvNextV2GRN): + module.weight.data.zero_() + module.bias.data.zero_() CONVNEXTV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index c86495cbbe21..12a407c51ad6 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -784,6 +784,18 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, Data2VecVisionEmbeddings): + module.cls_token.data.zero_() + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, Data2VecVisionRelativePositionBias): + module.relative_position_bias_table.data.zero_() + elif isinstance(module, Data2VecVisionLayer): + if module.lambda_1 is not None: + module.lambda_1.data.fill_(self.config.layer_scale_init_value) + module.lambda_2.data.fill_(self.config.layer_scale_init_value) DATA2VEC_VISION_START_DOCSTRING = r""" diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 8db75abc0a77..007129c5bd6e 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1850,30 +1850,20 @@ def __init__(self, config: DeformableDetrConfig): num_layers=3, ) - prior_prob = 0.01 - bias_value = -math.log((1 - prior_prob) / prior_prob) - self.class_embed.bias.data = torch.ones(config.num_labels) * bias_value - nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) - nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) - # if two-stage, the last class_embed and bbox_embed is for region proposal generation num_pred = (config.decoder_layers + 1) if config.two_stage else config.decoder_layers if config.with_box_refine: self.class_embed = _get_clones(self.class_embed, num_pred) self.bbox_embed = _get_clones(self.bbox_embed, num_pred) - nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) # hack implementation for iterative bounding box refinement self.model.decoder.bbox_embed = self.bbox_embed else: - nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) self.model.decoder.bbox_embed = None if config.two_stage: # hack implementation for two-stage self.model.decoder.class_embed = self.class_embed - for box_embed in self.bbox_embed: - nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 0c41ae1f7fe5..b8cae05f1ff6 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -487,6 +487,12 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, DeiTEmbeddings): + module.cls_token.data.zero_() + module.position_embeddings.data.zero_() + module.distillation_token.data.zero_() + if module.mask_token is not None: + module.mask_token.data.zero_() DEIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py index a383e2937f6f..18ab8db6d05a 100755 --- a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py @@ -44,7 +44,6 @@ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional self.offset = 2 self.embedding_dim = embedding_dim self.padding_idx = padding_idx - self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) @@ -399,6 +398,11 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Speech2Text2SinusoidalPositionalEmbedding): + weight = module.get_embedding(*module.weight.shape, module.padding_idx) + weight = nn.Parameter(weight, requires_grad=False) + weight.detach_() + module.weight = weight SPEECH_TO_TEXT_2_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py index 922d5fab9be9..6ad8a14a7329 100644 --- a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py @@ -516,12 +516,12 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No mean=0.0, std=self.config.initializer_range, ).to(module.position_embeddings.dtype) - module.cls_token.data = nn.init.trunc_normal_( module.cls_token.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, ).to(module.cls_token.dtype) + module.mask_token.data.zero_() VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 2e11d3a76c6c..7ed5a4ec6cb7 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -548,6 +548,11 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) + if self.config.use_mask_token: + module.mask_token.data.zero_() + elif isinstance(module, Dinov2LayerScale): + module.lambda1.data.fill_(self.config.layerscale_value) + DINOV2_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index c7c48dadb73f..449bfb9b91cd 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -556,6 +556,11 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) + module.mask_token.data.zero_() + module.register_tokens.data.zero_() + elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821 + module.lambda1.data.fill_(self.config.layerscale_value) + _EXPECTED_OUTPUT_SHAPE = [1, 257, 768] diff --git a/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py index cbd316c421b0..59777e215789 100644 --- a/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Union import torch import torch.utils.checkpoint @@ -277,7 +277,36 @@ class Dinov2WithRegistersEncoder(Dinov2Encoder): class Dinov2WithRegistersPreTrainedModel(Dinov2PreTrainedModel): - pass + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2WithRegistersEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + module.mask_token.data.zero_() + module.register_tokens.data.zero_() + elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821 + module.lambda1.data.fill_(self.config.layerscale_value) class Dinov2WithRegistersModel(Dinov2Model): diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 1434ae415045..929d73088475 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -869,6 +869,13 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, DonutSwinEmbeddings): + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, DonutSwinSelfAttention): + module.relative_position_bias_table.data.zero_() SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 98b9782ad335..c9bbaa171675 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -855,6 +855,9 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)): + module.cls_token.data.zero_() + module.position_embeddings.data.zero_() DPT_START_DOCSTRING = r""" diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index f2138ac0f683..3921ca50790f 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -670,7 +670,6 @@ class ElectraPreTrainedModel(PreTrainedModel): base_model_prefix = "electra" supports_gradient_checkpointing = True - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 2cba3bd8a237..415fd058e45d 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -264,6 +264,8 @@ def __init__( self.tie_weights() def tie_weights(self): + self.encoder.tie_weights() + self.decoder.tie_weights() # tie encoder & decoder if needed if self.config.tie_encoder_decoder: # tie encoder and decoder base model @@ -279,6 +281,12 @@ def tie_weights(self): # Leading to issues on subsequent calls by different tests or subsequent calls. self._dynamic_tied_weights_keys = tied_weights + def _init_weights(self, module): + if module in self.encoder.modules(): + self.encoder._init_weights(module) + elif module in self.decoder.modules(): + self.decoder._init_weights(module) + def get_encoder(self): return self.encoder @@ -385,14 +393,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): return model - # At the moment fast initialization is not supported for composite models - if kwargs.get("_fast_init", False): - logger.warning( - "Fast initialization is currently not supported for EncoderDecoderModel. " - "Falling back to slow initialization..." - ) - kwargs["_fast_init"] = False - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @classmethod diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index a7d07904e06a..6f90d8d052f9 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -681,7 +681,7 @@ class EsmPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"] - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): @@ -697,6 +697,8 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, EsmLMHead): + module.bias.data.zero_() ESM_START_DOCSTRING = r""" diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index ca08cad4d283..330f7c4e7baf 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -874,6 +874,18 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, FlavaMaskedPredictionHead): + module.bias.data.zero_() + elif isinstance(module, FlavaImageEmbeddings): + module.cls_token.data.zero_() + module.position_embeddings.data.zero_() + if module.mask_token is not None: + module.mask_token.data.zero_() + elif isinstance(module, FlavaMultimodalModel): + if module.use_cls_token: + module.cls_token.data.zero_() + elif isinstance(module, FlavaModel): + module.logit_scale.data.fill_(self.config.logit_scale_init_value) @add_start_docstrings( diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 687654a22da3..a5cf2981b14a 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -621,7 +621,6 @@ def forward( ) -# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->FocalNet,swin->focalnet class FocalNetPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -645,6 +644,13 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, FocalNetEmbeddings): + if module.mask_token is not None: + module.mask_token.data.zero_() + elif isinstance(module, FocalNetLayer): + if self.config.use_layerscale: + module.gamma_1.data.fill_(self.config.layerscale_value) + module.gamma_2.data.fill_(self.config.layerscale_value) FOCALNET_START_DOCSTRING = r""" diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index fa5ec7fdda26..8bbdf195501b 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -351,7 +351,10 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() elif isinstance(module, SinusoidalPositionalEmbedding): - pass + weight = module.get_embedding(*module.weight.shape, module.padding_idx) + weight = nn.Parameter(weight, requires_grad=False) + weight.detach_() + module.weight = weight elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: @@ -1302,17 +1305,13 @@ class SinusoidalPositionalEmbedding(nn.Embedding): """ def __init__(self, num_positions, embedding_dim, padding_idx): - self.make_weight(num_positions, embedding_dim, padding_idx) + super().__init__(num_positions, embedding_dim, padding_idx) def make_weight(self, num_positions, embedding_dim, padding_idx): weight = self.get_embedding(num_positions, embedding_dim, padding_idx) - if not hasattr(self, "weight"): - # in ___init__ - super().__init__(num_positions, embedding_dim, padding_idx, _weight=weight) - else: - # in forward put the weights on the correct dtype and device of the param - weight = weight.to(dtype=self.weight.dtype, device=self.weight.device) - self.weight = nn.Parameter(weight) + # in forward put the weights on the correct dtype and device of the param + weight = weight.to(dtype=self.weight.dtype, device=self.weight.device) + self.weight = nn.Parameter(weight) self.weight.detach_() self.weight.requires_grad = False diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index b753db265493..8842b10cc15b 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -437,7 +437,7 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): module.bias.data.zero_() module.weight.data.fill_(1.0) diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index e0ee50a0044a..c48310e4256b 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -187,6 +187,8 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No mean=0.0, std=self.config.initializer_range, ).to(module.position_embeddings.dtype) + if module.mask_token is not None: + module.mask_token.data.zero_() def eager_attention_forward( diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py index 447347a4eca8..2cf0fe32bf5f 100644 --- a/src/transformers/models/ijepa/modular_ijepa.py +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -129,6 +129,8 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No mean=0.0, std=self.config.initializer_range, ).to(module.position_embeddings.dtype) + if module.mask_token is not None: + module.mask_token.data.zero_() _EXPECTED_OUTPUT_SHAPE = [1, 256, 1280] diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index b35863251f11..3f37662459e8 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -234,22 +234,20 @@ class InformerSinusoidalPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: super().__init__(num_positions, embedding_dim) - @staticmethod - def _init_weight(out: nn.Parameter) -> nn.Parameter: + def _init_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] """ - n_pos, dim = out.shape + n_pos, dim = self.weight.shape position_enc = np.array( [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] ) - out.requires_grad = False # set early to avoid an error in pytorch-1.8+ + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - out.detach_() - return out + self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: @@ -887,7 +885,7 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() elif isinstance(module, InformerSinusoidalPositionalEmbedding): - module.weight = module._init_weight(module.weight) + module._init_weight() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 56a3776bde2d..8c31521a3f6d 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -635,6 +635,8 @@ def _init_weights(self, module): elif isinstance(module, LayoutLMLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, LayoutLMLMPredictionHead): + module.bias.data.zero_() LAYOUTLM_START_DOCSTRING = r""" diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 015e43fd4ecc..8cb9cbdf959d 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -48,6 +48,10 @@ import detectron2 from detectron2.modeling import META_ARCH_REGISTRY + # This is needed as otherwise their overload will break sequential loading by overwriting buffer over and over. See + # https://github.com/facebookresearch/detectron2/blob/9604f5995cc628619f0e4fd913453b4d7d61db3f/detectron2/layers/batch_norm.py#L83-L86 + detectron2.layers.batch_norm.FrozenBatchNorm2d._load_from_state_dict = torch.nn.Module._load_from_state_dict + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "microsoft/layoutlmv2-base-uncased" @@ -510,6 +514,10 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, LayoutLMv2SelfAttention): + if self.config.fast_qkv: + module.q_bias.data.zero_() + module.v_bias.data.zero_() elif isinstance(module, LayoutLMv2Model): if hasattr(module, "visual_segment_embedding"): module.visual_segment_embedding.data.normal_(mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 9183bb90240c..8c79ae42f0e5 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -374,6 +374,10 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, LayoutLMv3Model): + if self.config.visual_embed: + module.cls_token.data.zero_() + module.pos_embed.data.zero_() class LayoutLMv3SelfAttention(nn.Module): diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 09865489572e..b3d057fd648f 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -579,7 +579,6 @@ class LiltPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = [] - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index 36dae0ee1d7e..1b8fb938a2ea 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -790,6 +790,8 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, LxmertLMPredictionHead): + module.bias.data.zero_() LXMERT_START_DOCSTRING = r""" @@ -1072,6 +1074,9 @@ def __init__(self, config): } self.visual_losses = visual_losses + def _tie_weights(self): + self.cls.predictions.decoder.weight = self.lxmert.embeddings.word_embeddings.weight + def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index b71c6464485f..6d69e21213d7 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -74,22 +74,20 @@ class MarianSinusoidalPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: super().__init__(num_positions, embedding_dim) - @staticmethod - def _init_weight(out: nn.Parameter) -> nn.Parameter: + def _init_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] """ - n_pos, dim = out.shape + n_pos, dim = self.weight.shape position_enc = np.array( [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] ) - out.requires_grad = False # set early to avoid an error in pytorch-1.8+ + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - out.detach_() - return out + self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: @@ -467,7 +465,7 @@ def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalP if module.bias is not None: module.bias.data.zero_() elif isinstance(module, MarianSinusoidalPositionalEmbedding): - module.weight = module._init_weight(module.weight) + module._init_weight() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 166e63b84b99..f47483d9d861 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -731,6 +731,8 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, MarkupLMLMPredictionHead): + module.bias.data.zero_() @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 2597d2a03e62..4a8d0b002c6b 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -766,7 +766,6 @@ def forward( ) -# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->MaskFormerSwin, swin->model class MaskFormerSwinPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -790,6 +789,11 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, MaskFormerSwinEmbeddings): + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, MaskFormerSwinSelfAttention): + module.relative_position_bias_table.data.zero_() class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index add0eeba6544..82ac64c9a40f 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -718,6 +718,11 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MusicgenSinusoidalPositionalEmbedding): + weights = module.get_embedding(*module.weights.shape) + weights = nn.Parameter(weights, requires_grad=False) + weights.detach_() + module.weights = weights MUSICGEN_START_DOCSTRING = r""" @@ -1805,27 +1810,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Example: - - ```python - >>> from transformers import MusicgenForConditionalGeneration - - >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") - ```""" - - # At the moment fast initialization is not supported for composite models - if kwargs.get("_fast_init", False): - logger.warning( - "Fast initialization is currently not supported for MusicgenForConditionalGeneration. " - "Falling back to slow initialization..." - ) - kwargs["_fast_init"] = False - - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - @classmethod def from_sub_models_pretrained( cls, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 08af10d1a3d8..a43ecaa04c36 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -677,6 +677,11 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MusicgenMelodySinusoidalPositionalEmbedding): + weights = module.get_embedding(*module.weights.shape) + weights = nn.Parameter(weights, requires_grad=False) + weights.detach_() + module.weights = weights MUSICGEN_MELODY_START_DOCSTRING = r""" diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 32e105053bcb..11bb17cd28ab 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -75,22 +75,20 @@ class PegasusSinusoidalPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: super().__init__(num_positions, embedding_dim) - @staticmethod - def _init_weight(out: nn.Parameter) -> nn.Parameter: + def _init_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] """ - n_pos, dim = out.shape + n_pos, dim = self.weight.shape position_enc = np.array( [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] ) - out.requires_grad = False # set early to avoid an error in pytorch-1.8+ + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - out.detach_() - return out + self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: @@ -466,7 +464,7 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() elif isinstance(module, PegasusSinusoidalPositionalEmbedding): - module.weight = module._init_weight(module.weight) + module._init_weight() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: @@ -662,7 +660,7 @@ def resize_position_embeddings(self, new_num_position_embeddings: int): self.config.d_model, self.padding_idx, ) - self.embed_positions.weight = self.embed_positions._init_weight(self.embed_positions.weight) + self.embed_positions._init_weight() self.embed_positions.to(self.device) def get_position_embeddings(self) -> nn.Embedding: @@ -866,7 +864,7 @@ def resize_position_embeddings(self, new_num_position_embeddings: int): self.config.d_model, self.padding_idx, ) - self.embed_positions.weight = self.embed_positions._init_weight(self.embed_positions.weight) + self.embed_positions._init_weight() self.embed_positions.to(self.device) def get_position_embeddings(self) -> nn.Embedding: diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 9d6664e1eb7f..e84cd45453ed 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -235,13 +235,6 @@ class RagPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - @classmethod - def from_pretrained(cls, *args, **kwargs): - # At the moment fast initialization is not supported - # for composite models - kwargs["_fast_init"] = False - return super().from_pretrained(*args, **kwargs) - @classmethod def from_pretrained_question_encoder_generator( cls, diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 0425b8d1978d..f2dfa19a6a50 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -698,7 +698,7 @@ class RobertaPreTrainedModel(PreTrainedModel): _no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention", "RobertaSdpaSelfAttention"] _supports_sdpa = True - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->RobertaLMHead def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): @@ -714,6 +714,8 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, RobertaLMHead): + module.bias.data.zero_() ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index e8c5156d3cc5..6b0c40b222c1 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -581,7 +581,7 @@ class RobertaPreLayerNormPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["RobertaPreLayerNormEmbeddings", "RobertaPreLayerNormSelfAttention"] - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->RobertaPreLayerNormLMHead def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): @@ -597,6 +597,8 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, RobertaPreLayerNormLMHead): + module.bias.data.zero_() ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 437a149d6057..445f6edb1c89 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -60,22 +60,20 @@ class RoFormerSinusoidalPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: super().__init__(num_positions, embedding_dim) - @staticmethod - def _init_weight(out: nn.Parameter) -> nn.Parameter: + def _init_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] """ - n_pos, dim = out.shape + n_pos, dim = self.weight.shape position_enc = np.array( [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] ) - out.requires_grad = False # set early to avoid an error in pytorch-1.8+ + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - out.detach_() - return out + self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: @@ -693,7 +691,7 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() elif isinstance(module, RoFormerSinusoidalPositionalEmbedding): - module.weight = module._init_weight(module.weight) + module._init_weight() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 50e0e86ddee0..dd36b23ba66a 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -655,9 +655,10 @@ def forward(self, masks): class SamPromptEncoder(nn.Module): - def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding): + def __init__(self, config: SamPromptEncoderConfig): super().__init__() - self.shared_embedding = shared_patch_embedding + self.shared_embedding = SamPositionalEmbedding(config.vision_config) + config = config.prompt_encoder_config self.mask_embed = SamMaskEmbedding(config) self.no_mask_embed = nn.Embedding(1, config.hidden_size) @@ -1198,6 +1199,13 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (SamLayerNorm, nn.LayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, SamVisionAttention): + if module.use_rel_pos: + module.rel_pos_h.data.zero_() + module.rel_pos_w.data.zero_() SAM_START_DOCSTRING = r""" @@ -1348,17 +1356,24 @@ def forward( ) class SamModel(SamPreTrainedModel): _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + # need to be ignored, as it's a buffer and will not be correctly detected as tied weight + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] def __init__(self, config): super().__init__(config) self.shared_image_embedding = SamPositionalEmbedding(config.vision_config) self.vision_encoder = SamVisionEncoder(config.vision_config) - self.prompt_encoder = SamPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding) + self.prompt_encoder = SamPromptEncoder(config) self.mask_decoder = SamMaskDecoder(config.mask_decoder_config) self.post_init() + def _tie_weights(self): + self.prompt_encoder.shared_embedding.positional_embedding.data = ( + self.shared_image_embedding.positional_embedding.data + ) + def get_input_embeddings(self): return self.vision_encoder.get_input_embeddings() diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index 36efb00e67b0..3e303b1bff0e 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -463,7 +463,7 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): module.bias.data.zero_() module.weight.data.fill_(1.0) diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 4375f56a87e6..425c3f7d5b35 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -277,17 +277,6 @@ def freeze_feature_encoder(self): """ self.encoder.freeze_feature_encoder() - @classmethod - def from_pretrained(cls, *args, **kwargs): - # At the moment fast initialization is not supported for composite models - if kwargs.get("_fast_init", False): - logger.warning( - "Fast initialization is currently not supported for SpeechEncoderDecoderModel. " - "Falling back to slow initialization..." - ) - kwargs["_fast_init"] = False - return super().from_pretrained(*args, **kwargs) - @classmethod def from_encoder_decoder_pretrained( cls, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 64406745a426..295a427e6a72 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -520,7 +520,6 @@ class SplinterPreTrainedModel(PreTrainedModel): base_model_prefix = "splinter" supports_gradient_checkpointing = True - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index a4262491366a..46dff663d171 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -943,6 +943,13 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, SwinEmbeddings): + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, SwinSelfAttention): + module.relative_position_bias_table.data.zero_() SWIN_START_DOCSTRING = r""" diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 626f883ac803..46e0a1ca9ad7 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -976,7 +976,6 @@ def forward( ) -# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->Swinv2,swin->swinv2 class Swinv2PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -1000,6 +999,13 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, Swinv2Embeddings): + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, Swinv2SelfAttention): + module.logit_scale.data.fill_(math.log(10)) SWINV2_START_DOCSTRING = r""" diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 53d9a9d6baeb..5a2450b9a86f 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -719,7 +719,7 @@ class TapasPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_param_buffer_assignment = False - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with Bert->Tapas def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): @@ -735,6 +735,8 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, TapasLMPredictionHead): + module.bias.data.zero_() TAPAS_START_DOCSTRING = r""" diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index ba4a376e80fd..9a87d19d1602 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -234,22 +234,20 @@ class TimeSeriesSinusoidalPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: super().__init__(num_positions, embedding_dim) - @staticmethod - def _init_weight(out: nn.Parameter) -> nn.Parameter: + def _init_weight(self): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] """ - n_pos, dim = out.shape + n_pos, dim = self.weight.shape position_enc = np.array( [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] ) - out.requires_grad = False # set early to avoid an error in pytorch-1.8+ + out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False) sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - out.detach_() - return out + self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: @@ -640,7 +638,7 @@ def _init_weights(self, module): if module.bias is not None: module.bias.data.zero_() elif isinstance(module, TimeSeriesSinusoidalPositionalEmbedding): - module.weight = module._init_weight(module.weight) + module._init_weight() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 9b18306713d2..7451973b5b65 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -367,14 +367,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): return model - # At the moment fast initialization is not supported for composite models - if kwargs.get("_fast_init", False): - logger.warning( - "Fast initialization is currently not supported for VisionEncoderDecoderModel. " - "Falling back to slow initialization..." - ) - kwargs["_fast_init"] = False - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @classmethod diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py index a5d3cad6016e..3e770f6935d8 100755 --- a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py @@ -407,13 +407,6 @@ def forward( vision_model_output=vision_outputs, ) - @classmethod - def from_pretrained(cls, *args, **kwargs): - # At the moment fast initialization is not supported - # for composite models - kwargs["_fast_init"] = False - return super().from_pretrained(*args, **kwargs) - @classmethod def from_vision_text_pretrained( cls, diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 158a3e3e6d51..d757aeaf28b1 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -497,6 +497,9 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No std=self.config.initializer_range, ).to(module.cls_token.dtype) + if module.mask_token is not None: + module.mask_token.data.zero_() + VIT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index c002c41ca068..e4f6a868acca 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -208,7 +208,6 @@ def __init__(self, config): ) self.patch_size = config.patch_size self.config = config - self.initialize_weights() def initialize_weights(self): # initialize (and freeze) position embeddings by sin-cos embedding @@ -660,6 +659,11 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, ViTMAEEmbeddings): + module.initialize_weights() + elif isinstance(module, ViTMAEDecoder): + module.mask_token.data.zero_() + module.decoder_pos_embed.data.zero_() VIT_MAE_START_DOCSTRING = r""" diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 8f25438ef9e4..fb5a3d56ba60 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -480,6 +480,11 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, ViTMSNEmbeddings): + module.cls_token.data.zero_() + module.position_embeddings.data.zero_() + if module.mask_token is not None: + module.mask_token.data.zero_() VIT_MSN_START_DOCSTRING = r""" diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 238a723dfab3..669106239a06 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -485,8 +485,9 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) - elif isinstance(module, nn.Parameter): - module.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, VivitEmbeddings): + module.cls_token.data.zero_() + module.position_embeddings.data.zero_() VIVIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 07800804c1bf..1fe5823c2066 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -700,7 +700,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): _no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention", "XLMRobertaSdpaSelfAttention"] _supports_sdpa = True - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->XLMRobertaLMHead def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): @@ -716,6 +716,8 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, XLMRobertaLMHead): + module.bias.data.zero_() XLM_ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 014480ecd82e..ad43c7903f4f 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -692,7 +692,7 @@ class XLMRobertaXLPreTrainedModel(PreTrainedModel): _no_split_modules = ["XLMRobertaXLEmbeddings", "XLMRobertaXLLayer"] _supports_sdpa = True - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->XLMRobertaXLLMHead def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): @@ -708,6 +708,8 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, XLMRobertaXLLMHead): + module.bias.data.zero_() XLM_ROBERTA_XL_START_DOCSTRING = r""" diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index a3bde4c2b59d..21aad7188e0e 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -645,7 +645,7 @@ class XmodPreTrainedModel(PreTrainedModel): base_model_prefix = "roberta" supports_gradient_checkpointing = True - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->XmodLMHead def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): @@ -661,6 +661,8 @@ def _init_weights(self, module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, XmodLMHead): + module.bias.data.zero_() def set_default_language(self, language: str): """ diff --git a/src/transformers/quantizers/quantizers_utils.py b/src/transformers/quantizers/quantizers_utils.py index 6ae287bf251b..bbed8317049f 100644 --- a/src/transformers/quantizers/quantizers_utils.py +++ b/src/transformers/quantizers/quantizers_utils.py @@ -16,11 +16,6 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: if "." in tensor_name: - splits = tensor_name.split(".") - for split in splits[:-1]: - new_module = getattr(module, split) - if new_module is None: - raise ValueError(f"{module} has no attribute {split}.") - module = new_module - tensor_name = splits[-1] + module_name, tensor_name = tensor_name.rsplit(".", 1) + module = module.get_submodule(module_name) return module, tensor_name diff --git a/tests/models/align/test_modeling_align.py b/tests/models/align/test_modeling_align.py index 8619fe678506..08d4ea221050 100644 --- a/tests/models/align/test_modeling_align.py +++ b/tests/models/align/test_modeling_align.py @@ -380,14 +380,6 @@ def test_inputs_embeds(self): def test_inputs_embeds_matches_input_ids(self): pass - @unittest.skip(reason="AlignTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="AlignTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "kakaobrain/align-base" diff --git a/tests/models/altclip/test_modeling_altclip.py b/tests/models/altclip/test_modeling_altclip.py index 81190d85f852..a111181b699d 100755 --- a/tests/models/altclip/test_modeling_altclip.py +++ b/tests/models/altclip/test_modeling_altclip.py @@ -198,14 +198,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="AltCLIPVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="AltCLIPVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="AltCLIPVisionModel use the same cv backbone with CLIP model.") def test_model_from_pretrained(self): pass @@ -357,14 +349,6 @@ def test_hidden_states_output(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="AltCLIPTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="AltCLIPTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "BAAI/AltCLIP" diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 475baaa5ffc7..50c5cd37d089 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1535,7 +1535,3 @@ def test_decoder_model_attn_mask_past(self): @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return - - @unittest.skip - def test_save_load_fast_init_from_base(self): - pass diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 3dac349fb4fd..1bf594d3c999 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -213,14 +213,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="BlipVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="BlipVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "Salesforce/blip-vqa-base" @@ -361,14 +353,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "Salesforce/blip-vqa-base" diff --git a/tests/models/blip/test_modeling_blip_text.py b/tests/models/blip/test_modeling_blip_text.py index 0be5d72002a9..d6614c6a2f3f 100644 --- a/tests/models/blip/test_modeling_blip_text.py +++ b/tests/models/blip/test_modeling_blip_text.py @@ -165,14 +165,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "Salesforce/blip-vqa-base" diff --git a/tests/models/blip/test_modeling_tf_blip.py b/tests/models/blip/test_modeling_tf_blip.py index f8a73d19991f..ed427af7ee31 100644 --- a/tests/models/blip/test_modeling_tf_blip.py +++ b/tests/models/blip/test_modeling_tf_blip.py @@ -179,14 +179,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip(reason="BlipVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="BlipVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "Salesforce/blip-vqa-base" @@ -307,14 +299,6 @@ def test_model(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "Salesforce/blip-vqa-base" diff --git a/tests/models/blip/test_modeling_tf_blip_text.py b/tests/models/blip/test_modeling_tf_blip_text.py index 6636ee3e216d..082473dfd507 100644 --- a/tests/models/blip/test_modeling_tf_blip_text.py +++ b/tests/models/blip/test_modeling_tf_blip_text.py @@ -163,14 +163,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "Salesforce/blip-vqa-base" diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index b55ec4a23c05..a360cb98a4ba 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -220,14 +220,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Blip2VisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Blip2VisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "Salesforce/blip2-opt-2.7b" @@ -509,14 +501,6 @@ def test_retain_grad_hidden_states_attentions(self): def test_model_get_set_embeddings(self): pass - @unittest.skip(reason="There's no base Blip2Model") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="There's no base Blip2Model") - def test_save_load_fast_init_to_base(self): - pass - @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): """ @@ -954,14 +938,6 @@ def test_retain_grad_hidden_states_attentions(self): def test_model_get_set_embeddings(self): pass - @unittest.skip(reason="There's no base Blip2Model") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="There's no base Blip2Model") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="Does not work on the tiny model as we keep hitting edge cases.") def test_cpu_offload(self): pass @@ -1245,14 +1221,6 @@ def test_retain_grad_hidden_states_attentions(self): def test_model_common_attributes(self): pass - @unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Blip2TextModelWithProjection has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -1420,14 +1388,6 @@ def test_model_common_attributes(self): x = model.get_output_embeddings() self.assertTrue(x is None or isinstance(x, nn.Linear)) - @unittest.skip(reason="Blip2VisionModelWithProjection has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Blip2VisionModelWithProjection has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/bridgetower/test_modeling_bridgetower.py b/tests/models/bridgetower/test_modeling_bridgetower.py index 66d0d82b6d75..bc6ff0d6e47e 100644 --- a/tests/models/bridgetower/test_modeling_bridgetower.py +++ b/tests/models/bridgetower/test_modeling_bridgetower.py @@ -14,11 +14,8 @@ # limitations under the License. """Testing suite for the PyTorch BridgeTower model.""" -import tempfile import unittest -import numpy as np - from transformers import ( BridgeTowerConfig, BridgeTowerTextConfig, @@ -359,39 +356,6 @@ def test_model_from_pretrained(self): model = BridgeTowerModel.from_pretrained(model_name) self.assertIsNotNone(model) - @slow - def test_save_load_fast_init_from_base(self): - # Override as it is a slow test on this model - super().test_save_load_fast_init_from_base() - - # Override as extracting meaningful tensor from output is different for BridgeTower - def test_save_load(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**input_dict) - - out_2 = self.extract_output(outputs, model_class.__name__) - out_2 = out_2.cpu().numpy() - out_2[np.isnan(out_2)] = 0 - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname) - model.to(torch_device) - with torch.no_grad(): - after_outputs = model(**input_dict) - - # Make sure we don't have nans - out_1 = self.extract_output(after_outputs, model_class.__name__) - out_1 = out_1.cpu().numpy() - out_1[np.isnan(out_1)] = 0 - max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) - # Override this as `hidden states output` is different for BridgeTower def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): diff --git a/tests/models/chinese_clip/test_modeling_chinese_clip.py b/tests/models/chinese_clip/test_modeling_chinese_clip.py index bc14d80524c4..959304288365 100644 --- a/tests/models/chinese_clip/test_modeling_chinese_clip.py +++ b/tests/models/chinese_clip/test_modeling_chinese_clip.py @@ -408,14 +408,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="ChineseCLIPTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="ChineseCLIPTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @require_torch class ChineseCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): @@ -488,14 +480,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="ChineseCLIPVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="ChineseCLIPVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "OFA-Sys/chinese-clip-vit-base-patch16" diff --git a/tests/models/clap/test_modeling_clap.py b/tests/models/clap/test_modeling_clap.py index c8250648c670..21281ced3e89 100644 --- a/tests/models/clap/test_modeling_clap.py +++ b/tests/models/clap/test_modeling_clap.py @@ -263,14 +263,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="ClapAudioModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="ClapAudioModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "laion/clap-htsat-fused" @@ -432,14 +424,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="ClapTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="ClapTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "laion/clap-htsat-fused" diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 66f741e6f4b2..5600e67a70ca 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -446,14 +446,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="CLIPVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="CLIPVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "openai/clip-vit-base-patch32" @@ -628,14 +620,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="CLIPTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="CLIPTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "openai/clip-vit-base-patch32" diff --git a/tests/models/clipseg/test_modeling_clipseg.py b/tests/models/clipseg/test_modeling_clipseg.py index a116b82f5f4e..85115499267c 100644 --- a/tests/models/clipseg/test_modeling_clipseg.py +++ b/tests/models/clipseg/test_modeling_clipseg.py @@ -202,14 +202,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="CLIPSegVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="CLIPSegVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "CIDAS/clipseg-rd64-refined" @@ -345,14 +337,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="CLIPSegTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="CLIPSegTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "CIDAS/clipseg-rd64-refined" diff --git a/tests/models/depth_anything/test_modeling_depth_anything.py b/tests/models/depth_anything/test_modeling_depth_anything.py index 63d57d671706..b9d259b1c712 100644 --- a/tests/models/depth_anything/test_modeling_depth_anything.py +++ b/tests/models/depth_anything/test_modeling_depth_anything.py @@ -181,14 +181,6 @@ def test_training_gradient_checkpointing(self): def test_model_get_set_embeddings(self): pass - @unittest.skip(reason="Depth Anything with AutoBackbone does not have a base model") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Depth Anything with AutoBackbone does not have a base model") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip( reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) diff --git a/tests/models/diffllama/test_modeling_diffllama.py b/tests/models/diffllama/test_modeling_diffllama.py index 9864b713a59d..aba1844e3434 100644 --- a/tests/models/diffllama/test_modeling_diffllama.py +++ b/tests/models/diffllama/test_modeling_diffllama.py @@ -391,10 +391,6 @@ def test_diffllama_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="DiffLlama buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/dpt/test_modeling_dpt_auto_backbone.py b/tests/models/dpt/test_modeling_dpt_auto_backbone.py index 4c0527687ce7..e4c40ca80909 100644 --- a/tests/models/dpt/test_modeling_dpt_auto_backbone.py +++ b/tests/models/dpt/test_modeling_dpt_auto_backbone.py @@ -221,14 +221,6 @@ def test_initialization(self): def test_model_get_set_embeddings(self): pass - @unittest.skip(reason="DPT with AutoBackbone does not have a base model") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="DPT with AutoBackbone does not have a base model") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) diff --git a/tests/models/esm/test_modeling_esmfold.py b/tests/models/esm/test_modeling_esmfold.py index 7450f0295f77..03b1981dc87d 100644 --- a/tests/models/esm/test_modeling_esmfold.py +++ b/tests/models/esm/test_modeling_esmfold.py @@ -241,10 +241,6 @@ def test_retain_grad_hidden_states_attentions(self): def test_model_outputs_equivalence(self): pass - @unittest.skip(reason="This test doesn't work for ESMFold and doesn't test core functionality") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="ESMFold does not support input chunking.") def test_feed_forward_chunking(self): pass diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py index a6cf75a72285..3e3321a3141c 100644 --- a/tests/models/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -320,16 +320,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="FlavaImageModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - # skip this test as FlavaImageModel has no base class and is - # not available in MODEL_MAPPING - @unittest.skip(reason="FlavaImageModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "facebook/flava-full" @@ -486,14 +476,6 @@ def test_inputs_embeds(self): # FLAVA does not use inputs_embeds pass - @unittest.skip(reason="FlavaTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="FlavaTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "facebook/flava-full" @@ -650,14 +632,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="FlavaMultimodalModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="FlavaMultimodalModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "facebook/flava-full" @@ -785,14 +759,6 @@ def test_inputs_embeds(self): def test_model_outputs_equivalence(self): pass - @unittest.skip(reason="FlavaImageCodebook has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="FlavaImageCodebook has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "facebook/flava-full" diff --git a/tests/models/fsmt/test_modeling_fsmt.py b/tests/models/fsmt/test_modeling_fsmt.py index f9bec05743ca..4119769676a9 100644 --- a/tests/models/fsmt/test_modeling_fsmt.py +++ b/tests/models/fsmt/test_modeling_fsmt.py @@ -548,6 +548,7 @@ def test_basic(self): emb1 = SinusoidalPositionalEmbedding(num_positions=6, embedding_dim=6, padding_idx=self.padding_idx).to( torch_device ) + emb1.make_weight(*emb1.weight.shape, emb1.padding_idx) emb = emb1(input_ids) desired_weights = torch.tensor( [ @@ -562,10 +563,16 @@ def test_basic(self): def test_odd_embed_dim(self): # odd embedding_dim is allowed - SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=self.padding_idx).to(torch_device) + test = SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=self.padding_idx).to( + torch_device + ) + test.make_weight(*test.weight.shape, test.padding_idx) # odd num_embeddings is allowed - SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=self.padding_idx).to(torch_device) + test = SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=self.padding_idx).to( + torch_device + ) + test.make_weight(*test.weight.shape, test.padding_idx) @unittest.skip(reason="different from marian (needs more research)") def test_positional_emb_weights_against_marian(self): @@ -579,6 +586,7 @@ def test_positional_emb_weights_against_marian(self): emb1 = SinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512, padding_idx=self.padding_idx).to( torch_device ) + emb1.make_weight(*emb1.weight.shape, emb1.padding_idx) weights = emb1.weights.data[:3, :5] # XXX: only the 1st and 3rd lines match - this is testing against # verbatim copy of SinusoidalPositionalEmbedding from fairseq diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 0b4abb85e051..20247bd68c0c 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -406,10 +406,6 @@ def test_Gemma_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Gemma buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="Gemma uses GQA on all models so the KV cache is a non standard format") def test_past_key_values_format(self): pass diff --git a/tests/models/git/test_modeling_git.py b/tests/models/git/test_modeling_git.py index a2427dfbc158..03f7c8285bfc 100644 --- a/tests/models/git/test_modeling_git.py +++ b/tests/models/git/test_modeling_git.py @@ -187,14 +187,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="GitVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="GitVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "microsoft/git-base" diff --git a/tests/models/granite/test_modeling_granite.py b/tests/models/granite/test_modeling_granite.py index 826cda3f67c8..7819d61025e0 100644 --- a/tests/models/granite/test_modeling_granite.py +++ b/tests/models/granite/test_modeling_granite.py @@ -314,10 +314,6 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip("Granite buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/granitemoe/test_modeling_granitemoe.py b/tests/models/granitemoe/test_modeling_granitemoe.py index cd2470827b90..927cb2d3655e 100644 --- a/tests/models/granitemoe/test_modeling_granitemoe.py +++ b/tests/models/granitemoe/test_modeling_granitemoe.py @@ -313,10 +313,6 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip("GraniteMoe buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/granitemoeshared/test_modeling_granitemoeshared.py b/tests/models/granitemoeshared/test_modeling_granitemoeshared.py index 28787870d8d5..646911f6f79b 100644 --- a/tests/models/granitemoeshared/test_modeling_granitemoeshared.py +++ b/tests/models/granitemoeshared/test_modeling_granitemoeshared.py @@ -316,10 +316,6 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip("GraniteMoeShared buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/groupvit/test_modeling_groupvit.py b/tests/models/groupvit/test_modeling_groupvit.py index 4e836f827e9b..3c48743a590e 100644 --- a/tests/models/groupvit/test_modeling_groupvit.py +++ b/tests/models/groupvit/test_modeling_groupvit.py @@ -274,14 +274,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="GroupViTVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="GroupViTVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - # override since the attention mask from GroupViT is not used to compute loss, thus no grad def test_retain_grad_hidden_states_attentions(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -476,14 +468,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="GroupViTTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="GroupViTTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "nvidia/groupvit-gcc-yfcc" diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index b5e35490f6f6..6b5a36e58434 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -171,7 +171,7 @@ def check_encoder_decoder_model_standalone(self, config, inputs_dict): embed_positions = InformerSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model ).to(torch_device) - embed_positions.weight = embed_positions._init_weight(embed_positions.weight) + embed_positions._init_weight() self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight)) self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight)) diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index e9d325460d55..f7c13dd09d98 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -216,14 +216,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="InstructBlipVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="InstructBlipVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "Salesforce/instructblip-flan-t5-xl" @@ -522,14 +514,6 @@ def test_retain_grad_hidden_states_attentions(self): def test_model_get_set_embeddings(self): pass - @unittest.skip(reason="There's no base InstructBlipModel") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="There's no base InstructBlipModel") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip( "InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present" ) diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index 27ed2d42e7ce..e5cc00d92c6a 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -224,14 +224,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="InstructBlipVideoVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="InstructBlipVideoVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "Salesforce/instructblip-vicuna-7b" @@ -538,14 +530,6 @@ def test_retain_grad_hidden_states_attentions(self): def test_model_common_attributes(self): pass - @unittest.skip(reason="There's no base InstructBlipVideoModel") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="There's no base InstructBlipVideoModel") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip( "InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present" ) diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py index 4538ad56108b..aab93c553ef8 100644 --- a/tests/models/jetmoe/test_modeling_jetmoe.py +++ b/tests/models/jetmoe/test_modeling_jetmoe.py @@ -362,10 +362,6 @@ def test_jetmoe_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - @unittest.skip(reason="JetMoe buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="JetMoe uses MoA on all models so the KV cache is a non standard format") def test_past_key_values_format(self): pass diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 319187f11346..b072105a3fa9 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -384,10 +384,6 @@ def test_llama_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Llama buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/maskformer/test_modeling_maskformer_swin.py b/tests/models/maskformer/test_modeling_maskformer_swin.py index 502660b191ef..6125c2854f79 100644 --- a/tests/models/maskformer/test_modeling_maskformer_swin.py +++ b/tests/models/maskformer/test_modeling_maskformer_swin.py @@ -235,10 +235,6 @@ def test_model_get_set_embeddings(self): def test_attention_outputs(self): pass - @unittest.skip(reason="MaskFormerSwin is only used as an internal backbone") - def test_save_load_fast_init_to_base(self): - pass - def check_hidden_states_output(self, inputs_dict, config, model_class, image_size): model = model_class(config) model.to(torch_device) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index c682fe86b4d3..8f8b757bac96 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -406,10 +406,6 @@ def test_Mistral_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Mistral buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="Mistral uses GQA on all models so the KV cache is a non standard format") def test_past_key_values_format(self): pass diff --git a/tests/models/mistral/test_modeling_tf_mistral.py b/tests/models/mistral/test_modeling_tf_mistral.py index dd4eff6ba908..a45935e9784c 100644 --- a/tests/models/mistral/test_modeling_tf_mistral.py +++ b/tests/models/mistral/test_modeling_tf_mistral.py @@ -325,10 +325,6 @@ def test_Mistral_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - @unittest.skip("Mistral buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip("Mistral uses GQA on all models so the KV cache is a non standard format") def test_past_key_values_format(self): pass diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 481f94425c46..08d44a94710e 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -405,10 +405,6 @@ def test_Mixtral_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Mixtral buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="Mixtral uses GQA on all models so the KV cache is a non standard format") def test_past_key_values_format(self): pass diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index c14292b093f2..9d49e1b6c5b6 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -970,12 +970,6 @@ def test_greedy_generate_stereo_outputs(self): super().test_greedy_generate_dict_outputs() self.model_tester.audio_channels = original_audio_channels - @unittest.skip( - reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composite model" - ) - def test_save_load_fast_init_from_base(self): - pass - @require_flash_attn @require_torch_gpu @mark.flash_attn_test diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 0fcfa254afa2..a979fb8f6646 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -960,12 +960,6 @@ def test_greedy_generate_stereo_outputs(self): super().test_greedy_generate_dict_outputs() self.model_tester.audio_channels = original_audio_channels - @unittest.skip( - reason="MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composite model" - ) - def test_save_load_fast_init_from_base(self): - pass - @require_flash_attn @require_torch_gpu @mark.flash_attn_test diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index a96eb9111359..720486968f02 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -310,10 +310,6 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip(reason="OLMo buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/olmo2/test_modeling_olmo2.py b/tests/models/olmo2/test_modeling_olmo2.py index 51496188f9fc..9ed55eb38763 100644 --- a/tests/models/olmo2/test_modeling_olmo2.py +++ b/tests/models/olmo2/test_modeling_olmo2.py @@ -309,10 +309,6 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip(reason="OLMo2 buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/olmoe/test_modeling_olmoe.py b/tests/models/olmoe/test_modeling_olmoe.py index 07d904699faa..24461a7c40f7 100644 --- a/tests/models/olmoe/test_modeling_olmoe.py +++ b/tests/models/olmoe/test_modeling_olmoe.py @@ -323,10 +323,6 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip(reason="OLMoE buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/owlv2/test_modeling_owlv2.py b/tests/models/owlv2/test_modeling_owlv2.py index dff1cbe8c00c..c5446297b998 100644 --- a/tests/models/owlv2/test_modeling_owlv2.py +++ b/tests/models/owlv2/test_modeling_owlv2.py @@ -207,14 +207,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Owlv2VisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Owlv2VisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "google/owlv2-base-patch16-ensemble" @@ -355,14 +347,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="Owlv2TextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Owlv2TextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "google/owlv2-base-patch16-ensemble" @@ -689,10 +673,6 @@ def test_initialization(self): def test_forward_signature(self): pass - @unittest.skip(reason="Test_save_load_fast_init_from_base is tested in individual model tests") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="OwlV2 does not support training yet") def test_training(self): pass diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py index 1ad85cb37919..aadc8f1f9394 100644 --- a/tests/models/owlvit/test_modeling_owlvit.py +++ b/tests/models/owlvit/test_modeling_owlvit.py @@ -205,14 +205,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "google/owlvit-base-patch32" @@ -351,14 +343,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "google/owlvit-base-patch32" @@ -682,10 +666,6 @@ def test_initialization(self): def test_forward_signature(self): pass - @unittest.skip(reason="Test_save_load_fast_init_from_base is tested in individual model tests") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="OWL-ViT does not support training yet") def test_training(self): pass diff --git a/tests/models/perceiver/test_modeling_perceiver.py b/tests/models/perceiver/test_modeling_perceiver.py index e6bcb930ec61..2cd4719083c2 100644 --- a/tests/models/perceiver/test_modeling_perceiver.py +++ b/tests/models/perceiver/test_modeling_perceiver.py @@ -812,14 +812,6 @@ def test_problem_types(self): def test_multi_gpu_data_parallel_forward(self): pass - @unittest.skip(reason="Perceiver models don't have a typical head like is the case with BERT") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Perceiver models don't have a typical head like is the case with BERT") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="Perceiver doesn't support resize_token_embeddings") def test_resize_tokens_embeddings(self): pass diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 4867e38acb68..2fa6f0ca6707 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -379,11 +379,6 @@ def test_persimmon_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Persimmon buffers include complex numbers, which breaks this test") - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base - def test_save_load_fast_init_from_base(self): - pass - @parameterized.expand([("linear",), ("dynamic",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Persimmon def test_model_rope_scaling_from_config(self, scaling_type): diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index dd6846ac4c0e..1517ae9dbd2f 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -212,14 +212,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_retain_grad_hidden_states_attentions(self): pass - @unittest.skip(reason="Pix2StructVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Pix2StructVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "google/pix2struct-textcaps-base" @@ -361,14 +353,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="Pix2StructTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Pix2StructTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "google/pix2struct-textcaps-base" diff --git a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py index 77cb96ccea02..0dc6838a9dfb 100644 --- a/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py +++ b/tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py @@ -184,14 +184,6 @@ def test_training_gradient_checkpointing(self): def test_model_get_set_embeddings(self): pass - @unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip( reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 2757ba30a8e8..3be1aef60fd1 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -417,10 +417,6 @@ def test_Qwen2_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Qwen2 buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="Qwen2 uses GQA on all models so the KV cache is a non standard format") def test_past_key_values_format(self): pass diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index c1e2daee81e7..bdddb3708f4e 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -445,10 +445,6 @@ def test_Qwen2Moe_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Qwen2Moe buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="Qwen2Moe uses GQA on all models so the KV cache is a non standard format") def test_past_key_values_format(self): pass diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index 672bf51b4d60..1f9843f91f17 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -420,10 +420,6 @@ def test_Qwen3_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Qwen3 buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - # Ignore copy def test_past_key_values_format(self): super().test_past_key_values_format() diff --git a/tests/models/qwen3_moe/test_modeling_qwen3_moe.py b/tests/models/qwen3_moe/test_modeling_qwen3_moe.py index 3c237b7ae149..9938d20b857e 100644 --- a/tests/models/qwen3_moe/test_modeling_qwen3_moe.py +++ b/tests/models/qwen3_moe/test_modeling_qwen3_moe.py @@ -439,10 +439,6 @@ def test_Qwen3Moe_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Qwen3Moe buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - # Ignore copy def test_past_key_values_format(self): super().test_past_key_values_format() diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index a7a8a74653df..0ca4a0dd6e8e 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -337,10 +337,6 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip(reason="Fast init from base not tested for RecurrentGemma") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="RecurrentGemma does not return pkv") def test_past_key_values_format(self): pass diff --git a/tests/models/roformer/test_modeling_roformer.py b/tests/models/roformer/test_modeling_roformer.py index fbcc2361289b..a592ea01caea 100644 --- a/tests/models/roformer/test_modeling_roformer.py +++ b/tests/models/roformer/test_modeling_roformer.py @@ -534,7 +534,7 @@ class RoFormerSinusoidalPositionalEmbeddingTest(unittest.TestCase): def test_basic(self): input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device) emb1 = RoFormerSinusoidalPositionalEmbedding(num_positions=6, embedding_dim=6).to(torch_device) - emb1.weight = emb1._init_weight(emb1.weight) + emb1._init_weight() emb = emb1(input_ids.shape) desired_weights = torch.tensor( [[0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 1.0000], [0.8415, 0.0464, 0.0022, 0.5403, 0.9989, 1.0000]] @@ -553,7 +553,7 @@ def test_positional_emb_weights_against_roformer(self): ] ).to(torch_device) emb1 = RoFormerSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512).to(torch_device) - emb1.weight = emb1._init_weight(emb1.weight) + emb1._init_weight() weights = emb1.weight.data[:3, :5].to(torch_device) self.assertTrue( @@ -575,7 +575,7 @@ def test_apply_rotary_position_embeddings(self): -torch.arange(2 * 12 * 16 * 64, dtype=torch.float, device=torch_device).reshape(2, 12, 16, 64) / 100 ).to(torch_device) embed_positions = RoFormerSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=64).to(torch_device) - embed_positions.weight = embed_positions._init_weight(embed_positions.weight) + embed_positions._init_weight() sinusoidal_pos = embed_positions([2, 16, 768])[None, None, :, :] query_layer, key_layer = RoFormerSelfAttention.apply_rotary_position_embeddings( diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index 0f19b29d9026..2a17fad33447 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -647,14 +647,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="SamModel does not support training") def test_retain_grad_hidden_states_attentions(self): pass diff --git a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py index c30ebcc87fcd..d7f015804474 100644 --- a/tests/models/seamless_m4t/test_modeling_seamless_m4t.py +++ b/tests/models/seamless_m4t/test_modeling_seamless_m4t.py @@ -429,20 +429,10 @@ def test_inputs_embeds_matches_input_ids(self): def test_model_weights_reload_no_missing_tied_weights(self): pass - @unittest.skip( - reason="SeamlessM4TModel is base class but has actually a bigger architecture than seamlessM4T task-specific models." - ) - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="SeamlessM4TModel can takes input_ids or input_features") def test_forward_signature(self): pass - @unittest.skip(reason="SeamlessM4T has no base model") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) @@ -684,16 +674,6 @@ def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - @unittest.skip( - reason="SeamlessM4TModel is base class but has actually a bigger architecture than seamlessM4T task-specific models." - ) - def test_save_load_fast_init_to_base(self): - pass - - @unittest.skip(reason="SeamlessM4T has no base model") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) diff --git a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py index 2387e5f25ff1..2342f5502c97 100644 --- a/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py +++ b/tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py @@ -445,20 +445,10 @@ def test_inputs_embeds_matches_input_ids(self): def test_model_weights_reload_no_missing_tied_weights(self): pass - @unittest.skip( - reason="SeamlessM4Tv2Model is base class but has actually a bigger architecture than seamlessM4T task-specific models." - ) - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="SeamlessM4Tv2Model can takes input_ids or input_features") def test_forward_signature(self): pass - @unittest.skip(reason="SeamlessM4Tv2 has no base model") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) @@ -687,16 +677,6 @@ def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - @unittest.skip( - reason="SeamlessM4Tv2Model is base class but has actually a bigger architecture than seamlessM4T task-specific models." - ) - def test_save_load_fast_init_to_base(self): - pass - - @unittest.skip(reason="SeamlessM4Tv2 has no base model") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py index 3dec33018476..93268bbc8e43 100644 --- a/tests/models/siglip/test_modeling_siglip.py +++ b/tests/models/siglip/test_modeling_siglip.py @@ -399,14 +399,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="SiglipVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="SiglipVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") def test_initialization(self): pass @@ -563,16 +555,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="SiglipTextModel has no base class and is not available in MODEL_MAPPING") - # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_save_load_fast_init_from_base - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="SiglipTextModel has no base class and is not available in MODEL_MAPPING") - # Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_save_load_fast_init_to_base - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation") def test_initialization(self): pass diff --git a/tests/models/siglip2/test_modeling_siglip2.py b/tests/models/siglip2/test_modeling_siglip2.py index dea49ececa9b..f5959edb5fe2 100644 --- a/tests/models/siglip2/test_modeling_siglip2.py +++ b/tests/models/siglip2/test_modeling_siglip2.py @@ -487,14 +487,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Siglip2VisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Siglip2VisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation") def test_initialization(self): pass @@ -646,14 +638,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="Siglip2TextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Siglip2TextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="Siglip2 uses the same initialization scheme as the Flax original implementation") def test_initialization(self): pass diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 126edf62816e..3b7ae7fa0488 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -1895,14 +1895,6 @@ def test_model_outputs_equivalence(self): def test_retain_grad_hidden_states_attentions(self): pass - @unittest.skip(reason="Fails on automapping of SpeechT5HifiGanConfig") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="Fails on automapping of SpeechT5HifiGanConfig") - def test_save_load_fast_init_to_base(self): - pass - def test_batched_inputs_outputs(self): config, inputs = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index 9f06697a1948..5b3626a25876 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -395,10 +395,6 @@ def test_Starcoder2_token_classification_model(self): (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Starcoder2 buffers include complex numbers, which breaks this test") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip(reason="Starcoder2 uses GQA on all models so the KV cache is a non standard format") def test_past_key_values_format(self): pass diff --git a/tests/models/trocr/test_modeling_trocr.py b/tests/models/trocr/test_modeling_trocr.py index 26654546f648..9af22b1d1bf6 100644 --- a/tests/models/trocr/test_modeling_trocr.py +++ b/tests/models/trocr/test_modeling_trocr.py @@ -173,14 +173,6 @@ def setUp(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="trocr has no base model") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="trocr has no base model") - def test_save_load_fast_init_to_base(self): - pass - def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/models/upernet/test_modeling_upernet.py b/tests/models/upernet/test_modeling_upernet.py index 1b337460f8d5..d731ca9588ae 100644 --- a/tests/models/upernet/test_modeling_upernet.py +++ b/tests/models/upernet/test_modeling_upernet.py @@ -184,14 +184,6 @@ def test_inputs_embeds(self): def test_model_get_set_embeddings(self): pass - @unittest.skip(reason="UperNet does not have a base model") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="UperNet does not have a base model") - def test_save_load_fast_init_to_base(self): - pass - @require_torch_multi_gpu @unittest.skip(reason="UperNet has some layers using `add_module` which doesn't work well with `nn.DataParallel`") def test_multi_gpu_data_parallel_forward(self): diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index 177ddc269d49..a66bb44e8248 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -34,7 +34,7 @@ from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -250,20 +250,6 @@ def test_save_load(self): def test_determinism(self): pass - @unittest.skip( - reason="""ViTMAE returns a random mask + ids_restore in each forward pass. See test_save_load - to get deterministic results.""" - ) - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip( - reason="""ViTMAE returns a random mask + ids_restore in each forward pass. See test_save_load - to get deterministic results.""" - ) - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="""ViTMAE returns a random mask + ids_restore in each forward pass. See test_save_load""") def test_model_outputs_equivalence(self): pass @@ -335,6 +321,23 @@ def test_flash_attn_2_inference_equivalence(self): def test_flash_attn_2_inference_equivalence_right_padding(self): pass + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + # This is an excepton in the module, it's initialized with xavier_uniform without using initializer_range + if name.endswith("patch_embeddings.projection.weight"): + continue + if param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index a05e31bf9329..f4a1f8e5068e 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -3684,10 +3684,6 @@ def test_generate_without_input_ids(self): def test_retain_grad_hidden_states_attentions(self): return - @unittest.skip(reason="The model doesn't support fast init from base") - def test_save_load_fast_init_from_base(self): - pass - @unittest.skip( "Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test" ) diff --git a/tests/models/x_clip/test_modeling_x_clip.py b/tests/models/x_clip/test_modeling_x_clip.py index ac402d2ff9ca..5a121d77439c 100644 --- a/tests/models/x_clip/test_modeling_x_clip.py +++ b/tests/models/x_clip/test_modeling_x_clip.py @@ -206,14 +206,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="XCLIPVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="XCLIPVisionModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "microsoft/xclip-base-patch32" @@ -446,14 +438,6 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_inputs_embeds(self): pass - @unittest.skip(reason="XCLIPTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="XCLIPTextModel has no base class and is not available in MODEL_MAPPING") - def test_save_load_fast_init_to_base(self): - pass - @slow def test_model_from_pretrained(self): model_name = "microsoft/xclip-base-patch32" diff --git a/tests/models/zoedepth/test_modeling_zoedepth.py b/tests/models/zoedepth/test_modeling_zoedepth.py index 342ae269d39d..e9ffae7f5c60 100644 --- a/tests/models/zoedepth/test_modeling_zoedepth.py +++ b/tests/models/zoedepth/test_modeling_zoedepth.py @@ -174,14 +174,6 @@ def test_for_depth_estimation(self): def test_model_common_attributes(self): pass - @unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model") - def test_save_load_fast_init_from_base(self): - pass - - @unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model") - def test_save_load_fast_init_to_base(self): - pass - @unittest.skip(reason="ZoeDepth does not support training yet") def test_training(self): pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 76bbf2766fd9..e1c64964d3f4 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -505,60 +505,6 @@ def test_peft_gradient_checkpointing_enable_disable(self): m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False" ) - @is_flaky(description="low likelihood of failure, reason not yet discovered") - def test_save_load_fast_init_from_base(self): - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if config.__class__ not in MODEL_MAPPING: - self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING") - - base_class = MODEL_MAPPING[config.__class__] - - if isinstance(base_class, tuple): - base_class = base_class[0] - - if model_class == base_class: - continue - - # make a copy of model class to not break future tests - # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class - class CopyClass(model_class): - pass - - model_class_copy = CopyClass - - # make sure that all keys are expected for test - model_class_copy._keys_to_ignore_on_load_missing = [] - - # make init deterministic, but make sure that - # non-initialized weights throw errors nevertheless - model_class_copy._init_weights = _mock_init_weights - model_class_copy.init_weights = _mock_all_init_weights - - model = base_class(config) - state_dict = model.state_dict() - - # this will often delete a single weight of a multi-weight module - # to test an edge case - random_key_to_del = random.choice(list(state_dict.keys())) - del state_dict[random_key_to_del] - - # check that certain keys didn't get saved with the model - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin")) - - model_fast_init = model_class_copy.from_pretrained(tmpdirname) - model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False) - # Before we test anything - - for key in model_fast_init.state_dict().keys(): - if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor): - max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item() - else: - max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - @slow @require_accelerate @mark.accelerate_tests @@ -640,62 +586,6 @@ def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path): self.assertEqual(tied_params1, tied_params2) - def test_save_load_fast_init_to_base(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if config.__class__ not in MODEL_MAPPING: - self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING") - - base_class = MODEL_MAPPING[config.__class__] - - if isinstance(base_class, tuple): - base_class = base_class[0] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - # make a copy of model class to not break future tests - # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class - class CopyClass(base_class): - pass - - base_class_copy = CopyClass - - # make sure that all keys are expected for test - base_class_copy._keys_to_ignore_on_load_missing = [] - - # make init deterministic, but make sure that - # non-initialized weights throw errors nevertheless - base_class_copy._init_weights = _mock_init_weights - base_class_copy.init_weights = _mock_all_init_weights - - model = model_class(config) - state_dict = model.state_dict() - - # this will often delete a single weight of a multi-weight module - # to test an edge case - random_key_to_del = random.choice(list(state_dict.keys())) - del state_dict[random_key_to_del] - - # check that certain keys didn't get saved with the model - with tempfile.TemporaryDirectory() as tmpdirname: - model.config.save_pretrained(tmpdirname) - torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin")) - - model_fast_init = base_class_copy.from_pretrained(tmpdirname) - model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False) - - for key in model_fast_init.state_dict().keys(): - if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor): - max_diff = torch.max( - model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key] - ).item() - else: - max_diff = torch.max( - torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]) - ).item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - def test_torch_save_load(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if config.__class__ not in MODEL_MAPPING: @@ -3189,7 +3079,7 @@ def _init_weights(self, module): # not to init. the weights during the creation: to match the logic in `from_pretrained`, so we can keep the # same sequence of random ops in the execution path to allow us to compare `target_model` and `new_model` below # for `linear` part. - with ContextManagers([no_init_weights(True)]): + with ContextManagers([no_init_weights()]): target_model = MyClass(config=config) target_model.apply(target_model._initialize_weights)