Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend/pytorch arrays #1660

Open
wants to merge 9 commits into
base: backend/pytorch-arrays
Choose a base branch
from
71 changes: 71 additions & 0 deletions benchmark/dev_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
import numpy as np

import odl

from odl.contrib.torch.new_operator import OperatorModule

import matplotlib.pyplot as plt

if __name__ == '__main__':
device_name = 'cuda:0'
### Define input tensor
dimension = 3
n_points = 64
space = odl.uniform_discr(
[-20 for _ in range(dimension)],
[ 20 for _ in range(dimension)],
[n_points for _ in range(dimension)],
impl='pytorch', torch_device=device_name
)

odl_phantom = odl.phantom.shepp_logan(space, modified=True)
phantom : torch.Tensor = odl_phantom.asarray().unsqueeze(0).unsqueeze(0).to(device_name)
plt.matshow(phantom[0,0,32].detach().cpu())
plt.savefig('phantom')
plt.close()

# <!> enforce float32 conversion, rather than float64
phantom = phantom.to(dtype=torch.float32)
# <!> make tensor contiguous from creation
phantom = phantom.contiguous()
# <!> for the example, input_tensor.requires_grad == True
phantom.requires_grad_()
# Make a 3d single-axis parallel beam geometry with flat detector
# Angles: uniformly spaced, n = 180, min = 0, max = pi
angle_partition = odl.uniform_partition(0, 2 * np.pi, 32)
detector_partition = odl.uniform_partition([-30] * 2, [30] * 2, [100] * 2)
geometry = odl.tomo.Parallel3dAxisGeometry(angle_partition, detector_partition)

# Ray transform (= forward projection).
ray_trafo = odl.tomo.RayTransform(space, geometry, impl='astra_cuda_pytorch')

forward_module = OperatorModule(ray_trafo)
backward_module = OperatorModule(ray_trafo.adjoint)
sinogram :torch.Tensor = forward_module(phantom) #type:ignore

x = torch.zeros(
size = phantom.size(),
device = device_name,
requires_grad=True
)

optimiser = torch.optim.Adam( #type:ignore
[x],
lr = 1e-3
)

noisy_data = forward_module(phantom)
mse_loss =torch.nn.MSELoss()

for _ in range(100):
optimiser.zero_grad()
current_data = forward_module(x)
loss = mse_loss(current_data, noisy_data)
loss.mean().backward()
optimiser.step()

plt.matshow(x[0,0,32].detach().cpu())
plt.savefig('optimised')
plt.close()

95 changes: 95 additions & 0 deletions benchmark/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""This module benchmarks a function with parameters defined in a json file metadata"""
import argparse
from pathlib import Path
from datetime import datetime
import time
import sys

import json
import pandas as pd

N_CALLS = 1
MAX_ITERATIONS = 100

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--metadata_name', required = True)
args = parser.parse_args()

metadata_name = args.metadata_name

### unpack benchmark metadata
try:
with open(f'metadata/{metadata_name}.json', mode ='r', encoding="utf-8") as json_file:
metadata_dict = json.load(json_file)
except FileNotFoundError:
sys.exit(f'No file at metadata/{metadata_name}.json')

### unpack variables
benchmark_dict = {}
for key in ['backend', 'script_name', 'parameters']:
try:
benchmark_dict[key] = metadata_dict[key]
except ValueError:
sys.exit(f'No "{key}" key in the metadata_dict')

### load backend module
if benchmark_dict['backend'] == 'odl':
import scripts.odl_scripts as sc
DEVICE = 'cpu'

elif benchmark_dict['backend'] == 'torch':
import scripts.torch_scripts as sc
try:
DEVICE = metadata_dict['parameters']['device_name']
except ValueError:
DEVICE = 'cpu'

else:
raise NotImplementedError(f'''Backend {benchmark_dict["backend"]} not supported, only
"odl" and "torch"''')

try:
function = getattr(sc, benchmark_dict['script_name'])
except AttributeError:
sys.exit(f'''Script {benchmark_dict["script_name"]} not implemented for backend
{benchmark_dict["backend"]}''')


report_dict = {
"dimension" : [],
"n_points" : [],
"time" : [],
"error" : []
}

for dimension in benchmark_dict["parameters"]['dimensions']:
for n_points in benchmark_dict["parameters"]['n_points']:
print(
f"""Benchmarking {benchmark_dict['script_name']}
for dimension {dimension} and {n_points} points"""
)
for call in range(N_CALLS):
start = time.time()
error = function(
benchmark_dict["parameters"],
dimension, n_points,
MAX_ITERATIONS
)
end = time.time()
report_dict['dimension'].append(dimension)
report_dict['n_points'].append(n_points)
report_dict['time'].append(end - start)
report_dict['error'].append(error)

report_df = pd.DataFrame.from_dict(report_dict)
report_df['device'] = DEVICE
report_df['backend'] = benchmark_dict['backend']
report_df['max_iterations'] = MAX_ITERATIONS
report_df['timestamp'] = pd.Timestamp(datetime.now(), tz=None)
result_file_path = f'results/{metadata_name}.csv'
if Path(result_file_path).is_file():
report_df = pd.concat([
pd.read_csv(result_file_path), report_df
])
report_df.to_csv(f'results/{metadata_name}.csv', index = False)
13 changes: 13 additions & 0 deletions benchmark/metadata/mri_mlem_odl_adam.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"backend":"odl",
"script_name":"mri_mlem_adam",
"parameters":{
"n_points" : [512],
"dimensions" : [2],
"subsampling" : 0.5,
"learning_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"eps" : 1e-8
}
}
14 changes: 14 additions & 0 deletions benchmark/metadata/mri_mlem_torch_adam.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"backend":"torch",
"script_name":"mri_mlem_adam",
"parameters":{
"n_points" : [512],
"dimensions" : [2],
"subsampling" : 0.5,
"device_name":"cuda:0",
"learning_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"eps" : 1e-8
}
}
Loading