Skip to content

Commit

Permalink
prune SimpleModel (#5862)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Feb 8, 2021
1 parent 26bc754 commit 42812bb
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 106 deletions.
1 change: 0 additions & 1 deletion tests/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@
from tests.base.boring_model import BoringDataModule, BoringModel, RandomDataset # noqa: F401
from tests.base.datasets import TrialMNIST # noqa: F401
from tests.base.model_template import EvalModelTemplate, GenericEvalModelTemplate # noqa: F401
from tests.base.simple_model import SimpleModule # noqa: F401
98 changes: 0 additions & 98 deletions tests/base/simple_model.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_multi_cpu_model_ddp(tmpdir):
)

model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, min_acc=0.05)
tpipes.run_model_test(trainer_options, model, on_gpu=False)


def test_lbfgs_cpu_model(tmpdir):
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/flags/test_val_check_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import pytest

from pytorch_lightning.trainer import Trainer
from tests.base import SimpleModule
from tests.base import BoringModel


@pytest.mark.parametrize('max_epochs', [1, 2, 3])
def test_val_check_interval_1(tmpdir, max_epochs):

class TestModel(SimpleModule):
class TestModel(BoringModel):

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -48,7 +48,7 @@ def on_validation_epoch_start(self) -> None:
@pytest.mark.parametrize('max_epochs', [1, 2, 3])
def test_val_check_interval_quarter(tmpdir, max_epochs):

class TestModel(SimpleModule):
class TestModel(BoringModel):

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -76,7 +76,7 @@ def on_validation_epoch_start(self) -> None:
@pytest.mark.parametrize('max_epochs', [1, 2, 3])
def test_val_check_interval_third(tmpdir, max_epochs):

class TestModel(SimpleModule):
class TestModel(BoringModel):

def __init__(self):
super().__init__()
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/logging_/test_eval_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import BoringModel, RandomDataset, SimpleModule
from tests.base import BoringModel, RandomDataset
from tests.base.deterministic_model import DeterministicModel


Expand Down Expand Up @@ -358,7 +358,7 @@ def test_epoch_end(self, outputs):

def test_monitor_val_epoch_end(tmpdir):
epoch_min_loss_override = 0
model = SimpleModule()
model = BoringModel()
checkpoint_callback = callbacks.ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="avg_val_loss")
trainer = Trainer(
max_epochs=epoch_min_loss_override + 2,
Expand Down

0 comments on commit 42812bb

Please sign in to comment.