From 50a23d272cd09ec9024609b5ec1ffdb9d4115c46 Mon Sep 17 00:00:00 2001 From: Mojtaba Date: Wed, 30 Nov 2022 12:53:23 -0800 Subject: [PATCH 1/2] cat or concat --- projects/cringe/cringe_loss.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/projects/cringe/cringe_loss.py b/projects/cringe/cringe_loss.py index 51f9cb3df93..4afcea99df6 100644 --- a/projects/cringe/cringe_loss.py +++ b/projects/cringe/cringe_loss.py @@ -88,9 +88,15 @@ def __call__(self, x, y, classifier_labels=None, **kwargs): # concatenate the logits of the preds with the actual label's logits x_target = x[torch.arange(x.shape[0]), y] - x_ct = torch.concat( - [x_target.unsqueeze(1), sample_preds_values.unsqueeze(1)], -1 - ) + if hasattr(torch, 'concat'): + # torch > 1.10 + x_ct = torch.concat( + [x_target.unsqueeze(1), sample_preds_values.unsqueeze(1)], -1 + ) + else: + x_ct = torch.cat( + [x_target.unsqueeze(1), sample_preds_values.unsqueeze(1)], -1 + ) # get the y's for the x_ct (the correct label is index 0 if # the target is positive and index 1 if the target is negative) y_ct = torch.abs(torch.abs(classifier_labels) - 1).type(y.dtype).to(x_ct.device) From f374a5491e1606cdf56758bf7d77d0ab767fc0c4 Mon Sep 17 00:00:00 2001 From: Mojtaba Date: Thu, 1 Dec 2022 11:01:59 -0800 Subject: [PATCH 2/2] back to cat --- projects/cringe/cringe_loss.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/projects/cringe/cringe_loss.py b/projects/cringe/cringe_loss.py index 4afcea99df6..0e0916a0080 100644 --- a/projects/cringe/cringe_loss.py +++ b/projects/cringe/cringe_loss.py @@ -88,15 +88,7 @@ def __call__(self, x, y, classifier_labels=None, **kwargs): # concatenate the logits of the preds with the actual label's logits x_target = x[torch.arange(x.shape[0]), y] - if hasattr(torch, 'concat'): - # torch > 1.10 - x_ct = torch.concat( - [x_target.unsqueeze(1), sample_preds_values.unsqueeze(1)], -1 - ) - else: - x_ct = torch.cat( - [x_target.unsqueeze(1), sample_preds_values.unsqueeze(1)], -1 - ) + x_ct = torch.cat([x_target.unsqueeze(1), sample_preds_values.unsqueeze(1)], -1) # get the y's for the x_ct (the correct label is index 0 if # the target is positive and index 1 if the target is negative) y_ct = torch.abs(torch.abs(classifier_labels) - 1).type(y.dtype).to(x_ct.device)