Skip to content

Commit c417ab4

Browse files
committed
TST: atleast_nd: full tests
1 parent 550331d commit c417ab4

File tree

3 files changed

+68
-4
lines changed

3 files changed

+68
-4
lines changed

pixi.lock

+3-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+6
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ test = [
3333
"pytest >=6",
3434
"pytest-cov >=3",
3535
"array-api-strict",
36+
"numpy",
3637
]
3738
dev = [
3839
"pytest >=6",
3940
"pytest-cov >=3",
4041
"array-api-strict",
42+
"numpy",
4143
"pylint",
4244
]
4345
docs = [
@@ -84,6 +86,7 @@ lint = { depends-on = ["pre-commit", "pylint"] }
8486
pytest = ">=6"
8587
pytest-cov = ">=3"
8688
array-api-strict = "*"
89+
numpy = "*"
8790

8891
[tool.pixi.feature.test.tasks]
8992
test = { cmd = "pytest" }
@@ -92,6 +95,8 @@ test-ci = { cmd = "pytest -ra --cov --cov-report=xml --cov-report=term --duratio
9295
[tool.pixi.feature.dev.dependencies]
9396
pytest = ">=6"
9497
pytest-cov = ">=3"
98+
array-api-strict = "*"
99+
numpy = "*"
95100
pylint = "*"
96101

97102
[tool.pixi.feature.docs.dependencies]
@@ -155,6 +160,7 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
155160
warn_unreachable = true
156161
disallow_untyped_defs = false
157162
disallow_incomplete_defs = false
163+
ignore_missing_imports = true
158164

159165
[[tool.mypy.overrides]]
160166
module = "array_api_extra.*"

tests/test_funcs.py

+59-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,68 @@
11
from __future__ import annotations
22

3-
import array_api_strict as xp # type: ignore[import-not-found]
3+
import array_api_strict as xp
4+
from numpy.testing import assert_array_equal
45

56
from array_api_extra import atleast_nd
67

78

89
class TestAtLeastND:
9-
def test_1d_to_2d(self):
10+
def test_0D(self):
11+
x = xp.asarray(1)
12+
13+
y = atleast_nd(x, ndim=0, xp=xp)
14+
assert_array_equal(y, x)
15+
16+
y = atleast_nd(x, ndim=1, xp=xp)
17+
assert_array_equal(y, xp.ones((1,)))
18+
19+
y = atleast_nd(x, ndim=5, xp=xp)
20+
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1)))
21+
22+
def test_1D(self):
1023
x = xp.asarray([0, 1])
24+
25+
y = atleast_nd(x, ndim=0, xp=xp)
26+
assert_array_equal(y, x)
27+
28+
y = atleast_nd(x, ndim=1, xp=xp)
29+
assert_array_equal(y, x)
30+
31+
y = atleast_nd(x, ndim=2, xp=xp)
32+
assert_array_equal(y, xp.asarray([[0, 1]]))
33+
34+
y = atleast_nd(x, ndim=5, xp=xp)
35+
assert_array_equal(y, xp.reshape(xp.arange(2), (1, 1, 1, 1, 2)))
36+
37+
def test_2D(self):
38+
x = xp.asarray([[3]])
39+
40+
y = atleast_nd(x, ndim=0, xp=xp)
41+
assert_array_equal(y, x)
42+
1143
y = atleast_nd(x, ndim=2, xp=xp)
12-
assert y.ndim == 2
44+
assert_array_equal(y, x)
45+
46+
y = atleast_nd(x, ndim=3, xp=xp)
47+
assert_array_equal(y, 3 * xp.ones((1, 1, 1)))
48+
49+
y = atleast_nd(x, ndim=5, xp=xp)
50+
assert_array_equal(y, 3 * xp.ones((1, 1, 1, 1, 1)))
51+
52+
def test_5D(self):
53+
x = xp.ones((1, 1, 1, 1, 1))
54+
55+
y = atleast_nd(x, ndim=0, xp=xp)
56+
assert_array_equal(y, x)
57+
58+
y = atleast_nd(x, ndim=4, xp=xp)
59+
assert_array_equal(y, x)
60+
61+
y = atleast_nd(x, ndim=5, xp=xp)
62+
assert_array_equal(y, x)
63+
64+
y = atleast_nd(x, ndim=6, xp=xp)
65+
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1)))
66+
67+
y = atleast_nd(x, ndim=9, xp=xp)
68+
assert_array_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))

0 commit comments

Comments
 (0)