diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index ef0445e9007d24..c4af2da08c3bc1 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -113,7 +113,7 @@ def __init__(self, hparams, **kwargs): Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset ) self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams - assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" + assert self.eval_beams >= 0, f"got self.eval_beams={self.eval_beams}. Need an integer >= 0" self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric def freeze_embeds(self): @@ -159,14 +159,14 @@ def _step(self, batch: dict) -> Tuple: loss, nll_loss = label_smoothed_nll_loss( lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id ) - return (loss,) + return (loss,), lm_logits @property def pad(self) -> int: return self.tokenizer.pad_token_id def training_step(self, batch, batch_idx) -> Dict: - loss_tensors = self._step(batch) + loss_tensors, logits = self._step(batch) logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} # tokens per batch @@ -209,18 +209,21 @@ def calc_generative_metrics(self, preds, target) -> Dict: def _generative_step(self, batch: dict) -> dict: t0 = time.time() - generated_ids = self.model.generate( - batch["input_ids"], - attention_mask=batch["attention_mask"], - use_cache=True, - decoder_start_token_id=self.decoder_start_token_id, - num_beams=self.eval_beams, - ) + loss_tensors, logits = self._step(batch) + if self.eval_beams == 0: + generated_ids = torch.argmax(logits.detach(), axis=-1) + else: + generated_ids = self.model.generate( + batch["input_ids"], + attention_mask=batch["attention_mask"], + use_cache=True, + decoder_start_token_id=self.decoder_start_token_id, + num_beams=self.eval_beams, + ) gen_time = (time.time() - t0) / batch["input_ids"].shape[0] preds: List[str] = self.ids_to_clean_text(generated_ids) target: List[str] = self.ids_to_clean_text(batch["labels"]) - loss_tensors = self._step(batch) - base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} + base_metrics = {name: loss.detach() for name, loss in zip(self.loss_names, loss_tensors)} rouge: Dict = self.calc_generative_metrics(preds, target) summ_len = np.mean(lmap(len, generated_ids)) base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge) @@ -317,7 +320,7 @@ def add_model_specific_args(parser, root_dir): parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) parser.add_argument("--src_lang", type=str, default="", required=False) parser.add_argument("--tgt_lang", type=str, default="", required=False) - parser.add_argument("--eval_beams", type=int, default=None, required=False) + parser.add_argument("--eval_beams", type=int, default=None, required=False, help="# beams to use. 0 corresponds to not using beam search.") parser.add_argument( "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None] )