Skip to content

Commit

Permalink
Implement JIT version of sampling particle.xi on specific grid number
Browse files Browse the repository at this point in the history
  • Loading branch information
erikvansebille committed Oct 7, 2023
1 parent 71246ea commit 010b6c9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
27 changes: 17 additions & 10 deletions parcels/compilation/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,14 @@ def __init__(self, obj, attr):
self.attr = attr


class ParticleXiYiZiTiAttributeNode(IntrinsicNode):
def __init__(self, obj, attr):
logger.warning_once(f"Be careful when sampling particle.{attr}, as this is updated in the kernel loop. "
"Best to place the sampling statement before advection.")
self.obj = obj.ccode
self.attr = attr


class ParticleNode(IntrinsicNode):
attr_node_class = None

Expand All @@ -191,13 +199,15 @@ def __init__(self, obj):
self.attr_node_class = attr_node_class

def __getattr__(self, attr):
if attr in ['xi', 'yi', 'zi', 'ti']:
return ParticleXiYiZiTiAttributeNode(self, attr)
if attr in [v.name for v in self.obj.variables]:
return self.attr_node_class(self, attr)
elif attr in ['delete']:
return self.attr_node_class(self, 'state')
else:
raise AttributeError(f"Particle type {self.obj} does not define attribute '{attr}.\n"
f"Please add '{attr}' to {self.obj}.users_vars or define an appropriate sub-class.")
raise AttributeError(f"Particle type {self.obj.name} does not define attribute '{attr}. "

Check warning on line 209 in parcels/compilation/codegenerator.py

View check run for this annotation

Codecov / codecov/patch

parcels/compilation/codegenerator.py#L209

Added line #L209 was not covered by tests
f"Please add '{attr}' as a Variable in {self.obj.name}.")


class IntrinsicTransformer(ast.NodeTransformer):
Expand Down Expand Up @@ -314,14 +324,6 @@ def visit_AugAssign(self, node):

def visit_Assign(self, node):
node.targets = [self.visit(t) for t in node.targets]
if (isinstance(node.targets[0], ParticleAttributeNode) and hasattr(node.value, 'value')
and hasattr(node.value.value, 'attr') and node.value.value.attr in ['xi', 'yi', 'zi']):
node.value = node.value.value
ngridstr = []
if self.fieldset.gridset.size > 1:
ngridstr = "Also be careful that particle.xi is not well-defined in JIT mode when using multiple grids."
logger.warning_once(f"Be careful when sampling particle.{node.value.attr}, as this is updated in the kernel loop. "
f"Best to place the sampling statement before advection. {ngridstr}")
node.value = self.visit(node.value)
stmts = [node]

Expand Down Expand Up @@ -603,6 +605,9 @@ def visit_Assign(self, node):
tmp_node = tmp_node.elts[0]
node.ccode = c.Initializer(decl, node.value.ccode)
self.array_vars += [node.targets[0].id]
elif isinstance(node.value, ParticleXiYiZiTiAttributeNode):
raise RuntimeError(f"Add index of the grid when using particle.{node.value.attr} "

Check warning on line 609 in parcels/compilation/codegenerator.py

View check run for this annotation

Codecov / codecov/patch

parcels/compilation/codegenerator.py#L609

Added line #L609 was not covered by tests
f"(e.g. particle.{node.value.attr}[0]).")
else:
node.ccode = c.Assign(node.targets[0].ccode, node.value.ccode)

Expand Down Expand Up @@ -654,6 +659,8 @@ def visit_Subscript(self, node):
self.visit(node.slice)
if isinstance(node.value, FieldNode) or isinstance(node.value, VectorFieldNode):
node.ccode = node.value.__getitem__(node.slice.ccode).ccode
elif isinstance(node.value, ParticleXiYiZiTiAttributeNode):
node.ccode = f"{node.value.obj}->{node.value.attr}[pnum, {node.slice.ccode}]"
elif isinstance(node.value, IntrinsicNode):
raise NotImplementedError(f"Subscript not implemented for object type {type(node.value).__name__}")
else:
Expand Down
23 changes: 16 additions & 7 deletions tests/test_particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,12 @@ def Update_lon(particle, fieldset, time):
def test_write_xiyi(fieldset, mode, tmpdir):
outfilepath = tmpdir.join("pfile_xiyi.zarr")
fieldset.U.data[:] = 1 # set a non-zero zonal velocity
fieldset.add_field(Field(name='P', data=np.zeros((2, 2)), lon=[0, 1], lat=[0, 2]))
fieldset.add_field(Field(name='P', data=np.zeros((2, 20)), lon=np.linspace(0, 1, 20), lat=[0, 2]))
dt = 3600

class XiYiParticle(ptype[mode]):
pxi = Variable('pxi', dtype=np.int32, initial=0.)
pxi0 = Variable('pxi0', dtype=np.int32, initial=0.)
pxi1 = Variable('pxi1', dtype=np.int32, initial=0.)
pyi = Variable('pyi', dtype=np.int32, initial=0.)

def Get_XiYi(particle, fieldset, time):
Expand All @@ -280,21 +281,29 @@ def Get_XiYi(particle, fieldset, time):
and that the first outputted value is zero.
Be careful when using multiple grids, as the index may be different for the grids.
"""
particle.pxi = particle.xi[0]
particle.pxi0 = particle.xi[0]
particle.pxi1 = particle.xi[1]
particle.pyi = particle.yi[0]

Check warning on line 286 in tests/test_particlefile.py

View check run for this annotation

Codecov / codecov/patch

tests/test_particlefile.py#L284-L286

Added lines #L284 - L286 were not covered by tests

def SampleP(particle, fieldset, time):
if time > 5*3600:
tmp = fieldset.P[particle] # noqa

Check warning on line 290 in tests/test_particlefile.py

View check run for this annotation

Codecov / codecov/patch

tests/test_particlefile.py#L289-L290

Added lines #L289 - L290 were not covered by tests

pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0], lat=[0.2], lonlatdepth_dtype=np.float64)
pfile = pset.ParticleFile(name=outfilepath, outputdt=dt)
pset.execute([Get_XiYi, AdvectionRK4], endtime=10*dt, dt=dt, output_file=pfile)
pset.execute([Get_XiYi, SampleP, AdvectionRK4], endtime=10*dt, dt=dt, output_file=pfile)

ds = xr.open_zarr(outfilepath)
pxi = ds['pxi'][:].values[0].astype(np.int32)
pxi0 = ds['pxi0'][:].values[0].astype(np.int32)
pxi1 = ds['pxi1'][:].values[0].astype(np.int32)
lons = ds['lon'][:].values[0]
pyi = ds['pyi'][:].values[0].astype(np.int32)
lats = ds['lat'][:].values[0]

assert (pxi[0] == 0) and (pxi[-1] == 11) # check that particle has moved
for xi, lon in zip(pxi[1:], lons[1:]):
assert (pxi0[0] == 0) and (pxi0[-1] == 11) # check that particle has moved
assert np.all(pxi1[:7] == 0) # check that particle has not been sampled on grid 1 until time 6
assert np.all(pxi1[7:] > 0) # check that particle has not been sampled on grid 1 after time 6
for xi, lon in zip(pxi0[1:], lons[1:]):
assert fieldset.U.grid.lon[xi] <= lon < fieldset.U.grid.lon[xi+1]
for yi, lat in zip(pyi[1:], lats[1:]):
assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi+1]
Expand Down

0 comments on commit 010b6c9

Please sign in to comment.