Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VertexField to support field data at vertices #511

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6c010cf
feat(Field): add `VertexField` to support data at vertices
swapneelap Nov 22, 2023
3058da8
feat: [WIP] vertex field and cell field
lang-m Nov 22, 2023
6bbaffd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2023
41a478a
Move old _as_array to cell_field
lang-m Nov 23, 2023
2af630f
Revert change in html representation to avoid test failures
lang-m Nov 23, 2023
391eddf
Move line to cell_field
lang-m Nov 23, 2023
5417f61
fix: direct creation of subclasses
lang-m Nov 23, 2023
8939836
Implement __call__, line and _as_array for uniform fields
lang-m Nov 23, 2023
e85861d
Merge branch 'master' into vertexfield-subclass
swapneelap Nov 27, 2023
64dcdd2
fix(VertexField,CellField): export VertexField and CellField
swapneelap Nov 27, 2023
8f53395
feat(VertexField): add `to_xarray` method
swapneelap Nov 27, 2023
b170a50
feat(VertexField): get `hv` plot to work
swapneelap Nov 27, 2023
fd7a379
dev: test the `VectorField` functionality
swapneelap Nov 27, 2023
7890e67
feat(VertexField): add functionality to create field using a callable
swapneelap Nov 29, 2023
7ef01db
dev: test `VertexField` functionality with callable
swapneelap Nov 29, 2023
43da142
feat(VertexField): add `_as_array` dispatch to initiate from a field
swapneelap Nov 29, 2023
68265be
Experimental support for reading vti files
lang-m Nov 29, 2023
81139bf
Fix: vti support not added properly
lang-m Nov 29, 2023
97acbe1
remove redundant import
swapneelap Nov 30, 2023
229c23e
remove redundant import
swapneelap Nov 30, 2023
7887bdf
refactor(Field): convert `mpl` to an abstract method
swapneelap Nov 30, 2023
b7ae609
refactor(Field): convert xarray methods to abstract methods
swapneelap Nov 30, 2023
f7c2a2f
refactor(Field): convert `_hv_key_dims` to abstract method
swapneelap Nov 30, 2023
bbcf435
fix(Field): correct abstract method decorators
swapneelap Nov 30, 2023
72d21c8
fix(CellField): remove extra imports
swapneelap Nov 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
796 changes: 796 additions & 0 deletions dev/test-vertex-field.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion discretisedfield/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import matplotlib.pyplot as plt
import pytest

from . import tools
from . import cell_field, tools, vertex_field
from .cell_field import CellField
swapneelap marked this conversation as resolved.
Show resolved Hide resolved
from .field import Field
from .field_rotator import FieldRotator
from .interact import interact
from .line import Line
from .mesh import Mesh
from .operators import integrate
from .region import Region
from .vertex_field import VertexField
swapneelap marked this conversation as resolved.
Show resolved Hide resolved

# Enable default plotting style.
plt.style.use(pathlib.Path(__file__).parent / "plotting" / "plotting-style.mplstyle")
Expand Down
147 changes: 147 additions & 0 deletions discretisedfield/cell_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import collections
import functools
import numbers

import numpy as np

import discretisedfield as df

from .field import Field


class CellField(Field):
def __call__(self, point):
return self.array[self.mesh.point2index(point)]

# diff, integrate depending on how we calculate those for the VertexField

def line(self, p1, p2, n=100):
points = list(self.mesh.line(p1=p1, p2=p2, n=n))
values = [self(p) for p in points]

return df.Line(
points=points,
values=values,
point_columns=self.mesh.region.dims,
value_columns=[f"v{dim}" for dim in self.vdims]
if self.vdims is not None
else "v",
) # TODO scalar fields have no vdim

def __getitem__(self, item):
submesh = self.mesh[item]

index_min = self.mesh.point2index(
submesh.index2point((0,) * submesh.region.ndim)
)
index_max = np.add(index_min, submesh.n)
slices = [slice(i, j) for i, j in zip(index_min, index_max)]
return self.__class__(
submesh,
nvdim=self.nvdim,
value=self.array[tuple(slices)],
vdims=self.vdims,
unit=self.unit,
valid=self.valid[tuple(slices)],
vdim_mapping=self.vdim_mapping,
)

@functools.singledispatchmethod
def _as_array(self, val, mesh, nvdim, dtype):
raise TypeError(f"Unsupported type {type(val)}.")

# to avoid str being interpreted as iterable
@_as_array.register(str)
def _(self, val, mesh, nvdim, dtype):
raise TypeError(f"Unsupported type {type(val)}.")

@_as_array.register(numbers.Complex)
@_as_array.register(collections.abc.Iterable)
def _(self, val, mesh, nvdim, dtype):
if isinstance(val, numbers.Complex) and nvdim > 1 and val != 0:
raise ValueError(
f"Wrong dimension 1 provided for value; expected dimension is {nvdim}"
)

if isinstance(val, collections.abc.Iterable):
if nvdim == 1 and np.array_equal(np.shape(val), mesh.n):
return np.expand_dims(val, axis=-1)
elif np.shape(val)[-1] != nvdim:
raise ValueError(
f"Wrong dimension {len(val)} provided for value; expected dimension"
f" is {nvdim}."
)
dtype = dtype or max(np.asarray(val).dtype, np.float64)
return np.full((*mesh.n, nvdim), val, dtype=dtype)

@_as_array.register(collections.abc.Callable)
def _(self, val, mesh, nvdim, dtype):
# will only be called on user input
# dtype must be specified by the user for complex values
array = np.empty((*mesh.n, nvdim), dtype=dtype)
for index, point in zip(mesh.indices, mesh):
# Conversion to array and reshaping is required for numpy >= 1.24
# and for certain inputs, e.g. a tuple of numpy arrays which can e.g. occur
# for 1d vector fields.
array[index] = np.asarray(val(point)).reshape(nvdim)
return array

@_as_array.register(dict)
def _(self, val, mesh, nvdim, dtype):
# will only be called on user input
# dtype must be specified by the user for complex values
dtype = dtype or np.float64
fill_value = (
val["default"]
if "default" in val and not callable(val["default"])
else np.nan
)
array = np.full((*mesh.n, nvdim), fill_value, dtype=dtype)

for subregion in reversed(mesh.subregions.keys()):
# subregions can overlap, first subregion takes precedence
try:
submesh = mesh[subregion]
subval = val[subregion]
except KeyError:
continue # subregion not in val when implicitly set via "default"
else:
slices = mesh.region2slices(submesh.region)
array[slices] = self._as_array(subval, submesh, nvdim, dtype)

if np.any(np.isnan(array)):
# not all subregion keys specified and 'default' is missing or callable
if "default" not in val:
raise KeyError(
"Key 'default' required if not all subregion keys are specified."
)
subval = val["default"]
for idx in np.argwhere(np.isnan(array[..., 0])):
# only spatial indices required -> array[..., 0]
# conversion to array and reshaping similar to "callable" implementation
array[idx] = np.asarray(subval(mesh.index2point(idx))).reshape(nvdim)

return array


# We cannot register to self inside the class
@CellField._as_array.register(CellField)
def _(self, val, mesh, nvdim, dtype):
if mesh.region not in val.mesh.region:
raise ValueError(
f"{val.mesh.region} of the provided field does not "
f"contain {mesh.region} of the field that is being "
"created."
)
value = (
val.to_xarray()
.sel(
**{dim: getattr(mesh.cells, dim) for dim in mesh.region.dims},
method="nearest",
)
.data
)
if nvdim == 1:
# xarray dataarrays for scalar data are three dimensional
return value.reshape(*mesh.n, -1)
return value
Loading
Loading