9
9
import numpy as np
10
10
import h5py
11
11
12
- from .types import Numpy , Direction , Array , numpy_encoding , Literal , Ax , Coordinate , Symmetry
12
+ from .types import Numpy , Direction , Array , numpy_encoding , Literal , Ax , Coordinate , Symmetry , Axis
13
13
from .base import Tidy3dBaseModel
14
14
from .simulation import Simulation
15
15
from .grid import YeeGrid
16
16
from .mode import ModeSpec
17
+ from .monitor import PlanarMonitor
17
18
from .viz import add_ax_if_none , equal_aspect
18
19
from ..log import log , DataError
19
20
@@ -842,7 +843,106 @@ def sel_mode_index(self, mode_index):
842
843
}
843
844
844
845
845
- class SimulationData (Tidy3dBaseModel ):
846
+ class AbstractSimulationData (Tidy3dBaseModel , ABC ):
847
+ """Abstract class to store a simulation and some data associated with it."""
848
+
849
+ simulation : Simulation
850
+
851
+ @equal_aspect
852
+ @add_ax_if_none
853
+ # pylint:disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements
854
+ def plot_field_array (
855
+ self ,
856
+ field_data : xr .DataArray ,
857
+ axis : Axis ,
858
+ position : float ,
859
+ val : Literal ["real" , "imag" , "abs" ] = "real" ,
860
+ freq : float = None ,
861
+ eps_alpha : float = 0.2 ,
862
+ robust : bool = True ,
863
+ ax : Ax = None ,
864
+ ** patch_kwargs ,
865
+ ) -> Ax :
866
+ """Plot the field data for a monitor with simulation plot overlayed.
867
+
868
+ Parameters
869
+ ----------
870
+ field_data: xr.DataArray
871
+ DataArray with the field data to plot.
872
+ axis: Axis
873
+ Axis normal to the plotting plane.
874
+ position: float
875
+ Position along the axis.
876
+ val : Literal['real', 'imag', 'abs'] = 'real'
877
+ Which part of the field to plot.
878
+ freq: float = None
879
+ Frequency at which the permittivity is evaluated at (if dispersive).
880
+ By default, chooses permittivity as frequency goes to infinity.
881
+ eps_alpha : float = 0.2
882
+ Opacity of the structure permittivity.
883
+ Must be between 0 and 1 (inclusive).
884
+ robust : bool = True
885
+ If specified, uses the 2nd and 98th percentiles of the data to compute the color limits.
886
+ This helps in visualizing the field patterns especially in the presence of a source.
887
+ ax : matplotlib.axes._subplots.Axes = None
888
+ matplotlib axes to plot on, if not specified, one is created.
889
+ **patch_kwargs
890
+ Optional keyword arguments passed to ``add_artist(patch, **patch_kwargs)``.
891
+
892
+ Returns
893
+ -------
894
+ matplotlib.axes._subplots.Axes
895
+ The supplied or created matplotlib axes.
896
+ """
897
+
898
+ # select the cross section data
899
+ axis_label = "xyz" [axis ]
900
+ interp_kwarg = {axis_label : position }
901
+
902
+ if len (field_data .coords [axis_label ]) > 1 :
903
+ try :
904
+ field_data = field_data .interp (** interp_kwarg )
905
+
906
+ except Exception as e :
907
+ raise DataError (f"Could not interpolate data at { axis_label } ={ position } ." ) from e
908
+
909
+ # select the field value
910
+ if val not in ("real" , "imag" , "abs" ):
911
+ raise DataError (f"'val' must be one of ``{ 'real' , 'imag' , 'abs' } ``, given { val } " )
912
+
913
+ if val == "real" :
914
+ field_data = field_data .real
915
+ elif val == "imag" :
916
+ field_data = field_data .imag
917
+ elif val == "abs" :
918
+ field_data = abs (field_data )
919
+
920
+ if val == "abs" :
921
+ cmap = "magma"
922
+ else :
923
+ cmap = "RdBu"
924
+
925
+ # plot the field
926
+ xy_coord_labels = list ("xyz" )
927
+ xy_coord_labels .pop (axis )
928
+ x_coord_label , y_coord_label = xy_coord_labels # pylint:disable=unbalanced-tuple-unpacking
929
+ field_data .plot (ax = ax , x = x_coord_label , y = y_coord_label , robust = robust , cmap = cmap )
930
+
931
+ # plot the simulation epsilon
932
+ ax = self .simulation .plot_structures_eps (
933
+ freq = freq , cbar = False , alpha = eps_alpha , ax = ax , ** {axis_label : position }, ** patch_kwargs
934
+ )
935
+
936
+ # set the limits based on the xarray coordinates min and max
937
+ x_coord_values = field_data .coords [x_coord_label ]
938
+ y_coord_values = field_data .coords [y_coord_label ]
939
+ ax .set_xlim (min (x_coord_values ), max (x_coord_values ))
940
+ ax .set_ylim (min (y_coord_values ), max (y_coord_values ))
941
+
942
+ return ax
943
+
944
+
945
+ class SimulationData (AbstractSimulationData ):
846
946
"""Holds :class:`Monitor` data associated with :class:`Simulation`.
847
947
848
948
Parameters
@@ -859,7 +959,6 @@ class SimulationData(Tidy3dBaseModel):
859
959
A boolean flag denoting whether the data has been normalized by the spectrum of a source.
860
960
"""
861
961
862
- simulation : Simulation
863
962
monitor_data : Dict [str , Tidy3dData ]
864
963
log_string : str = None
865
964
diverged : bool = False
@@ -898,9 +997,8 @@ def __getitem__(self, monitor_name: str) -> Union[Tidy3dDataArray, xr.Dataset]:
898
997
a collection data instance is returned.
899
998
Otherwise, if it is a MonitorData instance, the xarray representation is returned.
900
999
"""
1000
+ self .ensure_monitor_exists (monitor_name )
901
1001
monitor_data = self .monitor_data .get (monitor_name )
902
- if not monitor_data :
903
- raise DataError (f"monitor '{ monitor_name } ' not found" )
904
1002
if isinstance (monitor_data , MonitorData ):
905
1003
return monitor_data .data
906
1004
return monitor_data
@@ -932,7 +1030,6 @@ def at_centers(self, field_monitor_name: str) -> xr.Dataset:
932
1030
"""
933
1031
934
1032
# get the data
935
- self .ensure_monitor_exists (field_monitor_name )
936
1033
field_monitor_data = self .monitor_data .get (field_monitor_name )
937
1034
self .ensure_field_monitor (field_monitor_data )
938
1035
@@ -945,8 +1042,6 @@ def at_centers(self, field_monitor_name: str) -> xr.Dataset:
945
1042
field_dataset = field_monitor_data .colocate (x = centers .x , y = centers .y , z = centers .z )
946
1043
return field_dataset
947
1044
948
- @equal_aspect
949
- @add_ax_if_none
950
1045
# pylint:disable=too-many-arguments, too-many-locals, too-many-branches, too-many-statements
951
1046
def plot_field (
952
1047
self ,
@@ -1008,13 +1103,12 @@ def plot_field(
1008
1103
"""
1009
1104
1010
1105
# get the monitor data
1011
- self .ensure_monitor_exists (field_monitor_name )
1012
1106
monitor_data = self .monitor_data .get (field_monitor_name )
1107
+ self .ensure_field_monitor (monitor_data )
1013
1108
if isinstance (monitor_data , ModeFieldData ):
1014
1109
if mode_index is None :
1015
1110
raise DataError ("'mode_index' must be supplied to plot a ModeFieldMonitor." )
1016
1111
monitor_data = monitor_data .sel_mode_index (mode_index = mode_index )
1017
- self .ensure_field_monitor (monitor_data )
1018
1112
1019
1113
# get the field data component
1020
1114
if field_name == "int" :
@@ -1023,6 +1117,7 @@ def plot_field(
1023
1117
for field in ("Ex" , "Ey" , "Ez" ):
1024
1118
field_data = monitor_data [field ]
1025
1119
xr_data += abs (field_data ) ** 2
1120
+ val = "abs"
1026
1121
else :
1027
1122
monitor_data .ensure_member_exists (field_name )
1028
1123
xr_data = monitor_data .data_dict .get (field_name ).data
@@ -1039,54 +1134,32 @@ def plot_field(
1039
1134
else :
1040
1135
raise DataError ("Field data has neither time nor frequency data, something went wrong." )
1041
1136
1042
- # select the cross section data
1043
- axis , pos = self .simulation .parse_xyz_kwargs (x = x , y = y , z = z )
1044
- axis_label = "xyz" [axis ]
1045
- interp_kwarg = {axis_label : pos }
1046
-
1047
- if len (field_data .coords [axis_label ]) > 1 :
1137
+ if x is None and y is None and z is None :
1138
+ """If a planar monitor, infer x/y/z based on the plane position and normal."""
1139
+ monitor = self .simulation .get_monitor_by_name (field_monitor_name )
1048
1140
try :
1049
- field_data = field_data . interp ( ** interp_kwarg )
1050
-
1141
+ axis = monitor . geometry . size . index ( 0.0 )
1142
+ position = monitor . geometry . center [ axis ]
1051
1143
except Exception as e :
1052
- raise DataError (f"Could not interpolate data at { axis_label } ={ pos } ." ) from e
1053
-
1054
- # select the field value
1055
- if val not in ("real" , "imag" , "abs" ):
1056
- raise DataError (f"'val' must be one of ``{ 'real' , 'imag' , 'abs' } ``, given { val } " )
1057
-
1058
- if field_name != "int" :
1059
- if val == "real" :
1060
- field_data = field_data .real
1061
- elif val == "imag" :
1062
- field_data = field_data .imag
1063
- elif val == "abs" :
1064
- field_data = abs (field_data )
1065
-
1066
- if val == "abs" or field_name == "int" :
1067
- cmap = "magma"
1144
+ raise ValueError (
1145
+ "If none of 'x', 'y' or 'z' is specified, monitor must have a "
1146
+ "zero-sized dimension"
1147
+ ) from e
1068
1148
else :
1069
- cmap = "RdBu"
1070
-
1071
- # plot the field
1072
- xy_coord_labels = list ("xyz" )
1073
- xy_coord_labels .pop (axis )
1074
- x_coord_label , y_coord_label = xy_coord_labels # pylint:disable=unbalanced-tuple-unpacking
1075
- field_data .plot (ax = ax , x = x_coord_label , y = y_coord_label , robust = robust , cmap = cmap )
1076
-
1077
- # plot the simulation epsilon
1078
- ax = self .simulation .plot_structures_eps (
1079
- freq = freq , cbar = False , x = x , y = y , z = z , alpha = eps_alpha , ax = ax , ** patch_kwargs
1149
+ axis , position = self .simulation .parse_xyz_kwargs (x = x , y = y , z = z )
1150
+
1151
+ return self .plot_field_array (
1152
+ field_data = field_data ,
1153
+ axis = axis ,
1154
+ position = position ,
1155
+ val = val ,
1156
+ freq = freq ,
1157
+ eps_alpha = eps_alpha ,
1158
+ robust = robust ,
1159
+ ax = ax ,
1160
+ ** patch_kwargs ,
1080
1161
)
1081
1162
1082
- # set the limits based on the xarray coordinates min and max
1083
- x_coord_values = field_data .coords [x_coord_label ]
1084
- y_coord_values = field_data .coords [y_coord_label ]
1085
- ax .set_xlim (min (x_coord_values ), max (x_coord_values ))
1086
- ax .set_ylim (min (y_coord_values ), max (y_coord_values ))
1087
-
1088
- return ax
1089
-
1090
1163
def normalize (self , normalize_index : int = 0 ):
1091
1164
"""Return a copy of the :class:`.SimulationData` object with data normalized by source.
1092
1165
0 commit comments