diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 2ca5a1f05a7..a72d6bf9ec1 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -911,7 +911,7 @@ def __call__( ) -class FlaxBartPretrainedModel(FlaxPreTrainedModel): +class FlaxBartPreTrainedModel(FlaxPreTrainedModel): config_class = BartConfig base_model_prefix: str = "model" module_class: nn.Module = None @@ -1232,7 +1232,7 @@ def __call__( "The bare Bart Model transformer outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING, ) -class FlaxBartModel(FlaxBartPretrainedModel): +class FlaxBartModel(FlaxBartPreTrainedModel): config: BartConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation module_class = FlaxBartModule @@ -1318,7 +1318,7 @@ def __call__( @add_start_docstrings( "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING ) -class FlaxBartForConditionalGeneration(FlaxBartPretrainedModel): +class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel): module_class = FlaxBartForConditionalGenerationModule dtype: jnp.dtype = jnp.float32 @@ -1623,7 +1623,7 @@ def __call__( """, BART_START_DOCSTRING, ) -class FlaxBartForSequenceClassification(FlaxBartPretrainedModel): +class FlaxBartForSequenceClassification(FlaxBartPreTrainedModel): module_class = FlaxBartForSequenceClassificationModule dtype = jnp.float32 @@ -1710,7 +1710,7 @@ def __call__( """, BART_START_DOCSTRING, ) -class FlaxBartForQuestionAnswering(FlaxBartPretrainedModel): +class FlaxBartForQuestionAnswering(FlaxBartPreTrainedModel): module_class = FlaxBartForQuestionAnsweringModule dtype = jnp.float32