Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Imraj-Singh committed Jul 25, 2024
2 parents 7f5509e + 6a360e6 commit 4e47b7e
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 0 deletions.
144 changes: 144 additions & 0 deletions adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@

#
# Classes implementing the BSREM+PnP algorithm in sirf.STIR
#
# BSREM from https://github.com/SyneRBI/SIRF-Contribs/blob/master/src/Python/sirf/contrib/BSREM/BSREM.py

import numpy
import sirf.STIR as STIR
from sirf.Utilities import examples_data_path

from cil.optimisation.algorithms import Algorithm

import time
import numpy as np


class AdamSkeleton(Algorithm):
''' Main implementation of a modified BSREM algorithm
This essentially implements constrained preconditioned gradient ascent
with an EM-type preconditioner.
In each update step, the gradient of a subset is computed, multiplied by a step_size and a EM-type preconditioner.
Before adding this to the previous iterate, an update_filter can be applied.
Step-size uses relaxation: ``initial_step_size`` / (1 + ``relaxation_eta`` * ``epoch()``)
'''
def __init__(self, data, initial, initial_step_size, relaxation_eta,
update_filter=STIR.TruncateToCylinderProcessor(), **kwargs):
'''
Arguments:
``data``: list of items as returned by `partitioner`
``initial``: initial estimate
``initial_step_size``, ``relaxation_eta``: step-size constants
``update_filter`` is applied on the (additive) update term, i.e. before adding to the previous iterate.
Set the filter to `None` if you don't want any.
'''
super().__init__(**kwargs)
self.x = initial.copy()
self.data = data
self.num_subsets = len(data)
self.initial_step_size = initial_step_size
self.relaxation_eta = relaxation_eta
# compute small number to add to image in preconditioner
# don't make it too small as otherwise the algorithm cannot recover from zeroes.
self.eps = initial.max()/1e3
self.average_sensitivity = initial.get_uniform_copy(0)
for s in range(len(data)):
self.average_sensitivity += self.subset_sensitivity(s)/self.num_subsets
# add a small number to avoid division by zero in the preconditioner
self.average_sensitivity += self.average_sensitivity.max()/1e4
self.subset = 0
self.update_filter = update_filter
self.configured = True

self.alpha = 1e-3
self.beta1 = 0.9
self.beta2 = 0.999
self.eps_adam = 1e-8

self.m = initial.copy()
self.m.fill(np.zeros_like(initial.as_array()))

self.m_hat = initial.copy()
self.m_hat.fill(np.zeros_like(initial.as_array()))

self.v = initial.copy()
self.v.fill(np.zeros_like(initial.as_array()))

self.v_hat = initial.copy()
self.v_hat.fill(np.zeros_like(initial.as_array()))


def subset_sensitivity(self, subset_num):
raise NotImplementedError

def subset_gradient(self, x, subset_num):
raise NotImplementedError

def epoch(self):
return self.iteration // self.num_subsets

def update(self):
g = self.subset_gradient(self.x, self.subset)
#g = (self.x + self.eps) * g / self.average_sensitivity

self.m = self.beta1 * self.m + (1 - self.beta1) * g
g.power(2, out=g)
self.v = self.beta2 * self.v + (1 - self.beta2) * g
self.m_hat = self.m.clone() / (1 - self.beta1 ** (self.iteration+1))
self.v_hat = self.v.clone() / (1 - self.beta2 ** (self.iteration+1))
self.v_hat.sqrt(out=self.v_hat)

self.x_update = self.alpha * self.m_hat / (self.v_hat + self.eps_adam)
if self.update_filter is not None:
self.update_filter.apply(self.x_update)

self.x += self.x_update
# threshold to non-negative

self.x.maximum(0, out=self.x)
self.subset = (self.subset + 1) % self.num_subsets

#self.alpha = self.alpha * 0.98

def update_objective(self):
# required for current CIL (needs to set self.loss)
self.loss.append(self.objective_function(self.x))

def objective_function(self, x):
''' value of objective function summed over all subsets '''
v = 0
for s in range(len(self.data)):
v += self.subset_objective(x, s)
return v

def subset_objective(self, x, subset_num):
''' value of objective function for one subset '''
raise NotImplementedError

class Adam1(AdamSkeleton):
''' BSREM implementation using sirf.STIR objective functions'''
def __init__(self, data, obj_funs, initial, initial_step_size=1, relaxation_eta=0, **kwargs):
'''
construct Algorithm with lists of data and, objective functions, initial estimate, initial step size,
step-size relaxation (per epoch) and optionally Algorithm parameters
'''
self.obj_funs = obj_funs
super().__init__(data, initial, initial_step_size, relaxation_eta, **kwargs)

def subset_sensitivity(self, subset_num):
''' Compute sensitivity for a particular subset'''
self.obj_funs[subset_num].set_up(self.x)
# note: sirf.STIR Poisson likelihood uses `get_subset_sensitivity(0) for the whole
# sensitivity if there are no subsets in that likelihood
return self.obj_funs[subset_num].get_subset_sensitivity(0)

def subset_gradient(self, x, subset_num):
''' Compute gradient at x for a particular subset'''
return self.obj_funs[subset_num].gradient(x)

def subset_objective(self, x, subset_num):
''' value of objective function for one subset '''
return self.obj_funs[subset_num](x)
54 changes: 54 additions & 0 deletions main_ADAM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Main file to modify for submissions.
Once renamed or symlinked as `main.py`, it will be used by `petric.py` as follows:
>>> from main import Submission, submission_callbacks
>>> from petric import data, metrics
>>> algorithm = Submission(data)
>>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks)
"""
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.utilities import callbacks
from petric import Dataset
from sirf.contrib.partitioner import partitioner

from adam import Adam1

assert issubclass(Adam1, Algorithm)


class MaxIteration(callbacks.Callback):
"""
The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout).
This callback forces stopping after `max_iteration` instead.
"""
def __init__(self, max_iteration: int, verbose: int = 1):
super().__init__(verbose)
self.max_iteration = max_iteration

def __call__(self, algorithm: Algorithm):
if algorithm.iteration >= self.max_iteration:
raise StopIteration

class Submission(Adam1):
# note that `issubclass(BSREM1, Algorithm) == True`
def __init__(self, data: Dataset, num_subsets: int = 7, update_objective_interval: int = 10):
"""
Initialisation function, setting up data & (hyper)parameters.
NB: in practice, `num_subsets` should likely be determined from the data.
This is just an example. Try to modify and improve it!
"""
data_sub, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
data.mult_factors, num_subsets,
initial_image=data.OSEM_image)
# WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations)
data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs))
data.prior.set_up(data.OSEM_image)
for f in obj_funs: # add prior evenly to every objective function
f.set_prior(data.prior)

super().__init__(data_sub, obj_funs, initial=data.OSEM_image, initial_step_size=.3, relaxation_eta=.01,
update_objective_interval=update_objective_interval)


submission_callbacks = [MaxIteration(660)]
87 changes: 87 additions & 0 deletions time_it.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@

import numpy as np
from dataclasses import dataclass

from pathlib import Path
import matplotlib.pyplot as plt
from dataclasses import dataclass
from time import time

import numpy as np
import time

import sirf.STIR as STIR
from sirf.contrib.partitioner import partitioner




@dataclass
class Dataset:
acquired_data: STIR.AcquisitionData
additive_term: STIR.AcquisitionData
mult_factors: STIR.AcquisitionData
OSEM_image: STIR.ImageData
prior: STIR.RelativeDifferencePrior
kappa: STIR.ImageData
reference_image: STIR.ImageData | None
whole_object_mask: STIR.ImageData | None
background_mask: STIR.ImageData | None
voi_masks: dict[str, STIR.ImageData]

datasets = ["NeuroLF_Hoffman_Dataset", "Siemens_mMR_NEMA_IQ", "Siemens_Vision600_thorax"]

outdir = "timing"

sirf_verbosity = 0

outdir = Path(outdir)
STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems
STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets()
_ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt'))

num_tries = 10
for dataset in datasets:
print("Timing information for: ", dataset)
srcdir = Path("/mnt/share/petric" + "/" + dataset)

acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs'))
additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs'))
mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs'))
OSEM_image = STIR.ImageData(str(srcdir / 'OSEM_image.hv'))

print("OSEM: ", OSEM_image.shape)
print("acquired_data: ", acquired_data.shape)

n_subs = [1, 2, 4, 8, 16, 32, 64]
ave_forward = []
ave_backward = []
ave_priorgrad = []
ave_prior = []

for k, n_sub in enumerate(n_subs):

data_sub, acq_models, obj_funs = partitioner.data_partition(acquired_data, additive_term,
mult_factors, n_sub,
initial_image=OSEM_image)


y = data_sub[0].copy()
x = OSEM_image.copy()

t1 = time.time()
for i in range(num_tries):
acq_models[0].forward(OSEM_image, out=y)

ave_forward.append(n_sub*(time.time() - t1)/num_tries)
print("FORWARD for {} sub is: {}".format(n_sub, ave_forward[-1]))

t1 = time.time()
for i in range(num_tries):

acq_models[0].adjoint(y, out=x)

ave_backward.append(n_sub*(time.time() - t1)/num_tries)
print("ADJOINT for {} sub is: {}".format(n_sub, ave_backward[-1]))


0 comments on commit 4e47b7e

Please sign in to comment.