diff --git a/examples/scripts/ct_projector_comparison_3d.py b/examples/scripts/ct_projector_comparison_3d.py index 2752b9b9..96988282 100644 --- a/examples/scripts/ct_projector_comparison_3d.py +++ b/examples/scripts/ct_projector_comparison_3d.py @@ -132,8 +132,8 @@ Convert ASTRA geometry to SCICO and project. """ -P_from_astra = scico.linop.xray.astra.convert_to_scico_geometry(H_astra.vol_geom, H_astra.proj_geom) -H_scico_from_astra = XRayTransform(XRayTransform3D(in_shape, P_from_astra, out_shape)) +P_from_astra = scico.linop.xray.astra._astra_to_scico_geometry(H_astra.vol_geom, H_astra.proj_geom) +H_scico_from_astra = XRayTransform3D(in_shape, P_from_astra, out_shape) y_scico_from_astra = H_scico_from_astra @ x HTy_scico_from_astra = H_scico_from_astra.T @ y_scico_from_astra diff --git a/scico/linop/xray/__init__.py b/scico/linop/xray/__init__.py index 4b57161b..75c66368 100644 --- a/scico/linop/xray/__init__.py +++ b/scico/linop/xray/__init__.py @@ -15,12 +15,16 @@ transform that is the appropriate mathematical model for beam attenuation based imaging in three or more dimensions. -SCICO includes its own integrated 2D X-ray transform, and also provides -interfaces to those implemented in the +SCICO includes its own integrated 2D and 3D X-ray transforms, and also +provides interfaces to those implemented in the `ASTRA toolbox `_ -and the `svmbir `_ package. Each of -these transforms uses a different convention for view angle directions, -as illustrated in the figure below. +and the `svmbir `_ package. + + +**2D Transforms** + +The SCICO, ASTRA, and svmbir transforms use different conventions for +view angle directions, as illustrated in the figure below. .. plot:: pyfigures/xray_2d_geom.py :align: center @@ -43,8 +47,24 @@ \theta_{\text{svmbir}} &= 2 \pi - \theta_{\text{scico}} \;. \end{aligned} + +**3D Transforms** + +There are more significant differences in the interfaces for the 3D SCICO +and ASTRA transforms. The SCICO 3D transform :class:`.xray.XRayTransform3D` +defines the projection geometry in terms of a set of projection matrices, +while the geometry for the ASTRA 3D transform +:class:`.astra.XRayTransform3D` may either be specified in terms of a set +of view angles, or via a more general set of vectors specifying projection +direction and detector orientation. A number of support functions are +provided for convering between these conventions. + +Note that the SCICO transform is implemented in JAX and can be run on +both CPU and GPU devices, while the ASTRA transform is implemented in +CUDA, and can only be run on GPU devices. """ +import sys from ._xray import XRayTransform2D, XRayTransform3D @@ -52,3 +72,8 @@ "XRayTransform2D", "XRayTransform3D", ] + + +# Imported items in __all__ appear to originate in top-level xray module +for name in __all__: + getattr(sys.modules[__name__], name).__module__ = __name__ diff --git a/scico/linop/xray/astra.py b/scico/linop/xray/astra.py index ab75f27d..18f89237 100644 --- a/scico/linop/xray/astra.py +++ b/scico/linop/xray/astra.py @@ -20,6 +20,7 @@ from typing import List, Optional, Sequence, Tuple, Union import numpy as np +import numpy.typing import jax from jax.typing import ArrayLike @@ -43,7 +44,10 @@ collections.Iterable = collections.abc.Iterable # type: ignore from scico.linop import LinearOperator -from scico.typing import Shape +from scico.typing import Shape, TypeAlias + +VolumeGeometry: TypeAlias = dict +ProjectionGeometry: TypeAlias = dict def set_astra_gpu_index(idx: Union[int, Sequence[int]]): @@ -55,84 +59,9 @@ def set_astra_gpu_index(idx: Union[int, Sequence[int]]): astra.set_gpu_index(idx) -def convert_from_scico_geometry( - in_shape: Shape, matrices: ArrayLike, det_shape: Shape -) -> ArrayLike: - """ - Convert SCICO projection matrices into ASTRA "parallel3d_vec" vectors. - - For 3D arrays, - in Astra, the dimensions go (slices, rows, columns) and (z, y, x); - in SCICO, the dimensions go (x, y, z). - - In Astra, the x-grid (recon) is centered on the origin and the y-grid (projection) can move. - In SCICO, the x-grid origin is the center of x[0, 0, 0], the y-grid origin is the center - of y[0, 0]. - - See https://astra-toolbox.com/docs/geom3d.html#projection-geometries parallel3d_vec. - - Args: - in_shape: Shape of input image. - matrices: (num_angles, 2, 4) array of homogeneous projection matrices. - det_shape: Shape of detector. - - Returns: - (num_angles, 12) vector array in the ASTRA "parallel3d_vec" convention. - - """ - # ray is perpendicular to projection axes - ray = np.cross(matrices[:, 0, :3], matrices[:, 1, :3]) - # detector center comes from lifting the center index to 3D - y_center = (np.array(det_shape) - 1) / 2 - x_center = ( - np.einsum("...mn,n->...m", matrices[..., :3], (np.array(in_shape) - 1) / 2) - + matrices[..., 3] - ) - d = np.einsum("...mn,...m->...n", matrices[..., :3], y_center - x_center) # (V, 2, 3) x (V, 2) - u = matrices[:, 1, :3] - v = matrices[:, 0, :3] - - # handle different axis conventions - ray = ray[:, [2, 1, 0]] - d = d[:, [2, 1, 0]] - u = u[:, [2, 1, 0]] - v = v[:, [2, 1, 0]] - - vectors = np.concatenate((ray, d, u, v), axis=1) # (v, 12) - return vectors - - -def convert_to_scico_geometry(vol_geom, proj_geom): - """ - Convert ASTRA volume and projection geometry into a SCICO X-ray - projection matrix, assuming "parallel3d_vec" format. - - The approach is to find locate 3 points in the volume domain, - deduce the corresponding projection locations, and, then, solve a - linear system to determine the affine relationship between them. - - Args: - vol_geom: ASTRA volume geometry object. - proj_geom: ASTRA projection geometry object. - - Returns: - (num_angles, 2, 4) array of homogeneous projection matrices. - - """ - x_volume = np.concatenate((np.zeros((1, 3)), np.eye(3)), axis=0) # (4, 3) - x_dets = _project_coords(x_volume, vol_geom, proj_geom) # (1, 4, 2) - - x_volume_aug = np.concatenate((x_volume, np.ones((4, 1))), axis=1) # (4, 4) - matrices = [] - for x_det in x_dets: - M = np.linalg.solve(x_volume_aug, x_det).T - np.testing.assert_allclose(M @ x_volume_aug[0], x_det[0]) - matrices.append(M) - - return np.stack(matrices) - - -def _project_coords(x_volume, vol_geom, proj_geom) -> ArrayLike: +def _project_coords( + x_volume: np.ndarray, vol_geom: VolumeGeometry, proj_geom: ProjectionGeometry +) -> np.ndarray: det_shape = (proj_geom["DetectorRowCount"], proj_geom["DetectorColCount"]) x_world = volume_coords_to_world_coords(x_volume, vol_geom=vol_geom) x_dets = [] @@ -144,10 +73,18 @@ def _project_coords(x_volume, vol_geom, proj_geom) -> ArrayLike: return np.stack(x_dets) -def project_world_coordinates(x, ray, d, u, v, det_shape): - """ - Project world coordinates along ray into the basis described by u - and v with center d. +def project_world_coordinates( + x: np.ndarray, + ray: np.typing.ArrayLike, + d: np.typing.ArrayLike, + u: np.typing.ArrayLike, + v: np.typing.ArrayLike, + det_shape: Sequence[int], +) -> np.ndarray: + """Project world coordinates along ray into the specified basis. + + Project world coordinates along `ray` into the basis described by `u` + and `v` with center `d`. Args: x: (..., 3) vector(s) of world coordinates. @@ -173,8 +110,9 @@ def project_world_coordinates(x, ray, d, u, v, det_shape): return ind_ij -def volume_coords_to_world_coords(idx, vol_geom): - """ +def volume_coords_to_world_coords(idx: np.ndarray, vol_geom: VolumeGeometry): + """Convert a volume coordinate into a world coordinate. + Convert a volume coordinate into a world coordinate using ASTRA conventions. @@ -192,7 +130,7 @@ def volume_coords_to_world_coords(idx, vol_geom): return _volume_index_to_astra_world_3d(idx, vol_geom) -def _volume_index_to_astra_world_2d(idx, vol_geom): +def _volume_index_to_astra_world_2d(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray: coord = idx[..., [2, 1]] # x:col, y:row, nx = np.array( # (x, y) order ( @@ -211,7 +149,7 @@ def _volume_index_to_astra_world_2d(idx, vol_geom): return (coord - center_coord) * dx -def _volume_index_to_astra_world_3d(idx, vol_geom): +def _volume_index_to_astra_world_3d(idx: np.ndarray, vol_geom: VolumeGeometry) -> np.ndarray: coord = idx[..., [2, 1, 0]] # x:col, y:row, z:slice nx = np.array( # (x, y, z) order ( @@ -398,6 +336,113 @@ def f(sino): return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), sino) +def convert_from_scico_geometry( + in_shape: Shape, matrices: ArrayLike, det_shape: Shape +) -> np.ndarray: + """Convert SCICO projection matrices into ASTRA "parallel3d_vec" vectors. + + For 3D arrays, + in ASTRA, the dimensions go (slices, rows, columns) and (z, y, x); + in SCICO, the dimensions go (x, y, z). + + In ASTRA, the x-grid (recon) is centered on the origin and the y-grid (projection) can move. + In SCICO, the x-grid origin is the center of x[0, 0, 0], the y-grid origin is the center + of y[0, 0]. + + See section "parallel3d_vec" in the + `astra documentation `__. + + Args: + in_shape: Shape of input image. + matrices: (num_angles, 2, 4) array of homogeneous projection matrices. + det_shape: Shape of detector. + + Returns: + (num_angles, 12) vector array in the ASTRA "parallel3d_vec" convention. + """ + # ray is perpendicular to projection axes + ray = np.cross(matrices[:, 0, :3], matrices[:, 1, :3]) + # detector center comes from lifting the center index to 3D + y_center = (np.array(det_shape) - 1) / 2 + x_center = ( + np.einsum("...mn,n->...m", matrices[..., :3], (np.array(in_shape) - 1) / 2) + + matrices[..., 3] + ) + d = np.einsum("...mn,...m->...n", matrices[..., :3], y_center - x_center) # (V, 2, 3) x (V, 2) + u = matrices[:, 1, :3] + v = matrices[:, 0, :3] + + # handle different axis conventions + ray = ray[:, [2, 1, 0]] + d = d[:, [2, 1, 0]] + u = u[:, [2, 1, 0]] + v = v[:, [2, 1, 0]] + + vectors = np.concatenate((ray, d, u, v), axis=1) # (v, 12) + return vectors + + +def _astra_to_scico_geometry(vol_geom: VolumeGeometry, proj_geom: ProjectionGeometry) -> np.ndarray: + """Convert ASTRA geometry objects into a SCICO projection matrix. + + Convert ASTRA volume and projection geometry into a SCICO X-ray + projection matrix, assuming "parallel3d_vec" format. + + The approach is to locate 3 points in the volume domain, + deduce the corresponding projection locations, and, then, solve a + linear system to determine the affine relationship between them. + + Args: + vol_geom: ASTRA volume geometry object. + proj_geom: ASTRA projection geometry object. + + Returns: + (num_angles, 2, 4) array of homogeneous projection matrices. + + """ + x_volume = np.concatenate((np.zeros((1, 3)), np.eye(3)), axis=0) # (4, 3) + x_dets = _project_coords(x_volume, vol_geom, proj_geom) # (1, 4, 2) + + x_volume_aug = np.concatenate((x_volume, np.ones((4, 1))), axis=1) # (4, 4) + matrices = [] + for x_det in x_dets: + M = np.linalg.solve(x_volume_aug, x_det).T + np.testing.assert_allclose(M @ x_volume_aug[0], x_det[0]) + matrices.append(M) + + return np.stack(matrices) + + +def convert_to_scico_geometry( + input_shape: Shape, + det_count: Tuple[int, int], + det_spacing: Optional[Tuple[float, float]] = None, + angles: Optional[np.ndarray] = None, + vectors: Optional[np.ndarray] = None, +) -> np.ndarray: + """Convert ASTRA geometry specificiation to a SCICO projection matrix. + + Convert ASTRA volume and projection geometry into a SCICO X-ray + projection matrix, assuming "parallel3d_vec" format. + + The approach is to locate 3 points in the volume domain, + deduce the corresponding projection locations, and, then, solve a + linear system to determine the affine relationship between them. + + Args: + vol_geom: ASTRA volume geometry object. + proj_geom: ASTRA projection geometry object. + + Returns: + (num_angles, 2, 4) array of homogeneous projection matrices. + + """ + vol_geom, proj_geom = XRayTransform3D.create_astra_geometry( + input_shape, det_count, det_spacing=det_spacing, angles=angles, vectors=vectors + ) + return _astra_to_scico_geometry(vol_geom, proj_geom) + + class XRayTransform3D(LinearOperator): # pragma: no cover r"""3D parallel beam X-ray transform based on the ASTRA toolbox. @@ -531,32 +576,25 @@ def __init__( raise ValueError("Expected det_count to be a tuple with 2 elements.") if angles is not None: Nview = angles.size - self.angles: np.ndarray = np.array(angles) + self.angles: Optional[np.ndarray] = np.array(angles) + self.vectors: Optional[np.ndarray] = None else: assert vectors is not None Nview = vectors.shape[0] - self.vectors: np.ndarray = np.array(vectors) + self.vectors = np.array(vectors) + self.angles = None output_shape: Shape = (det_count[0], Nview, det_count[1]) self.det_count = det_count assert isinstance(det_count, (list, tuple)) - if angles is not None: - assert det_spacing is not None - self.proj_geom = astra.create_proj_geom( - "parallel3d", - det_spacing[0], - det_spacing[1], - det_count[0], - det_count[1], - self.angles, - ) - else: - self.proj_geom = astra.create_proj_geom( - "parallel3d_vec", det_count[0], det_count[1], self.vectors - ) - self.input_shape: tuple = input_shape - self.vol_geom = astra.create_vol_geom(input_shape[1], input_shape[2], input_shape[0]) + self.vol_geom, self.proj_geom = self.create_astra_geometry( + input_shape, + det_count, + det_spacing=det_spacing, + angles=self.angles, + vectors=self.vectors, + ) # Wrap our non-jax function to indicate we will supply fwd/rev mode functions self._eval = jax.custom_vjp(self._proj) @@ -573,6 +611,53 @@ def __init__( jit=False, ) + @staticmethod + def create_astra_geometry( + input_shape: Shape, + det_count: Tuple[int, int], + det_spacing: Optional[Tuple[float, float]] = None, + angles: Optional[np.ndarray] = None, + vectors: Optional[np.ndarray] = None, + ) -> Tuple[VolumeGeometry, ProjectionGeometry]: + """Create ASTRA 3D geometry objects. + + Keyword arguments `det_spacing` and `angles` should be specified + to use the "parallel3d" geometry, and keyword argument `vectors` + should be specified to use the "parallel3d_vec" geometry. These + options are mutually exclusive. + + Args: + input_shape: Shape of the input array. + det_count: Number of detector elements. See the + `astra documentation `__ + for more information. + det_spacing: Spacing between detector elements. See the + `astra documentation `__ + for more information. + angles: Array of projection angles in radians. + vectors: Array of geometry specification vectors. + + Returns: + A tuple `(vol_geom, proj_geom)` of ASTRA volume geometry and + projection geometry objects. + """ + vol_geom = astra.create_vol_geom(input_shape[1], input_shape[2], input_shape[0]) + if angles is not None: + assert det_spacing is not None + proj_geom = astra.create_proj_geom( + "parallel3d", + det_spacing[0], + det_spacing[1], + det_count[0], + det_count[1], + angles, + ) + else: + proj_geom = astra.create_proj_geom( + "parallel3d_vec", det_count[0], det_count[1], vectors + ) + return vol_geom, proj_geom + def _proj(self, x: jax.Array) -> jax.Array: # apply the forward projector and generate a sinogram diff --git a/scico/test/linop/xray/test_astra.py b/scico/test/linop/xray/test_astra.py index e575fac0..c94314fd 100644 --- a/scico/test/linop/xray/test_astra.py +++ b/scico/test/linop/xray/test_astra.py @@ -260,7 +260,7 @@ def test_convert_to_scico_geometry(test_geometry): Basic regression test, `test_project_coords` tests the logic. """ vol_geom, proj_geom = test_geometry - matrices_truth = scico.linop.xray.astra.convert_to_scico_geometry(vol_geom, proj_geom) + matrices_truth = scico.linop.xray.astra._astra_to_scico_geometry(vol_geom, proj_geom) truth = np.array([[[0.0, 1.0, 0.0, -2.0], [0.0, 0.0, 1.0, -1.0]]]) np.testing.assert_allclose(matrices_truth, truth)