Skip to content

Latest commit

 

History

History
39 lines (27 loc) · 1.86 KB

README.md

File metadata and controls

39 lines (27 loc) · 1.86 KB

jaxtomo: tomographic projectors in JAX

jaxtomo implements tomographic projectors with JAX.

They are implemented purely in Python, which makes the code readable and hackable. Because JAX offers just-in-time compilation to GPU, the projectors are reasonably fast. They don't use texture memory and are slower than optimized implementations such as torch-radon.

Disclaimer

This is a personal project and very work-in-progress. It is meant as a learning exercise for me, a pedagogical implementation for others (once I add some comments), and maybe even a tool for implementing proof-of-concept pipelines.

Features

  • Parallel beam
  • Fan beam
  • Cone Beam
  • FBP
  • ... all with a flat detector
  • FP and BP registered as respective transpose for autodiff with JAX
  • End-to-end SIR via autodiff
  • jax.pmap for multi-GPU speedup

Todo

  • Valid FBP for large fan/cone angles (atm we just do Ramlak filter + BP)
  • Other FP methods (Siddon, Footprint, ...)
  • Curved detector
  • Different voxel basis functions [1], [2]
  • speedup bilinear interpolation and/or profile FP, as it's rather slow.
    According to examples/timing.py, FP takes ~5x longer than BP.
  • Try JAX Pallas

Proof of concept

Fan FBP

image

Parallel, fan, and cone projector

image