From 037434ad066974c262a92628c8098494ce05f72b Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 4 Nov 2020 07:41:08 +0900 Subject: [PATCH] Update tests for SemSegment --- tests/models/test_vision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index d32c09b9cc..2c4e675989 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -67,10 +67,10 @@ def train_dataloader(self): dm = DummyDataModule() - model = SemSegment(datamodule=dm, num_classes=19) + model = SemSegment(num_classes=19) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1) - trainer.fit(model) + trainer.fit(model, dm) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0