-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
213 lines (173 loc) · 5.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import matplotlib.pyplot as plt
import numpy as np
def vector_length(v):
"""Calculate the length of a vector.
Args:
v (np.ndarray): Input vector.
Returns:
float: Length of the vector.
Example:
>>> import numpy as np
>>> v = np.array([3, 4])
>>> vector_length(v)
5.0
"""
return np.linalg.norm(v)
def unit_vector(vector):
"""Calculate the unit vector of a given vector.
Args:
vector (np.ndarray): Input vector.
Returns:
np.ndarray: Unit vector.
Example:
>>> import numpy as np
>>> v = np.array([3, 4])
>>> unit_vector(v)
array([0.6, 0.8])
"""
return vector / vector_length(vector)
def angle_between(v1, v2):
"""Calculate the angle between two vectors.
Args:
v1 (np.ndarray): First vector.
v2 (np.ndarray): Second vector.
Returns:
float: Angle between the vectors in radians.
Example:
>>> import numpy as np
>>> v1 = np.array([1, 0])
>>> v2 = np.array([0, 1])
>>> round(angle_between(v1, v2), 2)
1.57
"""
v1_u = unit_vector(v1)
v2_u = unit_vector(v2)
return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
def create_3d_plot(figsize=(12, 10)):
"""Create a 3D plot.
Args:
figsize (tuple): Figure size (width, height).
Returns:
tuple: Figure and Axes objects.
"""
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection="3d")
return fig, ax
def plot_vector(ax, start, vector, color="r", label=None):
"""Plot a vector in 3D space with a smaller arrowhead.
Args:
ax (Axes): Matplotlib 3D axes object.
start (np.ndarray): Starting point of the vector.
vector (np.ndarray): Vector to plot.
color (str): Color of the vector.
label (str): Label for the vector.
"""
ax.quiver(
start[0],
start[1],
start[2],
vector[0],
vector[1],
vector[2],
color=color,
label=label,
arrow_length_ratio=0.1, # Controls the size of the arrowhead
pivot='tail', # Ensures the arrow starts at the correct point
linewidth=2, # Adjust the thickness of the arrow shaft
)
def plot_line(
ax,
start,
direction,
t_range=(-2, 2),
num_points=100,
color="b",
linestyle="-",
label=None,
):
"""Plot a line in 3D space.
Args:
ax (Axes): Matplotlib 3D axes object.
start (np.ndarray): Starting point of the line.
direction (np.ndarray): Direction vector of the line.
t_range (tuple): Range of the parameter t.
num_points (int): Number of points to plot.
color (str): Color of the line.
linestyle (str): Style of the line.
label (str): Label for the line.
"""
t = np.linspace(t_range[0], t_range[1], num_points)
line_points = start[:, np.newaxis] + direction[:, np.newaxis] * t
ax.plot(
line_points[0],
line_points[1],
line_points[2],
color=color,
linestyle=linestyle,
label=label,
)
def plot_point(ax, point, color="r", size=100, label=None):
"""Plot a point in 3D space.
Args:
ax (Axes): Matplotlib 3D axes object.
point (np.ndarray): Point to plot.
color (str): Color of the point.
size (float): Size of the point.
label (str): Label for the point.
"""
ax.scatter(*point, color=color, s=size, label=label)
def plot_shortest_distance(ax, point1, point2, color="r", linestyle="--", label=None):
"""Plot the shortest distance between two points in 3D space.
Args:
ax (Axes): Matplotlib 3D axes object.
point1 (np.ndarray): First point.
point2 (np.ndarray): Second point.
color (str): Color of the line.
linestyle (str): Style of the line.
label (str): Label for the line.
"""
ax.plot(
[point1[0], point2[0]],
[point1[1], point2[1]],
[point1[2], point2[2]],
color=color,
linestyle=linestyle,
label=label,
)
def add_text_3d(ax, position, text, fontsize=10, ha="center", va="center", bbox=None):
"""Add text to a 3D plot.
Args:
ax (Axes): Matplotlib 3D axes object.
position (np.ndarray): Position of the text.
text (str): Text to add.
fontsize (int): Font size of the text.
ha (str): Horizontal alignment.
va (str): Vertical alignment.
bbox (dict): Bounding box properties.
"""
ax.text(*position, text, fontsize=fontsize, ha=ha, va=va, bbox=bbox)
def set_plot_limits(ax, points, scale=1.2):
"""Set the limits of a 3D plot based on the given points.
Args:
ax (Axes): Matplotlib 3D axes object.
points (np.ndarray): Array of points.
scale (float): Scale factor for the limits.
"""
max_limit = np.max(np.abs(points)) * scale
ax.set_xlim([-max_limit, max_limit])
ax.set_ylim([-max_limit, max_limit])
ax.set_zlim([-max_limit, max_limit])
def finalize_plot(ax, title, xlabel="X", ylabel="Y", zlabel="Z"):
"""Finalize the 3D plot by setting labels, title, and legend.
Args:
ax (Axes): Matplotlib 3D axes object.
title (str): Title of the plot.
xlabel (str): Label for the x-axis.
ylabel (str): Label for the y-axis.
zlabel (str): Label for the z-axis.
"""
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_zlabel(zlabel)
ax.set_title(title)
ax.legend()