Skip to content

Commit

Permalink
Decouple dms from SemSegment
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Nov 3, 2020
1 parent d4e6096 commit 9809624
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions pl_bolts/models/vision/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@

class SemSegment(pl.LightningModule):
def __init__(
self,
datamodule: pl.LightningDataModule = None,
lr: float = 0.01,
num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
self,
lr: float = 0.01,
num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
):
"""
Basic model for semantic segmentation. Uses UNet architecture by default.
Expand All @@ -29,17 +28,13 @@ def __init__(
- `Annika Brundyn <https://github.com/annikabrundyn>`_
Args:
datamodule: LightningDataModule
num_layers: number of layers in each side of U-net (default 5)
features_start: number of features in first layer (default 64)
bilinear: whether to use bilinear interpolation (True) or transposed convolutions (default) for upsampling.
lr: learning (default 0.01)
"""
super().__init__()

assert datamodule
self.datamodule = datamodule

self.num_classes = num_classes
self.num_layers = num_layers
self.features_start = features_start
Expand Down Expand Up @@ -84,7 +79,6 @@ def configure_optimizers(self):
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.01, help="adam: learning rate")
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
Expand All @@ -100,23 +94,24 @@ def cli_main():
pl.seed_everything(1234)

parser = ArgumentParser()

# trainer args
parser = pl.Trainer.add_argparse_args(parser)

# model args
parser = SemSegment.add_model_specific_args(parser)
# datamodule args
parser = KittiDataModule.add_argparse_args(parser)

args = parser.parse_args()

# data
dm = KittiDataModule(args.data_dir).from_argparse_args(args)

# model
model = SemSegment(**args.__dict__, datamodule=dm)
model = SemSegment(**args.__dict__)

# train
trainer = pl.Trainer().from_argparse_args(args)
trainer.fit(model)
trainer.fit(model, dm)


if __name__ == '__main__':
Expand Down

0 comments on commit 9809624

Please sign in to comment.