diff --git a/projects/cringe/cringe_loss.py b/projects/cringe/cringe_loss.py index 0e0916a0080..dab1634724d 100644 --- a/projects/cringe/cringe_loss.py +++ b/projects/cringe/cringe_loss.py @@ -34,6 +34,7 @@ def __init__( num_pos_predictions=1, detach_positives_during_ct=False, train_ct_on_positive_examples=False, + train_ce_on_positive_examples=True, **kwargs, ): super().__init__(**kwargs) @@ -41,6 +42,7 @@ def __init__( self.num_pos_predictions = num_pos_predictions self.detach_positives_during_ct = detach_positives_during_ct self.train_ct_on_positive_examples = train_ct_on_positive_examples + self.train_ce_on_positive_examples = train_ce_on_positive_examples def __call__(self, x, y, classifier_labels=None, **kwargs): if classifier_labels is None: @@ -48,6 +50,9 @@ def __call__(self, x, y, classifier_labels=None, **kwargs): # turn no-class provided label (-1) into positive label (1) classifier_labels_ce = torch.abs(classifier_labels) + if not self.train_ce_on_positive_examples: + # only train CE on no-class labels + classifier_labels_ce = classifier_labels.eq(-1) if self.train_ct_on_positive_examples: # no-class (-1 to 0), positive (1 to 1), negative (0 to 1) @@ -55,6 +60,7 @@ def __call__(self, x, y, classifier_labels=None, **kwargs): else: # no-class (-1 to 0), positive (1 to 0), negative (0 to 1) classifier_labels_ct = torch.abs(torch.abs(classifier_labels) - 1) + classifier_labels_ct = classifier_labels_ct.bool() ce_loss = super().__call__(x, y, **kwargs) # multiply with classifier labels to not train with negative feedback (0) @@ -63,7 +69,7 @@ def __call__(self, x, y, classifier_labels=None, **kwargs): # compute the contrastive loss part for the negative labels # first, get the positives as the top predictions != target preds = torch.topk(x, k=self.num_pos_predictions + 1, axis=-1) - y_rep = y.unsqueeze(1).repeat(1, self.num_pos_predictions + 1) + y_rep = y.unsqueeze(-1).repeat(1, self.num_pos_predictions + 1) logits = preds.values - (preds.indices == y_rep) * 1e10 # if the positive is not in the first k predictions, mask out @@ -104,7 +110,7 @@ def __call__(self, x, y, classifier_labels=None, **kwargs): loss = ce_loss + self.ct_loss_weight * ct_loss - return loss, ce_loss, ct_loss + return loss, ce_loss, ct_loss, classifier_labels_ce, classifier_labels_ct class ContrastiveTransformerGeneratorAgent(TransformerGeneratorAgent): @@ -137,12 +143,18 @@ def add_cmdline_args( default=False, ) parser.add_argument( - '--train-ct-on-positive_examples', + '--train-ct-on-positive-examples', type=bool, help='If true, we train with the positive examples in the contrastive loss' ' (with the negatives being the top-k sampled from the model).', default=False, ) + parser.add_argument( + '--train-ce-on-positive-examples', + type=bool, + help='If true, we train with the positive examples in the cross entropy loss.', + default=True, + ) super().add_cmdline_args(parser, partial_opt=partial_opt) return agent @@ -153,6 +165,9 @@ def build_criterion(self): detach_positives_during_ct=self.opt['ct_detach_positives'], ignore_index=self.NULL_IDX, train_ct_on_positive_examples=self.opt['train_ct_on_positive_examples'], + train_ce_on_positive_examples=self.opt.get( + 'train_ce_on_positive_examples', True + ), reduction='none', ) @@ -210,7 +225,7 @@ def compute_loss(self, batch, return_output=False): model_output = self.model(*self._model_input(batch), ys=batch.label_vec) scores, preds, *_ = model_output score_view = scores.reshape(-1, scores.size(-1)) - (loss, ce_loss, ct_loss,) = self.criterion( + (loss, ce_loss, ct_loss, ce_mask, ct_mask) = self.criterion( score_view, batch.label_vec.view(-1), batch.classifier_label.repeat(1, scores.shape[1]) @@ -225,8 +240,15 @@ def loss_reshape(loss): ce_loss = loss_reshape(ce_loss) ct_loss = loss_reshape(ct_loss) notnull = batch.label_vec.ne(self.NULL_IDX) - target_tokens = notnull.long().sum(dim=-1) - correct = ((batch.label_vec == preds) * notnull).sum(dim=-1) + ce_mask = torch.logical_and(notnull, ce_mask.view(-1, batch.label_vec.size(-1))) + ct_mask = torch.logical_and(notnull, ct_mask.view(-1, batch.label_vec.size(-1))) + # number of tokens in each examples for cross entropy or cringe loss. + metric_notnull = torch.logical_or(ce_mask, ct_mask) + target_tokens = metric_notnull.long().sum(dim=-1) + ce_target_tokens = ce_mask.long().sum(dim=-1) + ct_target_tokens = ct_mask.long().sum(dim=-1) + + correct = ((batch.label_vec == preds) * metric_notnull).sum(dim=-1) pos_labels = (torch.abs(batch.classifier_label) == 1).view(-1) neg_labels = (torch.abs(batch.classifier_label) == 0).view(-1) @@ -238,20 +260,30 @@ def loss_reshape(loss): self.record_local_metric( 'ce_loss', [ - metric if metric > 0.0 else None - for metric in AverageMetric.many(ce_loss, target_tokens) - ], + metric if ce_token_cnt > 0 else None + for ce_token_cnt, metric in zip( + ce_target_tokens, AverageMetric.many(ce_loss, target_tokens) + ) + ], # type: ignore ) self.record_local_metric( 'ct_loss', [ - metric if metric > 0.0 else None - for metric in AverageMetric.many(ct_loss, target_tokens) - ], + metric if ct_token_cnt > 0 else None + for ct_token_cnt, metric in zip( + ct_target_tokens, AverageMetric.many(ct_loss, target_tokens) + ) + ], # type: ignore ) # token-wise accuracy self.record_local_metric( - 'token_acc', AverageMetric.many(correct, target_tokens) + 'token_acc', + [ + metric if per_target_token > 0 else None + for per_target_token, metric in zip( + target_tokens, AverageMetric.many(correct, target_tokens) + ) + ], # type: ignore ) self.record_local_metric( 'token_acc_pos', @@ -269,7 +301,22 @@ def loss_reshape(loss): ) # perplexity self.record_local_metric( - 'ppl_debug', PPLMetric.many(ce_loss + ct_loss, target_tokens) + 'ppl_debug', + [ + metric if per_target_token > 0 else None + for per_target_token, metric in zip( + target_tokens, PPLMetric.many(ce_loss + ct_loss, target_tokens) + ) + ], # type: ignore + ) + self.record_local_metric( + 'ppl_ce', + [ + metric if ce_token_cnt > 0 else None + for ce_token_cnt, metric in zip( + ce_target_tokens, PPLMetric.many(ce_loss, ce_target_tokens) + ) + ], # type: ignore ) self.record_local_metric( 'ppl_pos', @@ -290,6 +337,17 @@ def loss_reshape(loss): ], ) + # record sample size + self.record_local_metric( + 'ce_target_tokens', AverageMetric.many(ce_target_tokens) + ) + self.record_local_metric( + 'ct_target_tokens', AverageMetric.many(ct_target_tokens) + ) + self.record_local_metric( + 'total_target_tokens', AverageMetric.many(target_tokens) + ) + # actually do backwards loss loss = loss.sum() loss /= target_tokens.sum() # average loss per token