Skip to content

clemisch/jaxtomo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

84 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jaxtomo: tomographic projectors in JAX

jaxtomo implements tomographic projectors with JAX.

They are implemented purely in Python, which makes the code readable and hackable. Because JAX offers just-in-time compilation to GPU, the projectors are reasonably fast. They don't use texture memory and are slower than optimized implementations such as torch-radon.

Disclaimer

This is a personal project and very work-in-progress. It is meant as a learning exercise for me, a pedagogical implementation for others (once I add some comments), and maybe even a tool for implementing proof-of-concept pipelines.

Features

  • Parallel beam
  • Fan beam
  • Cone Beam
  • FBP
  • ... all with a flat detector
  • FP and BP registered as respective transpose for autodiff with JAX
  • End-to-end SIR via autodiff
  • jax.pmap for multi-GPU speedup

Todo

  • Valid FBP for large fan/cone angles (atm we just do Ramlak filter + BP)
  • Other FP methods (Siddon, Footprint, ...)
  • Curved detector
  • Different voxel basis functions [1], [2]
  • speedup bilinear interpolation and/or profile FP, as it's rather slow.
    According to examples/timing.py, FP takes ~5x longer than BP.
  • Try JAX Pallas

Proof of concept

Fan FBP

image

Parallel, fan, and cone projector

image

About

Tomographic projector in JAX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages