Skip to content

Commit

Permalink
prune Results usage in notebooks (#3911)
Browse files Browse the repository at this point in the history
* notebooks

* notebooks
  • Loading branch information
Borda authored Oct 6, 2020
1 parent c510a7f commit 4722cc0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 27 deletions.
13 changes: 6 additions & 7 deletions notebooks/01-mnist-hello-world.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
" def training_step(self, batch, batch_nb):\n",
" x, y = batch\n",
" loss = F.cross_entropy(self(x), y)\n",
" return pl.TrainResult(loss)\n",
" return loss\n",
"\n",
" def configure_optimizers(self):\n",
" return torch.optim.Adam(self.parameters(), lr=0.02)"
Expand Down Expand Up @@ -250,20 +250,19 @@
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" return pl.TrainResult(loss)\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n",
" result = pl.EvalResult(checkpoint_on=loss)\n",
"\n",
" # Calling result.log will surface up scalars for you in TensorBoard\n",
" result.log('val_loss', loss, prog_bar=True)\n",
" result.log('val_acc', acc, prog_bar=True)\n",
" return result\n",
" # Calling self.log will surface up scalars for you in TensorBoard\n",
" self.log('val_loss', loss, prog_bar=True)\n",
" self.log('val_acc', acc, prog_bar=True)\n",
" return loss\n",
"\n",
" def test_step(self, batch, batch_idx):\n",
" # Here we just reuse the validation_step for testing\n",
Expand Down
18 changes: 8 additions & 10 deletions notebooks/02-datamodules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,17 @@
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" return pl.TrainResult(loss)\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n",
" result = pl.EvalResult(checkpoint_on=loss)\n",
" result.log('val_loss', loss, prog_bar=True)\n",
" result.log('val_acc', acc, prog_bar=True)\n",
" return result\n",
" self.log('val_loss', loss, prog_bar=True)\n",
" self.log('val_acc', acc, prog_bar=True)\n",
" return loss\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
Expand Down Expand Up @@ -394,7 +393,7 @@
" x, y = batch\n",
" logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n",
" return pl.TrainResult(loss)\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
"\n",
Expand All @@ -403,10 +402,9 @@
" loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n",
" result = pl.EvalResult(checkpoint_on=loss)\n",
" result.log('val_loss', loss, prog_bar=True)\n",
" result.log('val_acc', acc, prog_bar=True)\n",
" return result\n",
" self.log('val_loss', loss, prog_bar=True)\n",
" self.log('val_acc', acc, prog_bar=True)\n",
" return loss\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
Expand Down
17 changes: 7 additions & 10 deletions notebooks/04-transformers-text-classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
" def training_step(self, batch, batch_idx):\n",
" outputs = self(**batch)\n",
" loss = outputs[0]\n",
" return pl.TrainResult(loss)\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx, dataloader_idx=0):\n",
" outputs = self(**batch)\n",
Expand All @@ -322,20 +322,17 @@
" preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n",
" labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n",
" loss = torch.stack([x['loss'] for x in output]).mean()\n",
" if i == 0:\n",
" result = pl.EvalResult(checkpoint_on=loss)\n",
" result.log(f'val_loss_{split}', loss, prog_bar=True)\n",
" self.log(f'val_loss_{split}', loss, prog_bar=True)\n",
" split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n",
" result.log_dict(split_metrics, prog_bar=True)\n",
" return result\n",
" self.log_dict(split_metrics, prog_bar=True)\n",
" return loss\n",
"\n",
" preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n",
" labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n",
" loss = torch.stack([x['loss'] for x in outputs]).mean()\n",
" result = pl.EvalResult(checkpoint_on=loss)\n",
" result.log('val_loss', loss, prog_bar=True)\n",
" result.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n",
" return result\n",
" self.log('val_loss', loss, prog_bar=True)\n",
" self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n",
" return loss\n",
"\n",
" def setup(self, stage):\n",
" if stage == 'fit':\n",
Expand Down

0 comments on commit 4722cc0

Please sign in to comment.