Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Fix incomplete batches in example scripts #17863

Merged
merged 14 commits into from
Jul 27, 2022
Merged
44 changes: 26 additions & 18 deletions examples/flax/language-modeling/run_clm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import optax
import transformers
from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
Expand Down Expand Up @@ -264,20 +264,24 @@ def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))


def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
Shuffle batches if `shuffle` is `True`.
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
steps_per_epoch = len(dataset) // batch_size

if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = jnp.arange(len(dataset))
batch_idx = np.arange(len(dataset))

batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)

for idx in batch_idx:
batch = dataset[idx]
Expand Down Expand Up @@ -621,7 +625,8 @@ def group_texts(examples):
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs

Expand Down Expand Up @@ -764,13 +769,14 @@ def eval_step(params, batch):
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ==============================
eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
eval_steps = len(eval_dataset) // eval_batch_size
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = next(eval_loader)
batch = shard(batch)
metrics = p_eval_step(state.params, batch)
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)

# normalize eval metrics
Expand Down Expand Up @@ -806,12 +812,14 @@ def eval_step(params, batch):
# Eval after training
if training_args.do_eval:
eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
eval_steps = len(eval_dataset) // eval_batch_size
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = shard(next(eval_loader))
metrics = p_eval_step(state.params, batch)
batch = next(eval_loader)
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)

# normalize eval metrics
Expand Down
40 changes: 25 additions & 15 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import jax.numpy as jnp
import optax
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
Expand Down Expand Up @@ -326,15 +327,20 @@ def mask_tokens(
return inputs, labels


def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size

if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
batch_idx = np.split(samples_idx, sections_split)
return batch_idx
if drop_last:
samples_to_remove = num_samples % batch_size
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
samples_idx = samples_idx.reshape((sections_split, batch_size))
else:
sections_split = math.ceil(num_samples / batch_size)
samples_idx = np.array_split(samples_idx, sections_split)
return samples_idx


def write_train_metric(summary_writer, train_metrics, train_time, step):
Expand Down Expand Up @@ -632,12 +638,14 @@ def group_texts(examples):
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
use_auth_token=True if model_args.use_auth_token else None,
)

# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()

num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs

Expand Down Expand Up @@ -796,16 +804,17 @@ def eval_step(params, batch):
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)

eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)

# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)

# normalize eval metrics
Expand Down Expand Up @@ -835,16 +844,17 @@ def eval_step(params, batch):
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)

eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)

# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)

# normalize eval metrics
Expand Down
43 changes: 28 additions & 15 deletions examples/flax/language-modeling/run_t5_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""
import json
import logging
import math
import os
import sys
import time
Expand All @@ -41,6 +42,7 @@
import jax.numpy as jnp
import optax
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
Expand Down Expand Up @@ -326,6 +328,7 @@ class FlaxDataCollatorForT5MLM:
decoder_start_token_id: int

def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:

# convert list to dict and tensorize input
batch = BatchEncoding(
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
Expand Down Expand Up @@ -394,6 +397,7 @@ def filter_input_ids(self, input_ids, sentinel_ids):
return input_ids

def random_spans_noise_mask(self, length):

"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .

Noise mask consisting of random spans of noise tokens.
Expand Down Expand Up @@ -457,15 +461,20 @@ def _random_segmentation(num_items, num_segments):
return is_noise[:orig_length]


def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size

if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
batch_idx = np.split(samples_idx, sections_split)
return batch_idx
if drop_last:
samples_to_remove = num_samples % batch_size
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
samples_idx = samples_idx.reshape((sections_split, batch_size))
else:
sections_split = math.ceil(num_samples / batch_size)
samples_idx = np.array_split(samples_idx, sections_split)
return samples_idx


def write_train_metric(summary_writer, train_metrics, train_time, step):
Expand Down Expand Up @@ -737,6 +746,7 @@ def group_texts(examples):
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
use_auth_token=True if model_args.use_auth_token else None,
)

# Data collator
Expand All @@ -754,7 +764,8 @@ def group_texts(examples):
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()

num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs

Expand Down Expand Up @@ -915,16 +926,17 @@ def eval_step(params, batch):
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)

eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)

# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)

# get eval metrics
Expand Down Expand Up @@ -952,16 +964,17 @@ def eval_step(params, batch):
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)

eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)

# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)

# get eval metrics
Expand Down
Loading