Skip to content

Commit

Permalink
dev
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhassell committed Sep 20, 2023
1 parent ebd47b4 commit 094a60a
Show file tree
Hide file tree
Showing 9 changed files with 3,085 additions and 1,552 deletions.
4 changes: 2 additions & 2 deletions Changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ version 3.11.0
* Fix for `cf.aggregate` failures when a datum or coordinate
conversion parameter has an array value
(https://github.com/NCAS-CMS/cf-python/issues/230)
* Allow for regridding using a destination field featuring size 1 dimension(s)
(https://github.com/NCAS-CMS/cf-python/issues/250)
* Allow for regridding using a destination field featuring size 1
dimension(s) (https://github.com/NCAS-CMS/cf-python/issues/250)
* Fix bug that sometimes caused `cf.Field.autocyclic` to fail when
setting a construct that is cyclic and has a defined period
* Fix bug that sometimes caused a failure when reading PP extra data
Expand Down
2 changes: 1 addition & 1 deletion cf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@

__Conventions__ = "CF-1.11"
__date__ = "2023-??-??"
__version__ = "3.16.0"
__version__ = "3.17.0"

_requires = (
"numpy",
Expand Down
19 changes: 13 additions & 6 deletions cf/data/dask_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,18 +287,25 @@ def regrid(
a = a.reshape(non_regrid_shape + tuple(dst_shape))

n_dst_axes = len(dst_shape)

if n_src_axes == n_dst_axes:
pass
elif n_src_axes == 1 and n_dst_axes > 1:
# E.g. UGRID regridded to regular lat-lon; changes
# 'axis_order' from [0,2,1] to [0,3,1,2]
r = axis_order[-1]
axis_order = [i + n_dst_axes - 1 if i > r else i for i in axis_order]
axis_order[-1:] = range(r, r + n_dst_axes)
elif n_dst_axes == 1 and n_src_axes > 1:
raxis = axis_order[-1]
axis_order = [
i if i <= raxis else i + n_dst_axes - 1 for i in axis_order
]
axis_order[-1:] = range(raxis, raxis + n_dst_axes)
elif n_src_axes == 2 and n_dst_axes == 1:
# E.g. regular lat-lon regridded to UGRID; changes
# 'axis_order' from [0,3,2,1] to [0,2,1]
pass # TODOUGRID
# 'axis_order' from [0,2,4,5,1,3] to [0,2,3,4,1], or
# [0,2,4,5,3,1] to [0,1,3,4,2]
raxis0, raxis = axis_order[-2:]
print(axis_order)
axis_order = [i if i <= raxis else i - 1 for i in axis_order[:-1]]
print(axis_order)
else:
raise ValueError("TODOUGRID")

Expand Down
25 changes: 18 additions & 7 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3734,6 +3734,7 @@ def _regrid(
from .dask_regrid import regrid, regrid_weights

shape = self.shape
ndim = self.ndim
src_shape = tuple(shape[i] for i in regrid_axes)
if src_shape != operator.src_shape:
raise ValueError(
Expand All @@ -3755,12 +3756,17 @@ def _regrid(

# Define the regridded chunksizes
regridded_chunks = [] # The 'chunks' parameter to `map_blocks`
drop_axis = [] # The 'drop_axis' parameter to `map_blocks`
new_axis = [] # The 'new_axis' parameter to `map_blocks`
n = 0
for i, c in enumerate(dx.chunks):
if i in regridded_sizes:
sizes = regridded_sizes[i]
n_sizes = len(sizes)
if not n_sizes:
drop_axis.append(i)
continue

regridded_chunks.extend(sizes)
if n_sizes > 1:
new_axis.extend(range(n + 1, n + n_sizes))
Expand All @@ -3770,25 +3776,29 @@ def _regrid(

n += 1

# Update the axis identifiers.
#
# This is necessary when regridding changes the number of data
# dimensions (e.g. as happens when regridding a mesh topology
# axis to/from separate lat and lon axes).
if new_axis:
# Update the axis identifiers.
#
# This is necessary when regridding changes the number of
# data dimensions (e.g. as happens when regridding a mesh
# topology axis to separate lat and lon axes).
axes = list(self._axes)
for i in new_axis:
axes.insert(i, new_axis_identifier(tuple(axes)))

self._axes = tuple(axes)
self._axes = axes
elif drop_axis:
axes = self._axes
axes = [axes[i] for i in range(ndim) if i not in drop_axis]
self._axes = axes

# Set the output data type
if method in ("nearest_dtos", "nearest_stod"):
dst_dtype = dx.dtype
else:
dst_dtype = float

non_regrid_axes = [i for i in range(self.ndim) if i not in regrid_axes]
non_regrid_axes = [i for i in range(ndim) if i not in regrid_axes]

src_mask = operator.src_mask
if src_mask is not None:
Expand All @@ -3813,6 +3823,7 @@ def _regrid(
weights_dst_mask=weights_dst_mask,
ref_src_mask=src_mask,
chunks=regridded_chunks,
drop_axis=drop_axis,
new_axis=new_axis,
meta=np.array((), dtype=dst_dtype),
)
Expand Down
Loading

0 comments on commit 094a60a

Please sign in to comment.