Skip to content

Commit

Permalink
tests: Update decorator tests to reflect changes in implementation
Browse files Browse the repository at this point in the history
Since we don't return lists anymore, we cannot check length or use __getitem__ calls.

Also, name checks are suspended for the time being, since the user should specify their
own names without us interfering / berating them for it.
  • Loading branch information
nicholasjng committed Jan 20, 2025
1 parent 955c68e commit 0b6f9e6
Showing 1 changed file with 16 additions and 32 deletions.
48 changes: 16 additions & 32 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

import nnbench
import nnbench.types
from nnbench import benchmark, parametrize, product
Expand Down Expand Up @@ -29,40 +27,26 @@ def test_parametrize():
def parametrized_benchmark(param: int) -> int:
return param

assert len(parametrized_benchmark) == 2
assert has_expected_args(parametrized_benchmark[0].fn, {"param": 1})
assert parametrized_benchmark[0].fn(**parametrized_benchmark[0].params) == 1
assert has_expected_args(parametrized_benchmark[1].fn, {"param": 2})
assert parametrized_benchmark[1].fn(**parametrized_benchmark[1].params) == 2


def test_parametrize_with_duplicate_parameters():
with pytest.warns(UserWarning, match="duplicate"):

@parametrize([{"param": 1}, {"param": 1}])
def parametrized_benchmark(param: int) -> int:
return param
all_benchmarks = list(parametrized_benchmark)
assert len(all_benchmarks) == 2
assert has_expected_args(all_benchmarks[0].fn, {"param": 1})
assert all_benchmarks[0].fn(**all_benchmarks[0].params) == 1
assert has_expected_args(all_benchmarks[1].fn, {"param": 2})
assert all_benchmarks[1].fn(**all_benchmarks[1].params) == 2


def test_product():
@product(iter1=[1, 2], iter2=["a", "b"])
def product_benchmark(iter1: int, iter2: str) -> tuple[int, str]:
return iter1, iter2

assert len(product_benchmark) == 4
assert has_expected_args(product_benchmark[0].fn, {"iter1": 1, "iter2": "a"})
assert product_benchmark[0].fn(**product_benchmark[0].params) == (1, "a")
assert has_expected_args(product_benchmark[1].fn, {"iter1": 1, "iter2": "b"})
assert product_benchmark[1].fn(**product_benchmark[1].params) == (1, "b")
assert has_expected_args(product_benchmark[2].fn, {"iter1": 2, "iter2": "a"})
assert product_benchmark[2].fn(**product_benchmark[2].params) == (2, "a")
assert has_expected_args(product_benchmark[3].fn, {"iter1": 2, "iter2": "b"})
assert product_benchmark[3].fn(**product_benchmark[3].params) == (2, "b")


def test_product_with_duplicate_parameters():
with pytest.warns(UserWarning, match="duplicate"):

@product(iter=[1, 1])
def product_benchmark(iter: int) -> int:
return iter
all_benchmarks = list(product_benchmark)
assert len(all_benchmarks) == 4
assert has_expected_args(all_benchmarks[0].fn, {"iter1": 1, "iter2": "a"})
assert all_benchmarks[0].fn(**all_benchmarks[0].params) == (1, "a")
assert has_expected_args(all_benchmarks[1].fn, {"iter1": 1, "iter2": "b"})
assert all_benchmarks[1].fn(**all_benchmarks[1].params) == (1, "b")
assert has_expected_args(all_benchmarks[2].fn, {"iter1": 2, "iter2": "a"})
assert all_benchmarks[2].fn(**all_benchmarks[2].params) == (2, "a")
assert has_expected_args(all_benchmarks[3].fn, {"iter1": 2, "iter2": "b"})
assert all_benchmarks[3].fn(**all_benchmarks[3].params) == (2, "b")

0 comments on commit 0b6f9e6

Please sign in to comment.