diff --git a/src/gnn_tracking/utils/lightning.py b/src/gnn_tracking/utils/lightning.py index fcf145f1..621aaa4d 100644 --- a/src/gnn_tracking/utils/lightning.py +++ b/src/gnn_tracking/utils/lightning.py @@ -132,10 +132,10 @@ def get_model( freeze: Whether to freeze the model """ if not chkpt_path: - return + return None lm = get_lightning_module(class_path, chkpt_path, freeze=freeze) if lm is None: - return + return None return lm.model diff --git a/src/gnn_tracking/utils/timing.py b/src/gnn_tracking/utils/timing.py index ec12c08e..883cdcc9 100644 --- a/src/gnn_tracking/utils/timing.py +++ b/src/gnn_tracking/utils/timing.py @@ -29,4 +29,4 @@ def timing(name="Codeblock", logger=None): try: yield finally: - logger.info(f"{name} took {t():.2f} seconds") + logger.info("%s took %f seconds", name, t()) diff --git a/tests/test_tcn_training.py b/tests/test_tcn_training.py index aa99f463..27d88361 100644 --- a/tests/test_tcn_training.py +++ b/tests/test_tcn_training.py @@ -83,7 +83,7 @@ def __post_init__(self): @pytest.mark.parametrize("t", _test_train_test_cases) -def test_train(tmp_path, built_graphs, t: TestTrainCase) -> None: +def test_train(built_graphs, t: TestTrainCase) -> None: fix_seeds() _, graph_builder = built_graphs g = graph_builder.data_list[0]