Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax BERT/Roberta] few small fixes #11558

Merged
merged 2 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions src/transformers/models/bert/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention
from jax import lax
from jax.random import PRNGKey

from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import (
Expand Down Expand Up @@ -92,9 +91,9 @@ class FlaxBertForPreTrainingOutput(ModelOutput):
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)

This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
This model is also a Flax Linen `flax.linen.Module
<https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

Expand All @@ -106,8 +105,8 @@ class FlaxBertForPreTrainingOutput(ModelOutput):
Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""

BERT_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -173,15 +172,13 @@ def setup(self):
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
batch_size, sequence_length = input_ids.shape
# Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
position_embeds = self.position_embeddings(position_ids.astype("i4"))
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))

# Sum all embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))

# Layer Norm
hidden_states = self.LayerNorm(hidden_states)
Expand Down Expand Up @@ -571,7 +568,7 @@ def __call__(
token_type_ids=None,
position_ids=None,
params: dict = None,
dropout_rng: PRNGKey = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down
12 changes: 5 additions & 7 deletions src/transformers/models/roberta/modeling_flax_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)

This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
This model is also a Flax Linen `flax.linen.Module
<https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

Expand All @@ -73,8 +73,8 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
Parameters:
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""

ROBERTA_INPUTS_DOCSTRING = r"""
Expand Down Expand Up @@ -140,15 +140,13 @@ def setup(self):
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
batch_size, sequence_length = input_ids.shape
# Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
position_embeds = self.position_embeddings(position_ids.astype("i4"))
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))

# Sum all embeddings
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))

# Layer Norm
hidden_states = self.LayerNorm(hidden_states)
Expand Down