From 0d08e6d6e6326e6f8b7ff398449fde82c2f9ff83 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 16 Nov 2020 15:11:16 +0000 Subject: [PATCH 1/2] Ensure sync across val/test step when using DDP --- pl_bolts/models/self_supervised/ssl_finetuner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pl_bolts/models/self_supervised/ssl_finetuner.py b/pl_bolts/models/self_supervised/ssl_finetuner.py index 94768257a7..d939601638 100644 --- a/pl_bolts/models/self_supervised/ssl_finetuner.py +++ b/pl_bolts/models/self_supervised/ssl_finetuner.py @@ -105,8 +105,8 @@ def validation_step(self, batch, batch_idx): loss, logits, y = self.shared_step(batch) acc = self.val_acc(logits, y) - self.log('val_loss', loss, prog_bar=True) - self.log('val_acc', self.val_acc) + self.log('val_loss', loss, prog_bar=True, sync_dist=True) + self.log('val_acc', self.val_acc, sync_dist=True) return loss @@ -114,8 +114,8 @@ def test_step(self, batch, batch_idx): loss, logits, y = self.shared_step(batch) acc = self.test_acc(logits, y) - self.log('test_loss', loss) - self.log('test_acc', self.test_acc) + self.log('test_loss', loss, sync_dist=True) + self.log('test_acc', self.test_acc, sync_dist=True) return loss From df0c174e6317c6c20b666d4c875d8f4d5337e345 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 16 Nov 2020 16:44:02 +0000 Subject: [PATCH 2/2] Remove sync_dist from class metrics as they are automatically reduced --- pl_bolts/models/self_supervised/ssl_finetuner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/self_supervised/ssl_finetuner.py b/pl_bolts/models/self_supervised/ssl_finetuner.py index d939601638..80dd6c94ac 100644 --- a/pl_bolts/models/self_supervised/ssl_finetuner.py +++ b/pl_bolts/models/self_supervised/ssl_finetuner.py @@ -106,7 +106,7 @@ def validation_step(self, batch, batch_idx): acc = self.val_acc(logits, y) self.log('val_loss', loss, prog_bar=True, sync_dist=True) - self.log('val_acc', self.val_acc, sync_dist=True) + self.log('val_acc', self.val_acc) return loss @@ -115,7 +115,7 @@ def test_step(self, batch, batch_idx): acc = self.test_acc(logits, y) self.log('test_loss', loss, sync_dist=True) - self.log('test_acc', self.test_acc, sync_dist=True) + self.log('test_acc', self.test_acc) return loss