From e5f8560483c83a4be4978d47e54bde5f10f9c50c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 22 Jun 2021 16:17:09 -0400 Subject: [PATCH] FlaxBartPretrainedModel -> FlaxBartPreTrainedModel --- src/transformers/models/bart/modeling_flax_bart.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 2ca5a1f05a7..a72d6bf9ec1 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -911,7 +911,7 @@ def __call__( ) -class FlaxBartPretrainedModel(FlaxPreTrainedModel): +class FlaxBartPreTrainedModel(FlaxPreTrainedModel): config_class = BartConfig base_model_prefix: str = "model" module_class: nn.Module = None @@ -1232,7 +1232,7 @@ def __call__( "The bare Bart Model transformer outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING, ) -class FlaxBartModel(FlaxBartPretrainedModel): +class FlaxBartModel(FlaxBartPreTrainedModel): config: BartConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation module_class = FlaxBartModule @@ -1318,7 +1318,7 @@ def __call__( @add_start_docstrings( "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING ) -class FlaxBartForConditionalGeneration(FlaxBartPretrainedModel): +class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel): module_class = FlaxBartForConditionalGenerationModule dtype: jnp.dtype = jnp.float32 @@ -1623,7 +1623,7 @@ def __call__( """, BART_START_DOCSTRING, ) -class FlaxBartForSequenceClassification(FlaxBartPretrainedModel): +class FlaxBartForSequenceClassification(FlaxBartPreTrainedModel): module_class = FlaxBartForSequenceClassificationModule dtype = jnp.float32 @@ -1710,7 +1710,7 @@ def __call__( """, BART_START_DOCSTRING, ) -class FlaxBartForQuestionAnswering(FlaxBartPretrainedModel): +class FlaxBartForQuestionAnswering(FlaxBartPreTrainedModel): module_class = FlaxBartForQuestionAnsweringModule dtype = jnp.float32