From 1498eb9888d55d76385b45e074f26703cc5049f3 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 22 Jun 2021 18:26:05 +0530 Subject: [PATCH] add FlaxAutoModelForImageClassification in main init (#12298) --- docs/source/model_doc/auto.rst | 7 +++++++ src/transformers/__init__.py | 4 ++++ src/transformers/models/auto/__init__.py | 2 ++ src/transformers/models/auto/modeling_flax_auto.py | 4 ++-- src/transformers/utils/dummy_flax_objects.py | 12 ++++++++++++ 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index 69f67d7f56ff20..7ccfbdf87d5771 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -266,3 +266,10 @@ FlaxAutoModelForNextSentencePrediction .. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction :members: + + +FlaxAutoModelForImageClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAutoModelForImageClassification + :members: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index dad079d40e1c0b..0d702227807059 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1509,6 +1509,7 @@ _import_structure["models.auto"].extend( [ "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", + "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_MASKED_LM_MAPPING", "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", @@ -1520,6 +1521,7 @@ "FLAX_MODEL_MAPPING", "FlaxAutoModel", "FlaxAutoModelForCausalLM", + "FlaxAutoModelForImageClassification", "FlaxAutoModelForMaskedLM", "FlaxAutoModelForMultipleChoice", "FlaxAutoModelForNextSentencePrediction", @@ -2848,6 +2850,7 @@ from .modeling_flax_utils import FlaxPreTrainedModel from .models.auto import ( FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, + FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, @@ -2859,6 +2862,7 @@ FLAX_MODEL_MAPPING, FlaxAutoModel, FlaxAutoModelForCausalLM, + FlaxAutoModelForImageClassification, FlaxAutoModelForMaskedLM, FlaxAutoModelForMultipleChoice, FlaxAutoModelForNextSentencePrediction, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index d483b271b8734c..f0e16ca27dc78f 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -87,6 +87,7 @@ if is_flax_available(): _import_structure["modeling_flax_auto"] = [ "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", + "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_MASKED_LM_MAPPING", "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", @@ -175,6 +176,7 @@ if is_flax_available(): from .modeling_flax_auto import ( FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, + FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index be03814c3be7b9..dd3d3cd8092493 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -115,7 +115,7 @@ ] ) -FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING = OrderedDict( +FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict( [ # Model for Image-classsification (ViTConfig, FlaxViTForImageClassification), @@ -188,7 +188,7 @@ FlaxAutoModelForImageClassification = auto_class_factory( "FlaxAutoModelForImageClassification", - FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING, + FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, head_doc="image classification modeling", ) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 7ad7ee76b6cd15..0eea12143b48f2 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -79,6 +79,9 @@ def from_pretrained(cls, *args, **kwargs): FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None +FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None + + FLAX_MODEL_FOR_MASKED_LM_MAPPING = None @@ -124,6 +127,15 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) +class FlaxAutoModelForImageClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxAutoModelForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"])