From bb98e7ce58c796b3627c7ac673787437adca085b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 2 May 2024 10:24:47 +0200 Subject: [PATCH] Fix for Neuron (#30259) --- .../models/cohere/modeling_cohere.py | 7 +- .../models/gemma/modeling_gemma.py | 7 +- .../models/llama/modeling_llama.py | 7 +- src/transformers/models/olmo/modeling_olmo.py | 7 +- src/transformers/training_args.py | 4 +- src/transformers/utils/fx.py | 286 +++++++++++++----- tests/test_modeling_common.py | 21 -- 7 files changed, 240 insertions(+), 99 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 950d45ea867a30..b4e5c0ee92a208 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1021,8 +1021,11 @@ def _update_causal_mask( causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 6077259d0b0fac..98ac4757548223 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1007,8 +1007,11 @@ def _update_causal_mask( causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2b8e8f6d0958dd..9560eb6e105c0a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1099,8 +1099,11 @@ def _update_causal_mask( causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 83637536a12531..7aa4843e56cb48 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1082,8 +1082,11 @@ def _update_causal_mask( causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 338bb116dddece..f45c1ba7762085 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -84,12 +84,12 @@ if os.environ.get("TORCHELASTIC_RUN_ID"): if is_optimum_neuron_available(): logger.info( - "Make sure that you are performing the training with the TrainiumTrainer from optimum[neuron], this " + "Make sure that you are performing the training with the NeuronTrainer from optimum[neuron], this " "will fail otherwise." ) else: logger.warning( - "Please use the TrainiumTrainer from optimum[neuron] instead of the Transformers library to perform " + "Please use the NeuronTrainer from optimum[neuron] instead of the Transformers library to perform " "training on AWS Trainium instances. More information here: " "https://github.com/huggingface/optimum-neuron" ) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index ab4f823c2fc8aa..0faf7e0d6ea956 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -15,22 +15,28 @@ import builtins import collections +import contextlib import functools import inspect import math import operator import os import random +import sys import warnings -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union import torch +import torch.utils._pytree as pytree from torch import nn -from torch.fx import Graph, GraphModule, Proxy, Tracer +from torch.fx import Graph, GraphModule, Node, Proxy, Tracer from torch.fx._compatibility import compatibility +from torch.fx._symbolic_trace import is_fx_tracing from torch.fx.proxy import ParameterProxy -from .. import PretrainedConfig, PreTrainedModel, logging +from .. import logging +from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache +from ..modeling_utils import PretrainedConfig, PreTrainedModel from ..models.auto import get_values from ..models.auto.modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, @@ -55,7 +61,7 @@ MODEL_MAPPING_NAMES, ) from ..pytorch_utils import is_torch_greater_or_equal_than_2_0 -from ..utils import ( +from .import_utils import ( ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, get_torch_version, @@ -192,6 +198,8 @@ def _generate_supported_model_class_names( ] _SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS))) +_CURRENT_TRACER = None + def torch_nn_embedding(self, input): return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype) @@ -701,6 +709,92 @@ class MetaDeviceAttribute(HFAttribute): pass +class HFCacheProxy(HFProxy): + """ + Proxy that represents an instance of `transformers.cache_utils.Cache`. + """ + + @property + def __class__(self): + return ProxyableCache + + +def create_wrapper( + function: Callable, + op_type: Union[Literal["call_function"], Literal["call_method"], Literal["get_attr"]], + proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None, +) -> Callable: + @functools.wraps(function) + def wrapper(*args, **kwargs): + if not is_fx_tracing(): + return function(*args, **kwargs) + + found_proxies = [] + + def check_proxy(a): + if isinstance(a, Proxy): + found_proxies.append(a) + + torch.fx.node.map_aggregate(args, check_proxy) + torch.fx.node.map_aggregate(kwargs, check_proxy) + + if len(found_proxies) > 0: + tracer = found_proxies[0].tracer + if op_type == "call_function": + target = function + elif op_type == "call_method": + target = function.__name__ + elif op_type == "get_attr": + target = function.__name__ + else: + raise ValueError(f"op_type {op_type} not supported.") + return tracer.create_proxy(op_type, target, args, kwargs, proxy_factory_fn=proxy_factory_fn) + else: + return function(*args, **kwargs) + + return wrapper + + +class HFProxyableClassMeta(type): + """ + Metaclass that creates a class with its main methods wrapped to be proxyable. + """ + + def __new__( + cls, + name: str, + bases: Tuple[Type, ...], + attrs: Dict[str, Any], + proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None, + ): + cls = super().__new__(cls, name, bases, attrs) + for attr_name in dir(cls): + attr = getattr(cls, attr_name, None) + if attr is None: + continue + if attr_name == "__init__": + op_type = "call_function" + elif attr_name.startswith("__"): + op_type = None + elif inspect.ismethod(attr): + op_type = "call_function" + elif inspect.isfunction(attr): + op_type = "call_method" + else: + op_type = None + if op_type is not None: + setattr(cls, attr_name, create_wrapper(attr, op_type, proxy_factory_fn=proxy_factory_fn)) + return cls + + +def gen_constructor_wrapper(target: Callable) -> Tuple[Callable, Callable]: + """ + Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on. + """ + wrapper = create_wrapper(target, "call_function") + return wrapper, target + + def _proxies_to_metas(v): """Returns the underlying metadata for HFProxies, and behaves like the identity for the others.""" if isinstance(v, MetaDeviceAttribute): @@ -712,25 +806,24 @@ def _proxies_to_metas(v): return v -def _gen_constructor_wrapper(target): - @functools.wraps(target) - def wrapper(*args, **kwargs): - proxy = None +def cache_proxy_factory_fn(n: Node) -> HFCacheProxy: + global _CURRENT_TRACER + if not isinstance(_CURRENT_TRACER, HFTracer): + raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.") + return HFCacheProxy(n, _CURRENT_TRACER) - def check_has_proxy(v): - if isinstance(v, Proxy): - nonlocal proxy - proxy = v - torch.fx.node.map_aggregate(args, check_has_proxy) - torch.fx.node.map_aggregate(kwargs, check_has_proxy) - - if proxy is not None: - return proxy.tracer.create_proxy("call_function", target, args, kwargs) - else: - return target(*args, **kwargs) - - return wrapper, target +# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`. +ProxyableCache = HFProxyableClassMeta("ProxyableCache", (Cache,), {}, proxy_factory_fn=cache_proxy_factory_fn) +ProxyableDynamicCache = HFProxyableClassMeta( + "ProxyableDynamicCache", (DynamicCache,), {}, proxy_factory_fn=cache_proxy_factory_fn +) +ProxyableSinkCache = HFProxyableClassMeta( + "ProxyableSinkCache", (SinkCache,), {}, proxy_factory_fn=cache_proxy_factory_fn +) +ProxyableStaticCache = HFProxyableClassMeta( + "ProxyableStaticCache", (StaticCache,), {}, proxy_factory_fn=cache_proxy_factory_fn +) def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): @@ -764,6 +857,13 @@ class HFTracer(Tracer): "finfo", "tril", ] + _CLASSES_TO_PATCH = { + Cache: ProxyableCache, + DynamicCache: ProxyableDynamicCache, + SinkCache: ProxyableSinkCache, + StaticCache: ProxyableStaticCache, + } + supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) def __init__(self, autowrap_modules=(math,), autowrap_functions=()): @@ -776,7 +876,7 @@ def __init__(self, autowrap_modules=(math,), autowrap_functions=()): ) def _generate_dummy_input( - self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str] + self, model: "PreTrainedModel", input_name: str, shape: List[int], input_names: List[str] ) -> Dict[str, torch.Tensor]: """Generates dummy input for model inference recording.""" # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored @@ -951,6 +1051,11 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas) kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas) + should_install_metadata = True + + self._disable_module_getattr = True + self._disable_call_module = True + if kind == "call_function": meta_target = _MANUAL_META_OVERRIDES.get(target, target) meta_out = meta_target(*args_metas, **kwargs_metas) @@ -963,39 +1068,36 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr elif kind == "call_module": if not hasattr(self, "orig_forward"): raise AttributeError(f"{self} does not have an attribute called orig_forward") - self._disable_module_getattr = True - try: - mod = self.root.get_submodule(target) - mod_type = type(mod) - if mod_type in _MANUAL_META_OVERRIDES: - meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas) - else: - meta_out = self.orig_forward(*args_metas, **kwargs_metas) - finally: - self._disable_module_getattr = False + mod = self.root.get_submodule(target) + mod_type = type(mod) + if mod_type in _MANUAL_META_OVERRIDES: + meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas) + else: + meta_out = self.orig_forward(*args_metas, **kwargs_metas) elif kind == "get_attr": - self._disable_module_getattr = True - try: - attr_itr = self.root - atoms = target.split(".") - for atom in atoms: - attr_itr = getattr(attr_itr, atom) - if isinstance(attr_itr, torch.Tensor): - meta_out = attr_itr.to(device="meta") - else: - meta_out = attr_itr - finally: - self._disable_module_getattr = False + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.Tensor): + meta_out = attr_itr.to(device="meta") + else: + meta_out = attr_itr else: - return rv + should_install_metadata = False + + if should_install_metadata: + if not isinstance(rv, Proxy): + raise ValueError("Don't support composite output yet") + rv.install_metadata(meta_out) - if not isinstance(rv, Proxy): - raise ValueError("Don't support composite output yet") - rv.install_metadata(meta_out) except Exception as e: if _IS_IN_DEBUG_MODE: warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") + self._disable_module_getattr = False + self._disable_call_module = False + return rv # Replaced by .getattr from PyTorch 1.13 @@ -1041,12 +1143,51 @@ def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any return self._module_getattr(attr, attr_val, parameter_proxy_cache) def call_module(self, m, forward, args, kwargs): + if getattr(self, "_disable_call_module", False): + return forward(*args, **kwargs) self.orig_forward = forward return super().call_module(m, forward, args, kwargs) def proxy(self, node): return HFProxy(node, self) + @contextlib.contextmanager + def patch_for_tracing(self, root: Union[torch.nn.Module, Callable[..., Any]]): + # Patching torch functions + self.patched_torch_methods = { + target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH + } + self.orig_fns = set() + + for name, (wrapper, orig) in self.patched_torch_methods.items(): + setattr(torch, name, wrapper) + self.orig_fns.add(orig) + + # Patching classes + patched = [] + module_of_model = inspect.getmodule(root) + for name, mod in sys.modules.items(): + if module_of_model is not None and mod is not module_of_model: + continue + if not name.startswith("transformers"): + continue + for orig_cls, patched_cls in self._CLASSES_TO_PATCH.items(): + for attr_name, attr in mod.__dict__.items(): + if attr is orig_cls: + patched.append((mod, attr_name, orig_cls)) + setattr(mod, attr_name, patched_cls) + + yield + + # Restoring patched functions and classes. + for name, (_, orig) in self.patched_torch_methods.items(): + setattr(torch, name, orig) + self.patched_torch_methods = {} + self.orig_fns = set() + + for mod, attr_name, orig_cls in patched: + setattr(mod, attr_name, orig_cls) + def trace( self, root: Union[torch.nn.Module, Callable[..., Any]], @@ -1125,28 +1266,25 @@ def trace( " transformers.PreTrainedModel." ) - concrete_metas = { - input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_ - for input_name, input_ in inputs.items() - } + def to_meta(value): + if isinstance(value, torch.Tensor): + return value.to("meta") + return value + + concrete_metas = pytree.tree_map(to_meta, inputs) + for param in sig.parameters.values(): if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names: concrete_metas[f"**{param.name}"] = {} self.meta_args = concrete_metas - self.patched_torch_methods = { - target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH - } - self.orig_fns = set() - for name, (wrapper, orig) in self.patched_torch_methods.items(): - setattr(torch, name, wrapper) - self.orig_fns.add(orig) - - try: - self.graph = super().trace(root, concrete_args=concrete_args) - finally: - for name, (_, orig) in self.patched_torch_methods.items(): - setattr(torch, name, orig) + global _CURRENT_TRACER + _CURRENT_TRACER = self + with self.patch_for_tracing(root): + try: + self.graph = super().trace(root, concrete_args=concrete_args) + finally: + _CURRENT_TRACER = None # This is necessary because concrete args are added as input to the traced module since # https://github.com/pytorch/pytorch/pull/55888. @@ -1256,11 +1394,11 @@ def get_concrete_args(model: nn.Module, input_names: List[str]): return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names} -def is_model_supported(model: PreTrainedModel): +def is_model_supported(model: "PreTrainedModel"): return model.__class__.__name__ in _SUPPORTED_MODELS -def check_if_model_is_supported(model: PreTrainedModel): +def check_if_model_is_supported(model: "PreTrainedModel"): if not is_model_supported(model): supported_model_names = ", ".join(_SUPPORTED_MODELS) raise NotImplementedError( @@ -1269,7 +1407,7 @@ def check_if_model_is_supported(model: PreTrainedModel): def symbolic_trace( - model: PreTrainedModel, + model: "PreTrainedModel", input_names: Optional[List[str]] = None, disable_check: bool = False, tracer_cls: Type[HFTracer] = HFTracer, @@ -1307,6 +1445,18 @@ def symbolic_trace( if not disable_check: check_if_model_is_supported(model) + if "past_key_values" in input_names and not getattr(model.config, "use_cache", False): + logger.warning( + "`past_key_values` were specified as input names, but model.config.use_cache = False, this might lead to " + "unexpected behavior." + ) + if "past_key_values" not in input_names and getattr(model.config, "use_cache", False): + logger.warning( + "`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting " + "model.config.use_cache = False." + ) + model.config.use_cache = False + # Tracing. tracer = tracer_cls() traced_graph = tracer.trace(model, concrete_args=concrete_args) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 71cb28d7548555..c38c7f66d1e5d8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -18,7 +18,6 @@ import inspect import os import os.path -import pickle import random import re import tempfile @@ -1279,26 +1278,6 @@ def flatten_output(output): f"traced {i}th output doesn't match model {i}th output for {model_class}", ) - # Test that the model can be serialized and restored properly - with tempfile.TemporaryDirectory() as tmp_dir_name: - pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") - try: - with open(pkl_file_name, "wb") as f: - pickle.dump(traced_model, f) - with open(pkl_file_name, "rb") as f: - loaded = pickle.load(f) - except Exception as e: - self.fail(f"Couldn't serialize / deserialize the traced model: {e}") - - loaded_output = loaded(**filtered_inputs) - loaded_output = flatten_output(loaded_output) - - for i in range(num_outputs): - self.assertTrue( - torch.allclose(model_output[i], loaded_output[i]), - f"serialized model {i}th output doesn't match model {i}th output for {model_class}", - ) - # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. # (Even with this call, there are still memory leak by ~0.04MB) self.clear_torch_jit_class_registry()