Skip to content

Commit

Permalink
Decouple DataModules from Models - SegSegment (#332)
Browse files Browse the repository at this point in the history
* Decouple dms from SemSegment

* Update tests for SemSegment

* Push an empty commit to rerun ci
  • Loading branch information
akihironitta authored Nov 9, 2020
1 parent 5f08376 commit 6895350
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 18 deletions.
27 changes: 11 additions & 16 deletions pl_bolts/models/vision/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@

class SemSegment(pl.LightningModule):
def __init__(
self,
datamodule: pl.LightningDataModule = None,
lr: float = 0.01,
num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
self,
lr: float = 0.01,
num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
):
"""
Basic model for semantic segmentation. Uses UNet architecture by default.
Expand All @@ -29,17 +28,13 @@ def __init__(
- `Annika Brundyn <https://github.com/annikabrundyn>`_
Args:
datamodule: LightningDataModule
num_layers: number of layers in each side of U-net (default 5)
features_start: number of features in first layer (default 64)
bilinear: whether to use bilinear interpolation (True) or transposed convolutions (default) for upsampling.
lr: learning (default 0.01)
"""
super().__init__()

assert datamodule
self.datamodule = datamodule

self.num_classes = num_classes
self.num_layers = num_layers
self.features_start = features_start
Expand Down Expand Up @@ -84,7 +79,6 @@ def configure_optimizers(self):
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.01, help="adam: learning rate")
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
Expand All @@ -100,23 +94,24 @@ def cli_main():
pl.seed_everything(1234)

parser = ArgumentParser()

# trainer args
parser = pl.Trainer.add_argparse_args(parser)

# model args
parser = SemSegment.add_model_specific_args(parser)
# datamodule args
parser = KittiDataModule.add_argparse_args(parser)

args = parser.parse_args()

# data
dm = KittiDataModule(args.data_dir).from_argparse_args(args)

# model
model = SemSegment(**args.__dict__, datamodule=dm)
model = SemSegment(**args.__dict__)

# train
trainer = pl.Trainer().from_argparse_args(args)
trainer.fit(model)
trainer.fit(model, dm)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def train_dataloader(self):

dm = DummyDataModule()

model = SemSegment(datamodule=dm, num_classes=19)
model = SemSegment(num_classes=19)

trainer = pl.Trainer(fast_dev_run=True, max_epochs=1)
trainer.fit(model)
trainer.fit(model, dm)
loss = trainer.progress_bar_dict['loss']

assert float(loss) > 0

0 comments on commit 6895350

Please sign in to comment.