Skip to content

Commit

Permalink
[Examples/TensorFlow] minor refactoring to allow compatible datasets …
Browse files Browse the repository at this point in the history
…to work (#22879)

minor refactoring to allow compatible datasets to work.
  • Loading branch information
sayakpaul authored Apr 20, 2023
1 parent 10dd3a7 commit 4116d1e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ def parse_args():
parser = argparse.ArgumentParser(
description="Prepare TFRecord shards from pre-tokenized samples of the wikitext dataset."
)
parser.add_argument(
"--dataset_name",
type=str,
default="wikitext",
help="Name of the training. Explore datasets at: hf.co/datasets.",
)
parser.add_argument(
"--dataset_config", type=str, default="wikitext-103-raw-v1", help="Configuration name of the dataset."
)
parser.add_argument(
"--tokenizer_name_or_path",
type=str,
Expand Down Expand Up @@ -96,11 +105,11 @@ def get_serialized_examples(tokenized_data):


def main(args):
wikitext = datasets.load_dataset("wikitext", "wikitext-103-raw-v1", split=args.split)
dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split=args.split)

if args.limit is not None:
max_samples = min(len(wikitext), args.limit)
wikitext = wikitext.select(range(max_samples))
max_samples = min(len(dataset), args.limit)
dataset = dataset.select(range(max_samples))
print(f"Limiting the dataset to {args.limit} entries.")

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
Expand All @@ -119,7 +128,7 @@ def main(args):

# Tokenize the whole dataset at once.
tokenize_fn = tokenize_function(tokenizer)
wikitext_tokenized = wikitext.map(tokenize_fn, batched=True, num_proc=4, remove_columns=["text"])
dataset_tokenized = dataset.map(tokenize_fn, batched=True, num_proc=4, remove_columns=["text"])

# We need to concatenate all our texts together, and then split the result
# into chunks of a fixed size, which we will call block_size. To do this, we
Expand All @@ -144,14 +153,14 @@ def group_texts(examples):
}
return result

grouped_dataset = wikitext_tokenized.map(group_texts, batched=True, batch_size=1000, num_proc=4)
grouped_dataset = dataset_tokenized.map(group_texts, batched=True, batch_size=1000, num_proc=4)

shard_count = 0
total_records = 0
for shard in range(0, len(grouped_dataset), args.shard_size):
dataset_snapshot = grouped_dataset[shard : shard + args.shard_size]
records_containing = len(dataset_snapshot["input_ids"])
filename = os.path.join(split_dir, f"wikitext-{shard_count}-{records_containing}.tfrecord")
filename = os.path.join(split_dir, f"dataset-{shard_count}-{records_containing}.tfrecord")
serialized_examples = get_serialized_examples(dataset_snapshot)

with tf.io.TFRecordWriter(filename) as out_file:
Expand Down
12 changes: 6 additions & 6 deletions examples/tensorflow/language-modeling-tpu/train_unigram.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,16 @@ def parse_args():


def main(args):
wikitext = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")
dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")

if args.limit is not None:
max_train_samples = min(len(wikitext), args.limit)
wikitext = wikitext.select(range(max_train_samples))
max_train_samples = min(len(dataset), args.limit)
dataset = dataset.select(range(max_train_samples))
logger.info(f"Limiting the dataset to {args.limit} entries.")

def batch_iterator():
for i in range(0, len(wikitext), args.batch_size):
yield wikitext[i : i + args.batch_size]["text"]
for i in range(0, len(dataset), args.batch_size):
yield dataset[i : i + args.batch_size]["text"]

# Prepare the tokenizer.
tokenizer = Tokenizer(Unigram())
Expand Down Expand Up @@ -111,7 +111,7 @@ def batch_iterator():
if args.export_to_hub:
logger.info("Exporting the trained tokenzier to Hub.")
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
new_tokenizer.push_to_hub("unigram-tokenizer-wikitext")
new_tokenizer.push_to_hub("unigram-tokenizer-dataset")


if __name__ == "__main__":
Expand Down

0 comments on commit 4116d1e

Please sign in to comment.