-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pad each batch, not the whole dataset #30
base: master
Are you sure you want to change the base?
Conversation
return train_loader, valid_loader, train_sampler, valid_sampler | ||
|
||
|
||
def make_data_lists(args, personachat, tokenizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring
@@ -86,36 +139,20 @@ def get_data_loaders(args, tokenizer): | |||
persona = dialog["personality"].copy() | |||
for _ in range(args.personality_permutations): | |||
for utterance in dialog["utterances"]: | |||
history = utterance["history"][-(2*args.max_history+1):] | |||
candidate_instances = defaultdict(list) | |||
history = utterance["history"][-(2 * args.max_history + 1):] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could add assert len(utterance['candidates']) >= num_candidates
@@ -72,11 +73,63 @@ def build_input_from_segments(persona, history, reply, tokenizer, lm_labels=Fals | |||
return instance, sequence # TODO: second arg is never used, delete it | |||
|
|||
|
|||
def pad_and_tensorize(batch_dict, padding): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this and ChatDataset
should be easy to unit test
valid_dataset = ChatDataset(datasets['valid'], pad_id) | ||
|
||
logger.info("Build train and validation dataloaders") | ||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(maybe) put this in ChatDataset.to_loader(self, args, shuffle) -> sampler, loader
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at some point might also want to document which tensors are 3D
for input_name, input_array in instance.items(): | ||
datasets[dataset_name][input_name].append(input_array) | ||
candidate_instances[input_name].append(input_array) | ||
for k in candidate_instances.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.items()
will save some chars
train.py
Outdated
for j, candidate in enumerate(utterance["candidates"][-num_candidates:]): | ||
lm_labels = bool(j == num_candidates-1) | ||
instance, _ = build_input_from_segments(persona, history, candidate, tokenizer, lm_labels) | ||
lm_labels = bool(j == num_candidates - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better varname?
Previously, each sequence was padded to the length of the longest sequence in the dataset.
In this PR, each batch is padded to the length of the longest sequence in the batch. This results in a 30% speedup with negligible impact on metrics.
Code Changes
ChatDataset
yields example dicts like{'input_ids': [[hist + cand1], ..[hist +cand_n]],}
for thePADDED_INPUTS
andmc_token_ids
andmc_labels
in the same format as previously.ChatDataset().collate_fn(examples: list)
turns a list of example dicts into the list of 5 tensors by batching them and padding themget_dataloaders
does much lessconvai_evaluation.py
still calls the oldpad_dataset
1 Epoch Sanity Check
Before Change: 85 minutes
Validation: {'accuracy': 0.7483655941545956,
'average_accuracy': 0.7483655941545956,
'average_nll': 2.6815188920676687,
'average_ppl': 14.607263311061963,
'nll': 2.6815188920676687}
After Change: 60 minutes
Validation: {'accuracy': 0.7466991411357519,
'average_accuracy': 0.7466991411357519,
'average_nll': 2.6821035040007972,
'average_ppl': 14.615805388160778,
'nll': 2.6821035040007972}
Command: