From 3d0e333c97152534b35601ddebf0d68b82c5a143 Mon Sep 17 00:00:00 2001 From: momchil Date: Tue, 22 Mar 2022 18:40:47 -0700 Subject: [PATCH 1/2] Reorganizing of ModeSolver and ModeSolverData --- tests/test_components.py | 2 +- tests/test_plugins.py | 6 +- tidy3d/__init__.py | 2 +- tidy3d/components/__init__.py | 2 +- tidy3d/components/data.py | 53 +------ tidy3d/components/monitor.py | 20 ++- tidy3d/plugins/mode/mode_solver.py | 224 +++++++++++++++++++++++------ 7 files changed, 203 insertions(+), 106 deletions(-) diff --git a/tests/test_components.py b/tests/test_components.py index 7dbb49cc85..217e777ef9 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -768,7 +768,7 @@ def test_monitor_plane(): with pytest.raises(ValidationError) as e_info: ModeMonitor(size=size, freqs=freqs, modes=[]) with pytest.raises(ValidationError) as e_info: - ModeSolverMonitor(size=size, freqs=freqs, modes=[]) + ModeFieldMonitor(size=size, freqs=freqs, modes=[]) with pytest.raises(ValidationError) as e_info: FluxMonitor(size=size, freqs=freqs, modes=[]) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 6f2898f363..efe5e5bc5b 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -74,7 +74,6 @@ def test_mode_solver(): size=(2, 2, 2), grid_size=(0.1, 0.1, 0.1), structures=[waveguide], run_time=1e-12 ) plane = td.Box(center=(0, 0, 0), size=(0, 1, 1)) - ms = ModeSolver(simulation=simulation, plane=plane) mode_spec = td.ModeSpec( num_modes=3, target_neff=2.0, @@ -82,7 +81,10 @@ def test_mode_solver(): bend_axis=0, num_pml=(10, 10), ) - modes = ms.solve(mode_spec=mode_spec, freqs=[td.constants.C_0 / 1.0]) + ms = ModeSolver( + simulation=simulation, plane=plane, mode_spec=mode_spec, freqs=[td.constants.C_0 / 1.0] + ) + modes = ms.solve() def _test_coeffs(): diff --git a/tidy3d/__init__.py b/tidy3d/__init__.py index 0928ffe890..6b27fd0574 100644 --- a/tidy3d/__init__.py +++ b/tidy3d/__init__.py @@ -33,7 +33,7 @@ # monitors from .components import FieldMonitor, FieldTimeMonitor, FluxMonitor, FluxTimeMonitor -from .components import ModeMonitor, ModeSolverMonitor +from .components import ModeMonitor, ModeFieldMonitor # simulation from .components import Simulation diff --git a/tidy3d/components/__init__.py b/tidy3d/components/__init__.py index 52a76e209d..4f16a349bc 100644 --- a/tidy3d/components/__init__.py +++ b/tidy3d/components/__init__.py @@ -29,7 +29,7 @@ # monitor from .monitor import FreqMonitor, TimeMonitor, FieldMonitor, FieldTimeMonitor from .monitor import Monitor, FluxMonitor, FluxTimeMonitor, ModeMonitor -from .monitor import ModeSolverMonitor +from .monitor import ModeFieldMonitor # simulation from .simulation import Simulation diff --git a/tidy3d/components/data.py b/tidy3d/components/data.py index 31a1f9f159..b3aad4a314 100644 --- a/tidy3d/components/data.py +++ b/tidy3d/components/data.py @@ -826,48 +826,6 @@ def sel_mode_index(self, mode_index): return FieldData(data_dict=data_dict) -class ModeSolverData(CollectionData): - """Stores a collection of mode field profiles and mode effective indexes from the mode solver. - - Parameters - ---------- - data_dict : Dict[str, :class:`AbstractModeData`] - Mapping of "n_complex" to :class:`ModeIndexData`, and "fields" to :class:`ModeFieldData`. - """ - - data_dict: Dict[str, Union[AbstractModeData, AbstractFieldData]] - type: Literal["ModeSolverData"] = "ModeSolverData" - - @property - def fields(self): - """Get field data.""" - return self.data_dict.get("fields") - - @property - def n_complex(self): - """Get complex effective indexes.""" - scalar_data = self.data_dict.get("n_complex") - if scalar_data: - return scalar_data.data - return None - - @property - def n_eff(self): - """Get real part of effective index.""" - scalar_data = self.data_dict.get("n_complex") - if scalar_data: - return scalar_data.n_eff - return None - - @property - def k_eff(self): - """Get imaginary part of effective index.""" - scalar_data = self.data_dict.get("n_complex") - if scalar_data: - return scalar_data.k_eff - return None - - # maps MonitorData.type string to the actual type, for MonitorData.from_file() DATA_TYPE_MAP = { "ScalarFieldData": ScalarFieldData, @@ -881,7 +839,6 @@ def k_eff(self): "ModeData": ModeData, "ModeFieldData": ModeFieldData, "ScalarModeFieldData": ScalarModeFieldData, - "ModeSolverData": ModeSolverData, } @@ -977,8 +934,6 @@ def at_centers(self, field_monitor_name: str) -> xr.Dataset: # get the data self.ensure_monitor_exists(field_monitor_name) field_monitor_data = self.monitor_data.get(field_monitor_name) - if isinstance(field_monitor_data, ModeSolverData): - field_monitor_data = field_monitor_data.fields self.ensure_field_monitor(field_monitor_data) # get the monitor, discretize, and get center locations @@ -1034,7 +989,7 @@ def plot_field( time: float = None if monitor is a :class:`FieldTimeMonitor`, specifies the time (sec) to plot the field. mode_index: int = None - if monitor is a :class:`ModeSolverMonitor`, specifies which mode index to plot. + if monitor is a :class:`ModeFieldMonitor`, specifies which mode index to plot. eps_alpha : float = 0.2 Opacity of the structure permittivity. Must be between 0 and 1 (inclusive). @@ -1055,10 +1010,10 @@ def plot_field( # get the monitor data self.ensure_monitor_exists(field_monitor_name) monitor_data = self.monitor_data.get(field_monitor_name) - if isinstance(monitor_data, ModeSolverData): + if isinstance(monitor_data, ModeFieldData): if mode_index is None: - raise DataError("'mode_index' must be supplied to plot a ModeSolverMonitor.") - monitor_data = monitor_data.fields.sel_mode_index(mode_index=mode_index) + raise DataError("'mode_index' must be supplied to plot a ModeFieldMonitor.") + monitor_data = monitor_data.sel_mode_index(mode_index=mode_index) self.ensure_field_monitor(monitor_data) # get the field data component diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index a1c8161995..1e10bf1014 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -225,7 +225,7 @@ class FieldMonitor(AbstractFieldMonitor, FreqMonitor): ... name='steady_state_monitor') """ - _data_type: Literal["ScalarFieldData"] = pydantic.Field("ScalarFieldData") + _data_type: Literal["FieldData"] = pydantic.Field("FieldData") def storage_size(self, num_cells: int, tmesh: Array) -> int: # stores 1 complex number per grid cell, per frequency, per field @@ -316,7 +316,7 @@ class FieldTimeMonitor(AbstractFieldMonitor, TimeMonitor): ... name='movie_monitor') """ - _data_type: Literal["ScalarFieldTimeData"] = pydantic.Field("ScalarFieldTimeData") + _data_type: Literal["FieldTimeData"] = pydantic.Field("FieldTimeData") def storage_size(self, num_cells: int, tmesh: Array) -> int: # stores 1 real number per grid cell, per time step, per field @@ -386,14 +386,14 @@ def storage_size(self, num_cells: int, tmesh: int) -> int: return 3 * BYTES_COMPLEX * len(self.freqs) * self.mode_spec.num_modes -class ModeSolverMonitor(AbstractModeMonitor): - """:class:`Monitor` that stores the mode data (field profiles and effective index) - returned by the mode solver in the monitor plane. +class ModeFieldMonitor(AbstractModeMonitor): + """:class:`Monitor` that stores the mode field profiles returned by the mode solver in the + monitor plane. Example ------- >>> mode_spec = ModeSpec(num_modes=3) - >>> monitor = ModeSolverMonitor( + >>> monitor = ModeFieldMonitor( ... center=(1,2,3), ... size=(2,2,0), ... freqs=[200e12, 210e12], @@ -401,17 +401,15 @@ class ModeSolverMonitor(AbstractModeMonitor): ... name='mode_monitor') """ - _data_type: Literal["ModeSolverData"] = pydantic.Field("ModeSolverData") + _data_type: Literal["ModeFieldData"] = pydantic.Field("ModeFieldData") def storage_size(self, num_cells: int, tmesh: int) -> int: # fields store 6 complex numbers per grid cell, per frequency, per mode. field_size = 6 * BYTES_COMPLEX * num_cells * len(self.freqs) * self.mode_spec.num_modes - # effective index stores 1 complex number per frequency per mode - neff_size = BYTES_COMPLEX * len(self.freqs) * self.mode_spec.num_modes - return field_size + neff_size + return field_size # types of monitors that are accepted by simulation MonitorType = Union[ - FieldMonitor, FieldTimeMonitor, FluxMonitor, FluxTimeMonitor, ModeMonitor, ModeSolverMonitor + FieldMonitor, FieldTimeMonitor, FluxMonitor, FluxTimeMonitor, ModeMonitor, ModeFieldMonitor ] diff --git a/tidy3d/plugins/mode/mode_solver.py b/tidy3d/plugins/mode/mode_solver.py index 996fc69996..02a8ff2ed0 100644 --- a/tidy3d/plugins/mode/mode_solver.py +++ b/tidy3d/plugins/mode/mode_solver.py @@ -1,9 +1,11 @@ -"""Turn Mode Specifications into Mode profiles +"""Solve for modes in a 2D cross-sectional plane in a simulation, assuming translational +invariance along a given propagation axis. """ -from typing import List, Tuple +from typing import List, Tuple, Union, Dict import logging +import h5py import numpy as np import pydantic @@ -14,7 +16,7 @@ from ...components import ModeMonitor from ...components.source import ModeSource, SourceTime from ...components.types import Direction, Array -from ...components.data import ModeIndexData, ModeFieldData, ScalarModeFieldData, ModeSolverData +from ...components.data import Tidy3dData, ModeIndexData, ModeFieldData, ScalarModeFieldData from ...log import ValidationError from .solver import compute_modes @@ -25,6 +27,136 @@ FIELD_DECAY_CUTOFF = 1e-2 +class ModeSolverData(Tidy3dBaseModel): + """Holds data associated with :class:`.ModeSolver`. + + Parameters + ---------- + mode_solver : :class:`.ModeSolver` + Original mode solver instance. + data_dict : Dict[str, :class:`.AbstractModeData`] + Mapping of "n_complex" to :class:`.ModeIndexData`, and "fields" to :class:`.ModeFieldData`. + """ + + simulation: Simulation + plane: Box + mode_spec: ModeSpec + data_dict: Dict[str, Union[ModeFieldData, ModeIndexData]] + + @property + def fields(self): + """Get field data.""" + return self.data_dict.get("fields") + + @property + def n_complex(self): + """Get complex effective indexes.""" + scalar_data = self.data_dict.get("n_complex") + if scalar_data: + return scalar_data.data + return None + + @property + def n_eff(self): + """Get real part of effective index.""" + scalar_data = self.data_dict.get("n_complex") + if scalar_data: + return scalar_data.n_eff + return None + + @property + def k_eff(self): + """Get imaginary part of effective index.""" + scalar_data = self.data_dict.get("n_complex") + if scalar_data: + return scalar_data.k_eff + return None + + def add_to_handle(self, handle: Union[h5py.File, h5py.Group]): + """Export to an hdf5 handle, which can be a file or a group. + + Parameters + ---------- + handle : Union[hdf5.File, hdf5.Group] + Handle to write the ModeSolverData to. + """ + + # save pydantic models as string + json_dict = { + "simulation": self.simulation, + "plane": self.plane, + "mode_spec": self.mode_spec, + } + for name, obj in json_dict.items(): + Tidy3dData.save_string(handle, name, obj.json()) + + # make groups for mode fields and index data + for name, data in self.data_dict.items(): + data_grp = handle.create_group(name) + data.add_to_group(data_grp) + + @classmethod + def load_from_handle(cls, handle: Union[h5py.File, h5py.Group]): + """Load from an hdf5 handle, which can be a file or a group. + + Parameters + ---------- + handle : Union[hdf5.File, hdf5.Group] + Handle to load the ModeSolverData from. + """ + + # construct pydantic models from string + json_dict = { + "simulation": Simulation, + "plane": Box, + "mode_spec": ModeSpec, + } + obj_dict = {} + for name, obj in json_dict.items(): + json_string = Tidy3dData.load_string(handle, name) + obj_dict[name] = obj.parse_raw(json_string) + + # load fields and effective index data + data_dict = { + "fields": ModeFieldData.load_from_group(handle["fields"]), + "n_complex": ModeIndexData.load_from_group(handle["n_complex"]), + } + return cls(data_dict=data_dict, **obj_dict) + + def to_file(self, fname: str) -> None: + """Export :class:`.ModeSolverData` to single hdf5 file. + + Parameters + ---------- + fname : str + Path to .hdf5 data file (including filename). + """ + + with h5py.File(fname, "a") as f_handle: + self.add_to_handle(f_handle) + + @classmethod + def from_file(cls, fname: str): + """Load :class:`.ModeSolverData` from .hdf5 file. + + Parameters + ---------- + fname : str + Path to .hdf5 data file (including filename). + + Returns + ------- + :class:`.ModeSolverData` + A :class:`.ModeSolverData` instance. + """ + + # read from file at fname + with h5py.File(fname, "r") as f_handle: + mode_solver = cls.load_from_handle(f_handle) + + return mode_solver + + class ModeSolver(Tidy3dBaseModel): """Interface for solving electromagnetic eigenmodes in a 2D plane with translational invariance in the third dimension. @@ -38,6 +170,16 @@ class ModeSolver(Tidy3dBaseModel): ..., title="Plane", description="Cross-sectional plane in which the mode will be computed." ) + mode_spec: ModeSpec = pydantic.Field( + ..., + title="Mode specification", + description="Container with specifications about the modes to be solved for.", + ) + + freqs: List[float] = pydantic.Field( + ..., title="Frequencies", description="A list of frequencies at which to solve." + ) + @pydantic.validator("plane", always=True) def is_plane(cls, val): """Raise validation error if not planar.""" @@ -55,20 +197,14 @@ def plane_sym(self): """Potentially smaller plane if symmetries present in the simulation.""" return self.simulation.min_sym_box(self.plane) - def solve(self, mode_spec: ModeSpec, freqs: List[float]) -> ModeSolverData: - """Solves for modal profile and effective index of ``Mode`` object. - - Parameters - ---------- - mode_spec : :class:`ModeSpec` - ``ModeSpec`` object containing specifications of the mode solver. - freqs : List[float] - List of frequencies to solve at (Hz). + def solve(self) -> ModeSolverData: + """Finds the modal profile and effective index of the modes. Returns ------- ModeSolverData - ``ModeSolverData`` object containing the effective index and mode fields for all modes. + :class:`.ModeSolverData` object containing the effective index and mode fields for all + modes. """ normal_axis = self.normal_axis @@ -90,19 +226,19 @@ def solve(self, mode_spec: ModeSpec, freqs: List[float]) -> ModeSolverData: # Compute and store the modes at all frequencies fields = {"Ex": [], "Ey": [], "Ez": [], "Hx": [], "Hy": [], "Hz": []} n_complex = [] - for ifreq, freq in enumerate(freqs): + for ifreq, freq in enumerate(self.freqs): # Compute the modes mode_fields, n_comp = compute_modes( eps_cross=self.solver_eps(freq), coords=solver_coords, freq=freq, - mode_spec=mode_spec, + mode_spec=self.mode_spec, symmetry=solver_symmetry, ) n_complex.append(n_comp) fields_freq = {"Ex": [], "Ey": [], "Ez": [], "Hx": [], "Hy": [], "Hz": []} - for mode_index in range(mode_spec.num_modes): + for mode_index in range(self.mode_spec.num_modes): # Get E and H fields at the current mode_index ((Ex, Ey, Ez), (Hx, Hy, Hz)) = self.process_fields(mode_fields, ifreq, mode_index) @@ -123,8 +259,8 @@ def solve(self, mode_spec: ModeSpec, freqs: List[float]) -> ModeSolverData: x=xyz_coords[0], y=xyz_coords[1], z=xyz_coords[2], - f=freqs, - mode_index=np.arange(mode_spec.num_modes), + f=self.freqs, + mode_index=np.arange(self.mode_spec.num_modes), values=np.stack(field, axis=-2), ) @@ -132,11 +268,16 @@ def solve(self, mode_spec: ModeSpec, freqs: List[float]) -> ModeSolverData: plane_grid, self.simulation.center, self.simulation.symmetry ) index_data = ModeIndexData( - f=freqs, - mode_index=np.arange(mode_spec.num_modes), + f=self.freqs, + mode_index=np.arange(self.mode_spec.num_modes), values=np.stack(n_complex, axis=0), ) - mode_info = ModeSolverData(data_dict={"fields": field_data, "n_complex": index_data}) + mode_info = ModeSolverData( + simulation=self.simulation, + plane=self.plane, + mode_spec=self.mode_spec, + data_dict={"fields": field_data, "n_complex": index_data}, + ) return mode_info @@ -216,18 +357,15 @@ def process_fields( def to_source( self, - mode_spec: ModeSpec, source_time: SourceTime, direction: Direction, mode_index: int = 0, ) -> ModeSource: - """Creates :class:`ModeSource` from a ModeSolver instance + additional specifications. + """Creates :class:`.ModeSource` from a ModeSolver instance + additional specifications. Parameters ---------- - mode_spec : :class:`ModeSpec` - :class:`ModeSpec` object containing specifications of mode. - source_time: :class:`SourceTime` + source_time: :class:`.SourceTime` Specification of the source time-dependence. direction : Direction Whether source will inject in ``"+"`` or ``"-"`` direction relative to plane normal. @@ -236,37 +374,41 @@ def to_source( Returns ------- - ModeSource - Modal source containing specification in ``mode``. + :class:`.ModeSource` + Mode source with specifications taken from the ModeSolver instance and the method + inputs. """ - center = self.plane.center - size = self.plane.size return ModeSource( - center=center, - size=size, + center=self.plane.center, + size=self.plane.size, source_time=source_time, - mode_spec=mode_spec, + mode_spec=self.mode_spec, mode_index=mode_index, direction=direction, ) - def to_monitor(self, mode_spec: ModeSpec, freqs: List[float], name: str) -> ModeMonitor: + def to_monitor(self, freqs: List[float], name: str) -> ModeMonitor: """Creates :class:`ModeMonitor` from a ModeSolver instance + additional specifications. Parameters ---------- - mode_spec : :class:`ModeSpec` - :class:`ModeSpec` object containing specifications of mode. freqs : List[float] Frequencies to include in Monitor (Hz). name : str Required name of monitor. + Returns ------- - ModeMonitor - Monitor that measures modes specified by ``mode_spec`` on ``plane`` at ``freqs``. + :class:`.ModeMonitor` + Mode monitor with specifications taken from the ModeSolver instance and the method + inputs. """ - center = self.plane.center - size = self.plane.size - return ModeMonitor(center=center, size=size, freqs=freqs, mode_spec=mode_spec, name=name) + + return ModeMonitor( + center=self.plane.size, + size=self.plane.center, + freqs=freqs, + mode_spec=self.mode_spec, + name=name, + ) From 818de502441958bd6f509e04c901d691e3fd267e Mon Sep 17 00:00:00 2001 From: momchil Date: Wed, 23 Mar 2022 16:04:06 -0700 Subject: [PATCH 2/2] Making AbstractSimulationData for ModeSolverData to inherit from Some shared code with SimulationData w.r.t. plotting fields SimulationData.plot_field now doesn't need x/y/z/ for a 2D monitor --- tidy3d/components/data.py | 179 ++++++++++++++++++++--------- tidy3d/plugins/mode/mode_solver.py | 109 ++++++++++++++++-- tidy3d/web/config.py | 2 +- 3 files changed, 224 insertions(+), 66 deletions(-) diff --git a/tidy3d/components/data.py b/tidy3d/components/data.py index b3aad4a314..7c0ef9f0fd 100644 --- a/tidy3d/components/data.py +++ b/tidy3d/components/data.py @@ -9,11 +9,12 @@ import numpy as np import h5py -from .types import Numpy, Direction, Array, numpy_encoding, Literal, Ax, Coordinate, Symmetry +from .types import Numpy, Direction, Array, numpy_encoding, Literal, Ax, Coordinate, Symmetry, Axis from .base import Tidy3dBaseModel from .simulation import Simulation from .grid import YeeGrid from .mode import ModeSpec +from .monitor import PlanarMonitor from .viz import add_ax_if_none, equal_aspect from ..log import log, DataError @@ -842,7 +843,106 @@ def sel_mode_index(self, mode_index): } -class SimulationData(Tidy3dBaseModel): +class AbstractSimulationData(Tidy3dBaseModel, ABC): + """Abstract class to store a simulation and some data associated with it.""" + + simulation: Simulation + + @equal_aspect + @add_ax_if_none + # pylint:disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements + def plot_field_array( + self, + field_data: xr.DataArray, + axis: Axis, + position: float, + val: Literal["real", "imag", "abs"] = "real", + freq: float = None, + eps_alpha: float = 0.2, + robust: bool = True, + ax: Ax = None, + **patch_kwargs, + ) -> Ax: + """Plot the field data for a monitor with simulation plot overlayed. + + Parameters + ---------- + field_data: xr.DataArray + DataArray with the field data to plot. + axis: Axis + Axis normal to the plotting plane. + position: float + Position along the axis. + val : Literal['real', 'imag', 'abs'] = 'real' + Which part of the field to plot. + freq: float = None + Frequency at which the permittivity is evaluated at (if dispersive). + By default, chooses permittivity as frequency goes to infinity. + eps_alpha : float = 0.2 + Opacity of the structure permittivity. + Must be between 0 and 1 (inclusive). + robust : bool = True + If specified, uses the 2nd and 98th percentiles of the data to compute the color limits. + This helps in visualizing the field patterns especially in the presence of a source. + ax : matplotlib.axes._subplots.Axes = None + matplotlib axes to plot on, if not specified, one is created. + **patch_kwargs + Optional keyword arguments passed to ``add_artist(patch, **patch_kwargs)``. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + # select the cross section data + axis_label = "xyz"[axis] + interp_kwarg = {axis_label: position} + + if len(field_data.coords[axis_label]) > 1: + try: + field_data = field_data.interp(**interp_kwarg) + + except Exception as e: + raise DataError(f"Could not interpolate data at {axis_label}={position}.") from e + + # select the field value + if val not in ("real", "imag", "abs"): + raise DataError(f"'val' must be one of ``{'real', 'imag', 'abs'}``, given {val}") + + if val == "real": + field_data = field_data.real + elif val == "imag": + field_data = field_data.imag + elif val == "abs": + field_data = abs(field_data) + + if val == "abs": + cmap = "magma" + else: + cmap = "RdBu" + + # plot the field + xy_coord_labels = list("xyz") + xy_coord_labels.pop(axis) + x_coord_label, y_coord_label = xy_coord_labels # pylint:disable=unbalanced-tuple-unpacking + field_data.plot(ax=ax, x=x_coord_label, y=y_coord_label, robust=robust, cmap=cmap) + + # plot the simulation epsilon + ax = self.simulation.plot_structures_eps( + freq=freq, cbar=False, alpha=eps_alpha, ax=ax, **{axis_label: position}, **patch_kwargs + ) + + # set the limits based on the xarray coordinates min and max + x_coord_values = field_data.coords[x_coord_label] + y_coord_values = field_data.coords[y_coord_label] + ax.set_xlim(min(x_coord_values), max(x_coord_values)) + ax.set_ylim(min(y_coord_values), max(y_coord_values)) + + return ax + + +class SimulationData(AbstractSimulationData): """Holds :class:`Monitor` data associated with :class:`Simulation`. Parameters @@ -859,7 +959,6 @@ class SimulationData(Tidy3dBaseModel): A boolean flag denoting whether the data has been normalized by the spectrum of a source. """ - simulation: Simulation monitor_data: Dict[str, Tidy3dData] log_string: str = None diverged: bool = False @@ -898,9 +997,8 @@ def __getitem__(self, monitor_name: str) -> Union[Tidy3dDataArray, xr.Dataset]: a collection data instance is returned. Otherwise, if it is a MonitorData instance, the xarray representation is returned. """ + self.ensure_monitor_exists(monitor_name) monitor_data = self.monitor_data.get(monitor_name) - if not monitor_data: - raise DataError(f"monitor '{monitor_name}' not found") if isinstance(monitor_data, MonitorData): return monitor_data.data return monitor_data @@ -932,7 +1030,6 @@ def at_centers(self, field_monitor_name: str) -> xr.Dataset: """ # get the data - self.ensure_monitor_exists(field_monitor_name) field_monitor_data = self.monitor_data.get(field_monitor_name) self.ensure_field_monitor(field_monitor_data) @@ -945,8 +1042,6 @@ def at_centers(self, field_monitor_name: str) -> xr.Dataset: field_dataset = field_monitor_data.colocate(x=centers.x, y=centers.y, z=centers.z) return field_dataset - @equal_aspect - @add_ax_if_none # pylint:disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements def plot_field( self, @@ -1008,13 +1103,12 @@ def plot_field( """ # get the monitor data - self.ensure_monitor_exists(field_monitor_name) monitor_data = self.monitor_data.get(field_monitor_name) + self.ensure_field_monitor(monitor_data) if isinstance(monitor_data, ModeFieldData): if mode_index is None: raise DataError("'mode_index' must be supplied to plot a ModeFieldMonitor.") monitor_data = monitor_data.sel_mode_index(mode_index=mode_index) - self.ensure_field_monitor(monitor_data) # get the field data component if field_name == "int": @@ -1023,6 +1117,7 @@ def plot_field( for field in ("Ex", "Ey", "Ez"): field_data = monitor_data[field] xr_data += abs(field_data) ** 2 + val = "abs" else: monitor_data.ensure_member_exists(field_name) xr_data = monitor_data.data_dict.get(field_name).data @@ -1039,54 +1134,32 @@ def plot_field( else: raise DataError("Field data has neither time nor frequency data, something went wrong.") - # select the cross section data - axis, pos = self.simulation.parse_xyz_kwargs(x=x, y=y, z=z) - axis_label = "xyz"[axis] - interp_kwarg = {axis_label: pos} - - if len(field_data.coords[axis_label]) > 1: + if x is None and y is None and z is None: + """If a planar monitor, infer x/y/z based on the plane position and normal.""" + monitor = self.simulation.get_monitor_by_name(field_monitor_name) try: - field_data = field_data.interp(**interp_kwarg) - + axis = monitor.geometry.size.index(0.0) + position = monitor.geometry.center[axis] except Exception as e: - raise DataError(f"Could not interpolate data at {axis_label}={pos}.") from e - - # select the field value - if val not in ("real", "imag", "abs"): - raise DataError(f"'val' must be one of ``{'real', 'imag', 'abs'}``, given {val}") - - if field_name != "int": - if val == "real": - field_data = field_data.real - elif val == "imag": - field_data = field_data.imag - elif val == "abs": - field_data = abs(field_data) - - if val == "abs" or field_name == "int": - cmap = "magma" + raise ValueError( + "If none of 'x', 'y' or 'z' is specified, monitor must have a " + "zero-sized dimension" + ) from e else: - cmap = "RdBu" - - # plot the field - xy_coord_labels = list("xyz") - xy_coord_labels.pop(axis) - x_coord_label, y_coord_label = xy_coord_labels # pylint:disable=unbalanced-tuple-unpacking - field_data.plot(ax=ax, x=x_coord_label, y=y_coord_label, robust=robust, cmap=cmap) - - # plot the simulation epsilon - ax = self.simulation.plot_structures_eps( - freq=freq, cbar=False, x=x, y=y, z=z, alpha=eps_alpha, ax=ax, **patch_kwargs + axis, position = self.simulation.parse_xyz_kwargs(x=x, y=y, z=z) + + return self.plot_field_array( + field_data=field_data, + axis=axis, + position=position, + val=val, + freq=freq, + eps_alpha=eps_alpha, + robust=robust, + ax=ax, + **patch_kwargs, ) - # set the limits based on the xarray coordinates min and max - x_coord_values = field_data.coords[x_coord_label] - y_coord_values = field_data.coords[y_coord_label] - ax.set_xlim(min(x_coord_values), max(x_coord_values)) - ax.set_ylim(min(y_coord_values), max(y_coord_values)) - - return ax - def normalize(self, normalize_index: int = 0): """Return a copy of the :class:`.SimulationData` object with data normalized by source. diff --git a/tidy3d/plugins/mode/mode_solver.py b/tidy3d/plugins/mode/mode_solver.py index 02a8ff2ed0..21bb8248f8 100644 --- a/tidy3d/plugins/mode/mode_solver.py +++ b/tidy3d/plugins/mode/mode_solver.py @@ -15,9 +15,10 @@ from ...components import ModeSpec from ...components import ModeMonitor from ...components.source import ModeSource, SourceTime -from ...components.types import Direction, Array +from ...components.types import Direction, Array, Ax, Literal, ArrayLike from ...components.data import Tidy3dData, ModeIndexData, ModeFieldData, ScalarModeFieldData -from ...log import ValidationError +from ...components.data import AbstractSimulationData +from ...log import ValidationError, DataError from .solver import compute_modes @@ -27,18 +28,19 @@ FIELD_DECAY_CUTOFF = 1e-2 -class ModeSolverData(Tidy3dBaseModel): +class ModeSolverData(AbstractSimulationData): """Holds data associated with :class:`.ModeSolver`. Parameters ---------- - mode_solver : :class:`.ModeSolver` - Original mode solver instance. - data_dict : Dict[str, :class:`.AbstractModeData`] + plane : :class:`.Box` + Cross-sectional plane in which the modes were be computed. + mode_spec : :class:`.ModeSpec` + Container with specifications about the modes. + data_dict : Dict[str, Union[ModeFieldData, ModeIndexData]] Mapping of "n_complex" to :class:`.ModeIndexData`, and "fields" to :class:`.ModeFieldData`. """ - simulation: Simulation plane: Box mode_spec: ModeSpec data_dict: Dict[str, Union[ModeFieldData, ModeIndexData]] @@ -156,6 +158,87 @@ def from_file(cls, fname: str): return mode_solver + # pylint:disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements + def plot_field( + self, + field_name: str, + val: Literal["real", "imag", "abs"] = "real", + freq: float = None, + mode_index: int = None, + eps_alpha: float = 0.2, + robust: bool = True, + ax: Ax = None, + **patch_kwargs, + ) -> Ax: + """Plot the field data for a monitor with simulation plot overlayed. + + Parameters + ---------- + field_name : str + Name of `field` to plot (eg. 'Ex'). + Also accepts `'int'` to plot intensity. + val : Literal['real', 'imag', 'abs'] = 'real' + Which part of the field to plot. + If ``field_name='int'``, this has no effect. + freq: float = None + Specifies the frequency (Hz) to plot. + Also sets the frequency at which the permittivity is evaluated at (if dispersive). + mode_index: int = None + Specifies which mode index to plot. + eps_alpha : float = 0.2 + Opacity of the structure permittivity. + Must be between 0 and 1 (inclusive). + robust : bool = True + If specified, uses the 2nd and 98th percentiles of the data to compute the color limits. + This helps in visualizing the field patterns especially in the presence of a source. + ax : matplotlib.axes._subplots.Axes = None + matplotlib axes to plot on, if not specified, one is created. + **patch_kwargs + Optional keyword arguments passed to ``add_artist(patch, **patch_kwargs)``. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + if mode_index >= self.mode_spec.num_modes: + raise DataError("``mode_index`` larger than ``mode_spec.num_modes``.") + mode_fields = self.fields.sel_mode_index(mode_index=mode_index) + + # get the field data component + if field_name == "int": + xr_data = 0.0 + for field in ("Ex", "Ey", "Ez"): + mode_fields = mode_fields[field] + xr_data += abs(mode_fields) ** 2 + val = "abs" + else: + xr_data = mode_fields.data_dict.get(field_name).data + + field_data = xr_data.sel(f=freq, method="nearest") + + axis = self.plane.size.index(0.0) + position = self.plane.center[axis] + + ax = self.plot_field_array( + field_data=field_data, + axis=axis, + position=position, + val=val, + freq=freq, + eps_alpha=eps_alpha, + robust=robust, + ax=ax, + **patch_kwargs, + ) + + n_eff = self.n_eff.isel(mode_index=mode_index).sel(f=freq, method="nearest") + title = f"f={float(field_data.f):1.2e}, n_eff={float(n_eff):1.4f}" + ax.set_title(title) + + return ax + class ModeSolver(Tidy3dBaseModel): """Interface for solving electromagnetic eigenmodes in a 2D plane with translational @@ -176,7 +259,7 @@ class ModeSolver(Tidy3dBaseModel): description="Container with specifications about the modes to be solved for.", ) - freqs: List[float] = pydantic.Field( + freqs: Union[List[float], ArrayLike] = pydantic.Field( ..., title="Frequencies", description="A list of frequencies at which to solve." ) @@ -361,7 +444,8 @@ def to_source( direction: Direction, mode_index: int = 0, ) -> ModeSource: - """Creates :class:`.ModeSource` from a ModeSolver instance + additional specifications. + """Creates :class:`.ModeSource` from a :class:`.ModeSolver` instance plus additional + specifications. Parameters ---------- @@ -389,7 +473,8 @@ def to_source( ) def to_monitor(self, freqs: List[float], name: str) -> ModeMonitor: - """Creates :class:`ModeMonitor` from a ModeSolver instance + additional specifications. + """Creates :class:`ModeMonitor` from a :class:`.ModeSolver` instance plus additional + specifications. Parameters ---------- @@ -406,8 +491,8 @@ def to_monitor(self, freqs: List[float], name: str) -> ModeMonitor: """ return ModeMonitor( - center=self.plane.size, - size=self.plane.center, + center=self.plane.center, + size=self.plane.size, freqs=freqs, mode_spec=self.mode_spec, name=name, diff --git a/tidy3d/web/config.py b/tidy3d/web/config.py index aa8fcfa2d4..259a970669 100644 --- a/tidy3d/web/config.py +++ b/tidy3d/web/config.py @@ -4,7 +4,7 @@ from dataclasses import dataclass -SOLVER_VERSION = "multinode-22.1.6" +SOLVER_VERSION = "release-22.1.6" @dataclass