Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Fieldset.from_xarray_dataset() method for integration with xarray Datasets #476

Merged
merged 25 commits into from
Oct 31, 2018

Conversation

erikvansebille
Copy link
Member

@erikvansebille erikvansebille commented Oct 17, 2018

This PR implements a new method, FieldSet.from_ds() to run Parcels directly on xarray Dataset objects. Following the suggestion by @rabernat, this fixes #467

An example script would be

from parcels import FieldSet, AdvectionRK4, ParticleSet, JITParticle
from parcels import plotTrajectoriesFile
import numpy as np
import xarray as xr
import gcsfs
from datetime import timedelta as delta


def periodicBC(particle, fieldset, time, dt):
    if particle.lon < 0:
        particle.lon += 360
    elif particle.lon > 360:
        particle.lon -= 360


# create xarray Dataset of AVISO geostrophic velocities using gcsfs
fname = "pangeo-data/dataset-duacs-rep-global-merged-allsat-phy-l4-v3-alt"
gcsmap = gcsfs.mapping.GCSMap(fname)
ds = xr.open_zarr(gcsmap)

# convert Dataset to Parcels FieldSet
variables = {'U': 'ugos', 'V': 'vgos'}
dimensions = {'lon': 'longitude', 'lat': 'latitude', 'time': 'time'}
fset = FieldSet.from_xarray_dataset(ds, variables, dimensions)
fset.add_periodic_halo(zonal=True, meridional=False, halosize=2)

# Create ParticleSet and output file
lons, lats = np.meshgrid(np.arange(60, 80, 2), np.arange(-50, -30, 2))
pset = ParticleSet(fset, JITParticle, lon=lons, lat=lats)
ofile = pset.ParticleFile('aviso_particles.nc', outputdt=delta(days=2))

# Advect ParticleSet with RK4 and periodic Boundary conditions
pset.execute(AdvectionRK4+pset.Kernel(periodicBC), runtime=delta(days=30),
             dt=delta(hours=1), output_file=ofile)

plotTrajectoriesFile('aviso_particles.nc')

The key statement here is fset = FieldSet.from_xarray_dataset(ds, variables, dimensions). Note that users will still need to define what the U and V Fields are named in the Dataset, and that also the dimensions need to be specified.

Copy link
Contributor

@delandmeterp delandmeterp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the computeTimeChunk() function used by the new functions?
It is quite a bit change. I'll test as well a nemo big run on that branch to see if it passes smoothly

else:
data[tindex, 0, :, :] = filebuffer.data[ti, :, :]
data[tindex, 0, :, :] = filebuffer.data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because I introduced a filebuffer.ti attribute, and now select the correct ti in the data property in NetCDFFilebuffer (e.g. https://github.com/OceanParcels/parcels/pull/476/files#diff-a68ef84b67d3b404f48daf2ccc3ca441R1257)

@erikvansebille
Copy link
Member Author

It is quite a bit change. I'll test as well a nemo big run on that branch to see if it passes smoothly

I agree, so we should test very well whether this doesn't break old functionality as well as works for different types of Datasets. Would be good to get @rabertnat's feedback on the latter!

@rabernat
Copy link

This looks very useful! I really appreciate your quick response to my suggestion.

I would be happy to try to help review this. The challenge for me is that it's quite a big PR, and I don't understand the internals of parcels very well. If you could point me towards specific parts where you would like some input on what xarray might be doing, I would be happy to share some feedback.

The key question from my point of view, which I can't completely figure out from reading your code, is when exactly xarray's various forms of lazy array-like things (including dask arrays as well as the other array wrappers used by xarray) are coerced to actual numpy arrays. How does this work? How does parcels decide which data along the space and time axes to load?

Have you actually tried the code in your comment above? If not, would you like me to give this a spin on some other xarray datasets?

Copy link

@rabernat rabernat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few random comments.

@@ -331,6 +331,46 @@ def from_parcels(cls, basename, uvar='vozocrtx', vvar='vomecrty', indices=None,
dimensions=dimensions, allow_time_extrapolation=allow_time_extrapolation,
time_periodic=time_periodic, full_load=full_load, **kwargs)

@classmethod
def from_ds(cls, ds, variables, dimensions, indices=None, mesh='spherical', allow_time_extrapolation=None,
time_periodic=False, full_load=False, **kwargs):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably prefer the name from_xarray_dataset, which is more explicit.

:param indices: Optional dictionary of indices for each dimension
to read from file(s), to allow for reading of subset of data.
Default is to read the full extent of each dimension.
:param extra_fields: Extra fields to read beyond U and V

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the mesh keyword does not appear to be documented in the docstring

pset = ParticleSet(fieldset, JITParticle, 0, 0)

pset.execute(AdvectionRK4, dt=1)
assert pset[0].lon == 4.5 and pset[0].lat == 10

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good test. It makes it quite clear how to use an xarray dataset with parcels.

One way to address the comment about laziness would be to create a special dask-backed array that raises an error when it is computed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! So how would we create a "special dask-backed array that raises an error when it is computed"? Do you have a suggestion for that?



ptype = {'scipy': ScipyParticle, 'jit': JITParticle}


def set_globcurrent_fieldset(filename=None, indices=None, full_load=False):
def set_globcurrent_fieldset(filename=None, indices=None, full_load=False, from_ds=False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe an explicit use_xarray would probably be better thatn from_ds here.

@erikvansebille
Copy link
Member Author

The key question from my point of view, which I can't completely figure out from reading your code, is when exactly xarray's various forms of lazy array-like things (including dask arrays as well as the other array wrappers used by xarray) are coerced to actual numpy arrays. How does this work? How does parcels decide which data along the space and time axes to load?

The loading happens in the Field.computeTimeChunk() method (https://github.com/OceanParcels/parcels/blob/fieldset_from_ds_xarray/parcels/field.py#L880) and the NetcdfFileBuffer.data property (https://github.com/OceanParcels/parcels/blob/fieldset_from_ds_xarray/parcels/field.py#L1243).
Note that we only load in the snapshots that are needed, and that if users give an indices dictionary with subsets of lon,lat,depth indices, we will also only read in these.
I'm not sure though whether we need any extra .load() functions etc for dask. I'm completely new to all that.s

Have you actually tried the code in your comment above? If not, would you like me to give this a spin on some other xarray datasets?

Yes, I have tested the code on my own notebooks here, but it's quite slow, probably mostly due to the latency of the connection to your server. Could you try also running it on a server closer to where the data is stored, and report how fast it goes?

Also changing optional argument in pytest `from_ds` to `use_xarray`,
and adding `mesh` param to from_xarray_dataset() docstring
@erikvansebille erikvansebille changed the title Implementing Fieldset.from_ds() method for integration with xarray Datasets Implementing Fieldset.from_xarray_dataset() method for integration with xarray Datasets Oct 17, 2018
@rabernat
Copy link

Sorry to pester with questions, but I'm trying to understand the design of parcels better in order to understand how xarray / dask will interact with it. A key routine seems to be computeTimeChunk

https://github.com/OceanParcels/parcels/blob/1020717d363426c49a2425d445b298d65f910adb/parcels/field.py#L880-L898

I'm confused because this function takes data as an input, modifies it in place, and then returns it. If you're going to modify it in place, why return anything at all?

When you call this function, e.g.

https://github.com/OceanParcels/parcels/blob/119f6df8e1353ba88be677d57f4487e366e85518/parcels/fieldset.py#L459

I'm pretty sure it would work exactly the same if you just did without any return at all

f.computeTimeChunk(data, f.loaded_time_indices[0])

For example

import numpy as np
def modify_data(data_in):
    data_in[0] = 1
data = np.zeros(2)
modify_data(data)
np.testing.assert_array_equal(data, [1, 0])

How often is computeTimeChunk called? Once per timestep for the whole set of particles? Or once per timestep per particle?

@@ -864,19 +880,20 @@ def advancetime(self, field_new, advanceForward):
def computeTimeChunk(self, data, tindex):
g = self.grid
with NetcdfFileBuffer(self.dataFiles[g.ti+tindex], self.dimensions, self.indices, self.netcdf_engine) as filebuffer:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is self.dataFiles in the case of an xarray dataset? What happens if there are no data files, like in your test below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These self.dataFilesrefer to the Dataset (ds) itself. Filebuffer.dataFiles is set in field.py
https://github.com/OceanParcels/parcels/blob/1020717d363426c49a2425d445b298d65f910adb/parcels/field.py#L208-L211

And dataset is set at the top of the file as
https://github.com/OceanParcels/parcels/blob/1020717d363426c49a2425d445b298d65f910adb/parcels/field.py#L158-L162

This may not be the cleanest way to do it (and the variable names may be confusing, as they originally referred to actual NetCDF files), but this was a simple way to implement xarray handling too.

@erikvansebille
Copy link
Member Author

I'm confused because this function takes data as an input, modifies it in place, and then returns it. If you're going to modify it in place, why return anything at all?

You're right, I just tested and it also works without the return data at the end of Field.computeTimeChunk(). An oversight, I guess. I'll remove for clarity

How often is computeTimeChunk called? Once per timestep for the whole set of particles? Or once per timestep per particle?

It's called once per timestep for the whole set of particles. The particle loop itself runs in C (the JIT compilation), and we exit that to e.g. call computeTimeChunk, but also to write output etc.
https://github.com/OceanParcels/parcels/blob/1020717d363426c49a2425d445b298d65f910adb/parcels/particleset.py#L384

Also testing both forward in time and backward in time
# Conflicts:
#	environment_py2_linux.yml
#	environment_py2_osx.yml
#	environment_py3_osx.yml
#	parcels/field.py
@delandmeterp
Copy link
Contributor

I've checked the NEMO runs who pass fine as well.

@erikvansebille erikvansebille merged commit 676343e into master Oct 31, 2018
@erikvansebille erikvansebille deleted the fieldset_from_ds_xarray branch October 31, 2018 08:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement FieldSet.from_xarray_dataset()
3 participants