diff --git a/benchmark/dev_test.py b/benchmark/dev_test.py new file mode 100644 index 00000000000..71348ae6cfd --- /dev/null +++ b/benchmark/dev_test.py @@ -0,0 +1,71 @@ +import torch +import numpy as np + +import odl + +from odl.contrib.torch.new_operator import OperatorModule + +import matplotlib.pyplot as plt + +if __name__ == '__main__': + device_name = 'cuda:0' + ### Define input tensor + dimension = 3 + n_points = 64 + space = odl.uniform_discr( + [-20 for _ in range(dimension)], + [ 20 for _ in range(dimension)], + [n_points for _ in range(dimension)], + impl='pytorch', torch_device=device_name + ) + + odl_phantom = odl.phantom.shepp_logan(space, modified=True) + phantom : torch.Tensor = odl_phantom.asarray().unsqueeze(0).unsqueeze(0).to(device_name) + plt.matshow(phantom[0,0,32].detach().cpu()) + plt.savefig('phantom') + plt.close() + + # enforce float32 conversion, rather than float64 + phantom = phantom.to(dtype=torch.float32) + # make tensor contiguous from creation + phantom = phantom.contiguous() + # for the example, input_tensor.requires_grad == True + phantom.requires_grad_() + # Make a 3d single-axis parallel beam geometry with flat detector + # Angles: uniformly spaced, n = 180, min = 0, max = pi + angle_partition = odl.uniform_partition(0, 2 * np.pi, 32) + detector_partition = odl.uniform_partition([-30] * 2, [30] * 2, [100] * 2) + geometry = odl.tomo.Parallel3dAxisGeometry(angle_partition, detector_partition) + + # Ray transform (= forward projection). + ray_trafo = odl.tomo.RayTransform(space, geometry, impl='astra_cuda_pytorch') + + forward_module = OperatorModule(ray_trafo) + backward_module = OperatorModule(ray_trafo.adjoint) + sinogram :torch.Tensor = forward_module(phantom) #type:ignore + + x = torch.zeros( + size = phantom.size(), + device = device_name, + requires_grad=True + ) + + optimiser = torch.optim.Adam( #type:ignore + [x], + lr = 1e-3 + ) + + noisy_data = forward_module(phantom) + mse_loss =torch.nn.MSELoss() + + for _ in range(100): + optimiser.zero_grad() + current_data = forward_module(x) + loss = mse_loss(current_data, noisy_data) + loss.mean().backward() + optimiser.step() + + plt.matshow(x[0,0,32].detach().cpu()) + plt.savefig('optimised') + plt.close() + diff --git a/benchmark/main.py b/benchmark/main.py new file mode 100644 index 00000000000..f4de8f9b6b3 --- /dev/null +++ b/benchmark/main.py @@ -0,0 +1,95 @@ +"""This module benchmarks a function with parameters defined in a json file metadata""" +import argparse +from pathlib import Path +from datetime import datetime +import time +import sys + +import json +import pandas as pd + +N_CALLS = 1 +MAX_ITERATIONS = 100 + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--metadata_name', required = True) + args = parser.parse_args() + + metadata_name = args.metadata_name + + ### unpack benchmark metadata + try: + with open(f'metadata/{metadata_name}.json', mode ='r', encoding="utf-8") as json_file: + metadata_dict = json.load(json_file) + except FileNotFoundError: + sys.exit(f'No file at metadata/{metadata_name}.json') + + ### unpack variables + benchmark_dict = {} + for key in ['backend', 'script_name', 'parameters']: + try: + benchmark_dict[key] = metadata_dict[key] + except ValueError: + sys.exit(f'No "{key}" key in the metadata_dict') + + ### load backend module + if benchmark_dict['backend'] == 'odl': + import scripts.odl_scripts as sc + DEVICE = 'cpu' + + elif benchmark_dict['backend'] == 'torch': + import scripts.torch_scripts as sc + try: + DEVICE = metadata_dict['parameters']['device_name'] + except ValueError: + DEVICE = 'cpu' + + else: + raise NotImplementedError(f'''Backend {benchmark_dict["backend"]} not supported, only + "odl" and "torch"''') + + try: + function = getattr(sc, benchmark_dict['script_name']) + except AttributeError: + sys.exit(f'''Script {benchmark_dict["script_name"]} not implemented for backend + {benchmark_dict["backend"]}''') + + + report_dict = { + "dimension" : [], + "n_points" : [], + "time" : [], + "error" : [] + } + + for dimension in benchmark_dict["parameters"]['dimensions']: + for n_points in benchmark_dict["parameters"]['n_points']: + print( + f"""Benchmarking {benchmark_dict['script_name']} + for dimension {dimension} and {n_points} points""" + ) + for call in range(N_CALLS): + start = time.time() + error = function( + benchmark_dict["parameters"], + dimension, n_points, + MAX_ITERATIONS + ) + end = time.time() + report_dict['dimension'].append(dimension) + report_dict['n_points'].append(n_points) + report_dict['time'].append(end - start) + report_dict['error'].append(error) + + report_df = pd.DataFrame.from_dict(report_dict) + report_df['device'] = DEVICE + report_df['backend'] = benchmark_dict['backend'] + report_df['max_iterations'] = MAX_ITERATIONS + report_df['timestamp'] = pd.Timestamp(datetime.now(), tz=None) + result_file_path = f'results/{metadata_name}.csv' + if Path(result_file_path).is_file(): + report_df = pd.concat([ + pd.read_csv(result_file_path), report_df + ]) + report_df.to_csv(f'results/{metadata_name}.csv', index = False) diff --git a/benchmark/metadata/mri_mlem_odl_adam.json b/benchmark/metadata/mri_mlem_odl_adam.json new file mode 100644 index 00000000000..e1293758e4f --- /dev/null +++ b/benchmark/metadata/mri_mlem_odl_adam.json @@ -0,0 +1,13 @@ +{ + "backend":"odl", + "script_name":"mri_mlem_adam", + "parameters":{ + "n_points" : [512], + "dimensions" : [2], + "subsampling" : 0.5, + "learning_rate": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "eps" : 1e-8 + } +} \ No newline at end of file diff --git a/benchmark/metadata/mri_mlem_torch_adam.json b/benchmark/metadata/mri_mlem_torch_adam.json new file mode 100644 index 00000000000..80598cc47b1 --- /dev/null +++ b/benchmark/metadata/mri_mlem_torch_adam.json @@ -0,0 +1,14 @@ +{ + "backend":"torch", + "script_name":"mri_mlem_adam", + "parameters":{ + "n_points" : [512], + "dimensions" : [2], + "subsampling" : 0.5, + "device_name":"cuda:0", + "learning_rate": 0.001, + "beta1": 0.9, + "beta2": 0.999, + "eps" : 1e-8 + } +} \ No newline at end of file diff --git a/benchmark/ray_trafo_test.py b/benchmark/ray_trafo_test.py new file mode 100644 index 00000000000..e512a1b44e3 --- /dev/null +++ b/benchmark/ray_trafo_test.py @@ -0,0 +1,832 @@ +# Copyright 2014-2019 The ODL contributors +# +# This file is part of ODL. +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v. 2.0. If a copy of the MPL was not distributed with this file, You can +# obtain one at https://mozilla.org/MPL/2.0/. + +"""Tests for the Ray transform.""" + +from __future__ import division + +import numpy as np +import pytest +from packaging.version import parse as parse_version +from functools import partial + +import odl +from odl.tomo.backends import ASTRA_VERSION +from odl.tomo.util.testutils import ( + skip_if_no_astra, skip_if_no_astra_cuda, skip_if_no_skimage) +from odl.util.testutils import all_almost_equal, simple_fixture + +# --- pytest fixtures --- # + + +ray_trafo_impl = simple_fixture( + name='ray_trafo_impl', + params=[ + pytest.param('astra_cuda_link', marks=skip_if_no_astra), + # pytest.param('astra_cpu', marks=skip_if_no_astra), + # pytest.param('astra_cuda', marks=skip_if_no_astra_cuda), + # pytest.param('skimage', marks=skip_if_no_skimage) + ] +) + +reco_space_impl = simple_fixture( + name='reco_space_impl', + params=[ + pytest.param('numpy'), + pytest.param('pytorch'), + ] +) + +geometry_params = [ + # 'par2d', + 'par3d', + # 'cone2d', + 'cone3d', + 'helical' + ] +geometry_ids = [" geometry='{}' ".format(p) for p in geometry_params] + + +@pytest.fixture(scope='module', ids=geometry_ids, params=geometry_params) +def geometry(request): + geom = request.param + m = 100 + n_angles = 100 + + if geom == 'par2d': + apart = odl.uniform_partition(0, np.pi, n_angles) + dpart = odl.uniform_partition(-30, 30, m) + return odl.tomo.Parallel2dGeometry(apart, dpart) + elif geom == 'par3d': + apart = odl.uniform_partition(0, np.pi, n_angles) + dpart = odl.uniform_partition([-30, -30], [30, 30], (m, m)) + return odl.tomo.Parallel3dAxisGeometry(apart, dpart) + elif geom == 'cone2d': + apart = odl.uniform_partition(0, 2 * np.pi, n_angles) + dpart = odl.uniform_partition(-30, 30, m) + return odl.tomo.FanBeamGeometry(apart, dpart, src_radius=200, + det_radius=100) + elif geom == 'cone3d': + apart = odl.uniform_partition(0, 2 * np.pi, n_angles) + dpart = odl.uniform_partition([-60, -60], [60, 60], (m, m)) + return odl.tomo.ConeBeamGeometry(apart, dpart, + src_radius=200, det_radius=100) + elif geom == 'helical': + apart = odl.uniform_partition(0, 8 * 2 * np.pi, n_angles) + dpart = odl.uniform_partition([-30, -3], [30, 3], (m, m)) + return odl.tomo.ConeBeamGeometry(apart, dpart, pitch=5, + src_radius=200, det_radius=100) + else: + raise ValueError('geom not valid') + + +geometry_type = simple_fixture( + 'geometry_type', + [ + # 'par2d', + 'par3d', + # 'cone2d', + 'cone3d' + ] +) + +projectors = [] +projectors.extend( + (pytest.param(proj_cfg, marks=skip_if_no_astra) + for proj_cfg in [ + # 'par2d astra_cpu numpy uniform', + # 'par2d astra_cpu numpy nonuniform', + # 'par2d astra_cpu numpy random', + # 'cone2d astra_cpu numpy uniform', + # 'cone2d astra_cpu numpy nonuniform', + # 'cone2d astra_cpu numpy random' + ]) +) +projectors.extend( + (pytest.param(proj_cfg, marks=skip_if_no_astra_cuda) + for proj_cfg in [ + 'par2d astra_cuda numpy uniform', + 'par2d astra_cuda numpy half_uniform', + 'par2d astra_cuda numpy nonuniform', # eroor + 'par2d astra_cuda numpy random', + 'cone2d astra_cuda numpy uniform', + 'cone2d astra_cuda numpy nonuniform', + 'cone2d astra_cuda numpy random', + 'par3d astra_cuda numpy uniform', + 'par3d astra_cuda numpy nonuniform', # fails + 'par3d astra_cuda numpy random', + 'cone3d astra_cuda numpy uniform', # fails + 'cone3d astra_cuda numpy nonuniform', # fails + 'cone3d astra_cuda numpy random', + 'helical astra_cuda numpy uniform', + + 'par2d astra_cuda pytorch uniform', + 'par2d astra_cuda pytorch half_uniform', + 'par2d astra_cuda pytorch nonuniform', + 'par2d astra_cuda pytorch random', + 'cone2d astra_cuda pytorch uniform', + 'cone2d astra_cuda pytorch nonuniform', + 'cone2d astra_cuda pytorch random', + 'par3d astra_cuda pytorch uniform', + 'par3d astra_cuda pytorch nonuniform', + 'par3d astra_cuda pytorch random', + 'cone3d astra_cuda pytorch uniform', + 'cone3d astra_cuda pytorch nonuniform', + 'cone3d astra_cuda pytorch random', + 'helical astra_cuda pytorch uniform' + + # Only 3D tests so far + 'par3d astra_cuda_pytorch pytorch uniform', + 'par3d astra_cuda_pytorch pytorch nonuniform', + 'par3d astra_cuda_pytorch pytorch random', + 'cone3d astra_cuda_pytorch pytorch uniform', + 'cone3d astra_cuda_pytorch pytorch nonuniform', + 'cone3d astra_cuda_pytorch pytorch random', + 'helical astra_cuda_pytorch pytorch uniform', + + 'par3d astra_cuda_pytorch numpy uniform', + 'par3d astra_cuda_pytorch numpy nonuniform', + 'par3d astra_cuda_pytorch numpy random', + 'cone3d astra_cuda_pytorch numpy uniform', + 'cone3d astra_cuda_pytorch numpy nonuniform', + 'cone3d astra_cuda_pytorch numpy random', + 'helical astra_cuda_pytorch numpy uniform' + ]) +) +projectors.extend( + (pytest.param(proj_cfg, marks=skip_if_no_skimage) + for proj_cfg in ['par2d skimage numpy uniform', + 'par2d skimage numpy half_uniform']) +) + +projector_ids = [ + " geom='{}' - ray_trafo_impl='{}' - reco_space_impl='{}' - angles='{}' ".format(*p.values[0].split()) + for p in projectors +] + + +@pytest.fixture(scope='module', params=projectors, ids=projector_ids) +def projector(request): + n = 100 + m = 100 + n_angles = 100 + dtype = 'float32' + + geom, ray_trafo_impl, reco_space_impl, angle = request.param.split() + + if angle == 'uniform': + apart = odl.uniform_partition(0, 2 * np.pi, n_angles) + elif angle == 'half_uniform': + apart = odl.uniform_partition(0, np.pi, n_angles) + elif angle == 'random': + # Linearly spaced with random noise + min_pt = 2 * (2.0 * np.pi) / n_angles + max_pt = (2.0 * np.pi) - 2 * (2.0 * np.pi) / n_angles + points = np.linspace(min_pt, max_pt, n_angles) + points += np.random.rand(n_angles) * (max_pt - min_pt) / (5 * n_angles) + apart = odl.nonuniform_partition(points) + elif angle == 'nonuniform': + # Angles spaced quadratically + min_pt = 2 * (2.0 * np.pi) / n_angles + max_pt = (2.0 * np.pi) - 2 * (2.0 * np.pi) / n_angles + points = np.linspace(min_pt ** 0.5, max_pt ** 0.5, n_angles) ** 2 + apart = odl.nonuniform_partition(points) + else: + raise ValueError('angle not valid') + + if geom == 'par2d': + # Reconstruction space + reco_space = odl.uniform_discr([-20] * 2, [20] * 2, [n] * 2, + dtype=dtype, impl=reco_space_impl) + # Geometry + dpart = odl.uniform_partition(-30, 30, m) + geom = odl.tomo.Parallel2dGeometry(apart, dpart) + # Ray transform + return odl.tomo.RayTransform(reco_space, geom, impl=ray_trafo_impl) + + elif geom == 'par3d': + # Reconstruction space + reco_space = odl.uniform_discr([-20] * 3, [20] * 3, [n] * 3, + dtype=dtype, impl=reco_space_impl) + + # Geometry + dpart = odl.uniform_partition([-30] * 2, [30] * 2, [m] * 2) + geom = odl.tomo.Parallel3dAxisGeometry(apart, dpart) + # Ray transform + return odl.tomo.RayTransform(reco_space, geom, impl=ray_trafo_impl) + + elif geom == 'cone2d': + # Reconstruction space + reco_space = odl.uniform_discr([-20] * 2, [20] * 2, [n] * 2, + dtype=dtype, impl=reco_space_impl) + # Geometry + dpart = odl.uniform_partition(-30, 30, m) + geom = odl.tomo.FanBeamGeometry(apart, dpart, src_radius=200, + det_radius=100) + # Ray transform + return odl.tomo.RayTransform(reco_space, geom, impl=ray_trafo_impl) + + elif geom == 'cone3d': + # Reconstruction space + reco_space = odl.uniform_discr([-20] * 3, [20] * 3, [n] * 3, + dtype=dtype, impl=reco_space_impl) + # Geometry + dpart = odl.uniform_partition([-60] * 2, [60] * 2, [m] * 2) + geom = odl.tomo.ConeBeamGeometry(apart, dpart, + src_radius=200, det_radius=100) + # Ray transform + return odl.tomo.RayTransform(reco_space, geom, impl=ray_trafo_impl) + + elif geom == 'helical': + # Reconstruction space + reco_space = odl.uniform_discr([-20, -20, 0], [20, 20, 40], + [n] * 3, dtype=dtype, impl=reco_space_impl) + # Geometry, overwriting angle partition + apart = odl.uniform_partition(0, 8 * 2 * np.pi, n_angles) + dpart = odl.uniform_partition([-30, -3], [30, 3], [m] * 2) + geom = odl.tomo.ConeBeamGeometry(apart, dpart, pitch=5, + src_radius=200, det_radius=100) + # Ray transform + return odl.tomo.RayTransform(reco_space, geom, impl=ray_trafo_impl) + else: + raise ValueError('geom not valid') + + +@pytest.fixture(scope='module', + params=[ + True, + False + ], + ids=[ + ' in-place ', + ' out-of-place ' + ]) +def in_place(request): + return request.param + + +# --- RayTransform tests --- # + + +def test_projector(projector, in_place): + """Test Ray transform forward projection.""" + # TODO: this needs to be improved + # Accept 10% errors + rtol = 1e-1 + + # Create Shepp-Logan phantom + vol = projector.domain.one() + + # Calculate projection + if in_place: + proj = projector.range.zero() + projector(vol, out=proj) + else: + proj = projector(vol) + + # We expect maximum value to be along diagonal + expected_max = projector.domain.partition.extent[0] * np.sqrt(2) + assert proj.max() == pytest.approx(expected_max, rel=rtol) + + +def test_adjoint(projector): + """Test Ray transform backward projection.""" + # Relative tolerance, still rather high due to imperfectly matched + # adjoint in the cone beam case + if ( + parse_version(ASTRA_VERSION) < parse_version('1.8rc1') + and isinstance(projector.geometry, odl.tomo.ConeBeamGeometry) + ): + rtol = 0.1 + else: + rtol = 0.05 + + # Create Shepp-Logan phantom + vol = odl.phantom.shepp_logan(projector.domain, modified=True) + + # Calculate projection + proj = projector(vol) + backproj = projector.adjoint(proj) + + # Verified the identity = + result_AxAx = proj.inner(proj) + result_xAtAx = backproj.inner(vol) + assert result_AxAx == pytest.approx(result_xAtAx, rel=rtol) + + +def test_adjoint_of_adjoint(projector): + """Test Ray transform adjoint of adjoint.""" + + # Create Shepp-Logan phantom + vol = odl.phantom.shepp_logan(projector.domain, modified=True) + + # Calculate projection + proj = projector(vol) + proj_adj_adj = projector.adjoint.adjoint(vol) + + # Verify A(x) == (A^*)^*(x) + assert all_almost_equal(proj, proj_adj_adj) + + # Calculate adjoints + proj_adj = projector.adjoint(proj) + proj_adj_adj_adj = projector.adjoint.adjoint.adjoint(proj) + + # Verify A^*(y) == ((A^*)^*)^*(x) + assert all_almost_equal(proj_adj, proj_adj_adj_adj) + + +def test_angles(projector): + """Test Ray transform angle conventions.""" + + # Smoothed line/hyperplane with offset + vol = projector.domain.element( + lambda x: np.exp(-(2 * x[0] - 10 + x[1]) ** 2)) + + # Create projection + result = projector(vol).asarray() + + # Find the angle where the projection has a maximum (along the line). + # TODO: center of mass would be more robust + axes = 1 if projector.domain.ndim == 2 else (1, 2) + ind_angle = np.argmax(np.max(result, axis=axes)) + # Restrict to [0, 2 * pi) for helical + maximum_angle = np.fmod(projector.geometry.angles[ind_angle], 2 * np.pi) + + # Verify correct maximum angle. The line is defined by the equation + # x1 = 10 - 2 * x0, i.e. the slope -2. Thus the angle arctan(1/2) should + # give the maximum projection values. + expected = np.arctan2(1, 2) + assert np.fmod(maximum_angle, np.pi) == pytest.approx(expected, abs=0.1) + + # Find the pixel where the projection has a maximum at that angle + axes = () if projector.domain.ndim == 2 else 1 + ind_pixel = np.argmax(np.max(result[ind_angle], axis=axes)) + max_pixel = projector.geometry.det_partition[ind_pixel, ...].mid_pt[0] + + # The line is at distance 2 * sqrt(5) from the origin, which translates + # to the same distance from the detector midpoint, with positive sign + # if the angle is smaller than pi and negative sign otherwise. + expected = 2 * np.sqrt(5) if maximum_angle < np.pi else -2 * np.sqrt(5) + + # We need to scale with the magnification factor if applicable + if isinstance(projector.geometry, odl.tomo.DivergentBeamGeometry): + src_to_det = ( + projector.geometry.src_radius #type:ignore + + projector.geometry.det_radius #type:ignore + ) + magnification = src_to_det / projector.geometry.src_radius #type:ignore + expected *= magnification + + assert max_pixel == pytest.approx(expected, abs=0.2) + + +def test_complex(ray_trafo_impl): + """Test transform of complex input for parallel 2d geometry.""" + space_c = odl.uniform_discr([-1, -1], [1, 1], (10, 10), dtype='complex64') + space_r = space_c.real_space + geom = odl.tomo.parallel_beam_geometry(space_c) + ray_trafo_c = odl.tomo.RayTransform(space_c, geom, impl=ray_trafo_impl) + ray_trafo_r = odl.tomo.RayTransform(space_r, geom, impl=ray_trafo_impl) + vol = odl.phantom.shepp_logan(space_c) + vol.imag = odl.phantom.cuboid(space_r) + + data = ray_trafo_c(vol) + true_data_re = ray_trafo_r(vol.real) + true_data_im = ray_trafo_r(vol.imag) + + assert all_almost_equal(data.real, true_data_re) + assert all_almost_equal(data.imag, true_data_im) + + # test adjoint for complex data + backproj_r = ray_trafo_r.adjoint + backproj_c = ray_trafo_c.adjoint + true_vol_re = backproj_r(data.real) + true_vol_im = backproj_r(data.imag) + backproj_vol = backproj_c(data) + + assert all_almost_equal(backproj_vol.real, true_vol_re) + assert all_almost_equal(backproj_vol.imag, true_vol_im) + + +def test_anisotropic_voxels(geometry): + """Test projection and backprojection with anisotropic voxels.""" + ndim = geometry.ndim + shape = [10] * (ndim - 1) + [5] + space = odl.uniform_discr([-1] * ndim, [1] * ndim, shape=shape, + dtype='float32') + + # If no implementation is available, skip + if ndim == 2 and not odl.tomo.ASTRA_AVAILABLE: + pytest.skip(msg='ASTRA not available, skipping 2d test') #type:ignore + elif ndim == 3 and not odl.tomo.ASTRA_CUDA_AVAILABLE: + pytest.skip(msg='ASTRA_CUDA not available, skipping 3d test') #type:ignore + + ray_trafo = odl.tomo.RayTransform(space, geometry) + vol_one = ray_trafo.domain.one() #type:ignore + data_one = ray_trafo.range.one() #type:ignore + + if ndim == 2: + # Should raise + with pytest.raises(NotImplementedError): + ray_trafo(vol_one) + with pytest.raises(NotImplementedError): + ray_trafo.adjoint(data_one) + elif ndim == 3: + # Just check that this doesn't crash and computes something nonzero + data = ray_trafo(vol_one) + backproj = ray_trafo.adjoint(data_one) + assert data.norm() > 0 + assert backproj.norm() > 0 + else: + assert False + + +def test_shifted_volume(geometry_type): + """Check that geometry shifts are handled correctly. + + We forward project a square/cube of all ones and check that the + correct portion of the detector gets nonzero values. In the default + setup, at angle 0, the source (if existing) is at (0, -s[, 0]), and + the detector at (0, +d[, 0]) with the positive x axis as (first) + detector axis. Thus, when shifting enough in the negative x direction, + the object should be visible at the left half of the detector only. + A shift in y should not influence the result (much). + + At +90 degrees, a shift in the negative y direction should have the same + effect. + """ + apart = odl.nonuniform_partition([0, np.pi / 2, np.pi, 3 * np.pi / 2]) + if geometry_type == 'par2d' and odl.tomo.ASTRA_AVAILABLE: + ndim = 2 + dpart = odl.uniform_partition(-30, 30, 30) + geometry = odl.tomo.Parallel2dGeometry(apart, dpart) + elif geometry_type == 'par3d' and odl.tomo.ASTRA_CUDA_AVAILABLE: + ndim = 3 + dpart = odl.uniform_partition([-30, -30], [30, 30], (30, 30)) + geometry = odl.tomo.Parallel3dAxisGeometry(apart, dpart) + if geometry_type == 'cone2d' and odl.tomo.ASTRA_AVAILABLE: + ndim = 2 + dpart = odl.uniform_partition(-30, 30, 30) + geometry = odl.tomo.FanBeamGeometry(apart, dpart, + src_radius=200, det_radius=100) + elif geometry_type == 'cone3d' and odl.tomo.ASTRA_CUDA_AVAILABLE: + ndim = 3 + dpart = odl.uniform_partition([-30, -30], [30, 30], (30, 30)) + geometry = odl.tomo.ConeBeamGeometry(apart, dpart, + src_radius=200, det_radius=100) + else: + pytest.skip('no projector available for geometry type') + + min_pt = np.array([-5.0] * ndim) + max_pt = np.array([5.0] * ndim) + shift_len = 6 # enough to move the projection to one side of the detector + + # Shift along axis 0 + shift = np.zeros(ndim) + shift[0] = -shift_len + + # Generate 4 projections with 90 degrees increment + space = odl.uniform_discr(min_pt + shift, max_pt + shift, [10] * ndim) + ray_trafo = odl.tomo.RayTransform(space, geometry) + proj = ray_trafo(space.one()) + + # Check that the object is projected to the correct place. With the + # chosen setup, at least one ray should go through a substantial + # part of the volume, yielding a value around 10 (=side length). + + # 0 degrees: All on the left + assert np.max(proj[0, :15]) > 5 + assert np.max(proj[0, 15:]) == 0 + + # 90 degrees: Left and right + assert np.max(proj[1, :15]) > 5 + assert np.max(proj[1, 15:]) > 5 + + # 180 degrees: All on the right + assert np.max(proj[2, :15]) == 0 + assert np.max(proj[2, 15:]) > 5 + + # 270 degrees: Left and right + assert np.max(proj[3, :15]) > 5 + assert np.max(proj[3, 15:]) > 5 + + # Do the same for axis 1 + shift = np.zeros(ndim) + shift[1] = -shift_len + + space = odl.uniform_discr(min_pt + shift, max_pt + shift, [10] * ndim) + ray_trafo = odl.tomo.RayTransform(space, geometry) + proj = ray_trafo(space.one()) + + # 0 degrees: Left and right + assert np.max(proj[0, :15]) > 5 + assert np.max(proj[0, 15:]) > 5 + + # 90 degrees: All on the left + assert np.max(proj[1, :15]) > 5 + assert np.max(proj[1, 15:]) == 0 + + # 180 degrees: Left and right + assert np.max(proj[2, :15]) > 5 + assert np.max(proj[2, 15:]) > 5 + + # 270 degrees: All on the right + assert np.max(proj[3, :15]) == 0 + assert np.max(proj[3, 15:]) > 5 + + +def test_detector_shifts_2d(): + """Check that detector shifts are handled correctly. + + We forward project a cubic phantom and check that ray transform + and back-projection with and without detector shifts are + numerically close (the error depends on domain discretization). + """ + + if not odl.tomo.ASTRA_AVAILABLE: + pytest.skip(msg='ASTRA not available, skipping 2d test') #type:ignore + + d = 10 + space = odl.uniform_discr([-1] * 2, [1] * 2, [d] * 2) + phantom = odl.phantom.cuboid(space, [-1 / 3] * 2, [1 / 3] * 2) + + full_angle = 2 * np.pi + n_angles = 2 * 10 + src_rad = 2 + det_rad = 2 + apart = odl.uniform_partition(0, full_angle, n_angles) + dpart = odl.uniform_partition(-4, 4, 8 * d) + geom = odl.tomo.FanBeamGeometry(apart, dpart, src_rad, det_rad) + k = 3 + shift = k * dpart.cell_sides[0] + geom_shift = odl.tomo.FanBeamGeometry( + apart, dpart, src_rad, det_rad, + det_shift_func=lambda angle: [0.0, shift] + ) + + assert all_almost_equal(geom.angles, geom_shift.angles) + angles = geom.angles + assert all_almost_equal(geom.src_position(angles), + geom_shift.src_position(angles)) + assert all_almost_equal(geom.det_axis(angles), + geom_shift.det_axis(angles)) + assert all_almost_equal(geom.det_refpoint(angles), + geom_shift.det_refpoint(angles) + + shift * geom_shift.det_axis(angles)) + + # check ray transform + op = odl.tomo.RayTransform(space, geom) + op_shift = odl.tomo.RayTransform(space, geom_shift) + y = op(phantom).asarray() + y_shift = op_shift(phantom).asarray() + # projection on the shifted detector is shifted regular projection + data_error = np.max(np.abs(y[:, :-k] - y_shift[:, k:])) + assert data_error < space.cell_volume + + # check back-projection + im = op.adjoint(y).asarray() + im_shift = op_shift.adjoint(y_shift).asarray() + error = np.abs(im_shift - im) + rel_error = np.max(error[im > 0] / im[im > 0]) + assert rel_error < space.cell_volume + + +def test_source_shifts_2d(): + """Check that source shifts are handled correctly. + + We forward project a Shepp-Logan phantom and check that reconstruction + with flying focal spot is equal to a sum of reconstructions with two + geometries which mimic ffs by using initial angular offsets and + detector shifts + """ + + if not odl.tomo.ASTRA_AVAILABLE: + pytest.skip(msg='ASTRA required but not available') #type:ignore + + d = 10 + space = odl.uniform_discr([-1] * 2, [1] * 2, [d] * 2) + phantom = odl.phantom.cuboid(space, [-1 / 3] * 2, [1 / 3] * 2) + + full_angle = 2 * np.pi + n_angles = 2 * 10 + src_rad = 2 + det_rad = 2 + apart = odl.uniform_partition(0, full_angle, n_angles) + dpart = odl.uniform_partition(-4, 4, 8 * d) + # Source positions with flying focal spot should correspond to + # source positions of 2 geometries with different starting positions + shift1 = np.array([0.0, -0.3]) + shift2 = np.array([0.0, 0.3]) + init = np.array([1, 0], dtype=np.float32) + det_init = np.array([0, -1], dtype=np.float32) + + ffs = partial(odl.tomo.flying_focal_spot, + apart=apart, + shifts=[shift1, shift2]) + geom_ffs = odl.tomo.FanBeamGeometry(apart, dpart, + src_rad, det_rad, + src_to_det_init=init, + det_axis_init=det_init, + src_shift_func=ffs, + det_shift_func=ffs) + # angles must be shifted to match discretization of apart + ang1 = -full_angle / (n_angles * 2) + apart1 = odl.uniform_partition(ang1, full_angle + ang1, n_angles // 2) + ang2 = full_angle / (n_angles * 2) + apart2 = odl.uniform_partition(ang2, full_angle + ang2, n_angles // 2) + + init1 = init + np.array([0, shift1[1]]) / (src_rad + shift1[0]) + init2 = init + np.array([0, shift2[1]]) / (src_rad + shift2[0]) + # radius also changes when a shift is applied + src_rad1 = np.linalg.norm(np.array([src_rad, 0]) + shift1) + src_rad2 = np.linalg.norm(np.array([src_rad, 0]) + shift2) + det_rad1 = np.linalg.norm( + np.array([det_rad, shift1[1] / src_rad * det_rad])) + det_rad2 = np.linalg.norm( + np.array([det_rad, shift2[1] / src_rad * det_rad])) + geom1 = odl.tomo.FanBeamGeometry(apart1, dpart, + src_rad1, det_rad1, + src_to_det_init=init1, + det_axis_init=det_init) + geom2 = odl.tomo.FanBeamGeometry(apart2, dpart, + src_rad2, det_rad2, + src_to_det_init=init2, + det_axis_init=det_init) + + # check ray transform + op_ffs = odl.tomo.RayTransform(space, geom_ffs) + op1 = odl.tomo.RayTransform(space, geom1) + op2 = odl.tomo.RayTransform(space, geom2) + y_ffs = op_ffs(phantom) + y1 = op1(phantom).asarray() + y2 = op2(phantom).asarray() + assert all_almost_equal(y_ffs[::2], y1) + assert all_almost_equal(y_ffs[1::2], y2) + + # check back-projection + im = op_ffs.adjoint(y_ffs).asarray() + im1 = op1.adjoint(y1).asarray() + im2 = op2.adjoint(y2).asarray() + im_combined = (im1 + im2) / 2 + rel_error = np.abs((im - im_combined)[im > 0] / im[im > 0]) + assert np.max(rel_error) < 1e-6 + + +def test_detector_shifts_3d(): + """Check that detector shifts are handled correctly. + + We forward project a cubic phantom and check that ray transform + and back-projection with and without detector shifts are + numerically close (the error depends on domain discretization). + """ + if not odl.tomo.ASTRA_CUDA_AVAILABLE: + pytest.skip(msg='ASTRA CUDA required but not available') #type:ignore + + d = 100 + space = odl.uniform_discr([-1] * 3, [1] * 3, [d] * 3) + phantom = odl.phantom.cuboid(space, [-1 / 3] * 3, [1 / 3] * 3) + + full_angle = 2 * np.pi + n_angles = 2 * 100 + src_rad = 2 + det_rad = 2 + apart = odl.uniform_partition(0, full_angle, n_angles) + dpart = odl.uniform_partition([-4] * 2, [4] * 2, [8 * d] * 2) + geom = odl.tomo.ConeBeamGeometry(apart, dpart, src_rad, det_rad) + k = 3 + l = 2 + shift = np.array([0, k, l]) * dpart.cell_sides[0] + geom_shift = odl.tomo.ConeBeamGeometry(apart, dpart, src_rad, det_rad, + det_shift_func=lambda angle: shift) + + angles = geom.angles + + assert all_almost_equal(angles, geom_shift.angles) + assert all_almost_equal(geom.src_position(angles), + geom_shift.src_position(angles)) + assert all_almost_equal(geom.det_axes(angles), + geom_shift.det_axes(angles)) + assert all_almost_equal(geom.det_refpoint(angles), + geom_shift.det_refpoint(angles) + + geom_shift.det_axes(angles)[:, 0] * shift[1] + - geom_shift.det_axes(angles)[:, 1] * shift[2]) + + # check forward pass + op = odl.tomo.RayTransform(space, geom) + op_shift = odl.tomo.RayTransform(space, geom_shift) + y = op(phantom).asarray() + y_shift = op_shift(phantom).asarray() + data_error = np.max(np.abs(y[:, :-k, l:] - y_shift[:, k:, :-l])) + assert data_error < 1e-3 + + # check back-projection + im = op.adjoint(y).asarray() + im_shift = op_shift.adjoint(y_shift).asarray() + error = np.max(np.abs(im_shift - im)) + assert error < 1e-3 + + +def test_source_shifts_3d(): + """Check that source shifts are handled correctly. + + We forward project a Shepp-Logan phantom and check that reconstruction + with flying focal spot is equal to a sum of reconstructions with two + geometries which mimic ffs by using initial angular offsets and + detector shifts + """ + if not odl.tomo.ASTRA_CUDA_AVAILABLE: + pytest.skip(msg='ASTRA_CUDA not available, skipping 3d test') #type:ignore + + d = 10 + space = odl.uniform_discr([-1] * 3, [1] * 3, [d] * 3) + phantom = odl.phantom.cuboid(space, [-1 / 3] * 3, [1 / 3] * 3) + + full_angle = 2 * np.pi + n_angles = 2 * 10 + apart = odl.uniform_partition(0, full_angle, n_angles) + dpart = odl.uniform_partition([-4] * 2, [4] * 2, [8 * d] * 2) + src_rad = 2 + det_rad = 2 + pitch = 0.2 + # Source positions with flying focal spot should correspond to + # source positions of 2 geometries with different starting positions + shift1 = np.array([0.0, -0.2, 0.1]) + shift2 = np.array([0.0, 0.2, -0.1]) + init = np.array([1, 0, 0], dtype=np.float32) + det_init = np.array([[0, -1, 0], [0, 0, 1]], dtype=np.float32) + ffs = partial(odl.tomo.flying_focal_spot, + apart=apart, + shifts=[shift1, shift2]) + geom_ffs = odl.tomo.ConeBeamGeometry(apart, dpart, + src_rad, det_rad, + src_to_det_init=init, + det_axes_init=det_init, + src_shift_func=ffs, + det_shift_func=ffs, + pitch=pitch) #type:ignore + # angles must be shifted to match discretization of apart + ang1 = -full_angle / (n_angles * 2) + apart1 = odl.uniform_partition(ang1, full_angle + ang1, n_angles // 2) + ang2 = full_angle / (n_angles * 2) + apart2 = odl.uniform_partition(ang2, full_angle + ang2, n_angles // 2) + + init1 = init + np.array([0, shift1[1], 0]) / (src_rad + shift1[0]) + init2 = init + np.array([0, shift2[1], 0]) / (src_rad + shift2[0]) + # radius also changes when a shift is applied + src_rad1 = np.linalg.norm(np.array([src_rad + shift1[0], shift1[1], 0])) + src_rad2 = np.linalg.norm(np.array([src_rad + shift2[0], shift2[1], 0])) + det_rad1 = np.linalg.norm( + np.array([det_rad, det_rad / src_rad * shift1[1], 0])) + det_rad2 = np.linalg.norm( + np.array([det_rad, det_rad / src_rad * shift2[1], 0])) + geom1 = odl.tomo.ConeBeamGeometry(apart1, dpart, src_rad1, det_rad1, + src_to_det_init=init1, + det_axes_init=det_init, + offset_along_axis=shift1[2], + pitch=pitch) #type:ignore + geom2 = odl.tomo.ConeBeamGeometry(apart2, dpart, src_rad2, det_rad2, + src_to_det_init=init2, + det_axes_init=det_init, + offset_along_axis=shift2[2], + pitch=pitch) #type:ignore + + assert all_almost_equal(geom_ffs.src_position(geom_ffs.angles)[::2], + geom1.src_position(geom1.angles)) + assert all_almost_equal(geom_ffs.src_position(geom_ffs.angles)[1::2], + geom2.src_position(geom2.angles)) + + assert all_almost_equal(geom_ffs.det_refpoint(geom_ffs.angles)[::2], + geom1.det_refpoint(geom1.angles)) + assert all_almost_equal(geom_ffs.det_refpoint(geom_ffs.angles)[1::2], + geom2.det_refpoint(geom2.angles)) + + assert all_almost_equal(geom_ffs.det_axes(geom_ffs.angles)[::2], + geom1.det_axes(geom1.angles)) + assert all_almost_equal(geom_ffs.det_axes(geom_ffs.angles)[1::2], + geom2.det_axes(geom2.angles)) + + op_ffs = odl.tomo.RayTransform(space, geom_ffs) + op1 = odl.tomo.RayTransform(space, geom1) + op2 = odl.tomo.RayTransform(space, geom2) + y_ffs = op_ffs(phantom) + y1 = op1(phantom) + y2 = op2(phantom) + assert all_almost_equal(np.mean(y_ffs[::2], axis=(1, 2)), + np.mean(y1, axis=(1, 2))) + assert all_almost_equal(np.mean(y_ffs[1::2], axis=(1, 2)), + np.mean(y2, axis=(1, 2))) + im = op_ffs.adjoint(y_ffs).asarray() + im_combined = (op1.adjoint(y1).asarray() + op2.adjoint(y2).asarray()) + # the scaling is a bit off for older versions of astra + im_combined = im_combined / np.sum(im_combined) * np.sum(im) + rel_error = np.abs((im - im_combined)[im > 0] / im[im > 0]) + assert np.max(rel_error) < 1e-6 + + +if __name__ == '__main__': + odl.util.test_file(__file__) diff --git a/benchmark/results/mri_mlem_odl_adam.csv b/benchmark/results/mri_mlem_odl_adam.csv new file mode 100644 index 00000000000..0fe9ca81772 --- /dev/null +++ b/benchmark/results/mri_mlem_odl_adam.csv @@ -0,0 +1,4 @@ +dimension,n_points,time,error,device,backend,max_iterations,timestamp +2,512,1.57490420341,111.12852552,cpu,odl,100,2024-10-10 09:38:25.539154 +2,512,1.35907077789,108.845603735,cpu,odl,100,2024-10-10 09:38:33.879930 +2,512,1.29878950119,109.469230706,cpu,odl,100,2024-10-10 09:38:39.865889 diff --git a/benchmark/results/mri_mlem_torch_adam.csv b/benchmark/results/mri_mlem_torch_adam.csv new file mode 100644 index 00000000000..573de636c42 --- /dev/null +++ b/benchmark/results/mri_mlem_torch_adam.csv @@ -0,0 +1,2 @@ +dimension,n_points,time,error,device,backend,max_iterations,timestamp +2,512,0.104027271271,114.579210652,cuda:0,torch,100,2024-10-10 09:42:45.245383 diff --git a/benchmark/scripts/__init__.py b/benchmark/scripts/__init__.py new file mode 100644 index 00000000000..4ebc76c3bfb --- /dev/null +++ b/benchmark/scripts/__init__.py @@ -0,0 +1,2 @@ +from . import torch_scripts +from . import odl_scripts \ No newline at end of file diff --git a/benchmark/scripts/odl_scripts.py b/benchmark/scripts/odl_scripts.py new file mode 100644 index 00000000000..225f090bd5e --- /dev/null +++ b/benchmark/scripts/odl_scripts.py @@ -0,0 +1,46 @@ +from typing import Dict + +import numpy as np +import odl + +def mri_mlem_adam( + parameters:Dict, + dimension : int, + n_points : int, + max_iterations : int + ): + subsampling : float = parameters['subsampling'] + learning_rate: float = parameters['learning_rate'] + beta1: float = parameters['beta1'] + beta2: float = parameters['beta2'] + eps: float = parameters['eps'] + # Create a space + space = odl.uniform_discr( + [0 for _ in range(dimension)], + [n_points for _ in range(dimension)], + [n_points for _ in range(dimension)] + ) + # Create MRI operator. First fourier transform, then subsample + ft = odl.trafos.FourierTransform(space) + sampling_points = np.random.rand(*ft.range.shape) < subsampling #type:ignore + sampling_mask = ft.range.element(sampling_points) + mri_op = sampling_mask * ft + + # Create noisy MRI data + phantom = odl.phantom.shepp_logan(space, modified=True) + noisy_data = mri_op(phantom) + odl.phantom.white_noise(mri_op.range) * 0.1 #type:ignore + + g = odl.solvers.L2Norm(mri_op.range).translated(noisy_data) * mri_op + + # Solve + x = mri_op.domain.zero() + odl.solvers.adam( + g, x, + maxiter=max_iterations, + learning_rate=learning_rate, + beta1=beta1, + beta2=beta2, + eps=eps) + + ### Return the data; compare it against target (l2 norm) + return np.linalg.norm(phantom - x.data) diff --git a/benchmark/scripts/torch_scripts.py b/benchmark/scripts/torch_scripts.py new file mode 100644 index 00000000000..84ba575e0d3 --- /dev/null +++ b/benchmark/scripts/torch_scripts.py @@ -0,0 +1,76 @@ +from typing import Dict + +import numpy as np +import torch +import torch.random + +import odl + + +def complex_mse_loss(output:torch.Tensor, target:torch.Tensor): + return (0.5*(output - target)**2).mean(dtype=torch.complex64) + +def mri_mlem_adam( + parameters:Dict, + dimension : int, + n_points : int, + max_iterations : int + ): + subsampling : float = parameters['subsampling'] + device_name : str = parameters['device_name'] + learning_rate: float = parameters['learning_rate'] + beta1:float = parameters['beta1'] + beta2:float = parameters['beta2'] + eps = parameters['eps'] + + space = odl.uniform_discr( + [0 for _ in range(dimension)], + [n_points for _ in range(dimension)], + [n_points for _ in range(dimension)] + ) + + phantom = odl.phantom.shepp_logan(space, modified=True) + phantom = torch.from_numpy(phantom.asarray()).unsqueeze(0).unsqueeze(0).to(device_name) + + x = torch.zeros( + size = phantom.size(), + device = device_name, + requires_grad=True + ) + + optimiser = torch.optim.Adam( + [x], + lr = learning_rate, + betas= (beta1, beta2), + eps = eps + ) + + class FwdOp(torch.nn.Module): + def __init__( + self, + phantom:torch.Tensor, + subsampling:float, + device + ): + super(FwdOp, self).__init__() + self.sampling_mask = torch.rand(phantom.size(), device=device) < subsampling + + def forward(self, input_tensor:torch.Tensor): + return self.sampling_mask * torch.fft.fftn(input_tensor) + + mri_op = FwdOp(phantom, subsampling, device_name) + + noisy_data = mri_op(phantom) + torch.normal( + mean=torch.zeros(phantom.size()), + std=torch.ones(phantom.size())).to(device_name) * 0.1 + + for _ in range(max_iterations): + optimiser.zero_grad() + current_data = mri_op(x) + loss = complex_mse_loss(current_data, noisy_data) + loss.mean().backward() + optimiser.step() + + return np.linalg.norm( + phantom.detach().cpu().numpy() - x.data.detach().cpu().numpy() + ) \ No newline at end of file diff --git a/examples/tomo/test_ray_trafo_3d.py b/examples/tomo/test_ray_trafo_3d.py new file mode 100644 index 00000000000..93a6ad1b000 --- /dev/null +++ b/examples/tomo/test_ray_trafo_3d.py @@ -0,0 +1,37 @@ +"""Example using the ray transform with 2d parallel beam geometry.""" + +import numpy as np +import odl + +import matplotlib.pyplot as plt +plt.gray() + +# Reconstruction space: discretized functions on the rectangle +# [-20, 20]^2 with 300 samples per dimension. +reco_space = odl.uniform_discr( + min_pt=[-20, -20, -20], max_pt=[20, 20, 20], shape=[300, 300, 300], + dtype='float32', impl='pytorch', torch_device='cpu') + +# Make a 3d single-axis parallel beam geometry with flat detector +# Angles: uniformly spaced, n = 180, min = 0, max = pi +angle_partition = odl.uniform_partition(0, np.pi, 180) +# Detector: uniformly sampled, n = (512, 512), min = (-30, -30), max = (30, 30) +detector_partition = odl.uniform_partition([-30, -30], [30, 30], [512, 512]) +geometry = odl.tomo.Parallel3dAxisGeometry(angle_partition, detector_partition) + +# Ray transform (= forward projection). +ray_trafo = odl.tomo.RayTransform(reco_space, geometry, impl='astra_cuda_link') + +# Create a discrete Shepp-Logan phantom (modified version) +phantom = odl.phantom.shepp_logan(reco_space, modified=True) + +# Create projection data by calling the ray transform on the phantom +proj_data = ray_trafo(phantom) +plt.matshow(proj_data[0,:,:]) +plt.savefig('test_ray_trafo_3d_sinogram', bbox_inches='tight') +plt.close() + +rec_data = ray_trafo.adjoint(proj_data) +plt.matshow(rec_data[150,:,:]) +plt.savefig('test_ray_trafo_3d_reconstruction', bbox_inches='tight') +plt.close() \ No newline at end of file diff --git a/odl/contrib/torch/new_operator.py b/odl/contrib/torch/new_operator.py new file mode 100644 index 00000000000..c2ed44321b1 --- /dev/null +++ b/odl/contrib/torch/new_operator.py @@ -0,0 +1,74 @@ +from __future__ import division + +import warnings + +from torch.autograd import Function +import itertools + +import numpy as np +import torch +from packaging.version import parse as parse_version + +from odl import Operator + +if parse_version(torch.__version__) < parse_version('0.4'): + warnings.warn("This interface is designed to work with Pytorch >= 0.4", + RuntimeWarning, stacklevel=2) + +__all__ = ('OperatorFunction', 'OperatorModule') + +class OperatorFunction(Function): + @staticmethod + def forward( + ctx, + input_tensor:torch.Tensor, + operator:Operator, + ): + assert len(input_tensor.size()) == 5 + extra_dims = input_tensor.size()[:2] + if input_tensor.requires_grad: + ctx.operator = operator + ctx.extra_dims = extra_dims + + output = input_tensor.new_empty(extra_dims + operator.range.shape, dtype=torch.float32) # type:ignore + + for subspace in itertools.product(*[range(dim_size) for dim_size in extra_dims]): + output[subspace] = operator(input_tensor[subspace]).data # type:ignore + return output + + @staticmethod + def backward(ctx, grad_output): + operator = ctx.operator + grad_input = grad_output.new_empty(ctx.extra_dims + operator.domain.shape, dtype=torch.float32) # type:ignore + + for subspace in itertools.product(*[range(dim_size) for dim_size in ctx.extra_dims]): + grad_input[subspace] = operator.adjoint(grad_output[subspace]).data + + return grad_input, None + +class OperatorModule(torch.nn.Module): + + def __init__(self, operator:Operator): + """Initialize a new instance.""" + super(OperatorModule, self).__init__() + self.operator = operator + + def forward(self, input_tensor:torch.Tensor): + return OperatorFunction.apply( + input_tensor, + self.operator + ) + + def __repr__(self): + """Return ``repr(self)``.""" + op_name = self.operator.__class__.__name__ + op_in_shape = self.operator.domain.shape #type:ignore + if len(op_in_shape) == 1: + op_in_shape = op_in_shape[0] + op_out_shape = self.operator.range.shape #type:ignore + if len(op_out_shape) == 1: + op_out_shape = op_out_shape[0] + + return '{}({}) ({} -> {})'.format( + self.__class__.__name__, op_name, op_in_shape, op_out_shape + ) \ No newline at end of file diff --git a/odl/operator/operator.py b/odl/operator/operator.py index 3a1a21b39ea..ccd09fbc8c7 100644 --- a/odl/operator/operator.py +++ b/odl/operator/operator.py @@ -675,6 +675,7 @@ def __call__(self, x, out=None, **kwargs): if out is not None: # In-place evaluation if out not in self.range: + ### This is a bizarre check, which seems to expect numpy arrays raise OpRangeError('`out` {!r} not an element of the range ' '{!r} of {!r}' ''.format(out, self.range, self)) @@ -684,6 +685,7 @@ def __call__(self, x, out=None, **kwargs): 'when range is a field') result = self._call_in_place(x, out=out, **kwargs) + ### This is a bizarre check, which seems to expect numpy arrays if result is not None and result is not out: raise ValueError('`op` returned a different value than `out`. ' 'With in-place evaluation, the operator can ' diff --git a/odl/space/pytorch_tensors.py b/odl/space/pytorch_tensors.py index 48740e1c8b4..fc19d677bc2 100644 --- a/odl/space/pytorch_tensors.py +++ b/odl/space/pytorch_tensors.py @@ -24,7 +24,7 @@ from odl.space.weighting import ( ArrayWeighting, ConstWeighting, CustomDist, CustomInner, CustomNorm, Weighting) -from odl.util.utility import _CORRESPONDING_PYTORCH_DTYPES +from odl.util.utility import ArrayOnPytorchManager, _CORRESPONDING_PYTORCH_DTYPES from odl.util import ( dtype_str, is_floating_dtype, is_numeric_dtype, is_real_dtype, nullcontext, signature_string, writable_array) @@ -422,11 +422,16 @@ def element(self, inp=None, data_ptr=None, order=None): if order is not None and str(order).upper() not in ('C'): raise ValueError(f"Only row-major order supported ('C'), not '{order}'.") - if inp is None and data_ptr is None: - arr = torch.empty(self.shape, dtype=self._torch_dtype, device=self._torch_device) - + def wrapped_array(arr): + if arr.shape != self.shape: + raise ValueError('shape of `inp` not equal to space shape: ' + '{} != {}'.format(arr.shape, self.shape)) return self.element_type(self, arr) + if inp is None and data_ptr is None: + return wrapped_array(torch.empty( + self.shape, dtype=self._torch_dtype, device=self._torch_device)) + elif inp is None and data_ptr is not None: if order is None: raise ValueError('`order` cannot be None for element ' @@ -437,20 +442,16 @@ def element(self, inp=None, data_ptr=None, order=None): as_numpy_array = np.ctypeslib.as_array(as_ctype_array) arr = as_numpy_array.view(dtype=self._torch_dtype) arr = arr.reshape(self.shape, order=order) - return self.element_type(self, torch.Tensor(arr)) + return wrapped_array(torch.tensor( + arr, dtype=self._torch_dtype, device=self._torch_device)) elif inp is not None and data_ptr is None: if inp in self and order is None: # Short-circuit for space elements and no enforced ordering return inp - # TODO avoid copy when it's not necessary - arr = torch.tensor(inp, dtype=self._torch_dtype, device=self._torch_device) - - if arr.shape != self.shape: - raise ValueError('shape of `inp` not equal to space shape: ' - '{} != {}'.format(arr.shape, self.shape)) - return self.element_type(self, arr) + return wrapped_array(ArrayOnPytorchManager(device=self._torch_device) + .as_compatible_array(inp, dtype=self._torch_dtype)) else: raise TypeError('cannot provide both `inp` and `data_ptr`') @@ -844,8 +845,8 @@ def __repr__(self): optargs = [] optmod = '' - inner_str = signature_string(posargs, optargs, mod=['', optmod]) - weight_str = self.weighting.repr_part + inner_str = signature_string(posargs, optargs, mod=['', optmod]) # type:ignore + weight_str = self.weighting.repr_part # type:ignore if weight_str: inner_str += ', ' + weight_str @@ -1418,7 +1419,7 @@ def __long__(self): This method is only useful in Python 2. """ - return long(self.data) + return long(self.data) # type:ignore def __float__(self): """Return ``float(self)``.""" diff --git a/odl/tomo/backends/astra_binders.py b/odl/tomo/backends/astra_binders.py new file mode 100644 index 00000000000..f6e8c5032a5 --- /dev/null +++ b/odl/tomo/backends/astra_binders.py @@ -0,0 +1,156 @@ +############################################################################### +# This code was taken from tomosipo and adapted to ODL API # +# Please check https://github.com/ahendriksen/tomosipo # +############################################################################### +"""ASTRA conversion and projection code + +There are two geometry conversion methods: + +- from_astra +- to_astra + +An important method is `create_astra_projector`, which creates an ASTRA +projector from a pair of geometries. + +Moreover, there is projection code that is centered around the following +ASTRA APIs: + +1. astra.experimental.direct_FPBP3D (modern) +2. astra.experimental.do_composite (legacy) + +The first is used in modern tomosipo code: it takes an existing ASTRA projector +and a link to a numpy or gpu array. + +The second is a legacy approach that is kept for debugging and testing purposes. +It takes multiple Data objects describing volumes (both data and geometry) and +projection geometries (both data and geometry). On this basis, it creates a +projector and passes it to ASTRA, which performs an all-to-all (back)projection. + +""" +try: + import astra.experimental as experimental + + ASTRA_BINDERS_AVAILABLE = True +except ImportError: + ASTRA_BINDERS_AVAILABLE = False + +__all__ = ( + 'ASTRA_BINDERS_AVAILABLE', +) +from odl.tomo.backends import links + +############################################################################### +# Direct ASTRA projection (modern) # +############################################################################### +def direct_project( + projector, + vol_link, + proj_link, + forward=None, + additive=False, +): + """Project forward or backward + + Forward/back projects a volume onto a projection dataset. + + :param projector: ?? + It is possible to provide a pre-generated ASTRA projector. Use + `ts.Operator.astra_projector` to generate an astra projector. + :param vol_link: TODO + :param proj_link: TODO + :param forward: bool + True for forward project, False for backproject. + :param additive: bool + If True, add projection data to existing data. Otherwise + overwrite data. + :returns: + :rtype: + + """ + if forward is None: + raise ValueError("project must be given a forward argument (True/False).") + + # These constants have become the default. See: + # https://github.com/astra-toolbox/astra-toolbox/commit/4d673b3cdb6d27d430087758a8081e4a10267595 + MODE_SET = 1 + MODE_ADD = 0 + + if not links.are_compatible(vol_link, proj_link): + raise ValueError( + "Cannot perform ASTRA projection on volume and projection data, because they are not compatible. " + "Usually, this indicates that the data are located on different computing devices. " + ) + + # If necessary, the link may adjust the current state of the + # program temporarily to ensure ASTRA runs correctly. For torch + # tensors, this may entail changing the currently active GPU + with vol_link.context(): + experimental.direct_FPBP3D( #type:ignore + projector, + vol_link.linked_data, + proj_link.linked_data, + MODE_ADD if additive else MODE_SET, + "FP" if forward else "BP", + ) + +def direct_fp( + projector, + vol_data, + proj_data, + additive=False, +): + """Project forward or backward + + Forward/back projects a volume onto a projection dataset. + + :param projector: ?? + It is possible to provide a pre-generated ASTRA projector. Use + `ts.Operator.astra_projector` to generate an astra projector. + :param vol_data: TODO + :param proj_data: TODO + :param additive: bool + If True, add projection data to existing data. Otherwise + overwrite data. + :returns: + :rtype: + + """ + return direct_project( + projector, + vol_data, + proj_data, + forward=True, + additive=additive, + ) + + +def direct_bp( + projector, + vol_data, + proj_data, + additive=False, +): + """Project forward or backward + + Forward/back projects a volume onto a projection dataset. + + :param projector: ?? + It is possible to provide a pre-generated ASTRA projector. Use + `ts.Operator.astra_projector` to generate an astra projector. + :param vol_data: TODO + :param proj_data: TODO + :param additive: bool + If True, add projection data to existing data. Otherwise + overwrite data. + :returns: + :rtype: + + """ + return direct_project( + projector, + vol_data, + proj_data, + forward=False, + additive=additive, + ) + diff --git a/odl/tomo/backends/astra_cuda_link.py b/odl/tomo/backends/astra_cuda_link.py new file mode 100644 index 00000000000..6c159ca2e3c --- /dev/null +++ b/odl/tomo/backends/astra_cuda_link.py @@ -0,0 +1,436 @@ +# Copyright 2014-2020 The ODL contributors +# +# This file is part of ODL. +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v. 2.0. If a copy of the MPL was not distributed with this file, You can +# obtain one at https://mozilla.org/MPL/2.0/. + +"""Backend for ASTRA using CUDA.""" + +from __future__ import absolute_import, division, print_function + +import warnings +from multiprocessing import Lock + +import numpy as np +from packaging.version import parse as parse_version + +from odl.discr import DiscretizedSpace +from odl.tomo.backends.astra_setup import ( + ASTRA_VERSION, astra_projection_geometry, + astra_projector, astra_supports, astra_versions_supporting, + astra_volume_geometry) +from odl.tomo.backends.astra_binders import ( + direct_fp, direct_bp +) +from odl.tomo.backends.util import _add_default_complex_impl +from odl.tomo.geometry import ( + ConeBeamGeometry, FanBeamGeometry, Geometry, Parallel2dGeometry, + Parallel3dAxisGeometry, Geometry) + +from odl.tomo.backends import links +from odl.discr.discr_space import DiscretizedSpaceElement +try: + import astra + ASTRA_CUDA_AVAILABLE = astra.astra.use_cuda() +except ImportError: + ASTRA_CUDA_AVAILABLE = False + +__all__ = ( + 'ASTRA_CUDA_AVAILABLE', +) + +def _to_link(array, shape): + return links.base.link(array, shape) + +class AstraCudaLinkImpl: + """`RayTransform` implementation for CUDA algorithms in ASTRA for PyTorch Tensors.""" + + algo_forward_id = None + algo_backward_id = None + vol_id = None + sino_id = None + proj_id = None + + def __init__( + self, + geometry:Geometry, + vol_space:DiscretizedSpace, + proj_space:DiscretizedSpace, + additive = False + ): + """Initialize a new instance. + + Parameters + ---------- + geometry : `Geometry` + Geometry defining the tomographic setup. + vol_space : `DiscretizedSpace` + Reconstruction space, the space of the images to be forward + projected. + proj_space : `DiscretizedSpace` + Projection space, the space of the result. + additive: `bool` (optional) + Specifies whether the operator should overwrite its range + (forward) and domain (transpose). When `additive=True`, + the operator adds instead of overwrites. The default is + `additive=False`. + """ + if not isinstance(geometry, Geometry): + raise TypeError( + '`geometry` must be a `Geometry` instance, got {!r}' + ''.format(geometry) + ) + if not isinstance(vol_space, DiscretizedSpace): + raise TypeError( + '`vol_space` must be a `DiscretizedSpace` instance, got {!r}' + ''.format(vol_space) + ) + if not isinstance(proj_space, DiscretizedSpace): + raise TypeError( + '`proj_space` must be a `DiscretizedSpace` instance, got {!r}' + ''.format(proj_space) + ) + + # Print a warning if the detector midpoint normal vector at any + # angle is perpendicular to the geometry axis in parallel 3d + # single-axis geometry -- this is broken in some ASTRA versions + if ( + isinstance(geometry, Parallel3dAxisGeometry) + and not astra_supports('par3d_det_mid_pt_perp_to_axis') + ): + req_ver = astra_versions_supporting( + 'par3d_det_mid_pt_perp_to_axis' + ) + axis = geometry.axis + mid_pt = geometry.det_params.mid_pt + for i, angle in enumerate(geometry.angles): + if abs( + np.dot(axis, geometry.det_to_src(angle, mid_pt)) + ) < 1e-4: + warnings.warn( + 'angle {}: detector midpoint normal {} is ' + 'perpendicular to the geometry axis {} in ' + '`Parallel3dAxisGeometry`; this is broken in ' + 'ASTRA {}, please upgrade to ASTRA {}' + ''.format(i, geometry.det_to_src(angle, mid_pt), + axis, ASTRA_VERSION, req_ver), + RuntimeWarning) + break + + self.geometry = geometry + self._vol_space = vol_space + self._proj_space = proj_space + self.additive = additive + self.create_ids() + + # ASTRA projectors are not thread-safe, thus we need to lock manually + self._mutex = Lock() + + @property + def vol_space(self): + return self._vol_space + + @property + def proj_space(self): + return self._proj_space + + def create_ids(self): + """Create ASTRA objects.""" + # Create input and output arrays + if self.geometry.motion_partition.ndim == 1: + motion_shape = self.geometry.motion_partition.shape + else: + # Need to flatten 2- or 3-dimensional angles into one axis + motion_shape = (np.prod(self.geometry.motion_partition.shape),) + + proj_shape = motion_shape + self.geometry.det_partition.shape + proj_ndim = len(proj_shape) + + if proj_ndim == 2: + astra_proj_shape = proj_shape + astra_vol_shape = self.vol_space.shape + elif proj_ndim == 3: + # The `u` and `v` axes of the projection data are swapped, + # see explanation in `astra_*_3d_geom_to_vec`. + astra_proj_shape = (proj_shape[1], proj_shape[0], proj_shape[2]) + astra_vol_shape = self.vol_space.shape + + self.astra_vol_shape = astra_vol_shape + self.astra_proj_shape = astra_proj_shape + + + # Create ASTRA data structures + self.vol_geom = astra_volume_geometry(self.vol_space) + self.proj_geom = astra_projection_geometry(self.geometry) + + # proj_type = 'cuda' if proj_ndim == 2 else 'cuda3d' + # As of now, things DO NOT work in 2d, soz + proj_type = 'cuda3d' + self.proj_id = astra_projector( + proj_type, self.vol_geom, self.proj_geom, proj_ndim + ) + + self.forward_scaling = astra_cuda_fp_scaling_factor( + self.geometry) + + self.backward_scaling = astra_cuda_bp_scaling_factor( + self.proj_space, self.vol_space, self.geometry + ) + + def _call_forward_real( + self, + volume:DiscretizedSpaceElement, + out=None + ): + ### TODO: put that in the __init__ + if volume.impl == 'numpy': + transpose_tuple = (1,0,2) + elif volume.impl == 'pytorch': + transpose_tuple = (1,0) + else: + raise NotImplementedError('Not implemented for another backend') + vlink = _to_link(volume.data, self.astra_vol_shape) + if out is not None: + raise NotImplementedError('Not implemented for in-place calls') + # plink = _to_link(out.data.transpose(*transpose_tuple), self.astra_proj_shape) + + else: + if self.additive: + plink = vlink.new_zeros(self.astra_proj_shape) + else: + plink = vlink.new_empty(self.astra_proj_shape) + + direct_fp( + self.proj_id, + vlink, + plink, + additive=self.additive + ) + + if self.geometry.ndim == 2: + raise NotImplementedError + elif self.geometry.ndim == 3: + if out is not None: + raise NotImplementedError('Not implemented for in-place calls') + # return plink.data * self.forward_scaling + else: + return plink.data.transpose(*transpose_tuple) * self.forward_scaling + + def _call_backward_real( + self, + projection:DiscretizedSpaceElement, + out=None, + **kwargs + ): + ### TODO: put that in the __init__ + if projection.impl == 'numpy': + transpose_tuple = (1,0,2) + elif projection.impl == 'pytorch': + transpose_tuple = (1,0) + else: + raise NotImplementedError('Not implemented for another backend') + + plink = _to_link(projection.data.transpose(*transpose_tuple), self.astra_proj_shape) + + if out is not None: + raise NotImplementedError('Not implemented for in-place calls') + # vlink = _to_link(out.data, self.astra_vol_shape) + else: + if self.additive: + vlink = plink.new_zeros(self.astra_vol_shape) + else: + vlink = plink.new_empty(self.astra_vol_shape) + + direct_bp( + self.proj_id, + vlink, + plink, + additive=self.additive, + ) + if out is not None: + raise NotImplementedError('Not implemented for in-place calls') + # return vlink.data * self.backward_scaling + else: + return vlink.data * self.backward_scaling + + + @_add_default_complex_impl + def call_forward(self, x, out=None, **kwargs): + return self._call_forward_real(x, out, **kwargs) + + @_add_default_complex_impl + def call_backward(self, x, out=None, **kwargs): + return self._call_backward_real(x, out, **kwargs) + + def __del__(self): + """Delete ASTRA objects.""" + if self.geometry.ndim == 2: + adata, aproj = astra.data2d, astra.projector + else: + adata, aproj = astra.data3d, astra.projector3d + + if self.algo_forward_id is not None: + astra.algorithm.delete(self.algo_forward_id) + self.algo_forward_id = None + if self.algo_backward_id is not None: + astra.algorithm.delete(self.algo_backward_id) + self.algo_backward_id = None + if self.vol_id is not None: + adata.delete(self.vol_id) + self.vol_id = None + if self.sino_id is not None: + adata.delete(self.sino_id) + self.sino_id = None + if self.proj_id is not None: + aproj.delete(self.proj_id) + self.proj_id = None + +def astra_cuda_fp_scaling_factor(geometry): + if ( + isinstance(geometry, Parallel2dGeometry) + and parse_version(ASTRA_VERSION) < parse_version('1.9.9.dev') + ): + # parallel2d scales with pixel stride + return 1 / float(geometry.det_partition.cell_volume) + else: + return 1 + +def astra_cuda_bp_scaling_factor(proj_space, vol_space, geometry): + """Volume scaling accounting for differing adjoint definitions. + + ASTRA defines the adjoint operator in terms of a fully discrete + setting (transposed "projection matrix") without any relation to + physical dimensions, which makes a re-scaling necessary to + translate it to spaces with physical dimensions. + + Behavior of ASTRA changes slightly between versions, so we keep + track of it and adapt the scaling accordingly. + """ + # Angular integration weighting factor + # angle interval weight by approximate cell volume + angle_extent = geometry.motion_partition.extent + num_angles = geometry.motion_partition.shape + # TODO: this gives the wrong factor for Parallel3dEulerGeometry with + # 2 angles + scaling_factor = (angle_extent / num_angles).prod() + + # Correct in case of non-weighted spaces + proj_extent = float(proj_space.partition.extent.prod()) + proj_size = float(proj_space.partition.size) + proj_weighting = proj_extent / proj_size + + scaling_factor *= ( + proj_space.weighting.const / proj_weighting + ) + scaling_factor /= ( + vol_space.weighting.const / vol_space.cell_volume + ) + + if parse_version(ASTRA_VERSION) < parse_version('1.8rc1'): + # Scaling for the old, pre-1.8 behaviour + if isinstance(geometry, Parallel2dGeometry): + # Scales with 1 / cell_volume + scaling_factor *= float(vol_space.cell_volume) + elif (isinstance(geometry, FanBeamGeometry) + and geometry.det_curvature_radius is None): + # Scales with 1 / cell_volume + scaling_factor *= float(vol_space.cell_volume) + # Additional magnification correction + src_radius = geometry.src_radius + det_radius = geometry.det_radius + scaling_factor *= ((src_radius + det_radius) / src_radius) + elif isinstance(geometry, Parallel3dAxisGeometry): + # Scales with voxel stride + # In 1.7, only cubic voxels are supported + voxel_stride = vol_space.cell_sides[0] + scaling_factor /= float(voxel_stride) + elif (isinstance(geometry, ConeBeamGeometry) + and geometry.det_curvature_radius is None): + # Scales with 1 / cell_volume + # In 1.7, only cubic voxels are supported + voxel_stride = vol_space.cell_sides[0] + scaling_factor /= float(voxel_stride) + # Magnification correction + src_radius = geometry.src_radius + det_radius = geometry.det_radius + scaling_factor *= ((src_radius + det_radius) / src_radius) ** 2 + elif parse_version(ASTRA_VERSION) < parse_version('1.9.0dev'): + # Scaling for the 1.8.x releases + if isinstance(geometry, Parallel2dGeometry): + # Scales with 1 / cell_volume + scaling_factor *= float(vol_space.cell_volume) + elif (isinstance(geometry, FanBeamGeometry) + and geometry.det_curvature_radius is None): + # Scales with 1 / cell_volume + scaling_factor *= float(vol_space.cell_volume) + # Magnification correction + src_radius = geometry.src_radius + det_radius = geometry.det_radius + scaling_factor *= ((src_radius + det_radius) / src_radius) + elif isinstance(geometry, Parallel3dAxisGeometry): + # Scales with cell volume + # currently only square voxels are supported + scaling_factor /= vol_space.cell_volume + elif (isinstance(geometry, ConeBeamGeometry) + and geometry.det_curvature_radius is None): + # Scales with cell volume + scaling_factor /= vol_space.cell_volume + # Magnification correction (scaling = 1 / magnification ** 2) + src_radius = geometry.src_radius + det_radius = geometry.det_radius + scaling_factor *= ((src_radius + det_radius) / src_radius) ** 2 + + # Correction for scaled 1/r^2 factor in ASTRA's density weighting. + # This compensates for scaled voxels and pixels, as well as a + # missing factor src_radius ** 2 in the ASTRA BP with + # density weighting. + det_px_area = geometry.det_partition.cell_volume + scaling_factor *= ( + src_radius ** 2 * det_px_area ** 2 / vol_space.cell_volume ** 2 + ) + elif parse_version(ASTRA_VERSION) < parse_version('1.9.9.dev'): + # Scaling for intermediate dev releases between 1.8.3 and 1.9.9.dev + if isinstance(geometry, Parallel2dGeometry): + # Scales with 1 / cell_volume + scaling_factor *= float(vol_space.cell_volume) + elif (isinstance(geometry, FanBeamGeometry) + and geometry.det_curvature_radius is None): + # Scales with 1 / cell_volume + scaling_factor *= float(vol_space.cell_volume) + # Magnification correction + src_radius = geometry.src_radius + det_radius = geometry.det_radius + scaling_factor *= ((src_radius + det_radius) / src_radius) + elif isinstance(geometry, Parallel3dAxisGeometry): + # Scales with cell volume + # currently only square voxels are supported + scaling_factor /= vol_space.cell_volume + elif (isinstance(geometry, ConeBeamGeometry) + and geometry.det_curvature_radius is None): + # Scales with cell volume + scaling_factor /= vol_space.cell_volume + # Magnification correction (scaling = 1 / magnification ** 2) + src_radius = geometry.src_radius + det_radius = geometry.det_radius + scaling_factor *= ((src_radius + det_radius) / src_radius) ** 2 + + # Correction for scaled 1/r^2 factor in ASTRA's density weighting. + # This compensates for scaled voxels and pixels, as well as a + # missing factor src_radius ** 2 in the ASTRA BP with + # density weighting. + det_px_area = geometry.det_partition.cell_volume + scaling_factor *= (src_radius ** 2 * det_px_area ** 2) + else: + # Scaling for versions since 1.9.9.dev + scaling_factor /= float(vol_space.cell_volume) + scaling_factor *= float(geometry.det_partition.cell_volume) + + return scaling_factor + + +if __name__ == '__main__': + from odl.util.testutils import run_doctests + + run_doctests() diff --git a/odl/tomo/backends/links/__init__.py b/odl/tomo/backends/links/__init__.py new file mode 100644 index 00000000000..aa04ed554a9 --- /dev/null +++ b/odl/tomo/backends/links/__init__.py @@ -0,0 +1,27 @@ +from . import base +from . import numpy + +from .base import ( + are_compatible +) + + +def _is_package_available(package): + import importlib + + try: + importlib.import_module(package) + return True + except ModuleNotFoundError: + return False + + +if _is_package_available("astra"): + + # Import torch linking facility if torch is available. + if _is_package_available("torch"): + from . import torch + + # Import cupy linking facility if cupy is available. + if _is_package_available("cupy"): + from . import cupy diff --git a/odl/tomo/backends/links/base.py b/odl/tomo/backends/links/base.py new file mode 100644 index 00000000000..c18fcbc9e78 --- /dev/null +++ b/odl/tomo/backends/links/base.py @@ -0,0 +1,131 @@ +############################################################################### +# This code was taken from tomosipo and adapted to ODL API # +# Please check https://github.com/ahendriksen/tomosipo # +############################################################################### + +from contextlib import contextmanager +import warnings + +backends = [ + # numpy backend is imported by default; + # torch backend is only supported when the PyTorch package has been installed; + # cupy backend is only supported when the CuPy package has been installed. +] + +def link(arr, shape): + for backend in backends: + if backend.__accepts__(arr): + return backend(shape, arr) + raise ValueError(f"An initial_value of class {type(arr)} is not supported. ") + +def are_compatible(link_a, link_b): + a_compat_with_b = link_a.__compatible_with__(link_b) + if a_compat_with_b is True: + return True + elif a_compat_with_b == NotImplemented: + b_compat_with_a = link_b.__compatible_with__(link_a) + if b_compat_with_a is True: + return True + elif b_compat_with_a == NotImplemented: + warnings.warn( + f"Cannot determine if link of type {type(link_a)} is compatible with {type(link_b)}. " + "Continuing anyway." + ) + else: + return False + else: + return False + + +class Link(object): + """A General base class for link types""" + + def __init__(self, shape, initial_value): + self._shape = shape + super().__init__() + + ########################################################################### + # "Protocol" functions / methods # + ########################################################################### + @staticmethod + def __accepts__(initial_value): + """Determines if the link class can make use of the initial_value + + :param initial_value: + :returns: + :rtype: + + """ + raise NotImplementedError() + + def __compatible_with__(self, other): + """Can ASTRA project from this link to other link?""" + raise NotImplementedError() + + ########################################################################### + # Properties # + ########################################################################### + @property + def linked_data(self): + """Returns a numpy array or GPULink object + + :returns: + :rtype: + + """ + raise NotImplementedError() + + @property + def data(self): + """Returns the underlying data. + + Changes to the return value will be reflected in the astra + data. + """ + raise NotImplementedError() + + @data.setter + def data(self, val): + raise AttributeError( + "You cannot change which array backs a dataset.\n" + "To change the underlying data instead, use: \n" + " >>> x.data[:] = new_data\n" + ) + + @property + def shape(self): + return self._shape + + ########################################################################### + # Context manager # + ########################################################################### + @contextmanager + def context(self): + """Context-manager to manage ASTRA interactions + + This is a no-op for numpy data. + + This context-manager used, for example, for pytorch data on + GPU to make sure the current CUDA stream is set to the device + of the input data. + + :returns: + :rtype: + + """ + raise NotImplementedError() + + ########################################################################### + # New data creation # + ########################################################################### + def new_zeros(self, shape): + raise NotImplementedError() + + def new_full(self, shape, value): + raise NotImplementedError() + + def new_empty(self, shape): + raise NotImplementedError() + + def clone(self): + raise NotImplementedError() diff --git a/odl/tomo/backends/links/numpy.py b/odl/tomo/backends/links/numpy.py new file mode 100644 index 00000000000..298f3330b6c --- /dev/null +++ b/odl/tomo/backends/links/numpy.py @@ -0,0 +1,149 @@ +############################################################################### +# This code was taken from tomosipo and adapted to ODL API # +# Please check https://github.com/ahendriksen/tomosipo # +############################################################################### + +import numpy as np +import warnings +from contextlib import contextmanager +from .base import Link, backends + + +class NumpyLink(Link): + """Link implementation for numpy arrays""" + + def __init__(self, shape, initial_value): + super(NumpyLink, self).__init__(shape, initial_value) + + if initial_value is None: + self._data = np.zeros(shape, dtype=np.float32) + elif np.isscalar(initial_value): + self._data = np.zeros(shape, dtype=np.float32) + self._data[:] = initial_value + else: + initial_value = np.array(initial_value, copy=False) + if initial_value.shape != shape: + raise ValueError( + "Cannot link array. " + f"Expected array of shape {shape}. Got {initial_value.shape}" + ) + # Make contiguous: + if initial_value.dtype != np.float32: + warnings.warn( + f"The parameter initial_value is of type {initial_value.dtype}; expected `np.float32`. " + f"The type has been Automatically converted. " + f"Use `ts.link(x.astype(np.float32))' to inhibit this warning. " + ) + initial_value = initial_value.astype(np.float32) + if not ( + initial_value.flags["C_CONTIGUOUS"] and initial_value.flags["ALIGNED"] + ): + warnings.warn( + f"The parameter initial_value should be C_CONTIGUOUS and ALIGNED. " + f"It has been automatically made contiguous and aligned. " + f"Use `ts.link(np.ascontiguousarray(x))' to inhibit this warning. " + ) + initial_value = np.ascontiguousarray(initial_value) + self._data = initial_value + + ########################################################################### + # "Protocol" functions / methods # + ########################################################################### + @staticmethod + def __accepts__(initial_value): + # `NumpyLink' is the default backend, so it should accept + # an initial_value of `None'. + if initial_value is None: + return True + elif isinstance(initial_value, np.ndarray): + return True + elif np.isscalar(initial_value): + return True + else: + return False + + def __compatible_with__(self, other): + if isinstance(other, NumpyLink): + return True + else: + return NotImplemented + + ########################################################################### + # Properties # + ########################################################################### + @property + def linked_data(self): + """Returns a numpy array or GPULink object + + :returns: + :rtype: + + """ + return self._data + + @property + def data(self): + """Returns a shared numpy array with the underlying data. + + Changes to the return value will be reflected in the astra + data. + + If you want to avoid this, consider copying the data + immediately, using `np.copy` for instance. + + NOTE: if the underlying object is an Astra projection data + type, the order of the axes will be in (Y, num_angles, X) + order. + + :returns: np.array + :rtype: np.array + + """ + return self._data + + @data.setter + def data(self, val): + raise AttributeError( + "You cannot change which array backs a dataset.\n" + "To change the underlying data instead, use: \n" + " >>> x.data[:] = new_data\n" + ) + + ########################################################################### + # Context manager # + ########################################################################### + @contextmanager + def context(self): + """Context-manager to manage ASTRA interactions + + This is a no-op for numpy data. + + """ + yield + + ########################################################################### + # New data creation # + ########################################################################### + def new_zeros(self, shape): + return NumpyLink( + shape, + np.zeros(shape, dtype=self._data.dtype), + ) + + def new_full(self, shape, value): + return NumpyLink( + shape, + np.full(shape, value, dtype=self._data.dtype), + ) + + def new_empty(self, shape): + return NumpyLink( + shape, + np.empty(shape, dtype=self._data.dtype), + ) + + def clone(self): + return NumpyLink(self._data.shape, np.copy(self._data)) + + +backends.append(NumpyLink) diff --git a/odl/tomo/backends/links/torch.py b/odl/tomo/backends/links/torch.py new file mode 100644 index 00000000000..3a46f717b22 --- /dev/null +++ b/odl/tomo/backends/links/torch.py @@ -0,0 +1,180 @@ +############################################################################### +# This code was taken from tomosipo and adapted to ODL API # +# Please check https://github.com/ahendriksen/tomosipo # +############################################################################### +"""This module adds support for torch arrays as astra.data3d backends + +This module is not automatically imported by tomosipo, you must import +it manually as follows: + +>>> import tomosipo as ts +>>> import tomosipo.torch_support + +Now, you may use torch tensors as you would numpy arrays: + +>>> vg = ts.volume(shape=(10, 10, 10)) +>>> pg = ts.parallel(angles=10, shape=10) +>>> A = ts.operator(vg, pg) +>>> x = torch.zeros(A.domain_shape) +>>> A(x).shape == A.range_shape +True + +You can also directly apply the tomographic operator to data on the +GPU: + +>>> A(x.cuda()).is_cuda +True +""" +import astra +from .base import Link, backends +from .numpy import NumpyLink +from contextlib import contextmanager +import warnings +import torch + + +class TorchLink(Link): + """Link implementation for torch arrays""" + + def __init__(self, shape, initial_value): + super(TorchLink, self).__init__(shape, initial_value) + + if not isinstance(initial_value, torch.Tensor): + raise ValueError( + f"Expected initial_value to be a `torch.Tensor'. Got {initial_value.__class__}" + ) + + if initial_value.shape == torch.Size([]): + self._data = torch.zeros( + shape, dtype=torch.float32, device=initial_value.device + ) + self._data[:] = initial_value + else: + if shape != initial_value.shape: + raise ValueError( + f"Expected initial_value with shape {shape}. Got {initial_value.shape}" + ) + # Ensure float32 + if initial_value.dtype != torch.float32: + warnings.warn( + f"The parameter initial_value is of type {initial_value.dtype}; expected `torch.float32`. " + f"The type has been automatically converted. " + f"Use `ts.link(x.to(dtype=torch.float32))' to inhibit this warning. " + ) + initial_value = initial_value.to(dtype=torch.float32) + # Make contiguous: + if not initial_value.is_contiguous(): + warnings.warn( + f"The parameter initial_value should be contiguous. " + f"It has been automatically made contiguous. " + f"Use `ts.link(x.contiguous())' to inhibit this warning. " + ) + initial_value = initial_value.contiguous() + self._data = initial_value + + ########################################################################### + # "Protocol" functions / methods # + ########################################################################### + @staticmethod + def __accepts__(initial_value): + # only accept torch tensors + return isinstance(initial_value, torch.Tensor) + + def __compatible_with__(self, other): + dev_self = self._data.device + if isinstance(other, NumpyLink): + dev_other = torch.device("cpu") + elif isinstance(other, TorchLink): + dev_other = other._data.device + else: + return NotImplemented + + return dev_self == dev_other + + ########################################################################### + # Properties # + ########################################################################### + @property + def linked_data(self): + if self._data.is_cuda: + z, y, x = self._data.shape + pitch = x * 4 # we assume 4 byte float32 values + link = astra.data3d.GPULink(self._data.data_ptr(), x, y, z, pitch) + return link + else: + # The torch tensor may be part of the computation + # graph. It must be detached to obtain a numpy + # array. We assume that this function will only be + # called to feed the data into Astra, which should not + # modify it. So this should be fine. + return self._data.detach().numpy() + + @property + def data(self): + """Returns a shared array with the underlying data. + + Changes to the return value will be reflected in the astra + data. + + If you want to avoid this, consider copying the data + immediately, using `x.data.clone()` for instance. + + NOTE: if the underlying object is an Astra projection data + type, the order of the axes will be in (Y, num_angles, X) + order. + + :returns: torch.Tensor + :rtype: torch.Tensor + + """ + return self._data + + @data.setter + def data(self, val): + raise AttributeError( + "You cannot change which torch tensor backs a dataset.\n" + "To change the underlying data instead, use: \n" + " >>> vd.data[:] = new_data\n" + ) + + ########################################################################### + # Context manager # + ########################################################################### + @contextmanager + def context(self): + """Context-manager to manage ASTRA interactions + + This context-manager makes sure that the current CUDA + stream is set to the CUDA device of the current linked data. + + :returns: + :rtype: + + """ + if self._data.is_cuda: + with torch.cuda.device_of(self._data): + yield + else: + # no-op for cpu-stored data + yield + + ########################################################################### + # New data creation # + ########################################################################### + def new_zeros(self, shape): + return TorchLink(shape, self._data.new_zeros(shape)) + + def new_full(self, shape, value): + return TorchLink(shape, self._data.new_full(shape, value)) + + def new_empty(self, shape): + return TorchLink(shape, self._data.new_empty(shape)) + + def clone(self): + return TorchLink(self._data.shape, self._data.clone()) + + +# When the torch module is mock imported by the Sphinx documentation system, do +# not alter the observable behavior the linking backend. +if not hasattr(torch, "__sphinx_mock__"): + backends.append(TorchLink) diff --git a/odl/tomo/operators/ray_trafo.py b/odl/tomo/operators/ray_trafo.py index 64419fd43b6..2c628d4ce10 100644 --- a/odl/tomo/operators/ray_trafo.py +++ b/odl/tomo/operators/ray_trafo.py @@ -22,6 +22,7 @@ from odl.tomo.backends.astra_cpu import AstraCpuImpl from odl.tomo.backends.astra_cuda import AstraCudaImpl from odl.tomo.backends.skimage_radon import SkImageImpl +from odl.tomo.backends.astra_cuda_link import AstraCudaLinkImpl from odl.tomo.geometry import Geometry from odl.util import is_string @@ -34,6 +35,7 @@ RAY_TRAFO_IMPLS['astra_cpu'] = AstraCpuImpl if ASTRA_CUDA_AVAILABLE: RAY_TRAFO_IMPLS['astra_cuda'] = AstraCudaImpl + RAY_TRAFO_IMPLS['astra_cuda_link'] = AstraCudaLinkImpl __all__ = ('RayTransform',) @@ -60,6 +62,7 @@ def __init__(self, vol_space, geometry, **kwargs): - ``'astra_cuda'``: ASTRA toolbox, using CUDA, 2D or 3D - ``'astra_cpu'``: ASTRA toolbox using CPU, only 2D + - ``'astra_cuda_link'``: ASTRA toolbox, using CUDA and Link API - ``'skimage'``: scikit-image, only 2D parallel with square reconstruction space. @@ -124,7 +127,7 @@ def __init__(self, vol_space, geometry, **kwargs): proj_tspace = vol_space.tspace_type( geometry.partition.shape, - weighting=weighting, + weighting=weighting, #type:ignore dtype=dtype, ) @@ -290,7 +293,7 @@ def get_impl(self, use_cache=True): if not use_cache or self.__cached_impl is None: # Lazily (re)instantiate the backend self.__cached_impl = self._impl_type( - self.geometry, + self.geometry, #type:ignore vol_space=self.domain, proj_space=self.range) @@ -314,7 +317,8 @@ def _call(self, x, out=None, **kwargs): DiscretizedSpaceElement Result of the transform, an element of the range. """ - return self.get_impl(self.use_cache).call_forward(x, out, **kwargs) + return self.get_impl( + self.use_cache).call_forward(x, out, **kwargs) #type:ignore @property def geometry(self): @@ -362,7 +366,7 @@ def _call(self, x, out=None, **kwargs): """ return ray_trafo.get_impl( ray_trafo.use_cache - ).call_backward(x, out, **kwargs) + ).call_backward(x, out, **kwargs)#type:ignore @property def geometry(self): diff --git a/odl/util/utility.py b/odl/util/utility.py index f9df90ac2ee..8ffb049c0d8 100644 --- a/odl/util/utility.py +++ b/odl/util/utility.py @@ -59,8 +59,17 @@ REPR_PRECISION = 4 # For printing scalars and array entries -TYPE_MAP_R2C = {np.dtype(dtype): np.result_type(dtype, 1j) - for dtype in np.sctypes['float']} +### https://numpy.org/doc/stable/reference/arrays.scalars.html#floating-point-types +NP_FLOAT_TYPES = [ + np.half, + np.single, + np.double, + np.longdouble + ] + +TYPE_MAP_R2C = { + np.dtype(dtype): np.result_type(dtype, 1j) for dtype in NP_FLOAT_TYPES + } TYPE_MAP_C2R = {cdt: np.empty(0, dtype=cdt).real.dtype for rdt, cdt in TYPE_MAP_R2C.items()} @@ -606,7 +615,17 @@ class ArrayOnPytorchManager(ABC): def __init__(self, device): self._device = device def as_compatible_array(self, arr, **kwargs): - return torch.tensor(arr, device = self._device, **kwargs) + dtype = kwargs.get('dtype', None) + if isinstance(arr, torch.Tensor): + arr = arr.detach() + if dtype is not None and arr.dtype!=kwargs['dtype']: + arr = arr.type(dtype) + if self._device is not None and arr.device!=self._device: + return arr.to(self._device) + else: + return arr + else: + return torch.tensor(arr, device = self._device, **kwargs) def compatible_zeros(self, shape, **kwargs): return torch.zeros(shape, device = self._device, **kwargs) def compatible_ones(self, shape, **kwargs):