14
14
from .types import Bound , Size , Coordinate , Axis , Coordinate2D , tidynumpy , Array
15
15
from .types import Vertices , Ax , Shapely
16
16
from .viz import add_ax_if_none , equal_aspect
17
- from .viz import PLOT_BUFFER , ARROW_LENGTH_FACTOR , ARROW_WIDTH_FACTOR
17
+ from .viz import PLOT_BUFFER , ARROW_LENGTH_FACTOR , ARROW_WIDTH_FACTOR , MAX_ARROW_WIDTH_FACTOR
18
18
from ..log import Tidy3dKeyError , SetupError , ValidationError
19
19
from ..constants import MICROMETER , LARGE_NUMBER
20
20
@@ -405,15 +405,18 @@ def rotate_points(points: Array[float], axis: Coordinate, angle: float) -> Array
405
405
rot_mat = np .zeros ((3 , 3 ))
406
406
cos = np .cos (angle )
407
407
sin = np .sin (angle )
408
- rot_mat [0 , 0 ] = cos + ux ** 2 * (1 - cos )
408
+ rot_mat [0 , 0 ] = cos + ux ** 2 * (1 - cos )
409
409
rot_mat [0 , 1 ] = ux * uy * (1 - cos ) - uz * sin
410
410
rot_mat [0 , 2 ] = ux * uz * (1 - cos ) + uy * sin
411
411
rot_mat [1 , 0 ] = uy * ux * (1 - cos ) + uz * sin
412
- rot_mat [1 , 1 ] = cos + uy ** 2 * (1 - cos )
412
+ rot_mat [1 , 1 ] = cos + uy ** 2 * (1 - cos )
413
413
rot_mat [1 , 2 ] = uy * uz * (1 - cos ) - ux * sin
414
414
rot_mat [2 , 0 ] = uz * ux * (1 - cos ) - uy * sin
415
415
rot_mat [2 , 1 ] = uz * uy * (1 - cos ) + ux * sin
416
- rot_mat [2 , 2 ] = cos + uz ** 2 * (1 - cos )
416
+ rot_mat [2 , 2 ] = cos + uz ** 2 * (1 - cos )
417
+
418
+ if len (points .shape ) == 1 :
419
+ return rot_mat @ points
417
420
418
421
return np .einsum ("ij,jp...->ip..." , rot_mat , points )
419
422
@@ -423,14 +426,11 @@ def reflect_points(
423
426
polar_axis : Axis ,
424
427
angle_theta : float ,
425
428
angle_phi : float ,
426
- plane_point : Coordinate ,
427
429
) -> Array [float ]:
428
- """Reflect a set of points in 3D at a plane defined by a point on the plane and an axis
429
- normal to the plane.
430
- all ``points``, array of shape (3, ...) at a plane passing through a given
431
- ``plane_point``, normal to an axis defined in polar coordinates (theta, phi) w.r.t. the
430
+ """Reflect a set of points in 3D at a plane passing through the coordinate origin defined
431
+ and normal to a given axis defined in polar coordinates (theta, phi) w.r.t. the
432
432
``polar_axis`` which can be 0, 1, or 2.
433
-
433
+
434
434
Parameters
435
435
----------
436
436
points : Array[float]
@@ -441,20 +441,15 @@ def reflect_points(
441
441
Polar angle w.r.t. the polar axis.
442
442
angle_phi : float
443
443
Azimuth angle around the polar axis.
444
- plane_point : Coordinate
445
- A point lying on the reflection plane.
446
444
"""
447
445
448
- # Offset coordinates such that ``plane_point`` is at the origin
449
- points_new = points - plane_point
450
-
451
446
# Rotate such that the plane normal is along the polar_axis
452
447
axis_theta , axis_phi = [0 , 0 , 0 ], [0 , 0 , 0 ]
453
448
axis_phi [polar_axis ] = 1
454
449
plane_axes = [0 , 1 , 2 ]
455
450
plane_axes .pop (polar_axis )
456
451
axis_theta [plane_axes [1 ]] = 1
457
- points_new = self .rotate_points (points_new , axis_phi , - angle_phi )
452
+ points_new = self .rotate_points (points , axis_phi , - angle_phi )
458
453
points_new = self .rotate_points (points_new , axis_theta , - angle_theta )
459
454
460
455
# Flip the ``polar_axis`` coordinate of the points, which is now normal to the plane
@@ -617,7 +612,7 @@ def _intersect_dist(self, position, z0) -> float:
617
612
dz = np .abs (z0 - position )
618
613
if dz > self .radius :
619
614
return None
620
- return 2 * np .sqrt (self .radius ** 2 - dz ** 2 )
615
+ return 2 * np .sqrt (self .radius ** 2 - dz ** 2 )
621
616
622
617
623
618
""" importable geometries """
@@ -786,7 +781,7 @@ def _plot_arrow( # pylint:disable=too-many-arguments, too-many-locals
786
781
787
782
plot_axis , _ = self .parse_xyz_kwargs (x = x , y = y , z = z )
788
783
arrow_axis = [component == 0 for component in direction ]
789
- arrow_length = self ._arrow_length (ax , length_factor )
784
+ arrow_length , arrow_width = self ._arrow_dims (ax , length_factor , width_factor )
790
785
791
786
# only add arrow if the plotting plane is perpendicular to the source
792
787
if arrow_axis .count (0.0 ) > 1 or arrow_axis .index (0.0 ) != plot_axis :
@@ -800,7 +795,7 @@ def add_arrow(sign=1.0):
800
795
y = y0 ,
801
796
dx = sign * arrow_length * dx ,
802
797
dy = sign * arrow_length * dy ,
803
- width = width_factor * arrow_length ,
798
+ width = arrow_width ,
804
799
color = color ,
805
800
alpha = alpha ,
806
801
zorder = np .inf ,
@@ -812,17 +807,28 @@ def add_arrow(sign=1.0):
812
807
813
808
return ax
814
809
815
- def _arrow_length (self , ax : Ax , length_factor : float = ARROW_LENGTH_FACTOR ) -> float :
816
- """Length of arrow is the minimum size of the axes times the length factor."""
810
+ def _arrow_dims (
811
+ self ,
812
+ ax : Ax ,
813
+ length_factor : float = ARROW_LENGTH_FACTOR ,
814
+ width_factor : float = ARROW_WIDTH_FACTOR ,
815
+ ) -> Tuple [float , float ]:
816
+ """Length and width of arrow based on axes size and length and width factors."""
817
817
818
818
# get the sizes of the matplotlib axes
819
819
xmin , xmax = ax .get_xlim ()
820
820
ymin , ymax = ax .get_ylim ()
821
821
ax_width = xmax - xmin
822
822
ax_height = ymax - ymin
823
823
824
- # apply length factor to the minimum size to get arrow width
825
- return length_factor * min (ax_width , ax_height )
824
+ # apply length factor to the minimum size to get arrow length
825
+ arrow_length = length_factor * min (ax_width , ax_height )
826
+
827
+ # constrain arrow width by the maximum size and the max arrow width factor
828
+ arrow_width = width_factor * arrow_length
829
+ arrow_width = min (arrow_width , MAX_ARROW_WIDTH_FACTOR * max (ax_width , ax_height ))
830
+
831
+ return arrow_length , arrow_width
826
832
827
833
828
834
class Sphere (Circular ):
@@ -854,7 +860,7 @@ def inside(self, x, y, z) -> bool:
854
860
dist_x = np .abs (x - x0 )
855
861
dist_y = np .abs (y - y0 )
856
862
dist_z = np .abs (z - z0 )
857
- return (dist_x ** 2 + dist_y ** 2 + dist_z ** 2 ) <= (self .radius ** 2 )
863
+ return (dist_x ** 2 + dist_y ** 2 + dist_z ** 2 ) <= (self .radius ** 2 )
858
864
859
865
def intersections (self , x : float = None , y : float = None , z : float = None ):
860
866
"""Returns shapely geometry at plane specified by one non None value of x,y,z.
@@ -976,7 +982,7 @@ def inside(self, x, y, z) -> bool:
976
982
dist_x = np .abs (x - x0 )
977
983
dist_y = np .abs (y - y0 )
978
984
dist_z = np .abs (z - z0 )
979
- inside_radius = (dist_x ** 2 + dist_y ** 2 ) <= (self .radius ** 2 )
985
+ inside_radius = (dist_x ** 2 + dist_y ** 2 ) <= (self .radius ** 2 )
980
986
inside_height = dist_z < (self .length / 2 )
981
987
return inside_radius * inside_height
982
988
0 commit comments