Skip to content

Commit 81d0d48

Browse files
committed
Further reorganization of angled sources; fixes to plotting
1 parent 694095a commit 81d0d48

File tree

9 files changed

+143
-125
lines changed

9 files changed

+143
-125
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ nlopt
77
matplotlib
88
shapely==1.7.1
99
descartes
10-
pydantic
10+
pydantic>=1.9.0
1111
PyYAML
1212
boto3
1313
requests

tidy3d/components/geometry.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .types import Bound, Size, Coordinate, Axis, Coordinate2D, tidynumpy, Array
1515
from .types import Vertices, Ax, Shapely
1616
from .viz import add_ax_if_none, equal_aspect
17-
from .viz import PLOT_BUFFER, ARROW_LENGTH_FACTOR, ARROW_WIDTH_FACTOR
17+
from .viz import PLOT_BUFFER, ARROW_LENGTH_FACTOR, ARROW_WIDTH_FACTOR, MAX_ARROW_WIDTH_FACTOR
1818
from ..log import Tidy3dKeyError, SetupError, ValidationError
1919
from ..constants import MICROMETER, LARGE_NUMBER
2020

@@ -405,15 +405,18 @@ def rotate_points(points: Array[float], axis: Coordinate, angle: float) -> Array
405405
rot_mat = np.zeros((3, 3))
406406
cos = np.cos(angle)
407407
sin = np.sin(angle)
408-
rot_mat[0, 0] = cos + ux ** 2 * (1 - cos)
408+
rot_mat[0, 0] = cos + ux**2 * (1 - cos)
409409
rot_mat[0, 1] = ux * uy * (1 - cos) - uz * sin
410410
rot_mat[0, 2] = ux * uz * (1 - cos) + uy * sin
411411
rot_mat[1, 0] = uy * ux * (1 - cos) + uz * sin
412-
rot_mat[1, 1] = cos + uy ** 2 * (1 - cos)
412+
rot_mat[1, 1] = cos + uy**2 * (1 - cos)
413413
rot_mat[1, 2] = uy * uz * (1 - cos) - ux * sin
414414
rot_mat[2, 0] = uz * ux * (1 - cos) - uy * sin
415415
rot_mat[2, 1] = uz * uy * (1 - cos) + ux * sin
416-
rot_mat[2, 2] = cos + uz ** 2 * (1 - cos)
416+
rot_mat[2, 2] = cos + uz**2 * (1 - cos)
417+
418+
if len(points.shape) == 1:
419+
return rot_mat @ points
417420

418421
return np.einsum("ij,jp...->ip...", rot_mat, points)
419422

@@ -423,14 +426,11 @@ def reflect_points(
423426
polar_axis: Axis,
424427
angle_theta: float,
425428
angle_phi: float,
426-
plane_point: Coordinate,
427429
) -> Array[float]:
428-
"""Reflect a set of points in 3D at a plane defined by a point on the plane and an axis
429-
normal to the plane.
430-
all ``points``, array of shape (3, ...) at a plane passing through a given
431-
``plane_point``, normal to an axis defined in polar coordinates (theta, phi) w.r.t. the
430+
"""Reflect a set of points in 3D at a plane passing through the coordinate origin defined
431+
and normal to a given axis defined in polar coordinates (theta, phi) w.r.t. the
432432
``polar_axis`` which can be 0, 1, or 2.
433-
433+
434434
Parameters
435435
----------
436436
points : Array[float]
@@ -441,20 +441,15 @@ def reflect_points(
441441
Polar angle w.r.t. the polar axis.
442442
angle_phi : float
443443
Azimuth angle around the polar axis.
444-
plane_point : Coordinate
445-
A point lying on the reflection plane.
446444
"""
447445

448-
# Offset coordinates such that ``plane_point`` is at the origin
449-
points_new = points - plane_point
450-
451446
# Rotate such that the plane normal is along the polar_axis
452447
axis_theta, axis_phi = [0, 0, 0], [0, 0, 0]
453448
axis_phi[polar_axis] = 1
454449
plane_axes = [0, 1, 2]
455450
plane_axes.pop(polar_axis)
456451
axis_theta[plane_axes[1]] = 1
457-
points_new = self.rotate_points(points_new, axis_phi, -angle_phi)
452+
points_new = self.rotate_points(points, axis_phi, -angle_phi)
458453
points_new = self.rotate_points(points_new, axis_theta, -angle_theta)
459454

460455
# Flip the ``polar_axis`` coordinate of the points, which is now normal to the plane
@@ -617,7 +612,7 @@ def _intersect_dist(self, position, z0) -> float:
617612
dz = np.abs(z0 - position)
618613
if dz > self.radius:
619614
return None
620-
return 2 * np.sqrt(self.radius ** 2 - dz ** 2)
615+
return 2 * np.sqrt(self.radius**2 - dz**2)
621616

622617

623618
""" importable geometries """
@@ -786,7 +781,7 @@ def _plot_arrow( # pylint:disable=too-many-arguments, too-many-locals
786781

787782
plot_axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z)
788783
arrow_axis = [component == 0 for component in direction]
789-
arrow_length = self._arrow_length(ax, length_factor)
784+
arrow_length, arrow_width = self._arrow_dims(ax, length_factor, width_factor)
790785

791786
# only add arrow if the plotting plane is perpendicular to the source
792787
if arrow_axis.count(0.0) > 1 or arrow_axis.index(0.0) != plot_axis:
@@ -800,7 +795,7 @@ def add_arrow(sign=1.0):
800795
y=y0,
801796
dx=sign * arrow_length * dx,
802797
dy=sign * arrow_length * dy,
803-
width=width_factor * arrow_length,
798+
width=arrow_width,
804799
color=color,
805800
alpha=alpha,
806801
zorder=np.inf,
@@ -812,17 +807,28 @@ def add_arrow(sign=1.0):
812807

813808
return ax
814809

815-
def _arrow_length(self, ax: Ax, length_factor: float = ARROW_LENGTH_FACTOR) -> float:
816-
"""Length of arrow is the minimum size of the axes times the length factor."""
810+
def _arrow_dims(
811+
self,
812+
ax: Ax,
813+
length_factor: float = ARROW_LENGTH_FACTOR,
814+
width_factor: float = ARROW_WIDTH_FACTOR,
815+
) -> Tuple[float, float]:
816+
"""Length and width of arrow based on axes size and length and width factors."""
817817

818818
# get the sizes of the matplotlib axes
819819
xmin, xmax = ax.get_xlim()
820820
ymin, ymax = ax.get_ylim()
821821
ax_width = xmax - xmin
822822
ax_height = ymax - ymin
823823

824-
# apply length factor to the minimum size to get arrow width
825-
return length_factor * min(ax_width, ax_height)
824+
# apply length factor to the minimum size to get arrow length
825+
arrow_length = length_factor * min(ax_width, ax_height)
826+
827+
# constrain arrow width by the maximum size and the max arrow width factor
828+
arrow_width = width_factor * arrow_length
829+
arrow_width = min(arrow_width, MAX_ARROW_WIDTH_FACTOR * max(ax_width, ax_height))
830+
831+
return arrow_length, arrow_width
826832

827833

828834
class Sphere(Circular):
@@ -854,7 +860,7 @@ def inside(self, x, y, z) -> bool:
854860
dist_x = np.abs(x - x0)
855861
dist_y = np.abs(y - y0)
856862
dist_z = np.abs(z - z0)
857-
return (dist_x ** 2 + dist_y ** 2 + dist_z ** 2) <= (self.radius ** 2)
863+
return (dist_x**2 + dist_y**2 + dist_z**2) <= (self.radius**2)
858864

859865
def intersections(self, x: float = None, y: float = None, z: float = None):
860866
"""Returns shapely geometry at plane specified by one non None value of x,y,z.
@@ -976,7 +982,7 @@ def inside(self, x, y, z) -> bool:
976982
dist_x = np.abs(x - x0)
977983
dist_y = np.abs(y - y0)
978984
dist_z = np.abs(z - z0)
979-
inside_radius = (dist_x ** 2 + dist_y ** 2) <= (self.radius ** 2)
985+
inside_radius = (dist_x**2 + dist_y**2) <= (self.radius**2)
980986
inside_height = dist_z < (self.length / 2)
981987
return inside_radius * inside_height
982988

tidy3d/components/mode.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from typing import Tuple
44

55
import pydantic as pd
6+
import numpy as np
67

7-
from ..constants import MICROMETER, RADIAN
8+
from ..constants import MICROMETER, RADIAN, GLANCING_CUTOFF
89
from .base import Tidy3dBaseModel
910
from .types import Axis2D, Literal
1011
from ..log import SetupError
@@ -83,3 +84,13 @@ def bend_axis_given(cls, val, values):
8384
if val is None and values.get("bend_radius") is not None:
8485
raise SetupError("bend_axis must also be defined if bend_radius is defined.")
8586
return val
87+
88+
@pd.validator("angle_theta", allow_reuse=True, always=True)
89+
def glancing_incidence(cls, val):
90+
"""Warn if close to glancing incidence."""
91+
if np.abs(np.pi / 2 - val) < GLANCING_CUTOFF:
92+
raise SetupError(
93+
"Mode propagation axis too close to glancing angle for accurate injection. "
94+
"For best results, switch the injection axis."
95+
)
96+
return val

tidy3d/components/monitor.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,34 @@ class AbstractModeMonitor(PlanarMonitor, FreqMonitor):
183183
description="Parameters to feed to mode solver which determine modes measured by monitor.",
184184
)
185185

186+
def plot(
187+
self, x: float = None, y: float = None, z: float = None, ax: Ax = None, **kwargs
188+
) -> Ax:
189+
190+
# call the monitor.plot() function first
191+
ax = super().plot(x=x, y=y, z=z, ax=ax, **kwargs)
192+
193+
# and then add an arrow using the direction comuputed from `_dir_arrow`.
194+
ax = self._plot_arrow(
195+
x=x,
196+
y=y,
197+
z=z,
198+
ax=ax,
199+
direction=self._dir_arrow,
200+
color=ARROW_COLOR_MONITOR,
201+
alpha=ARROW_ALPHA,
202+
both_dirs=True,
203+
)
204+
return ax
205+
206+
@property
207+
def _dir_arrow(self) -> Tuple[float, float, float]:
208+
"""Source direction normal vector in cartesian coordinates."""
209+
dx = np.cos(self.mode_spec.angle_phi) * np.sin(self.mode_spec.angle_theta)
210+
dy = np.sin(self.mode_spec.angle_phi) * np.sin(self.mode_spec.angle_theta)
211+
dz = np.cos(self.mode_spec.angle_theta)
212+
return self.unpop_axis(dz, (dx, dy), axis=self.size.index(0.0))
213+
186214

187215
class FieldMonitor(AbstractFieldMonitor, FreqMonitor):
188216
""":class:`Monitor` that records electromagnetic fields in the frequency domain.
@@ -357,34 +385,6 @@ def storage_size(self, num_cells: int, tmesh: int) -> int:
357385
# stores 3 complex numbers per frequency, per mode.
358386
return 3 * BYTES_COMPLEX * len(self.freqs) * self.mode_spec.num_modes
359387

360-
def plot(
361-
self, x: float = None, y: float = None, z: float = None, ax: Ax = None, **kwargs
362-
) -> Ax:
363-
364-
# call the monitor.plot() function first
365-
ax = super().plot(x=x, y=y, z=z, ax=ax, **kwargs)
366-
367-
# and then add an arrow using the direction comuputed from `_dir_arrow`.
368-
ax = self._plot_arrow(
369-
x=x,
370-
y=y,
371-
z=z,
372-
ax=ax,
373-
direction=self._dir_arrow,
374-
color=ARROW_COLOR_MONITOR,
375-
alpha=ARROW_ALPHA,
376-
both_dirs=True,
377-
)
378-
return ax
379-
380-
@property
381-
def _dir_arrow(self) -> Tuple[float, float, float]:
382-
"""Source direction normal vector in cartesian coordinates."""
383-
normal = [0.0, 0.0, 0.0]
384-
normal_axis = self.size.index(0.0)
385-
normal[normal_axis] = 1.0
386-
return tuple(normal)
387-
388388

389389
class ModeSolverMonitor(AbstractModeMonitor):
390390
""":class:`Monitor` that stores the mode data (field profiles and effective index)

0 commit comments

Comments
 (0)