diff --git a/dev/test-vertex-field.ipynb b/dev/test-vertex-field.ipynb new file mode 100644 index 000000000..4fe12c181 --- /dev/null +++ b/dev/test-vertex-field.ipynb @@ -0,0 +1,796 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e7940da2-2548-42d1-99ef-29996377a03f", + "metadata": {}, + "source": [ + "## Initiate `VertexField`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "fecfb0d5-07cf-484b-a3ae-fe4a6931f230", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import discretisedfield as df" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f3b2589f-2452-49a8-9b8d-e6172f2f57cb", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "p1 = (-50, -50, -50)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4d15ed1f-c13e-488f-96e7-fb44a200969c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "p2 = (50, 50, 50)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a9df0e02-c2e4-4f22-8b34-d3bbcf6320d2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "mesh = df.Mesh(p1=p1, p2=p2, n=(100, 100, 100))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a68bb896-c389-4950-b75d-700f7ce2b158", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def Ms(point):\n", + " x, y, z = point\n", + " if x**2 + y**2 <= (20) ** 2:\n", + " return 1\n", + " else:\n", + " return 0" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e2c9f907-1f07-409d-9ef3-a951159d0d9d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "vertex_field = df.VertexField(mesh, nvdim=3, value=[0, 0, 1], norm=Ms)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2af02dd8-40f5-4ed4-9249-a256bccd3bd9", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": {}, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "
\n", + "
\n", + "" + ], + "text/plain": [ + ":DynamicMap [z]\n", + " :Overlay\n", + " .Image.I :Image [x,y] (field)\n", + " .VectorField.I :VectorField [x,y] (angle,mag)" + ] + }, + "execution_count": 11, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "p1355" + } + }, + "output_type": "execute_result" + } + ], + "source": [ + "vertex_field.hv(kdims=[\"x\", \"y\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7eec5343-4e75-4802-a978-ec2497968a06", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "vertex_field_xrr = vertex_field.to_xarray()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "445b5584-cdee-41eb-b4f6-066bdb5e59a6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'field' (x: 101, y: 101, z: 101)>\n",
+       "array([[[1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        ...,\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
+       "\n",
+       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        ...,\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
+       "\n",
+       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        ...,\n",
+       "...\n",
+       "        ...,\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
+       "\n",
+       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        ...,\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
+       "\n",
+       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        ...,\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.]]])\n",
+       "Coordinates:\n",
+       "  * x        (x) float64 0.0 1.0 2.0 3.0 4.0 5.0 ... 96.0 97.0 98.0 99.0 100.0\n",
+       "  * y        (y) float64 0.0 1.0 2.0 3.0 4.0 5.0 ... 96.0 97.0 98.0 99.0 100.0\n",
+       "  * z        (z) float64 0.0 1.0 2.0 3.0 4.0 5.0 ... 96.0 97.0 98.0 99.0 100.0\n",
+       "    vdims    <U1 'z'\n",
+       "Attributes:\n",
+       "    units:             None\n",
+       "    cell:              [1. 1. 1.]\n",
+       "    pmin:              [0 0 0]\n",
+       "    pmax:              [100 100 100]\n",
+       "    nvdim:             3\n",
+       "    tolerance_factor:  1e-12\n",
+       "    data_location:     vertex
" + ], + "text/plain": [ + "\n", + "array([[[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + "...\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]]])\n", + "Coordinates:\n", + " * x (x) float64 0.0 1.0 2.0 3.0 4.0 5.0 ... 96.0 97.0 98.0 99.0 100.0\n", + " * y (y) float64 0.0 1.0 2.0 3.0 4.0 5.0 ... 96.0 97.0 98.0 99.0 100.0\n", + " * z (z) float64 0.0 1.0 2.0 3.0 4.0 5.0 ... 96.0 97.0 98.0 99.0 100.0\n", + " vdims 1. Values are named_tuples ``hv_key_dim(data, unit)`` that + contain the data (which has to fulfil len(data) > 1, typically as a numpy array + or list) and the unit of a string (empty string if there is no unit). + + """ + key_dims = { + dim: dfp.util.hv_key_dim(coords, unit) + for dim, unit in zip(self.mesh.region.dims, self.mesh.region.units) + if len(coords := getattr(self.mesh.cells, dim)) > 1 + } + if self.nvdim > 1: + key_dims["vdims"] = dfp.util.hv_key_dim(self.vdims, "") + return key_dims + + def line(self, p1, p2, n=100): + points = list(self.mesh.line(p1=p1, p2=p2, n=n)) + values = [self(p) for p in points] + + return df.Line( + points=points, + values=values, + point_columns=self.mesh.region.dims, + value_columns=[f"v{dim}" for dim in self.vdims] + if self.vdims is not None + else "v", + ) # TODO scalar fields have no vdim + + def __getitem__(self, item): + submesh = self.mesh[item] + + index_min = self.mesh.point2index( + submesh.index2point((0,) * submesh.region.ndim) + ) + index_max = np.add(index_min, submesh.n) + slices = [slice(i, j) for i, j in zip(index_min, index_max)] + return self.__class__( + submesh, + nvdim=self.nvdim, + value=self.array[tuple(slices)], + vdims=self.vdims, + unit=self.unit, + valid=self.valid[tuple(slices)], + vdim_mapping=self.vdim_mapping, + ) + + @property + def mpl(self): + """Plot interface, matplotlib based. + + This property provides access to the different plotting methods. It is + also callable to quickly generate plots. For more details and the + available methods refer to the documentation linked below. + + .. seealso:: + + :py:func:`~discretisedfield.plotting.Mpl.__call__` + :py:func:`~discretisedfield.plotting.Mpl.scalar` + :py:func:`~discretisedfield.plotting.Mpl.vector` + :py:func:`~discretisedfield.plotting.Mpl.lightness` + :py:func:`~discretisedfield.plotting.Mpl.contour` + + Examples + -------- + .. plot:: :context: close-figs + + 1. Visualising the field using ``matplotlib``. + + >>> import discretisedfield as df + ... + >>> p1 = (0, 0, 0) + >>> p2 = (100, 100, 100) + >>> n = (10, 10, 10) + >>> mesh = df.Mesh(p1=p1, p2=p2, n=n) + >>> field = df.Field(mesh, nvdim=3, value=(1, 2, 0)) + >>> field.sel(z=50).resample(n=(5, 5)).mpl() + + """ + return dfp.MplField(self) + + def to_xarray(self, name="field", unit=None): + """Field value as ``xarray.DataArray``. + + The function returns an ``xarray.DataArray`` with the dimensions + ``self.mesh.region.dims`` and ``vdims`` (only if ``field.nvdim > 1``). The + coordinates of the geometric dimensions are derived from ``self.mesh.points``, + and for vector field components from ``self.vdims``. Addtionally, + the values of ``self.mesh.cell``, ``self.mesh.region.pmin``, and + ``self.mesh.region.pmax`` are stored as ``cell``, ``pmin``, and ``pmax`` + attributes of the DataArray. The ``unit`` attribute of geometric + dimensions is set to the respective strings in ``self.mesh.region.units``. + + The name and unit of the field ``DataArray`` can be set by passing + ``name`` and ``unit``. If the type of value passed to any of the two + arguments is not ``str``, then a ``TypeError`` is raised. + + Parameters + ---------- + name : str, optional + + String to set name of the field ``DataArray``. + + unit : str, optional + + String to set units of the field ``DataArray``. + + Returns + ------- + xarray.DataArray + + Field values DataArray. + + Raises + ------ + TypeError + + If either ``name`` or ``unit`` argument is not a string. + + Examples + -------- + 1. Create a field + + >>> import discretisedfield as df + ... + >>> p1 = (0, 0, 0) + >>> p2 = (10, 10, 10) + >>> cell = (1, 1, 1) + >>> mesh = df.Mesh(p1=p1, p2=p2, cell=cell) + >>> field = df.Field(mesh=mesh, nvdim=3, value=(1, 0, 0), norm=1.) + ... + >>> field + Field(...) + + 2. Create `xarray.DataArray` from field + + >>> xa = field.to_xarray() + >>> xa + + ... + + 3. Select values of `x` component + + >>> xa.sel(vdims='x') + + ... + + """ + if not isinstance(name, str): + msg = "Name argument must be a string." + raise TypeError(msg) + + if unit is not None and not isinstance(unit, str): + msg = "Unit argument must be a string." + raise TypeError(msg) + + axes = self.mesh.region.dims + + data_array_coords = {axis: getattr(self.mesh.cells, axis) for axis in axes} + + geo_units_dict = dict(zip(axes, self.mesh.region.units)) + + if self.nvdim > 1: + data_array_dims = axes + ("vdims",) + if self.vdims is not None: + data_array_coords["vdims"] = self.vdims + field_array = self.array + else: + data_array_dims = axes + field_array = np.squeeze(self.array, axis=-1) + + data_array = xr.DataArray( + field_array, + dims=data_array_dims, + coords=data_array_coords, + name=name, + attrs=dict( + units=unit or self.unit, + cell=self.mesh.cell, + pmin=self.mesh.region.pmin, + pmax=self.mesh.region.pmax, + nvdim=self.nvdim, + tolerance_factor=self.mesh.region.tolerance_factor, + data_location="cell", + ), + ) + + # TODO save vdim_mapping + + for dim in geo_units_dict: + data_array[dim].attrs["units"] = geo_units_dict[dim] + + return data_array + + @classmethod + def from_xarray(cls, xa): + """Create ``discretisedfield.Field`` from ``xarray.DataArray`` + + The class method accepts an ``xarray.DataArray`` as an argument to + return a ``discretisedfield.Field`` object. The first n (or n-1) dimensions of + the DataArray are considered geometric dimensions of a scalar (or vector) field. + In case of a vector field, the last dimension must be named ``vdims``. The + DataArray attribute ``nvdim`` determines whether it is a scalar or a vector + field (i.e. ``nvdim = 1`` is a scalar field and ``nvdim >= 1`` is a vector + field). Hence, ``nvdim`` attribute must be present, greater than or equal to + one, and of an integer type. + + The DataArray coordinates corresponding to the geometric dimensions represent + the discretisation along the respective dimension and must have equally spaced + values. The coordinates of ``vdims`` represent the name of field components + (e.g. ['x', 'y', 'z'] for a 3D vector field). + + Additionally, it is expected to have ``cell``, ``p1``, and ``p2`` attributes for + creating the right mesh for the field; however, in the absence of these, the + coordinates of the geometric axes dimensions are utilized. It should be noted + that ``cell`` attribute is required if any of the geometric directions has only + a single cell. + + Parameters + ---------- + xa : xarray.DataArray + + DataArray to create Field. + + Returns + ------- + discretisedfield.Field + + Field created from DataArray. + + Raises + ------ + TypeError + + - If argument is not ``xarray.DataArray``. + - If ``nvdim`` attribute in not an integer. + + KeyError + + - If at least one of the geometric dimension coordinates has a single + value and ``cell`` attribute is missing. + - If ``nvdim`` attribute is absent. + + ValueError + + - If DataArray does not have a dimension ``vdims`` when attribute ``nvdim`` + is grater than one. + - If coordinates of geometrical dimensions are not equally spaced. + + Examples + -------- + 1. Create a DataArray + + >>> import xarray as xr + >>> import numpy as np + ... + >>> xa = xr.DataArray(np.ones((20, 20, 20, 3), dtype=float), + ... dims = ['x', 'y', 'z', 'vdims'], + ... coords = dict(x=np.arange(0, 20), + ... y=np.arange(0, 20), + ... z=np.arange(0, 20), + ... vdims=['x', 'y', 'z']), + ... name = 'mag', + ... attrs = dict(cell=[1., 1., 1.], + ... p1=[1., 1., 1.], + ... p2=[21., 21., 21.], + ... nvdim=3),) + >>> xa + + ... + + 2. Create Field from DataArray + + >>> import discretisedfield as df + ... + >>> field = df.Field.from_xarray(xa) + >>> field + Field(...) + >>> field.mean() + array([1., 1., 1.]) + + """ + if not isinstance(xa, xr.DataArray): + raise TypeError("Argument must be a xarray.DataArray.") + + if "nvdim" not in xa.attrs: + raise KeyError( + 'The DataArray must have an attribute "nvdim" to identify a scalar or' + " a vector field." + ) + + if xa.attrs["nvdim"] < 1: + raise ValueError('"nvdim" attribute must be greater or equal to 1.') + elif not isinstance(xa.attrs["nvdim"], int): + raise TypeError("The value of nvdim must be an integer.") + + if xa.attrs["nvdim"] > 1 and "vdims" not in xa.dims: + raise ValueError( + 'The DataArray must have a dimension "vdims" when "nvdim" attribute is' + " greater than 1." + ) + + dims_list = [dim for dim in xa.dims if dim != "vdims"] + + for i in dims_list: + if xa[i].values.size > 1 and not np.allclose( + np.diff(xa[i].values), np.diff(xa[i].values).mean() + ): + raise ValueError(f"Coordinates of {i} must be equally spaced.") + + try: + cell = xa.attrs["cell"] + except KeyError: + if any(len_ == 1 for len_ in xa.values.shape[:-1]): + raise KeyError( + "DataArray must have a 'cell' attribute if any " + "of the geometric directions has a single cell." + ) from None + cell = [np.diff(xa[i].values).mean() for i in dims_list] + + p1 = ( + xa.attrs["pmin"] + if "pmin" in xa.attrs + else [xa[i].values[0] - c / 2 for i, c in zip(dims_list, cell)] + ) + p2 = ( + xa.attrs["pmax"] + if "pmax" in xa.attrs + else [xa[i].values[-1] + c / 2 for i, c in zip(dims_list, cell)] + ) + + if any("units" not in xa[i].attrs for i in dims_list): + region = df.Region(p1=p1, p2=p2, dims=dims_list) + mesh = df.Mesh(region=region, cell=cell) + else: + region = df.Region( + p1=p1, p2=p2, dims=dims_list, units=[xa[i].units for i in dims_list] + ) + mesh = df.Mesh(region=region, cell=cell) + + if "tolerance_factor" in xa.attrs: + mesh.region.tolerance_factor = xa.attrs["tolerance_factor"] + + vdims = xa.vdims.values if "vdims" in xa.coords else None + nvdim = xa.attrs["nvdim"] + val = np.expand_dims(xa.values, axis=-1) if nvdim == 1 else xa.values + # print(val.shape) + # TODO load vdim_mapping + return cls( + mesh=mesh, nvdim=nvdim, value=val, vdims=vdims, dtype=xa.values.dtype + ) + + @functools.singledispatchmethod + def _as_array(self, val, mesh, nvdim, dtype): + raise TypeError(f"Unsupported type {type(val)}.") + + # to avoid str being interpreted as iterable + @_as_array.register(str) + def _(self, val, mesh, nvdim, dtype): + raise TypeError(f"Unsupported type {type(val)}.") + + @_as_array.register(numbers.Complex) + @_as_array.register(collections.abc.Iterable) + def _(self, val, mesh, nvdim, dtype): + if isinstance(val, numbers.Complex) and nvdim > 1 and val != 0: + raise ValueError( + f"Wrong dimension 1 provided for value; expected dimension is {nvdim}" + ) + + if isinstance(val, collections.abc.Iterable): + if nvdim == 1 and np.array_equal(np.shape(val), mesh.n): + return np.expand_dims(val, axis=-1) + elif np.shape(val)[-1] != nvdim: + raise ValueError( + f"Wrong dimension {len(val)} provided for value; expected dimension" + f" is {nvdim}." + ) + dtype = dtype or max(np.asarray(val).dtype, np.float64) + return np.full((*mesh.n, nvdim), val, dtype=dtype) + + @_as_array.register(collections.abc.Callable) + def _(self, val, mesh, nvdim, dtype): + # will only be called on user input + # dtype must be specified by the user for complex values + array = np.empty((*mesh.n, nvdim), dtype=dtype) + for index, point in zip(mesh.indices, mesh): + # Conversion to array and reshaping is required for numpy >= 1.24 + # and for certain inputs, e.g. a tuple of numpy arrays which can e.g. occur + # for 1d vector fields. + array[index] = np.asarray(val(point)).reshape(nvdim) + return array + + @_as_array.register(dict) + def _(self, val, mesh, nvdim, dtype): + # will only be called on user input + # dtype must be specified by the user for complex values + dtype = dtype or np.float64 + fill_value = ( + val["default"] + if "default" in val and not callable(val["default"]) + else np.nan + ) + array = np.full((*mesh.n, nvdim), fill_value, dtype=dtype) + + for subregion in reversed(mesh.subregions.keys()): + # subregions can overlap, first subregion takes precedence + try: + submesh = mesh[subregion] + subval = val[subregion] + except KeyError: + continue # subregion not in val when implicitly set via "default" + else: + slices = mesh.region2slices(submesh.region) + array[slices] = self._as_array(subval, submesh, nvdim, dtype) + + if np.any(np.isnan(array)): + # not all subregion keys specified and 'default' is missing or callable + if "default" not in val: + raise KeyError( + "Key 'default' required if not all subregion keys are specified." + ) + subval = val["default"] + for idx in np.argwhere(np.isnan(array[..., 0])): + # only spatial indices required -> array[..., 0] + # conversion to array and reshaping similar to "callable" implementation + array[idx] = np.asarray(subval(mesh.index2point(idx))).reshape(nvdim) + + return array + + +# We cannot register to self inside the class +@CellField._as_array.register(CellField) +def _(self, val, mesh, nvdim, dtype): + if mesh.region not in val.mesh.region: + raise ValueError( + f"{val.mesh.region} of the provided field does not " + f"contain {mesh.region} of the field that is being " + "created." + ) + value = ( + val.to_xarray() + .sel( + **{dim: getattr(mesh.cells, dim) for dim in mesh.region.dims}, + method="nearest", + ) + .data + ) + if nvdim == 1: + # xarray dataarrays for scalar data are three dimensional + return value.reshape(*mesh.n, -1) + return value diff --git a/discretisedfield/field.py b/discretisedfield/field.py index 6d2b47372..fecca3e30 100644 --- a/discretisedfield/field.py +++ b/discretisedfield/field.py @@ -1,10 +1,8 @@ -import collections -import functools +import abc import numbers import numpy as np import scipy.fft as spfft -import xarray as xr from vtkmodules.util import numpy_support as vns from vtkmodules.vtkCommonDataModel import vtkRectilinearGrid @@ -12,7 +10,6 @@ import discretisedfield.plotting as dfp import discretisedfield.util as dfu from discretisedfield.operators import _split_diff_combine -from discretisedfield.plotting.util import hv_key_dim from . import html from .io import _FieldIO @@ -157,12 +154,36 @@ class Field(_FieldIO): "write": "to_file", # method is in io.__init__ } + def __new__( + cls, + mesh, + nvdim=None, + value=0.0, + norm=None, + data_location="cell", + vdims=None, + dtype=None, + unit=None, + valid=True, + vdim_mapping=None, + **kwargs, + ): + if cls in [df.cell_field.CellField, df.vertex_field.VertexField]: + return super().__new__(cls) + elif data_location == "cell": + return super().__new__(df.cell_field.CellField) + elif data_location == "vertex": + return super().__new__(df.vertex_field.VertexField) + else: + raise ValueError(f"Unknown field data location: {data_location}") + def __init__( self, mesh, nvdim=None, value=0.0, norm=None, + data_location="cell", vdims=None, dtype=None, unit=None, @@ -849,6 +870,7 @@ def _repr_html_(self): """Show HTML-based representation in Jupyter notebook.""" return html.get_template("field").render(field=self) + @abc.abstractmethod def __call__(self, point): r"""Sample the field value at ``point``. @@ -886,7 +908,6 @@ def __call__(self, point): array([1., 3., 4.]) """ - return self.array[self.mesh.point2index(point)] def __getattr__(self, attr): """Extract the component of the vector field. @@ -2867,6 +2888,7 @@ def integrate(self, direction=None, cumulative=False): vdim_mapping=self.vdim_mapping, ) + @abc.abstractmethod def line(self, p1, p2, n=100): r"""Sample the field along the line. @@ -2915,17 +2937,6 @@ def line(self, p1, p2, n=100): >>> line = field.line(p1=(0, 0, 0), p2=(2, 0, 0), n=5) """ - points = list(self.mesh.line(p1=p1, p2=p2, n=n)) - values = [self(p) for p in points] - - return df.Line( - points=points, - values=values, - point_columns=self.mesh.region.dims, - value_columns=[f"v{dim}" for dim in self.vdims] - if self.vdims is not None - else "v", - ) # TODO scalar fields have no vdim def sel(self, *args, **kwargs): """Select a part of the field. @@ -3098,6 +3109,7 @@ def resample(self, n): vdim_mapping=self.vdim_mapping, ) + @abc.abstractmethod def __getitem__(self, item): """Extracts the field on a subregion. @@ -3168,22 +3180,6 @@ def __getitem__(self, item): (4, 4, 1, 1) """ - submesh = self.mesh[item] - - index_min = self.mesh.point2index( - submesh.index2point((0,) * submesh.region.ndim) - ) - index_max = np.add(index_min, submesh.n) - slices = [slice(i, j) for i, j in zip(index_min, index_max)] - return self.__class__( - submesh, - nvdim=self.nvdim, - value=self.array[tuple(slices)], - vdims=self.vdims, - unit=self.unit, - valid=self.valid[tuple(slices)], - vdim_mapping=self.vdim_mapping, - ) def angle(self, vector): r"""Angle between two vectors. @@ -3430,12 +3426,17 @@ def to_vtk(self): vns.numpy_to_vtk(np.fromiter(self.mesh.vertices.z, float)) ) - cell_data = rgrid.GetCellData() + if isinstance(self, df.cell_field.CellField): + vtk_data = rgrid.GetCellData() + elif isinstance(self, df.vertex_field.VertexField): + vtk_data = rgrid.GetPointData() + else: + assert False, f"Unknown field type {type(self)}." field_norm = vns.numpy_to_vtk( self.norm.array.transpose((2, 1, 0, 3)).reshape(-1) ) field_norm.SetName("norm") - cell_data.AddArray(field_norm) + vtk_data.AddArray(field_norm) if self.nvdim > 1: # For some visualisation packages it is an advantage to have direct # access to the individual field components, e.g. for colouring. @@ -3444,52 +3445,22 @@ def to_vtk(self): getattr(self, comp).array.transpose((2, 1, 0, 3)).reshape((-1)) ) component_array.SetName(f"{comp}-component") - cell_data.AddArray(component_array) + vtk_data.AddArray(component_array) field_array = vns.numpy_to_vtk( self.array.transpose((2, 1, 0, 3)).reshape((-1, self.nvdim)) ) field_array.SetName("field") - cell_data.AddArray(field_array) + vtk_data.AddArray(field_array) if self.nvdim == 3: - cell_data.SetActiveVectors("field") + vtk_data.SetActiveVectors("field") elif self.nvdim == 1: - cell_data.SetActiveScalars("field") + vtk_data.SetActiveScalars("field") return rgrid - @property + @abc.abstractproperty def mpl(self): - """Plot interface, matplotlib based. - - This property provides access to the different plotting methods. It is - also callable to quickly generate plots. For more details and the - available methods refer to the documentation linked below. - - .. seealso:: - - :py:func:`~discretisedfield.plotting.Mpl.__call__` - :py:func:`~discretisedfield.plotting.Mpl.scalar` - :py:func:`~discretisedfield.plotting.Mpl.vector` - :py:func:`~discretisedfield.plotting.Mpl.lightness` - :py:func:`~discretisedfield.plotting.Mpl.contour` - - Examples - -------- - .. plot:: :context: close-figs - - 1. Visualising the field using ``matplotlib``. - - >>> import discretisedfield as df - ... - >>> p1 = (0, 0, 0) - >>> p2 = (100, 100, 100) - >>> n = (10, 10, 10) - >>> mesh = df.Mesh(p1=p1, p2=p2, n=n) - >>> field = df.Field(mesh, nvdim=3, value=(1, 2, 0)) - >>> field.sel(z=50).resample(n=(5, 5)).mpl() - - """ - return dfp.MplField(self) + pass @property def k3d(self): @@ -3552,24 +3523,9 @@ def _hv_vdims_guess(self, kdims): # the hv class expects two valid vdims or None return None if None in vdims else vdims - @property + @abc.abstractproperty def _hv_key_dims(self): - """Dict of key dimensions of the field. - - Keys are the field dimensions (domain and vector space, e.g. x, y, z, vdims) - that have length > 1. Values are named_tuples ``hv_key_dim(data, unit)`` that - contain the data (which has to fulfil len(data) > 1, typically as a numpy array - or list) and the unit of a string (empty string if there is no unit). - - """ - key_dims = { - dim: hv_key_dim(coords, unit) - for dim, unit in zip(self.mesh.region.dims, self.mesh.region.units) - if len(coords := getattr(self.mesh.cells, dim)) > 1 - } - if self.nvdim > 1: - key_dims["vdims"] = hv_key_dim(self.vdims, "") - return key_dims + pass def fftn(self, **kwargs): """Performs an N-dimensional discrete Fast Fourier Transform (FFT) @@ -3988,372 +3944,14 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): except Exception: raise NotImplementedError() + @abc.abstractmethod def to_xarray(self, name="field", unit=None): - """Field value as ``xarray.DataArray``. + pass - The function returns an ``xarray.DataArray`` with the dimensions - ``self.mesh.region.dims`` and ``vdims`` (only if ``field.nvdim > 1``). The - coordinates of the geometric dimensions are derived from ``self.mesh.points``, - and for vector field components from ``self.vdims``. Addtionally, - the values of ``self.mesh.cell``, ``self.mesh.region.pmin``, and - ``self.mesh.region.pmax`` are stored as ``cell``, ``pmin``, and ``pmax`` - attributes of the DataArray. The ``unit`` attribute of geometric - dimensions is set to the respective strings in ``self.mesh.region.units``. - - The name and unit of the field ``DataArray`` can be set by passing - ``name`` and ``unit``. If the type of value passed to any of the two - arguments is not ``str``, then a ``TypeError`` is raised. - - Parameters - ---------- - name : str, optional - - String to set name of the field ``DataArray``. - - unit : str, optional - - String to set units of the field ``DataArray``. - - Returns - ------- - xarray.DataArray - - Field values DataArray. - - Raises - ------ - TypeError - - If either ``name`` or ``unit`` argument is not a string. - - Examples - -------- - 1. Create a field - - >>> import discretisedfield as df - ... - >>> p1 = (0, 0, 0) - >>> p2 = (10, 10, 10) - >>> cell = (1, 1, 1) - >>> mesh = df.Mesh(p1=p1, p2=p2, cell=cell) - >>> field = df.Field(mesh=mesh, nvdim=3, value=(1, 0, 0), norm=1.) - ... - >>> field - Field(...) - - 2. Create `xarray.DataArray` from field - - >>> xa = field.to_xarray() - >>> xa - - ... - - 3. Select values of `x` component - - >>> xa.sel(vdims='x') - - ... - - """ - if not isinstance(name, str): - msg = "Name argument must be a string." - raise TypeError(msg) - - if unit is not None and not isinstance(unit, str): - msg = "Unit argument must be a string." - raise TypeError(msg) - - axes = self.mesh.region.dims - - data_array_coords = {axis: getattr(self.mesh.cells, axis) for axis in axes} - - geo_units_dict = dict(zip(axes, self.mesh.region.units)) - - if self.nvdim > 1: - data_array_dims = axes + ("vdims",) - if self.vdims is not None: - data_array_coords["vdims"] = self.vdims - field_array = self.array - else: - data_array_dims = axes - field_array = np.squeeze(self.array, axis=-1) - - data_array = xr.DataArray( - field_array, - dims=data_array_dims, - coords=data_array_coords, - name=name, - attrs=dict( - units=unit or self.unit, - cell=self.mesh.cell, - pmin=self.mesh.region.pmin, - pmax=self.mesh.region.pmax, - nvdim=self.nvdim, - tolerance_factor=self.mesh.region.tolerance_factor, - ), - ) - - # TODO save vdim_mapping - - for dim in geo_units_dict: - data_array[dim].attrs["units"] = geo_units_dict[dim] - - return data_array - - @classmethod + @abc.abstractclassmethod def from_xarray(cls, xa): - """Create ``discretisedfield.Field`` from ``xarray.DataArray`` - - The class method accepts an ``xarray.DataArray`` as an argument to - return a ``discretisedfield.Field`` object. The first n (or n-1) dimensions of - the DataArray are considered geometric dimensions of a scalar (or vector) field. - In case of a vector field, the last dimension must be named ``vdims``. The - DataArray attribute ``nvdim`` determines whether it is a scalar or a vector - field (i.e. ``nvdim = 1`` is a scalar field and ``nvdim >= 1`` is a vector - field). Hence, ``nvdim`` attribute must be present, greater than or equal to - one, and of an integer type. - - The DataArray coordinates corresponding to the geometric dimensions represent - the discretisation along the respective dimension and must have equally spaced - values. The coordinates of ``vdims`` represent the name of field components - (e.g. ['x', 'y', 'z'] for a 3D vector field). - - Additionally, it is expected to have ``cell``, ``p1``, and ``p2`` attributes for - creating the right mesh for the field; however, in the absence of these, the - coordinates of the geometric axes dimensions are utilized. It should be noted - that ``cell`` attribute is required if any of the geometric directions has only - a single cell. - - Parameters - ---------- - xa : xarray.DataArray - - DataArray to create Field. - - Returns - ------- - discretisedfield.Field - - Field created from DataArray. - - Raises - ------ - TypeError - - - If argument is not ``xarray.DataArray``. - - If ``nvdim`` attribute in not an integer. - - KeyError - - - If at least one of the geometric dimension coordinates has a single - value and ``cell`` attribute is missing. - - If ``nvdim`` attribute is absent. - - ValueError - - - If DataArray does not have a dimension ``vdims`` when attribute ``nvdim`` - is grater than one. - - If coordinates of geometrical dimensions are not equally spaced. - - Examples - -------- - 1. Create a DataArray - - >>> import xarray as xr - >>> import numpy as np - ... - >>> xa = xr.DataArray(np.ones((20, 20, 20, 3), dtype=float), - ... dims = ['x', 'y', 'z', 'vdims'], - ... coords = dict(x=np.arange(0, 20), - ... y=np.arange(0, 20), - ... z=np.arange(0, 20), - ... vdims=['x', 'y', 'z']), - ... name = 'mag', - ... attrs = dict(cell=[1., 1., 1.], - ... p1=[1., 1., 1.], - ... p2=[21., 21., 21.], - ... nvdim=3),) - >>> xa - - ... - - 2. Create Field from DataArray - - >>> import discretisedfield as df - ... - >>> field = df.Field.from_xarray(xa) - >>> field - Field(...) - >>> field.mean() - array([1., 1., 1.]) + pass - """ - if not isinstance(xa, xr.DataArray): - raise TypeError("Argument must be a xarray.DataArray.") - - if "nvdim" not in xa.attrs: - raise KeyError( - 'The DataArray must have an attribute "nvdim" to identify a scalar or' - " a vector field." - ) - - if xa.attrs["nvdim"] < 1: - raise ValueError('"nvdim" attribute must be greater or equal to 1.') - elif not isinstance(xa.attrs["nvdim"], int): - raise TypeError("The value of nvdim must be an integer.") - - if xa.attrs["nvdim"] > 1 and "vdims" not in xa.dims: - raise ValueError( - 'The DataArray must have a dimension "vdims" when "nvdim" attribute is' - " greater than 1." - ) - - dims_list = [dim for dim in xa.dims if dim != "vdims"] - - for i in dims_list: - if xa[i].values.size > 1 and not np.allclose( - np.diff(xa[i].values), np.diff(xa[i].values).mean() - ): - raise ValueError(f"Coordinates of {i} must be equally spaced.") - - try: - cell = xa.attrs["cell"] - except KeyError: - if any(len_ == 1 for len_ in xa.values.shape[:-1]): - raise KeyError( - "DataArray must have a 'cell' attribute if any " - "of the geometric directions has a single cell." - ) from None - cell = [np.diff(xa[i].values).mean() for i in dims_list] - - p1 = ( - xa.attrs["pmin"] - if "pmin" in xa.attrs - else [xa[i].values[0] - c / 2 for i, c in zip(dims_list, cell)] - ) - p2 = ( - xa.attrs["pmax"] - if "pmax" in xa.attrs - else [xa[i].values[-1] + c / 2 for i, c in zip(dims_list, cell)] - ) - - if any("units" not in xa[i].attrs for i in dims_list): - region = df.Region(p1=p1, p2=p2, dims=dims_list) - mesh = df.Mesh(region=region, cell=cell) - else: - region = df.Region( - p1=p1, p2=p2, dims=dims_list, units=[xa[i].units for i in dims_list] - ) - mesh = df.Mesh(region=region, cell=cell) - - if "tolerance_factor" in xa.attrs: - mesh.region.tolerance_factor = xa.attrs["tolerance_factor"] - - vdims = xa.vdims.values if "vdims" in xa.coords else None - nvdim = xa.attrs["nvdim"] - val = np.expand_dims(xa.values, axis=-1) if nvdim == 1 else xa.values - # print(val.shape) - # TODO load vdim_mapping - return cls( - mesh=mesh, nvdim=nvdim, value=val, vdims=vdims, dtype=xa.values.dtype - ) - - @functools.singledispatchmethod + @abc.abstractmethod def _as_array(self, val, mesh, nvdim, dtype): - raise TypeError(f"Unsupported type {type(val)}.") - - # to avoid str being interpreted as iterable - @_as_array.register(str) - def _(self, val, mesh, nvdim, dtype): - raise TypeError(f"Unsupported type {type(val)}.") - - @_as_array.register(numbers.Complex) - @_as_array.register(collections.abc.Iterable) - def _(self, val, mesh, nvdim, dtype): - if isinstance(val, numbers.Complex) and nvdim > 1 and val != 0: - raise ValueError( - f"Wrong dimension 1 provided for value; expected dimension is {nvdim}" - ) - - if isinstance(val, collections.abc.Iterable): - if nvdim == 1 and np.array_equal(np.shape(val), mesh.n): - return np.expand_dims(val, axis=-1) - elif np.shape(val)[-1] != nvdim: - raise ValueError( - f"Wrong dimension {len(val)} provided for value; expected dimension" - f" is {nvdim}." - ) - dtype = dtype or max(np.asarray(val).dtype, np.float64) - return np.full((*mesh.n, nvdim), val, dtype=dtype) - - @_as_array.register(collections.abc.Callable) - def _(self, val, mesh, nvdim, dtype): - # will only be called on user input - # dtype must be specified by the user for complex values - array = np.empty((*mesh.n, nvdim), dtype=dtype) - for index, point in zip(mesh.indices, mesh): - # Conversion to array and reshaping is required for numpy >= 1.24 - # and for certain inputs, e.g. a tuple of numpy arrays which can e.g. occur - # for 1d vector fields. - array[index] = np.asarray(val(point)).reshape(nvdim) - return array - - @_as_array.register(dict) - def _(self, val, mesh, nvdim, dtype): - # will only be called on user input - # dtype must be specified by the user for complex values - dtype = dtype or np.float64 - fill_value = ( - val["default"] - if "default" in val and not callable(val["default"]) - else np.nan - ) - array = np.full((*mesh.n, nvdim), fill_value, dtype=dtype) - - for subregion in reversed(mesh.subregions.keys()): - # subregions can overlap, first subregion takes precedence - try: - submesh = mesh[subregion] - subval = val[subregion] - except KeyError: - continue # subregion not in val when implicitly set via "default" - else: - slices = mesh.region2slices(submesh.region) - array[slices] = self._as_array(subval, submesh, nvdim, dtype) - - if np.any(np.isnan(array)): - # not all subregion keys specified and 'default' is missing or callable - if "default" not in val: - raise KeyError( - "Key 'default' required if not all subregion keys are specified." - ) - subval = val["default"] - for idx in np.argwhere(np.isnan(array[..., 0])): - # only spatial indices required -> array[..., 0] - # conversion to array and reshaping similar to "callable" implementation - array[idx] = np.asarray(subval(mesh.index2point(idx))).reshape(nvdim) - - return array - - -# We cannot register to self (or df.Field) inside the class -@Field._as_array.register(Field) -def _(self, val, mesh, nvdim, dtype): - if mesh.region not in val.mesh.region: - raise ValueError( - f"{val.mesh.region} of the provided field does not " - f"contain {mesh.region} of the field that is being " - "created." - ) - value = ( - val.to_xarray() - .sel( - **{dim: getattr(mesh.cells, dim) for dim in mesh.region.dims}, - method="nearest", - ) - .data - ) - if nvdim == 1: - # xarray dataarrays for scalar data are three dimensional - return value.reshape(*mesh.n, -1) - return value + """Convert val into a numpy array for the given mesh.""" diff --git a/discretisedfield/io/__init__.py b/discretisedfield/io/__init__.py index f54989382..9fdb1f289 100644 --- a/discretisedfield/io/__init__.py +++ b/discretisedfield/io/__init__.py @@ -14,6 +14,7 @@ from .hdf5 import _FieldIO_HDF5, _MeshIO_HDF5, _RegionIO_HDF5 from .ovf import _FieldIO_OVF +from .vti import _FieldIO_VTI from .vtk import _FieldIO_VTK @@ -57,11 +58,16 @@ def _subregion_filename(filename): return f"{str(filename)}.subregions.json" -class _FieldIO(_FieldIO_HDF5, _FieldIO_OVF, _FieldIO_VTK): +class _FieldIO(_FieldIO_HDF5, _FieldIO_OVF, _FieldIO_VTK, _FieldIO_VTI): __slots__ = [] def to_file( - self, filename, representation="bin8", extend_scalar=False, save_subregions=True + self, + filename, + representation="bin8", + extend_scalar=False, + save_subregions=True, + array_name=None, ): """Write the field to OVF, HDF5, or VTK file. @@ -179,6 +185,12 @@ def to_file( representation=representation, save_subregions=save_subregions, ) + elif filename.suffix == ".vti": + self._to_vti( + filename, + array_name=array_name, + save_subregions=save_subregions, + ) elif filename.suffix in [".hdf5", ".h5"]: self._to_hdf5(filename) else: @@ -274,6 +286,8 @@ def from_file(cls, filename): return cls._from_ovf(filename) elif filename.suffix == ".vtk": return cls._from_vtk(filename) + elif filename.suffix == ".vti": + return cls._from_vti(filename) elif filename.suffix in [".hdf5", ".h5"]: return cls._from_hdf5(filename) else: diff --git a/discretisedfield/io/vti.py b/discretisedfield/io/vti.py new file mode 100644 index 000000000..47c7a84ec --- /dev/null +++ b/discretisedfield/io/vti.py @@ -0,0 +1,56 @@ +import contextlib + +import numpy as np +import pyvista as pv + +import discretisedfield as df + + +class _FieldIO_VTI: + __slots__ = [] + + def _to_vti(self, filename, array_name, save_subregions=True): + grid = pv.ImageData( + dimensions=self.mesh.n + 1, + spacing=self.mesh.cell, + origin=self.mesh.region.pmin, + ) + if isinstance(self, df.cell_field.CellField): + grid.cell_data.set_array( + self.array.reshape(-1, self.nvdim, order="F"), array_name + ) + elif isinstance(self, df.vertex_field.VertexField): + grid.point_data.set_array( + self.array.reshape(-1, self.nvdim, order="F"), array_name + ) + else: + assert False, "This should never happen" + + if save_subregions and self.mesh.subregions: + self.mesh.save_subregions(filename) + + grid.save(filename) + + @classmethod + def _from_vti(cls, filename): + data: pv.core.grid.ImageData = pv.read(filename) + + p1 = data.bounds[::2] + p2 = data.bounds[1::2] + cell = data.spacing + mesh = df.Mesh(p1=p1, p2=p2, cell=cell) + + field_name = data.array_names[0] + value = data[field_name] + nvdim = value.shape[-1] + + value = value.reshape((*data.dimensions, nvdim), order="F") + if np.array_equal(mesh.n, value.shape[:-1]): + data_location = "cell" + else: + data_location = "vertex" + + with contextlib.suppress(FileNotFoundError): + mesh.load_subregions(filename) + + return cls(mesh, nvdim=nvdim, value=value, data_location=data_location) diff --git a/discretisedfield/vertex_field.py b/discretisedfield/vertex_field.py new file mode 100644 index 000000000..66a4ffd6f --- /dev/null +++ b/discretisedfield/vertex_field.py @@ -0,0 +1,271 @@ +import collections +import functools +import numbers +from itertools import product + +import numpy as np +import xarray as xr + +import discretisedfield as df +import discretisedfield.util as dfu +from discretisedfield.plotting.util import hv_key_dim + +from .field import Field + + +class VertexField(Field): + def __call__(self, point): + """TODO Returns nearest node for now.""" + if point not in self.mesh.region: + raise ValueError(f"{point=} not in '{self.mesh.region}'.") + + vertices = self.mesh.vertices + index = tuple(np.argmin(point[i] - vertices[i]) for i in range(self.nvdim)) + + return self.array[index] + + def diff(self, direction, order=1, restrict2valid=True): + """Maybe this is slighly wrong and we should ask Claas about this.""" + super().diff(direction, order=order, restrict2valid=restrict2valid) + + def integrate(self, direction=None, cumulative=False): + """Maybe this is slighly wrong and we should ask Claas about this.""" + super().integrate(direction=direction, cumulative=cumulative) + + def line(self, p1, p2, n): + def mesh_cell_line(p1, p2, n): + if p1 not in self.mesh.region or p2 not in self.mesh.region: + msg = f"Point {p1=} or point {p2=} is outside the mesh region." + raise ValueError(msg) + + dl = np.subtract(p2, p1) / n + for i in range(n): + yield dfu.array2tuple(np.add(p1, i * dl)) + + points = list(mesh_cell_line(p1=p1, p2=p2, n=n)) + values = [self(p) for p in points] + + return df.Line( + points=points, + values=values, + point_columns=self.mesh.region.dims, + value_columns=[f"v{dim}" for dim in self.vdims] + if self.vdims is not None + else "v", + ) # TODO scalar fields have no vdim + + def __getitem__(self, item): + raise NotImplementedError + + def mpl(self): + pass # @Swapneel + + @property + def _hv_key_dims(self): + """Dict of key dimensions of the field. + + Keys are the field dimensions (domain and vector space, e.g. x, y, z, vdims) + that have length > 1. Values are named_tuples ``hv_key_dim(data, unit)`` that + contain the data (which has to fulfil len(data) > 1, typically as a numpy array + or list) and the unit of a string (empty string if there is no unit). + + """ + key_dims = { + dim: hv_key_dim(coords, unit) + for dim, unit in zip(self.mesh.region.dims, self.mesh.region.units) + if len(coords := getattr(self.mesh.vertices, dim)) > 1 + } + if self.nvdim > 1: + key_dims["vdims"] = hv_key_dim(self.vdims, "") + return key_dims + + # def hv(self): + # pass # @Swapneel + # + # NOTE: We are ignoring all the FFTs for now. + + def to_xarray(self, name="field", unit=None): + """VertexField value as ``xarray.DataArray``. + + The method returns an ``xarray.DataArray`` with the dimensions + ``self.mesh.region.dims`` and ``vdims`` (only if ``field.nvdim > 1``). The + coordinates of the geometric dimensions are derived from ``self.mesh.vertices`` + and for vector field components from ``self.vdims``. Additionally, + the values of ``self.mesh.cell``, ``self.mesh.region.pmin``, and + ``self.mesh.region.pmax`` are stored as ``cell``, ``pmin``, and ``pmax`` + attributes of the DataArray. The ``unit`` attribute of geometric + dimensions is set to the respective strings in ``self.mesh.region.units``. + + The name and unit of the ``DataArray`` can be set by passing ``name`` and + ``unit`` respectively. If the type of value passed to any of the two + arguments is not ``str``, then a ``TypeError`` is raised. + + Parameters + ---------- + name : str, optional + + String to set name of the field ``DataArray``. + + unit : str, optional + + String to set units of the field ``DataArray``. + + Returns + ------- + xarray.DataArray + + VertexField values DataArray. + + Raises + ------ + TypeError + + If either ``name`` or ``unit`` argument is not a string. + + Examples + -------- + 1. Create a field + + >>> import discretisedfield as df + ... + >>> p1 = (0, 0, 0) + >>> p2 = (10, 10, 10) + >>> cell = (1, 1, 1) + >>> mesh = df.Mesh(p1=p1, p2=p2, cell=cell) + >>> field = df.VertexField(mesh=mesh, nvdim=3, value=(1, 0, 0), norm=1.) + ... + >>> field + Field(...) + + 2. Create `xarray.DataArray` from field + + >>> xa = field.to_xarray() + >>> xa + + ... + + 3. Select values of `x` component + + >>> xa.sel(vdims='x') + + ... + + """ + if not isinstance(name, str): + msg = "Name argument must be a string." + raise TypeError(msg) + + if unit is not None and not isinstance(unit, str): + msg = "Unit argument must be a string." + raise TypeError(msg) + + axes = self.mesh.region.dims + + data_array_coords = {axis: getattr(self.mesh.vertices, axis) for axis in axes} + + geo_units_dict = dict(zip(axes, self.mesh.region.units)) + + if self.nvdim > 1: + data_array_dims = axes + ("vdims",) + if self.vdims is not None: + data_array_coords["vdims"] = self.vdims + field_array = self.array + else: + data_array_dims = axes + field_array = np.squeeze(self.array, axis=-1) + + data_array = xr.DataArray( + field_array, + dims=data_array_dims, + coords=data_array_coords, + name=name, + attrs=dict( + units=unit or self.unit, + cell=self.mesh.cell, + pmin=self.mesh.region.pmin, + pmax=self.mesh.region.pmax, + nvdim=self.nvdim, + tolerance_factor=self.mesh.region.tolerance_factor, + data_location="vertex", + ), + ) + + # TODO save vdim_mapping + + for dim in geo_units_dict: + data_array[dim].attrs["units"] = geo_units_dict[dim] + + return data_array + + @classmethod + def from_xarray(cls, xa): + raise NotImplementedError + + @functools.singledispatchmethod + def _as_array(self, val, mesh, nvdim, dtype): + raise TypeError(f"Unsupported type {type(val)}.") + + # to avoid str being interpreted as iterable + @_as_array.register(str) + def _(self, val, mesh, nvdim, dtype): + raise TypeError(f"Unsupported type {type(val)}.") + + @_as_array.register(numbers.Complex) + @_as_array.register(collections.abc.Iterable) + def _(self, val, mesh, nvdim, dtype): + if isinstance(val, numbers.Complex) and nvdim > 1 and val != 0: + raise ValueError( + f"Wrong dimension 1 provided for value; expected dimension is {nvdim}" + ) + + if isinstance(val, collections.abc.Iterable): + if nvdim == 1 and np.array_equal(np.shape(val), mesh.n + 1): + return np.expand_dims(val, axis=-1) + elif np.shape(val)[-1] != nvdim: + raise ValueError( + f"Wrong dimension {len(val)} provided for value; expected dimension" + f" is {nvdim}." + ) + dtype = dtype or max(np.asarray(val).dtype, np.float64) + return np.full((*(mesh.n + 1), nvdim), val, dtype=dtype) + + @_as_array.register(collections.abc.Callable) + def _(self, val, mesh, nvdim, dtype): + # will only be called on user input + # dtype must be specified by the user for complex values + n_vertices = [i + 1 for i in mesh.n] + array = np.empty((*n_vertices, nvdim), dtype=dtype) + for index, point in zip( + product(*(range(vertices) for vertices in n_vertices)), + product(*(getattr(mesh.vertices, dim) for dim in mesh.region.dims)), + ): + # Conversion to array and reshaping is required for numpy >= 1.24 + # and for certain inputs, e.g. a tuple of numpy arrays which can e.g. occur + # for 1d vector fields. + array[index] = np.asarray(val(point)).reshape(nvdim) + return array + + +# We cannot register to self inside the class +@VertexField._as_array.register(VertexField) +def _(self, val, mesh, nvdim, dtype): + if mesh.region not in val.mesh.region: + raise ValueError( + f"{val.mesh.region} of the provided field does not " + f"contain {mesh.region} of the field that is being " + "created." + ) + value = ( + val.to_xarray() + .sel( + **{dim: getattr(mesh.vertices, dim) for dim in mesh.region.dims}, + method="nearest", + ) + .data + ) + if nvdim == 1: + # xarray dataarrays for scalar data are three dimensional + return value.reshape(*(mesh.n + 1), -1) + return value + + # TODO: reimplement the remaining _as_array functions. @Swapneel