Skip to content

Commit

Permalink
Resolve #535 (#541)
Browse files Browse the repository at this point in the history
* Change np random generation to jax np

* Bump supported jaxlib/jax versions

* Avoid doctest errors resulting from unimportable astra or svmbir

* Avoid doctest failure due to changes in numpy 2.0

* Address FutureWarning in 2D x-ray

* Update submodule

* Proposed changes to #541 (#543)

* Add type annotation

* Remove jax distributed data generation option

* Clean up

* Extend docs

* Add additional test for exception state

* Tracer conversion error fix from Cristina

* Omitted import

* Clean up

* Consistent phrasing

* Clean up some f-strings

* Add missing ray init

* Set dtype

* Fix indentation error

* Update module docstring

* Experimental solution to ray/jax failure

* Bug fix

* Improve docstring

* Implement hack to resolve jax/ray conflict

* Debug attempt

* New solution attempt

* Debug attempt

* Return to earlier approach

* Extend comment

* Clean up and improve function logic

* Address some problems

* Clean up

* Rename function for consistency with related functions

* Bug fix

* Clean up

* Bug fix

* Address pylint complaint

* Revert unworkable structure

* Error message fix

* Address mypy errors

---------

Co-authored-by: Brendt Wohlberg <brendt@ieee.org>
Co-authored-by: Michael-T-McCann <michael.thompson.mccann@gmail.com>
Co-authored-by: Brendt Wohlberg <bwohlberg@users.noreply.github.com>
  • Loading branch information
4 people authored Jul 24, 2024
1 parent 1ffcbcf commit 0dff98d
Show file tree
Hide file tree
Showing 15 changed files with 252 additions and 327 deletions.
4 changes: 2 additions & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ Version 0.0.6 (unreleased)
• Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to
``scico.flax.save_variables`` and ``scico.flax.load_variables``
respectively.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.28.
• Support ``flax`` versions between 0.8.0 and 0.8.3 (inclusive).
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.30.
• Support ``flax`` versions 0.8.0 to 0.8.3.



Expand Down
6 changes: 6 additions & 0 deletions examples/scripts/ct_astra_datagen_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
generated using filtered back projection (FBP).
"""

# isort: off
import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

from scico import plot
from scico.flax.examples import load_ct_data

Expand Down
6 changes: 6 additions & 0 deletions examples/scripts/ct_astra_modl_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,18 @@
reconstructed images.
"""

# isort: off
import os
from functools import partial
from time import time

import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand Down
15 changes: 10 additions & 5 deletions examples/scripts/ct_astra_odp_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,21 @@
term. The output of the final stage is the set of reconstructed images.
"""

# isort: off
import os
from functools import partial
from time import time

import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand All @@ -60,11 +69,7 @@
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform2D

"""
Prepare parallel processing. Set an arbitrary processor count (only
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)

Expand Down
21 changes: 14 additions & 7 deletions examples/scripts/ct_astra_unet_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,29 @@
by :cite:`jin-2017-unet`.
"""

# isort: off
import os
from time import time

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax

import numpy as np

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_ct_data

"""
Prepare parallel processing. Set an arbitrary processor count (only
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)

Expand Down Expand Up @@ -190,7 +197,7 @@
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
jax.numpy.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
x=hist.Epoch,
ptyp="semilogy",
title="Loss function",
Expand All @@ -201,7 +208,7 @@
ax=ax[0],
)
plot.plot(
jax.numpy.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
x=hist.Epoch,
title="Metric",
xlbl="Epoch",
Expand Down
11 changes: 9 additions & 2 deletions examples/scripts/deconv_datagen_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,17 @@
training neural network models for deconvolution (deblurring). Foam
phantoms from xdesign are used to generate the clean images.
"""

# isort: off
import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

from scico import plot
from scico.flax.examples import load_foam1_blur_data
from scico.flax.examples import load_blur_data

"""
Read data from cache or generate if not available.
Expand All @@ -29,7 +36,7 @@
nimg = train_nimg + test_nimg
output_size = 256 # image size

train_ds, test_ds = load_foam1_blur_data(
train_ds, test_ds = load_blur_data(
train_nimg,
test_nimg,
output_size,
Expand Down
19 changes: 12 additions & 7 deletions examples/scripts/deconv_modl_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,32 @@
images.
"""

# isort: off
import os
from functools import partial
from time import time

import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_foam1_blur_data
from scico.flax.examples import load_blur_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop import CircularConvolve

"""
Prepare parallel processing. Set an arbitrary processor count (only
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)

Expand All @@ -87,7 +92,7 @@
test_nimg = 64 # number of testing images
nimg = train_nimg + test_nimg

train_ds, test_ds = load_foam1_blur_data(
train_ds, test_ds = load_blur_data(
train_nimg,
test_nimg,
output_size,
Expand Down
19 changes: 12 additions & 7 deletions examples/scripts/deconv_odp_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,32 @@
set of deblurred images.
"""

# isort: off
import os
from functools import partial
from time import time

import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_foam1_blur_data
from scico.flax.examples import load_blur_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop import CircularConvolve

"""
Prepare parallel processing. Set an arbitrary processor count (only
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)

Expand All @@ -95,7 +100,7 @@
test_nimg = 64 # number of testing images
nimg = train_nimg + test_nimg

train_ds, test_ds = load_foam1_blur_data(
train_ds, test_ds = load_blur_data(
train_nimg,
test_nimg,
output_size,
Expand Down
10 changes: 5 additions & 5 deletions examples/scripts/denoise_dncnn_train_bsds.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
with additive Gaussian noise.
"""

# isort: off
import os
from time import time

import numpy as np

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand All @@ -26,11 +30,7 @@
from scico import metric, plot
from scico.flax.examples import load_image_data

"""
Prepare parallel processing. Set an arbitrary processor count (only
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ scipy>=1.6.0
imageio>=2.17
tifffile
matplotlib
jaxlib>=0.4.3,<=0.4.28
jax>=0.4.3,<=0.4.28
jaxlib>=0.4.3,<=0.4.30
jax>=0.4.3,<=0.4.30
orbax-checkpoint<=0.5.7
flax>=0.8.0,<=0.8.3
pyabel>=0.9.0
4 changes: 2 additions & 2 deletions scico/flax/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
"""Data utility functions used by Flax example scripts."""

from .data_preprocessing import PaddedCircularConvolve, build_blur_kernel
from .examples import load_ct_data, load_foam1_blur_data, load_image_data
from .examples import load_blur_data, load_ct_data, load_image_data

__all__ = [
"load_ct_data",
"load_foam1_blur_data",
"load_blur_data",
"load_image_data",
"PaddedCircularConvolve",
"build_blur_kernel",
Expand Down
Loading

0 comments on commit 0dff98d

Please sign in to comment.