Skip to content

Commit

Permalink
add FlaxAutoModelForImageClassification in main init (#12298)
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj authored Jun 22, 2021
1 parent 2affeb2 commit 1498eb9
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 2 deletions.
7 changes: 7 additions & 0 deletions docs/source/model_doc/auto.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,10 @@ FlaxAutoModelForNextSentencePrediction

.. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction
:members:


FlaxAutoModelForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxAutoModelForImageClassification
:members:
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,7 @@
_import_structure["models.auto"].extend(
[
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
Expand All @@ -1520,6 +1521,7 @@
"FLAX_MODEL_MAPPING",
"FlaxAutoModel",
"FlaxAutoModelForCausalLM",
"FlaxAutoModelForImageClassification",
"FlaxAutoModelForMaskedLM",
"FlaxAutoModelForMultipleChoice",
"FlaxAutoModelForNextSentencePrediction",
Expand Down Expand Up @@ -2848,6 +2850,7 @@
from .modeling_flax_utils import FlaxPreTrainedModel
from .models.auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
Expand All @@ -2859,6 +2862,7 @@
FLAX_MODEL_MAPPING,
FlaxAutoModel,
FlaxAutoModelForCausalLM,
FlaxAutoModelForImageClassification,
FlaxAutoModelForMaskedLM,
FlaxAutoModelForMultipleChoice,
FlaxAutoModelForNextSentencePrediction,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
if is_flax_available():
_import_structure["modeling_flax_auto"] = [
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
Expand Down Expand Up @@ -175,6 +176,7 @@
if is_flax_available():
from .modeling_flax_auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
]
)

FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Image-classsification
(ViTConfig, FlaxViTForImageClassification),
Expand Down Expand Up @@ -188,7 +188,7 @@

FlaxAutoModelForImageClassification = auto_class_factory(
"FlaxAutoModelForImageClassification",
FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
head_doc="image classification modeling",
)

Expand Down
12 changes: 12 additions & 0 deletions src/transformers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def from_pretrained(cls, *args, **kwargs):
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None


FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None


FLAX_MODEL_FOR_MASKED_LM_MAPPING = None


Expand Down Expand Up @@ -124,6 +127,15 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxAutoModelForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])


class FlaxAutoModelForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
Expand Down

0 comments on commit 1498eb9

Please sign in to comment.