Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
cringe logging (#5036)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jing authored May 18, 2023
1 parent 14ba31b commit dea6377
Showing 1 changed file with 72 additions and 14 deletions.
86 changes: 72 additions & 14 deletions projects/cringe/cringe_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,33 @@ def __init__(
num_pos_predictions=1,
detach_positives_during_ct=False,
train_ct_on_positive_examples=False,
train_ce_on_positive_examples=True,
**kwargs,
):
super().__init__(**kwargs)
self.ct_loss_weight = ct_loss_weight
self.num_pos_predictions = num_pos_predictions
self.detach_positives_during_ct = detach_positives_during_ct
self.train_ct_on_positive_examples = train_ct_on_positive_examples
self.train_ce_on_positive_examples = train_ce_on_positive_examples

def __call__(self, x, y, classifier_labels=None, **kwargs):
if classifier_labels is None:
classifier_labels = -torch.ones_like(y).to(y.device)

# turn no-class provided label (-1) into positive label (1)
classifier_labels_ce = torch.abs(classifier_labels)
if not self.train_ce_on_positive_examples:
# only train CE on no-class labels
classifier_labels_ce = classifier_labels.eq(-1)

if self.train_ct_on_positive_examples:
# no-class (-1 to 0), positive (1 to 1), negative (0 to 1)
classifier_labels_ct = torch.clamp(classifier_labels + 1, max=1)
else:
# no-class (-1 to 0), positive (1 to 0), negative (0 to 1)
classifier_labels_ct = torch.abs(torch.abs(classifier_labels) - 1)
classifier_labels_ct = classifier_labels_ct.bool()

ce_loss = super().__call__(x, y, **kwargs)
# multiply with classifier labels to not train with negative feedback (0)
Expand All @@ -63,7 +69,7 @@ def __call__(self, x, y, classifier_labels=None, **kwargs):
# compute the contrastive loss part for the negative labels
# first, get the positives as the top predictions != target
preds = torch.topk(x, k=self.num_pos_predictions + 1, axis=-1)
y_rep = y.unsqueeze(1).repeat(1, self.num_pos_predictions + 1)
y_rep = y.unsqueeze(-1).repeat(1, self.num_pos_predictions + 1)
logits = preds.values - (preds.indices == y_rep) * 1e10

# if the positive is not in the first k predictions, mask out
Expand Down Expand Up @@ -104,7 +110,7 @@ def __call__(self, x, y, classifier_labels=None, **kwargs):

loss = ce_loss + self.ct_loss_weight * ct_loss

return loss, ce_loss, ct_loss
return loss, ce_loss, ct_loss, classifier_labels_ce, classifier_labels_ct


class ContrastiveTransformerGeneratorAgent(TransformerGeneratorAgent):
Expand Down Expand Up @@ -137,12 +143,18 @@ def add_cmdline_args(
default=False,
)
parser.add_argument(
'--train-ct-on-positive_examples',
'--train-ct-on-positive-examples',
type=bool,
help='If true, we train with the positive examples in the contrastive loss'
' (with the negatives being the top-k sampled from the model).',
default=False,
)
parser.add_argument(
'--train-ce-on-positive-examples',
type=bool,
help='If true, we train with the positive examples in the cross entropy loss.',
default=True,
)
super().add_cmdline_args(parser, partial_opt=partial_opt)
return agent

Expand All @@ -153,6 +165,9 @@ def build_criterion(self):
detach_positives_during_ct=self.opt['ct_detach_positives'],
ignore_index=self.NULL_IDX,
train_ct_on_positive_examples=self.opt['train_ct_on_positive_examples'],
train_ce_on_positive_examples=self.opt.get(
'train_ce_on_positive_examples', True
),
reduction='none',
)

Expand Down Expand Up @@ -210,7 +225,7 @@ def compute_loss(self, batch, return_output=False):
model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
scores, preds, *_ = model_output
score_view = scores.reshape(-1, scores.size(-1))
(loss, ce_loss, ct_loss,) = self.criterion(
(loss, ce_loss, ct_loss, ce_mask, ct_mask) = self.criterion(
score_view,
batch.label_vec.view(-1),
batch.classifier_label.repeat(1, scores.shape[1])
Expand All @@ -225,8 +240,15 @@ def loss_reshape(loss):
ce_loss = loss_reshape(ce_loss)
ct_loss = loss_reshape(ct_loss)
notnull = batch.label_vec.ne(self.NULL_IDX)
target_tokens = notnull.long().sum(dim=-1)
correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)
ce_mask = torch.logical_and(notnull, ce_mask.view(-1, batch.label_vec.size(-1)))
ct_mask = torch.logical_and(notnull, ct_mask.view(-1, batch.label_vec.size(-1)))
# number of tokens in each examples for cross entropy or cringe loss.
metric_notnull = torch.logical_or(ce_mask, ct_mask)
target_tokens = metric_notnull.long().sum(dim=-1)
ce_target_tokens = ce_mask.long().sum(dim=-1)
ct_target_tokens = ct_mask.long().sum(dim=-1)

correct = ((batch.label_vec == preds) * metric_notnull).sum(dim=-1)

pos_labels = (torch.abs(batch.classifier_label) == 1).view(-1)
neg_labels = (torch.abs(batch.classifier_label) == 0).view(-1)
Expand All @@ -238,20 +260,30 @@ def loss_reshape(loss):
self.record_local_metric(
'ce_loss',
[
metric if metric > 0.0 else None
for metric in AverageMetric.many(ce_loss, target_tokens)
],
metric if ce_token_cnt > 0 else None
for ce_token_cnt, metric in zip(
ce_target_tokens, AverageMetric.many(ce_loss, target_tokens)
)
], # type: ignore
)
self.record_local_metric(
'ct_loss',
[
metric if metric > 0.0 else None
for metric in AverageMetric.many(ct_loss, target_tokens)
],
metric if ct_token_cnt > 0 else None
for ct_token_cnt, metric in zip(
ct_target_tokens, AverageMetric.many(ct_loss, target_tokens)
)
], # type: ignore
)
# token-wise accuracy
self.record_local_metric(
'token_acc', AverageMetric.many(correct, target_tokens)
'token_acc',
[
metric if per_target_token > 0 else None
for per_target_token, metric in zip(
target_tokens, AverageMetric.many(correct, target_tokens)
)
], # type: ignore
)
self.record_local_metric(
'token_acc_pos',
Expand All @@ -269,7 +301,22 @@ def loss_reshape(loss):
)
# perplexity
self.record_local_metric(
'ppl_debug', PPLMetric.many(ce_loss + ct_loss, target_tokens)
'ppl_debug',
[
metric if per_target_token > 0 else None
for per_target_token, metric in zip(
target_tokens, PPLMetric.many(ce_loss + ct_loss, target_tokens)
)
], # type: ignore
)
self.record_local_metric(
'ppl_ce',
[
metric if ce_token_cnt > 0 else None
for ce_token_cnt, metric in zip(
ce_target_tokens, PPLMetric.many(ce_loss, ce_target_tokens)
)
], # type: ignore
)
self.record_local_metric(
'ppl_pos',
Expand All @@ -290,6 +337,17 @@ def loss_reshape(loss):
],
)

# record sample size
self.record_local_metric(
'ce_target_tokens', AverageMetric.many(ce_target_tokens)
)
self.record_local_metric(
'ct_target_tokens', AverageMetric.many(ct_target_tokens)
)
self.record_local_metric(
'total_target_tokens', AverageMetric.many(target_tokens)
)

# actually do backwards loss
loss = loss.sum()
loss /= target_tokens.sum() # average loss per token
Expand Down

0 comments on commit dea6377

Please sign in to comment.