Skip to content

Commit 2c873ac

Browse files
committed
Feat: readd manually changes on align_chunks from pydata#10516
1 parent e287ea5 commit 2c873ac

File tree

3 files changed

+60
-53
lines changed

3 files changed

+60
-53
lines changed

xarray/backends/chunks.py

Lines changed: 58 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,18 @@
44

55

66
def align_nd_chunks(
7-
nd_var_chunks: tuple[tuple[int, ...], ...],
7+
nd_v_chunks: tuple[tuple[int, ...], ...],
88
nd_backend_chunks: tuple[tuple[int, ...], ...],
99
) -> tuple[tuple[int, ...], ...]:
10-
if len(nd_backend_chunks) != len(nd_var_chunks):
10+
if len(nd_backend_chunks) != len(nd_v_chunks):
1111
raise ValueError(
1212
"The number of dimensions on the backend and the variable must be the same."
1313
)
1414

1515
nd_aligned_chunks: list[tuple[int, ...]] = []
16-
for backend_chunks, var_chunks in zip(
17-
nd_backend_chunks, nd_var_chunks, strict=True
18-
):
16+
for backend_chunks, v_chunks in zip(nd_backend_chunks, nd_v_chunks, strict=True):
1917
# Validate that they have the same number of elements
20-
if sum(backend_chunks) != sum(var_chunks):
18+
if sum(backend_chunks) != sum(v_chunks):
2119
raise ValueError(
2220
"The number of elements in the backend does not "
2321
"match the number of elements in the variable. "
@@ -42,39 +40,39 @@ def align_nd_chunks(
4240
nd_aligned_chunks.append(backend_chunks)
4341
continue
4442

45-
if len(var_chunks) == 1:
46-
nd_aligned_chunks.append(var_chunks)
43+
if len(v_chunks) == 1:
44+
nd_aligned_chunks.append(v_chunks)
4745
continue
4846

4947
# Size of the chunk on the backend
5048
fixed_chunk = max(backend_chunks)
5149

5250
# The ideal size of the chunks is the maximum of the two; this would avoid
5351
# that we use more memory than expected
54-
max_chunk = max(fixed_chunk, max(var_chunks))
52+
max_chunk = max(fixed_chunk, *v_chunks)
5553

5654
# The algorithm assumes that the chunks on this array are aligned except the last one
5755
# because it can be considered a partial one
5856
aligned_chunks: list[int] = []
5957

6058
# For simplicity of the algorithm, let's transform the Array chunks in such a way that
6159
# we remove the partial chunks. To achieve this, we add artificial data to the borders
62-
t_var_chunks = list(var_chunks)
63-
t_var_chunks[0] += fixed_chunk - backend_chunks[0]
64-
t_var_chunks[-1] += fixed_chunk - backend_chunks[-1]
60+
t_v_chunks = list(v_chunks)
61+
t_v_chunks[0] += fixed_chunk - backend_chunks[0]
62+
t_v_chunks[-1] += fixed_chunk - backend_chunks[-1]
6563

6664
# The unfilled_size is the amount of space that has not been filled on the last
6765
# processed chunk; this is equivalent to the amount of data that would need to be
6866
# added to a partial Zarr chunk to fill it up to the fixed_chunk size
6967
unfilled_size = 0
7068

71-
for var_chunk in t_var_chunks:
69+
for v_chunk in t_v_chunks:
7270
# Ideally, we should try to preserve the original Dask chunks, but this is only
7371
# possible if the last processed chunk was aligned (unfilled_size == 0)
74-
ideal_chunk = var_chunk
72+
ideal_chunk = v_chunk
7573
if unfilled_size:
7674
# If that scenario is not possible, the best option is to merge the chunks
77-
ideal_chunk = var_chunk + aligned_chunks[-1]
75+
ideal_chunk = v_chunk + aligned_chunks[-1]
7876

7977
while ideal_chunk:
8078
if not unfilled_size:
@@ -105,27 +103,27 @@ def align_nd_chunks(
105103
border_size = fixed_chunk - backend_chunks[::order][0]
106104
aligned_chunks = aligned_chunks[::order]
107105
aligned_chunks[0] -= border_size
108-
t_var_chunks = t_var_chunks[::order]
109-
t_var_chunks[0] -= border_size
106+
t_v_chunks = t_v_chunks[::order]
107+
t_v_chunks[0] -= border_size
110108
if (
111109
len(aligned_chunks) >= 2
112110
and aligned_chunks[0] + aligned_chunks[1] <= max_chunk
113-
and aligned_chunks[0] != t_var_chunks[0]
111+
and aligned_chunks[0] != t_v_chunks[0]
114112
):
115113
# The artificial data added to the border can introduce inefficient chunks
116114
# on the borders, for that reason, we will check if we can merge them or not
117115
# Example:
118116
# backend_chunks = [6, 6, 1]
119-
# var_chunks = [6, 7]
120-
# t_var_chunks = [6, 12]
121-
# The ideal output should preserve the same var_chunks, but the previous loop
117+
# v_chunks = [6, 7]
118+
# t_v_chunks = [6, 12]
119+
# The ideal output should preserve the same v_chunks, but the previous loop
122120
# is going to produce aligned_chunks = [6, 6, 6]
123121
# And after removing the artificial data, we will end up with aligned_chunks = [6, 6, 1]
124122
# which is not ideal and can be merged into a single chunk
125123
aligned_chunks[1] += aligned_chunks[0]
126124
aligned_chunks = aligned_chunks[1:]
127125

128-
t_var_chunks = t_var_chunks[::order]
126+
t_v_chunks = t_v_chunks[::order]
129127
aligned_chunks = aligned_chunks[::order]
130128

131129
nd_aligned_chunks.append(tuple(aligned_chunks))
@@ -141,9 +139,14 @@ def build_grid_chunks(
141139
if region is None:
142140
region = slice(0, size)
143141

144-
region_start = region.start if region.start else 0
142+
region_start = region.start or 0
145143
# Generate the zarr chunks inside the region of this dim
146144
chunks_on_region = [chunk_size - (region_start % chunk_size)]
145+
if chunks_on_region[0] >= size:
146+
# This is useful for the scenarios where the chunk_size are bigger
147+
# than the variable chunks, which can happens when the user specifies
148+
# the enc_chunks manually.
149+
return (size,)
147150
chunks_on_region.extend([chunk_size] * ((size - chunks_on_region[0]) // chunk_size))
148151
if (size - chunks_on_region[0]) % chunk_size != 0:
149152
chunks_on_region.append((size - chunks_on_region[0]) % chunk_size)
@@ -155,45 +158,45 @@ def grid_rechunk(
155158
enc_chunks: tuple[int, ...],
156159
region: tuple[slice, ...],
157160
) -> Variable:
158-
nd_var_chunks = v.chunks
159-
if not nd_var_chunks:
161+
nd_v_chunks = v.chunks
162+
if not nd_v_chunks:
160163
return v
161164

162165
nd_grid_chunks = tuple(
163166
build_grid_chunks(
164-
sum(var_chunks),
167+
v_size,
165168
region=interval,
166169
chunk_size=chunk_size,
167170
)
168-
for var_chunks, chunk_size, interval in zip(
169-
nd_var_chunks, enc_chunks, region, strict=True
171+
for v_size, chunk_size, interval in zip(
172+
v.shape, enc_chunks, region, strict=True
170173
)
171174
)
172175

173176
nd_aligned_chunks = align_nd_chunks(
174-
nd_var_chunks=nd_var_chunks,
177+
nd_v_chunks=nd_v_chunks,
175178
nd_backend_chunks=nd_grid_chunks,
176179
)
177180
v = v.chunk(dict(zip(v.dims, nd_aligned_chunks, strict=True)))
178181
return v
179182

180183

181184
def validate_grid_chunks_alignment(
182-
nd_var_chunks: tuple[tuple[int, ...], ...] | None,
185+
nd_v_chunks: tuple[tuple[int, ...], ...] | None,
183186
enc_chunks: tuple[int, ...],
184187
backend_shape: tuple[int, ...],
185188
region: tuple[slice, ...],
186189
allow_partial_chunks: bool,
187190
name: str,
188191
):
189-
if nd_var_chunks is None:
192+
if nd_v_chunks is None:
190193
return
191194
base_error = (
192195
"Specified Zarr chunks encoding['chunks']={enc_chunks!r} for "
193196
"variable named {name!r} would overlap multiple Dask chunks. "
194-
"Check the chunk at position {var_chunk_pos}, which has a size of "
195-
"{var_chunk_size} on dimension {dim_i}. It is unaligned with "
196-
"backend chunks of size {chunk_size} in region {region}. "
197+
"Please check the Dask chunks at position {v_chunk_pos} and "
198+
"{v_chunk_pos_next}, on axis {axis}, they are overlapped "
199+
"on the same Zarr chunk in the region {region}. "
197200
"Writing this array in parallel with Dask could lead to corrupted data. "
198201
"To resolve this issue, consider one of the following options: "
199202
"- Rechunk the array using `chunk()`. "
@@ -202,58 +205,61 @@ def validate_grid_chunks_alignment(
202205
"- Enable automatic chunks alignment with `align_chunks=True`."
203206
)
204207

205-
for dim_i, chunk_size, var_chunks, interval, size in zip(
208+
for axis, chunk_size, v_chunks, interval, size in zip(
206209
range(len(enc_chunks)),
207210
enc_chunks,
208-
nd_var_chunks,
211+
nd_v_chunks,
209212
region,
210213
backend_shape,
211214
strict=True,
212215
):
213-
for i, chunk in enumerate(var_chunks[1:-1]):
216+
for i, chunk in enumerate(v_chunks[1:-1]):
214217
if chunk % chunk_size:
215218
raise ValueError(
216219
base_error.format(
217-
var_chunk_pos=i + 1,
218-
var_chunk_size=chunk,
220+
v_chunk_pos=i + 1,
221+
v_chunk_pos_next=i + 2,
222+
v_chunk_size=chunk,
223+
axis=axis,
219224
name=name,
220-
dim_i=dim_i,
221225
chunk_size=chunk_size,
222226
region=interval,
223227
enc_chunks=enc_chunks,
224228
)
225229
)
226230

227-
interval_start = interval.start if interval.start else 0
231+
interval_start = interval.start or 0
228232

229-
if len(var_chunks) > 1:
233+
if len(v_chunks) > 1:
230234
# The first border size is the amount of data that needs to be updated on the
231235
# first chunk taking into account the region slice.
232236
first_border_size = chunk_size
233237
if allow_partial_chunks:
234238
first_border_size = chunk_size - interval_start % chunk_size
235239

236-
if (var_chunks[0] - first_border_size) % chunk_size:
240+
if (v_chunks[0] - first_border_size) % chunk_size:
237241
raise ValueError(
238242
base_error.format(
239-
var_chunk_pos=0,
240-
var_chunk_size=var_chunks[0],
243+
v_chunk_pos=0,
244+
v_chunk_pos_next=0,
245+
v_chunk_size=v_chunks[0],
246+
axis=axis,
241247
name=name,
242-
dim_i=dim_i,
243248
chunk_size=chunk_size,
244249
region=interval,
245250
enc_chunks=enc_chunks,
246251
)
247252
)
248253

249254
if not allow_partial_chunks:
250-
region_stop = interval.stop if interval.stop else size
255+
region_stop = interval.stop or size
251256

252257
error_on_last_chunk = base_error.format(
253-
var_chunk_pos=len(var_chunks) - 1,
254-
var_chunk_size=var_chunks[-1],
258+
v_chunk_pos=len(v_chunks) - 1,
259+
v_chunk_pos_next=len(v_chunks) - 1,
260+
v_chunk_size=v_chunks[-1],
261+
axis=axis,
255262
name=name,
256-
dim_i=dim_i,
257263
chunk_size=chunk_size,
258264
region=interval,
259265
enc_chunks=enc_chunks,
@@ -267,7 +273,7 @@ def validate_grid_chunks_alignment(
267273
# If the region is covering the last chunk then check
268274
# if the reminder with the default chunk size
269275
# is equal to the size of the last chunk
270-
if var_chunks[-1] % chunk_size != size % chunk_size:
276+
if v_chunks[-1] % chunk_size != size % chunk_size:
271277
raise ValueError(error_on_last_chunk)
272-
elif var_chunks[-1] % chunk_size:
278+
elif v_chunks[-1] % chunk_size:
273279
raise ValueError(error_on_last_chunk)

xarray/backends/zarr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,7 @@ def set_variables(
11861186
# threads
11871187
shape = zarr_shape if zarr_shape else v.shape
11881188
validate_grid_chunks_alignment(
1189-
nd_var_chunks=v.chunks,
1189+
nd_v_chunks=v.chunks,
11901190
enc_chunks=encoding["chunks"],
11911191
region=region,
11921192
allow_partial_chunks=self._mode != "r+",

xarray/core/dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,6 +2303,7 @@ def to_zarr(
23032303
append_dim=append_dim,
23042304
region=region,
23052305
safe_chunks=safe_chunks,
2306+
align_chunks=align_chunks,
23062307
zarr_version=zarr_version,
23072308
zarr_format=zarr_format,
23082309
write_empty_chunks=write_empty_chunks,

0 commit comments

Comments
 (0)