diff --git a/docs/examples/tutorial_stommel_uxarray.ipynb b/docs/examples/tutorial_stommel_uxarray.ipynb index d5bc4317cf..5b703550b2 100644 --- a/docs/examples/tutorial_stommel_uxarray.ipynb +++ b/docs/examples/tutorial_stommel_uxarray.ipynb @@ -5,7 +5,19 @@ "metadata": {}, "source": [ "# Stommel Gyre on Unstructured Grid\n", - "This tutorial walks through creating a UXArray dataset using the Stommel Gyre analytical solution for a closed rectangular domain on a beta-plane" + "This tutorial walks a simple example of using Parcels for particle advection on an unstructured grid. The purpose of this tutorial is to introduce you to the new way fields and fieldsets can be instantiated in Parcels using UXArray DataArrays and UXArray grids.\n", + "\n", + "We focus on a simple example, using constant-in-time velocity and pressure fields for the classic barotropic Stommel Gyre. This example dataset is included in Parcels' new `parcels._datasets` module. This module provides example XArray and UXArray datasets that are compatible with Parcels and mimic the way many general circulation model outputs are represented in (U)XArray. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading the example dataset\n", + "Creating a particle simulation starts with defining a dataset that contains the fields that will be used to influence particle attributes, such as position, through kernels. In this example, we focus on advection. Because of this, the dataset we're using will provide velocity fields for our simulation.\n", + "\n", + "Parcels now includes pre-canned example datasets to demonstrate the schema of XArray and UXArray datasets that are compatible with Parcels. For unstructured grid datasets, you can use the `parcels._datasets.unstructured.generic.datasets` dictionary to see which datasets are available for unstructured grids." ] }, { @@ -14,260 +26,248 @@ "metadata": {}, "outputs": [], "source": [ - "def stommel_fieldset_uxarray(xdim=200, ydim=200):\n", - " \"\"\"Simulate a periodic current along a western boundary, with significantly\n", - " larger velocities along the western edge than the rest of the region\n", - "\n", - " The original test description can be found in: N. Fabbroni, 2009,\n", - " Numerical Simulation of Passive tracers dispersion in the sea,\n", - " Ph.D. dissertation, University of Bologna\n", - " http://amsdottorato.unibo.it/1733/1/Fabbroni_Nicoletta_Tesi.pdf\n", - " \"\"\"\n", - " import math\n", - "\n", - " import numpy as np\n", - " import pandas as pd\n", - " import uxarray as ux\n", - "\n", - " a = b = 66666 * 1e3\n", - " scalefac = 0.00025 # to scale for physically meaningful velocities\n", - "\n", - " # Coordinates of the test fieldset\n", - " # Crowd points to the west edge of the domain\n", - " # using a polyonmial map on x-direction\n", - " x = np.linspace(0, 1, xdim, dtype=np.float32)\n", - " lon, lat = np.meshgrid(a * x, np.linspace(0, b, ydim, dtype=np.float32))\n", - " points = (lon.flatten() / 1111111.111111111, lat.flatten() / 1111111.111111111)\n", - "\n", - " # Create the grid\n", - " uxgrid = ux.Grid.from_points(points, method=\"regional_delaunay\")\n", - " uxgrid.construct_face_centers()\n", - "\n", - " # Define arrays U (zonal), V (meridional) and P (sea surface height)\n", - " U = np.zeros((1, 1, lat.size), dtype=np.float32)\n", - " V = np.zeros((1, 1, lat.size), dtype=np.float32)\n", - " P = np.zeros((1, 1, lat.size), dtype=np.float32)\n", + "from parcels._datasets.unstructured.generic import datasets as datasets_unstructured" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "datasets_unstructured.keys()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example, we'll be using the stommel_gyre_delaunay example dataset. This dataset is created by generating a delaunay triangulation of a uniform grid of points in a square domain $x \\in [0,60^\\circ] \\times [0,60^\\circ]$. There is a single vertical layer that is 1000m thick. This layer is defined by the layer surfaces $z_f = 0$ and $z_f = 1000$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = datasets_unstructured[\"stommel_gyre_delaunay\"]\n", + "ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"lon\", ds.uxgrid.face_lon.min().data[()], ds.uxgrid.face_lon.max().data[()])\n", + "print(\"lat\", ds.uxgrid.face_lat.min().data[()], ds.uxgrid.face_lat.max().data[()])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the dataset, we have the following dimensions\n", "\n", - " beta = 2e-11\n", - " r = 1 / (11.6 * 86400)\n", - " es = r / (beta * a)\n", + "* `time: 1` - The number of time levels that the variables in this dataset are defined at. \n", + "* `nz1: 1` - The number of vertical layers. The `nz1` dimension is associated with the `nz1` coordinate that defines the vertical position of the center of each vertical layer. The `nz1` coordinate consists of non-negative values that are assumed to increase with `nz1` dimension index.\n", + "* `n_face: 721` - The number of 2-d unstructured grid faces in the `UXArray.grid`\n", + "* `nz: 2` - The number of vertical layer interfaces. The `nz` dimension is associated with the `nz` coordinate that defines the vertical positions of the interfaces of each vertical layer. The `nz` coordinate consists of non-negative values that are assumed to increase with `nz` dimension index. Note that the number of layer interfaces is always the number of layers plus one.\n", + "* `n_node: 400` - The number of corner node vertices in the grid.\n", "\n", - " i = 0\n", - " for x, y in zip(lon.flatten(), lat.flatten()):\n", - " xi = x / a\n", - " yi = y / b\n", - " P[0, 0, i] = (\n", - " (1 - math.exp(-xi / es) - xi) * math.pi * np.sin(math.pi * yi) * scalefac\n", - " )\n", - " U[0, 0, i] = (\n", - " -(1 - math.exp(-xi / es) - xi)\n", - " * math.pi**2\n", - " * np.cos(math.pi * yi)\n", - " * scalefac\n", - " )\n", - " V[0, 0, i] = (\n", - " (math.exp(-xi / es) / es - 1) * math.pi * np.sin(math.pi * yi) * scalefac\n", - " )\n", - " i += 1\n", + "Whenever you are building a UXArray dataset for use in Parcels, its important to keep in mind that these dimensions and coordinates are assumed to exist for your dataset. Further, it is highly recommended that you use UXArray when possible to load unstructured general circulation model data when possible. This ensures that other characteristics, such as the counterclockwise ordering of vertices for each element, are defined properly for use in Parcels." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defining a Grid, Fields, and Vector Fields\n", "\n", - " u = ux.UxDataArray(\n", - " data=U,\n", - " name=\"u\",\n", - " uxgrid=uxgrid,\n", - " dims=[\"time\", \"nz1\", \"n_node\"],\n", - " coords=dict(\n", - " time=([\"time\"], pd.to_datetime([\"2000-01-01\"])),\n", - " nz1=([\"nz1\"], [0]),\n", - " ),\n", - " attrs=dict(\n", - " description=\"zonal velocity\",\n", - " units=\"m/s\",\n", - " location=\"node\",\n", - " mesh=\"delaunay\",\n", - " ),\n", - " )\n", - " v = ux.UxDataArray(\n", - " data=V,\n", - " name=\"v\",\n", - " uxgrid=uxgrid,\n", - " dims=[\"time\", \"nz1\", \"n_node\"],\n", - " coords=dict(\n", - " time=([\"time\"], pd.to_datetime([\"2000-01-01\"])),\n", - " nz1=([\"nz1\"], [0]),\n", - " ),\n", - " attrs=dict(\n", - " description=\"meridional velocity\",\n", - " units=\"m/s\",\n", - " location=\"node\",\n", - " mesh=\"delaunay\",\n", - " ),\n", - " )\n", - " p = ux.UxDataArray(\n", - " data=P,\n", - " name=\"p\",\n", - " uxgrid=uxgrid,\n", - " dims=[\"time\", \"nz1\", \"n_node\"],\n", - " coords=dict(\n", - " time=([\"time\"], pd.to_datetime([\"2000-01-01\"])),\n", - " nz1=([\"nz1\"], [0]),\n", - " ),\n", - " attrs=dict(\n", - " description=\"pressure\",\n", - " units=\"N/m^2\",\n", - " location=\"node\",\n", - " mesh=\"delaunay\",\n", - " ),\n", - " )\n", + "A `UXArray.Dataset` consists of multiple `UXArray.UxDataArray`'s and a `UXArray.UxGrid`. Parcels views general circulation model data through the `Field` and `VectorField` classes. A `Field` is defined by its `name`, `data`, `grid`, and `interp_method`. A `VectorField` can be constructed by using 2 or 3 `Field`'s. The `Field.data` attribute can be either an `XArray.DataArray` or `UXArray.UxDataArray` object. The `Field.grid` attribute is of type `Parcels.XGrid` or `Parcels.UXGrid`. Last, the `interp_method` is a dynamic function that can be set at runtime to define the interpolation procedure for the `Field`. This gives you the flexibility to use one of the pre-defined interpolation methods included with Parcels v4, or to create your own interpolator. \n", "\n", - " return ux.UxDataset({\"u\": u, \"v\": v, \"p\": p}, uxgrid=uxgrid)\n", + "The first step to creating a `Field` (or `VectorField`) is to define the Grid. For an unstructured grid, we will create a `Parcels.UXGrid` object, which requires a `UxArray.grid` and the vertical layer interface positions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from parcels.uxgrid import UxGrid\n", "\n", + "grid = UxGrid(grid=ds.uxgrid, z=ds.coords[\"nz\"])\n", + "# You can view the uxgrid object with the following command:\n", + "grid.uxgrid" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the `UxGrid` object defined, we can now define our `Field` objects, provided we can align a suitable interpolator what that `Field`. Aligning an interpolator requires you to be cognizant of the location that each `DataArray` is associated with. Since Parcels v4 provides flexibility to customize your interpolation scheme, care must be taken when pairing an interpolation scheme with a field. On unstructured grids, data is typically registered to \"nodes\", \"faces\", or \"edges\". For example, with FESOM2 data, `u` and `v` velocity components are face registered while the vertical velocity component `w` is node registered.\n", "\n", - "uxds = stommel_fieldset_uxarray(50, 50)\n", + "In Parcels, grid searching is conducted with respect to the faces. In other words, when a grid index `ei` is provided to an interpolation method, this refers the face index `fi` at vertical layer `zi` (when unraveled). Within the interpolation method, the `field.grid.uxgrid.face_node_connectivity` attribute can be used to obtain the node indices that surround the face. Using these connectivity tables is necessary for properly indexing node registered data.\n", "\n", - "uxds.uxgrid.plot(\n", - " line_width=0.5,\n", - " height=500,\n", - " width=1000,\n", - " title=\"Regional Delaunay Regions\",\n", + "For the example Stommel gyre dataset in this tutorial, the `u` and `v` velocity components are face registered (similar to FESOM). Parcels includes a nearest neighbor interpolator for face registered unstructured grid data through `Parcels.application_kernels.interpolation.UXPiecewiseConstantFace`. Below, we create the `Field`s `U` and `V` and associate them with the `UxGrid` we created previously and this interpolation method. Setting the `mesh_type` to `\"spherical\"` is a legacy feature from Parcels v3 that enables unit conversion from `m/s` to `deg/s`; this is needed in this case since the grid locations are defined in units of degrees." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from parcels.application_kernels.interpolation import UXPiecewiseConstantFace\n", + "from parcels.field import Field\n", + "\n", + "U = Field(\n", + " name=\"U\",\n", + " data=ds.U,\n", + " grid=grid,\n", + " mesh_type=\"spherical\",\n", + " interp_method=UXPiecewiseConstantFace,\n", + ")\n", + "V = Field(\n", + " name=\"V\",\n", + " data=ds.V,\n", + " grid=grid,\n", + " mesh_type=\"spherical\",\n", + " interp_method=UXPiecewiseConstantFace,\n", + ")\n", + "P = Field(\n", + " name=\"P\",\n", + " data=ds.p,\n", + " grid=grid,\n", + " mesh_type=\"spherical\",\n", + " interp_method=UXPiecewiseConstantFace,\n", ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we've defined the `U` and `V` fields, we can define a `VectorField`. The `VectorField` is created in a similar manner, except that it is initialized with `Field` objects. You can optionally define an `interp_method` on the `VectorField`. When this is done, the `VectorField.interp_method` is used for interpolation; otherwise, evaluation of the `VectorField` is done component-wise using the `interp_method` associated with each component separately." + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def stommel_fieldset_xarray(xdim=200, ydim=200, grid_type=\"A\"):\n", - " \"\"\"Simulate a periodic current along a western boundary, with significantly\n", - " larger velocities along the western edge than the rest of the region\n", - "\n", - " The original test description can be found in: N. Fabbroni, 2009,\n", - " Numerical Simulation of Passive tracers dispersion in the sea,\n", - " Ph.D. dissertation, University of Bologna\n", - " http://amsdottorato.unibo.it/1733/1/Fabbroni_Nicoletta_Tesi.pdf\n", - " \"\"\"\n", - " import math\n", - "\n", - " import numpy as np\n", - " import pandas as pd\n", - " import xarray as xr\n", + "from parcels.field import VectorField\n", "\n", - " a = b = 10000 * 1e3\n", - " scalefac = 0.05 # to scale for physically meaningful velocities\n", - " dx, dy = a / xdim, b / ydim\n", - "\n", - " # Coordinates of the test fieldset (on A-grid in deg)\n", - " lon = np.linspace(0, a, xdim, dtype=np.float32)\n", - " lat = np.linspace(0, b, ydim, dtype=np.float32)\n", - "\n", - " # Define arrays U (zonal), V (meridional) and P (sea surface height)\n", - " U = np.zeros((1, 1, lat.size, lon.size), dtype=np.float32)\n", - " V = np.zeros((1, 1, lat.size, lon.size), dtype=np.float32)\n", - " P = np.zeros((1, 1, lat.size, lon.size), dtype=np.float32)\n", - "\n", - " beta = 2e-11\n", - " r = 1 / (11.6 * 86400)\n", - " es = r / (beta * a)\n", + "UV = VectorField(name=\"UV\", U=U, V=V)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defining the FieldSet\n", + "With all of the fields defined, that we want for this simulation, we can now create the `FieldSet`. As the name suggests, the `FieldSet` is the set of all `Field`s that will be used for a particle simulation. A `FieldSet` is initialized with a list of `Field` objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from parcels.fieldset import FieldSet\n", "\n", - " for j in range(lat.size):\n", - " for i in range(lon.size):\n", - " xi = lon[i] / a\n", - " yi = lat[j] / b\n", - " P[..., j, i] = (\n", - " (1 - math.exp(-xi / es) - xi)\n", - " * math.pi\n", - " * np.sin(math.pi * yi)\n", - " * scalefac\n", - " )\n", - " if grid_type == \"A\":\n", - " U[..., j, i] = (\n", - " -(1 - math.exp(-xi / es) - xi)\n", - " * math.pi**2\n", - " * np.cos(math.pi * yi)\n", - " * scalefac\n", - " )\n", - " V[..., j, i] = (\n", - " (math.exp(-xi / es) / es - 1)\n", - " * math.pi\n", - " * np.sin(math.pi * yi)\n", - " * scalefac\n", - " )\n", + "fieldset = FieldSet([UV, UV.U, UV.V, P])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting your own custom interpolator\n", + "You may be wondering how to set your own custom interpolator. In Parcels v4, this is as simple as defining a function that matches a specific API. The API you need to match is defined in the `field.py` module in the `Field._interp_template` and `VectorField._interp_template`. Specifically,\n", + "\n", + "```python\n", + "def _interp_template(\n", + " self, # Field or VectorField\n", + " ti: int, # Time index\n", + " ei: int, # Flat grid index\n", + " bcoords: np.ndarray, # Barycentric coordinates relative to the cell vertices\n", + " tau: np.float32 | np.float64, # Time interpolation weight\n", + " t: np.float32 | np.float64, # Current simulation time\n", + " z: np.float32 | np.float64, # Current particle depth\n", + " y: np.float32 | np.float64, # Current particle y-position\n", + " x: np.float32 | np.float64, # Current particle x-position\n", + " ) -> np.float32 | np.float64 # For `Field`, returns a float value.\n", + "```\n", + "\n", + "So long as your function matches this API, you can define such a function and set the `Field.interp_method` to that function.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", "\n", - " time = pd.to_datetime([\"2000-01-01\"])\n", - " z = [0]\n", - " if grid_type == \"C\":\n", - " V[..., :, 1:] = (P[..., :, 1:] - P[..., :, 0:-1]) / dx * a\n", - " U[..., 1:, :] = -(P[..., 1:, :] - P[..., 0:-1, :]) / dy * b\n", - " u_dims = [\"time\", \"nz1\", \"face_lat\", \"node_lon\"]\n", - " u_lat = lat\n", - " u_lon = lon - dx * 0.5\n", - " u_location = \"x_edge\"\n", - " v_dims = [\"time\", \"nz1\", \"node_lat\", \"face_lon\"]\n", - " v_lat = lat - dy * 0.5\n", - " v_lon = lon\n", - " v_location = \"y_edge\"\n", - " p_dims = [\"time\", \"nz1\", \"face_lat\", \"face_lon\"]\n", - " p_lat = lat\n", - " p_lon = lon\n", - " p_location = \"face\"\n", "\n", + "def my_custom_interpolator(\n", + " self,\n", + " ti: int,\n", + " ei: int,\n", + " bcoords: np.ndarray,\n", + " tau: np.float32 | np.float64,\n", + " t: np.float32 | np.float64,\n", + " z: np.float32 | np.float64,\n", + " y: np.float32 | np.float64,\n", + " x: np.float32 | np.float64,\n", + ") -> np.float32 | np.float64:\n", + " \"\"\"Custom interpolation method for the P field.\n", + " This method interpolates the value at a face by averaging the values of its neighboring faces.\n", + " While this may be nonsense, it demonstrates how to create a custom interpolation method.\"\"\"\n", + "\n", + " zi, fi = self.grid.unravel_index(ei)\n", + " neighbors = self.grid.uxgrid.face_face_connectivity[fi]\n", + " f_at_neighbors = self.data.values[ti, zi, neighbors]\n", + " # Interpolate using the average of the neighboring face values\n", + " if len(f_at_neighbors) > 0:\n", + " return np.mean(f_at_neighbors)\n", + " # If no neighbors, return the value at the face itself\n", " else:\n", - " u_dims = [\"time\", \"nz1\", \"node_lat\", \"node_lon\"]\n", - " v_dims = [\"time\", \"nz1\", \"node_lat\", \"node_lon\"]\n", - " p_dims = [\"time\", \"nz1\", \"node_lat\", \"node_lon\"]\n", - " u_lat = lat\n", - " u_lon = lon\n", - " v_lat = lat\n", - " v_lon = lon\n", - " u_location = \"node\"\n", - " v_location = \"node\"\n", - " p_lat = lat\n", - " p_lon = lon\n", - " p_location = \"node\"\n", + " return self.data.values[ti, zi, fi]\n", "\n", - " u = xr.DataArray(\n", - " data=U,\n", - " name=\"u\",\n", - " dims=u_dims,\n", - " coords=[time, z, u_lat, u_lon],\n", - " attrs=dict(\n", - " description=\"zonal velocity\",\n", - " units=\"m/s\",\n", - " location=u_location,\n", - " mesh=f\"Arakawa-{grid_type}\",\n", - " ),\n", - " )\n", - " v = xr.DataArray(\n", - " data=V,\n", - " name=\"v\",\n", - " dims=v_dims,\n", - " coords=[time, z, v_lat, v_lon],\n", - " attrs=dict(\n", - " description=\"meridional velocity\",\n", - " units=\"m/s\",\n", - " location=v_location,\n", - " mesh=f\"Arakawa-{grid_type}\",\n", - " ),\n", - " )\n", - " p = xr.DataArray(\n", - " data=P,\n", - " name=\"p\",\n", - " dims=p_dims,\n", - " coords=[time, z, p_lat, p_lon],\n", - " attrs=dict(\n", - " description=\"pressure\",\n", - " units=\"N/m^2\",\n", - " location=p_location,\n", - " mesh=f\"Arakawa-{grid_type}\",\n", - " ),\n", - " )\n", "\n", - " return xr.Dataset({\"u\": u, \"v\": v, \"p\": p})\n", + "# Assign the custom interpolator to the P field\n", + "P.interp_method = my_custom_interpolator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Understanding the context inside an interpolator method\n", + "Providing the `Field` object as an input to an interpolator exposes you to a ton of useful information and methods for building complex interpolators. Particularly, the `Field.grid` attribute gives you access to connectivity tables and metric terms that you may find useful for constructing an interpolator. For context, the `Parcels.UXGrid` class is built on top of the `Parcels.BaseGrid` class (much likes it's structured grid `Parcels.XGrid` counterpart). The `Parcels.UXGrid` class combines a `UXArray.grid` object alongside the vertical layer interfaces, which provides sufficient information to define the API that the `BaseGrid` class demands. This includes\n", "\n", + "* `search` - A method for returning a flat grid index `ei` for a position `(x,y,z)`\n", + "* `ravel_index` - A method for converting a face index `fi` and a vertical layer index `zi` into a single flat grid index `ei`\n", + "* `unravel_index` - A method for converted a single flat grid index `ei` into a face index `fi` and a vertical layer index `zi`\n", "\n", - "ds_arakawa_a = stommel_fieldset_xarray(50, 50, \"A\")\n", - "ds_arakawa_c = stommel_fieldset_xarray(50, 50, \"C\")" + "The `ravel/unravel` methods are a necessity for most interpolators. For unstructured grids, the `Field.grid.uxgrid` attribute give you access to all of the attributes associated with a `UxArray.grid` object (See https://uxarray.readthedocs.io/en/latest/api.html#grid for more details.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Running the forward integration" ] }, { @@ -276,7 +276,7 @@ "metadata": {}, "outputs": [], "source": [ - "ds_arakawa_a" + "from parcels import AdvectionEE, AdvectionRK4" ] }, { @@ -285,7 +285,9 @@ "metadata": {}, "outputs": [], "source": [ - "ds_arakawa_a[\"u\"].attrs" + "from datetime import datetime, timedelta\n", + "\n", + "from parcels import Particle, ParticleSet" ] }, { @@ -294,7 +296,20 @@ "metadata": {}, "outputs": [], "source": [ - "ds_arakawa_c" + "num_particles = 2\n", + "\n", + "pset = ParticleSet(\n", + " fieldset,\n", + " lon=np.random.uniform(3.0, 57.0, size=(num_particles,)),\n", + " lat=np.random.uniform(3.0, 57.0, size=(num_particles,)),\n", + " depth=50.0 * np.ones(shape=(num_particles,)),\n", + " time=0.0\n", + " * np.ones(\n", + " shape=(num_particles,)\n", + " ), # important otherwise initialization appears to take forever?\n", + " pclass=Particle,\n", + ")\n", + "print(len(pset), \"particles created\")" ] }, { @@ -303,17 +318,40 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "\n", - "min_length_scale = 1111111.111111111 * np.sqrt(np.min(uxds.uxgrid.face_areas))\n", - "print(min_length_scale)\n", - "\n", - "max_v = np.sqrt(uxds[\"u\"] ** 2 + uxds[\"v\"] ** 2).max()\n", - "print(max_v)\n", - "\n", - "cfl = 0.1\n", - "dt = cfl * min_length_scale / max_v\n", - "print(dt)" + "from tqdm import tqdm\n", + "\n", + "from parcels import FieldOutOfBoundError\n", + "\n", + "# for capturing positions\n", + "_lon = [pset.lon]\n", + "_lat = [pset.lat]\n", + "\n", + "# output / sub-experiment time stepping\n", + "output_dt = timedelta(minutes=10)\n", + "endtime = output_dt\n", + "\n", + "# run 40 x 6 hours and capture positions after each iteration\n", + "for num_output in tqdm(range(40)):\n", + " # if one particle errors, let's top all of them\n", + " try:\n", + " pset.execute(\n", + " endtime=endtime,\n", + " dt=timedelta(seconds=60),\n", + " pyfunc=AdvectionEE,\n", + " verbose_progress=False,\n", + " )\n", + " except FieldOutOfBoundError:\n", + " print(\"out of bounds, stopping (all particles)\")\n", + " break\n", + "\n", + " # on to the next sub experiment\n", + " endtime += output_dt\n", + " _lon.append(pset.lon)\n", + " _lat.append(pset.lat)\n", + "\n", + "# merge captured positions\n", + "lon = np.vstack(_lon)\n", + "lat = np.vstack(_lat)" ] }, { @@ -322,27 +360,15 @@ "metadata": {}, "outputs": [], "source": [ - "from datetime import timedelta\n", - "\n", - "import numpy as np\n", - "import uxarray as ux\n", - "\n", - "from parcels import Particle, ParticleSet, UxAdvectionEuler, UXFieldSet\n", + "from matplotlib import pyplot as plt\n", "\n", - "npart = 10\n", - "fieldset = UXFieldSet(uxds)\n", - "# pset = ParticleSet(\n", - "# fieldset,\n", - "# pclass=Particle,\n", - "# lon=np.linspace(1, 59, npart),\n", - "# lat=np.zeros(npart)+30)\n", - "# pset.execute(UxAdvectionEuler, runtime=timedelta(hours=24), dt=timedelta(seconds=dt))" + "plt.plot(lon, lat, \"-\");" ] } ], "metadata": { "kernelspec": { - "display_name": "parcels", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -356,9 +382,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.2" + "version": "3.12.11" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/parcels/kernel.py b/parcels/kernel.py index be2be29a02..287afce4c9 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -16,7 +16,6 @@ AdvectionRK4_3D_CROCO, AdvectionRK45, ) -from parcels.field import VectorField from parcels.grid import GridType from parcels.tools.statuscodes import ( StatusCode, @@ -313,12 +312,6 @@ def execute(self, pset, endtime, dt): stacklevel=2, ) - if pset.fieldset is not None: - for f in self.fieldset.fields.values(): - if isinstance(f, VectorField): - continue - f.data = np.array(f.data) - if not self._positionupdate_kernels_added: self.add_positionupdate_kernels() self._positionupdate_kernels_added = True diff --git a/parcels/particleset.py b/parcels/particleset.py index a037cc3868..2ea9c41fa7 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -1,32 +1,22 @@ import sys import warnings from collections.abc import Iterable -from copy import copy from datetime import date, datetime, timedelta -import cftime import numpy as np import xarray as xr from scipy.spatial import KDTree from tqdm import tqdm -from parcels._compat import MPI from parcels._core.utils.time import TimeInterval from parcels._reprs import particleset_repr from parcels.application_kernels.advection import AdvectionRK4 from parcels.grid import GridType from parcels.interaction.interactionkernel import InteractionKernel -from parcels.interaction.neighborsearch import ( - BruteFlatNeighborSearch, - BruteSphericalNeighborSearch, - HashSphericalNeighborSearch, - KDTreeFlatNeighborSearch, -) from parcels.kernel import Kernel -from parcels.particle import Particle, Variable +from parcels.particle import Particle from parcels.particledata import ParticleData, ParticleDataIterator from parcels.particlefile import ParticleFile -from parcels.tools._helpers import timedelta_to_float from parcels.tools.converters import _get_cftime_calendars, convert_to_flat_array from parcels.tools.loggers import logger from parcels.tools.statuscodes import StatusCode @@ -92,8 +82,6 @@ def __init__( repeatdt=None, lonlatdepth_dtype=None, pid_orig=None, - interaction_distance=None, - periodic_domain_zonal=None, **kwargs, ): self.particledata = None @@ -109,59 +97,42 @@ def __init__( self.fieldset = fieldset self._pclass = pclass - # ==== first: create a new subclass of the pclass that includes the required variables ==== # - # ==== see dynamic-instantiation trick here: https://www.python-course.eu/python3_classes_and_type.php ==== # - class_name = pclass.__name__ - array_class = None - if class_name not in dir(): - - def ArrayClass_init(self, *args, **kwargs): - fieldset = kwargs.get("fieldset", None) - ngrids = kwargs.get("ngrids", None) - if type(self).ngrids.initial < 0: - numgrids = ngrids - if numgrids is None and fieldset is not None: - numgrids = len(fieldset.gridset) - assert numgrids is not None, "Neither fieldsets nor number of grids are specified - exiting." - type(self).ngrids.initial = numgrids - self.ngrids = type(self).ngrids.initial - if self.ngrids >= 0: - self.ei = np.zeros(self.ngrids, dtype=np.int32) - super(type(self), self).__init__(*args, **kwargs) - - array_class_vdict = { - "ngrids": Variable("ngrids", dtype=np.int32, to_write=False, initial=-1), - "ei": Variable("ei", dtype=np.int32, to_write=False), - "__init__": ArrayClass_init, - } - array_class = type(class_name, (pclass,), array_class_vdict) - else: - array_class = locals()[class_name] - # ==== dynamic re-classing completed ==== # - _pclass = array_class + if repeatdt: + NotImplementedError("ParticleSet.repeatdt is not implemented yet in v4") lon = np.empty(shape=0) if lon is None else convert_to_flat_array(lon) lat = np.empty(shape=0) if lat is None else convert_to_flat_array(lat) + time = np.empty(shape=0) if time is None else convert_to_flat_array(time) if isinstance(pid_orig, (type(None), bool)): pid_orig = np.arange(lon.size) if depth is None: - mindepth = self.fieldset.dimrange("depth")[0] + mindepth = 0 + for field in self.fieldset.fields.values(): + if field.grid.depth is not None: + mindepth = min(mindepth, field.grid.depth[0]) depth = np.ones(lon.size) * mindepth else: depth = convert_to_flat_array(depth) assert lon.size == lat.size and lon.size == depth.size, "lon, lat, depth don't all have the same lenghts" - time = convert_to_flat_array(time) - time = np.repeat(time, lon.size) if time.size == 1 else time + if time.size > 0: + time = np.repeat(time, lon.size) if time.size == 1 else time + + if type(time[0]) in [np.datetime64, np.timedelta64]: + pass # already in the right format + elif type(time[0]) in [datetime, date]: + time = np.array([np.datetime64(t) for t in time]) + elif type(time[0]) in [timedelta]: + time = np.array([np.timedelta64(t) for t in time]) + else: + raise NotImplementedError("particle time must be a datetime, timedelta, or date object") + + time = np.array([self.time_origin.reltime(t) if _convert_to_reltime(t) else t for t in time]) + + assert lon.size == time.size, "time and positions (lon, lat, depth) do not have the same lengths." - if time.size > 0 and type(time[0]) in [datetime, date]: - time = np.array([np.datetime64(t) for t in time]) - if time.size > 0 and isinstance(time[0], np.timedelta64) and not self.time_origin: - raise NotImplementedError("If fieldset.time_origin is not a date, time of a particle must be a double") - time = np.array([self.time_origin.reltime(t) if _convert_to_reltime(t) else t for t in time]) - assert lon.size == time.size, "time and positions (lon, lat, depth) do not have the same lengths." if fieldset.time_interval: _warn_particle_times_outside_fieldset_time_bounds(time, fieldset.time_interval) @@ -179,98 +150,18 @@ def ArrayClass_init(self, *args, **kwargs): lon.size == kwargs[kwvar].size ), f"{kwvar} and positions (lon, lat, depth) don't have the same lengths." - self.repeatdt = timedelta_to_float(repeatdt) if repeatdt is not None else None - - if self.repeatdt: - if self.repeatdt <= 0: - raise ValueError("Repeatdt should be > 0") - if time[0] and not np.allclose(time, time[0]): - raise ValueError("All Particle.time should be the same when repeatdt is not None") - self._repeatpclass = pclass - self._repeatkwargs = kwargs - self._repeatkwargs.pop("partition_function", None) - - ngrids = len(fieldset.gridset) - - # Variables used for interaction kernels. - inter_dist_horiz = None - inter_dist_vert = None - # The _dirty_neighbor attribute keeps track of whether - # the neighbor search structure needs to be rebuilt. - # If indices change (for example adding/deleting a particle) - # The NS structure needs to be rebuilt and _dirty_neighbor should be - # set to true. Since the NS structure isn't immediately initialized, - # it is set to True here. - self._dirty_neighbor = True - self.particledata = ParticleData( - _pclass, + self._pclass, lon=lon, lat=lat, depth=depth, time=time, lonlatdepth_dtype=lonlatdepth_dtype, pid_orig=pid_orig, - ngrid=ngrids, + ngrid=len(fieldset.gridset), **kwargs, ) - # Initialize neighbor search data structure (used for interaction). - if interaction_distance is not None: - meshes = [g.mesh for g in fieldset.gridset] - # Assert all grids have the same mesh type - assert np.all(np.array(meshes) == meshes[0]) - mesh_type = meshes[0] - if mesh_type == "spherical": - if len(self) < 1000: - interaction_class = BruteSphericalNeighborSearch - else: - interaction_class = HashSphericalNeighborSearch - elif mesh_type == "flat": - if len(self) < 1000: - interaction_class = BruteFlatNeighborSearch - else: - interaction_class = KDTreeFlatNeighborSearch - else: - assert False, "Interaction is only possible on 'flat' and 'spherical' meshes" - try: - if len(interaction_distance) == 2: - inter_dist_vert, inter_dist_horiz = interaction_distance - else: - inter_dist_vert = interaction_distance[0] - inter_dist_horiz = interaction_distance[0] - except TypeError: - inter_dist_vert = interaction_distance - inter_dist_horiz = interaction_distance - self._neighbor_tree = interaction_class( - inter_dist_vert=inter_dist_vert, - inter_dist_horiz=inter_dist_horiz, - periodic_domain_zonal=periodic_domain_zonal, - ) - # End of neighbor search data structure initialization. - - if self.repeatdt: - if len(time) > 0 and time[0] is None: - self._repeat_starttime = time[0] - else: - if self.particledata.data["time"][0] and not np.allclose( - self.particledata.data["time"], self.particledata.data["time"][0] - ): - raise ValueError("All Particle.time should be the same when repeatdt is not None") - self._repeat_starttime = copy(self.particledata.data["time"][0]) - self._repeatlon = copy(self.particledata.data["lon"]) - self._repeatlat = copy(self.particledata.data["lat"]) - self._repeatdepth = copy(self.particledata.data["depth"]) - for kwvar in kwargs: - if kwvar not in ["partition_function"]: - self._repeatkwargs[kwvar] = copy(self.particledata.data[kwvar]) - - if self.repeatdt: - if MPI and self.particledata.pu_indicators is not None: - mpi_comm = MPI.COMM_WORLD - mpi_rank = mpi_comm.Get_rank() - self._repeatpid = pid_orig[self.particledata.pu_indicators == mpi_rank] - self._kernel = None def __del__(self): @@ -863,14 +754,11 @@ def set_variable_write_status(self, var, write_status): def execute( self, pyfunc=AdvectionRK4, - pyfunc_inter=None, - endtime=None, - runtime: float | timedelta | np.timedelta64 | None = None, - dt: float | timedelta | np.timedelta64 = 1.0, + endtime: timedelta | datetime | None = None, + runtime: timedelta | None = None, + dt: np.float64 | np.float32 | timedelta | None = None, output_file=None, verbose_progress=True, - postIterationCallbacks=None, - callbackdt: float | timedelta | np.timedelta64 | None = None, ): """Execute a given kernel function over the particle set for multiple timesteps. @@ -883,13 +771,14 @@ def execute( Kernel function to execute. This can be the name of a defined Python function or a :class:`parcels.kernel.Kernel` object. Kernels can be concatenated using the + operator (Default value = AdvectionRK4) - endtime : - End time for the timestepping loop. - It is either a datetime object or a positive double. (Default value = None) - runtime : - Length of the timestepping loop. Use instead of endtime. - It is either a timedelta object or a positive double. (Default value = None) - dt : + endtime (datetime or timedelta): : + End time for the timestepping loop. If a timedelta is provided, it is interpreted as the total simulation time. In this case, + the absolute end time is the start of the fieldset's time interval plus the timedelta. + If a datetime is provided, it is interpreted as the absolute end time of the simulation. + runtime (timedelta): + The duration of the simuulation execution. Must be a timedelta object and is required to be set when the `fieldset.time_interval` is not defined. + If the `fieldset.time_interval` is defined and the runtime is provided, the end time will be the start of the fieldset's time interval plus the runtime. + dt (timedelta): Timestep interval (in seconds) to be passed to the kernel. It is either a timedelta object or a double. Use a negative value for a backward-in-time simulation. (Default value = 1 second) @@ -897,12 +786,6 @@ def execute( mod:`parcels.particlefile.ParticleFile` object for particle output (Default value = None) verbose_progress : bool Boolean for providing a progress bar for the kernel execution loop. (Default value = True) - postIterationCallbacks : - Optional, array of functions that are to be called after each iteration (post-process, non-Kernel) (Default value = None) - callbackdt : - Optional, in conjecture with 'postIterationCallbacks', timestep interval to (latest) interrupt the running kernel and invoke post-iteration callbacks from 'postIterationCallbacks' (Default value = None) - pyfunc_inter : - (Default value = None) Notes ----- @@ -919,197 +802,83 @@ def execute( self._kernel = pyfunc else: self._kernel = self.Kernel(pyfunc) + if output_file: output_file.metadata["parcels_kernels"] = self._kernel.name - # Set up the interaction kernel(s) if not set and given. - if self._interaction_kernel is None and pyfunc_inter is not None: - if isinstance(pyfunc_inter, InteractionKernel): - self._interaction_kernel = pyfunc_inter + if self.fieldset.time_interval is None: + start_time = timedelta(seconds=0) # For the execution loop, we need a start time as a timedelta object + if runtime is None: + raise ValueError("The runtime must be provided when the time_interval is not defined for a fieldset.") + else: - self._interaction_kernel = self.InteractionKernel(pyfunc_inter) - - # Convert all time variables to seconds - if isinstance(endtime, timedelta): - raise TypeError("endtime must be either a datetime or a double") - if isinstance(endtime, datetime): - endtime = np.datetime64(endtime) - elif isinstance(endtime, cftime.datetime): - endtime = self.time_origin.reltime(endtime) - if isinstance(endtime, np.datetime64): - if self.time_origin.calendar is None: - raise NotImplementedError("If fieldset.time_origin is not a date, execution endtime must be a double") - endtime = self.time_origin.reltime(endtime) - - if runtime is not None: - runtime = timedelta_to_float(runtime) - - dt = timedelta_to_float(dt) - - if abs(dt) <= 1e-6: - raise ValueError("Time step dt is too small") - if (dt * 1e6) % 1 != 0: - raise ValueError("Output interval should not have finer precision than 1e-6 s") - outputdt = timedelta_to_float(output_file.outputdt) if output_file else np.inf - - if callbackdt is not None: - callbackdt = timedelta_to_float(callbackdt) - - assert runtime is None or runtime >= 0, "runtime must be positive" - assert outputdt is None or outputdt >= 0, "outputdt must be positive" - - if runtime is not None and endtime is not None: - raise RuntimeError("Only one of (endtime, runtime) can be specified") - - mintime, maxtime = self.fieldset.dimrange("time") # TODO : change to fieldset.time_interval - - default_release_time = mintime if dt >= 0 else maxtime - if np.any(np.isnan(self.particledata.data["time"])): - self.particledata.data["time"][np.isnan(self.particledata.data["time"])] = default_release_time - self.particledata.data["time_nextloop"][np.isnan(self.particledata.data["time_nextloop"])] = ( - default_release_time - ) - min_rt = np.min(self.particledata.data["time_nextloop"]) - max_rt = np.max(self.particledata.data["time_nextloop"]) - - # Derive starttime and endtime from arguments or fieldset defaults - starttime = min_rt if dt >= 0 else max_rt - if self.repeatdt is not None and self._repeat_starttime is None: - self._repeat_starttime = starttime - if runtime is not None: - endtime = starttime + runtime * np.sign(dt) - elif endtime is None: - mintime, maxtime = self.fieldset.dimrange("time") - endtime = maxtime if dt >= 0 else mintime - - if (abs(endtime - starttime) < 1e-5 or runtime == 0) and dt == 0: - raise RuntimeError( - "dt and runtime are zero, or endtime is equal to Particle.time. " - "ParticleSet.execute() will not do anything." - ) + if isinstance(runtime, timedelta): + end_time = runtime + else: + raise ValueError("The runtime must be a timedelta object") - if np.isfinite(outputdt): - _warn_outputdt_release_desync(outputdt, starttime, self.particledata.data["time_nextloop"]) + else: + start_time = self.fieldset.time_interval.left + + if runtime is None: + if endtime is None: + raise ValueError( + "Must provide either runtime or endtime when time_interval is defined for a fieldset." + ) + # Ensure that the endtime uses the same type as the start_time + if isinstance(endtime, self.fieldset.time_interval.left.__class__): + if endtime < self.fieldset.time_interval.left: + raise ValueError("The endtime must be after the start time of the fieldset.time_interval") + end_time = min(endtime, self.fieldset.time_interval.right) + else: + raise ValueError("The endtime must be of the same type as the fieldset.time_interval start time.") + else: + end_time = start_time + runtime - self.particledata._data["dt"][:] = dt + outputdt = output_file.outputdt if output_file else None - if callbackdt is None: - interupt_dts = [np.inf, outputdt] - if self.repeatdt is not None: - interupt_dts.append(self.repeatdt) - callbackdt = np.min(np.array(interupt_dts)) + # dt must be converted to float to avoid "TypeError: float() argument must be a string or a real number, not 'datetime.timedelta'" + dt_seconds = dt / np.timedelta64(1, "s") + self.particledata._data["dt"][:] = dt_seconds # Set up pbar if output_file: logger.info(f"Output files are stored in {output_file.fname}.") if verbose_progress: - pbar = tqdm(total=abs(endtime - starttime), file=sys.stdout) + pbar = tqdm(total=(start_time - end_time).total_seconds(), file=sys.stdout) - # Set up variables for first iteration - if self.repeatdt: - next_prelease = self._repeat_starttime + ( - abs(starttime - self._repeat_starttime) // self.repeatdt + 1 - ) * self.repeatdt * np.sign(dt) - else: - next_prelease = np.inf if dt > 0 else -np.inf if output_file: - next_output = starttime + dt + next_output = outputdt else: - next_output = np.inf * np.sign(dt) - next_callback = starttime + callbackdt * np.sign(dt) + next_output = np.inf tol = 1e-12 - time = starttime - - while (time < endtime and dt > 0) or (time > endtime and dt < 0): - # Check if we can fast-forward to the next time needed for the particles - if dt > 0: - skip_kernel = True if min(self.time) > (time + dt) else False - else: - skip_kernel = True if max(self.time) < (time + dt) else False - - time_at_startofloop = time - - next_input = self.fieldset.computeTimeChunk(time, dt) - # Define next_time (the timestamp when the execution needs to be handed back to python) - if dt > 0: - next_time = min(next_prelease, next_input, next_output, next_callback, endtime) - else: - next_time = max(next_prelease, next_input, next_output, next_callback, endtime) - - # If we don't perform interaction, only execute the normal kernel efficiently. - if self._interaction_kernel is None: - if not skip_kernel: - res = self._kernel.execute(self, endtime=next_time, dt=dt) - if res == StatusCode.StopAllExecution: - return StatusCode.StopAllExecution - # Interaction: interleave the interaction and non-interaction kernel for each time step. - # E.g. Normal -> Inter -> Normal -> Inter if endtime-time == 2*dt - else: - cur_time = time - while (cur_time < next_time and dt > 0) or (cur_time > next_time and dt < 0): - if dt > 0: - cur_end_time = min(cur_time + dt, next_time) - else: - cur_end_time = max(cur_time + dt, next_time) - self._kernel.execute(self, endtime=cur_end_time, dt=dt) - self._interaction_kernel.execute(self, endtime=cur_end_time, dt=dt) - cur_time += dt - # End of interaction specific code - time = next_time - - # Check for empty ParticleSet - if np.isinf(next_prelease) and len(self) == 0: + time = start_time + while time <= end_time: + t0 = time + next_time = t0 + dt + # Kernel and particledata currently expect all time objects to be numpy floats. + # When converting absolute times to floats, we do them all relative to the start time. + # TODO: To completely support datetime or timedelta objects, this really needs to be addressed in the kernels and particledata + next_time_float = (next_time - start_time) / np.timedelta64(1, "s") + res = self._kernel.execute(self, endtime=next_time_float, dt=dt_seconds) + if res == StatusCode.StopAllExecution: return StatusCode.StopAllExecution - if abs(time - next_output) < tol: - for fld in self.fieldset.fields.values(): - if hasattr(fld, "to_write") and fld.to_write: - if fld.grid.tdim > 1: - raise RuntimeError( - "Field writing during execution only works for Fields with one snapshot in time" - ) - fldfilename = str(output_file.fname).replace(".zarr", f"_{fld.to_write:04d}") - fld.write(fldfilename) - fld.to_write += 1 - - if abs(time - next_output) < tol: + # End of interaction specific code + # TODO: Handle IO timing based of timedelta or datetime objects + if abs(next_time_float - next_output) < tol: if output_file: - output_file.write(self, time_at_startofloop) + output_file.write(self, next_output) if np.isfinite(outputdt): - next_output += outputdt * np.sign(dt) - - # ==== insert post-process here to also allow for memory clean-up via external func ==== # - if abs(time - next_callback) < tol: - if postIterationCallbacks is not None: - for extFunc in postIterationCallbacks: - extFunc() - next_callback += callbackdt * np.sign(dt) - - if abs(time - next_prelease) < tol: - pset_new = self.__class__( - fieldset=self.fieldset, - time=time, - lon=self._repeatlon, - lat=self._repeatlat, - depth=self._repeatdepth, - pclass=self._repeatpclass, - lonlatdepth_dtype=self.particledata.lonlatdepth_dtype, - partition_function=False, - pid_orig=self._repeatpid, - **self._repeatkwargs, - ) - for p in pset_new: - p.dt = dt - self.add(pset_new) - next_prelease += self.repeatdt * np.sign(dt) + next_output += outputdt - if time != endtime: - next_input = self.fieldset.computeTimeChunk(time, dt) if verbose_progress: - pbar.update(abs(time - time_at_startofloop)) + pbar.update(dt.total_seconds()) + + time = next_time if verbose_progress: pbar.close() @@ -1127,15 +896,15 @@ def _warn_outputdt_release_desync(outputdt: float, starttime: float, release_tim ) -def _warn_particle_times_outside_fieldset_time_bounds(release_times: np.ndarray, time: np.ndarray | TimeInterval): +def _warn_particle_times_outside_fieldset_time_bounds(release_times: np.ndarray, time: TimeInterval): if np.any(release_times): - if np.any(release_times < time[0]): + if np.any(release_times < time.left): warnings.warn( "Some particles are set to be released outside the FieldSet's executable time domain.", ParticleSetWarning, stacklevel=2, ) - if np.any(release_times > time[-1]): + if np.any(release_times > time.right): warnings.warn( "Some particles are set to be released after the fieldset's last time and the fields are not constant in time.", ParticleSetWarning, diff --git a/tests/v4/test_particleset.py b/tests/v4/test_particleset.py new file mode 100644 index 0000000000..81d4c1a080 --- /dev/null +++ b/tests/v4/test_particleset.py @@ -0,0 +1,97 @@ +from datetime import timedelta + +import pytest + +from parcels import ( + AdvectionEE, + Field, + FieldSet, + Particle, + ParticleSet, + UXPiecewiseConstantFace, + VectorField, +) +from parcels._datasets.unstructured.generic import datasets as datasets_unstructured +from parcels.uxgrid import UxGrid + + +@pytest.mark.parametrize("verbose_progress", [True, False]) +def test_uxstommelgyre_pset_execute(verbose_progress): + ds = datasets_unstructured["stommel_gyre_delaunay"] + grid = UxGrid(grid=ds.uxgrid, z=ds.coords["nz"]) + U = Field( + name="U", + data=ds.U, + grid=grid, + mesh_type="spherical", + interp_method=UXPiecewiseConstantFace, + ) + V = Field( + name="V", + data=ds.V, + grid=grid, + mesh_type="spherical", + interp_method=UXPiecewiseConstantFace, + ) + P = Field( + name="P", + data=ds.p, + grid=grid, + mesh_type="spherical", + interp_method=UXPiecewiseConstantFace, + ) + UV = VectorField(name="UV", U=U, V=V) + fieldset = FieldSet([UV, UV.U, UV.V, P]) + pset = ParticleSet( + fieldset, + lon=[30.0], + lat=[5.0], + depth=[50.0], + time=[timedelta(seconds=0.0)], + pclass=Particle, + ) + pset.execute( + runtime=timedelta(minutes=10), dt=timedelta(seconds=60), pyfunc=AdvectionEE, verbose_progress=verbose_progress + ) + + +@pytest.mark.xfail(reason="Output file not implemented yet") +def test_uxstommelgyre_pset_execute_output(): + ds = datasets_unstructured["stommel_gyre_delaunay"] + grid = UxGrid(grid=ds.uxgrid, z=ds.coords["nz"]) + U = Field( + name="U", + data=ds.U, + grid=grid, + mesh_type="spherical", + interp_method=UXPiecewiseConstantFace, + ) + V = Field( + name="V", + data=ds.V, + grid=grid, + mesh_type="spherical", + interp_method=UXPiecewiseConstantFace, + ) + P = Field( + name="P", + data=ds.p, + grid=grid, + mesh_type="spherical", + interp_method=UXPiecewiseConstantFace, + ) + UV = VectorField(name="UV", U=U, V=V) + fieldset = FieldSet([UV, UV.U, UV.V, P]) + pset = ParticleSet( + fieldset, + lon=[30.0], + lat=[5.0], + depth=[50.0], + time=[0.0], + pclass=Particle, + ) + output_file = pset.ParticleFile( + name="stommel_uxarray_particles.zarr", # the file name + outputdt=timedelta(minutes=5), # the time step of the outputs + ) + pset.execute(runtime=timedelta(minutes=10), dt=timedelta(seconds=60), pyfunc=AdvectionEE, output_file=output_file) diff --git a/tests/v4/test_uxarray_fieldset.py b/tests/v4/test_uxarray_fieldset.py index fbb2cc3840..f1ed7262fe 100644 --- a/tests/v4/test_uxarray_fieldset.py +++ b/tests/v4/test_uxarray_fieldset.py @@ -1,5 +1,3 @@ -from datetime import timedelta - import pytest import uxarray as ux @@ -83,9 +81,9 @@ def test_fesom_fieldset(ds_fesom_channel, uv_fesom_channel): assert (fieldset.V == ds_fesom_channel.V).all() -@pytest.mark.skip(reason="ParticleSet.__init__ needs major refactoring") def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel): fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) + # Check that the fieldset has the expected properties assert (fieldset.U == ds_fesom_channel.U).all() assert (fieldset.V == ds_fesom_channel.V).all() @@ -104,19 +102,6 @@ def test_set_interp_methods(ds_fesom_channel, uv_fesom_channel): fieldset.V.interp_method = UXPiecewiseConstantFace -@pytest.mark.skip(reason="ParticleSet.__init__ needs major refactoring") -def test_fesom_channel(ds_fesom_channel, uvw_fesom_channel): - fieldset = FieldSet([uvw_fesom_channel, uvw_fesom_channel.U, uvw_fesom_channel.V, uvw_fesom_channel.W]) - - # Check that the fieldset has the expected properties - assert (fieldset.U == ds_fesom_channel.U).all() - assert (fieldset.V == ds_fesom_channel.V).all() - assert (fieldset.W == ds_fesom_channel.W).all() - - pset = ParticleSet(fieldset, pclass=Particle) - pset.execute(endtime=timedelta(days=1), dt=timedelta(hours=1)) - - def test_fesom2_square_delaunay_uniform_z_coordinate_eval(): """ Test the evaluation of a fieldset with a FESOM2 square Delaunay grid and uniform z-coordinate.