Skip to content

Commit

Permalink
[FlaxRoberta] Add FlaxRobertaModels & adapt run_mlm_flax.py (huggingf…
Browse files Browse the repository at this point in the history
…ace#11470)

* add flax roberta

* make style

* correct initialiazation

* modify model to save weights

* fix copied from

* fix copied from

* correct some more code

* add more roberta models

* Apply suggestions from code review

* merge from master

* finish

* finish docs

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
  • Loading branch information
2 people authored and Iwontbecreative committed Jul 15, 2021
1 parent 9369104 commit 005bb22
Show file tree
Hide file tree
Showing 11 changed files with 696 additions and 48 deletions.
35 changes: 35 additions & 0 deletions docs/source/model_doc/roberta.rst
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,38 @@ FlaxRobertaModel

.. autoclass:: transformers.FlaxRobertaModel
:members: __call__


FlaxRobertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxRobertaForMaskedLM
:members: __call__


FlaxRobertaForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxRobertaForSequenceClassification
:members: __call__


FlaxRobertaForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxRobertaForMultipleChoice
:members: __call__


FlaxRobertaForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxRobertaForTokenClassification
:members: __call__


FlaxRobertaForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxRobertaForQuestionAnswering
:members: __call__
116 changes: 87 additions & 29 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig,
AutoTokenizer,
FlaxBertForMaskedLM,
FlaxAutoModelForMaskedLM,
HfArgumentParser,
PreTrainedTokenizerBase,
TensorType,
Expand Down Expand Up @@ -105,6 +105,12 @@ class ModelArguments:
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
},
)


@dataclass
Expand Down Expand Up @@ -162,6 +168,10 @@ class DataTrainingArguments:
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
line_by_line: bool = field(
default=False,
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
)

def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
Expand Down Expand Up @@ -537,27 +547,76 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

padding = "max_length" if data_args.pad_to_max_length else False

def tokenize_function(examples):
# Remove empty lines
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
return tokenizer(
examples,
return_special_tokens_mask=True,
padding=padding,
truncation=True,
max_length=data_args.max_seq_length,
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

if data_args.line_by_line:
# When using line_by_line, we just tokenize each nonempty line.
padding = "max_length" if data_args.pad_to_max_length else False

def tokenize_function(examples):
# Remove empty lines
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
return tokenizer(
examples,
return_special_tokens_mask=True,
padding=padding,
truncation=True,
max_length=max_seq_length,
)

tokenized_datasets = datasets.map(
tokenize_function,
input_columns=[text_column_name],
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)

tokenized_datasets = datasets.map(
tokenize_function,
input_columns=[text_column_name],
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
else:
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
# efficient when it receives the `special_tokens_mask`.
def tokenize_function(examples):
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)

tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)

# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len.
result = {
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
for k, t in concatenated_examples.items()
}
return result

# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
# might be slower to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

tokenized_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)

# Enable tensorboard only on the master node
if has_tensorboard and jax.host_id() == 0:
Expand All @@ -571,13 +630,7 @@ def tokenize_function(examples):
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

model = FlaxBertForMaskedLM.from_pretrained(
"bert-base-cased",
dtype=jnp.float32,
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
seed=training_args.seed,
dropout_rate=0.1,
)
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))

# Setup optimizer
optimizer = Adam(
Expand All @@ -602,8 +655,8 @@ def tokenize_function(examples):

# Store some constant
nb_epochs = int(training_args.num_train_epochs)
batch_size = int(training_args.train_batch_size)
eval_batch_size = int(training_args.eval_batch_size)
batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()

epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
for epoch in epochs:
Expand Down Expand Up @@ -657,3 +710,8 @@ def tokenize_function(examples):
if has_tensorboard and jax.host_id() == 0:
for name, value in eval_summary.items():
summary_writer.scalar(name, value, epoch)

# save last checkpoint
if jax.host_id() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], optimizer.target))
model.save_pretrained(training_args.output_dir, params=params)
22 changes: 20 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,17 @@
"FlaxBertPreTrainedModel",
]
)
_import_structure["models.roberta"].append("FlaxRobertaModel")
_import_structure["models.roberta"].extend(
[
"FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering",
"FlaxRobertaForSequenceClassification",
"FlaxRobertaForTokenClassification",
"FlaxRobertaModel",
"FlaxRobertaPreTrainedModel",
]
)
else:
from .utils import dummy_flax_objects

Expand Down Expand Up @@ -2575,7 +2585,15 @@
FlaxBertModel,
FlaxBertPreTrainedModel,
)
from .models.roberta import FlaxRobertaModel
from .models.roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
FlaxRobertaPreTrainedModel,
)
else:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,9 +1608,9 @@ def is_tensor(x):

if is_flax_available():
import jaxlib.xla_extension as jax_xla
from jax.interpreters.partial_eval import DynamicJaxprTracer
from jax.core import Tracer

if isinstance(x, (jax_xla.DeviceArray, DynamicJaxprTracer)):
if isinstance(x, (jax_xla.DeviceArray, Tracer)):
return True

return isinstance(x, np.ndarray)
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def from_pretrained(

return model

def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=False, **kwargs):
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method
Expand Down Expand Up @@ -416,7 +416,8 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=F
# save model
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
with open(output_model_file, "wb") as f:
model_bytes = to_bytes(self.params)
params = params if params is not None else self.params
model_bytes = to_bytes(params)
f.write(model_bytes)

logger.info(f"Model weights saved in {output_model_file}")
Expand Down
15 changes: 14 additions & 1 deletion src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@
FlaxBertForTokenClassification,
FlaxBertModel,
)
from ..roberta.modeling_flax_roberta import FlaxRobertaModel
from ..roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
)
from .auto_factory import auto_class_factory
from .configuration_auto import BertConfig, RobertaConfig

Expand All @@ -47,41 +54,47 @@
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
# Model for pre-training mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForPreTraining),
]
)

FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
[
# Model for Masked LM mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForMaskedLM),
]
)

FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Sequence Classification mapping
(RobertaConfig, FlaxRobertaForSequenceClassification),
(BertConfig, FlaxBertForSequenceClassification),
]
)

FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
# Model for Question Answering mapping
(RobertaConfig, FlaxRobertaForQuestionAnswering),
(BertConfig, FlaxBertForQuestionAnswering),
]
)

FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Token Classification mapping
(RobertaConfig, FlaxRobertaForTokenClassification),
(BertConfig, FlaxBertForTokenClassification),
]
)

FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[
# Model for Multiple Choice mapping
(RobertaConfig, FlaxRobertaForMultipleChoice),
(BertConfig, FlaxBertForMultipleChoice),
]
)
Expand Down
20 changes: 18 additions & 2 deletions src/transformers/models/roberta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,15 @@
]

if is_flax_available():
_import_structure["modeling_flax_roberta"] = ["FlaxRobertaModel"]
_import_structure["modeling_flax_roberta"] = [
"FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering",
"FlaxRobertaForSequenceClassification",
"FlaxRobertaForTokenClassification",
"FlaxRobertaModel",
"FlaxRobertaPreTrainedModel",
]


if TYPE_CHECKING:
Expand Down Expand Up @@ -97,7 +105,15 @@
)

if is_flax_available():
from .modeling_flax_roberta import FlaxRobertaModel
from .modeling_tf_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
FlaxRobertaPreTrainedModel,
)

else:
import importlib
Expand Down
Loading

0 comments on commit 005bb22

Please sign in to comment.