Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jun 4, 2020
1 parent ae288e0 commit 0662707
Showing 1 changed file with 14 additions and 21 deletions.
35 changes: 14 additions & 21 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,7 @@
from tests.base import EvalModelTemplate


class OmegaConfModel(EvalModelTemplate):
def __init__(self, ogc):
super().__init__()
self.ogc = ogc
self.size = ogc.list[0]


def test_class_nesting(tmpdir):
def test_class_nesting():

class MyModule(LightningModule):
def forward(self):
Expand Down Expand Up @@ -47,6 +40,12 @@ def test2(self):

@pytest.mark.xfail(sys.version_info >= (3, 8), reason='OmegaConf only for Python >= 3.8')
def test_omegaconf(tmpdir):
class OmegaConfModel(EvalModelTemplate):
def __init__(self, ogc):
super().__init__()
self.ogc = ogc
self.size = ogc.list[0]

conf = OmegaConf.create({"k": "v", "list": [15.4, {"a": "1", "b": "2"}]})
model = OmegaConfModel(conf)

Expand All @@ -64,37 +63,34 @@ class SubClassEvalModel(EvalModelTemplate):

def __init__(self, *args, subclass_arg=1200, **kwargs):
super().__init__(*args, **kwargs)
self.subclass_arg = subclass_arg
self.save_hyperparameters()


class SubSubClassEvalModel(SubClassEvalModel):
pass


class UnconventionalArgsEvalModel(EvalModelTemplate):
""" A model that has unconventional names for "self", "*args" and "**kwargs". """

def __init__(obj, *more_args, other_arg=300, **more_kwargs):
# intentionally named obj
super().__init__(*more_args, **more_kwargs)
obj.other_arg = other_arg
other_arg = 321
obj.save_hyperparameters()


class SubSubClassEvalModel(SubClassEvalModel):
pass
other_arg = 321


class AggSubClassEvalModel(SubClassEvalModel):

def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
super().__init__(*args, **kwargs)
self.my_loss = my_loss
self.save_hyperparameters()


class DictConfSubClassEvalModel(SubClassEvalModel):
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param='something')), **kwargs):
super().__init__(*args, **kwargs)
self.dict_conf = dict_conf
self.save_hyperparameters()


@pytest.mark.parametrize("cls", [
Expand Down Expand Up @@ -187,21 +183,18 @@ def test_collect_init_arguments_with_local_vars(cls):
class NamespaceArgModel(EvalModelTemplate):
def __init__(self, hparams: Namespace):
super().__init__()
# manually
self.save_hyperparameters(hparams)


class DictArgModel(EvalModelTemplate):
def __init__(self, some_dict: dict):
super().__init__()
# manually
self.save_hyperparameters(some_dict)


class OmegaConfArgModel(EvalModelTemplate):
def __init__(self, conf: OmegaConf):
super().__init__()
# manually
self.save_hyperparameters(conf)


Expand Down Expand Up @@ -245,4 +238,4 @@ def __init__(self, arg1, arg2):
def test_single_config_models_fail(tmpdir, cls, config):
""" Test fail on passing unsupported config type. """
with pytest.raises(ValueError):
model = cls(**config)
_ = cls(**config)

0 comments on commit 0662707

Please sign in to comment.