Skip to content

Commit

Permalink
cover missing lines for meshgrid
Browse files Browse the repository at this point in the history
  • Loading branch information
Bchass committed Nov 20, 2024
1 parent e37a71f commit 44d5435
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
8 changes: 8 additions & 0 deletions tinynumpy/tests/test_tinynumpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,14 @@ def test_meshgrid():
assert (yy.shape == yy_shape_expected)
assert (zz.shape == zz_shape_expected)

# value error for indexing
with pytest.raises(ValueError):
xv, yv = tnp.meshgrid(x, y, indexing='xi')

# value error for len shapes < 1
with pytest.raises(ValueError):
xv, yv = tnp.meshgrid(x, indexing='xy')



def test_astype():
Expand Down
25 changes: 19 additions & 6 deletions tinynumpy/tinynumpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None):
return a

def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=None):
""" logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=None)
Return numbers spaced evenly on a log scale.
"""

start, stop = float(start), float(stop)
ra = stop - start
Expand All @@ -384,19 +389,26 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=Non


def meshgrid(*xi, copy=True, sparse=False, indexing='xy'):
""" meshgrid(*xi, copy=True, sparse=False, indexing='xy')
Return a tuple of coordinate matrices from coordinate vectors.
"""

ndim = len(xi)

if indexing not in {'xy', 'ij'}:
raise ValueError("Indexing must be 'xy' or 'ij'")

ndim = len(xi)
if ndim < 1:
raise ValueError("At least one input array is required")

# Adjust the order of inputs for 'xy' indexing
if indexing == 'xy' and ndim >= 2:
if indexing == 'xy' and ndim > 1:
xi = (xi[1], xi[0]) + xi[2:]

# Get the lengths of each input array
shapes = [len(arr) for arr in xi]
shapes = [len(x) for x in xi]

if len(shapes) < 2:
raise ValueError("At least two input arrays are required")

# Create the output grids
grids = []
Expand All @@ -405,6 +417,7 @@ def meshgrid(*xi, copy=True, sparse=False, indexing='xy'):
if i == 0:
# Repeat for columns (x-axis direction)
grid = [list(x) for _ in range(shapes[1])]
print(grid)
else:
# Repeat for rows (y-axis direction)
grid = [[x_val] * shapes[0] for x_val in x]
Expand Down

0 comments on commit 44d5435

Please sign in to comment.