Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
Release Notes
=============

.. Upcoming Version
.. ----------------
Upcoming Version
----------------

* Fix the `slice_size` argument in the `solve` function. The argument was not properly passed to the `to_file` function.
* Fix the slicing of constraints in case the term dimension is larger than the leading constraint coordinate dimension.

Version 0.4.0
--------------
Expand Down
8 changes: 7 additions & 1 deletion linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,11 @@ def iterate_slices(
if slice_dims is None:
slice_dims = list(getattr(ds, "coord_dims", ds.dims))

if not set(slice_dims).issubset(ds.dims):
raise ValueError(
"Invalid slice dimensions. Must be a subset of the dataset dimensions."
)

# Calculate the total number of elements in the dataset
size = np.prod([ds.sizes[dim] for dim in ds.dims], dtype=int)

Expand All @@ -517,7 +522,8 @@ def iterate_slices(
n_slices = max(size // slice_size, 1)

# leading dimension (the dimension with the largest size)
leading_dim = max(ds.sizes, key=ds.sizes.get) # type: ignore
sizes = {dim: ds.sizes[dim] for dim in slice_dims}
leading_dim = max(sizes, key=sizes.get) # type: ignore
size_of_leading_dim = ds.sizes[leading_dim]

if size_of_leading_dim < n_slices:
Expand Down
4 changes: 3 additions & 1 deletion linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,9 @@ def solve(
env=env,
)
else:
problem_fn = self.to_file(to_path(problem_fn), io_api)
problem_fn = self.to_file(
to_path(problem_fn), io_api, slice_size=slice_size
)
result = solver.solve_problem_from_file(
problem_fn=to_path(problem_fn),
solution_fn=to_path(solution_fn),
Expand Down
28 changes: 15 additions & 13 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,11 +475,11 @@ def test_iterate_slices_basic():

def test_iterate_slices_with_exclude_dims():
ds = xr.Dataset(
{"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002
coords={"x": np.arange(10), "y": np.arange(10)},
{"var": (("x", "y"), np.random.rand(10, 20))}, # noqa: NPY002
coords={"x": np.arange(10), "y": np.arange(20)},
)
slices = list(iterate_slices(ds, slice_size=20, slice_dims=["x"]))
assert len(slices) == 5
assert len(slices) == 10
for s in slices:
assert isinstance(s, xr.Dataset)
assert set(s.dims) == set(ds.dims)
Expand All @@ -499,11 +499,13 @@ def test_iterate_slices_large_max_size():

def test_iterate_slices_small_max_size():
ds = xr.Dataset(
{"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002
coords={"x": np.arange(10), "y": np.arange(10)},
{"var": (("x", "y"), np.random.rand(10, 20))}, # noqa: NPY002
coords={"x": np.arange(10), "y": np.arange(20)},
)
slices = list(iterate_slices(ds, slice_size=8, slice_dims=[]))
assert len(slices) == 10
slices = list(iterate_slices(ds, slice_size=8, slice_dims=["x"]))
assert (
len(slices) == 10
) # goes to the smallest slice possible which is 1 for the x dimension
for s in slices:
assert isinstance(s, xr.Dataset)
assert set(s.dims) == set(ds.dims)
Expand All @@ -520,16 +522,16 @@ def test_iterate_slices_slice_size_none():
assert ds.equals(s)


def test_iterate_slices_no_slice_dims():
def test_iterate_slices_invalid_slice_dims():
ds = xr.Dataset(
{"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002
coords={"x": np.arange(10), "y": np.arange(10)},
)
slices = list(iterate_slices(ds, slice_size=50, slice_dims=[]))
assert len(slices) == 2
for s in slices:
assert isinstance(s, xr.Dataset)
assert set(s.dims) == set(ds.dims)
with pytest.raises(ValueError):
list(iterate_slices(ds, slice_size=50, slice_dims=[]))

with pytest.raises(ValueError):
list(iterate_slices(ds, slice_size=50, slice_dims=["z"]))


def test_get_dims_with_index_levels():
Expand Down