Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def test_monitor_plane():
with pytest.raises(ValidationError) as e_info:
ModeMonitor(size=size, freqs=freqs, modes=[])
with pytest.raises(ValidationError) as e_info:
ModeSolverMonitor(size=size, freqs=freqs, modes=[])
ModeFieldMonitor(size=size, freqs=freqs, modes=[])
with pytest.raises(ValidationError) as e_info:
FluxMonitor(size=size, freqs=freqs, modes=[])

Expand Down
6 changes: 4 additions & 2 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,17 @@ def test_mode_solver():
size=(2, 2, 2), grid_size=(0.1, 0.1, 0.1), structures=[waveguide], run_time=1e-12
)
plane = td.Box(center=(0, 0, 0), size=(0, 1, 1))
ms = ModeSolver(simulation=simulation, plane=plane)
mode_spec = td.ModeSpec(
num_modes=3,
target_neff=2.0,
bend_radius=3.0,
bend_axis=0,
num_pml=(10, 10),
)
modes = ms.solve(mode_spec=mode_spec, freqs=[td.constants.C_0 / 1.0])
ms = ModeSolver(
simulation=simulation, plane=plane, mode_spec=mode_spec, freqs=[td.constants.C_0 / 1.0]
)
modes = ms.solve()


def _test_coeffs():
Expand Down
2 changes: 1 addition & 1 deletion tidy3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

# monitors
from .components import FieldMonitor, FieldTimeMonitor, FluxMonitor, FluxTimeMonitor
from .components import ModeMonitor, ModeSolverMonitor
from .components import ModeMonitor, ModeFieldMonitor

# simulation
from .components import Simulation
Expand Down
2 changes: 1 addition & 1 deletion tidy3d/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# monitor
from .monitor import FreqMonitor, TimeMonitor, FieldMonitor, FieldTimeMonitor
from .monitor import Monitor, FluxMonitor, FluxTimeMonitor, ModeMonitor
from .monitor import ModeSolverMonitor
from .monitor import ModeFieldMonitor

# simulation
from .simulation import Simulation
Expand Down
232 changes: 130 additions & 102 deletions tidy3d/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import numpy as np
import h5py

from .types import Numpy, Direction, Array, numpy_encoding, Literal, Ax, Coordinate, Symmetry
from .types import Numpy, Direction, Array, numpy_encoding, Literal, Ax, Coordinate, Symmetry, Axis
from .base import Tidy3dBaseModel
from .simulation import Simulation
from .grid import YeeGrid
from .mode import ModeSpec
from .monitor import PlanarMonitor
from .viz import add_ax_if_none, equal_aspect
from ..log import log, DataError

Expand Down Expand Up @@ -826,48 +827,6 @@ def sel_mode_index(self, mode_index):
return FieldData(data_dict=data_dict)


class ModeSolverData(CollectionData):
"""Stores a collection of mode field profiles and mode effective indexes from the mode solver.

Parameters
----------
data_dict : Dict[str, :class:`AbstractModeData`]
Mapping of "n_complex" to :class:`ModeIndexData`, and "fields" to :class:`ModeFieldData`.
"""

data_dict: Dict[str, Union[AbstractModeData, AbstractFieldData]]
type: Literal["ModeSolverData"] = "ModeSolverData"

@property
def fields(self):
"""Get field data."""
return self.data_dict.get("fields")

@property
def n_complex(self):
"""Get complex effective indexes."""
scalar_data = self.data_dict.get("n_complex")
if scalar_data:
return scalar_data.data
return None

@property
def n_eff(self):
"""Get real part of effective index."""
scalar_data = self.data_dict.get("n_complex")
if scalar_data:
return scalar_data.n_eff
return None

@property
def k_eff(self):
"""Get imaginary part of effective index."""
scalar_data = self.data_dict.get("n_complex")
if scalar_data:
return scalar_data.k_eff
return None


# maps MonitorData.type string to the actual type, for MonitorData.from_file()
DATA_TYPE_MAP = {
"ScalarFieldData": ScalarFieldData,
Expand All @@ -881,11 +840,109 @@ def k_eff(self):
"ModeData": ModeData,
"ModeFieldData": ModeFieldData,
"ScalarModeFieldData": ScalarModeFieldData,
"ModeSolverData": ModeSolverData,
}


class SimulationData(Tidy3dBaseModel):
class AbstractSimulationData(Tidy3dBaseModel, ABC):
"""Abstract class to store a simulation and some data associated with it."""

simulation: Simulation

@equal_aspect
@add_ax_if_none
# pylint:disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements
def plot_field_array(
self,
field_data: xr.DataArray,
axis: Axis,
position: float,
val: Literal["real", "imag", "abs"] = "real",
freq: float = None,
eps_alpha: float = 0.2,
robust: bool = True,
ax: Ax = None,
**patch_kwargs,
) -> Ax:
"""Plot the field data for a monitor with simulation plot overlayed.

Parameters
----------
field_data: xr.DataArray
DataArray with the field data to plot.
axis: Axis
Axis normal to the plotting plane.
position: float
Position along the axis.
val : Literal['real', 'imag', 'abs'] = 'real'
Which part of the field to plot.
freq: float = None
Frequency at which the permittivity is evaluated at (if dispersive).
By default, chooses permittivity as frequency goes to infinity.
eps_alpha : float = 0.2
Opacity of the structure permittivity.
Must be between 0 and 1 (inclusive).
robust : bool = True
If specified, uses the 2nd and 98th percentiles of the data to compute the color limits.
This helps in visualizing the field patterns especially in the presence of a source.
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
**patch_kwargs
Optional keyword arguments passed to ``add_artist(patch, **patch_kwargs)``.

Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""

# select the cross section data
axis_label = "xyz"[axis]
interp_kwarg = {axis_label: position}

if len(field_data.coords[axis_label]) > 1:
try:
field_data = field_data.interp(**interp_kwarg)

except Exception as e:
raise DataError(f"Could not interpolate data at {axis_label}={position}.") from e

# select the field value
if val not in ("real", "imag", "abs"):
raise DataError(f"'val' must be one of ``{'real', 'imag', 'abs'}``, given {val}")

if val == "real":
field_data = field_data.real
elif val == "imag":
field_data = field_data.imag
elif val == "abs":
field_data = abs(field_data)

if val == "abs":
cmap = "magma"
else:
cmap = "RdBu"

# plot the field
xy_coord_labels = list("xyz")
xy_coord_labels.pop(axis)
x_coord_label, y_coord_label = xy_coord_labels # pylint:disable=unbalanced-tuple-unpacking
field_data.plot(ax=ax, x=x_coord_label, y=y_coord_label, robust=robust, cmap=cmap)

# plot the simulation epsilon
ax = self.simulation.plot_structures_eps(
freq=freq, cbar=False, alpha=eps_alpha, ax=ax, **{axis_label: position}, **patch_kwargs
)

# set the limits based on the xarray coordinates min and max
x_coord_values = field_data.coords[x_coord_label]
y_coord_values = field_data.coords[y_coord_label]
ax.set_xlim(min(x_coord_values), max(x_coord_values))
ax.set_ylim(min(y_coord_values), max(y_coord_values))

return ax


class SimulationData(AbstractSimulationData):
"""Holds :class:`Monitor` data associated with :class:`Simulation`.

Parameters
Expand All @@ -902,7 +959,6 @@ class SimulationData(Tidy3dBaseModel):
A boolean flag denoting whether the data has been normalized by the spectrum of a source.
"""

simulation: Simulation
monitor_data: Dict[str, Tidy3dData]
log_string: str = None
diverged: bool = False
Expand Down Expand Up @@ -941,9 +997,8 @@ def __getitem__(self, monitor_name: str) -> Union[Tidy3dDataArray, xr.Dataset]:
a collection data instance is returned.
Otherwise, if it is a MonitorData instance, the xarray representation is returned.
"""
self.ensure_monitor_exists(monitor_name)
monitor_data = self.monitor_data.get(monitor_name)
if not monitor_data:
raise DataError(f"monitor '{monitor_name}' not found")
if isinstance(monitor_data, MonitorData):
return monitor_data.data
return monitor_data
Expand Down Expand Up @@ -975,10 +1030,7 @@ def at_centers(self, field_monitor_name: str) -> xr.Dataset:
"""

# get the data
self.ensure_monitor_exists(field_monitor_name)
field_monitor_data = self.monitor_data.get(field_monitor_name)
if isinstance(field_monitor_data, ModeSolverData):
field_monitor_data = field_monitor_data.fields
self.ensure_field_monitor(field_monitor_data)

# get the monitor, discretize, and get center locations
Expand All @@ -990,8 +1042,6 @@ def at_centers(self, field_monitor_name: str) -> xr.Dataset:
field_dataset = field_monitor_data.colocate(x=centers.x, y=centers.y, z=centers.z)
return field_dataset

@equal_aspect
@add_ax_if_none
# pylint:disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements
def plot_field(
self,
Expand Down Expand Up @@ -1034,7 +1084,7 @@ def plot_field(
time: float = None
if monitor is a :class:`FieldTimeMonitor`, specifies the time (sec) to plot the field.
mode_index: int = None
if monitor is a :class:`ModeSolverMonitor`, specifies which mode index to plot.
if monitor is a :class:`ModeFieldMonitor`, specifies which mode index to plot.
eps_alpha : float = 0.2
Opacity of the structure permittivity.
Must be between 0 and 1 (inclusive).
Expand All @@ -1053,13 +1103,12 @@ def plot_field(
"""

# get the monitor data
self.ensure_monitor_exists(field_monitor_name)
monitor_data = self.monitor_data.get(field_monitor_name)
if isinstance(monitor_data, ModeSolverData):
if mode_index is None:
raise DataError("'mode_index' must be supplied to plot a ModeSolverMonitor.")
monitor_data = monitor_data.fields.sel_mode_index(mode_index=mode_index)
self.ensure_field_monitor(monitor_data)
if isinstance(monitor_data, ModeFieldData):
if mode_index is None:
raise DataError("'mode_index' must be supplied to plot a ModeFieldMonitor.")
monitor_data = monitor_data.sel_mode_index(mode_index=mode_index)

# get the field data component
if field_name == "int":
Expand All @@ -1068,6 +1117,7 @@ def plot_field(
for field in ("Ex", "Ey", "Ez"):
field_data = monitor_data[field]
xr_data += abs(field_data) ** 2
val = "abs"
else:
monitor_data.ensure_member_exists(field_name)
xr_data = monitor_data.data_dict.get(field_name).data
Expand All @@ -1084,54 +1134,32 @@ def plot_field(
else:
raise DataError("Field data has neither time nor frequency data, something went wrong.")

# select the cross section data
axis, pos = self.simulation.parse_xyz_kwargs(x=x, y=y, z=z)
axis_label = "xyz"[axis]
interp_kwarg = {axis_label: pos}

if len(field_data.coords[axis_label]) > 1:
if x is None and y is None and z is None:
"""If a planar monitor, infer x/y/z based on the plane position and normal."""
monitor = self.simulation.get_monitor_by_name(field_monitor_name)
try:
field_data = field_data.interp(**interp_kwarg)

axis = monitor.geometry.size.index(0.0)
position = monitor.geometry.center[axis]
except Exception as e:
raise DataError(f"Could not interpolate data at {axis_label}={pos}.") from e

# select the field value
if val not in ("real", "imag", "abs"):
raise DataError(f"'val' must be one of ``{'real', 'imag', 'abs'}``, given {val}")

if field_name != "int":
if val == "real":
field_data = field_data.real
elif val == "imag":
field_data = field_data.imag
elif val == "abs":
field_data = abs(field_data)

if val == "abs" or field_name == "int":
cmap = "magma"
raise ValueError(
"If none of 'x', 'y' or 'z' is specified, monitor must have a "
"zero-sized dimension"
) from e
else:
cmap = "RdBu"

# plot the field
xy_coord_labels = list("xyz")
xy_coord_labels.pop(axis)
x_coord_label, y_coord_label = xy_coord_labels # pylint:disable=unbalanced-tuple-unpacking
field_data.plot(ax=ax, x=x_coord_label, y=y_coord_label, robust=robust, cmap=cmap)

# plot the simulation epsilon
ax = self.simulation.plot_structures_eps(
freq=freq, cbar=False, x=x, y=y, z=z, alpha=eps_alpha, ax=ax, **patch_kwargs
axis, position = self.simulation.parse_xyz_kwargs(x=x, y=y, z=z)

return self.plot_field_array(
field_data=field_data,
axis=axis,
position=position,
val=val,
freq=freq,
eps_alpha=eps_alpha,
robust=robust,
ax=ax,
**patch_kwargs,
)

# set the limits based on the xarray coordinates min and max
x_coord_values = field_data.coords[x_coord_label]
y_coord_values = field_data.coords[y_coord_label]
ax.set_xlim(min(x_coord_values), max(x_coord_values))
ax.set_ylim(min(y_coord_values), max(y_coord_values))

return ax

def normalize(self, normalize_index: int = 0):
"""Return a copy of the :class:`.SimulationData` object with data normalized by source.

Expand Down
Loading