Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use pytest to skip torch tests #2307

Merged
merged 2 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,591 changes: 1,290 additions & 1,301 deletions darts/tests/datasets/test_datasets.py

Large diffs are not rendered by default.

855 changes: 418 additions & 437 deletions darts/tests/explainability/test_tft_explainer.py

Large diffs are not rendered by default.

31 changes: 16 additions & 15 deletions darts/tests/models/components/glu_variants.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import pytest

from darts.logging import get_logger

logger = get_logger(__name__)

try:
import torch

TORCH_AVAILABLE = True
except ImportError:
logger.warning("Torch not available. Loss tests will be skipped.")
TORCH_AVAILABLE = False


if TORCH_AVAILABLE:
from darts.models.components import glu_variants
from darts.models.components.glu_variants import GLU_FFN
except ImportError:
pytest.skip(
f"Torch not available. {__name__} tests will be skipped.",
allow_module_level=True,
)


class TestFFN:
def test_ffn(self):
for FeedForward_network in GLU_FFN:
self.feed_forward_block = getattr(glu_variants, FeedForward_network)(
d_model=4, d_ff=16, dropout=0.1
)
class TestFFN:
def test_ffn(self):
for FeedForward_network in GLU_FFN:
self.feed_forward_block = getattr(glu_variants, FeedForward_network)(
d_model=4, d_ff=16, dropout=0.1
)

inputs = torch.zeros(1, 4, 4)
self.feed_forward_block(x=inputs)
inputs = torch.zeros(1, 4, 4)
self.feed_forward_block(x=inputs)
57 changes: 28 additions & 29 deletions darts/tests/models/components/test_layer_norm_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,45 @@
try:
import torch

TORCH_AVAILABLE = True
except ImportError:
logger.warning("Torch not available. Loss tests will be skipped.")
TORCH_AVAILABLE = False


if TORCH_AVAILABLE:
from darts.models.components.layer_norm_variants import (
LayerNorm,
LayerNormNoBias,
RINorm,
RMSNorm,
)
except ImportError:
pytest.skip(
f"Torch not available. {__name__} tests will be skipped.",
allow_module_level=True,
)


class TestLayerNormVariants:
def test_lnv(self):
for layer_norm in [RMSNorm, LayerNorm, LayerNormNoBias]:
ln = layer_norm(4)
inputs = torch.zeros(1, 4, 4)
ln(inputs)
class TestLayerNormVariants:
def test_lnv(self):
for layer_norm in [RMSNorm, LayerNorm, LayerNormNoBias]:
ln = layer_norm(4)
inputs = torch.zeros(1, 4, 4)
ln(inputs)

def test_rin(self):
def test_rin(self):

np.random.seed(42)
torch.manual_seed(42)
np.random.seed(42)
torch.manual_seed(42)

x = torch.randn(3, 4, 7)
affine_options = [True, False]
x = torch.randn(3, 4, 7)
affine_options = [True, False]

# test with and without affine and correct input dim
for affine in affine_options:
# test with and without affine and correct input dim
for affine in affine_options:

rin = RINorm(input_dim=7, affine=affine)
x_norm = rin(x)
rin = RINorm(input_dim=7, affine=affine)
x_norm = rin(x)

# expand dims to simulate probablistic forecasting
x_denorm = rin.inverse(x_norm.view(x_norm.shape + (1,))).squeeze(-1)
assert torch.all(torch.isclose(x, x_denorm)).item()
# expand dims to simulate probablistic forecasting
x_denorm = rin.inverse(x_norm.view(x_norm.shape + (1,))).squeeze(-1)
assert torch.all(torch.isclose(x, x_denorm)).item()

# try invalid input_dim
rin = RINorm(input_dim=3, affine=True)
with pytest.raises(RuntimeError):
x_norm = rin(x)
# try invalid input_dim
rin = RINorm(input_dim=3, affine=True)
with pytest.raises(RuntimeError):
x_norm = rin(x)
Loading
Loading