From fa5de91818b9399409fd4b0eab74a3f31abdda25 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 4 Nov 2020 07:40:44 +0900 Subject: [PATCH 1/3] Decouple dms from SemSegment --- pl_bolts/models/vision/segmentation.py | 27 +++++++++++--------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/pl_bolts/models/vision/segmentation.py b/pl_bolts/models/vision/segmentation.py index d44c632356..e89692a157 100644 --- a/pl_bolts/models/vision/segmentation.py +++ b/pl_bolts/models/vision/segmentation.py @@ -9,13 +9,12 @@ class SemSegment(pl.LightningModule): def __init__( - self, - datamodule: pl.LightningDataModule = None, - lr: float = 0.01, - num_classes: int = 19, - num_layers: int = 5, - features_start: int = 64, - bilinear: bool = False + self, + lr: float = 0.01, + num_classes: int = 19, + num_layers: int = 5, + features_start: int = 64, + bilinear: bool = False ): """ Basic model for semantic segmentation. Uses UNet architecture by default. @@ -29,7 +28,6 @@ def __init__( - `Annika Brundyn `_ Args: - datamodule: LightningDataModule num_layers: number of layers in each side of U-net (default 5) features_start: number of features in first layer (default 64) bilinear: whether to use bilinear interpolation (True) or transposed convolutions (default) for upsampling. @@ -37,9 +35,6 @@ def __init__( """ super().__init__() - assert datamodule - self.datamodule = datamodule - self.num_classes = num_classes self.num_layers = num_layers self.features_start = features_start @@ -84,7 +79,6 @@ def configure_optimizers(self): @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") parser.add_argument("--lr", type=float, default=0.01, help="adam: learning rate") parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net") parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer") @@ -100,23 +94,24 @@ def cli_main(): pl.seed_everything(1234) parser = ArgumentParser() - # trainer args parser = pl.Trainer.add_argparse_args(parser) - # model args parser = SemSegment.add_model_specific_args(parser) + # datamodule args + parser = KittiDataModule.add_argparse_args(parser) + args = parser.parse_args() # data dm = KittiDataModule(args.data_dir).from_argparse_args(args) # model - model = SemSegment(**args.__dict__, datamodule=dm) + model = SemSegment(**args.__dict__) # train trainer = pl.Trainer().from_argparse_args(args) - trainer.fit(model) + trainer.fit(model, dm) if __name__ == '__main__': From 58d944b9fe10121a8709c3b91f423140d4fd6bac Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 4 Nov 2020 07:41:08 +0900 Subject: [PATCH 2/3] Update tests for SemSegment --- tests/models/test_vision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 055dafb5c2..90971df32f 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -66,10 +66,10 @@ def train_dataloader(self): dm = DummyDataModule() - model = SemSegment(datamodule=dm, num_classes=19) + model = SemSegment(num_classes=19) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1) - trainer.fit(model) + trainer.fit(model, dm) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0 From eff91501ad4ce70e5dd410acf5f7b86c655bd391 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 7 Nov 2020 11:09:23 +0900 Subject: [PATCH 3/3] Push an empty commit to rerun ci