diff --git a/docs/examples/documentation_stuck_particles.ipynb b/docs/examples/documentation_stuck_particles.ipynb
index e21cad50dd..681748e7e8 100644
--- a/docs/examples/documentation_stuck_particles.ipynb
+++ b/docs/examples/documentation_stuck_particles.ipynb
@@ -82,7 +82,6 @@
" FieldSet,\n",
" JITParticle,\n",
" ParticleSet,\n",
- " Variable,\n",
" download_example_dataset,\n",
")"
]
@@ -1222,8 +1221,7 @@
},
"outputs": [],
"source": [
- "class LandParticle(JITParticle):\n",
- " on_land = Variable(\"on_land\")\n",
+ "LandParticle = JITParticle.add_variable(\"on_land\")\n",
"\n",
"\n",
"def Sample_land(particle, fieldset, time):\n",
diff --git a/docs/examples/documentation_unstuck_Agrid.ipynb b/docs/examples/documentation_unstuck_Agrid.ipynb
index 7aa21def6f..dc63f8bab9 100644
--- a/docs/examples/documentation_unstuck_Agrid.ipynb
+++ b/docs/examples/documentation_unstuck_Agrid.ipynb
@@ -809,10 +809,13 @@
"metadata": {},
"outputs": [],
"source": [
- "class DisplacementParticle(JITParticle):\n",
- " dU = Variable(\"dU\")\n",
- " dV = Variable(\"dV\")\n",
- " d2s = Variable(\"d2s\", initial=1e3)\n",
+ "DisplacementParticle = JITParticle.add_variables(\n",
+ " [\n",
+ " Variable(\"dU\"),\n",
+ " Variable(\"dV\"),\n",
+ " Variable(\"d2s\", initial=1e3),\n",
+ " ]\n",
+ ")\n",
"\n",
"\n",
"def set_displacement(particle, fieldset, time):\n",
diff --git a/docs/examples/example_globcurrent.py b/docs/examples/example_globcurrent.py
index 4f53d6af0a..a705bf193f 100755
--- a/docs/examples/example_globcurrent.py
+++ b/docs/examples/example_globcurrent.py
@@ -13,7 +13,6 @@
ParticleSet,
ScipyParticle,
TimeExtrapolationError,
- Variable,
download_example_dataset,
)
@@ -99,8 +98,7 @@ def test_globcurrent_time_periodic(mode, rundays):
for deferred_load in [True, False]:
fieldset = set_globcurrent_fieldset(time_periodic=delta(days=365), deferred_load=deferred_load)
- class MyParticle(ptype[mode]):
- sample_var = Variable('sample_var', initial=0.)
+ MyParticle = ptype[mode].add_variable('sample_var', initial=0.)
pset = ParticleSet(fieldset, pclass=MyParticle, lon=25, lat=-35, time=fieldset.U.grid.time[0])
@@ -194,8 +192,7 @@ def test_globcurrent_startparticles_between_time_arrays(mode, dt, with_starttime
fieldset.add_field(Field.from_netcdf(fnamesFeb, ('P', 'eastward_eulerian_current_velocity'),
{'lat': 'lat', 'lon': 'lon', 'time': 'time'}))
- class MyParticle(ptype[mode]):
- sample_var = Variable('sample_var', initial=0.)
+ MyParticle = ptype[mode].add_variable('sample_var', initial=0.)
def SampleP(particle, fieldset, time):
particle.sample_var += fieldset.P[time, particle.depth, particle.lat, particle.lon]
diff --git a/docs/examples/example_peninsula.py b/docs/examples/example_peninsula.py
index 1d01845840..a96479d025 100644
--- a/docs/examples/example_peninsula.py
+++ b/docs/examples/example_peninsula.py
@@ -130,11 +130,8 @@ def peninsula_example(fieldset, outfile, npart, mode='jit', degree=1,
# First, we define a custom Particle class to which we add a
# custom variable, the initial stream function value p.
# We determine the particle base class according to mode.
- class MyParticle(ptype[mode]):
- # JIT compilation requires a-priori knowledge of the particle
- # data structure, so we define additional variables here.
- p = Variable('p', dtype=np.float32, initial=0.)
- p_start = Variable('p_start', dtype=np.float32, initial=0)
+ MyParticle = ptype[mode].add_variable([Variable('p', dtype=np.float32, initial=0.),
+ Variable('p_start', dtype=np.float32, initial=0)])
# Initialise particles
if fieldset.U.grid.mesh == 'flat':
diff --git a/docs/examples/example_stommel.py b/docs/examples/example_stommel.py
index 47f9b69a7c..052987638c 100755
--- a/docs/examples/example_stommel.py
+++ b/docs/examples/example_stommel.py
@@ -105,11 +105,11 @@ def stommel_example(npart=1, mode='jit', verbose=False, method=AdvectionRK4, gri
dt = delta(hours=1)
outputdt = delta(days=5)
- class MyParticle(ParticleClass):
- p = Variable('p', dtype=np.float32, initial=0.)
- p_start = Variable('p_start', dtype=np.float32, initial=0.)
- next_dt = Variable('next_dt', dtype=np.float64, initial=dt.total_seconds())
- age = Variable('age', dtype=np.float32, initial=0.)
+ extra_vars = [Variable('p', dtype=np.float32, initial=0.),
+ Variable('p_start', dtype=np.float32, initial=0.),
+ Variable('next_dt', dtype=np.float64, initial=dt.total_seconds()),
+ Variable('age', dtype=np.float32, initial=0.)]
+ MyParticle = ParticleClass.add_variables(extra_vars)
if custom_partition_function:
pset = ParticleSet.from_line(fieldset, size=npart, pclass=MyParticle, repeatdt=repeatdt,
diff --git a/docs/examples/parcels_tutorial.ipynb b/docs/examples/parcels_tutorial.ipynb
index cbd6fcc7e5..78d7eb52e8 100644
--- a/docs/examples/parcels_tutorial.ipynb
+++ b/docs/examples/parcels_tutorial.ipynb
@@ -228,7 +228,7 @@
"output_type": "stream",
"text": [
"INFO: Output files are stored in EddyParticles.zarr.\n",
- "100%|██████████| 518400.0/518400.0 [00:02<00:00, 187445.76it/s]\n"
+ "100%|██████████| 518400.0/518400.0 [00:03<00:00, 172443.78it/s]\n"
]
}
],
@@ -569,42 +569,42 @@
"\n",
"\n",
"
\n",
- "
![]()
\n",
+ "
![]()
\n",
"
\n",
- "
\n",
+ " oninput=\"anim39fc9ac9ecc54c368c0d0e047c7673e6.set_frame(parseInt(this.value));\">\n",
"
\n",
- "
\n",
- "
\n",
"
\n",
"
\n",
@@ -614,9 +614,9 @@
" /* Instantiate the Animation class. */\n",
" /* The IDs given should match those used in the template above. */\n",
" (function() {\n",
- " var img_id = \"_anim_img8bfbea7abfae4cbf885a957f3f62a93b\";\n",
- " var slider_id = \"_anim_slider8bfbea7abfae4cbf885a957f3f62a93b\";\n",
- " var loop_select_id = \"_anim_loop_select8bfbea7abfae4cbf885a957f3f62a93b\";\n",
+ " var img_id = \"_anim_img39fc9ac9ecc54c368c0d0e047c7673e6\";\n",
+ " var slider_id = \"_anim_slider39fc9ac9ecc54c368c0d0e047c7673e6\";\n",
+ " var loop_select_id = \"_anim_loop_select39fc9ac9ecc54c368c0d0e047c7673e6\";\n",
" var frames = new Array(29);\n",
" \n",
" frames[0] = \"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAArwAAAH0CAYAAADfWf7fAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\\\n",
@@ -13156,7 +13156,7 @@
" /* set a timeout to make sure all the above elements are created before\n",
" the object is initialized. */\n",
" setTimeout(function() {\n",
- " anim8bfbea7abfae4cbf885a957f3f62a93b = new Animation(frames, img_id, slider_id, 100.0,\n",
+ " anim39fc9ac9ecc54c368c0d0e047c7673e6 = new Animation(frames, img_id, slider_id, 100.0,\n",
" loop_select_id);\n",
" }, 0);\n",
" })()\n",
@@ -13201,7 +13201,7 @@
"output_type": "stream",
"text": [
"INFO: Output files are stored in EddyParticles_Bwd.zarr.\n",
- "100%|██████████| 518400.0/518400.0 [00:02<00:00, 188464.76it/s]\n"
+ "100%|██████████| 518400.0/518400.0 [00:02<00:00, 176426.86it/s]\n"
]
}
],
@@ -13323,7 +13323,7 @@
"output_type": "stream",
"text": [
"INFO: Output files are stored in EddyParticles_WestVel.zarr.\n",
- "100%|██████████| 172800.0/172800.0 [00:00<00:00, 180626.58it/s]\n"
+ "100%|██████████| 172800.0/172800.0 [00:00<00:00, 179532.85it/s]\n"
]
}
],
@@ -13503,7 +13503,7 @@
"output_type": "stream",
"text": [
"INFO: Output files are stored in GlobCurrentParticles.zarr.\n",
- "100%|██████████| 864000.0/864000.0 [00:00<00:00, 1104294.63it/s]\n"
+ "100%|██████████| 864000.0/864000.0 [00:00<00:00, 1072517.72it/s]\n"
]
}
],
@@ -13592,7 +13592,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Now define a new `Particle` class that has an extra `Variable`: the pressure. We initialise this by sampling the `fieldset.P` field.\n"
+ "Now define a new `Particle` class that has an extra `Variable`: the pressure. This `particle.p` can be used to store the values of the `fieldset.P` field at the particle locations.\n"
]
},
{
@@ -13601,11 +13601,7 @@
"metadata": {},
"outputs": [],
"source": [
- "class SampleParticle(JITParticle):\n",
- " \"\"\"Define a new particle class with variable 'p'\n",
- " initialised by sampling the pressure\"\"\"\n",
- "\n",
- " p = Variable(\"p\")"
+ "SampleParticle = JITParticle.add_variable(\"p\")"
]
},
{
@@ -13688,7 +13684,7 @@
"output_type": "stream",
"text": [
"INFO: Output files are stored in PeninsulaPressure.zarr.\n",
- "100%|██████████| 72000.0/72000.0 [00:00<00:00, 157840.98it/s]\n"
+ "100%|██████████| 72000.0/72000.0 [00:00<00:00, 143717.52it/s]\n"
]
}
],
@@ -13768,7 +13764,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "First, we need to create a new `Particle` class that includes three extra variables. The `distance` variable will be written to output, but the auxiliary variables `prev_lon` and `prev_lat` won't be written to output (can be controlled using the `to_write` keyword)\n"
+ "First, we need to add three extra variables to the Particle Class. The `distance` variable will be written to output, but the auxiliary variables `prev_lon` and `prev_lat` won't be written to output (can be controlled using the `to_write` keyword)."
]
},
{
@@ -13777,19 +13773,13 @@
"metadata": {},
"outputs": [],
"source": [
- "class DistParticle(JITParticle):\n",
- " \"\"\"Define a new particle class that contains three extra variables\"\"\"\n",
- "\n",
- " # the distance travelled by the particle\n",
- " distance = Variable(\"distance\", initial=0.0, dtype=np.float32)\n",
+ "extra_vars = [\n",
+ " Variable(\"distance\", initial=0.0, dtype=np.float32),\n",
+ " Variable(\"prev_lon\", dtype=np.float32, to_write=False, initial=attrgetter(\"lon\")),\n",
+ " Variable(\"prev_lat\", dtype=np.float32, to_write=False, initial=attrgetter(\"lat\")),\n",
+ "]\n",
"\n",
- " # the previous longitude and latitude of the particle\n",
- " prev_lon = Variable(\n",
- " \"prev_lon\", dtype=np.float32, to_write=False, initial=attrgetter(\"lon\")\n",
- " )\n",
- " prev_lat = Variable(\n",
- " \"prev_lat\", dtype=np.float32, to_write=False, initial=attrgetter(\"lat\")\n",
- " )"
+ "DistParticle = JITParticle.add_variables(extra_vars)"
]
},
{
@@ -13879,7 +13869,7 @@
"output_type": "stream",
"text": [
"INFO: Output files are stored in GlobCurrentParticles_Dist.zarr.\n",
- "100%|██████████| 518400.0/518400.0 [00:03<00:00, 147724.13it/s]\n"
+ "100%|██████████| 518400.0/518400.0 [00:03<00:00, 136275.28it/s]\n"
]
}
],
@@ -13936,7 +13926,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.9"
+ "version": "3.11.6"
}
},
"nbformat": 4,
diff --git a/docs/examples/tutorial_Argofloats.ipynb b/docs/examples/tutorial_Argofloats.ipynb
index 3a458c7359..f672e604b8 100644
--- a/docs/examples/tutorial_Argofloats.ipynb
+++ b/docs/examples/tutorial_Argofloats.ipynb
@@ -128,19 +128,21 @@
"\n",
"\n",
"# Define a new Particle type including extra Variables\n",
- "class ArgoParticle(JITParticle):\n",
- " # Phase of cycle:\n",
- " # init_descend=0,\n",
- " # drift=1,\n",
- " # profile_descend=2,\n",
- " # profile_ascend=3,\n",
- " # transmit=4\n",
- " cycle_phase = Variable(\"cycle_phase\", dtype=np.int32, initial=0.0)\n",
- " cycle_age = Variable(\"cycle_age\", dtype=np.float32, initial=0.0)\n",
- " drift_age = Variable(\"drift_age\", dtype=np.float32, initial=0.0)\n",
- " # if fieldset has temperature\n",
- " # temp = Variable('temp', dtype=np.float32, initial=np.nan)\n",
- "\n",
+ "ArgoParticle = JITParticle.add_variables(\n",
+ " [\n",
+ " # Phase of cycle:\n",
+ " # init_descend=0,\n",
+ " # drift=1,\n",
+ " # profile_descend=2,\n",
+ " # profile_ascend=3,\n",
+ " # transmit=4\n",
+ " Variable(\"cycle_phase\", dtype=np.int32, initial=0.0),\n",
+ " Variable(\"cycle_age\", dtype=np.float32, initial=0.0),\n",
+ " Variable(\"drift_age\", dtype=np.float32, initial=0.0),\n",
+ " # if fieldset has temperature\n",
+ " # Variable('temp', dtype=np.float32, initial=np.nan),\n",
+ " ]\n",
+ ")\n",
"\n",
"# Initiate one Argo float in the Agulhas Current\n",
"pset = ParticleSet(\n",
diff --git a/docs/examples/tutorial_NestedFields.ipynb b/docs/examples/tutorial_NestedFields.ipynb
index e9b1983c03..664e01365b 100644
--- a/docs/examples/tutorial_NestedFields.ipynb
+++ b/docs/examples/tutorial_NestedFields.ipynb
@@ -226,16 +226,11 @@
}
],
"source": [
- "from parcels import Variable\n",
- "\n",
- "\n",
"def SampleNestedFieldIndex(particle, fieldset, time):\n",
" particle.f = fieldset.F[time, particle.depth, particle.lat, particle.lon]\n",
"\n",
"\n",
- "class SampleParticle(JITParticle):\n",
- " f = Variable(\"f\", dtype=np.int32)\n",
- "\n",
+ "SampleParticle = JITParticle.add_variable(\"f\", dtype=np.int32)\n",
"\n",
"pset = ParticleSet(fieldset, pclass=SampleParticle, lon=[1000], lat=[500])\n",
"pset.execute(SampleNestedFieldIndex, runtime=1)\n",
diff --git a/docs/examples/tutorial_analyticaladvection.ipynb b/docs/examples/tutorial_analyticaladvection.ipynb
index 8c25d8461e..b3b67496f3 100644
--- a/docs/examples/tutorial_analyticaladvection.ipynb
+++ b/docs/examples/tutorial_analyticaladvection.ipynb
@@ -154,9 +154,12 @@
" particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon]\n",
"\n",
"\n",
- "class MyParticle(ScipyParticle):\n",
- " radius = Variable(\"radius\", dtype=np.float32, initial=0.0)\n",
- " radius_start = Variable(\"radius_start\", dtype=np.float32, initial=0.0)\n",
+ "MyParticle = ScipyParticle.add_variables(\n",
+ " [\n",
+ " Variable(\"radius\", dtype=np.float32, initial=0.0),\n",
+ " Variable(\"radius_start\", dtype=np.float32, initial=0.0),\n",
+ " ]\n",
+ ")\n",
"\n",
"\n",
"pset = ParticleSet(fieldsetRR, pclass=MyParticle, lon=0, lat=4e3, time=0)\n",
diff --git a/docs/examples/tutorial_delaystart.ipynb b/docs/examples/tutorial_delaystart.ipynb
index c4f3cc7259..68d0bce87e 100644
--- a/docs/examples/tutorial_delaystart.ipynb
+++ b/docs/examples/tutorial_delaystart.ipynb
@@ -32416,10 +32416,13 @@
}
],
"source": [
- "class GrowingParticle(JITParticle):\n",
- " mass = Variable(\"mass\", initial=0)\n",
- " splittime = Variable(\"splittime\", initial=-1)\n",
- " splitmass = Variable(\"splitmass\", initial=0)\n",
+ "GrowingParticle = JITParticle.add_variables(\n",
+ " [\n",
+ " Variable(\"mass\", initial=0),\n",
+ " Variable(\"splittime\", initial=-1),\n",
+ " Variable(\"splitmass\", initial=0),\n",
+ " ]\n",
+ ")\n",
"\n",
"\n",
"def GrowParticles(particle, fieldset, time):\n",
diff --git a/docs/examples/tutorial_interaction.ipynb b/docs/examples/tutorial_interaction.ipynb
index 23faff0997..2d6d09f814 100644
--- a/docs/examples/tutorial_interaction.ipynb
+++ b/docs/examples/tutorial_interaction.ipynb
@@ -152,8 +152,9 @@
"\n",
"# Create custom particle class with extra variable that indicates\n",
"# whether the interaction kernel should be executed on this particle.\n",
- "class InteractingParticle(ScipyParticle):\n",
- " attractor = Variable(\"attractor\", dtype=np.bool_, to_write=\"once\")\n",
+ "InteractingParticle = ScipyParticle.add_variable(\n",
+ " \"attractor\", dtype=np.bool_, to_write=\"once\"\n",
+ ")\n",
"\n",
"\n",
"attractor = [\n",
@@ -30838,10 +30839,12 @@
"\n",
"# Create custom InteractionParticle class\n",
"# with extra variables nearest_neighbor and mass\n",
- "class MergeParticle(ScipyInteractionParticle):\n",
- " nearest_neighbor = Variable(\"nearest_neighbor\", dtype=np.int64, to_write=False)\n",
- " mass = Variable(\"mass\", initial=1, dtype=np.float32)\n",
- "\n",
+ "MergeParticle = ScipyInteractionParticle.add_variables(\n",
+ " [\n",
+ " Variable(\"nearest_neighbor\", dtype=np.int64, to_write=False),\n",
+ " Variable(\"mass\", initial=1, dtype=np.float32),\n",
+ " ]\n",
+ ")\n",
"\n",
"pset = ParticleSet(\n",
" fieldset=fieldset,\n",
diff --git a/docs/examples/tutorial_interpolation.ipynb b/docs/examples/tutorial_interpolation.ipynb
index b23666a0c7..421bfc830e 100644
--- a/docs/examples/tutorial_interpolation.ipynb
+++ b/docs/examples/tutorial_interpolation.ipynb
@@ -28,7 +28,7 @@
"import numpy as np\n",
"from matplotlib import cm\n",
"\n",
- "from parcels import FieldSet, JITParticle, ParticleSet, Variable"
+ "from parcels import FieldSet, JITParticle, ParticleSet"
]
},
{
@@ -74,8 +74,7 @@
"metadata": {},
"outputs": [],
"source": [
- "class SampleParticle(JITParticle):\n",
- " p = Variable(\"p\", dtype=np.float32)\n",
+ "SampleParticle = JITParticle.add_variable(\"p\", dtype=np.float32)\n",
"\n",
"\n",
"def SampleP(particle, fieldset, time):\n",
diff --git a/docs/examples/tutorial_parcels_structure.ipynb b/docs/examples/tutorial_parcels_structure.ipynb
index 94ac9d9c0c..1b98374be1 100644
--- a/docs/examples/tutorial_parcels_structure.ipynb
+++ b/docs/examples/tutorial_parcels_structure.ipynb
@@ -201,12 +201,8 @@
"source": [
"from parcels import JITParticle, ParticleSet, Variable\n",
"\n",
- "# Define a new particle class\n",
- "\n",
- "\n",
- "class AgeParticle(JITParticle): # It is a JIT particle\n",
- " age = Variable(\"age\", initial=0) # Variable 'age' is added with initial value 0.\n",
- "\n",
+ "# Define a new particleclass with Variable 'age' with initial value 0.\n",
+ "AgeParticle = JITParticle.add_variable(Variable(\"age\", initial=0))\n",
"\n",
"pset = ParticleSet(\n",
" fieldset=fieldset, # the fields that the particleset uses\n",
diff --git a/docs/examples/tutorial_particle_field_interaction.ipynb b/docs/examples/tutorial_particle_field_interaction.ipynb
index 5f8dfa9cd1..143837ba91 100644
--- a/docs/examples/tutorial_particle_field_interaction.ipynb
+++ b/docs/examples/tutorial_particle_field_interaction.ipynb
@@ -56,7 +56,6 @@
" FieldSet,\n",
" ParticleSet,\n",
" ScipyParticle,\n",
- " Variable,\n",
" download_example_dataset,\n",
")"
]
@@ -160,10 +159,7 @@
"metadata": {},
"outputs": [],
"source": [
- "class VectorParticle(ScipyParticle):\n",
- " \"\"\"initialise particle concentration c with a non-zero value\"\"\"\n",
- "\n",
- " c = Variable(\"c\", dtype=np.float32, initial=100.0)\n",
+ "VectorParticle = ScipyParticle.add_variable(\"c\", dtype=np.float32, initial=100.0)\n",
"\n",
"\n",
"def Interaction(particle, fieldset, time):\n",
diff --git a/docs/examples/tutorial_sampling.ipynb b/docs/examples/tutorial_sampling.ipynb
index d702feac23..b2889f0869 100644
--- a/docs/examples/tutorial_sampling.ipynb
+++ b/docs/examples/tutorial_sampling.ipynb
@@ -125,9 +125,7 @@
"metadata": {},
"outputs": [],
"source": [
- "class SampleParticle(JITParticle): # Define a new particle class\n",
- " temperature = Variable(\"temperature\")\n",
- "\n",
+ "SampleParticle = JITParticle.add_variable(\"temperature\")\n",
"\n",
"pset = ParticleSet(\n",
" fieldset=fieldset, pclass=SampleParticle, lon=lon, lat=lat, time=time\n",
@@ -307,9 +305,12 @@
"metadata": {},
"outputs": [],
"source": [
- "class SampleParticle(JITParticle):\n",
- " U = Variable(\"U\", dtype=np.float32, initial=np.nan)\n",
- " V = Variable(\"V\", dtype=np.float32, initial=np.nan)\n",
+ "SampleParticle = JITParticle.add_variables(\n",
+ " [\n",
+ " Variable(\"U\", dtype=np.float32, initial=np.nan),\n",
+ " Variable(\"V\", dtype=np.float32, initial=np.nan),\n",
+ " ]\n",
+ ")\n",
"\n",
"\n",
"def SampleVel_correct(particle, fieldset, time):\n",
@@ -340,12 +341,7 @@
"metadata": {},
"outputs": [],
"source": [
- "class SampleParticleOnce(JITParticle):\n",
- " \"\"\"Define a new particle class with Variable 'temperature'\n",
- " initially zero and only written once\"\"\"\n",
- "\n",
- " temperature = Variable(\"temperature\", initial=0, to_write=\"once\")\n",
- "\n",
+ "SampleParticleOnce = JITParticle.add_variable(\"temperature\", initial=0, to_write=\"once\")\n",
"\n",
"pset = ParticleSet(\n",
" fieldset=fieldset, pclass=SampleParticleOnce, lon=lon, lat=lat, time=time\n",
diff --git a/parcels/particle.py b/parcels/particle.py
index 0e4765f343..769dda2aba 100644
--- a/parcels/particle.py
+++ b/parcels/particle.py
@@ -127,43 +127,7 @@ def supported_dtypes(self):
return [np.int32, np.uint32, np.int64, np.uint64, np.float32, np.double, np.float64, c_void_p]
-class _Particle:
- """Private base class for all particle types."""
-
- lastID = 0 # class-level variable keeping track of last Particle ID used
-
- def __init__(self):
- ptype = self.getPType()
- # Explicit initialisation of all particle variables
- for v in ptype.variables:
- if isinstance(v.initial, attrgetter):
- initial = v.initial(self)
- else:
- initial = v.initial
- # Enforce type of initial value
- if v.dtype != c_void_p:
- setattr(self, v.name, v.dtype(initial))
-
- # Placeholder for explicit error handling
- self.exception = None
-
- def __del__(self):
- pass # superclass is 'object', and object itself has no destructor, hence 'pass'
-
- @classmethod
- def getPType(cls):
- return ParticleType(cls)
-
- @classmethod
- def getInitialValue(cls, ptype, name):
- return next((v.initial for v in ptype.variables if v.name is name), None)
-
- @classmethod
- def setLastID(cls, offset):
- _Particle.lastID = offset
-
-
-class ScipyParticle(_Particle):
+class ScipyParticle:
"""Class encapsulating the basic attributes of a particle, to be executed in SciPy mode.
Parameters
@@ -198,6 +162,8 @@ class ScipyParticle(_Particle):
dt = Variable('dt', dtype=np.float64, to_write=False)
state = Variable('state', dtype=np.int32, initial=StatusCode.Evaluate, to_write=False)
+ lastID = 0 # class-level variable keeping track of last Particle ID used
+
def __init__(self, lon, lat, pid, fieldset=None, ngrids=None, depth=0., time=0., cptr=None):
# Enforce default values through Variable descriptor
@@ -210,14 +176,23 @@ def __init__(self, lon, lat, pid, fieldset=None, ngrids=None, depth=0., time=0.,
type(self).time.initial = time
type(self).time_nextloop.initial = time
type(self).id.initial = pid
- _Particle.lastID = max(_Particle.lastID, pid)
+ type(self).lastID = max(type(self).lastID, pid)
type(self).obs_written.initial = 0
type(self).dt.initial = None
- super().__init__()
+ ptype = self.getPType()
+ # Explicit initialisation of all particle variables
+ for v in ptype.variables:
+ if isinstance(v.initial, attrgetter):
+ initial = v.initial(self)
+ else:
+ initial = v.initial
+ # Enforce type of initial value
+ if v.dtype != c_void_p:
+ setattr(self, v.name, v.dtype(initial))
def __del__(self):
- super().__del__()
+ pass # superclass is 'object', and object itself has no destructor, hence 'pass'
def __repr__(self):
time_string = 'not_yet_set' if self.time is None or np.isnan(self.time) else f"{self.time:f}"
@@ -229,6 +204,54 @@ def __repr__(self):
str += f"{var}={getattr(self, var):f}, "
return str + f"time={time_string})"
+ @classmethod
+ def add_variable(cls, var, *args, **kwargs):
+ """Add a new variable to the Particle class
+
+ Parameters
+ ----------
+ var : str, Variable or list of Variables
+ Variable object to be added. Can be the name of the Variable,
+ a Variable object, or a list of Variable objects
+ """
+ if isinstance(var, list):
+ return cls.add_variables(var)
+ if not isinstance(var, Variable):
+ if len(args) > 0 and 'dtype' not in kwargs:
+ kwargs['dtype'] = args[0]
+ if len(args) > 1 and 'initial' not in kwargs:
+ kwargs['initial'] = args[1]
+ if len(args) > 2 and 'to_write' not in kwargs:
+ kwargs['to_write'] = args[2]
+ dtype = kwargs.pop('dtype', np.float32)
+ initial = kwargs.pop('initial', 0)
+ to_write = kwargs.pop('to_write', True)
+ var = Variable(var, dtype=dtype, initial=initial, to_write=to_write)
+
+ class NewParticle(cls):
+ pass
+
+ setattr(NewParticle, var.name, var)
+ return NewParticle
+
+ @classmethod
+ def add_variables(cls, variables):
+ """Add multiple new variables to the Particle class
+
+ Parameters
+ ----------
+ variables : list of Variable
+ Variable objects to be added. Has to be a list of Variable objects
+ """
+ NewParticle = cls
+ for var in variables:
+ NewParticle = NewParticle.add_variable(var)
+ return NewParticle
+
+ @classmethod
+ def getPType(cls):
+ return ParticleType(cls)
+
@classmethod
def set_lonlatdepth_dtype(cls, dtype):
cls.lon.dtype = dtype
@@ -238,10 +261,14 @@ def set_lonlatdepth_dtype(cls, dtype):
cls.lat_nextloop.dtype = dtype
cls.depth_nextloop.dtype = dtype
+ @classmethod
+ def setLastID(cls, offset):
+ ScipyParticle.lastID = offset
+
-class ScipyInteractionParticle(ScipyParticle):
- vert_dist = Variable("vert_dist", dtype=np.float32)
- horiz_dist = Variable("horiz_dist", dtype=np.float32)
+ScipyInteractionParticle = ScipyParticle.add_variables([
+ Variable("vert_dist", dtype=np.float32),
+ Variable("horiz_dist", dtype=np.float32)])
class JITParticle(ScipyParticle):
diff --git a/tests/test_advection.py b/tests/test_advection.py
index 9d68ec2b28..7e11338452 100644
--- a/tests/test_advection.py
+++ b/tests/test_advection.py
@@ -19,7 +19,6 @@
ParticleSet,
ScipyParticle,
StatusCode,
- Variable,
)
ptype = {'scipy': ScipyParticle, 'jit': JITParticle}
@@ -301,8 +300,7 @@ def test_stationary_eddy(fieldset_stationary, mode, method, rtol, diffField, npa
dt = delta(minutes=3).total_seconds()
endtime = delta(hours=6).total_seconds()
- class RK45Particles(ptype[mode]):
- next_dt = Variable('next_dt', dtype=np.float32, initial=dt)
+ RK45Particles = ptype[mode].add_variable('next_dt', dtype=np.float32, initial=dt)
pclass = RK45Particles if method == 'RK45' else ptype[mode]
pset = ParticleSet(fieldset, pclass=pclass, lon=lon, lat=lat)
@@ -395,8 +393,7 @@ def test_moving_eddy(fieldset_moving, mode, method, rtol, diffField, npart=1):
dt = delta(minutes=3).total_seconds()
endtime = delta(hours=6).total_seconds()
- class RK45Particles(ptype[mode]):
- next_dt = Variable('next_dt', dtype=np.float32, initial=dt)
+ RK45Particles = ptype[mode].add_variable('next_dt', dtype=np.float32, initial=dt)
pclass = RK45Particles if method == 'RK45' else ptype[mode]
pset = ParticleSet(fieldset, pclass=pclass, lon=lon, lat=lat)
@@ -461,8 +458,7 @@ def test_decaying_eddy(fieldset_decaying, mode, method, rtol, diffField, npart=1
dt = delta(minutes=3).total_seconds()
endtime = delta(hours=6).total_seconds()
- class RK45Particles(ptype[mode]):
- next_dt = Variable('next_dt', dtype=np.float32, initial=dt)
+ RK45Particles = ptype[mode].add_variable('next_dt', dtype=np.float32, initial=dt)
pclass = RK45Particles if method == 'RK45' else ptype[mode]
pset = ParticleSet(fieldset, pclass=pclass, lon=lon, lat=lat)
diff --git a/tests/test_diffusion.py b/tests/test_diffusion.py
index dda28ed343..02af182534 100644
--- a/tests/test_diffusion.py
+++ b/tests/test_diffusion.py
@@ -15,7 +15,6 @@
ParticleSet,
RectilinearZGrid,
ScipyParticle,
- Variable,
)
ptype = {'scipy': ScipyParticle, 'jit': JITParticle}
@@ -132,8 +131,7 @@ def test_randomvonmises(mode, mu, kappa, npart=10000):
# Set random seed
ParcelsRandom.seed(1234)
- class AngleParticle(ptype[mode]):
- angle = Variable('angle')
+ AngleParticle = ptype[mode].add_variable('angle')
pset = ParticleSet(fieldset=fieldset, pclass=AngleParticle, lon=np.zeros(npart), lat=np.zeros(npart), depth=np.zeros(npart))
def vonmises(particle, fieldset, time):
diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py
index a5ba1bf49c..2f17eb09a2 100644
--- a/tests/test_fieldset.py
+++ b/tests/test_fieldset.py
@@ -834,13 +834,14 @@ def sampleTemp(particle, fieldset, time):
# test if we can sample a non-timevarying field too
particle.d = fieldset.D[0, 0, particle.lat, particle.lon]
- class MyParticle(ptype[mode]):
- temp = Variable('temp', dtype=np.float32, initial=20.)
- u1 = Variable('u1', dtype=np.float32, initial=0.)
- u2 = Variable('u2', dtype=np.float32, initial=0.)
- v1 = Variable('v1', dtype=np.float32, initial=0.)
- v2 = Variable('v2', dtype=np.float32, initial=0.)
- d = Variable('d', dtype=np.float32, initial=0.)
+ MyParticle = ptype[mode].add_variables([
+ Variable('temp', dtype=np.float32, initial=20.),
+ Variable('u1', dtype=np.float32, initial=0.),
+ Variable('u2', dtype=np.float32, initial=0.),
+ Variable('v1', dtype=np.float32, initial=0.),
+ Variable('v2', dtype=np.float32, initial=0.),
+ Variable('d', dtype=np.float32, initial=0.),
+ ])
pset = ParticleSet.from_list(fieldset, pclass=MyParticle, lon=[0.5], lat=[0.5], depth=[0.5])
pset.execute(AdvectionRK4_3D + pset.Kernel(sampleTemp), runtime=delta(hours=51), dt=delta(hours=dt_sign*1))
@@ -944,10 +945,10 @@ def test_fieldset_initialisation_kernel_dask(time2, tmpdir, filename='test_parce
def SampleField(particle, fieldset, time):
particle.u_kernel, particle.v_kernel = fieldset.UV[time, particle.depth, particle.lat, particle.lon]
- class SampleParticle(JITParticle):
- u_kernel = Variable('u_kernel', dtype=np.float32, initial=0.)
- v_kernel = Variable('v_kernel', dtype=np.float32, initial=0.)
- u_scipy = Variable('u_scipy', dtype=np.float32, initial=0.)
+ SampleParticle = JITParticle.add_variables([
+ Variable('u_kernel', dtype=np.float32, initial=0.),
+ Variable('v_kernel', dtype=np.float32, initial=0.),
+ Variable('u_scipy', dtype=np.float32, initial=0.)])
pset = ParticleSet(fieldset, pclass=SampleParticle, time=[0, time2], lon=[0.5, 0.5], lat=[0.5, 0.5], depth=[0.5, 0.5])
@@ -1084,8 +1085,7 @@ def test_deferredload_simplefield(mode, direction, time_extrapolation, tmpdir, t
fieldset = FieldSet.from_netcdf(filename, {'U': 'U', 'V': 'V'}, {'lon': 'x', 'lat': 'y', 'time': 't'},
deferred_load=True, mesh='flat', allow_time_extrapolation=time_extrapolation)
- class SamplingParticle(ptype[mode]):
- p = Variable('p')
+ SamplingParticle = ptype[mode].add_variable("p")
pset = ParticleSet(fieldset, SamplingParticle, lon=0.5, lat=0.5)
def SampleU(particle, fieldset, time):
diff --git a/tests/test_fieldset_sampling.py b/tests/test_fieldset_sampling.py
index 68aaca45f0..8ba473dc60 100644
--- a/tests/test_fieldset_sampling.py
+++ b/tests/test_fieldset_sampling.py
@@ -23,11 +23,10 @@
def pclass(mode):
- class SampleParticle(ptype[mode]):
- u = Variable('u', dtype=np.float32)
- v = Variable('v', dtype=np.float32)
- p = Variable('p', dtype=np.float32)
- return SampleParticle
+ return ptype[mode].add_variables([
+ Variable('u', dtype=np.float32),
+ Variable('v', dtype=np.float32),
+ Variable('p', dtype=np.float32)])
def k_sample_uv():
@@ -628,8 +627,7 @@ def test_sampling_multigrids_non_vectorfield_from_file(mode, npart, tmpdir, chs,
assert fieldset.U.grid is fieldset.V.grid
assert fieldset.U.grid is not fieldset.B.grid
- class TestParticle(ptype[mode]):
- sample_var = Variable('sample_var', initial=0.)
+ TestParticle = ptype[mode].add_variable('sample_var', initial=0.)
pset = ParticleSet.from_line(fieldset, pclass=TestParticle, start=[0.3, 0.3], finish=[0.7, 0.7], size=npart)
@@ -672,8 +670,7 @@ def test_sampling_multigrids_non_vectorfield(mode, npart):
assert fieldset.U.grid is fieldset.V.grid
assert fieldset.U.grid is not fieldset.B.grid
- class TestParticle(ptype[mode]):
- sample_var = Variable('sample_var', initial=0.)
+ TestParticle = ptype[mode].add_variable('sample_var', initial=0.)
pset = ParticleSet.from_line(fieldset, pclass=TestParticle, start=[0.3, 0.3], finish=[0.7, 0.7], size=npart)
diff --git a/tests/test_grids.py b/tests/test_grids.py
index a317cbbce8..76a11e0ea8 100644
--- a/tests/test_grids.py
+++ b/tests/test_grids.py
@@ -84,9 +84,9 @@ def sampleTemp(particle, fieldset, time):
particle.temp0 = fieldset.temp0[time+particle.dt, particle.depth, particle.lat, particle.lon]
particle.temp1 = fieldset.temp1[time+particle.dt, particle.depth, particle.lat, particle.lon]
- class MyParticle(ptype[mode]):
- temp0 = Variable('temp0', dtype=np.float32, initial=20.)
- temp1 = Variable('temp1', dtype=np.float32, initial=20.)
+ MyParticle = ptype[mode].add_variables([
+ Variable('temp0', dtype=np.float32, initial=20.),
+ Variable('temp1', dtype=np.float32, initial=20.)])
pset = ParticleSet.from_list(fieldset, MyParticle, lon=[3001], lat=[5001], repeatdt=1)
@@ -224,8 +224,7 @@ def bath_func(lon):
def sampleTemp(particle, fieldset, time):
particle.temp = fieldset.temp[time, particle.depth, particle.lat, particle.lon]
- class MyParticle(ptype[mode]):
- temp = Variable('temp', dtype=np.float32, initial=20.)
+ MyParticle = ptype[mode].add_variable('temp', dtype=np.float32, initial=20.)
lon = 400
lat = 0
@@ -309,8 +308,7 @@ def bath_func(lon):
rel_depth_field = Field('relDepth', rel_depth_data, grid=grid)
fieldset = FieldSet(u_field, v_field, fields={'relDepth': rel_depth_field})
- class MyParticle(ptype[mode]):
- relDepth = Variable('relDepth', dtype=np.float32, initial=20.)
+ MyParticle = ptype[mode].add_variable('relDepth', dtype=np.float32, initial=20.)
def moveEast(particle, fieldset, time):
particle_dlon += 5 * particle.dt # noqa
@@ -352,8 +350,7 @@ def sampleSpeed(particle, fieldset, time):
u, v = fieldset.UV[time, particle.depth, particle.lat, particle.lon]
particle.speed = math.sqrt(u*u+v*v)
- class MyParticle(ptype[mode]):
- speed = Variable('speed', dtype=np.float32, initial=0.)
+ MyParticle = ptype[mode].add_variable('speed', dtype=np.float32, initial=0.)
pset = ParticleSet.from_list(fieldset, MyParticle, lon=[400, -200], lat=[600, 600])
pset.execute(pset.Kernel(sampleSpeed), runtime=1)
@@ -380,9 +377,9 @@ def test_nemo_grid(mode):
def sampleVel(particle, fieldset, time):
(particle.zonal, particle.meridional) = fieldset.UV[time, particle.depth, particle.lat, particle.lon]
- class MyParticle(ptype[mode]):
- zonal = Variable('zonal', dtype=np.float32, initial=0.)
- meridional = Variable('meridional', dtype=np.float32, initial=0.)
+ MyParticle = ptype[mode].add_variables([
+ Variable('zonal', dtype=np.float32, initial=0.),
+ Variable('meridional', dtype=np.float32, initial=0.)])
lonp = 175.5
latp = 81.5
@@ -437,9 +434,9 @@ def test_cgrid_uniform_2dvel(mode, time):
def sampleVel(particle, fieldset, time):
(particle.zonal, particle.meridional) = fieldset.UV[time, particle.depth, particle.lat, particle.lon]
- class MyParticle(ptype[mode]):
- zonal = Variable('zonal', dtype=np.float32, initial=0.)
- meridional = Variable('meridional', dtype=np.float32, initial=0.)
+ MyParticle = ptype[mode].add_variables([
+ Variable('zonal', dtype=np.float32, initial=0.),
+ Variable('meridional', dtype=np.float32, initial=0.)])
pset = ParticleSet.from_list(fieldset, MyParticle, lon=.7, lat=.3)
pset.execute(pset.Kernel(sampleVel), runtime=1)
@@ -497,10 +494,10 @@ def test_cgrid_uniform_3dvel(mode, vert_mode, time):
def sampleVel(particle, fieldset, time):
(particle.zonal, particle.meridional, particle.vertical) = fieldset.UVW[time, particle.depth, particle.lat, particle.lon]
- class MyParticle(ptype[mode]):
- zonal = Variable('zonal', dtype=np.float32, initial=0.)
- meridional = Variable('meridional', dtype=np.float32, initial=0.)
- vertical = Variable('vertical', dtype=np.float32, initial=0.)
+ MyParticle = ptype[mode].add_variables([
+ Variable('zonal', dtype=np.float32, initial=0.),
+ Variable('meridional', dtype=np.float32, initial=0.),
+ Variable('vertical', dtype=np.float32, initial=0.)])
pset = ParticleSet.from_list(fieldset, MyParticle, lon=.7, lat=.3, depth=.2)
pset.execute(pset.Kernel(sampleVel), runtime=1)
@@ -554,10 +551,10 @@ def test_cgrid_uniform_3dvel_spherical(mode, vert_mode, time):
def sampleVel(particle, fieldset, time):
(particle.zonal, particle.meridional, particle.vertical) = fieldset.UVW[time, particle.depth, particle.lat, particle.lon]
- class MyParticle(ptype[mode]):
- zonal = Variable('zonal', dtype=np.float32, initial=0.)
- meridional = Variable('meridional', dtype=np.float32, initial=0.)
- vertical = Variable('vertical', dtype=np.float32, initial=0.)
+ MyParticle = ptype[mode].add_variables([
+ Variable('zonal', dtype=np.float32, initial=0.),
+ Variable('meridional', dtype=np.float32, initial=0.),
+ Variable('vertical', dtype=np.float32, initial=0.)])
lonp = 179.8
latp = 81.35
@@ -601,12 +598,12 @@ def OutBoundsError(particle, fieldset, time):
particle_ddepth -= 3 # noqa
particle.state = StatusCode.Success
- class MyParticle(ptype[mode]):
- zonal = Variable('zonal', dtype=np.float32, initial=0.)
- meridional = Variable('meridional', dtype=np.float32, initial=0.)
- vert = Variable('vert', dtype=np.float32, initial=0.)
- tracer = Variable('tracer', dtype=np.float32, initial=0.)
- out_of_bounds = Variable('out_of_bounds', dtype=np.float32, initial=0.)
+ MyParticle = ptype[mode].add_variables([
+ Variable('zonal', dtype=np.float32, initial=0.),
+ Variable('meridional', dtype=np.float32, initial=0.),
+ Variable('vert', dtype=np.float32, initial=0.),
+ Variable('tracer', dtype=np.float32, initial=0.),
+ Variable('out_of_bounds', dtype=np.float32, initial=0.)])
pset = ParticleSet.from_list(fieldset, MyParticle, lon=[3, 5, 1], lat=[3, 5, 1], depth=[3, 7, 11])
pset.execute(pset.Kernel(sampleVel) + OutBoundsError, runtime=1)
@@ -699,9 +696,9 @@ def UpdateR(particle, fieldset, time):
particle.radius_start = fieldset.R[time, particle.depth, particle.lat, particle.lon]
particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon]
- class MyParticle(ptype[mode]):
- radius = Variable('radius', dtype=np.float32, initial=0.)
- radius_start = Variable('radius_start', dtype=np.float32, initial=0.)
+ MyParticle = ptype[mode].add_variables([
+ Variable('radius', dtype=np.float32, initial=0.),
+ Variable('radius_start', dtype=np.float32, initial=0.)])
pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=4e3, time=0)
@@ -777,9 +774,9 @@ def UpdateR(particle, fieldset, time):
particle.radius_start = fieldset.R[time, particle.depth, particle.lat, particle.lon]
particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon]
- class MyParticle(ptype[mode]):
- radius = Variable('radius', dtype=np.float32, initial=0.)
- radius_start = Variable('radius_start', dtype=np.float32, initial=0.)
+ MyParticle = ptype[mode].add_variables([
+ Variable('radius', dtype=np.float32, initial=0.),
+ Variable('radius_start', dtype=np.float32, initial=0.)])
pset = ParticleSet(fieldset, pclass=MyParticle, depth=4e3, lon=0, lat=0, time=0)
@@ -856,9 +853,9 @@ def UpdateR(particle, fieldset, time):
particle.radius_start = fieldset.R[time, particle.depth, particle.lat, particle.lon]
particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon]
- class MyParticle(ptype[mode]):
- radius = Variable('radius', dtype=np.float32, initial=0.)
- radius_start = Variable('radius_start', dtype=np.float32, initial=0.)
+ MyParticle = ptype[mode].add_variables([
+ Variable('radius', dtype=np.float32, initial=0.),
+ Variable('radius_start', dtype=np.float32, initial=0.)])
pset = ParticleSet(fieldset, pclass=MyParticle, depth=-9.995e3, lon=0, lat=0, time=0)
@@ -924,10 +921,10 @@ def VelocityInterpolator(particle, fieldset, time):
particle.Vvel = fieldset.V[time, particle.depth, particle.lat, particle.lon]
particle.Wvel = fieldset.W[time, particle.depth, particle.lat, particle.lon]
- class myParticle(ptype[mode]):
- Uvel = Variable("Uvel", dtype=np.float32, initial=0.0)
- Vvel = Variable("Vvel", dtype=np.float32, initial=0.0)
- Wvel = Variable("Wvel", dtype=np.float32, initial=0.0)
+ myParticle = ptype[mode].add_variables([
+ Variable("Uvel", dtype=np.float32, initial=0.0),
+ Variable("Vvel", dtype=np.float32, initial=0.0),
+ Variable("Wvel", dtype=np.float32, initial=0.0)])
for pointtype in ["U", "V", "W"]:
if gridindexingtype == "pop":
diff --git a/tests/test_interaction.py b/tests/test_interaction.py
index 2b1f6f2867..11cc8029b5 100644
--- a/tests/test_interaction.py
+++ b/tests/test_interaction.py
@@ -52,11 +52,6 @@ def fieldset(xdim=20, ydim=20, mesh='spherical'):
return FieldSet.from_data(data, dimensions, mesh=mesh)
-class MergeParticle(ScipyInteractionParticle):
- nearest_neighbor = Variable('nearest_neighbor', dtype=np.int64, to_write=False)
- mass = Variable('mass', initial=1, dtype=np.float32)
-
-
@pytest.fixture(name="fieldset")
def fieldset_fixture(xdim=20, ydim=20):
return fieldset(xdim=xdim, ydim=ydim)
@@ -133,6 +128,9 @@ def test_neighbor_merge(fieldset):
lats = [0.0, 0.0, 0.0, 0.0]
# Distance in meters R_earth*0.2 degrees
interaction_distance = 6371000*5.5*np.pi/180
+ MergeParticle = ScipyInteractionParticle.add_variables([
+ Variable('nearest_neighbor', dtype=np.int64, to_write=False),
+ Variable('mass', initial=1, dtype=np.float32)])
pset = ParticleSet(fieldset, pclass=MergeParticle, lon=lons, lat=lats,
interaction_distance=interaction_distance)
pyfunc_inter = (pset.InteractionKernel(NearestNeighborWithinRange)
@@ -144,16 +142,13 @@ def test_neighbor_merge(fieldset):
assert len(pset) == 1
-class AttractingParticle(ScipyInteractionParticle):
- attractor = Variable('attractor', dtype=np.bool_, to_write='once')
-
-
@pytest.mark.parametrize('mode', ['scipy'])
def test_asymmetric_attraction(fieldset, mode):
lons = [0.0, 0.1, 0.2]
lats = [0.0, 0.0, 0.0]
# Distance in meters R_earth*0.2 degrees
interaction_distance = 6371000*5.5*np.pi/180
+ AttractingParticle = ScipyInteractionParticle.add_variable('attractor', dtype=np.bool_, to_write='once')
pset = ParticleSet(fieldset, pclass=AttractingParticle, lon=lons, lat=lats,
interaction_distance=interaction_distance,
attractor=[True, False, False])
diff --git a/tests/test_kernel_execution.py b/tests/test_kernel_execution.py
index 7644d1d63a..836e46864d 100644
--- a/tests/test_kernel_execution.py
+++ b/tests/test_kernel_execution.py
@@ -12,7 +12,6 @@
ParticleSet,
ScipyParticle,
StatusCode,
- Variable,
)
ptype = {'scipy': ScipyParticle, 'jit': JITParticle}
@@ -51,8 +50,7 @@ def MoveLon_Update_dlon(particle, fieldset, time):
def SampleP(particle, fieldset, time):
particle.p = fieldset.U[time, particle.depth, particle.lat, particle.lon]
- class SampleParticle(ptype[mode]):
- p = Variable('p', dtype=np.float32, initial=0.)
+ SampleParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0.)
MoveLon = MoveLon_Update_dlon if kernel_type == 'update_dlon' else MoveLon_Update_Lon
diff --git a/tests/test_kernel_language.py b/tests/test_kernel_language.py
index 4168cdb713..75f96a3e72 100644
--- a/tests/test_kernel_language.py
+++ b/tests/test_kernel_language.py
@@ -55,8 +55,7 @@ def fieldset_fixture(xdim=20, ydim=20):
])
def test_expression_int(mode, name, expr, result, npart=10):
"""Test basic arithmetic expressions."""
- class TestParticle(ptype[mode]):
- p = Variable('p', dtype=np.float32)
+ TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0)
pset = ParticleSet(fieldset(), pclass=TestParticle,
lon=np.linspace(0., 1., npart),
lat=np.zeros(npart) + 0.5)
@@ -74,8 +73,7 @@ class TestParticle(ptype[mode]):
])
def test_expression_float(mode, name, expr, result, npart=10):
"""Test basic arithmetic expressions."""
- class TestParticle(ptype[mode]):
- p = Variable('p', dtype=np.float32)
+ TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0)
pset = ParticleSet(fieldset(), pclass=TestParticle,
lon=np.linspace(0., 1., npart),
lat=np.zeros(npart) + 0.5)
@@ -99,8 +97,7 @@ class TestParticle(ptype[mode]):
])
def test_expression_bool(mode, name, expr, result, npart=10):
"""Test basic arithmetic expressions."""
- class TestParticle(ptype[mode]):
- p = Variable('p', dtype=np.float32)
+ TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0)
pset = ParticleSet(fieldset(), pclass=TestParticle,
lon=np.linspace(0., 1., npart),
lat=np.zeros(npart) + 0.5)
@@ -114,8 +111,7 @@ class TestParticle(ptype[mode]):
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
def test_while_if_break(mode):
"""Test while, if and break commands."""
- class TestParticle(ptype[mode]):
- p = Variable('p', dtype=np.float32, initial=0.)
+ TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0)
pset = ParticleSet(fieldset(), pclass=TestParticle, lon=[0], lat=[0])
def kernel(particle, fieldset, time):
@@ -132,9 +128,9 @@ def kernel(particle, fieldset, time):
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
def test_nested_if(mode):
"""Test nested if commands."""
- class TestParticle(ptype[mode]):
- p0 = Variable('p0', dtype=np.int32, initial=0)
- p1 = Variable('p1', dtype=np.int32, initial=1)
+ TestParticle = ptype[mode].add_variables([
+ Variable('p0', dtype=np.int32, initial=0),
+ Variable('p1', dtype=np.int32, initial=1)])
pset = ParticleSet(fieldset(), pclass=TestParticle, lon=0, lat=0)
def kernel(particle, fieldset, time):
@@ -150,8 +146,7 @@ def kernel(particle, fieldset, time):
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
def test_pass(mode):
"""Test pass commands."""
- class TestParticle(ptype[mode]):
- p = Variable('p', dtype=np.int32, initial=0)
+ TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0)
pset = ParticleSet(fieldset(), pclass=TestParticle, lon=0, lat=0)
def kernel(particle, fieldset, time):
@@ -209,8 +204,7 @@ def kernel_abs(particle, fieldset, time):
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
def test_if_withfield(fieldset, mode):
"""Test combination of if and Field sampling commands."""
- class TestParticle(ptype[mode]):
- p = Variable('p', dtype=np.float32, initial=0.)
+ TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0)
pset = ParticleSet(fieldset, pclass=TestParticle, lon=[0], lat=[0])
def kernel(particle, fieldset, time):
@@ -242,8 +236,7 @@ def kernel(particle, fieldset, time):
@pytest.mark.parametrize('mode', ['scipy'])
def test_print(fieldset, mode, capfd):
"""Test print statements."""
- class TestParticle(ptype[mode]):
- p = Variable('p', dtype=np.float32, initial=0.)
+ TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0)
pset = ParticleSet(fieldset, pclass=TestParticle, lon=[0.5], lat=[0.5])
def kernel(particle, fieldset, time):
@@ -301,8 +294,7 @@ def random_series(npart, rngfunc, rngargs, mode):
])
def test_random_float(mode, rngfunc, rngargs, npart=10):
"""Test basic random number generation."""
- class TestParticle(ptype[mode]):
- p = Variable('p', dtype=np.float32 if rngfunc == 'randint' else np.float32)
+ TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0)
pset = ParticleSet(fieldset(), pclass=TestParticle,
lon=np.linspace(0., 1., npart),
lat=np.zeros(npart) + 0.5)
@@ -317,9 +309,7 @@ class TestParticle(ptype[mode]):
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
@pytest.mark.parametrize('concat', [False, True])
def test_random_kernel_concat(fieldset, mode, concat):
- class TestParticle(ptype[mode]):
- p = Variable('p', dtype=np.float32)
-
+ TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0)
pset = ParticleSet(fieldset, pclass=TestParticle, lon=0, lat=0)
def RandomKernel(particle, fieldset, time):
@@ -374,8 +364,7 @@ def pykernel(particle, fieldset, time):
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
def test_dt_modif_by_kernel(mode):
- class TestParticle(ptype[mode]):
- age = Variable('age', dtype=np.float32)
+ TestParticle = ptype[mode].add_variable('age', dtype=np.float32, initial=0)
pset = ParticleSet(fieldset(), pclass=TestParticle, lon=[0.5], lat=[0])
def modif_dt(particle, fieldset, time):
@@ -428,8 +417,7 @@ def generate_fieldset(xdim=2, ydim=2, zdim=2, tdim=1):
data, dimensions = generate_fieldset()
fieldset = FieldSet.from_data(data, dimensions)
- class DensParticle(ptype[mode]):
- density = Variable('density', dtype=np.float32)
+ DensParticle = ptype[mode].add_variable('density', dtype=np.float32)
pset = ParticleSet(fieldset, pclass=DensParticle, lon=5, lat=5, depth=1000)
@@ -446,22 +434,20 @@ def test_EOSseawaterproperties_kernels(mode):
dimensions={'lat': 0, 'lon': 0, 'depth': 0})
fieldset.add_constant('refpressure', float(0))
- class PoTempParticle(ptype[mode]):
- potemp = Variable('potemp', dtype=np.float32)
- pressure = Variable('pressure', dtype=np.float32, initial=10000)
+ PoTempParticle = ptype[mode].add_variables([
+ Variable('potemp', dtype=np.float32),
+ Variable('pressure', dtype=np.float32, initial=10000)])
pset = ParticleSet(fieldset, pclass=PoTempParticle, lon=5, lat=5, depth=1000)
pset.execute(PtempFromTemp, runtime=1)
assert np.allclose(pset[0].potemp, 36.89073)
- class TempParticle(ptype[mode]):
- temp = Variable('temp', dtype=np.float32)
- pressure = Variable('pressure', dtype=np.float32, initial=10000)
+ TempParticle = ptype[mode].add_variables([
+ Variable('temp', dtype=np.float32),
+ Variable('pressure', dtype=np.float32, initial=10000)])
pset = ParticleSet(fieldset, pclass=TempParticle, lon=5, lat=5, depth=1000)
pset.execute(TempFromPtemp, runtime=1)
assert np.allclose(pset[0].temp, 40)
- class TPressureParticle(ptype[mode]):
- pressure = Variable('pressure', dtype=np.float32)
pset = ParticleSet(fieldset, pclass=TempParticle, lon=5, lat=30, depth=7321.45)
pset.execute(PressureFromLatDepth, runtime=1)
assert np.allclose(pset[0].pressure, 7500, atol=1e-2)
@@ -491,8 +477,7 @@ def generate_fieldset(p, xdim=2, ydim=2, zdim=2, tdim=1):
data, dimensions = generate_fieldset(pressure)
fieldset = FieldSet.from_data(data, dimensions)
- class DensParticle(ptype[mode]):
- density = Variable('density', dtype=np.float32)
+ DensParticle = ptype[mode].add_variable('density', dtype=np.float32)
pset = ParticleSet(fieldset, pclass=DensParticle, lon=5, lat=5, depth=1000)
diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py
index 8977acb0c8..f87ddb9c87 100644
--- a/tests/test_particlefile.py
+++ b/tests/test_particlefile.py
@@ -158,14 +158,12 @@ def Update_lon(particle, fieldset, time):
def test_write_dtypes_pfile(fieldset, mode, tmpdir):
filepath = tmpdir.join("pfile_dtypes.zarr")
- dtypes = ['float32', 'float64', 'int32', 'uint32', 'int64', 'uint64']
+ dtypes = [np.float32, np.float64, np.int32, np.uint32, np.int64, np.uint64]
if mode == 'scipy':
- dtypes.extend(['bool_', 'int8', 'uint8', 'int16', 'uint16'])
+ dtypes.extend([np.bool_, np.int8, np.uint8, np.int16, np.uint16])
- class MyParticle(ptype[mode]):
- for d in dtypes:
- # need an exec() here because we need to dynamically set the variable name
- exec(f'v_{d} = Variable("v_{d}", dtype=np.{d}, initial=0.)')
+ extra_vars = [Variable(f'v_{d.__name__}', dtype=d, initial=0.) for d in dtypes]
+ MyParticle = ptype[mode].add_variables(extra_vars)
pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=0)
pfile = pset.ParticleFile(name=filepath, outputdt=1)
@@ -173,7 +171,7 @@ class MyParticle(ptype[mode]):
ds = xr.open_zarr(filepath, mask_and_scale=False) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float
for d in dtypes:
- assert ds[f'v_{d}'].dtype == d
+ assert ds[f'v_{d.__name__}'].dtype == d
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
@@ -185,9 +183,9 @@ def Update_v(particle, fieldset, time):
particle.v_once += 1.
particle.age += particle.dt
- class MyParticle(ptype[mode]):
- v_once = Variable('v_once', dtype=np.float64, initial=0., to_write='once')
- age = Variable('age', dtype=np.float32, initial=0.)
+ MyParticle = ptype[mode].add_variables([
+ Variable('v_once', dtype=np.float64, initial=0., to_write='once'),
+ Variable('age', dtype=np.float32, initial=0.)])
lon = np.linspace(0, 1, npart)
lat = np.linspace(1, 0, npart)
time = np.arange(0, npart/10., 0.1, dtype=np.float64)
@@ -211,9 +209,9 @@ def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, rep
fieldset.maxvar = maxvar
pset = None
- class MyParticle(ptype[mode]):
- sample_var = Variable('sample_var', initial=0.)
- v_once = Variable('v_once', dtype=np.float64, initial=0., to_write='once')
+ MyParticle = ptype[mode].add_variables([
+ Variable('sample_var', initial=0.),
+ Variable('v_once', dtype=np.float64, initial=0., to_write='once')])
if type == 'repeatdt':
pset = ParticleSet(fieldset, lon=[0], lat=[0], pclass=MyParticle, repeatdt=repeatdt)
@@ -289,10 +287,10 @@ def test_write_xiyi(fieldset, mode, tmpdir):
fieldset.add_field(Field(name='P', data=np.zeros((2, 20)), lon=np.linspace(0, 1, 20), lat=[0, 2]))
dt = 3600
- class XiYiParticle(ptype[mode]):
- pxi0 = Variable('pxi0', dtype=np.int32, initial=0.)
- pxi1 = Variable('pxi1', dtype=np.int32, initial=0.)
- pyi = Variable('pyi', dtype=np.int32, initial=0.)
+ XiYiParticle = ptype[mode].add_variables([
+ Variable('pxi0', dtype=np.int32, initial=0.),
+ Variable('pxi1', dtype=np.int32, initial=0.),
+ Variable('pyi', dtype=np.int32, initial=0.)])
def Get_XiYi(particle, fieldset, time):
"""Kernel to sample the grid indices of the particle.
diff --git a/tests/test_particles.py b/tests/test_particles.py
index 615e132934..c63d27cc08 100644
--- a/tests/test_particles.py
+++ b/tests/test_particles.py
@@ -30,8 +30,7 @@ def fieldset_fixture(xdim=100, ydim=100):
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
def test_print(fieldset, mode):
- class TestParticle(ptype[mode]):
- p = Variable('p', to_write=True)
+ TestParticle = ptype[mode].add_variable('p', to_write=True)
pset = ParticleSet(fieldset, pclass=TestParticle, lon=[0, 1], lat=[0, 1])
print(pset)
@@ -39,11 +38,10 @@ class TestParticle(ptype[mode]):
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
def test_variable_init(fieldset, mode, npart=10):
"""Test that checks correct initialisation of custom variables."""
- class TestParticle(ptype[mode]):
- p_float = Variable('p_float', dtype=np.float32, initial=10.)
- p_double = Variable('p_double', dtype=np.float64, initial=11.)
- p_int = Variable('p_int', dtype=np.int32, initial=12.)
-
+ extra_vars = [Variable('p_float', dtype=np.float32, initial=10.),
+ Variable('p_double', dtype=np.float64, initial=11.)]
+ TestParticle = ptype[mode].add_variables(extra_vars)
+ TestParticle = TestParticle.add_variable('p_int', np.int32, initial=12.)
pset = ParticleSet(fieldset, pclass=TestParticle,
lon=np.linspace(0, 1, npart),
lat=np.linspace(1, 0, npart))
@@ -62,8 +60,7 @@ def addOne(particle, fieldset, time):
@pytest.mark.parametrize('type', ['np.int8', 'mp.float', 'np.int16'])
def test_variable_unsupported_dtypes(fieldset, mode, type):
"""Test that checks errors thrown for unsupported dtypes in JIT mode."""
- class TestParticle(ptype[mode]):
- p = Variable('p', dtype=type, initial=10.)
+ TestParticle = ptype[mode].add_variable('p', dtype=type, initial=10.)
error_thrown = False
try:
ParticleSet(fieldset, pclass=TestParticle, lon=[0], lat=[0])
@@ -76,8 +73,7 @@ class TestParticle(ptype[mode]):
def test_variable_special_names(fieldset, mode):
"""Test that checks errors thrown for special names."""
for vars in ['z', 'lon']:
- class TestParticle(ptype[mode]):
- tmp = Variable(vars, dtype=np.float32, initial=10.)
+ TestParticle = ptype[mode].add_variable(vars, dtype=np.float32, initial=10.)
error_thrown = False
try:
ParticleSet(fieldset, pclass=TestParticle, lon=[0], lat=[0])
@@ -92,14 +88,11 @@ def test_variable_init_relative(fieldset, mode, coord_type, npart=10):
"""Test that checks relative initialisation of custom variables."""
lonlat_type = np.float64 if coord_type == 'double' else np.float32
- class TestParticle(ptype[mode]):
- p_base = Variable('p_base', dtype=lonlat_type, initial=10.)
- p_relative = Variable('p_relative', dtype=lonlat_type,
- initial=attrgetter('p_base'))
- p_lon = Variable('p_lon', dtype=lonlat_type,
- initial=attrgetter('lon'))
- p_lat = Variable('p_lat', dtype=lonlat_type,
- initial=attrgetter('lat'))
+ TestParticle = ptype[mode].add_variables([
+ Variable('p_base', dtype=lonlat_type, initial=10.),
+ Variable('p_relative', dtype=lonlat_type, initial=attrgetter('p_base')),
+ Variable('p_lon', dtype=lonlat_type, initial=attrgetter('lon')),
+ Variable('p_lat', dtype=lonlat_type, initial=attrgetter('lat'))])
lon = np.linspace(0, 1, npart, dtype=lonlat_type)
lat = np.linspace(1, 0, npart, dtype=lonlat_type)
diff --git a/tests/test_particlesets.py b/tests/test_particlesets.py
index 54074a52fa..9e1d64fdf6 100644
--- a/tests/test_particlesets.py
+++ b/tests/test_particlesets.py
@@ -57,8 +57,7 @@ def test_pset_create_list_with_customvariable(fieldset, mode, npart=100):
lon = np.linspace(0, 1, npart, dtype=np.float32)
lat = np.linspace(1, 0, npart, dtype=np.float32)
- class MyParticle(ptype[mode]):
- v = Variable('v')
+ MyParticle = ptype[mode].add_variable("v")
v_vals = np.arange(npart)
pset = ParticleSet.from_list(fieldset, lon=lon, lat=lat, v=v_vals, pclass=MyParticle)
@@ -74,10 +73,9 @@ def test_pset_create_fromparticlefile(fieldset, mode, restart, tmpdir):
lon = np.linspace(0, 1, 10, dtype=np.float32)
lat = np.linspace(1, 0, 10, dtype=np.float32)
- class TestParticle(ptype[mode]):
- p = Variable('p', np.float32, initial=0.33)
- p2 = Variable('p2', np.float32, initial=1, to_write=False)
- p3 = Variable('p3', np.float32, to_write='once')
+ TestParticle = ptype[mode].add_variable('p', np.float32, initial=0.33)
+ TestParticle = TestParticle.add_variable('p2', np.float32, initial=1, to_write=False)
+ TestParticle = TestParticle.add_variable('p3', np.float64, to_write='once')
pset = ParticleSet(fieldset, lon=lon, lat=lat, depth=[4]*len(lon), pclass=TestParticle, p3=np.arange(len(lon)))
pfile = pset.ParticleFile(filename, outputdt=1)
@@ -99,6 +97,7 @@ def Kernel(particle, fieldset, time):
assert np.allclose([p.id for p in pset], [p.id for p in pset_new])
pset_new.execute(Kernel, runtime=2, dt=1)
assert len(pset_new) == 3*len(pset)
+ assert pset[0].p3.dtype == np.float64
@pytest.mark.parametrize('mode', ['scipy'])
@@ -200,8 +199,7 @@ def IncrLon(particle, fieldset, time):
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
def test_pset_repeatdt_custominit(fieldset, mode):
- class MyParticle(ptype[mode]):
- sample_var = Variable('sample_var')
+ MyParticle = ptype[mode].add_variable('sample_var')
pset = ParticleSet(fieldset, lon=0, lat=0, pclass=MyParticle, repeatdt=1, sample_var=5)
@@ -236,9 +234,9 @@ def test_pset_access(fieldset, mode, npart=100):
@pytest.mark.parametrize('mode', ['scipy', 'jit'])
def test_pset_custom_ptype(fieldset, mode, npart=100):
- class TestParticle(ptype[mode]):
- p = Variable('p', np.float32, initial=0.33)
- n = Variable('n', np.int32, initial=2)
+
+ TestParticle = ptype[mode].add_variable([Variable('p', np.float32, initial=0.33),
+ Variable('n', np.int32, initial=2)])
pset = ParticleSet(fieldset, pclass=TestParticle,
lon=np.linspace(0, 1, npart),
@@ -421,8 +419,7 @@ def test_from_field_exact_val(staggered_grid):
FMask = Field('mask', mask, lon, lat, interp_method='cgrid_tracer')
fieldset.add_field(FMask)
- class SampleParticle(ptype['scipy']):
- mask = Variable('mask', initial=0)
+ SampleParticle = ptype['scipy'].add_variable('mask', initial=0)
def SampleMask(particle, fieldset, time):
particle.mask = fieldset.mask[particle]