Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pad each batch, not the whole dataset #30

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

sshleifer
Copy link
Contributor

@sshleifer sshleifer commented Sep 23, 2019

Previously, each sequence was padded to the length of the longest sequence in the dataset.
In this PR, each batch is padded to the length of the longest sequence in the batch. This results in a 30% speedup with negligible impact on metrics.

Code Changes

  • ChatDataset yields example dicts like {'input_ids': [[hist + cand1], ..[hist +cand_n]],} for the PADDED_INPUTS and mc_token_ids and mc_labels in the same format as previously.
  • ChatDataset().collate_fn(examples: list) turns a list of example dicts into the list of 5 tensors by batching them and padding them
  • As a result, get_dataloaders does much less
  • There is a data format change to the part of the process where we make lists of examples to facilitate this.
  • convai_evaluation.py still calls the old pad_dataset

1 Epoch Sanity Check

Before Change: 85 minutes
Validation: {'accuracy': 0.7483655941545956,
'average_accuracy': 0.7483655941545956,
'average_nll': 2.6815188920676687,
'average_ppl': 14.607263311061963,
'nll': 2.6815188920676687}

After Change: 60 minutes
Validation: {'accuracy': 0.7466991411357519,
'average_accuracy': 0.7466991411357519,
'average_nll': 2.6821035040007972,
'average_ppl': 14.615805388160778,
'nll': 2.6821035040007972}

Command:

python train.py --model_checkpoint openai-gpt --dataset_cache dataset_cache --fp16 O1 --n_epochs 1 --train_batch_size 4

@sshleifer sshleifer changed the title (WIP) Pad each batch, not the whole dataset Pad each batch, not the whole dataset Sep 29, 2019
return train_loader, valid_loader, train_sampler, valid_sampler


def make_data_lists(args, personachat, tokenizer):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

@@ -86,36 +139,20 @@ def get_data_loaders(args, tokenizer):
persona = dialog["personality"].copy()
for _ in range(args.personality_permutations):
for utterance in dialog["utterances"]:
history = utterance["history"][-(2*args.max_history+1):]
candidate_instances = defaultdict(list)
history = utterance["history"][-(2 * args.max_history + 1):]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could add assert len(utterance['candidates']) >= num_candidates

@@ -72,11 +73,63 @@ def build_input_from_segments(persona, history, reply, tokenizer, lm_labels=Fals
return instance, sequence # TODO: second arg is never used, delete it


def pad_and_tensorize(batch_dict, padding):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this and ChatDataset should be easy to unit test

valid_dataset = ChatDataset(datasets['valid'], pad_id)

logger.info("Build train and validation dataloaders")
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(maybe) put this in ChatDataset.to_loader(self, args, shuffle) -> sampler, loader

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at some point might also want to document which tensors are 3D

for input_name, input_array in instance.items():
datasets[dataset_name][input_name].append(input_array)
candidate_instances[input_name].append(input_array)
for k in candidate_instances.keys():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.items() will save some chars

train.py Outdated
for j, candidate in enumerate(utterance["candidates"][-num_candidates:]):
lm_labels = bool(j == num_candidates-1)
instance, _ = build_input_from_segments(persona, history, candidate, tokenizer, lm_labels)
lm_labels = bool(j == num_candidates - 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better varname?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant