From 5769fb124f52a9361c1306ab8c664673a79c608e Mon Sep 17 00:00:00 2001 From: momchil Date: Fri, 8 Apr 2022 12:46:00 -0700 Subject: [PATCH 1/5] Symmetry data now unpacked in sim_data.__getitem__ --- tidy3d/components/data.py | 16 +++++++++++-- tidy3d/components/grid.py | 47 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/tidy3d/components/data.py b/tidy3d/components/data.py index 30b9db8f7e..eb8e88a385 100644 --- a/tidy3d/components/data.py +++ b/tidy3d/components/data.py @@ -13,7 +13,7 @@ from .types import Numpy, Direction, Array, numpy_encoding, Literal, Ax, Coordinate, Symmetry, Axis from .base import Tidy3dBaseModel from .simulation import Simulation -from .grid import YeeGrid +from .grid import Grid, Coords from .viz import add_ax_if_none, equal_aspect from ..log import DataError, log @@ -401,7 +401,7 @@ def colocate(self, x, y, z) -> xr.Dataset: return xr.Dataset(centered_data_dict) # pylint:disable=too-many-locals - def apply_syms(self, new_grid: YeeGrid, sym_center: Coordinate, symmetry: Symmetry): + def apply_syms(self, new_grid: Grid, sym_center: Coordinate, symmetry: Symmetry): """Create a new AbstractFieldData subclass by interpolating on the supplied ``new_grid``, using symmetries as defined by ``sym_center`` and ``symmetry``.""" @@ -1018,6 +1018,18 @@ def __getitem__(self, monitor_name: str) -> Union[Tidy3dDataArray, xr.Dataset]: monitor_data = self.monitor_data.get(monitor_name) if isinstance(monitor_data, MonitorData): return monitor_data.data + if isinstance(monitor_data, AbstractFieldData): + # Unwrap symmetries + monitor = self.simulation.get_monitor_by_name(monitor_name) + sim = self.simulation + span_inds = sim.grid.discretize_inds(monitor.geometry, extend=True) + boundary_dict = {} + for idim, dim in enumerate(["x", "y", "z"]): + ind_beg, ind_end = span_inds[idim] + boundary_dict[dim] = sim.grid.periodic_subspace(idim, ind_beg, ind_end + 1) + mnt_grid = Grid(boundaries=Coords(**boundary_dict)) + return monitor_data.apply_syms(mnt_grid, sim.center, sim.symmetry) + return monitor_data def ensure_monitor_exists(self, monitor_name: str) -> None: diff --git a/tidy3d/components/grid.py b/tidy3d/components/grid.py index 60c6fb11b8..7ab0a119cf 100644 --- a/tidy3d/components/grid.py +++ b/tidy3d/components/grid.py @@ -319,13 +319,16 @@ def _yee_h(self, axis: Axis): return Coords(**yee_coords) - def discretize_inds(self, box: Box) -> List[Tuple[int, int]]: + def discretize_inds(self, box: Box, extend: bool = False) -> List[Tuple[int, int]]: """Start and stopping indexes for the cells that intersect with a :class:`Box`. Parameters ---------- box : :class:`Box` Rectangular geometry within simulation to discretize. + extend : bool = False + If ``True``, ensure that the returned indexes extend sufficiently in very direction to + be able to interpolate any field component at any point within the Box. Returns ------- @@ -340,8 +343,8 @@ def discretize_inds(self, box: Box) -> List[Tuple[int, int]]: inds_list = [] # for each dimension - for axis_label, pt_min, pt_max in zip("xyz", pts_min, pts_max): - bound_coords = boundaries.dict()[axis_label] + for axis, (pt_min, pt_max) in enumerate(zip(pts_min, pts_max)): + bound_coords = boundaries.to_list[axis] assert pt_min <= pt_max, "min point was greater than max point" # index of smallest coord greater than than pt_max @@ -352,7 +355,45 @@ def discretize_inds(self, box: Box) -> List[Tuple[int, int]]: inds_leq_pt_min = np.where(bound_coords <= pt_min)[0] ind_min = 0 if len(inds_leq_pt_min) == 0 else inds_leq_pt_min[-1] + if extend: + # If the box bounds on the left side are to the left of the closest grid center, + # we need an extra pixel to be able to interpolate the center components. + if box.bounds[0][axis] < self.centers.to_list[axis][ind_min]: + ind_min -= 1 + + # We always need an extra pixel on the right for the surface components. + ind_max += 1 + # store indexes inds_list.append((ind_min, ind_max)) return inds_list + + def periodic_subspace(self, axis: Axis, ind_beg: int = 0, ind_end: int = 0) -> Coords1D: + """Pick a subspace of 1D coords within ``range(ind_beg, ind_end)``. If any indexes lie + outside of the coords array, periodic padding is used, where the zeroth and last element + of coords are identified.""" + + coords = self.boundaries.to_list[axis] + padded_coords = coords + num_coords = coords.size + num_cells = num_coords - 1 + coords_width = coords[-1] - coords[0] + + # Pad on the left if needed + if ind_beg < 0: + num_pad = int(np.ceil(-ind_beg / num_cells)) + coords_pad = coords[:-1, None] + (coords_width * np.arange(-num_pad, 0))[None, :] + coords_pad = coords_pad.T.ravel() + padded_coords = np.concatenate([coords_pad, padded_coords]) + ind_beg += num_pad * num_cells + ind_end += num_pad * num_cells + + # Pad on the right if needed + if ind_end >= padded_coords.size: + num_pad = int(np.ceil((ind_end - padded_coords.size) / num_cells)) + coords_pad = coords[1:, None] + (coords_width * np.arange(1, num_pad + 1))[None, :] + coords_pad = coords_pad.T.ravel() + padded_coords = np.concatenate([padded_coords, coords_pad]) + + return padded_coords[ind_beg:ind_end] From 460bff1072c90ed600e36c0ce8cf106f93dedb7b Mon Sep 17 00:00:00 2001 From: momchil Date: Fri, 8 Apr 2022 16:48:57 -0700 Subject: [PATCH 2/5] Reorganizing data.py classes and adding PermittivityData --- tidy3d/components/data.py | 456 ++++++++++++++++++----------- tidy3d/components/grid.py | 21 +- tidy3d/plugins/mode/mode_solver.py | 2 +- tidy3d/plugins/plotly/data.py | 4 +- 4 files changed, 305 insertions(+), 178 deletions(-) diff --git a/tidy3d/components/data.py b/tidy3d/components/data.py index eb8e88a385..eed0eee5f5 100644 --- a/tidy3d/components/data.py +++ b/tidy3d/components/data.py @@ -302,61 +302,14 @@ def ensure_member_exists(self, member_name: str): raise DataError(f"member_name '{member_name}' not found.") -""" Subclasses of MonitorData and CollectionData """ +""" Abstract subclasses of MonitorData and CollectionData """ -class AbstractFieldData(CollectionData, ABC): - """Sores a collection of EM fields either in freq or time domain.""" - - """ Get the standard EM components from the dict using convenient "dot" syntax.""" - - @property - def Ex(self): - """Get Ex component of field using '.Ex' syntax.""" - scalar_data = self.data_dict.get("Ex") - if scalar_data: - return scalar_data.data - return None +class SpatialCollectionData(CollectionData, ABC): + """Sores a collection of scalar data defined over x, y, z (among other) coords.""" - @property - def Ey(self): - """Get Ey component of field using '.Ey' syntax.""" - scalar_data = self.data_dict.get("Ey") - if scalar_data: - return scalar_data.data - return None - - @property - def Ez(self): - """Get Ez component of field using '.Ez' syntax.""" - scalar_data = self.data_dict.get("Ez") - if scalar_data: - return scalar_data.data - return None - - @property - def Hx(self): - """Get Hx component of field using '.Hx' syntax.""" - scalar_data = self.data_dict.get("Hx") - if scalar_data: - return scalar_data.data - return None - - @property - def Hy(self): - """Get Hy component of field using '.Hy' syntax.""" - scalar_data = self.data_dict.get("Hy") - if scalar_data: - return scalar_data.data - return None - - @property - def Hz(self): - """Get Hz component of field using '.Hz' syntax.""" - scalar_data = self.data_dict.get("Hz") - if scalar_data: - return scalar_data.data - return None + # Defines how data components are affected by a positive symmetry along each of the axes + _sym_dict = {} def colocate(self, x, y, z) -> xr.Dataset: """colocate all of the data at a set of x, y, z coordinates. @@ -397,39 +350,45 @@ def colocate(self, x, y, z) -> xr.Dataset: coord_vals = coord_val_map[coord_name] centered_data_array = centered_data_array.interp(**{coord_name: coord_vals}) centered_data_dict[field_name] = centered_data_array - # import pdb; pdb.set_trace() return xr.Dataset(centered_data_dict) # pylint:disable=too-many-locals - def apply_syms(self, new_grid: Grid, sym_center: Coordinate, symmetry: Symmetry): - """Create a new AbstractFieldData subclass by interpolating on the supplied ``new_grid``, - using symmetries as defined by ``sym_center`` and ``symmetry``.""" + def apply_syms(self, grid_dict: Dict[str, Coords], sym_center: Coordinate, symmetry: Symmetry): + """Create a new SpatialCollectionData subclass by interpolating on the supplied + ``new_grid``, using symmetries as defined by ``sym_center`` and ``symmetry``. + + Parameters + ---------- + grid_dict : Dict[str, Coords] + Mapping of the data labels in the SpatialCollectionData to new coordinates on which + to be interpolated, using symmetries to expand beyond the stored domain. + sym_center : Coordinate + Position of the symmetry planes in x, y, and z. + symmetry : Symmetry + Eigenvalues of the symmetry operation in x, y, and z. + + Returns + ------- + SpatialCollectionData + A new SpatialCollectionData with the expanded fields. For data labels that are not + keys in the grid_dict, the data is returned unmodified. + """ new_data_dict = {} - yee_grid_dict = new_grid.yee.grid_dict - # Defines how field components are affected by a positive symmetry along each of the axes - component_sym_dict = { - "Ex": [-1, 1, 1], - "Ey": [1, -1, 1], - "Ez": [1, 1, -1], - "Hx": [1, -1, -1], - "Hy": [-1, 1, -1], - "Hz": [-1, -1, 1], - } - for field, scalar_data in self.data_dict.items(): + for data_key, scalar_data in self.data_dict.items(): new_data = scalar_data.data - # Get new grid locations - yee_coords = yee_grid_dict[field].to_list - # Apply symmetries - zipped = zip("xyz", yee_coords, sym_center, symmetry) - for dim, (dim_name, coords, center, sym) in enumerate(zipped): - # There shouldn't be anything to do if there's no symmetry on this axis - if sym == 0: + zipped = zip("xyz", sym_center, symmetry) + for dim, (dim_name, center, sym) in enumerate(zipped): + # Continue if no symmetry or the data key is not in the supplied grid_dict + if sym == 0 or not grid_dict.get(data_key): continue + # Get new grid locations + coords = grid_dict[data_key].to_list[dim] + # Get indexes of coords that lie on the left of the symmetry center flip_inds = np.where(coords < center)[0] @@ -442,14 +401,77 @@ def apply_syms(self, new_grid: Grid, sym_center: Coordinate, symmetry: Symmetry) new_data = new_data.interp({dim_name: coords_interp}, kwargs={"fill_value": 0.0}) new_data = new_data.assign_coords({dim_name: coords}) - # Apply the correct +/-1 for the field component - new_data[{dim_name: flip_inds}] *= sym * component_sym_dict[field][dim] + # Apply the correct +/-1 for the data_key component + new_data[{dim_name: flip_inds}] *= sym * self._sym_dict[data_key][dim] - new_data_dict[field] = type(scalar_data)(values=new_data.values, **new_data.coords) + new_data_dict[data_key] = type(scalar_data)(values=new_data.values, **new_data.coords) return type(self)(data_dict=new_data_dict) +class AbstractFieldData(SpatialCollectionData, ABC): + """Sores a collection of EM fields either in freq or time domain.""" + + _sym_dict = { + "Ex": (-1, 1, 1), + "Ey": (1, -1, 1), + "Ez": (1, 1, -1), + "Hx": (1, -1, -1), + "Hy": (-1, 1, -1), + "Hz": (-1, -1, 1), + } + + """ Get the standard EM components from the dict using convenient "dot" syntax.""" + + @property + def Ex(self): + """Get Ex component of field using '.Ex' syntax.""" + scalar_data = self.data_dict.get("Ex") + if scalar_data: + return scalar_data.data + return None + + @property + def Ey(self): + """Get Ey component of field using '.Ey' syntax.""" + scalar_data = self.data_dict.get("Ey") + if scalar_data: + return scalar_data.data + return None + + @property + def Ez(self): + """Get Ez component of field using '.Ez' syntax.""" + scalar_data = self.data_dict.get("Ez") + if scalar_data: + return scalar_data.data + return None + + @property + def Hx(self): + """Get Hx component of field using '.Hx' syntax.""" + scalar_data = self.data_dict.get("Hx") + if scalar_data: + return scalar_data.data + return None + + @property + def Hy(self): + """Get Hy component of field using '.Hy' syntax.""" + scalar_data = self.data_dict.get("Hy") + if scalar_data: + return scalar_data.data + return None + + @property + def Hz(self): + """Get Hz component of field using '.Hz' syntax.""" + scalar_data = self.data_dict.get("Hz") + if scalar_data: + return scalar_data.data + return None + + class FreqData(MonitorData, ABC): """Stores frequency-domain data using an ``f`` dimension for frequency in Hz.""" @@ -466,8 +488,8 @@ class TimeData(MonitorData, ABC): t: Array[float] -class AbstractScalarFieldData(MonitorData, ABC): - """Stores a single, scalar field as a function of spatial coordinates x,y,z.""" +class ScalarSpatialData(MonitorData, ABC): + """Stores a single, scalar variable as a function of spatial coordinates x, y, z.""" x: Array[float] y: Array[float] @@ -478,14 +500,20 @@ class PlanarData(MonitorData, ABC): """Stores data that must be found via a planar monitor.""" +class AbstractModeData(PlanarData, FreqData, ABC): + """Abstract class for mode data as a function of frequency and mode index.""" + + mode_index: Array[int] + + class AbstractFluxData(PlanarData, ABC): """Stores electromagnetic flux through a plane.""" -""" usable monitors """ +""" Usable individual data containers for CollectionData monitors """ -class ScalarFieldData(AbstractScalarFieldData, FreqData): +class ScalarFieldData(ScalarSpatialData, FreqData): """Stores a single scalar field in frequency-domain. Parameters @@ -522,7 +550,7 @@ def normalize(self, source_freq_amps: Array[complex]) -> None: self.values /= source_freq_amps # pylint: disable=no-member -class ScalarFieldTimeData(AbstractScalarFieldData, TimeData): +class ScalarFieldTimeData(ScalarSpatialData, TimeData): """stores a single scalar field in time domain Parameters @@ -555,13 +583,21 @@ class ScalarFieldTimeData(AbstractScalarFieldData, TimeData): _dims = ("x", "y", "z", "t") -class FieldData(AbstractFieldData): - """Stores a collection of scalar fields in the frequency domain from a :class:`FieldMonitor`. +class ScalarPermittivityData(ScalarSpatialData, FreqData): + """Stores a single scalar permittivity distribution in frequency-domain. Parameters ---------- - data_dict : Dict[str, :class:`ScalarFieldData`] - Mapping of field name (eg. 'Ex') to its scalar field data. + x : numpy.ndarray + Data coordinates in x direction (um). + y : numpy.ndarray + Data coordinates in y direction (um). + z : numpy.ndarray + Data coordinates in z direction (um). + f : numpy.ndarray + Frequency coordinates (Hz). + values : numpy.ndarray + Complex-valued array of shape ``(len(x), len(y), len(z), len(f))`` storing eps values. Example ------- @@ -570,93 +606,24 @@ class FieldData(AbstractFieldData): >>> y = np.linspace(-2, 2, 20) >>> z = np.linspace(0, 0, 1) >>> values = (1+1j) * np.random.random((len(x), len(y), len(z), len(f))) - >>> field = ScalarFieldData(values=values, x=x, y=y, z=z, f=f) - >>> data = FieldData(data_dict={'Ex': field, 'Ey': field}) - """ - - data_dict: Dict[str, ScalarFieldData] - type: Literal["FieldData"] = "FieldData" - - -class FieldTimeData(AbstractFieldData): - """Stores a collection of scalar fields in the time domain from a :class:`FieldTimeMonitor`. - - Parameters - ---------- - data_dict : Dict[str, :class:`ScalarFieldTimeData`] - Mapping of field name to its scalar field data. - - Example - ------- - >>> t = np.linspace(0, 1e-12, 1001) - >>> x = np.linspace(-1, 1, 10) - >>> y = np.linspace(-2, 2, 20) - >>> z = np.linspace(0, 0, 1) - >>> values = np.random.random((len(x), len(y), len(z), len(t))) - >>> field = ScalarFieldTimeData(values=values, x=x, y=y, z=z, t=t) - >>> data = FieldTimeData(data_dict={'Ex': field, 'Ey': field}) - """ - - data_dict: Dict[str, ScalarFieldTimeData] - type: Literal["FieldTimeData"] = "FieldTimeData" - - -class FluxData(AbstractFluxData, FreqData): - """Stores frequency-domain power flux data from a :class:`FluxMonitor`. - - Parameters - ---------- - f : numpy.ndarray - Frequency coordinates (Hz). - values : numpy.ndarray - Complex-valued array of shape ``(len(f),)`` storing field values. - - Example - ------- - >>> f = np.linspace(2e14, 3e14, 1001) - >>> values = np.random.random((len(f),)) - >>> data = FluxData(values=values, f=f) + >>> data = ScalarPermittivityData(values=values, x=x, y=y, z=z, f=f) """ - values: Array[float] - data_attrs: Dict[str, str] = {"units": "W", "long_name": "flux"} - type: Literal["FluxData"] = "FluxData" + values: Array[complex] + data_attrs: Dict[str, str] = None + type: Literal["ScalarPermittivityData"] = "ScalarPermittivityData" - _dims = ("f",) + _dims = ("x", "y", "z", "f") def normalize(self, source_freq_amps: Array[complex]) -> None: - """normalize the values by the amplitude of the source.""" - self.values /= abs(source_freq_amps) ** 2 # pylint: disable=no-member + pass -class FluxTimeData(AbstractFluxData, TimeData): - """Stores time-domain power flux data from a :class:`FluxTimeMonitor`. - - Parameters - ---------- - t : numpy.ndarray - Time coordinates (sec). - values : numpy.ndarray - Real-valued array of shape ``(len(t),)`` storing field values. - - Example - ------- - >>> t = np.linspace(0, 1e-12, 1001) - >>> values = np.random.random((len(t),)) - >>> data = FluxTimeData(values=values, t=t) - """ - - values: Array[float] - data_attrs: Dict[str, str] = {"units": "W", "long_name": "flux"} - type: Literal["FluxTimeData"] = "FluxTimeData" - - _dims = ("t",) - - -class AbstractModeData(PlanarData, FreqData, ABC): - """Abstract class for mode data as a function of frequency and mode index.""" +class ScalarModeFieldData(ScalarFieldData, AbstractModeData): + """Like ScalarFieldData but with extra dimension ``mode_index``.""" - mode_index: Array[int] + type: Literal["ScalarModeFieldData"] = "ScalarModeFieldData" + _dims = ("x", "y", "z", "f", "mode_index") class ModeAmpsData(AbstractModeData): @@ -743,6 +710,158 @@ def k_eff(self): return _k_eff +""" Usable monitor/collection data """ + + +class FieldData(AbstractFieldData): + """Stores a collection of scalar fields in the frequency domain from a :class:`FieldMonitor`. + + Parameters + ---------- + data_dict : Dict[str, :class:`ScalarFieldData`] + Mapping of field name (eg. 'Ex') to its scalar field data. + + Example + ------- + >>> f = np.linspace(1e14, 2e14, 1001) + >>> x = np.linspace(-1, 1, 10) + >>> y = np.linspace(-2, 2, 20) + >>> z = np.linspace(0, 0, 1) + >>> values = (1+1j) * np.random.random((len(x), len(y), len(z), len(f))) + >>> field = ScalarFieldData(values=values, x=x, y=y, z=z, f=f) + >>> data = FieldData(data_dict={'Ex': field, 'Ey': field}) + """ + + data_dict: Dict[str, ScalarFieldData] + type: Literal["FieldData"] = "FieldData" + + +class FieldTimeData(AbstractFieldData): + """Stores a collection of scalar fields in the time domain from a :class:`FieldTimeMonitor`. + + Parameters + ---------- + data_dict : Dict[str, :class:`ScalarFieldTimeData`] + Mapping of field name to its scalar field data. + + Example + ------- + >>> t = np.linspace(0, 1e-12, 1001) + >>> x = np.linspace(-1, 1, 10) + >>> y = np.linspace(-2, 2, 20) + >>> z = np.linspace(0, 0, 1) + >>> values = np.random.random((len(x), len(y), len(z), len(t))) + >>> field = ScalarFieldTimeData(values=values, x=x, y=y, z=z, t=t) + >>> data = FieldTimeData(data_dict={'Ex': field, 'Ey': field}) + """ + + data_dict: Dict[str, ScalarFieldTimeData] + type: Literal["FieldTimeData"] = "FieldTimeData" + + +class PermittivityData(SpatialCollectionData): + """Sores a collection of permittivity components over spatial coordinates and frequency + from a :class:`PermittivityMonitor`. + + Parameters + ---------- + data_dict : Dict[str, :class:`ScalarPermittivityData`] + Mapping of component name to its scalar data. + + Example + ------- + >>> f = np.linspace(1e14, 2e14, 1001) + >>> x = np.linspace(-1, 1, 10) + >>> y = np.linspace(-2, 2, 20) + >>> z = np.linspace(0, 0, 1) + >>> values = (1+1j) * np.random.random((len(x), len(y), len(z), len(f))) + >>> eps = ScalarPermittivityData(values=values, x=x, y=y, z=z, f=f) + >>> data = PermittivityData(data_dict={'eps_xx': eps, 'eps_yy': eps, 'eps_zz': eps}) + """ + + data_dict: Dict[str, ScalarPermittivityData] + type: Literal["PermittivityData"] = "PermittivityData" + _sym_dict = {"eps_xx": (1, 1, 1), "eps_yy": (1, 1, 1), "eps_zz": (1, 1, 1)} + + """ Get the permittivity components from the dict using convenient "dot" syntax.""" + + @property + def eps_xx(self): + """Get eps_xx component.""" + scalar_data = self.data_dict.get("eps_xx") + if scalar_data: + return scalar_data.data + return None + + @property + def eps_yy(self): + """Get eps_yy component.""" + scalar_data = self.data_dict.get("eps_yy") + if scalar_data: + return scalar_data.data + return None + + @property + def eps_zz(self): + """Get eps_zz component.""" + scalar_data = self.data_dict.get("eps_zz") + if scalar_data: + return scalar_data.data + return None + + +class FluxData(AbstractFluxData, FreqData): + """Stores frequency-domain power flux data from a :class:`FluxMonitor`. + + Parameters + ---------- + f : numpy.ndarray + Frequency coordinates (Hz). + values : numpy.ndarray + Complex-valued array of shape ``(len(f),)`` storing field values. + + Example + ------- + >>> f = np.linspace(2e14, 3e14, 1001) + >>> values = np.random.random((len(f),)) + >>> data = FluxData(values=values, f=f) + """ + + values: Array[float] + data_attrs: Dict[str, str] = {"units": "W", "long_name": "flux"} + type: Literal["FluxData"] = "FluxData" + + _dims = ("f",) + + def normalize(self, source_freq_amps: Array[complex]) -> None: + """normalize the values by the amplitude of the source.""" + self.values /= abs(source_freq_amps) ** 2 # pylint: disable=no-member + + +class FluxTimeData(AbstractFluxData, TimeData): + """Stores time-domain power flux data from a :class:`FluxTimeMonitor`. + + Parameters + ---------- + t : numpy.ndarray + Time coordinates (sec). + values : numpy.ndarray + Real-valued array of shape ``(len(t),)`` storing field values. + + Example + ------- + >>> t = np.linspace(0, 1e-12, 1001) + >>> values = np.random.random((len(t),)) + >>> data = FluxTimeData(values=values, t=t) + """ + + values: Array[float] + data_attrs: Dict[str, str] = {"units": "W", "long_name": "flux"} + type: Literal["FluxTimeData"] = "FluxTimeData" + + _dims = ("t",) + + class ModeData(CollectionData): """Stores a collection of mode decomposition amplitudes and mode effective indexes for all modes in a :class:`.ModeMonitor`. @@ -799,13 +918,6 @@ def k_eff(self): return None -class ScalarModeFieldData(ScalarFieldData, AbstractModeData): - """Like ScalarFieldData but with extra dimension ``mode_index``.""" - - type: Literal["ScalarModeFieldData"] = "ScalarModeFieldData" - _dims = ("x", "y", "z", "f", "mode_index") - - class ModeFieldData(AbstractFieldData): """Like FieldData but with extra dimension ``mode_index``.""" diff --git a/tidy3d/components/grid.py b/tidy3d/components/grid.py index 7ab0a119cf..f625eb9515 100644 --- a/tidy3d/components/grid.py +++ b/tidy3d/components/grid.py @@ -370,9 +370,24 @@ def discretize_inds(self, box: Box, extend: bool = False) -> List[Tuple[int, int return inds_list def periodic_subspace(self, axis: Axis, ind_beg: int = 0, ind_end: int = 0) -> Coords1D: - """Pick a subspace of 1D coords within ``range(ind_beg, ind_end)``. If any indexes lie - outside of the coords array, periodic padding is used, where the zeroth and last element - of coords are identified.""" + """Pick a subspace of 1D boundaries within ``range(ind_beg, ind_end)``. If any indexes lie + outside of the grid boundaries array, periodic padding is used, where the zeroth and last + element of the boundaries are identified. + + Parameters + ---------- + axis : Axis + Axis along which to pick the subspace. + ind_beg : int, optional + Starting index for the subspace. + ind_end : int, optional + Ending index for the subspace. + + Returns + ------- + Coords1D + The subspace of the grid along ``axis``. + """ coords = self.boundaries.to_list[axis] padded_coords = coords diff --git a/tidy3d/plugins/mode/mode_solver.py b/tidy3d/plugins/mode/mode_solver.py index 4490815019..753ce91892 100644 --- a/tidy3d/plugins/mode/mode_solver.py +++ b/tidy3d/plugins/mode/mode_solver.py @@ -349,7 +349,7 @@ def solve(self) -> ModeSolverData: ) field_data = ModeFieldData(data_dict=data_dict).apply_syms( - plane_grid, self.simulation.center, self.simulation.symmetry + plane_grid.yee.grid_dict, self.simulation.center, self.simulation.symmetry ) index_data = ModeIndexData( f=self.freqs, diff --git a/tidy3d/plugins/plotly/data.py b/tidy3d/plugins/plotly/data.py index 7e65a85ead..d899f591f3 100644 --- a/tidy3d/plugins/plotly/data.py +++ b/tidy3d/plugins/plotly/data.py @@ -12,7 +12,7 @@ from .utils import PlotlyFig from ...components.data import FluxData, FluxTimeData, FieldData, FieldTimeData -from ...components.data import ModeFieldData, ModeData, AbstractScalarFieldData +from ...components.data import ModeFieldData, ModeData, ScalarSpatialData from ...components.geometry import Geometry from ...components.types import Axis, Direction from ...log import Tidy3dKeyError, log @@ -491,7 +491,7 @@ def inital_field_val(self): return field_vals[0] @property - def scalar_field_data(self) -> AbstractScalarFieldData: + def scalar_field_data(self) -> ScalarSpatialData: """The current scalar field monitor data.""" if self.field_val is None: self.field_val = self.inital_field_val From a27bb4e91906f2c7cc28c790f0678475d63ba87e Mon Sep 17 00:00:00 2001 From: momchil Date: Fri, 8 Apr 2022 18:13:15 -0700 Subject: [PATCH 3/5] New PermittivityMonitor --- tidy3d/__init__.py | 2 +- tidy3d/components/__init__.py | 2 +- tidy3d/components/data.py | 13 +++++++++++-- tidy3d/components/monitor.py | 28 +++++++++++++++++++++++++++- 4 files changed, 40 insertions(+), 5 deletions(-) diff --git a/tidy3d/__init__.py b/tidy3d/__init__.py index 6b27fd0574..74cd5729c9 100644 --- a/tidy3d/__init__.py +++ b/tidy3d/__init__.py @@ -33,7 +33,7 @@ # monitors from .components import FieldMonitor, FieldTimeMonitor, FluxMonitor, FluxTimeMonitor -from .components import ModeMonitor, ModeFieldMonitor +from .components import ModeMonitor, ModeFieldMonitor, PermittivityMonitor # simulation from .components import Simulation diff --git a/tidy3d/components/__init__.py b/tidy3d/components/__init__.py index 4f16a349bc..41f7e82dcb 100644 --- a/tidy3d/components/__init__.py +++ b/tidy3d/components/__init__.py @@ -29,7 +29,7 @@ # monitor from .monitor import FreqMonitor, TimeMonitor, FieldMonitor, FieldTimeMonitor from .monitor import Monitor, FluxMonitor, FluxTimeMonitor, ModeMonitor -from .monitor import ModeFieldMonitor +from .monitor import ModeFieldMonitor, PermittivityMonitor # simulation from .simulation import Simulation diff --git a/tidy3d/components/data.py b/tidy3d/components/data.py index eed0eee5f5..c24ae4da52 100644 --- a/tidy3d/components/data.py +++ b/tidy3d/components/data.py @@ -946,6 +946,7 @@ def sel_mode_index(self, mode_index): "ScalarFieldTimeData": ScalarFieldTimeData, "FieldData": FieldData, "FieldTimeData": FieldTimeData, + "PermittivityData": PermittivityData, "FluxData": FluxData, "FluxTimeData": FluxTimeData, "ModeAmpsData": ModeAmpsData, @@ -1130,7 +1131,7 @@ def __getitem__(self, monitor_name: str) -> Union[Tidy3dDataArray, xr.Dataset]: monitor_data = self.monitor_data.get(monitor_name) if isinstance(monitor_data, MonitorData): return monitor_data.data - if isinstance(monitor_data, AbstractFieldData): + if isinstance(monitor_data, SpatialCollectionData): # Unwrap symmetries monitor = self.simulation.get_monitor_by_name(monitor_name) sim = self.simulation @@ -1140,7 +1141,15 @@ def __getitem__(self, monitor_name: str) -> Union[Tidy3dDataArray, xr.Dataset]: ind_beg, ind_end = span_inds[idim] boundary_dict[dim] = sim.grid.periodic_subspace(idim, ind_beg, ind_end + 1) mnt_grid = Grid(boundaries=Coords(**boundary_dict)) - return monitor_data.apply_syms(mnt_grid, sim.center, sim.symmetry) + mnt_grid_dict = mnt_grid.yee.grid_dict + if isinstance(monitor_data, PermittivityData): + # Define monitor grid keys where the permittivity lives + mnt_grid_dict = { + "eps_xx": mnt_grid_dict["Ex"], + "eps_yy": mnt_grid_dict["Ey"], + "eps_zz": mnt_grid_dict["Ez"], + } + return monitor_data.apply_syms(mnt_grid_dict, sim.center, sim.symmetry) return monitor_data diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index 3b396f9361..bfd5d5008e 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -319,6 +319,26 @@ def storage_size(self, num_cells: int, tmesh: Array) -> int: return BYTES_REAL * num_steps * num_cells * len(self.fields) +class PermittivityMonitor(FreqMonitor): + """:class:`Monitor` that records the diagonal components of the complex-valued relative + permittivity tensor in the frequency domain. + + Example + ------- + >>> monitor = PermittivityMonitor( + ... center=(1,2,3), + ... size=(2,2,2), + ... freqs=[250e12, 300e12], + ... name='eps_monitor') + """ + + _data_type: Literal["PermittivityData"] = pydantic.Field("PermittivityData") + + def storage_size(self, num_cells: int, tmesh: Array) -> int: + # stores 3 complex number per grid cell, per frequency + return BYTES_COMPLEX * num_cells * len(self.freqs) * 3 + + class FluxMonitor(AbstractFluxMonitor, FreqMonitor): """:class:`Monitor` that records power flux through a plane in the frequency domain. @@ -406,5 +426,11 @@ def storage_size(self, num_cells: int, tmesh: int) -> int: # types of monitors that are accepted by simulation MonitorType = Union[ - FieldMonitor, FieldTimeMonitor, FluxMonitor, FluxTimeMonitor, ModeMonitor, ModeFieldMonitor + FieldMonitor, + FieldTimeMonitor, + PermittivityMonitor, + FluxMonitor, + FluxTimeMonitor, + ModeMonitor, + ModeFieldMonitor, ] From af579cf06a7f29815b5c3a7bf671459f9d1f9ada Mon Sep 17 00:00:00 2001 From: momchil Date: Mon, 11 Apr 2022 13:27:00 -0700 Subject: [PATCH 4/5] Last cleanup, pdating changelog, bumping patch version --- CHANGELOG.md | 2 ++ tidy3d/components/data.py | 12 +++++++----- tidy3d/components/monitor.py | 4 +++- tidy3d/version.py | 2 +- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f9fd71d905..8cbe12045a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - App / GUI for visualizing contents of `SimulationData` in `tidy3d.plugings.plotly`. +- New `PermittivityMonitor` and `PermittivityData` to store the complex relative permittivity as used in the simulation. ### Changed - Faster plotting for matplotlib and plotly. - `SimulationData` normalization keeps track of source index and can be normalized when loading directly from .hdf5 file. +- Monitor data with symmetries now store the minimum required data to file and expands the symmetries on the fly. ## [1.2.1] - 2022-3-30 diff --git a/tidy3d/components/data.py b/tidy3d/components/data.py index c24ae4da52..3c4fa8d59c 100644 --- a/tidy3d/components/data.py +++ b/tidy3d/components/data.py @@ -401,8 +401,10 @@ def apply_syms(self, grid_dict: Dict[str, Coords], sym_center: Coordinate, symme new_data = new_data.interp({dim_name: coords_interp}, kwargs={"fill_value": 0.0}) new_data = new_data.assign_coords({dim_name: coords}) - # Apply the correct +/-1 for the data_key component - new_data[{dim_name: flip_inds}] *= sym * self._sym_dict[data_key][dim] + sym_eval = self._sym_dict.get(data_key) + if sym_eval: + # Apply the correct +/-1 for the data_key component + new_data[{dim_name: flip_inds}] *= sym * sym_eval[dim] new_data_dict[data_key] = type(scalar_data)(values=new_data.values, **new_data.coords) @@ -781,7 +783,6 @@ class PermittivityData(SpatialCollectionData): data_dict: Dict[str, ScalarPermittivityData] type: Literal["PermittivityData"] = "PermittivityData" - _sym_dict = {"eps_xx": (1, 1, 1), "eps_yy": (1, 1, 1), "eps_zz": (1, 1, 1)} """ Get the permittivity components from the dict using convenient "dot" syntax.""" @@ -882,7 +883,7 @@ class ModeData(CollectionData): >>> data = ModeData(data_dict={'n_complex': index_data, 'amps': amps_data}) """ - data_dict: Dict[str, AbstractModeData] + data_dict: Dict[str, Union[ModeAmpsData, ModeIndexData]] type: Literal["ModeData"] = "ModeData" @property @@ -944,6 +945,8 @@ def sel_mode_index(self, mode_index): DATA_TYPE_MAP = { "ScalarFieldData": ScalarFieldData, "ScalarFieldTimeData": ScalarFieldTimeData, + "ScalarPermittivityData": ScalarPermittivityData, + "ScalarModeFieldData": ScalarModeFieldData, "FieldData": FieldData, "FieldTimeData": FieldTimeData, "PermittivityData": PermittivityData, @@ -953,7 +956,6 @@ def sel_mode_index(self, mode_index): "ModeIndexData": ModeIndexData, "ModeData": ModeData, "ModeFieldData": ModeFieldData, - "ScalarModeFieldData": ScalarModeFieldData, } diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index bfd5d5008e..2356f8f03a 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -321,7 +321,9 @@ def storage_size(self, num_cells: int, tmesh: Array) -> int: class PermittivityMonitor(FreqMonitor): """:class:`Monitor` that records the diagonal components of the complex-valued relative - permittivity tensor in the frequency domain. + permittivity tensor in the frequency domain. The recorded data has the same shape as a + :class:`.FieldMonitor` of the same geometry: the permittivity values are saved at the + Yee grid locations, and can be interpolated to any point inside the monitor. Example ------- diff --git a/tidy3d/version.py b/tidy3d/version.py index 70f4131c1f..0fb6c747ce 100644 --- a/tidy3d/version.py +++ b/tidy3d/version.py @@ -1,3 +1,3 @@ """Defines the front end version of tidy3d""" -__version__ = "1.2.1" +__version__ = "1.2.2" From 09e1d60072920913f841290ccd30e7e82f7d77dd Mon Sep 17 00:00:00 2001 From: momchil Date: Mon, 11 Apr 2022 16:31:25 -0700 Subject: [PATCH 5/5] Tyler comments; putting symmetry attributes in SpatialCollectionData --- tidy3d/components/data.py | 125 ++++++++++++++++++----------- tidy3d/components/grid.py | 8 +- tidy3d/plugins/mode/mode_solver.py | 59 ++++++++------ 3 files changed, 118 insertions(+), 74 deletions(-) diff --git a/tidy3d/components/data.py b/tidy3d/components/data.py index 3c4fa8d59c..690d3a16c5 100644 --- a/tidy3d/components/data.py +++ b/tidy3d/components/data.py @@ -2,7 +2,7 @@ """Classes for Storing Monitor and Simulation Data.""" from abc import ABC, abstractmethod -from typing import Dict, List, Union, Optional +from typing import Dict, List, Union, Optional, Tuple import logging import xarray as xr @@ -13,6 +13,7 @@ from .types import Numpy, Direction, Array, numpy_encoding, Literal, Ax, Coordinate, Symmetry, Axis from .base import Tidy3dBaseModel from .simulation import Simulation +from .monitor import Monitor from .grid import Grid, Coords from .viz import add_ax_if_none, equal_aspect from ..log import DataError, log @@ -72,6 +73,11 @@ class Config: # pylint: disable=too-few-public-methods def add_to_group(self, hdf5_grp): """Add data contents to an hdf5 group.""" + @property + @abstractmethod + def sim_data_getitem(self): + """What gets returned by sim_data['monitor_data_name']""" + @classmethod @abstractmethod def load_from_group(cls, hdf5_grp): @@ -208,6 +214,11 @@ def load_from_group(cls, hdf5_grp): return cls(**kwargs) + @property + def sim_data_getitem(self) -> Tidy3dDataArray: + """What gets returned by sim_data['monitor_data_name']""" + return self.data + class CollectionData(Tidy3dData): """Abstract base class. @@ -301,6 +312,11 @@ def ensure_member_exists(self, member_name: str): if member_name not in self.data_dict: raise DataError(f"member_name '{member_name}' not found.") + @property + def sim_data_getitem(self) -> Tidy3dData: + """What gets returned by sim_data['monitor_data_name']""" + return self + """ Abstract subclasses of MonitorData and CollectionData """ @@ -308,8 +324,24 @@ def ensure_member_exists(self, member_name: str): class SpatialCollectionData(CollectionData, ABC): """Sores a collection of scalar data defined over x, y, z (among other) coords.""" - # Defines how data components are affected by a positive symmetry along each of the axes - _sym_dict = {} + """ Attributes storing details about any symmetries that can be used to expand the data. """ + + # Position of the symmetry planes in x, y, and z. + symmetry_center: Coordinate = None + # Eigenvalues of the symmetry under reflection in x, y, and z. + symmetry: Tuple[Symmetry, Symmetry, Symmetry] = (0, 0, 0) + + """Grid after the symmetries (if any) are expanded. The dictionary keys must correspond to + the data keys in the ``data_dict`` for the expanded grid to be invoked.""" + expanded_grid: Dict[str, Coords] = {} + + """Dictionary of the form ``{data_key: Symmetry}``, defining how data components are affected + by a positive symmetry along each of the axes. If the name of a given data in the ``data_dict`` + is not in this dictionary, then in the presence of symmetry the data is just unwrapped with a + positive symmetry value in each direction. If the data name is in the dictionary, for each axis + the corresponding ``_sym_dict`` value times the ``self.symmetry`` eigenvalue is used. + """ + _sym_dict: Dict[str, Symmetry] = {} def colocate(self, x, y, z) -> xr.Dataset: """colocate all of the data at a set of x, y, z coordinates. @@ -352,26 +384,17 @@ def colocate(self, x, y, z) -> xr.Dataset: centered_data_dict[field_name] = centered_data_array return xr.Dataset(centered_data_dict) - # pylint:disable=too-many-locals - def apply_syms(self, grid_dict: Dict[str, Coords], sym_center: Coordinate, symmetry: Symmetry): - """Create a new SpatialCollectionData subclass by interpolating on the supplied - ``new_grid``, using symmetries as defined by ``sym_center`` and ``symmetry``. - - Parameters - ---------- - grid_dict : Dict[str, Coords] - Mapping of the data labels in the SpatialCollectionData to new coordinates on which - to be interpolated, using symmetries to expand beyond the stored domain. - sym_center : Coordinate - Position of the symmetry planes in x, y, and z. - symmetry : Symmetry - Eigenvalues of the symmetry operation in x, y, and z. + @property + def expand_syms(self) -> Tidy3dData: + """Create a new :class:`SpatialCollectionData` subclass by interpolating on the + stored ``expanded_grid` using the stored symmetry information. Returns ------- - SpatialCollectionData - A new SpatialCollectionData with the expanded fields. For data labels that are not - keys in the grid_dict, the data is returned unmodified. + :class:`SpatialCollectionData` + A new data object with the expanded fields. The data is only modified for data keys + found in the ``self.expanded_grid`` dict, and along dimensions where ``self.symmetry`` + is non-zero. """ new_data_dict = {} @@ -380,14 +403,14 @@ def apply_syms(self, grid_dict: Dict[str, Coords], sym_center: Coordinate, symme new_data = scalar_data.data # Apply symmetries - zipped = zip("xyz", sym_center, symmetry) + zipped = zip("xyz", self.symmetry_center, self.symmetry) for dim, (dim_name, center, sym) in enumerate(zipped): - # Continue if no symmetry or the data key is not in the supplied grid_dict - if sym == 0 or not grid_dict.get(data_key): + # Continue if no symmetry or the data key is not in the expanded grid + if sym == 0 or self.expanded_grid.get(data_key) is None: continue # Get new grid locations - coords = grid_dict[data_key].to_list[dim] + coords = self.expanded_grid[data_key].to_list[dim] # Get indexes of coords that lie on the left of the symmetry center flip_inds = np.where(coords < center)[0] @@ -402,7 +425,7 @@ def apply_syms(self, grid_dict: Dict[str, Coords], sym_center: Coordinate, symme new_data = new_data.assign_coords({dim_name: coords}) sym_eval = self._sym_dict.get(data_key) - if sym_eval: + if sym_eval is not None: # Apply the correct +/-1 for the data_key component new_data[{dim_name: flip_inds}] *= sym * sym_eval[dim] @@ -410,6 +433,24 @@ def apply_syms(self, grid_dict: Dict[str, Coords], sym_center: Coordinate, symme return type(self)(data_dict=new_data_dict) + @property + def sim_data_getitem(self) -> Tidy3dData: + """What gets returned by sim_data['monitor_data_name']""" + return self.expand_syms + + def set_symmetry_attrs(self, simulation: Simulation, monitor_name: str): + """Set the collection data attributes related to symmetries.""" + monitor = simulation.get_monitor_by_name(monitor_name) + span_inds = simulation.grid.discretize_inds(monitor.geometry, extend=True) + boundary_dict = {} + for idim, dim in enumerate(["x", "y", "z"]): + ind_beg, ind_end = span_inds[idim] + boundary_dict[dim] = simulation.grid.periodic_subspace(idim, ind_beg, ind_end + 1) + mnt_grid = Grid(boundaries=Coords(**boundary_dict)) + self.expanded_grid = mnt_grid.yee.grid_dict + self.symmetry = simulation.symmetry + self.symmetry_center = simulation.center + class AbstractFieldData(SpatialCollectionData, ABC): """Sores a collection of EM fields either in freq or time domain.""" @@ -810,6 +851,16 @@ def eps_zz(self): return scalar_data.data return None + def set_symmetry_attrs(self, simulation: Simulation, monitor_name: str): + """Set the collection data attributes related to symmetries.""" + super().set_symmetry_attrs(simulation, monitor_name) + # Redefine the expanded grid for epsilon rather than for fields. + self.expanded_grid = { + "eps_xx": self.expanded_grid["Ex"], + "eps_yy": self.expanded_grid["Ey"], + "eps_zz": self.expanded_grid["Ez"], + } + class FluxData(AbstractFluxData, FreqData): """Stores frequency-domain power flux data from a :class:`FluxMonitor`. @@ -1131,29 +1182,9 @@ def __getitem__(self, monitor_name: str) -> Union[Tidy3dDataArray, xr.Dataset]: """ self.ensure_monitor_exists(monitor_name) monitor_data = self.monitor_data.get(monitor_name) - if isinstance(monitor_data, MonitorData): - return monitor_data.data if isinstance(monitor_data, SpatialCollectionData): - # Unwrap symmetries - monitor = self.simulation.get_monitor_by_name(monitor_name) - sim = self.simulation - span_inds = sim.grid.discretize_inds(monitor.geometry, extend=True) - boundary_dict = {} - for idim, dim in enumerate(["x", "y", "z"]): - ind_beg, ind_end = span_inds[idim] - boundary_dict[dim] = sim.grid.periodic_subspace(idim, ind_beg, ind_end + 1) - mnt_grid = Grid(boundaries=Coords(**boundary_dict)) - mnt_grid_dict = mnt_grid.yee.grid_dict - if isinstance(monitor_data, PermittivityData): - # Define monitor grid keys where the permittivity lives - mnt_grid_dict = { - "eps_xx": mnt_grid_dict["Ex"], - "eps_yy": mnt_grid_dict["Ey"], - "eps_zz": mnt_grid_dict["Ez"], - } - return monitor_data.apply_syms(mnt_grid_dict, sim.center, sim.symmetry) - - return monitor_data + monitor_data.set_symmetry_attrs(self.simulation, monitor_name) + return monitor_data.sim_data_getitem def ensure_monitor_exists(self, monitor_name: str) -> None: """Raise exception if monitor isn't in the simulation data""" diff --git a/tidy3d/components/grid.py b/tidy3d/components/grid.py index f625eb9515..7509ba2fbb 100644 --- a/tidy3d/components/grid.py +++ b/tidy3d/components/grid.py @@ -328,7 +328,7 @@ def discretize_inds(self, box: Box, extend: bool = False) -> List[Tuple[int, int Rectangular geometry within simulation to discretize. extend : bool = False If ``True``, ensure that the returned indexes extend sufficiently in very direction to - be able to interpolate any field component at any point within the Box. + be able to interpolate any field component at any point within the ``box``. Returns ------- @@ -377,10 +377,10 @@ def periodic_subspace(self, axis: Axis, ind_beg: int = 0, ind_end: int = 0) -> C Parameters ---------- axis : Axis - Axis along which to pick the subspace. - ind_beg : int, optional + Axis index along which to pick the subspace. + ind_beg : int = 0 Starting index for the subspace. - ind_end : int, optional + ind_end : int = 0 Ending index for the subspace. Returns diff --git a/tidy3d/plugins/mode/mode_solver.py b/tidy3d/plugins/mode/mode_solver.py index 753ce91892..6940a25f93 100644 --- a/tidy3d/plugins/mode/mode_solver.py +++ b/tidy3d/plugins/mode/mode_solver.py @@ -310,7 +310,7 @@ def solve(self) -> ModeSolverData: # Compute and store the modes at all frequencies fields = {"Ex": [], "Ey": [], "Ez": [], "Hx": [], "Hy": [], "Hz": []} n_complex = [] - for ifreq, freq in enumerate(self.freqs): + for freq in self.freqs: # Compute the modes mode_fields, n_comp = compute_modes( eps_cross=self.solver_eps(freq), @@ -324,7 +324,7 @@ def solve(self) -> ModeSolverData: fields_freq = {"Ex": [], "Ey": [], "Ez": [], "Hx": [], "Hy": [], "Hz": []} for mode_index in range(self.mode_spec.num_modes): # Get E and H fields at the current mode_index - ((Ex, Ey, Ez), (Hx, Hy, Hz)) = self.process_fields(mode_fields, ifreq, mode_index) + ((Ex, Ey, Ez), (Hx, Hy, Hz)) = self.process_fields(mode_fields, mode_index) # note: back in original coordinates fields_mode = {"Ex": Ex, "Ey": Ey, "Ez": Ez, "Hx": Hx, "Hy": Hy, "Hz": Hz} @@ -347,10 +347,14 @@ def solve(self) -> ModeSolverData: mode_index=np.arange(self.mode_spec.num_modes), values=np.stack(field, axis=-2), ) - - field_data = ModeFieldData(data_dict=data_dict).apply_syms( - plane_grid.yee.grid_dict, self.simulation.center, self.simulation.symmetry + field_data = ModeFieldData( + data_dict=data_dict, + expanded_grid=plane_grid.yee.grid_dict, + symmetry_center=self.simulation.center, + symmetry=self.simulation.symmetry, ) + field_data = field_data.expand_syms + self.field_decay_warning(field_data) index_data = ModeIndexData( f=self.freqs, mode_index=np.arange(self.mode_spec.num_modes), @@ -399,29 +403,12 @@ def rotate_field_coords(self, field): f_rot = np.stack(self.plane.unpop_axis(f_z, (f_x, f_y), axis=self.normal_axis), axis=0) return f_rot - # pylint:disable=too-many-locals - def process_fields( - self, mode_fields: Array[complex], freq_index: int, mode_index: int - ) -> Tuple[FIELD, FIELD]: + def process_fields(self, mode_fields: Array[complex], mode_index: int) -> Tuple[FIELD, FIELD]: """Transform solver fields to simulation axes, set gauge, and check decay at boundaries.""" # Separate E and H fields (in solver coordinates) E, H = mode_fields[..., mode_index] - # Warn if not decayed at edges - e_edge = 0 - if E.shape[1] > 1: - e_edge = np.sum(np.abs(E[:, 0, :]) ** 2 + np.abs(E[:, -1, :]) ** 2) - if E.shape[2] > 1: - e_edge += np.sum(np.abs(E[:, :, 0]) ** 2 + np.abs(E[:, :, -1]) ** 2) - e_norm = np.sum(np.abs(E) ** 2) - - if e_edge / e_norm > FIELD_DECAY_CUTOFF: - logging.warning( - f"Mode field at frequency index {freq_index}, mode index {mode_index} does not " - "decay at the plane boundaries." - ) - # Set gauge to highest-amplitude in-plane E being real and positive ind_max = np.argmax(np.abs(E[:2])) phi = np.angle(E[:2].ravel()[ind_max]) @@ -440,6 +427,32 @@ def process_fields( return ((Ex, Ey, Ez), (Hx, Hy, Hz)) + def field_decay_warning(self, field_data): + """Warn if any of the modes do not decay at the edges.""" + _, plane_dims = self.plane.pop_axis(["x", "y", "z"], axis=self.normal_axis) + field_sizes = field_data.Ex.sizes + for freq_index in range(field_sizes["f"]): + for mode_index in range(field_sizes["mode_index"]): + e_edge, e_norm = 0, 0 + # Sum up the total field intensity + for E in (field_data.Ex, field_data.Ey, field_data.Ez): + e_norm += np.sum(np.abs(E[{"f": freq_index, "mode_index": mode_index}]) ** 2) + # Sum up the field intensity at the edges + if field_sizes[plane_dims[0]] > 1: + for E in (field_data.Ex, field_data.Ey, field_data.Ez): + isel = {plane_dims[0]: [0, -1], "f": freq_index, "mode_index": mode_index} + e_edge += np.sum(np.abs(E[isel]) ** 2) + if field_sizes[plane_dims[1]] > 1: + for E in (field_data.Ex, field_data.Ey, field_data.Ez): + isel = {plane_dims[1]: [0, -1], "f": freq_index, "mode_index": mode_index} + e_edge += np.sum(np.abs(E[isel]) ** 2) + # Warn if needed + if e_edge / e_norm > FIELD_DECAY_CUTOFF: + logging.warning( + f"Mode field at frequency index {freq_index}, mode index {mode_index} does " + "not decay at the plane boundaries." + ) + def to_source( self, source_time: SourceTime,