Skip to content

Commit aee7ea9

Browse files
committed
Making AbstractSimulationData for ModeSolverData to inherit from
Some shared code with SimulationData w.r.t. plotting fields SimulationData.plot_field now doesn't need x/y/z/ for a 2D monitor
1 parent c996af1 commit aee7ea9

File tree

3 files changed

+224
-66
lines changed

3 files changed

+224
-66
lines changed

tidy3d/components/data.py

Lines changed: 126 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
import numpy as np
1010
import h5py
1111

12-
from .types import Numpy, Direction, Array, numpy_encoding, Literal, Ax, Coordinate, Symmetry
12+
from .types import Numpy, Direction, Array, numpy_encoding, Literal, Ax, Coordinate, Symmetry, Axis
1313
from .base import Tidy3dBaseModel
1414
from .simulation import Simulation
1515
from .grid import YeeGrid
1616
from .mode import ModeSpec
17+
from .monitor import PlanarMonitor
1718
from .viz import add_ax_if_none, equal_aspect
1819
from ..log import log, DataError
1920

@@ -842,7 +843,106 @@ def sel_mode_index(self, mode_index):
842843
}
843844

844845

845-
class SimulationData(Tidy3dBaseModel):
846+
class AbstractSimulationData(Tidy3dBaseModel, ABC):
847+
"""Abstract class to store a simulation and some data associated with it."""
848+
849+
simulation: Simulation
850+
851+
@equal_aspect
852+
@add_ax_if_none
853+
# pylint:disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements
854+
def plot_field_array(
855+
self,
856+
field_data: xr.DataArray,
857+
axis: Axis,
858+
position: float,
859+
val: Literal["real", "imag", "abs"] = "real",
860+
freq: float = None,
861+
eps_alpha: float = 0.2,
862+
robust: bool = True,
863+
ax: Ax = None,
864+
**patch_kwargs,
865+
) -> Ax:
866+
"""Plot the field data for a monitor with simulation plot overlayed.
867+
868+
Parameters
869+
----------
870+
field_data: xr.DataArray
871+
DataArray with the field data to plot.
872+
axis: Axis
873+
Axis normal to the plotting plane.
874+
position: float
875+
Position along the axis.
876+
val : Literal['real', 'imag', 'abs'] = 'real'
877+
Which part of the field to plot.
878+
freq: float = None
879+
Frequency at which the permittivity is evaluated at (if dispersive).
880+
By default, chooses permittivity as frequency goes to infinity.
881+
eps_alpha : float = 0.2
882+
Opacity of the structure permittivity.
883+
Must be between 0 and 1 (inclusive).
884+
robust : bool = True
885+
If specified, uses the 2nd and 98th percentiles of the data to compute the color limits.
886+
This helps in visualizing the field patterns especially in the presence of a source.
887+
ax : matplotlib.axes._subplots.Axes = None
888+
matplotlib axes to plot on, if not specified, one is created.
889+
**patch_kwargs
890+
Optional keyword arguments passed to ``add_artist(patch, **patch_kwargs)``.
891+
892+
Returns
893+
-------
894+
matplotlib.axes._subplots.Axes
895+
The supplied or created matplotlib axes.
896+
"""
897+
898+
# select the cross section data
899+
axis_label = "xyz"[axis]
900+
interp_kwarg = {axis_label: position}
901+
902+
if len(field_data.coords[axis_label]) > 1:
903+
try:
904+
field_data = field_data.interp(**interp_kwarg)
905+
906+
except Exception as e:
907+
raise DataError(f"Could not interpolate data at {axis_label}={position}.") from e
908+
909+
# select the field value
910+
if val not in ("real", "imag", "abs"):
911+
raise DataError(f"'val' must be one of ``{'real', 'imag', 'abs'}``, given {val}")
912+
913+
if val == "real":
914+
field_data = field_data.real
915+
elif val == "imag":
916+
field_data = field_data.imag
917+
elif val == "abs":
918+
field_data = abs(field_data)
919+
920+
if val == "abs":
921+
cmap = "magma"
922+
else:
923+
cmap = "RdBu"
924+
925+
# plot the field
926+
xy_coord_labels = list("xyz")
927+
xy_coord_labels.pop(axis)
928+
x_coord_label, y_coord_label = xy_coord_labels # pylint:disable=unbalanced-tuple-unpacking
929+
field_data.plot(ax=ax, x=x_coord_label, y=y_coord_label, robust=robust, cmap=cmap)
930+
931+
# plot the simulation epsilon
932+
ax = self.simulation.plot_structures_eps(
933+
freq=freq, cbar=False, alpha=eps_alpha, ax=ax, **{axis_label: position}, **patch_kwargs
934+
)
935+
936+
# set the limits based on the xarray coordinates min and max
937+
x_coord_values = field_data.coords[x_coord_label]
938+
y_coord_values = field_data.coords[y_coord_label]
939+
ax.set_xlim(min(x_coord_values), max(x_coord_values))
940+
ax.set_ylim(min(y_coord_values), max(y_coord_values))
941+
942+
return ax
943+
944+
945+
class SimulationData(AbstractSimulationData):
846946
"""Holds :class:`Monitor` data associated with :class:`Simulation`.
847947
848948
Parameters
@@ -859,7 +959,6 @@ class SimulationData(Tidy3dBaseModel):
859959
A boolean flag denoting whether the data has been normalized by the spectrum of a source.
860960
"""
861961

862-
simulation: Simulation
863962
monitor_data: Dict[str, Tidy3dData]
864963
log_string: str = None
865964
diverged: bool = False
@@ -898,9 +997,8 @@ def __getitem__(self, monitor_name: str) -> Union[Tidy3dDataArray, xr.Dataset]:
898997
a collection data instance is returned.
899998
Otherwise, if it is a MonitorData instance, the xarray representation is returned.
900999
"""
1000+
self.ensure_monitor_exists(monitor_name)
9011001
monitor_data = self.monitor_data.get(monitor_name)
902-
if not monitor_data:
903-
raise DataError(f"monitor '{monitor_name}' not found")
9041002
if isinstance(monitor_data, MonitorData):
9051003
return monitor_data.data
9061004
return monitor_data
@@ -932,7 +1030,6 @@ def at_centers(self, field_monitor_name: str) -> xr.Dataset:
9321030
"""
9331031

9341032
# get the data
935-
self.ensure_monitor_exists(field_monitor_name)
9361033
field_monitor_data = self.monitor_data.get(field_monitor_name)
9371034
self.ensure_field_monitor(field_monitor_data)
9381035

@@ -945,8 +1042,6 @@ def at_centers(self, field_monitor_name: str) -> xr.Dataset:
9451042
field_dataset = field_monitor_data.colocate(x=centers.x, y=centers.y, z=centers.z)
9461043
return field_dataset
9471044

948-
@equal_aspect
949-
@add_ax_if_none
9501045
# pylint:disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements
9511046
def plot_field(
9521047
self,
@@ -1008,13 +1103,12 @@ def plot_field(
10081103
"""
10091104

10101105
# get the monitor data
1011-
self.ensure_monitor_exists(field_monitor_name)
10121106
monitor_data = self.monitor_data.get(field_monitor_name)
1107+
self.ensure_field_monitor(monitor_data)
10131108
if isinstance(monitor_data, ModeFieldData):
10141109
if mode_index is None:
10151110
raise DataError("'mode_index' must be supplied to plot a ModeFieldMonitor.")
10161111
monitor_data = monitor_data.sel_mode_index(mode_index=mode_index)
1017-
self.ensure_field_monitor(monitor_data)
10181112

10191113
# get the field data component
10201114
if field_name == "int":
@@ -1023,6 +1117,7 @@ def plot_field(
10231117
for field in ("Ex", "Ey", "Ez"):
10241118
field_data = monitor_data[field]
10251119
xr_data += abs(field_data) ** 2
1120+
val = "abs"
10261121
else:
10271122
monitor_data.ensure_member_exists(field_name)
10281123
xr_data = monitor_data.data_dict.get(field_name).data
@@ -1039,54 +1134,32 @@ def plot_field(
10391134
else:
10401135
raise DataError("Field data has neither time nor frequency data, something went wrong.")
10411136

1042-
# select the cross section data
1043-
axis, pos = self.simulation.parse_xyz_kwargs(x=x, y=y, z=z)
1044-
axis_label = "xyz"[axis]
1045-
interp_kwarg = {axis_label: pos}
1046-
1047-
if len(field_data.coords[axis_label]) > 1:
1137+
if x is None and y is None and z is None:
1138+
"""If a planar monitor, infer x/y/z based on the plane position and normal."""
1139+
monitor = self.simulation.get_monitor_by_name(field_monitor_name)
10481140
try:
1049-
field_data = field_data.interp(**interp_kwarg)
1050-
1141+
axis = monitor.geometry.size.index(0.0)
1142+
position = monitor.geometry.center[axis]
10511143
except Exception as e:
1052-
raise DataError(f"Could not interpolate data at {axis_label}={pos}.") from e
1053-
1054-
# select the field value
1055-
if val not in ("real", "imag", "abs"):
1056-
raise DataError(f"'val' must be one of ``{'real', 'imag', 'abs'}``, given {val}")
1057-
1058-
if field_name != "int":
1059-
if val == "real":
1060-
field_data = field_data.real
1061-
elif val == "imag":
1062-
field_data = field_data.imag
1063-
elif val == "abs":
1064-
field_data = abs(field_data)
1065-
1066-
if val == "abs" or field_name == "int":
1067-
cmap = "magma"
1144+
raise ValueError(
1145+
"If none of 'x', 'y' or 'z' is specified, monitor must have a "
1146+
"zero-sized dimension"
1147+
) from e
10681148
else:
1069-
cmap = "RdBu"
1070-
1071-
# plot the field
1072-
xy_coord_labels = list("xyz")
1073-
xy_coord_labels.pop(axis)
1074-
x_coord_label, y_coord_label = xy_coord_labels # pylint:disable=unbalanced-tuple-unpacking
1075-
field_data.plot(ax=ax, x=x_coord_label, y=y_coord_label, robust=robust, cmap=cmap)
1076-
1077-
# plot the simulation epsilon
1078-
ax = self.simulation.plot_structures_eps(
1079-
freq=freq, cbar=False, x=x, y=y, z=z, alpha=eps_alpha, ax=ax, **patch_kwargs
1149+
axis, position = self.simulation.parse_xyz_kwargs(x=x, y=y, z=z)
1150+
1151+
return self.plot_field_array(
1152+
field_data=field_data,
1153+
axis=axis,
1154+
position=position,
1155+
val=val,
1156+
freq=freq,
1157+
eps_alpha=eps_alpha,
1158+
robust=robust,
1159+
ax=ax,
1160+
**patch_kwargs,
10801161
)
10811162

1082-
# set the limits based on the xarray coordinates min and max
1083-
x_coord_values = field_data.coords[x_coord_label]
1084-
y_coord_values = field_data.coords[y_coord_label]
1085-
ax.set_xlim(min(x_coord_values), max(x_coord_values))
1086-
ax.set_ylim(min(y_coord_values), max(y_coord_values))
1087-
1088-
return ax
1089-
10901163
def normalize(self, normalize_index: int = 0):
10911164
"""Return a copy of the :class:`.SimulationData` object with data normalized by source.
10921165

0 commit comments

Comments
 (0)