Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[Cringe] edits to metric logging #5036

Merged
merged 1 commit into from
May 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 72 additions & 14 deletions projects/cringe/cringe_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,33 @@ def __init__(
num_pos_predictions=1,
detach_positives_during_ct=False,
train_ct_on_positive_examples=False,
train_ce_on_positive_examples=True,
**kwargs,
):
super().__init__(**kwargs)
self.ct_loss_weight = ct_loss_weight
self.num_pos_predictions = num_pos_predictions
self.detach_positives_during_ct = detach_positives_during_ct
self.train_ct_on_positive_examples = train_ct_on_positive_examples
self.train_ce_on_positive_examples = train_ce_on_positive_examples

def __call__(self, x, y, classifier_labels=None, **kwargs):
if classifier_labels is None:
classifier_labels = -torch.ones_like(y).to(y.device)

# turn no-class provided label (-1) into positive label (1)
classifier_labels_ce = torch.abs(classifier_labels)
if not self.train_ce_on_positive_examples:
# only train CE on no-class labels
classifier_labels_ce = classifier_labels.eq(-1)

if self.train_ct_on_positive_examples:
# no-class (-1 to 0), positive (1 to 1), negative (0 to 1)
classifier_labels_ct = torch.clamp(classifier_labels + 1, max=1)
else:
# no-class (-1 to 0), positive (1 to 0), negative (0 to 1)
classifier_labels_ct = torch.abs(torch.abs(classifier_labels) - 1)
classifier_labels_ct = classifier_labels_ct.bool()

ce_loss = super().__call__(x, y, **kwargs)
# multiply with classifier labels to not train with negative feedback (0)
Expand All @@ -63,7 +69,7 @@ def __call__(self, x, y, classifier_labels=None, **kwargs):
# compute the contrastive loss part for the negative labels
# first, get the positives as the top predictions != target
preds = torch.topk(x, k=self.num_pos_predictions + 1, axis=-1)
y_rep = y.unsqueeze(1).repeat(1, self.num_pos_predictions + 1)
y_rep = y.unsqueeze(-1).repeat(1, self.num_pos_predictions + 1)
logits = preds.values - (preds.indices == y_rep) * 1e10

# if the positive is not in the first k predictions, mask out
Expand Down Expand Up @@ -104,7 +110,7 @@ def __call__(self, x, y, classifier_labels=None, **kwargs):

loss = ce_loss + self.ct_loss_weight * ct_loss

return loss, ce_loss, ct_loss
return loss, ce_loss, ct_loss, classifier_labels_ce, classifier_labels_ct


class ContrastiveTransformerGeneratorAgent(TransformerGeneratorAgent):
Expand Down Expand Up @@ -137,12 +143,18 @@ def add_cmdline_args(
default=False,
)
parser.add_argument(
'--train-ct-on-positive_examples',
'--train-ct-on-positive-examples',
type=bool,
help='If true, we train with the positive examples in the contrastive loss'
' (with the negatives being the top-k sampled from the model).',
default=False,
)
parser.add_argument(
'--train-ce-on-positive-examples',
type=bool,
help='If true, we train with the positive examples in the cross entropy loss.',
default=True,
)
super().add_cmdline_args(parser, partial_opt=partial_opt)
return agent

Expand All @@ -153,6 +165,9 @@ def build_criterion(self):
detach_positives_during_ct=self.opt['ct_detach_positives'],
ignore_index=self.NULL_IDX,
train_ct_on_positive_examples=self.opt['train_ct_on_positive_examples'],
train_ce_on_positive_examples=self.opt.get(
'train_ce_on_positive_examples', True
),
reduction='none',
)

Expand Down Expand Up @@ -210,7 +225,7 @@ def compute_loss(self, batch, return_output=False):
model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
scores, preds, *_ = model_output
score_view = scores.reshape(-1, scores.size(-1))
(loss, ce_loss, ct_loss,) = self.criterion(
(loss, ce_loss, ct_loss, ce_mask, ct_mask) = self.criterion(
score_view,
batch.label_vec.view(-1),
batch.classifier_label.repeat(1, scores.shape[1])
Expand All @@ -225,8 +240,15 @@ def loss_reshape(loss):
ce_loss = loss_reshape(ce_loss)
ct_loss = loss_reshape(ct_loss)
notnull = batch.label_vec.ne(self.NULL_IDX)
target_tokens = notnull.long().sum(dim=-1)
correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)
ce_mask = torch.logical_and(notnull, ce_mask.view(-1, batch.label_vec.size(-1)))
ct_mask = torch.logical_and(notnull, ct_mask.view(-1, batch.label_vec.size(-1)))
# number of tokens in each examples for cross entropy or cringe loss.
metric_notnull = torch.logical_or(ce_mask, ct_mask)
target_tokens = metric_notnull.long().sum(dim=-1)
ce_target_tokens = ce_mask.long().sum(dim=-1)
ct_target_tokens = ct_mask.long().sum(dim=-1)

correct = ((batch.label_vec == preds) * metric_notnull).sum(dim=-1)

pos_labels = (torch.abs(batch.classifier_label) == 1).view(-1)
neg_labels = (torch.abs(batch.classifier_label) == 0).view(-1)
Expand All @@ -238,20 +260,30 @@ def loss_reshape(loss):
self.record_local_metric(
'ce_loss',
[
metric if metric > 0.0 else None
for metric in AverageMetric.many(ce_loss, target_tokens)
],
metric if ce_token_cnt > 0 else None
for ce_token_cnt, metric in zip(
ce_target_tokens, AverageMetric.many(ce_loss, target_tokens)
)
], # type: ignore
)
self.record_local_metric(
'ct_loss',
[
metric if metric > 0.0 else None
for metric in AverageMetric.many(ct_loss, target_tokens)
],
metric if ct_token_cnt > 0 else None
for ct_token_cnt, metric in zip(
ct_target_tokens, AverageMetric.many(ct_loss, target_tokens)
)
], # type: ignore
)
# token-wise accuracy
self.record_local_metric(
'token_acc', AverageMetric.many(correct, target_tokens)
'token_acc',
[
metric if per_target_token > 0 else None
for per_target_token, metric in zip(
target_tokens, AverageMetric.many(correct, target_tokens)
)
], # type: ignore
)
self.record_local_metric(
'token_acc_pos',
Expand All @@ -269,7 +301,22 @@ def loss_reshape(loss):
)
# perplexity
self.record_local_metric(
'ppl_debug', PPLMetric.many(ce_loss + ct_loss, target_tokens)
'ppl_debug',
[
metric if per_target_token > 0 else None
for per_target_token, metric in zip(
target_tokens, PPLMetric.many(ce_loss + ct_loss, target_tokens)
)
], # type: ignore
)
self.record_local_metric(
'ppl_ce',
[
metric if ce_token_cnt > 0 else None
for ce_token_cnt, metric in zip(
ce_target_tokens, PPLMetric.many(ce_loss, ce_target_tokens)
)
], # type: ignore
)
self.record_local_metric(
'ppl_pos',
Expand All @@ -290,6 +337,17 @@ def loss_reshape(loss):
],
)

# record sample size
self.record_local_metric(
'ce_target_tokens', AverageMetric.many(ce_target_tokens)
)
self.record_local_metric(
'ct_target_tokens', AverageMetric.many(ct_target_tokens)
)
self.record_local_metric(
'total_target_tokens', AverageMetric.many(target_tokens)
)

# actually do backwards loss
loss = loss.sum()
loss /= target_tokens.sum() # average loss per token
Expand Down