Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FlaxRoberta] Add FlaxRobertaModels & adapt run_mlm_flax.py #11470

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions docs/source/model_doc/roberta.rst
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,38 @@ FlaxRobertaModel

.. autoclass:: transformers.FlaxRobertaModel
:members: __call__


FlaxRobertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxRobertaForMaskedLM
:members: __call__


FlaxRobertaForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxRobertaForSequenceClassification
:members: __call__


FlaxRobertaForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxRobertaForMultipleChoice
:members: __call__


FlaxRobertaForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxRobertaForTokenClassification
:members: __call__


FlaxRobertaForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxRobertaForQuestionAnswering
:members: __call__
116 changes: 87 additions & 29 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig,
AutoTokenizer,
FlaxBertForMaskedLM,
FlaxAutoModelForMaskedLM,
HfArgumentParser,
PreTrainedTokenizerBase,
TensorType,
Expand Down Expand Up @@ -105,6 +105,12 @@ class ModelArguments:
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
},
)


@dataclass
Expand Down Expand Up @@ -162,6 +168,10 @@ class DataTrainingArguments:
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
line_by_line: bool = field(
default=False,
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
)

def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
Expand Down Expand Up @@ -537,27 +547,76 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

padding = "max_length" if data_args.pad_to_max_length else False

def tokenize_function(examples):
# Remove empty lines
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
return tokenizer(
examples,
return_special_tokens_mask=True,
padding=padding,
truncation=True,
max_length=data_args.max_seq_length,
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

if data_args.line_by_line:
# When using line_by_line, we just tokenize each nonempty line.
padding = "max_length" if data_args.pad_to_max_length else False

def tokenize_function(examples):
# Remove empty lines
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
return tokenizer(
examples,
return_special_tokens_mask=True,
padding=padding,
truncation=True,
max_length=max_seq_length,
)

tokenized_datasets = datasets.map(
tokenize_function,
input_columns=[text_column_name],
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)

tokenized_datasets = datasets.map(
tokenize_function,
input_columns=[text_column_name],
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
else:
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
# efficient when it receives the `special_tokens_mask`.
def tokenize_function(examples):
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)

tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)

# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len.
result = {
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
for k, t in concatenated_examples.items()
}
return result

# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
# might be slower to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

tokenized_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)

# Enable tensorboard only on the master node
if has_tensorboard and jax.host_id() == 0:
Expand All @@ -571,13 +630,7 @@ def tokenize_function(examples):
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

model = FlaxBertForMaskedLM.from_pretrained(
"bert-base-cased",
dtype=jnp.float32,
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
seed=training_args.seed,
dropout_rate=0.1,
)
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))

# Setup optimizer
optimizer = Adam(
Expand All @@ -602,8 +655,8 @@ def tokenize_function(examples):

# Store some constant
nb_epochs = int(training_args.num_train_epochs)
batch_size = int(training_args.train_batch_size)
eval_batch_size = int(training_args.eval_batch_size)
batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()

epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
for epoch in epochs:
Expand Down Expand Up @@ -657,3 +710,8 @@ def tokenize_function(examples):
if has_tensorboard and jax.host_id() == 0:
for name, value in eval_summary.items():
summary_writer.scalar(name, value, epoch)

# save last checkpoint
if jax.host_id() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], optimizer.target))
model.save_pretrained(training_args.output_dir, params=params)
22 changes: 20 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,17 @@
"FlaxBertPreTrainedModel",
]
)
_import_structure["models.roberta"].append("FlaxRobertaModel")
_import_structure["models.roberta"].extend(
[
"FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering",
"FlaxRobertaForSequenceClassification",
"FlaxRobertaForTokenClassification",
"FlaxRobertaModel",
"FlaxRobertaPreTrainedModel",
]
)
else:
from .utils import dummy_flax_objects

Expand Down Expand Up @@ -2575,7 +2585,15 @@
FlaxBertModel,
FlaxBertPreTrainedModel,
)
from .models.roberta import FlaxRobertaModel
from .models.roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
FlaxRobertaPreTrainedModel,
)
else:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,9 +1608,9 @@ def is_tensor(x):

if is_flax_available():
import jaxlib.xla_extension as jax_xla
from jax.interpreters.partial_eval import DynamicJaxprTracer
from jax.core import Tracer

if isinstance(x, (jax_xla.DeviceArray, DynamicJaxprTracer)):
if isinstance(x, (jax_xla.DeviceArray, Tracer)):
return True

return isinstance(x, np.ndarray)
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def from_pretrained(

return model

def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=False, **kwargs):
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method
Expand Down Expand Up @@ -416,7 +416,8 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=F
# save model
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
with open(output_model_file, "wb") as f:
model_bytes = to_bytes(self.params)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is required to allows weight saving in distributed mode - see run_mlm_flax.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like the state_dict argument in the PyTorch equivalent, completely fine by me!

params = params if params is not None else self.params
model_bytes = to_bytes(params)
f.write(model_bytes)

logger.info(f"Model weights saved in {output_model_file}")
Expand Down
15 changes: 14 additions & 1 deletion src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@
FlaxBertForTokenClassification,
FlaxBertModel,
)
from ..roberta.modeling_flax_roberta import FlaxRobertaModel
from ..roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
)
from .auto_factory import auto_class_factory
from .configuration_auto import BertConfig, RobertaConfig

Expand All @@ -47,41 +54,47 @@
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
# Model for pre-training mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForPreTraining),
]
)

FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
[
# Model for Masked LM mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForMaskedLM),
]
)

FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Sequence Classification mapping
(RobertaConfig, FlaxRobertaForSequenceClassification),
(BertConfig, FlaxBertForSequenceClassification),
]
)

FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
# Model for Question Answering mapping
(RobertaConfig, FlaxRobertaForQuestionAnswering),
(BertConfig, FlaxBertForQuestionAnswering),
]
)

FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Token Classification mapping
(RobertaConfig, FlaxRobertaForTokenClassification),
(BertConfig, FlaxBertForTokenClassification),
]
)

FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[
# Model for Multiple Choice mapping
(RobertaConfig, FlaxRobertaForMultipleChoice),
(BertConfig, FlaxBertForMultipleChoice),
]
)
Expand Down
20 changes: 18 additions & 2 deletions src/transformers/models/roberta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,15 @@
]

if is_flax_available():
_import_structure["modeling_flax_roberta"] = ["FlaxRobertaModel"]
_import_structure["modeling_flax_roberta"] = [
"FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering",
"FlaxRobertaForSequenceClassification",
"FlaxRobertaForTokenClassification",
"FlaxRobertaModel",
"FlaxRobertaPreTrainedModel",
]


if TYPE_CHECKING:
Expand Down Expand Up @@ -97,7 +105,15 @@
)

if is_flax_available():
from .modeling_flax_roberta import FlaxRobertaModel
from .modeling_tf_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
FlaxRobertaPreTrainedModel,
)

else:
import importlib
Expand Down
Loading