diff --git a/docs/source/bettertransformer/overview.mdx b/docs/source/bettertransformer/overview.mdx index 9b1ba9e130..1f4bf80e49 100644 --- a/docs/source/bettertransformer/overview.mdx +++ b/docs/source/bettertransformer/overview.mdx @@ -43,6 +43,7 @@ The list of supported model below: - [M2M100](https://arxiv.org/abs/2010.11125) - [RemBERT](https://arxiv.org/abs/2010.12821) - [RoBERTa](https://arxiv.org/abs/1907.11692) +- [RoCBert](https://aclanthology.org/2022.acl-long.65.pdf) - [Splinter](https://arxiv.org/abs/2101.00438) - [Tapas](https://arxiv.org/abs/2211.06550) - [ViLT](https://arxiv.org/abs/2102.03334) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index f34766a2e4..483460bda3 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -49,6 +49,7 @@ class BetterTransformerManager: "mbart": ("MBartEncoderLayer", MBartEncoderLayerBetterTransformer), "rembert": ("RemBertLayer", BertLayerBetterTransformer), "roberta": ("RobertaLayer", BertLayerBetterTransformer), + "roc_bert": ("RoCBertLayer", BertLayerBetterTransformer), "splinter": ("SplinterLayer", BertLayerBetterTransformer), "tapas": ("TapasLayer", BertLayerBetterTransformer), "vilt": ("ViltLayer", ViltLayerBetterTransformer), diff --git a/tests/bettertransformer/test_bettertransformer_encoder.py b/tests/bettertransformer/test_bettertransformer_encoder.py index 3a12e9bceb..6481a8aa5c 100644 --- a/tests/bettertransformer/test_bettertransformer_encoder.py +++ b/tests/bettertransformer/test_bettertransformer_encoder.py @@ -255,6 +255,14 @@ def test_accelerate_compatibility_single_gpu_without_keeping(self): self.check_accelerate_compatibility_cpu_gpu(keep_original_model=False, max_memory=max_memory) +class BetterTransformersRoCBertTest(BetterTransformersEncoderTest): + all_models_to_test = ["hf-internal-testing/tiny-random-RoCBertModel"] + + # unrelated issue with torch.amp.autocast with rocbert (expected scalar type BFloat16 but found Float) + def test_raise_autocast(self): + pass + + class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest.TestCase): r""" Full testing suite of the `BetterTransformers` integration into Hugging Face