diff --git a/examples/tensorflow/language-modeling-tpu/prepare_tfrecord_shards.py b/examples/tensorflow/language-modeling-tpu/prepare_tfrecord_shards.py index 93ab29b74201..a8bb7d37929f 100644 --- a/examples/tensorflow/language-modeling-tpu/prepare_tfrecord_shards.py +++ b/examples/tensorflow/language-modeling-tpu/prepare_tfrecord_shards.py @@ -33,6 +33,15 @@ def parse_args(): parser = argparse.ArgumentParser( description="Prepare TFRecord shards from pre-tokenized samples of the wikitext dataset." ) + parser.add_argument( + "--dataset_name", + type=str, + default="wikitext", + help="Name of the training. Explore datasets at: hf.co/datasets.", + ) + parser.add_argument( + "--dataset_config", type=str, default="wikitext-103-raw-v1", help="Configuration name of the dataset." + ) parser.add_argument( "--tokenizer_name_or_path", type=str, @@ -96,11 +105,11 @@ def get_serialized_examples(tokenized_data): def main(args): - wikitext = datasets.load_dataset("wikitext", "wikitext-103-raw-v1", split=args.split) + dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split=args.split) if args.limit is not None: - max_samples = min(len(wikitext), args.limit) - wikitext = wikitext.select(range(max_samples)) + max_samples = min(len(dataset), args.limit) + dataset = dataset.select(range(max_samples)) print(f"Limiting the dataset to {args.limit} entries.") tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) @@ -119,7 +128,7 @@ def main(args): # Tokenize the whole dataset at once. tokenize_fn = tokenize_function(tokenizer) - wikitext_tokenized = wikitext.map(tokenize_fn, batched=True, num_proc=4, remove_columns=["text"]) + dataset_tokenized = dataset.map(tokenize_fn, batched=True, num_proc=4, remove_columns=["text"]) # We need to concatenate all our texts together, and then split the result # into chunks of a fixed size, which we will call block_size. To do this, we @@ -144,14 +153,14 @@ def group_texts(examples): } return result - grouped_dataset = wikitext_tokenized.map(group_texts, batched=True, batch_size=1000, num_proc=4) + grouped_dataset = dataset_tokenized.map(group_texts, batched=True, batch_size=1000, num_proc=4) shard_count = 0 total_records = 0 for shard in range(0, len(grouped_dataset), args.shard_size): dataset_snapshot = grouped_dataset[shard : shard + args.shard_size] records_containing = len(dataset_snapshot["input_ids"]) - filename = os.path.join(split_dir, f"wikitext-{shard_count}-{records_containing}.tfrecord") + filename = os.path.join(split_dir, f"dataset-{shard_count}-{records_containing}.tfrecord") serialized_examples = get_serialized_examples(dataset_snapshot) with tf.io.TFRecordWriter(filename) as out_file: diff --git a/examples/tensorflow/language-modeling-tpu/train_unigram.py b/examples/tensorflow/language-modeling-tpu/train_unigram.py index 65cd2c757728..ea8246a99f3b 100644 --- a/examples/tensorflow/language-modeling-tpu/train_unigram.py +++ b/examples/tensorflow/language-modeling-tpu/train_unigram.py @@ -69,16 +69,16 @@ def parse_args(): def main(args): - wikitext = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train") + dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train") if args.limit is not None: - max_train_samples = min(len(wikitext), args.limit) - wikitext = wikitext.select(range(max_train_samples)) + max_train_samples = min(len(dataset), args.limit) + dataset = dataset.select(range(max_train_samples)) logger.info(f"Limiting the dataset to {args.limit} entries.") def batch_iterator(): - for i in range(0, len(wikitext), args.batch_size): - yield wikitext[i : i + args.batch_size]["text"] + for i in range(0, len(dataset), args.batch_size): + yield dataset[i : i + args.batch_size]["text"] # Prepare the tokenizer. tokenizer = Tokenizer(Unigram()) @@ -111,7 +111,7 @@ def batch_iterator(): if args.export_to_hub: logger.info("Exporting the trained tokenzier to Hub.") new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer) - new_tokenizer.push_to_hub("unigram-tokenizer-wikitext") + new_tokenizer.push_to_hub("unigram-tokenizer-dataset") if __name__ == "__main__":