Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Schuenke <37338697+schuenke@users.noreply.github.com>
  • Loading branch information
fzimmermann89 and schuenke authored Dec 17, 2024
1 parent 368500c commit ab9612e
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions src/mrpro/data/KTrajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,29 +92,33 @@ def from_tensor(
axes_order
Order of the axes in the tensor. Our convention usually is zyx order.
repeat_detection_tolerance
detects if broadcasting can be used, i.e. if dimensions are repeated.
Set to None to disable.
Tolerance for detecting repeated dimensions (broadcasting).
If trajectory points differ by less than this value, they are considered
identical. Set to None to disable this feature.
grid_detection_tolerance
tolerance to detect if trajectory points are on integer grid positions
encoding_matrix
if an encoding matrix is supplied, the trajectory is rescaled to fit
within the matrix. Otherwise, it is left as-is.
Tolerance for detecting whether trajectory points align with integer
grid positions. This tolerance is applied after rescaling if
`scaling_matrix` is provided.
scaling_matrix
If a scaling matrix is provided, the trajectory is rescaled to fit within
the dimensions of the matrix. If not provided, the trajectory remains unchanged.
"""
ks = tensor.unbind(dim=stack_dim)
kz, ky, kx = (ks[axes_order.index(axis)] for axis in 'zyx')

def normalize(k: torch.Tensor, encoding_size: int) -> torch.Tensor:
def rescale(k: torch.Tensor, size: float) -> torch.Tensor:
max_abs_range = 2 * k.abs().max()
if encoding_size == 1 or max_abs_range < 1e-6:
if size < 2 or max_abs_range < 1e-6:
# a single encoding point should be at zero
# avoid division by zero
return torch.zeros_like(k)
return k * (encoding_size / max_abs_range)
return k * (size / max_abs_range)

if encoding_matrix is not None:
kz = normalize(kz, encoding_matrix.z)
ky = normalize(ky, encoding_matrix.y)
kx = normalize(kx, encoding_matrix.x)
if scaling_matrix is not None:
kz = rescale(kz, scaling_matrix.z)
ky = rescale(ky, scaling_matrix.y)
kx = rescale(kx, scaling_matrix.x)

return cls(
kz,
Expand Down

0 comments on commit ab9612e

Please sign in to comment.