From 6bb7510a50af4b736df296620fa58a77fea978e2 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 7 Nov 2018 22:12:41 +0100 Subject: [PATCH 1/2] fixing pre-processing bug - averaging loss for gradient accumulation - no_grad on evaluation --- run_classifier.py | 21 +++++++++------------ run_squad.py | 45 +++++++++++++++------------------------------ 2 files changed, 24 insertions(+), 42 deletions(-) diff --git a/run_classifier.py b/run_classifier.py index b5290afd129221..c19c6f9ac071cc 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -458,7 +458,6 @@ def main(): raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() - label_list = processor.get_labels() tokenizer = tokenization.FullTokenizer( @@ -518,20 +517,18 @@ def main(): for epoch in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 - for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(tqdm(train_dataloader, desc="Iteration")): - input_ids = input_ids.to(device) - input_mask = input_mask.to(device) - segment_ids = segment_ids.to(device) - label_ids = label_ids.to(device) - + for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): + batch = tuple(t.to(device) for t in batch) + input_ids, input_mask, segment_ids, label_ids = batch loss, _ = model(input_ids, segment_ids, input_mask, label_ids) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + loss.backward() tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 - loss.backward() - if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() # We have accumulated enought gradients model.zero_grad() @@ -579,13 +576,13 @@ def main(): nb_eval_examples += input_ids.size(0) nb_eval_steps += 1 - eval_loss = eval_loss / nb_eval_steps #len(eval_dataloader) - eval_accuracy = eval_accuracy / nb_eval_examples #len(eval_dataloader) + eval_loss = eval_loss / nb_eval_steps + eval_accuracy = eval_accuracy / nb_eval_examples result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy, 'global_step': global_step, - 'loss': tr_loss/nb_tr_steps}#'loss': loss.item()} + 'loss': tr_loss/nb_tr_steps} output_eval_file = os.path.join(args.output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: diff --git a/run_squad.py b/run_squad.py index 8a69e057e58271..a25893e1d9abd3 100644 --- a/run_squad.py +++ b/run_squad.py @@ -743,7 +743,7 @@ def main(): type=int, default=1, help="Number of updates steps to accumualte before performing a backward/update pass.") - + args = parser.parse_args() if args.local_rank == -1 or args.no_cuda: @@ -857,20 +857,13 @@ def main(): model.train() for epoch in trange(int(args.num_train_epochs), desc="Epoch"): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): + batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, start_positions, end_positions = batch - input_ids = input_ids.to(device) - input_mask = input_mask.to(device) - segment_ids = segment_ids.to(device) - start_positions = start_positions.to(device) - end_positions = start_positions.to(device) - - start_positions = start_positions.view(-1, 1) - end_positions = end_positions.view(-1, 1) - loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. - + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() # We have accumulated enought gradients @@ -908,30 +901,22 @@ def main(): model.eval() all_results = [] logger.info("Start evaluating") - for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, desc="Evaluating"): + for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): if len(all_results) % 1000 == 0: logger.info("Processing example: %d" % (len(all_results))) - input_ids = input_ids.to(device) input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) - - start_logits, end_logits = model(input_ids, segment_ids, input_mask) - - unique_id = [int(eval_features[e.item()].unique_id) for e in example_index] - start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits] - end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits] - for idx, i in enumerate(unique_id): - s = [float(x) for x in start_logits[idx]] - e = [float(x) for x in end_logits[idx]] - all_results.append( - RawResult( - unique_id=i, - start_logits=s, - end_logits=e - ) - ) - + with torch.no_grad(): + batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask) + for i, example_index in enumerate(example_indices): + start_logits = batch_start_logits[i].detach().cpu().tolist() + end_logits = batch_end_logits[i].detach().cpu().tolist() + eval_feature = eval_features[example_index.item()] + unique_id = int(eval_feature.unique_id) + all_results.append(RawResult(unique_id=unique_id, + start_logits=start_logits, + end_logits=end_logits)) output_prediction_file = os.path.join(args.output_dir, "predictions.json") output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") write_predictions(eval_examples, eval_features, all_results, From dbc318a4c605374f6663098ffa8701a626f2b23a Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 7 Nov 2018 22:22:55 +0100 Subject: [PATCH 2/2] cleaning up - speeding up a bit multi-gpu --- modeling.py | 2 +- run_classifier.py | 7 ++++--- run_squad.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/modeling.py b/modeling.py index c467e8266efa82..860cb939a4f855 100644 --- a/modeling.py +++ b/modeling.py @@ -467,6 +467,6 @@ def forward(self, input_ids, token_type_ids, attention_mask, start_positions=Non start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 - return total_loss, (start_logits, end_logits) + return total_loss else: return start_logits, end_logits diff --git a/run_classifier.py b/run_classifier.py index c19c6f9ac071cc..41c7459bd35a31 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -514,13 +514,13 @@ def main(): train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) model.train() - for epoch in trange(int(args.num_train_epochs), desc="Epoch"): + for _ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch - loss, _ = model(input_ids, segment_ids, input_mask, label_ids) + loss = model(input_ids, segment_ids, input_mask, label_ids) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: @@ -564,7 +564,8 @@ def main(): segment_ids = segment_ids.to(device) label_ids = label_ids.to(device) - tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) + with torch.no_grad(): + tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) logits = logits.detach().cpu().numpy() label_ids = label_ids.to('cpu').numpy() diff --git a/run_squad.py b/run_squad.py index a25893e1d9abd3..78dff7dea5b50b 100644 --- a/run_squad.py +++ b/run_squad.py @@ -855,11 +855,11 @@ def main(): train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) model.train() - for epoch in trange(int(args.num_train_epochs), desc="Epoch"): + for _ in trange(int(args.num_train_epochs), desc="Epoch"): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, start_positions, end_positions = batch - loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) + loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: