Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Examples/TensorFlow] minor refactoring to allow compatible datasets to work #22879

Merged
merged 1 commit into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ def parse_args():
parser = argparse.ArgumentParser(
description="Prepare TFRecord shards from pre-tokenized samples of the wikitext dataset."
)
parser.add_argument(
"--dataset_name",
type=str,
default="wikitext",
help="Name of the training. Explore datasets at: hf.co/datasets.",
)
parser.add_argument(
"--dataset_config", type=str, default="wikitext-103-raw-v1", help="Configuration name of the dataset."
)
parser.add_argument(
"--tokenizer_name_or_path",
type=str,
Expand Down Expand Up @@ -96,11 +105,11 @@ def get_serialized_examples(tokenized_data):


def main(args):
wikitext = datasets.load_dataset("wikitext", "wikitext-103-raw-v1", split=args.split)
dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split=args.split)

if args.limit is not None:
max_samples = min(len(wikitext), args.limit)
wikitext = wikitext.select(range(max_samples))
max_samples = min(len(dataset), args.limit)
dataset = dataset.select(range(max_samples))
print(f"Limiting the dataset to {args.limit} entries.")

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
Expand All @@ -119,7 +128,7 @@ def main(args):

# Tokenize the whole dataset at once.
tokenize_fn = tokenize_function(tokenizer)
wikitext_tokenized = wikitext.map(tokenize_fn, batched=True, num_proc=4, remove_columns=["text"])
dataset_tokenized = dataset.map(tokenize_fn, batched=True, num_proc=4, remove_columns=["text"])

# We need to concatenate all our texts together, and then split the result
# into chunks of a fixed size, which we will call block_size. To do this, we
Expand All @@ -144,14 +153,14 @@ def group_texts(examples):
}
return result

grouped_dataset = wikitext_tokenized.map(group_texts, batched=True, batch_size=1000, num_proc=4)
grouped_dataset = dataset_tokenized.map(group_texts, batched=True, batch_size=1000, num_proc=4)

shard_count = 0
total_records = 0
for shard in range(0, len(grouped_dataset), args.shard_size):
dataset_snapshot = grouped_dataset[shard : shard + args.shard_size]
records_containing = len(dataset_snapshot["input_ids"])
filename = os.path.join(split_dir, f"wikitext-{shard_count}-{records_containing}.tfrecord")
filename = os.path.join(split_dir, f"dataset-{shard_count}-{records_containing}.tfrecord")
serialized_examples = get_serialized_examples(dataset_snapshot)

with tf.io.TFRecordWriter(filename) as out_file:
Expand Down
12 changes: 6 additions & 6 deletions examples/tensorflow/language-modeling-tpu/train_unigram.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,16 @@ def parse_args():


def main(args):
wikitext = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")
dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")

if args.limit is not None:
max_train_samples = min(len(wikitext), args.limit)
wikitext = wikitext.select(range(max_train_samples))
max_train_samples = min(len(dataset), args.limit)
dataset = dataset.select(range(max_train_samples))
logger.info(f"Limiting the dataset to {args.limit} entries.")

def batch_iterator():
for i in range(0, len(wikitext), args.batch_size):
yield wikitext[i : i + args.batch_size]["text"]
for i in range(0, len(dataset), args.batch_size):
yield dataset[i : i + args.batch_size]["text"]

# Prepare the tokenizer.
tokenizer = Tokenizer(Unigram())
Expand Down Expand Up @@ -111,7 +111,7 @@ def batch_iterator():
if args.export_to_hub:
logger.info("Exporting the trained tokenzier to Hub.")
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
new_tokenizer.push_to_hub("unigram-tokenizer-wikitext")
new_tokenizer.push_to_hub("unigram-tokenizer-dataset")


if __name__ == "__main__":
Expand Down