Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax/Jax documentation #8331

Merged
merged 10 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ jobs:
- v0.4-build_doc-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[tf,torch,sentencepiece,docs]
- run: pip install ."[all, docs]"
- save_cache:
key: v0.4-build_doc-{{ checksum "setup.py" }}
paths:
Expand Down
7 changes: 7 additions & 0 deletions docs/source/model_doc/bert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,10 @@ TFBertForQuestionAnswering

.. autoclass:: transformers.TFBertForQuestionAnswering
:members: call


FlaxBertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxBertModel
:members: __call__
7 changes: 7 additions & 0 deletions docs/source/model_doc/roberta.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,10 @@ TFRobertaForQuestionAnswering

.. autoclass:: transformers.TFRobertaForQuestionAnswering
:members: call


FlaxRobertaModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxRobertaModel
:members: __call__
58 changes: 24 additions & 34 deletions src/transformers/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax.numpy as jnp

from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings
from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
from .utils import logging

Expand All @@ -35,13 +35,20 @@

BERT_START_DOCSTRING = r"""

This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
pruning heads etc.)
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)

This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
general usage and behavior.
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__

Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Expand All @@ -52,50 +59,32 @@

BERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.

Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
:meth:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
details.

`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:

- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.

`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
`What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.

`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
Expand Down Expand Up @@ -291,7 +280,7 @@ class FlaxBertModule(nn.Module):
intermediate_size: int

@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):

# Embedding
embeddings = FlaxBertEmbeddings(
Expand Down Expand Up @@ -410,7 +399,8 @@ def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs):
def module(self) -> nn.Module:
return self._module

def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)

Expand All @@ -423,7 +413,7 @@ def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_
return self.model.apply(
{"params": self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
)
60 changes: 25 additions & 35 deletions src/transformers/modeling_flax_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax.numpy as jnp

from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings
from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
from .utils import logging

Expand All @@ -34,13 +34,20 @@

ROBERTA_START_DOCSTRING = r"""

This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
pruning heads etc.)
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)

This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
general usage and behavior.
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__

Parameters:
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
Expand All @@ -51,50 +58,32 @@

ROBERTA_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.

Indices can be obtained using :class:`~transformers.RobertaTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
details.

`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:

- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.

`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
`What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.

`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
Expand Down Expand Up @@ -302,7 +291,7 @@ class FlaxRobertaModule(nn.Module):
intermediate_size: int

@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):

# Embedding
embeddings = FlaxRobertaEmbeddings(
Expand Down Expand Up @@ -421,7 +410,8 @@ def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs):
def module(self) -> nn.Module:
return self._module

def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def __call__(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None):
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)

Expand All @@ -436,7 +426,7 @@ def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_
return self.model.apply(
{"params": self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
)