diff --git a/.idea/deployment.xml b/.idea/deployment.xml index 1c96d28..6dff251 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,6 +1,6 @@ - + diff --git a/run_flax_speech_recognition_ctc.py b/run_flax_speech_recognition_ctc.py index c53525c..f42e845 100644 --- a/run_flax_speech_recognition_ctc.py +++ b/run_flax_speech_recognition_ctc.py @@ -19,6 +19,7 @@ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. import logging +import math import os import re import sys @@ -39,7 +40,7 @@ import transformers import wandb as wandb from flax import core, jax_utils, struct, traverse_util -from flax.jax_utils import unreplicate +from flax.jax_utils import unreplicate, pad_shard_unpad from flax.training.common_utils import get_metrics, shard, shard_prng_key from huggingface_hub import Repository from models.configuration_wav2vec2 import Wav2Vec2Config @@ -496,35 +497,20 @@ def get_grouped_indices( return megabatches -def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by - the batch size, the last incomplete batch is dropped.""" + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" num_samples = len(samples_idx) - samples_to_remove = num_samples % batch_size - - if samples_to_remove != 0: - samples_idx = samples_idx[:-samples_to_remove] - sections_split = num_samples // batch_size - return samples_idx.reshape((sections_split, batch_size)) - - -def data_loader(dataset, batch_size, rng, sampler, collator): - samples_idx = sampler(dataset, batch_size, rng) - - num_samples = len(samples_idx) - samples_to_remove = num_samples % batch_size - - if samples_to_remove != 0: - samples_idx = samples_idx[:-samples_to_remove] - sections_split = num_samples // batch_size - - batch_idx = np.split(samples_idx, sections_split) - - for idx in batch_idx: - samples = dataset[idx] - batch = collator(samples) - batch = shard(batch.data) - yield batch + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx def write_train_metric(summary_writer, train_metrics, train_time, step): @@ -1114,6 +1100,7 @@ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]): gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() batch_size_per_update = train_batch_size * gradient_accumulation_steps + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() to_dtype = to_bf16 if training_args.mixed_precision else to_fp32 @@ -1297,19 +1284,18 @@ def run_evaluation(step): # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx] batch = data_collator(samples) - batch = shard(batch.data) labels = batch["labels"] - metrics, pred_ids = p_eval_step(state.params, batch) + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) eval_metrics.append(metrics) - eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) + eval_labels.extend(labels) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) @@ -1424,19 +1410,18 @@ def save_checkpoint(step): # Generate eval set by sequentially sampling indices from the test dataset and grouping by length eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)): samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx] batch = data_collator(samples) - batch = shard(batch.data) labels = batch["labels"] - metrics, pred_ids = p_eval_step(state.params, batch) + metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))) eval_metrics.append(metrics) - eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) + eval_labels.extend(labels) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) diff --git a/run_flax_speech_recognition_ctc_ngram.py b/run_flax_speech_recognition_ctc_ngram.py index 6f757de..e0642ea 100644 --- a/run_flax_speech_recognition_ctc_ngram.py +++ b/run_flax_speech_recognition_ctc_ngram.py @@ -19,6 +19,7 @@ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. import logging +import math import os import re import sys @@ -39,7 +40,7 @@ import transformers import wandb as wandb from flax import core, jax_utils, struct, traverse_util -from flax.jax_utils import unreplicate +from flax.jax_utils import unreplicate, pad_shard_unpad from flax.training.common_utils import get_metrics, shard, shard_prng_key from huggingface_hub import Repository from models.configuration_wav2vec2 import Wav2Vec2Config @@ -500,16 +501,20 @@ def get_grouped_indices( return megabatches -def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by - the batch size, the last incomplete batch is dropped.""" + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" num_samples = len(samples_idx) - samples_to_remove = num_samples % batch_size - - if samples_to_remove != 0: - samples_idx = samples_idx[:-samples_to_remove] - sections_split = num_samples // batch_size - return samples_idx.reshape((sections_split, batch_size)) + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx def data_loader(dataset, batch_size, rng, sampler, collator): @@ -1132,6 +1137,7 @@ def compute_metrics(logits: List[np.ndarray], label_ids: List[List[int]]): gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() batch_size_per_update = train_batch_size * gradient_accumulation_steps + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() to_dtype = to_bf16 if training_args.mixed_precision else to_fp32 @@ -1313,19 +1319,18 @@ def run_evaluation(step): # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx] batch = data_collator(samples) - batch = shard(batch.data) labels = batch["labels"] - metrics, logits = p_eval_step(state.params, batch) + metrics, logits = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) eval_preds.extend(jax.device_get(logits.reshape(-1, *logits.shape[-2:]))) eval_metrics.append(metrics) - eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) + eval_labels.extend(labels) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) @@ -1437,19 +1442,18 @@ def save_checkpoint(step): # Generate eval set by sequentially sampling indices from the test dataset and grouping by length eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)): samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx] batch = data_collator(samples) - batch = shard(batch.data) labels = batch["labels"] - metrics, logits = p_eval_step(state.params, batch) + metrics, logits = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) eval_preds.extend(jax.device_get(logits.reshape(-1, *logits.shape[-2:]))) eval_metrics.append(metrics) - eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) + eval_labels.extend(labels) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) diff --git a/run_flax_speech_recognition_seq2seq.py b/run_flax_speech_recognition_seq2seq.py index c7ca156..395d0af 100644 --- a/run_flax_speech_recognition_seq2seq.py +++ b/run_flax_speech_recognition_seq2seq.py @@ -19,6 +19,7 @@ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. import logging +import math import os import re import sys @@ -39,7 +40,7 @@ import transformers import wandb as wandb from flax import core, jax_utils, struct, traverse_util -from flax.jax_utils import unreplicate +from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from huggingface_hub import Repository from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel @@ -567,6 +568,7 @@ def get_grouped_indices( # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler. num_samples = len(lengths) indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples) + indices = np.asarray(indices) megabatch_size = mega_batch_mult * batch_size megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] @@ -585,35 +587,20 @@ def get_grouped_indices( return megabatches -def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by - the batch size, the last incomplete batch is dropped.""" + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" num_samples = len(samples_idx) - samples_to_remove = num_samples % batch_size - - if samples_to_remove != 0: - samples_idx = samples_idx[:-samples_to_remove] - sections_split = num_samples // batch_size - return samples_idx.reshape((sections_split, batch_size)) - - -def data_loader(dataset, batch_size, rng, sampler, collator): - samples_idx = sampler(dataset, batch_size, rng) - - num_samples = len(samples_idx) - samples_to_remove = num_samples % batch_size - - if samples_to_remove != 0: - samples_idx = samples_idx[:-samples_to_remove] - sections_split = num_samples // batch_size - - batch_idx = np.split(samples_idx, sections_split) - - for idx in batch_idx: - samples = dataset[idx] - batch = collator(samples) - batch = shard(batch.data) - yield batch + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx def write_train_metric(summary_writer, train_metrics, train_time, step): @@ -1124,6 +1111,7 @@ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]): gradient_accumulation_steps = int(training_args.gradient_accumulation_steps) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() batch_size_per_update = train_batch_size * gradient_accumulation_steps + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() to_dtype = to_bf16 if training_args.mixed_precision else to_fp32 @@ -1345,29 +1333,28 @@ def run_evaluation(step, final_step=False): # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last_batch=False) for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx] batch = data_collator(samples) eval_ids.extend(batch.pop("input_ids")) - batch = shard(batch.data) labels = batch["labels"] - metrics = p_eval_step(state.params, batch) + metrics = pad_shard_unpad(p_eval_step, static_return=True)(state.params, batch.data, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # generation if training_args.predict_with_generate: if not final_step: - generated_ids = p_generate_step(state.params, batch) + generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch.data) eval_preds.extend( jax.device_get( generated_ids.reshape(-1, gen_kwargs["num_beams"], gen_kwargs["max_length"]) ) ) else: - generated_ids = p_final_generate_step(state.params, batch) + generated_ids = pad_shard_unpad(p_final_generate_step)(state.params, batch.data) eval_preds.extend( jax.device_get( generated_ids.reshape( @@ -1375,7 +1362,7 @@ def run_evaluation(step, final_step=False): ) ) ) - eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) + eval_labels.extend(labels) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) @@ -1444,7 +1431,7 @@ def save_checkpoint(step): # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng) - train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update) + train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update, drop_last_batch=True) # Gather the indices for creating the batch and do a training step for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1): @@ -1504,27 +1491,26 @@ def save_checkpoint(step): # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length pred_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size) - pred_batch_idx = generate_batch_splits(pred_samples_idx, eval_batch_size) + pred_batch_idx = generate_batch_splits(pred_samples_idx, eval_batch_size, drop_last_batch=False) for i, batch_idx in enumerate(tqdm(pred_batch_idx, desc=f"Predicting {split}...", position=2)): samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx] batch = data_collator(samples) pred_ids.extend(batch.pop("input_ids")) - batch = shard(batch.data) labels = batch["labels"] - metrics = p_eval_step(state.params, batch) + metrics = pad_shard_unpad(p_eval_step, static_return=True)(state.params, batch.data) pred_metrics.append(metrics) # generation if training_args.predict_with_generate: - generated_ids = p_final_generate_step(state.params, batch) + generated_ids = pad_shard_unpad(p_final_generate_step)(state.params, batch.data, min_device_batch_size=per_device_eval_batch_size) pred_generations.extend( jax.device_get( generated_ids.reshape(-1, final_gen_kwargs["num_beams"], final_gen_kwargs["max_length"]) ) ) - pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) + pred_labels.extend(labels) # normalize eval metrics pred_metrics = get_metrics(pred_metrics)