Skip to content

Commit

Permalink
Add type annotations for plt.models.regression (#595)
Browse files Browse the repository at this point in the history
* added type annotations for regression

* changes to pass tests

* Remove regression from mypy list

* Add type Any to kwargs

* Fix typing

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 16, 2021
1 parent 3636142 commit aad3d24
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 38 deletions.
35 changes: 18 additions & 17 deletions pl_bolts/models/regression/linear_regression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from argparse import ArgumentParser
from typing import Any, Dict, List, Tuple, Type

import pytorch_lightning as pl
import torch
from torch import nn
from torch import nn, Tensor
from torch.nn import functional as F
from torch.optim import Adam
from torch.optim.optimizer import Optimizer
Expand All @@ -20,32 +21,32 @@ def __init__(
output_dim: int = 1,
bias: bool = True,
learning_rate: float = 1e-4,
optimizer: Optimizer = Adam,
optimizer: Type[Optimizer] = Adam,
l1_strength: float = 0.0,
l2_strength: float = 0.0,
**kwargs
):
**kwargs: Any,
) -> None:
"""
Args:
input_dim: number of dimensions of the input (1+)
output_dim: number of dimensions of the output (default=1)
output_dim: number of dimensions of the output (default: ``1``)
bias: If false, will not use $+b$
learning_rate: learning_rate for the optimizer
optimizer: the optimizer to use (default='Adam')
l1_strength: L1 regularization strength (default=None)
l2_strength: L2 regularization strength (default=None)
optimizer: the optimizer to use (default: ``Adam``)
l1_strength: L1 regularization strength (default: ``0.0``)
l2_strength: L2 regularization strength (default: ``0.0``)
"""
super().__init__()
self.save_hyperparameters()
self.optimizer = optimizer

self.linear = nn.Linear(in_features=self.hparams.input_dim, out_features=self.hparams.output_dim, bias=bias)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
y_hat = self.linear(x)
return y_hat

def training_step(self, batch, batch_idx):
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
x, y = batch

# flatten any input
Expand All @@ -71,34 +72,34 @@ def training_step(self, batch, batch_idx):
progress_bar_metrics = tensorboard_logs
return {'loss': loss, 'log': tensorboard_logs, 'progress_bar': progress_bar_metrics}

def validation_step(self, batch, batch_idx):
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
return {'val_loss': F.mse_loss(y_hat, y)}

def validation_epoch_end(self, outputs):
def validation_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_mse_loss': val_loss}
progress_bar_metrics = tensorboard_logs
return {'val_loss': val_loss, 'log': tensorboard_logs, 'progress_bar': progress_bar_metrics}

def test_step(self, batch, batch_idx):
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
x, y = batch
y_hat = self(x)
return {'test_loss': F.mse_loss(y_hat, y)}

def test_epoch_end(self, outputs):
def test_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
test_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
tensorboard_logs = {'test_mse_loss': test_loss}
progress_bar_metrics = tensorboard_logs
return {'test_loss': test_loss, 'log': tensorboard_logs, 'progress_bar': progress_bar_metrics}

def configure_optimizers(self):
def configure_optimizers(self) -> Optimizer:
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--learning_rate', type=float, default=0.0001)
parser.add_argument('--input_dim', type=int, default=None)
Expand All @@ -108,7 +109,7 @@ def add_model_specific_args(parent_parser):
return parser


def cli_main():
def cli_main() -> None:
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule
from pl_bolts.utils import _SKLEARN_AVAILABLE

Expand Down
34 changes: 17 additions & 17 deletions pl_bolts/models/regression/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from argparse import ArgumentParser
from typing import Any, Dict, List, Tuple, Type

import pytorch_lightning as pl
import torch
from torch import nn
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.functional import softmax
from torch.optim import Adam
Expand All @@ -21,33 +22,33 @@ def __init__(
num_classes: int,
bias: bool = True,
learning_rate: float = 1e-4,
optimizer: Optimizer = Adam,
optimizer: Type[Optimizer] = Adam,
l1_strength: float = 0.0,
l2_strength: float = 0.0,
**kwargs
):
**kwargs: Any,
) -> None:
"""
Args:
input_dim: number of dimensions of the input (at least 1)
num_classes: number of class labels (binary: 2, multi-class: >2)
bias: specifies if a constant or intercept should be fitted (equivalent to fit_intercept in sklearn)
learning_rate: learning_rate for the optimizer
optimizer: the optimizer to use (default='Adam')
l1_strength: L1 regularization strength (default=None)
l2_strength: L2 regularization strength (default=None)
optimizer: the optimizer to use (default: ``Adam``)
l1_strength: L1 regularization strength (default: ``0.0``)
l2_strength: L2 regularization strength (default: ``0.0``)
"""
super().__init__()
self.save_hyperparameters()
self.optimizer = optimizer

self.linear = nn.Linear(in_features=self.hparams.input_dim, out_features=self.hparams.num_classes, bias=bias)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.linear(x)
y_hat = softmax(x)
return y_hat

def training_step(self, batch, batch_idx):
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
x, y = batch

# flatten any input
Expand All @@ -74,39 +75,38 @@ def training_step(self, batch, batch_idx):
progress_bar_metrics = tensorboard_logs
return {'loss': loss, 'log': tensorboard_logs, 'progress_bar': progress_bar_metrics}

def validation_step(self, batch, batch_idx):
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
acc = accuracy(y_hat, y)
return {'val_loss': F.cross_entropy(y_hat, y), 'acc': acc}

def validation_epoch_end(self, outputs):
def validation_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
acc = torch.stack([x['acc'] for x in outputs]).mean()
val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_ce_loss': val_loss, 'val_acc': acc}
progress_bar_metrics = tensorboard_logs
return {'val_loss': val_loss, 'log': tensorboard_logs, 'progress_bar': progress_bar_metrics}

def test_step(self, batch, batch_idx):
x, y = batch
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
x = x.view(x.size(0), -1)
y_hat = self(x)
acc = accuracy(y_hat, y)
return {'test_loss': F.cross_entropy(y_hat, y), 'acc': acc}

def test_epoch_end(self, outputs):
def test_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
acc = torch.stack([x['acc'] for x in outputs]).mean()
test_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
tensorboard_logs = {'test_ce_loss': test_loss, 'test_acc': acc}
progress_bar_metrics = tensorboard_logs
return {'test_loss': test_loss, 'log': tensorboard_logs, 'progress_bar': progress_bar_metrics}

def configure_optimizers(self):
def configure_optimizers(self) -> Optimizer:
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--learning_rate', type=float, default=0.0001)
parser.add_argument('--input_dim', type=int, default=None)
Expand All @@ -116,7 +116,7 @@ def add_model_specific_args(parent_parser):
return parser


def cli_main():
def cli_main() -> None:
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule
from pl_bolts.utils import _SKLEARN_AVAILABLE

Expand Down
5 changes: 1 addition & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ ignore_errors = True
[mypy-pl_bolts.metrics.*]
ignore_errors = True

[mypy-pl_bolts.models.*]
[mypy-pl_bolts.models.mnist_module]
ignore_errors = True

[mypy-pl_bolts.models.autoencoders.*]
Expand All @@ -129,9 +129,6 @@ ignore_errors = True
[mypy-pl_bolts.models.gans.*]
ignore_errors = True

[mypy-pl_bolts.models.regression.*]
ignore_errors = True

[mypy-pl_bolts.models.rl.*]
ignore_errors = True

Expand Down

0 comments on commit aad3d24

Please sign in to comment.