-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax/Jax documentation #8331
Flax/Jax documentation #8331
Changes from 5 commits
39d7950
bc31764
c720c09
e9e0935
05464cb
b44cb92
9e55220
51dedc7
417f60b
98a4f5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -22,7 +22,7 @@ | |||||
import jax.numpy as jnp | ||||||
|
||||||
from .configuration_bert import BertConfig | ||||||
from .file_utils import add_start_docstrings | ||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward | ||||||
from .modeling_flax_utils import FlaxPreTrainedModel, gelu | ||||||
from .utils import logging | ||||||
|
||||||
|
@@ -35,13 +35,19 @@ | |||||
|
||||||
BERT_START_DOCSTRING = r""" | ||||||
|
||||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | ||||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | ||||||
pruning heads etc.) | ||||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the | ||||||
generic methods the library implements for all its model (such as downloading, saving and converting weights from | ||||||
PyTorch models) | ||||||
|
||||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ | ||||||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to | ||||||
general usage and behavior. | ||||||
This model is also a Flax Linen `flax.nn.Module | ||||||
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax | ||||||
Module and refer to the Flax documentation for all matter related to general usage and behavior. | ||||||
|
||||||
Finally, this model supports inherent JAX features such as `Just-In-Time (JIT) compilation | ||||||
<https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__, `Automatic Differentiation | ||||||
<https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__, `Vectorization | ||||||
<https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__, and `Parallelization | ||||||
<https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__. | ||||||
|
||||||
Parameters: | ||||||
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. | ||||||
|
@@ -52,50 +58,32 @@ | |||||
|
||||||
BERT_INPUTS_DOCSTRING = r""" | ||||||
Args: | ||||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): | ||||||
input_ids (:obj:`Numpy array` of shape :obj:`({0})`): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
Indices of input sequence tokens in the vocabulary. | ||||||
|
||||||
Indices can be obtained using :class:`~transformers.BertTokenizer`. See | ||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | ||||||
:func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No reason to change that line. |
||||||
details. | ||||||
|
||||||
`What are input IDs? <../glossary.html#input-ids>`__ | ||||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): | ||||||
attention_mask (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | ||||||
|
||||||
- 1 for tokens that are **not masked**, | ||||||
- 0 for tokens that are **masked**. | ||||||
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__ | ||||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): | ||||||
token_type_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | ||||||
1]``: | ||||||
|
||||||
- 0 corresponds to a `sentence A` token, | ||||||
- 1 corresponds to a `sentence B` token. | ||||||
|
||||||
`What are token type IDs? <../glossary.html#token-type-ids>`_ | ||||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): | ||||||
`What are token type IDs? <../glossary.html#token-type-ids>`__ | ||||||
position_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | ||||||
config.max_position_embeddings - 1]``. | ||||||
|
||||||
`What are position IDs? <../glossary.html#position-ids>`_ | ||||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | ||||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | ||||||
|
||||||
- 1 indicates the head is **not masked**, | ||||||
- 0 indicates the head is **masked**. | ||||||
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): | ||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | ||||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | ||||||
vectors than the model's internal embedding lookup matrix. | ||||||
output_attentions (:obj:`bool`, `optional`): | ||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | ||||||
tensors for more detail. | ||||||
output_hidden_states (:obj:`bool`, `optional`): | ||||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | ||||||
more detail. | ||||||
return_dict (:obj:`bool`, `optional`): | ||||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | ||||||
""" | ||||||
|
@@ -291,7 +279,7 @@ class FlaxBertModule(nn.Module): | |||||
intermediate_size: int | ||||||
|
||||||
@nn.compact | ||||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): | ||||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids): | ||||||
|
||||||
# Embedding | ||||||
embeddings = FlaxBertEmbeddings( | ||||||
|
@@ -410,7 +398,8 @@ def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs): | |||||
def module(self) -> nn.Module: | ||||||
return self._module | ||||||
|
||||||
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): | ||||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) | ||||||
def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None): | ||||||
if token_type_ids is None: | ||||||
token_type_ids = jnp.ones_like(input_ids) | ||||||
|
||||||
|
@@ -423,7 +412,7 @@ def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_ | |||||
return self.model.apply( | ||||||
{"params": self.params}, | ||||||
jnp.array(input_ids, dtype="i4"), | ||||||
jnp.array(attention_mask, dtype="i4"), | ||||||
jnp.array(token_type_ids, dtype="i4"), | ||||||
jnp.array(position_ids, dtype="i4"), | ||||||
jnp.array(attention_mask, dtype="i4"), | ||||||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ | |
import jax.numpy as jnp | ||
|
||
from .configuration_roberta import RobertaConfig | ||
from .file_utils import add_start_docstrings | ||
from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward | ||
from .modeling_flax_utils import FlaxPreTrainedModel, gelu | ||
from .utils import logging | ||
|
||
|
@@ -34,13 +34,19 @@ | |
|
||
ROBERTA_START_DOCSTRING = r""" | ||
|
||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | ||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | ||
pruning heads etc.) | ||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the | ||
generic methods the library implements for all its model (such as downloading, saving and converting weights from | ||
PyTorch models) | ||
|
||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ | ||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to | ||
general usage and behavior. | ||
This model is also a Flax Linen `flax.nn.Module | ||
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax | ||
Module and refer to the Flax documentation for all matter related to general usage and behavior. | ||
|
||
Finally, this model supports inherent JAX features such as `Just-In-Time (JIT) compilation | ||
<https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__, `Automatic Differentiation | ||
<https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__, `Vectorization | ||
<https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__, and `Parallelization | ||
<https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__. | ||
|
||
Parameters: | ||
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the | ||
|
@@ -51,50 +57,32 @@ | |
|
||
ROBERTA_INPUTS_DOCSTRING = r""" | ||
Args: | ||
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): | ||
input_ids (:obj:`Numpy array` of shape :obj:`({0})`): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same in this file: Numpy array -> numpy.ndarray in all docstrings. |
||
Indices of input sequence tokens in the vocabulary. | ||
|
||
Indices can be obtained using :class:`~transformers.RobertaTokenizer`. See | ||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | ||
Indices can be obtained using :class:`~transformers.BertTokenizer`. See | ||
:func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for | ||
details. | ||
|
||
`What are input IDs? <../glossary.html#input-ids>`__ | ||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): | ||
attention_mask (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): | ||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | ||
|
||
- 1 for tokens that are **not masked**, | ||
- 0 for tokens that are **masked**. | ||
|
||
`What are attention masks? <../glossary.html#attention-mask>`__ | ||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): | ||
token_type_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): | ||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | ||
1]``: | ||
|
||
- 0 corresponds to a `sentence A` token, | ||
- 1 corresponds to a `sentence B` token. | ||
|
||
`What are token type IDs? <../glossary.html#token-type-ids>`_ | ||
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): | ||
`What are token type IDs? <../glossary.html#token-type-ids>`__ | ||
position_ids (:obj:`Numpy array` of shape :obj:`({0})`, `optional`): | ||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | ||
config.max_position_embeddings - 1]``. | ||
|
||
`What are position IDs? <../glossary.html#position-ids>`_ | ||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | ||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | ||
|
||
- 1 indicates the head is **not masked**, | ||
- 0 indicates the head is **masked**. | ||
|
||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): | ||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | ||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | ||
vectors than the model's internal embedding lookup matrix. | ||
output_attentions (:obj:`bool`, `optional`): | ||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | ||
tensors for more detail. | ||
output_hidden_states (:obj:`bool`, `optional`): | ||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | ||
more detail. | ||
return_dict (:obj:`bool`, `optional`): | ||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | ||
""" | ||
|
@@ -302,7 +290,7 @@ class FlaxRobertaModule(nn.Module): | |
intermediate_size: int | ||
|
||
@nn.compact | ||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): | ||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids): | ||
|
||
# Embedding | ||
embeddings = FlaxRobertaEmbeddings( | ||
|
@@ -421,7 +409,8 @@ def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs): | |
def module(self) -> nn.Module: | ||
return self._module | ||
|
||
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): | ||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) | ||
def __call__(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None): | ||
if token_type_ids is None: | ||
token_type_ids = jnp.ones_like(input_ids) | ||
|
||
|
@@ -436,7 +425,7 @@ def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_ | |
return self.model.apply( | ||
{"params": self.params}, | ||
jnp.array(input_ids, dtype="i4"), | ||
jnp.array(attention_mask, dtype="i4"), | ||
jnp.array(token_type_ids, dtype="i4"), | ||
jnp.array(position_ids, dtype="i4"), | ||
jnp.array(attention_mask, dtype="i4"), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this be nicer in a list?