Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/unittests #1950

Merged
merged 17 commits into from
Aug 17, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add pytest conftest for migrating from unittest to pytest
dennisbader committed Aug 10, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 5c0742d3b844c6b5b83f0df382fa2b27de75b915
20 changes: 0 additions & 20 deletions darts/tests/base_test_class2.py

This file was deleted.

37 changes: 37 additions & 0 deletions darts/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging
import shutil
import tempfile

import pytest


@pytest.fixture(scope="session", autouse=True)
def set_up_tests(request):
logging.disable(logging.CRITICAL)

def tear_down_tests():
try:
shutil.rmtree(".darts")
except FileNotFoundError:
pass

request.addfinalizer(tear_down_tests)


@pytest.fixture(scope="module")
def tmpdir_module():
"""Sets up a temporary directory that will be dunped after the test module (script) finished."""
temp_work_dir = tempfile.mkdtemp(prefix="darts")
yield temp_work_dir
shutil.rmtree(temp_work_dir)
dennisbader marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture(scope="session")
def tfm_kwargs():
return {
"pl_trainer_kwargs": {
"accelerator": "cpu",
"enable_progress_bar": False,
"enable_model_summary": False,
}
}
2 changes: 2 additions & 0 deletions darts/tests/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# prevent PyTorch Lightning from using GPU (M1 system compatibility)
tfm_kwargs = {"pl_trainer_kwargs": {"accelerator": "cpu"}}
32 changes: 32 additions & 0 deletions darts/tests/test_examples/test_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os.path

from darts.datasets import AirPassengersDataset
from darts.models import TFTModel


def test_add():
assert 1 + 2 == 3


def test_save(tmpdir_module, tfm_kwargs):
s = AirPassengersDataset().load()
m = TFTModel(
input_chunk_length=12,
output_chunk_length=6,
add_relative_index=True,
**tfm_kwargs
)
m.fit(s, epochs=1)
m.save(os.path.join(tmpdir_module, "tft1.pt"))


def test_save2(tmpdir_module, tfm_kwargs):
s = AirPassengersDataset().load()
m = TFTModel(
input_chunk_length=12,
output_chunk_length=6,
add_relative_index=True,
**tfm_kwargs
)
m.fit(s, epochs=1)
m.save(os.path.join(tmpdir_module, "tft2.pt"))
40 changes: 40 additions & 0 deletions darts/tests/test_examples/test_example_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os.path

import pytest

from darts.datasets import AirPassengersDataset
from darts.models import TFTModel


def test_add():
assert 1 + 2 == 3


def test_save(tmpdir_module, tfm_kwargs):
s = AirPassengersDataset().load()
m = TFTModel(
input_chunk_length=12,
output_chunk_length=6,
add_relative_index=True,
**tfm_kwargs
)
m.fit(s, epochs=1, verbose=False)
m.save(os.path.join(tmpdir_module, "tft1.pt"))


@pytest.mark.parametrize(
"model_config", [(0, {"full_attention": False}), (1, {"full_attention": True})]
)
def test_save2(tmpdir_module, tfm_kwargs, model_config):
idx, idx_model_kwargs = model_config
assert idx, "blabla error message"

s = AirPassengersDataset().load()
m = TFTModel(
input_chunk_length=12,
output_chunk_length=6,
add_relative_index=True,
**tfm_kwargs
)
m.fit(s, epochs=1, verbose=False)
m.save(os.path.join(tmpdir_module, "tft2.pt"))