Skip to content

Commit

Permalink
fix mess
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Feb 6, 2023
1 parent 59d5ede commit c6dd4b9
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
"is_clearml_available",
"is_comet_available",
"is_neptune_available",
"is_optimum_available",
"is_optuna_available",
"is_ray_available",
"is_ray_tune_available",
Expand Down Expand Up @@ -3583,6 +3584,7 @@
is_clearml_available,
is_comet_available,
is_neptune_available,
is_optimum_available,
is_optuna_available,
is_ray_available,
is_ray_tune_available,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def is_tensorboard_available():
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None


def is_optimum_available():
return importlib.util.find_spec("optimum") is not None


def is_optuna_available():
return importlib.util.find_spec("optuna") is not None

Expand Down
51 changes: 51 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, GenerationMixin
from .integrations import is_optimum_available
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
Expand Down Expand Up @@ -3014,6 +3015,56 @@ def register_for_auto_class(cls, auto_class="AutoModel"):

cls._auto_class = auto_class

def to_bettertransformer(self) -> "PreTrainedModel":
"""
Converts the model to use [PyTorch's native attention
implementation](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), integrated to
Transformers through [Optimum library](https://huggingface.co/docs/optimum/bettertransformer/overview). Only a
subset of all Transformers models are supported.
PyTorch's attention fastpath allows to speed up inference through kernel fusions and the use of [nested
tensors](https://pytorch.org/docs/stable/nested.html). Detailed benchmarks can be found in [this blog
post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2).
Returns:
[`PreTrainedModel`]: The model converted to BetterTransformer.
"""
if not is_optimum_available():
raise ImportError("The package `optimum` is required to use Better Transformer.")

from optimum.version import __version__ as optimum_version

if version.parse(optimum_version) < version.parse("1.7.0"):
raise ImportError(
f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
)

from optimum.bettertransformer import BetterTransformer

return BetterTransformer.transform(self)

def reverse_bettertransformer(self):
"""
Reverts the transformation from [`~to_bettertransformer`] so that the original modeling is used, for example in
order to save the model.
Returns:
[`PreTrainedModel`]: The model converted back to the original modeling.
"""
if not is_optimum_available():
raise ImportError("The package `optimum` is required to use Better Transformer.")

from optimum.version import __version__ as optimum_version

if version.parse(optimum_version) < version.parse("1.7.0"):
raise ImportError(
f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
)

from optimum.bettertransformer import BetterTransformer

return BetterTransformer.reverse(self)


PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
Expand Down

0 comments on commit c6dd4b9

Please sign in to comment.