diff --git a/CHANGES.rst b/CHANGES.rst index e9e3e0301..4be84b0cb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -16,8 +16,8 @@ Version 0.0.6 (unreleased) • Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to ``scico.flax.save_variables`` and ``scico.flax.load_variables`` respectively. -• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.28. -• Support ``flax`` versions between 0.8.0 and 0.8.3 (inclusive). +• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.30. +• Support ``flax`` versions 0.8.0 to 0.8.3. diff --git a/examples/scripts/ct_astra_datagen_foam2.py b/examples/scripts/ct_astra_datagen_foam2.py index 76c19da57..4e6fb97c8 100644 --- a/examples/scripts/ct_astra_datagen_foam2.py +++ b/examples/scripts/ct_astra_datagen_foam2.py @@ -13,8 +13,14 @@ generated using filtered back projection (FBP). """ +# isort: off import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + from scico import plot from scico.flax.examples import load_ct_data diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_astra_modl_train_foam2.py index dc25ebb38..be218d5eb 100644 --- a/examples/scripts/ct_astra_modl_train_foam2.py +++ b/examples/scripts/ct_astra_modl_train_foam2.py @@ -40,12 +40,18 @@ reconstructed images. """ +# isort: off import os from functools import partial from time import time import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable diff --git a/examples/scripts/ct_astra_odp_train_foam2.py b/examples/scripts/ct_astra_odp_train_foam2.py index 8c5d9ad61..4a8355e36 100644 --- a/examples/scripts/ct_astra_odp_train_foam2.py +++ b/examples/scripts/ct_astra_odp_train_foam2.py @@ -44,12 +44,21 @@ term. The output of the final stage is the set of reconstructed images. """ +# isort: off import os from functools import partial from time import time import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -60,11 +69,7 @@ from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop.xray.astra import XRayTransform2D -""" -Prepare parallel processing. Set an arbitrary processor count (only -applies if GPU is not available). -""" -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + platform = jax.lib.xla_bridge.get_backend().platform print("Platform: ", platform) diff --git a/examples/scripts/ct_astra_unet_train_foam2.py b/examples/scripts/ct_astra_unet_train_foam2.py index 72e82e81d..bae623b3b 100644 --- a/examples/scripts/ct_astra_unet_train_foam2.py +++ b/examples/scripts/ct_astra_unet_train_foam2.py @@ -13,22 +13,29 @@ by :cite:`jin-2017-unet`. """ +# isort: off import os from time import time +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax +import numpy as np + from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax from scico import metric, plot from scico.flax.examples import load_ct_data -""" -Prepare parallel processing. Set an arbitrary processor count (only -applies if GPU is not available). -""" -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + platform = jax.lib.xla_bridge.get_backend().platform print("Platform: ", platform) @@ -190,7 +197,7 @@ hist = stats_object.history(transpose=True) fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5)) plot.plot( - jax.numpy.vstack((hist.Train_Loss, hist.Eval_Loss)).T, + np.vstack((hist.Train_Loss, hist.Eval_Loss)).T, x=hist.Epoch, ptyp="semilogy", title="Loss function", @@ -201,7 +208,7 @@ ax=ax[0], ) plot.plot( - jax.numpy.vstack((hist.Train_SNR, hist.Eval_SNR)).T, + np.vstack((hist.Train_SNR, hist.Eval_SNR)).T, x=hist.Epoch, title="Metric", xlbl="Epoch", diff --git a/examples/scripts/deconv_datagen_foam1.py b/examples/scripts/deconv_datagen_foam1.py index 99914e537..f80bee5fd 100644 --- a/examples/scripts/deconv_datagen_foam1.py +++ b/examples/scripts/deconv_datagen_foam1.py @@ -12,10 +12,17 @@ training neural network models for deconvolution (deblurring). Foam phantoms from xdesign are used to generate the clean images. """ + +# isort: off import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + from scico import plot -from scico.flax.examples import load_foam1_blur_data +from scico.flax.examples import load_blur_data """ Read data from cache or generate if not available. @@ -29,7 +36,7 @@ nimg = train_nimg + test_nimg output_size = 256 # image size -train_ds, test_ds = load_foam1_blur_data( +train_ds, test_ds = load_blur_data( train_nimg, test_nimg, output_size, diff --git a/examples/scripts/deconv_modl_train_foam1.py b/examples/scripts/deconv_modl_train_foam1.py index 2916d7a1f..69b19d939 100644 --- a/examples/scripts/deconv_modl_train_foam1.py +++ b/examples/scripts/deconv_modl_train_foam1.py @@ -41,27 +41,32 @@ images. """ +# isort: off import os from functools import partial from time import time import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax from scico import metric, plot -from scico.flax.examples import load_foam1_blur_data +from scico.flax.examples import load_blur_data from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop import CircularConvolve -""" -Prepare parallel processing. Set an arbitrary processor count (only -applies if GPU is not available). -""" -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + platform = jax.lib.xla_bridge.get_backend().platform print("Platform: ", platform) @@ -87,7 +92,7 @@ test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg -train_ds, test_ds = load_foam1_blur_data( +train_ds, test_ds = load_blur_data( train_nimg, test_nimg, output_size, diff --git a/examples/scripts/deconv_odp_train_foam1.py b/examples/scripts/deconv_odp_train_foam1.py index ffe852e7a..9887fe894 100644 --- a/examples/scripts/deconv_odp_train_foam1.py +++ b/examples/scripts/deconv_odp_train_foam1.py @@ -49,27 +49,32 @@ set of deblurred images. """ +# isort: off import os from functools import partial from time import time import numpy as np +import logging +import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable from scico import flax as sflax from scico import metric, plot -from scico.flax.examples import load_foam1_blur_data +from scico.flax.examples import load_blur_data from scico.flax.train.traversals import clip_positive, construct_traversal from scico.linop import CircularConvolve -""" -Prepare parallel processing. Set an arbitrary processor count (only -applies if GPU is not available). -""" -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + platform = jax.lib.xla_bridge.get_backend().platform print("Platform: ", platform) @@ -95,7 +100,7 @@ test_nimg = 64 # number of testing images nimg = train_nimg + test_nimg -train_ds, test_ds = load_foam1_blur_data( +train_ds, test_ds = load_blur_data( train_nimg, test_nimg, output_size, diff --git a/examples/scripts/denoise_dncnn_train_bsds.py b/examples/scripts/denoise_dncnn_train_bsds.py index a55df76d0..ac9fcb755 100644 --- a/examples/scripts/denoise_dncnn_train_bsds.py +++ b/examples/scripts/denoise_dncnn_train_bsds.py @@ -13,11 +13,15 @@ with additive Gaussian noise. """ +# isort: off import os from time import time import numpy as np +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -26,11 +30,7 @@ from scico import metric, plot from scico.flax.examples import load_image_data -""" -Prepare parallel processing. Set an arbitrary processor count (only -applies if GPU is not available). -""" -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + platform = jax.lib.xla_bridge.get_backend().platform print("Platform: ", platform) diff --git a/requirements.txt b/requirements.txt index 1946b4295..cb78a3136 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,8 +4,8 @@ scipy>=1.6.0 imageio>=2.17 tifffile matplotlib -jaxlib>=0.4.3,<=0.4.28 -jax>=0.4.3,<=0.4.28 +jaxlib>=0.4.3,<=0.4.30 +jax>=0.4.3,<=0.4.30 orbax-checkpoint<=0.5.7 flax>=0.8.0,<=0.8.3 pyabel>=0.9.0 diff --git a/scico/flax/examples/__init__.py b/scico/flax/examples/__init__.py index 5a8d0d450..8ca7b1821 100644 --- a/scico/flax/examples/__init__.py +++ b/scico/flax/examples/__init__.py @@ -8,11 +8,11 @@ """Data utility functions used by Flax example scripts.""" from .data_preprocessing import PaddedCircularConvolve, build_blur_kernel -from .examples import load_ct_data, load_foam1_blur_data, load_image_data +from .examples import load_blur_data, load_ct_data, load_image_data __all__ = [ "load_ct_data", - "load_foam1_blur_data", + "load_blur_data", "load_image_data", "PaddedCircularConvolve", "build_blur_kernel", diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index a339435ec..ca7d3c073 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -7,33 +7,40 @@ """Functionality to generate training data for Flax example scripts. -Computation is distributed via ray (if available) or JAX or to reduce -processing time. +Computation is distributed via ray to reduce processing time. """ -import os -import warnings +from functools import partial from time import time from typing import Callable, List, Tuple, Union import numpy as np -try: - import ray # noqa: F401 -except ImportError: - have_ray = False -else: - have_ray = True - try: import xdesign # noqa: F401 except ImportError: have_xdesign = False + + # pylint: disable=missing-class-docstring + class UnitCircle: + pass + + # pylint: enable=missing-class-docstring else: have_xdesign = True + from xdesign import ( # type: ignore + Foam, + SimpleMaterial, + UnitCircle, + discrete_phantom, + ) -if have_xdesign: - from xdesign import Foam, SimpleMaterial, UnitCircle, discrete_phantom +try: + import ray # noqa: F401 +except ImportError: + have_ray = False +else: + have_ray = True import jax import jax.numpy as jnp @@ -47,62 +54,54 @@ have_astra = False else: have_astra = True - -if have_astra: from scico.linop.xray.astra import XRayTransform2D -# Arbitrary process count: only applies if GPU is not available. -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" - - -if have_xdesign: - - class Foam2(UnitCircle): - """Foam-like material with two attenuations. - - Define functionality to generate phantom with structure similar - to foam with two different attenuation properties.""" - - def __init__( - self, - size_range: Union[float, List[float]] = [0.05, 0.01], - gap: float = 0, - porosity: float = 1, - attn1: float = 1.0, - attn2: float = 10.0, - ): - """Foam-like structure with two different attenuations. - Circles for material 1 are more sparse than for material 2 - by design. - - Args: - size_range: The radius, or range of radius, of the - circles to be added. Default: [0.05, 0.01]. - gap: Minimum distance between circle boundaries. - Default: 0. - porosity: Target porosity. Must be a value between - [0, 1]. Default: 1. - attn1: Mass attenuation parameter for material 1. - Default: 1. - attn2: Mass attenuation parameter for material 2. - Default: 10. - """ - super(Foam2, self).__init__(radius=0.5, material=SimpleMaterial(attn1)) - if porosity < 0 or porosity > 1: - raise ValueError("Porosity must be in the range [0,1).") - self.sprinkle( - 300, size_range, gap, material=SimpleMaterial(attn2), max_density=porosity / 2.0 - ) + self.sprinkle( - 300, size_range, gap, material=SimpleMaterial(20), max_density=porosity - ) - - -def generate_foam2_images(seed: float, size: int, ndata: int) -> Array: - """Generate batch of foam2 structures. +class Foam2(UnitCircle): + """Foam-like material with two attenuations. + + Define functionality to generate phantom with structure similar + to foam with two different attenuation properties.""" + + def __init__( + self, + size_range: Union[float, List[float]] = [0.05, 0.01], + gap: float = 0, + porosity: float = 1, + attn1: float = 1.0, + attn2: float = 10.0, + ): + """Foam-like structure with two different attenuations. + Circles for material 1 are more sparse than for material 2 + by design. + + Args: + size_range: The radius, or range of radius, of the + circles to be added. Default: [0.05, 0.01]. + gap: Minimum distance between circle boundaries. + Default: 0. + porosity: Target porosity. Must be a value between + [0, 1]. Default: 1. + attn1: Mass attenuation parameter for material 1. + Default: 1. + attn2: Mass attenuation parameter for material 2. + Default: 10. + """ + if porosity < 0 or porosity > 1: + raise ValueError("Porosity must be in the range [0,1).") + super().__init__(radius=0.5, material=SimpleMaterial(attn1)) # type: ignore + self.sprinkle( # type: ignore + 300, size_range, gap, material=SimpleMaterial(attn2), max_density=porosity / 2.0 + ) + self.sprinkle( # type: ignore + 300, size_range, gap, material=SimpleMaterial(20), max_density=porosity + ) - Generate batch of images with :class:`Foam2` structure - (foam-like material with two different attenuations). + +def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray: + """Generate batch of xdesign foam-like structures. + + Generate batch of images with `xdesign` foam-like structure, which + uses one attenuation. Args: seed: Seed for data generation. @@ -114,23 +113,20 @@ def generate_foam2_images(seed: float, size: int, ndata: int) -> Array: """ if not have_xdesign: raise RuntimeError("Package xdesign is required for use of this function.") - - # np.random.seed(seed) - saux = jnp.zeros((ndata, size, size, 1)) + np.random.seed(seed) + saux = np.zeros((ndata, size, size, 1), dtype=np.float32) for i in range(ndata): - foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) - saux = saux.at[i, ..., 0].set(discrete_phantom(foam, size=size)) - # normalize - saux = saux / jnp.max(saux, axis=(1, 2), keepdims=True) + foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) + saux[i, ..., 0] = discrete_phantom(foam, size=size) return saux -def generate_foam1_images(seed: float, size: int, ndata: int) -> Array: - """Generate batch of xdesign foam-like structures. +def generate_foam2_images(seed: float, size: int, ndata: int) -> np.ndarray: + """Generate batch of foam2 structures. - Generate batch of images with `xdesign` foam-like structure, which - uses one attenuation. + Generate batch of images with :class:`Foam2` structure + (foam-like material with two different attenuations). Args: seed: Seed for data generation. @@ -142,12 +138,13 @@ def generate_foam1_images(seed: float, size: int, ndata: int) -> Array: """ if not have_xdesign: raise RuntimeError("Package xdesign is required for use of this function.") - - # np.random.seed(seed) - saux = jnp.zeros((ndata, size, size, 1)) + np.random.seed(seed) + saux = np.zeros((ndata, size, size, 1), dtype=np.float32) for i in range(ndata): - foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) - saux = saux.at[i, ..., 0].set(discrete_phantom(foam, size=size)) + foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1) + saux[i, ..., 0] = discrete_phantom(foam, size=size) + # normalize + saux /= np.max(saux, axis=(1, 2), keepdims=True) return saux @@ -180,7 +177,11 @@ def batched_f(f_: Callable, vr: Array) -> Array: evaluation preserves the batch axis. """ nproc = jax.device_count() - res = jax.pmap(lambda i: vector_f(f_, vr[i]))(jnp.arange(nproc)) + if vr.shape[0] != nproc: + vrr = vr.reshape((nproc, -1, *vr.shape[:1])) + else: + vrr = vr + res = jax.pmap(partial(vector_f, f_))(vrr) return res @@ -191,8 +192,7 @@ def generate_ct_data( imgfunc: Callable = generate_foam2_images, seed: int = 1234, verbose: bool = False, - prefer_ray: bool = True, -) -> Tuple[Array, ...]: +) -> Tuple[Array, Array, Array]: """Generate batch of computed tomography (CT) data. Generate batch of CT data for training of machine learning network @@ -205,9 +205,6 @@ def generate_ct_data( imgfunc: Function for generating input images (e.g. foams). seed: Seed for data generation. verbose: Flag indicating whether to print status messages. - Default: ``False``. - prefer_ray: Use ray for distributed processing if available. - Default: ``True``. Returns: tuple: A tuple (img, sino, fbp) containing: @@ -216,19 +213,16 @@ def generate_ct_data( - **sino** : (:class:`jax.Array`): Corresponding sinograms. - **fbp** : (:class:`jax.Array`) Corresponding filtered back projections. """ - if not have_astra: - raise RuntimeError("Package astra is required for use of this function.") + if not (have_ray and have_xdesign and have_astra): + raise RuntimeError( + "Packages ray, xdesign, and astra are required for use of this function." + ) # Generate input data. - if have_ray and prefer_ray: - start_time = time() - img = ray_distributed_data_generation(imgfunc, size, nimg, seed) - time_dtgen = time() - start_time - else: - start_time = time() - img = distributed_data_generation(imgfunc, size, nimg, False) - time_dtgen = time() - start_time - # Clip to [0,1] range. + start_time = time() + img = distributed_data_generation(imgfunc, size, nimg, seed) + time_dtgen = time() - start_time + # clip to [0,1] range img = jnp.clip(img, 0, 1) nproc = jax.device_count() @@ -236,13 +230,13 @@ def generate_ct_data( # Configure a CT projection operator to generate synthetic measurements. angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles gt_sh = (size, size) - detector_spacing = 1 + detector_spacing = 1.0 A = XRayTransform2D(gt_sh, size, detector_spacing, angles) # X-ray transform operator # Compute sinograms in parallel. start_time = time() if nproc > 1: - # Shard array + # shard array imgshd = img.reshape((nproc, -1, size, size, 1)) sinoshd = batched_f(A, imgshd) sino = sinoshd.reshape((-1, nproj, size, 1)) @@ -281,11 +275,10 @@ def generate_blur_data( size: int, blur_kernel: Array, noise_sigma: float, - imgfunc: Callable, + imgfunc: Callable = generate_foam1_images, seed: int = 4321, verbose: bool = False, - prefer_ray: bool = True, -) -> Tuple[Array, ...]: +) -> Tuple[Array, Array]: """Generate batch of blurred data. Generate batch of blurred data for training of machine learning @@ -299,9 +292,6 @@ def generate_blur_data( imgfunc: Function to generate foams. seed: Seed for data generation. verbose: Flag indicating whether to print status messages. - Default: ``False``. - prefer_ray: Use ray for distributed processing if available. - Default: ``True``. Returns: tuple: A tuple (img, blurn) containing: @@ -309,14 +299,11 @@ def generate_blur_data( - **img** : Generated foam images. - **blurn** : Corresponding blurred and noisy images. """ - if have_ray and prefer_ray: - start_time = time() - img = ray_distributed_data_generation(imgfunc, size, nimg, seed) - time_dtgen = time() - start_time - else: - start_time = time() - img = distributed_data_generation(imgfunc, size, nimg, False) - time_dtgen = time() - start_time + if not (have_ray and have_xdesign): + raise RuntimeError("Packages ray and xdesign are required for use of this function.") + start_time = time() + img = distributed_data_generation(imgfunc, size, nimg, seed) + time_dtgen = time() - start_time # Clip to [0,1] range. img = jnp.clip(img, 0, 1) @@ -356,76 +343,56 @@ def generate_blur_data( def distributed_data_generation( - imgenf: Callable, size: int, nimg: int, sharded: bool = True -) -> Array: - """Data generation distributed among processes using jax. - - Args: - imagenf: Function for batch-data generation. - size: Size of image to generate. - ndata: Number of images to generate. - sharded: Flag to indicate if data is to be returned as the - chunks generated by each process or consolidated. - Default: ``True``. - - Returns: - Array of generated data. - """ - nproc = jax.device_count() - seeds = jnp.arange(nproc) - if nproc > 1 and nimg % nproc > 0: - raise ValueError("Number of images to generate must be divisible by the number of devices") - - ndata_per_proc = int(nimg // nproc) - - idx = np.arange(nproc) - imgs = jax.vmap(imgenf, (0, None, None))(idx, size, ndata_per_proc) - - # imgs = jax.pmap(imgenf, static_broadcasted_argnums=(1, 2))(seeds, size, ndata_per_proc) - - if not sharded: - imgs = imgs.reshape((-1, size, size, 1)) - - return imgs - - -def ray_distributed_data_generation( imgenf: Callable, size: int, nimg: int, seedg: float = 123 -) -> Array: +) -> np.ndarray: """Data generation distributed among processes using ray. + *Warning:* callable `imgenf` should not make use of any jax functions + to avoid the risk of errors when running with GPU devices, in which + case jax is initialized to expect the availability of GPUs, which are + then not available within the `ray.remote` function due to the absence + of any declared GPUs as a `num_gpus` parameter of `@ray.remote`. + Args: imagenf: Function for batch-data generation. size: Size of image to generate. ndata: Number of images to generate. - seedg: Base seed for data generation. Default: 123. + seedg: Base seed for data generation. Returns: Array of generated data. """ if not have_ray: raise RuntimeError("Package ray is required for use of this function.") + if not ray.is_initialized(): + raise RuntimeError("Ray must be initialized via ray.init() before calling this function.") - @ray.remote - def data_gen(seed, size, ndata, imgf): - return imgf(seed, size, ndata) - - # Use half of available CPU resources. + # Use half of available CPU resources ar = ray.available_resources() - if "CPU" not in ar: - warnings.warn("No CPU key in ray.available_resources() output") - nproc = max(int(ar.get("CPU", "1")) // 2, 1) - # nproc = max(int(ar["CPU"]) // 2, 1) + nproc = max(int(ar.get("CPU", 1)) // 2, 1) + + # Attempt to avoid ray/jax conflicts. This solution is a nasty hack that + # can severely limit parallel execution (since ray will ensure that only + # as many actors as available GPUs are created), and is expected to be + # rather brittle. + if "GPU" in ar: + num_gpus = 1 + nproc = min(nproc, int(ar.get("GPU"))) + else: + num_gpus = 0 + if nproc > nimg: nproc = nimg - if nproc > 1 and nimg % nproc > 0: - raise ValueError( - f"Number of images to generate ({nimg}) " - f"must be divisible by the number of available devices ({nproc})" - ) + if nimg % nproc > 0: + # Increase nimg to be a multiple of nproc if it isn't already + nimg = (nimg // nproc + 1) * nproc ndata_per_proc = int(nimg // nproc) + @ray.remote(num_gpus=num_gpus) + def data_gen(seed, size, ndata, imgf): + return imgf(seed, size, ndata) + ray_return = ray.get( [data_gen.remote(seed + seedg, size, ndata_per_proc, imgenf) for seed in range(nproc)] ) diff --git a/scico/flax/examples/examples.py b/scico/flax/examples/examples.py index 0bf428a33..82c710aae 100644 --- a/scico/flax/examples/examples.py +++ b/scico/flax/examples/examples.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -16,7 +16,7 @@ from scico.numpy import Array from scico.typing import Shape -from .data_generation import generate_blur_data, generate_ct_data, generate_foam1_images +from .data_generation import generate_blur_data, generate_ct_data from .data_preprocessing import ConfigImageSetDict, build_image_dataset, get_bsds_data from .typed_dict import CTDataSetDict @@ -49,7 +49,6 @@ def load_ct_data( nproj: int, cache_path: Optional[str] = None, verbose: bool = False, - prefer_ray: bool = True, ) -> Tuple[CTDataSetDict, ...]: # pragma: no cover """ Load or generate CT data. @@ -77,8 +76,6 @@ def load_ct_data( Default: ``None``. verbose: Flag indicating whether to print status messages. Default: ``False``. - prefer_ray: Use ray for distributed processing if available. - Default: ``True``. Returns: tuple: A tuple (trdt, ttdt) containing: @@ -146,7 +143,6 @@ def load_ct_data( size, nproj, verbose=verbose, - prefer_ray=prefer_ray, ) # Separate training and testing partitions. trdt = {"img": img[:train_nimg], "sino": sino[:train_nimg], "fbp": fbp[:train_nimg]} @@ -178,7 +174,7 @@ def load_ct_data( return trdt, ttdt -def load_foam1_blur_data( +def load_blur_data( train_nimg: int, test_nimg: int, size: int, @@ -186,7 +182,6 @@ def load_foam1_blur_data( noise_sigma: float, cache_path: Optional[str] = None, verbose: bool = False, - prefer_ray: bool = True, ) -> Tuple[DataSetDict, ...]: # pragma: no cover """Load or generate blurred data based on xdesign foam structures. @@ -214,8 +209,6 @@ def load_foam1_blur_data( Default: ``None``. verbose: Flag indicating whether to print status messages. Default: ``False``. - prefer_ray: Use ray for distributed processing if available. - Default: ``True``. Returns: tuple: A tuple (train_ds, test_ds) containing: @@ -295,9 +288,7 @@ def load_foam1_blur_data( size, blur_kernel, noise_sigma, - imgfunc=generate_foam1_images, verbose=verbose, - prefer_ray=prefer_ray, ) # Separate training and testing partitions. train_ds = {"image": blrn[:train_nimg], "label": img[:train_nimg]} @@ -588,7 +579,7 @@ def print_input_path(path_display: str): # pragma: no cover Args: path_display: Path for loading data. """ - print(f"{'Data read from path':26s}{':':4s}{path_display}") + print(f"Data read from path: {path_display}") def print_output_path(path_display: str): # pragma: no cover @@ -597,7 +588,7 @@ def print_output_path(path_display: str): # pragma: no cover Args: path_display: Path for storing data. """ - print(f"{'Storing data in path':26s}{':':4s}{path_display}") + print(f"Storing data in path: {path_display}") def print_data_range(idstring: str, data: Array): # pragma: no cover @@ -607,11 +598,7 @@ def print_data_range(idstring: str, data: Array): # pragma: no cover idstring: Data descriptive string. data: Array to compute min and max. """ - print( - f"{'Data range --':10s}{idstring}{'--':5s}{':':5s}" - f"{'Min:':6s}{data.min():>5.2f}" - f"{', Max:':6s}{data.max():>5.2f}" - ) + print(f"Data range --{idstring}-- Min: {data.min():>5.2f} " f"Max: {data.max():>5.2f}") def print_data_size(idstring: str, size: int): # pragma: no cover @@ -621,7 +608,7 @@ def print_data_size(idstring: str, size: int): # pragma: no cover idstring: Data descriptive string. size: Integer representing size of a set. """ - print(f"{'Set --':3s}{idstring}{'--':12s}{':':4s}{'Size:':8s}{size}") + print(f"Set --{idstring}-- size: {size}") def print_info( @@ -656,9 +643,8 @@ def print_data_warning(idstring: str, requested: int, available: int): # pragma available: Size of data set available. """ print( - f"{'Not enough images sampled in ':10s}{idstring}" - f"{' file':6s}{'Requested :':14s}{requested}" - f"{' Available :':14s}{available}" + f"Not enough images sampled in {idstring} file. " + f"Requested: {requested} Available: {available}" ) @@ -677,10 +663,9 @@ def runtime_error_scalar( available: Parameter value available in data. """ raise RuntimeError( - f"{'Requested parameter --':15s}{type}{'-- :':7s}{requested}" - f"{' does not match parameter read from '}" - f"{idstring}{' file :':10s}{available}." - f"\nDelete cache and check data source." + f"Requested value of parameter --{type}-- does not match value " + f"read from {idstring} file. Requested: {requested} Available: " + f"{available}.\nDelete cache and check data source." ) @@ -697,8 +682,7 @@ def runtime_error_array(type: str, idstring: str, maxdiff: float): entries. """ raise RuntimeError( - f"{'Requested parameter --':15s}{type}{'--'}" - f"{' does not match parameter read from '}" - f"{idstring}{' file'}. Maximum array difference: {maxdiff:>5.3f}." - f"\nDelete cache and check data source." + f"Requested value of parameter --{type}-- does not match value " + f"read from {idstring} file. Maximum array difference: " + f"{maxdiff:>5.3f}.\nDelete cache and check data source." ) diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index ee720c7ca..645f5ca49 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -97,16 +97,16 @@ def __init__( self.nx = np.array(im_shape) if dx is None: - dx = np.full((2,), np.sqrt(2) / 2) + dx = np.full(2, np.sqrt(2) / 2) if is_scalar_equiv(dx): - dx = dx * np.ones(2) + dx = np.full(2, dx) self.dx = dx # check projected pixel width assumption Pdx = np.stack((dx[0] * jnp.cos(angles), dx[1] * jnp.sin(angles))) Pdiag1 = np.abs(Pdx[0] + Pdx[1]) Pdiag2 = np.abs(Pdx[0] - Pdx[1]) - max_width = np.max(np.maximum(Pdiag1, Pdiag2)) + max_width: float = np.max(np.maximum(Pdiag1, Pdiag2)) if max_width > 1: warn( @@ -134,7 +134,7 @@ def project(self, im): def back_project(self, y): """Compute X-ray back projection""" - return _back_project(y, self.x0, self.dx, tuple(self.nx), self.y0, self.angles) + return _back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) @partial(jax.jit, static_argnames=["ny"]) @@ -195,7 +195,7 @@ def _back_project(y, x0, dx, nx, y0, angles): return HTy -@partial(jax.jit, static_argnames=["nx", "y0"]) +@partial(jax.jit, static_argnames=["nx"]) @partial(jax.vmap, in_axes=(None, None, None, 0, None)) def _calc_weights(x0, dx, nx, angle, y0): """ diff --git a/scico/test/flax/test_examples_flax.py b/scico/test/flax/test_examples_flax.py index aafa2359b..72c084dd2 100644 --- a/scico/test/flax/test_examples_flax.py +++ b/scico/test/flax/test_examples_flax.py @@ -3,8 +3,6 @@ import numpy as np -import jax - import pytest from scico import random @@ -12,10 +10,11 @@ distributed_data_generation, generate_blur_data, generate_ct_data, + generate_foam1_images, + generate_foam2_images, have_astra, have_ray, have_xdesign, - ray_distributed_data_generation, ) from scico.flax.examples.data_preprocessing import ( CenterCrop, @@ -41,84 +40,53 @@ @pytest.mark.skipif(not have_xdesign, reason="xdesign package not installed") -def test_foam2_gen(): - seed = 4321 +def test_foam1_gen(): + seed = 4444 N = 32 ndata = 2 - from scico.flax.examples.data_generation import generate_foam2_images - dt = generate_foam2_images(seed, N, ndata) + dt = generate_foam1_images(seed, N, ndata) assert dt.shape == (ndata, N, N, 1) @pytest.mark.skipif(not have_xdesign, reason="xdesign package not installed") -def test_foam_gen(): - seed = 4444 +def test_foam2_gen(): + seed = 4321 N = 32 ndata = 2 - from scico.flax.examples.data_generation import generate_foam1_images - dt = generate_foam1_images(seed, N, ndata) + dt = generate_foam2_images(seed, N, ndata) assert dt.shape == (ndata, N, N, 1) -def fake_data_gen(seed, N, ndata): - dt, key = random.randn((ndata, N, N, 1), seed=seed) - return dt - - -def test_distdatagen(): - N = 16 - nimg = 8 - dt = distributed_data_generation(fake_data_gen, N, nimg) - assert dt.ndim == 5 - assert dt.shape[0] * dt.shape[1] == nimg - assert dt.shape[2:] == (N, N, 1) - - -def test_distdatagen_flag(): - N = 16 - nimg = 8 - dt = distributed_data_generation(fake_data_gen, N, nimg, False) - assert dt.ndim == 4 - assert dt.shape == (nimg, N, N, 1) - - -@pytest.mark.skipif( - jax.device_count() == 1, reason="no processes for checking failure of distributed computing" -) -def test_distdatagen_exception(): - N = 16 - nimg = 15 - with pytest.raises(ValueError): - distributed_data_generation(fake_data_gen, N, nimg) - - @pytest.mark.skipif(not have_ray, reason="ray package not installed") -def test_ray_distdatagen(): +def test_distdatagen(): N = 16 nimg = 8 def random_data_gen(seed, N, ndata): - dt, key = random.randn((ndata, N, N, 1), seed=seed) + np.random.seed(seed) + dt = np.random.randn(ndata, N, N, 1) return dt - dt = ray_distributed_data_generation(random_data_gen, N, nimg) + dt = distributed_data_generation(random_data_gen, N, nimg) assert dt.ndim == 4 assert dt.shape == (nimg, N, N, 1) -@pytest.mark.skipif(not have_astra, reason="astra package not installed") +@pytest.mark.skipif( + not have_astra or not have_ray or not have_xdesign, + reason="astra, ray, or xdesign package not installed", +) def test_ct_data_generation(): N = 32 nimg = 8 nproj = 45 def random_img_gen(seed, size, ndata): - key = jax.random.PRNGKey(seed) - key, subkey = jax.random.split(key) + np.random.seed(seed) shape = (ndata, size, size, 1) - return jax.random.normal(subkey, shape) + return np.random.randn(*shape) img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen) assert img.shape == (nimg, N, N, 1) @@ -126,24 +94,7 @@ def random_img_gen(seed, size, ndata): assert fbp.shape == (nimg, N, N, 1) -@pytest.mark.skipif(not have_astra, reason="astra package not installed") -def test_ct_data_generation_jax(): - N = 32 - nimg = 8 - nproj = 45 - - def random_img_gen(seed, size, ndata): - key = jax.random.PRNGKey(seed) - key, subkey = jax.random.split(key) - shape = (ndata, size, size, 1) - return jax.random.normal(subkey, shape) - - img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen, prefer_ray=False) - assert img.shape == (nimg, N, N, 1) - assert sino.shape == (nimg, nproj, N, 1) - assert fbp.shape == (nimg, N, N, 1) - - +@pytest.mark.skipif(not have_ray or not have_xdesign, reason="ray or xdesign package not installed") def test_blur_data_generation(): N = 32 nimg = 8 @@ -152,28 +103,10 @@ def test_blur_data_generation(): def random_img_gen(seed, size, ndata): np.random.seed(seed) - return np.random.randn(ndata, size, size, 1) - - img, blurn = generate_blur_data(nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen) - assert img.shape == (nimg, N, N, 1) - assert blurn.shape == (nimg, N, N, 1) - - -def test_blur_data_generation_jax(): - N = 32 - nimg = 8 - n = 3 # convolution kernel size - blur_kernel = np.ones((n, n)) / (n * n) - - def random_img_gen(seed, size, ndata): - key = jax.random.PRNGKey(seed) - key, subkey = jax.random.split(key) shape = (ndata, size, size, 1) - return jax.random.normal(subkey, shape) + return np.random.randn(*shape) - img, blurn = generate_blur_data( - nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen, prefer_ray=False - ) + img, blurn = generate_blur_data(nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen) assert img.shape == (nimg, N, N, 1) assert blurn.shape == (nimg, N, N, 1)