From 010b6c98ac68c3a1541a66caec51180c7d2d31e8 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Sat, 7 Oct 2023 12:24:37 -0300 Subject: [PATCH] Implement JIT version of sampling particle.xi on specific grid number --- parcels/compilation/codegenerator.py | 27 +++++++++++++++++---------- tests/test_particlefile.py | 23 ++++++++++++++++------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/parcels/compilation/codegenerator.py b/parcels/compilation/codegenerator.py index b3b63a4e6..0eba772b1 100644 --- a/parcels/compilation/codegenerator.py +++ b/parcels/compilation/codegenerator.py @@ -181,6 +181,14 @@ def __init__(self, obj, attr): self.attr = attr +class ParticleXiYiZiTiAttributeNode(IntrinsicNode): + def __init__(self, obj, attr): + logger.warning_once(f"Be careful when sampling particle.{attr}, as this is updated in the kernel loop. " + "Best to place the sampling statement before advection.") + self.obj = obj.ccode + self.attr = attr + + class ParticleNode(IntrinsicNode): attr_node_class = None @@ -191,13 +199,15 @@ def __init__(self, obj): self.attr_node_class = attr_node_class def __getattr__(self, attr): + if attr in ['xi', 'yi', 'zi', 'ti']: + return ParticleXiYiZiTiAttributeNode(self, attr) if attr in [v.name for v in self.obj.variables]: return self.attr_node_class(self, attr) elif attr in ['delete']: return self.attr_node_class(self, 'state') else: - raise AttributeError(f"Particle type {self.obj} does not define attribute '{attr}.\n" - f"Please add '{attr}' to {self.obj}.users_vars or define an appropriate sub-class.") + raise AttributeError(f"Particle type {self.obj.name} does not define attribute '{attr}. " + f"Please add '{attr}' as a Variable in {self.obj.name}.") class IntrinsicTransformer(ast.NodeTransformer): @@ -314,14 +324,6 @@ def visit_AugAssign(self, node): def visit_Assign(self, node): node.targets = [self.visit(t) for t in node.targets] - if (isinstance(node.targets[0], ParticleAttributeNode) and hasattr(node.value, 'value') - and hasattr(node.value.value, 'attr') and node.value.value.attr in ['xi', 'yi', 'zi']): - node.value = node.value.value - ngridstr = [] - if self.fieldset.gridset.size > 1: - ngridstr = "Also be careful that particle.xi is not well-defined in JIT mode when using multiple grids." - logger.warning_once(f"Be careful when sampling particle.{node.value.attr}, as this is updated in the kernel loop. " - f"Best to place the sampling statement before advection. {ngridstr}") node.value = self.visit(node.value) stmts = [node] @@ -603,6 +605,9 @@ def visit_Assign(self, node): tmp_node = tmp_node.elts[0] node.ccode = c.Initializer(decl, node.value.ccode) self.array_vars += [node.targets[0].id] + elif isinstance(node.value, ParticleXiYiZiTiAttributeNode): + raise RuntimeError(f"Add index of the grid when using particle.{node.value.attr} " + f"(e.g. particle.{node.value.attr}[0]).") else: node.ccode = c.Assign(node.targets[0].ccode, node.value.ccode) @@ -654,6 +659,8 @@ def visit_Subscript(self, node): self.visit(node.slice) if isinstance(node.value, FieldNode) or isinstance(node.value, VectorFieldNode): node.ccode = node.value.__getitem__(node.slice.ccode).ccode + elif isinstance(node.value, ParticleXiYiZiTiAttributeNode): + node.ccode = f"{node.value.obj}->{node.value.attr}[pnum, {node.slice.ccode}]" elif isinstance(node.value, IntrinsicNode): raise NotImplementedError(f"Subscript not implemented for object type {type(node.value).__name__}") else: diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py index f5600f43f..e4d708592 100644 --- a/tests/test_particlefile.py +++ b/tests/test_particlefile.py @@ -267,11 +267,12 @@ def Update_lon(particle, fieldset, time): def test_write_xiyi(fieldset, mode, tmpdir): outfilepath = tmpdir.join("pfile_xiyi.zarr") fieldset.U.data[:] = 1 # set a non-zero zonal velocity - fieldset.add_field(Field(name='P', data=np.zeros((2, 2)), lon=[0, 1], lat=[0, 2])) + fieldset.add_field(Field(name='P', data=np.zeros((2, 20)), lon=np.linspace(0, 1, 20), lat=[0, 2])) dt = 3600 class XiYiParticle(ptype[mode]): - pxi = Variable('pxi', dtype=np.int32, initial=0.) + pxi0 = Variable('pxi0', dtype=np.int32, initial=0.) + pxi1 = Variable('pxi1', dtype=np.int32, initial=0.) pyi = Variable('pyi', dtype=np.int32, initial=0.) def Get_XiYi(particle, fieldset, time): @@ -280,21 +281,29 @@ def Get_XiYi(particle, fieldset, time): and that the first outputted value is zero. Be careful when using multiple grids, as the index may be different for the grids. """ - particle.pxi = particle.xi[0] + particle.pxi0 = particle.xi[0] + particle.pxi1 = particle.xi[1] particle.pyi = particle.yi[0] + def SampleP(particle, fieldset, time): + if time > 5*3600: + tmp = fieldset.P[particle] # noqa + pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0], lat=[0.2], lonlatdepth_dtype=np.float64) pfile = pset.ParticleFile(name=outfilepath, outputdt=dt) - pset.execute([Get_XiYi, AdvectionRK4], endtime=10*dt, dt=dt, output_file=pfile) + pset.execute([Get_XiYi, SampleP, AdvectionRK4], endtime=10*dt, dt=dt, output_file=pfile) ds = xr.open_zarr(outfilepath) - pxi = ds['pxi'][:].values[0].astype(np.int32) + pxi0 = ds['pxi0'][:].values[0].astype(np.int32) + pxi1 = ds['pxi1'][:].values[0].astype(np.int32) lons = ds['lon'][:].values[0] pyi = ds['pyi'][:].values[0].astype(np.int32) lats = ds['lat'][:].values[0] - assert (pxi[0] == 0) and (pxi[-1] == 11) # check that particle has moved - for xi, lon in zip(pxi[1:], lons[1:]): + assert (pxi0[0] == 0) and (pxi0[-1] == 11) # check that particle has moved + assert np.all(pxi1[:7] == 0) # check that particle has not been sampled on grid 1 until time 6 + assert np.all(pxi1[7:] > 0) # check that particle has not been sampled on grid 1 after time 6 + for xi, lon in zip(pxi0[1:], lons[1:]): assert fieldset.U.grid.lon[xi] <= lon < fieldset.U.grid.lon[xi+1] for yi, lat in zip(pyi[1:], lats[1:]): assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi+1]