Skip to content

Commit

Permalink
refactor geometry function signatures for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
subagonsouth committed Sep 25, 2024
1 parent 74e0054 commit 632a3b7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
22 changes: 12 additions & 10 deletions imap_processing/spice/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,24 +166,24 @@ def get_spacecraft_spin_phase(
@typing.no_type_check
@ensure_spice
def frame_transform(
from_frame: SpiceFrame,
to_frame: SpiceFrame,
et: Union[float, npt.NDArray],
position: npt.NDArray,
from_frame: SpiceFrame,
to_frame: SpiceFrame,
) -> npt.NDArray:
"""
Transform an <x, y, z> vector between reference frames (rotation only).
Parameters
----------
from_frame : SpiceFrame
Reference frame of input vector(s).
to_frame : SpiceFrame
Reference frame of output vector(s).
et : float or npt.NDArray
Ephemeris time(s) corresponding to position(s).
position : npt.NDArray
<x, y, z> vector or array of vectors in reference frame `from_frame`.
from_frame : SpiceFrame
Reference frame of input vector(s).
to_frame : SpiceFrame
Reference frame of output vector(s).
Returns
-------
Expand Down Expand Up @@ -213,7 +213,7 @@ def frame_transform(

# rotate will have shape = (3, 3) or (n, 3, 3)
# position will have shape = (3,) or (n, 3)
rotate = get_rotation_matrix(from_frame, to_frame, et)
rotate = get_rotation_matrix(et, from_frame, to_frame)
# adding a dimension to position results in the following input and output
# shapes from matrix multiplication
# Single et/position: (3, 3),(3, 1) -> (3, 1)
Expand All @@ -224,7 +224,9 @@ def frame_transform(


def get_rotation_matrix(
from_frame: SpiceFrame, to_frame: SpiceFrame, et: Union[float, npt.NDArray]
et: Union[float, npt.NDArray],
from_frame: SpiceFrame,
to_frame: SpiceFrame,
) -> npt.NDArray:
"""
Get the rotation matrix/matrices that can be used to transform between frames.
Expand All @@ -236,12 +238,12 @@ def get_rotation_matrix(
Parameters
----------
et : float or npt.NDArray
Ephemeris time(s) for which to get the rotation matrices.
from_frame : SpiceFrame
Reference frame to transform from.
to_frame : SpiceFrame
Reference frame to transform to.
et : float or npt.NDArray
Ephemeris time(s) for which to get the rotation matrices.
Returns
-------
Expand Down
25 changes: 11 additions & 14 deletions imap_processing/tests/spice/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_frame_transform(furnish_kernels):
et_0 = spice.utc2et("2025-04-30T12:00:00.000")
position = np.arange(3) + 1
result_0 = frame_transform(
SpiceFrame.IMAP_ULTRA_45, SpiceFrame.IMAP_DPS, et_0, position
et_0, position, SpiceFrame.IMAP_ULTRA_45, SpiceFrame.IMAP_DPS
)
# compare against pure SPICE calculation
rotation_matrix = spice.pxform(
Expand All @@ -112,10 +112,7 @@ def test_frame_transform(furnish_kernels):
ets = np.array([et_0, et_0 + 10])
positions = np.array([[1, 1, 1], [1, 2, 3]])
vec_result = frame_transform(
SpiceFrame.IMAP_HI_90,
SpiceFrame.IMAP_DPS,
ets,
positions,
ets, positions, SpiceFrame.IMAP_HI_90, SpiceFrame.IMAP_DPS
)

assert vec_result.shape == (2, 3)
Expand All @@ -134,34 +131,34 @@ def test_frame_transform_exceptions():
ValueError, match="Position vectors with one dimension must have 3 elements."
):
frame_transform(
SpiceFrame.IMAP_SPACECRAFT, SpiceFrame.IMAP_CODICE, 0, np.arange(4)
0, np.arange(4), SpiceFrame.IMAP_SPACECRAFT, SpiceFrame.IMAP_CODICE
)
with pytest.raises(
ValueError,
match="Ephemeris time must be float when single position vector is provided.",
):
frame_transform(
SpiceFrame.ECLIPJ2000,
SpiceFrame.IMAP_HIT,
np.asarray(0),
np.array([1, 0, 0]),
SpiceFrame.ECLIPJ2000,
SpiceFrame.IMAP_HIT,
)
with pytest.raises(ValueError, match="Invalid position shape: "):
frame_transform(
SpiceFrame.ECLIPJ2000,
SpiceFrame.IMAP_HIT,
np.arange(2),
np.arange(4).reshape((2, 2)),
SpiceFrame.ECLIPJ2000,
SpiceFrame.IMAP_HIT,
)
with pytest.raises(
ValueError,
match="Mismatch in number of position vectors and Ephemeris times provided.",
):
frame_transform(
SpiceFrame.ECLIPJ2000,
SpiceFrame.IMAP_HIT,
np.arange(2),
np.arange(9).reshape((3, 3)),
SpiceFrame.ECLIPJ2000,
SpiceFrame.IMAP_HIT,
)


Expand All @@ -178,11 +175,11 @@ def test_get_rotation_matrix(furnish_kernels):
et = spice.utc2et("2025-09-30T12:00:00.000")
# test input of float
rotation = get_rotation_matrix(
SpiceFrame.IMAP_IDEX, SpiceFrame.IMAP_SPACECRAFT, et
et, SpiceFrame.IMAP_IDEX, SpiceFrame.IMAP_SPACECRAFT
)
assert rotation.shape == (3, 3)
# test array of et input
rotation = get_rotation_matrix(
SpiceFrame.IMAP_IDEX, SpiceFrame.IMAP_SPACECRAFT, np.arange(10) + et
np.arange(10) + et, SpiceFrame.IMAP_IDEX, SpiceFrame.IMAP_SPACECRAFT
)
assert rotation.shape == (10, 3, 3)

0 comments on commit 632a3b7

Please sign in to comment.