Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX 3D X-ray CT projector #529

Merged
merged 40 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
4627595
Rough in 3D X-ray projector
Michael-T-McCann May 20, 2024
453a53e
Start on example
Michael-T-McCann May 21, 2024
c51370c
Rough in scico to astra conversion
Michael-T-McCann May 21, 2024
9493ff2
Work on example, still unfinished
Michael-T-McCann Jun 4, 2024
3625324
Work on example
Michael-T-McCann Jun 7, 2024
9a4ea59
Prepare to build example
Michael-T-McCann Jun 7, 2024
caef6eb
Add example to list of GPU-required
Michael-T-McCann Jun 7, 2024
3d0dd60
Modify example script titles
bwohlberg Jun 11, 2024
b3ce7e0
Update example script titles in examples README
bwohlberg Jun 11, 2024
ac7d34c
Rename example script
bwohlberg Jun 11, 2024
10be36e
Switch to using ContextTimer for more compact code
bwohlberg Jun 11, 2024
99d2131
Update change summary
bwohlberg Jun 11, 2024
1869259
Add missing __all__ entry
bwohlberg Jun 11, 2024
7bdcce5
Update copyright date
bwohlberg Jun 11, 2024
579df96
Refactor, improve docs
Michael-T-McCann Jun 14, 2024
f4fac55
Start to address memory problem
Michael-T-McCann Jun 21, 2024
46cde7e
Split up input to address memory bottleneck
Michael-T-McCann Jun 21, 2024
8661aa9
Rough in new conversion functions
Michael-T-McCann Jul 9, 2024
0a3aaa9
Point to correct data branch
Michael-T-McCann Jul 9, 2024
00d785c
Fix detector center transposition
Michael-T-McCann Jul 10, 2024
45daca1
Update submodule
bwohlberg Jul 19, 2024
28fe495
Merge branch 'main' into mike/3d_xray
bwohlberg Jul 19, 2024
b95e398
Start on back projector
Michael-T-McCann Jul 15, 2024
8a59ea4
Add manual back projector
Michael-T-McCann Jul 22, 2024
3b6a911
Update submodule
bwohlberg Jul 24, 2024
dc5dd5b
Merge branch 'main' into mike/3d_xray
bwohlberg Jul 24, 2024
beba882
Again correct unhashable argument (ndarray -> tuple for shapes)
Michael-T-McCann Jul 24, 2024
5df1c75
Merge branch 'main' into mike/3d_xray
bwohlberg Jul 30, 2024
4add502
Merge branch 'main' into mike/3d_xray
bwohlberg Aug 15, 2024
4351941
Work on comments
Michael-T-McCann Aug 26, 2024
854d3ad
Merge branch 'main' into mike/3d_xray
bwohlberg Sep 6, 2024
eee399f
Rename new projectors to XRayTransform2D and XRayTransform3D
Michael-T-McCann Sep 9, 2024
7acbaa4
Merge branch 'main' into mike/3d_xray
bwohlberg Sep 10, 2024
0b63130
Fix import errors in examples
bwohlberg Sep 10, 2024
6c01c4d
Changes to `linop.xray.astra` module (#551)
bwohlberg Sep 10, 2024
7afd4b6
Loosen tolerances to make tests pass
Michael-T-McCann Sep 11, 2024
9bb7b68
Remove unfinished example
Michael-T-McCann Sep 11, 2024
9da091d
Rerun examples that use the new projectors
Michael-T-McCann Sep 11, 2024
1b215e9
Update submodule
Michael-T-McCann Sep 11, 2024
bcba849
Add docstrings, typing
Michael-T-McCann Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ venv.bak/
# Rope project settings
.ropeproject

# VS Code settings
.vscode/

# mkdocs documentation
/site

Expand Down
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Version 0.0.6 (unreleased)
----------------------------

• Significant changes to ``linop.xray.astra`` API.
• New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``.
• New functional ``functional.IsotropicTVNorm`` and faster implementation
of ``functional.AnisotropicTVNorm``.
• New linear operators ``linop.ProjectedGradient``, ``linop.PolarGradient``,
Expand Down
4 changes: 3 additions & 1 deletion docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ Computed Tomography
examples/ct_astra_modl_train_foam2
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_projector_comparison
examples/ct_projector_comparison_2d
examples/ct_projector_comparison_3d
examples/ct_multi_cs_tv_admm
examples/ct_multi_tv_admm

Deconvolution
Expand Down
4 changes: 4 additions & 0 deletions examples/scriptcheck.sh
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ for f in $SCRIPTPATH/scripts/*.py; do
printf "%s\n" skipped
continue
fi
if [ $SKIP_GPU -eq 1 ] && grep -q 'ct_projector_comparison_3d' <<< $f; then
printf "%s\n" skipped
continue
fi

# Create temporary copy of script with all algorithm maxiter values set
# to small number and final input statements commented out.
Expand Down
6 changes: 4 additions & 2 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ Computed Tomography
CT Training and Reconstructions with ODP
`ct_astra_unet_train_foam2.py <ct_astra_unet_train_foam2.py>`_
CT Training and Reconstructions with UNet
`ct_projector_comparison.py <ct_projector_comparison.py>`_
X-ray Transform Comparison
`ct_projector_comparison_2d.py <ct_projector_comparison_2d.py>`_
2D X-ray Transform Comparison
`ct_projector_comparison_3d.py <ct_projector_comparison_3d.py>`_
3D X-ray Transform Comparison
`ct_multi_tv_admm.py <ct_multi_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)

Expand Down
6 changes: 2 additions & 4 deletions examples/scripts/ct_multi_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.linop.xray import Parallel2dProjector, XRayTransform, astra, svmbir
from scico.linop.xray import XRayTransform2D, astra, svmbir
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

Expand All @@ -54,9 +54,7 @@
"svmbir": svmbir.XRayTransform(
x_gt.shape, 2 * np.pi - angles, det_count, delta_pixel=1.0, delta_channel=det_spacing
), # svmbir
"scico": XRayTransform(
Parallel2dProjector((N, N), angles, det_count=det_count, dx=1 / det_spacing)
), # scico
"scico": XRayTransform2D((N, N), angles, det_count=det_count, dx=1 / det_spacing), # scico
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@


r"""
X-ray Transform Comparison
==========================
2D X-ray Transform Comparison
=============================

This example compares SCICO's native X-ray transform algorithm
This example compares SCICO's native 2D X-ray transform algorithm
to that of the ASTRA toolbox.
"""

Expand All @@ -22,7 +22,7 @@

import scico.linop.xray.astra as astra
from scico import plot
from scico.linop import Parallel2dProjector, XRayTransform
from scico.linop.xray import XRayTransform2D
from scico.util import Timer

"""
Expand All @@ -46,7 +46,7 @@

projectors = {}
timer.start("scico_init")
projectors["scico"] = XRayTransform(Parallel2dProjector((N, N), angles))
projectors["scico"] = XRayTransform2D((N, N), angles)
timer.stop("scico_init")

timer.start("astra_init")
Expand Down
200 changes: 200 additions & 0 deletions examples/scripts/ct_projector_comparison_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.


r"""
3D X-ray Transform Comparison
=============================

This example shows how to define a SCICO native 3D X-ray transform using
ASTRA toolbox conventions and vice versa.
"""

import numpy as np

import jax
import jax.numpy as jnp

import scico.linop.xray.astra as astra
from scico import plot
from scico.examples import create_block_phantom
from scico.linop.xray import XRayTransform3D
from scico.util import ContextTimer, Timer

"""
Create a ground truth image and set detector dimensions.
"""
N = 64
# use rectangular volume to check whether axes are handled correctly
in_shape = (N + 1, N + 2, N + 3)
x = create_block_phantom(in_shape)
x = jnp.array(x)

# use rectangular detector to check whether axes are handled correctly
out_shape = (N, N + 1)


"""
Set up SCICO projection.
"""
num_angles = 3


Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
rot_X = 90.0 - 16.0
rot_Y = np.linspace(0, 180, num_angles, endpoint=False)
angles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1)
matrices = XRayTransform3D.matrices_from_euler_angles(
in_shape, out_shape, "XY", angles, degrees=True
)

"""
Specify geometry using SCICO conventions and project.
"""
num_repeats = 3

timer_scico = Timer()
with ContextTimer(timer_scico, "init"):
H_scico = XRayTransform3D(in_shape, matrices, out_shape)

with ContextTimer(timer_scico, "first_fwd"):
y_scico = H_scico @ x
jax.block_until_ready(y_scico)

with ContextTimer(timer_scico, "avg_fwd"):
for _ in range(num_repeats):
y_scico = H_scico @ x
jax.block_until_ready(y_scico)
timer_scico.td["avg_fwd"] /= num_repeats

with ContextTimer(timer_scico, "first_back"):
HTy_scico = H_scico.T @ y_scico

with ContextTimer(timer_scico, "avg_back"):
for _ in range(num_repeats):
HTy_scico = H_scico.T @ y_scico
jax.block_until_ready(HTy_scico)
timer_scico.td["avg_back"] /= num_repeats


"""
Convert SCICO geometry to ASTRA and project.
"""

vectors_from_scico = astra.convert_from_scico_geometry(in_shape, matrices, out_shape)

timer_astra = Timer()
with ContextTimer(timer_astra, "init"):
H_astra_from_scico = astra.XRayTransform3D(
input_shape=in_shape, det_count=out_shape, vectors=vectors_from_scico
)

with ContextTimer(timer_astra, "first_fwd"):
y_astra_from_scico = H_astra_from_scico @ x
jax.block_until_ready(y_astra_from_scico)

with ContextTimer(timer_astra, "avg_fwd"):
for _ in range(num_repeats):
y_astra_from_scico = H_astra_from_scico @ x
jax.block_until_ready(y_astra_from_scico)
timer_astra.td["avg_fwd"] /= num_repeats

with ContextTimer(timer_astra, "first_back"):
HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico

with ContextTimer(timer_astra, "avg_back"):
for _ in range(num_repeats):
HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico
jax.block_until_ready(HTy_astra_from_scico)
timer_astra.td["avg_back"] /= num_repeats


"""
Specify geometry with ASTRA conventions and project.
"""

angles = np.random.rand(num_angles) * 180 # random projection angles
det_spacing = [1.0, 1.0]
vectors = astra.angle_to_vector(det_spacing, angles)

H_astra = astra.XRayTransform3D(input_shape=in_shape, det_count=out_shape, vectors=vectors)

y_astra = H_astra @ x
HTy_astra = H_astra.T @ y_astra


"""
Convert ASTRA geometry to SCICO and project.
"""

P_from_astra = astra._astra_to_scico_geometry(H_astra.vol_geom, H_astra.proj_geom)
H_scico_from_astra = XRayTransform3D(in_shape, P_from_astra, out_shape)

y_scico_from_astra = H_scico_from_astra @ x
HTy_scico_from_astra = H_scico_from_astra.T @ y_scico_from_astra


"""
Print timing results.
"""
print(f"init astra {timer_astra.td['init']:.2e} s")
print(f"init scico {timer_scico.td['init']:.2e} s")
print("")
for tstr in ("first", "avg"):
for dstr in ("fwd", "back"):
for timer, pstr in zip((timer_astra, timer_scico), ("astra", "scico")):
print(f"{tstr:5s} {dstr:4s} {pstr} {timer.td[tstr + '_' + dstr]:.2e} s")
print()


"""
Show projections.
"""
fig, ax = plot.subplots(nrows=3, ncols=2, figsize=(8, 10))
plot.imview(y_scico[0], title="SCICO projections", cbar=None, fig=fig, ax=ax[0, 0])
plot.imview(y_scico[1], cbar=None, fig=fig, ax=ax[1, 0])
plot.imview(y_scico[2], cbar=None, fig=fig, ax=ax[2, 0])
plot.imview(y_astra_from_scico[:, 0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1])
plot.imview(y_astra_from_scico[:, 1], cbar=None, fig=fig, ax=ax[1, 1])
plot.imview(y_astra_from_scico[:, 2], cbar=None, fig=fig, ax=ax[2, 1])
fig.suptitle("Using SCICO conventions")
fig.tight_layout()
fig.show()

fig, ax = plot.subplots(nrows=3, ncols=2, figsize=(8, 10))
plot.imview(y_scico_from_astra[0], title="SCICO projections", cbar=None, fig=fig, ax=ax[0, 0])
plot.imview(y_scico_from_astra[1], cbar=None, fig=fig, ax=ax[1, 0])
plot.imview(y_scico_from_astra[2], cbar=None, fig=fig, ax=ax[2, 0])
plot.imview(y_astra[:, 0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1])
plot.imview(y_astra[:, 1], cbar=None, fig=fig, ax=ax[1, 1])
plot.imview(y_astra[:, 2], cbar=None, fig=fig, ax=ax[2, 1])
fig.suptitle("Using ASTRA conventions")
fig.tight_layout()
fig.show()


"""
Show back projections.
"""
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 5))
plot.imview(HTy_scico[N // 2], title="SCICO back projection", cbar=None, fig=fig, ax=ax[0])
plot.imview(
HTy_astra_from_scico[N // 2], title="ASTRA back projection", cbar=None, fig=fig, ax=ax[1]
)
fig.suptitle("Using SCICO conventions")
fig.tight_layout()
fig.show()

fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 5))
plot.imview(
HTy_scico_from_astra[N // 2], title="SCICO back projection", cbar=None, fig=fig, ax=ax[0]
)
plot.imview(HTy_astra[N // 2], title="ASTRA back projection", cbar=None, fig=fig, ax=ax[1])
fig.suptitle("Using ASTRA conventions")
fig.tight_layout()
fig.show()


input("\nWaiting for input to close figures and exit")
4 changes: 2 additions & 2 deletions examples/scripts/ct_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.linop.xray import Parallel2dProjector, XRayTransform
from scico.linop.xray import XRayTransform2D
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

Expand All @@ -46,7 +46,7 @@
"""
n_projection = 45 # number of projections
angles = np.linspace(0, np.pi, n_projection) + np.pi / 2.0 # evenly spaced projection angles
A = XRayTransform(Parallel2dProjector((N, N), angles)) # CT projection operator
A = XRayTransform2D((N, N), angles) # CT projection operator
y = A @ x_gt # sinogram


Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ Computed Tomography
- ct_astra_modl_train_foam2.py
- ct_astra_odp_train_foam2.py
- ct_astra_unet_train_foam2.py
- ct_projector_comparison.py
- ct_projector_comparison_2d.py
- ct_projector_comparison_3d.py
- ct_multi_tv_admm.py

Deconvolution
Expand Down
Loading
Loading