-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/SyneRBI/PETRIC-UCL-EWS
- Loading branch information
Showing
3 changed files
with
285 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
|
||
# | ||
# Classes implementing the BSREM+PnP algorithm in sirf.STIR | ||
# | ||
# BSREM from https://github.com/SyneRBI/SIRF-Contribs/blob/master/src/Python/sirf/contrib/BSREM/BSREM.py | ||
|
||
import numpy | ||
import sirf.STIR as STIR | ||
from sirf.Utilities import examples_data_path | ||
|
||
from cil.optimisation.algorithms import Algorithm | ||
|
||
import time | ||
import numpy as np | ||
|
||
|
||
class AdamSkeleton(Algorithm): | ||
''' Main implementation of a modified BSREM algorithm | ||
This essentially implements constrained preconditioned gradient ascent | ||
with an EM-type preconditioner. | ||
In each update step, the gradient of a subset is computed, multiplied by a step_size and a EM-type preconditioner. | ||
Before adding this to the previous iterate, an update_filter can be applied. | ||
Step-size uses relaxation: ``initial_step_size`` / (1 + ``relaxation_eta`` * ``epoch()``) | ||
''' | ||
def __init__(self, data, initial, initial_step_size, relaxation_eta, | ||
update_filter=STIR.TruncateToCylinderProcessor(), **kwargs): | ||
''' | ||
Arguments: | ||
``data``: list of items as returned by `partitioner` | ||
``initial``: initial estimate | ||
``initial_step_size``, ``relaxation_eta``: step-size constants | ||
``update_filter`` is applied on the (additive) update term, i.e. before adding to the previous iterate. | ||
Set the filter to `None` if you don't want any. | ||
''' | ||
super().__init__(**kwargs) | ||
self.x = initial.copy() | ||
self.data = data | ||
self.num_subsets = len(data) | ||
self.initial_step_size = initial_step_size | ||
self.relaxation_eta = relaxation_eta | ||
# compute small number to add to image in preconditioner | ||
# don't make it too small as otherwise the algorithm cannot recover from zeroes. | ||
self.eps = initial.max()/1e3 | ||
self.average_sensitivity = initial.get_uniform_copy(0) | ||
for s in range(len(data)): | ||
self.average_sensitivity += self.subset_sensitivity(s)/self.num_subsets | ||
# add a small number to avoid division by zero in the preconditioner | ||
self.average_sensitivity += self.average_sensitivity.max()/1e4 | ||
self.subset = 0 | ||
self.update_filter = update_filter | ||
self.configured = True | ||
|
||
self.alpha = 1e-3 | ||
self.beta1 = 0.9 | ||
self.beta2 = 0.999 | ||
self.eps_adam = 1e-8 | ||
|
||
self.m = initial.copy() | ||
self.m.fill(np.zeros_like(initial.as_array())) | ||
|
||
self.m_hat = initial.copy() | ||
self.m_hat.fill(np.zeros_like(initial.as_array())) | ||
|
||
self.v = initial.copy() | ||
self.v.fill(np.zeros_like(initial.as_array())) | ||
|
||
self.v_hat = initial.copy() | ||
self.v_hat.fill(np.zeros_like(initial.as_array())) | ||
|
||
|
||
def subset_sensitivity(self, subset_num): | ||
raise NotImplementedError | ||
|
||
def subset_gradient(self, x, subset_num): | ||
raise NotImplementedError | ||
|
||
def epoch(self): | ||
return self.iteration // self.num_subsets | ||
|
||
def update(self): | ||
g = self.subset_gradient(self.x, self.subset) | ||
#g = (self.x + self.eps) * g / self.average_sensitivity | ||
|
||
self.m = self.beta1 * self.m + (1 - self.beta1) * g | ||
g.power(2, out=g) | ||
self.v = self.beta2 * self.v + (1 - self.beta2) * g | ||
self.m_hat = self.m.clone() / (1 - self.beta1 ** (self.iteration+1)) | ||
self.v_hat = self.v.clone() / (1 - self.beta2 ** (self.iteration+1)) | ||
self.v_hat.sqrt(out=self.v_hat) | ||
|
||
self.x_update = self.alpha * self.m_hat / (self.v_hat + self.eps_adam) | ||
if self.update_filter is not None: | ||
self.update_filter.apply(self.x_update) | ||
|
||
self.x += self.x_update | ||
# threshold to non-negative | ||
|
||
self.x.maximum(0, out=self.x) | ||
self.subset = (self.subset + 1) % self.num_subsets | ||
|
||
#self.alpha = self.alpha * 0.98 | ||
|
||
def update_objective(self): | ||
# required for current CIL (needs to set self.loss) | ||
self.loss.append(self.objective_function(self.x)) | ||
|
||
def objective_function(self, x): | ||
''' value of objective function summed over all subsets ''' | ||
v = 0 | ||
for s in range(len(self.data)): | ||
v += self.subset_objective(x, s) | ||
return v | ||
|
||
def subset_objective(self, x, subset_num): | ||
''' value of objective function for one subset ''' | ||
raise NotImplementedError | ||
|
||
class Adam1(AdamSkeleton): | ||
''' BSREM implementation using sirf.STIR objective functions''' | ||
def __init__(self, data, obj_funs, initial, initial_step_size=1, relaxation_eta=0, **kwargs): | ||
''' | ||
construct Algorithm with lists of data and, objective functions, initial estimate, initial step size, | ||
step-size relaxation (per epoch) and optionally Algorithm parameters | ||
''' | ||
self.obj_funs = obj_funs | ||
super().__init__(data, initial, initial_step_size, relaxation_eta, **kwargs) | ||
|
||
def subset_sensitivity(self, subset_num): | ||
''' Compute sensitivity for a particular subset''' | ||
self.obj_funs[subset_num].set_up(self.x) | ||
# note: sirf.STIR Poisson likelihood uses `get_subset_sensitivity(0) for the whole | ||
# sensitivity if there are no subsets in that likelihood | ||
return self.obj_funs[subset_num].get_subset_sensitivity(0) | ||
|
||
def subset_gradient(self, x, subset_num): | ||
''' Compute gradient at x for a particular subset''' | ||
return self.obj_funs[subset_num].gradient(x) | ||
|
||
def subset_objective(self, x, subset_num): | ||
''' value of objective function for one subset ''' | ||
return self.obj_funs[subset_num](x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""Main file to modify for submissions. | ||
Once renamed or symlinked as `main.py`, it will be used by `petric.py` as follows: | ||
>>> from main import Submission, submission_callbacks | ||
>>> from petric import data, metrics | ||
>>> algorithm = Submission(data) | ||
>>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks) | ||
""" | ||
from cil.optimisation.algorithms import Algorithm | ||
from cil.optimisation.utilities import callbacks | ||
from petric import Dataset | ||
from sirf.contrib.partitioner import partitioner | ||
|
||
from adam import Adam1 | ||
|
||
assert issubclass(Adam1, Algorithm) | ||
|
||
|
||
class MaxIteration(callbacks.Callback): | ||
""" | ||
The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout). | ||
This callback forces stopping after `max_iteration` instead. | ||
""" | ||
def __init__(self, max_iteration: int, verbose: int = 1): | ||
super().__init__(verbose) | ||
self.max_iteration = max_iteration | ||
|
||
def __call__(self, algorithm: Algorithm): | ||
if algorithm.iteration >= self.max_iteration: | ||
raise StopIteration | ||
|
||
class Submission(Adam1): | ||
# note that `issubclass(BSREM1, Algorithm) == True` | ||
def __init__(self, data: Dataset, num_subsets: int = 7, update_objective_interval: int = 10): | ||
""" | ||
Initialisation function, setting up data & (hyper)parameters. | ||
NB: in practice, `num_subsets` should likely be determined from the data. | ||
This is just an example. Try to modify and improve it! | ||
""" | ||
data_sub, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, | ||
data.mult_factors, num_subsets, | ||
initial_image=data.OSEM_image) | ||
# WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations) | ||
data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs)) | ||
data.prior.set_up(data.OSEM_image) | ||
for f in obj_funs: # add prior evenly to every objective function | ||
f.set_prior(data.prior) | ||
|
||
super().__init__(data_sub, obj_funs, initial=data.OSEM_image, initial_step_size=.3, relaxation_eta=.01, | ||
update_objective_interval=update_objective_interval) | ||
|
||
|
||
submission_callbacks = [MaxIteration(660)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
|
||
import numpy as np | ||
from dataclasses import dataclass | ||
|
||
from pathlib import Path | ||
import matplotlib.pyplot as plt | ||
from dataclasses import dataclass | ||
from time import time | ||
|
||
import numpy as np | ||
import time | ||
|
||
import sirf.STIR as STIR | ||
from sirf.contrib.partitioner import partitioner | ||
|
||
|
||
|
||
|
||
@dataclass | ||
class Dataset: | ||
acquired_data: STIR.AcquisitionData | ||
additive_term: STIR.AcquisitionData | ||
mult_factors: STIR.AcquisitionData | ||
OSEM_image: STIR.ImageData | ||
prior: STIR.RelativeDifferencePrior | ||
kappa: STIR.ImageData | ||
reference_image: STIR.ImageData | None | ||
whole_object_mask: STIR.ImageData | None | ||
background_mask: STIR.ImageData | None | ||
voi_masks: dict[str, STIR.ImageData] | ||
|
||
datasets = ["NeuroLF_Hoffman_Dataset", "Siemens_mMR_NEMA_IQ", "Siemens_Vision600_thorax"] | ||
|
||
outdir = "timing" | ||
|
||
sirf_verbosity = 0 | ||
|
||
outdir = Path(outdir) | ||
STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems | ||
STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets() | ||
_ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt')) | ||
|
||
num_tries = 10 | ||
for dataset in datasets: | ||
print("Timing information for: ", dataset) | ||
srcdir = Path("/mnt/share/petric" + "/" + dataset) | ||
|
||
acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs')) | ||
additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs')) | ||
mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs')) | ||
OSEM_image = STIR.ImageData(str(srcdir / 'OSEM_image.hv')) | ||
|
||
print("OSEM: ", OSEM_image.shape) | ||
print("acquired_data: ", acquired_data.shape) | ||
|
||
n_subs = [1, 2, 4, 8, 16, 32, 64] | ||
ave_forward = [] | ||
ave_backward = [] | ||
ave_priorgrad = [] | ||
ave_prior = [] | ||
|
||
for k, n_sub in enumerate(n_subs): | ||
|
||
data_sub, acq_models, obj_funs = partitioner.data_partition(acquired_data, additive_term, | ||
mult_factors, n_sub, | ||
initial_image=OSEM_image) | ||
|
||
|
||
y = data_sub[0].copy() | ||
x = OSEM_image.copy() | ||
|
||
t1 = time.time() | ||
for i in range(num_tries): | ||
acq_models[0].forward(OSEM_image, out=y) | ||
|
||
ave_forward.append(n_sub*(time.time() - t1)/num_tries) | ||
print("FORWARD for {} sub is: {}".format(n_sub, ave_forward[-1])) | ||
|
||
t1 = time.time() | ||
for i in range(num_tries): | ||
|
||
acq_models[0].adjoint(y, out=x) | ||
|
||
ave_backward.append(n_sub*(time.time() - t1)/num_tries) | ||
print("ADJOINT for {} sub is: {}".format(n_sub, ave_backward[-1])) | ||
|
||
|