Skip to content

Commit

Permalink
Merge pull request #123 from SyneRBI/penalisation_factor
Browse files Browse the repository at this point in the history
Utility to get a penalisation factor for a similar dataset
  • Loading branch information
casperdcl authored Oct 4, 2024
2 parents ecedd0d + 52272ff commit bbb048f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 3 deletions.
86 changes: 86 additions & 0 deletions SIRF_data_preparation/get_penalisation_factor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/usr/bin/env python
"""Find penalisation factor for one dataset based on another
Usage:
get_penalisation_factor.py [--help | options]
Options:
-h, --help
--dataset=<name> dataset name (required)
--ref_dataset=<name> reference dataset name (required)
-w, --write_penalisation_factor write in data/<dataset>/penalisation_factor.txt
"""

# Copyright 2024 University College London
# Licence: Apache-2.0

import sys
from pathlib import Path

import numpy as np
from docopt import docopt

# %% imports
import sirf.STIR
from petric import Dataset, get_data
from SIRF_data_preparation.data_utilities import the_data_path

# %%
__version__ = "0.1.0"

write_penalisation_factor = False
if "ipykernel" not in sys.argv[0]: # clunky way to be able to set variables from within jupyter/VScode without docopt
args = docopt(__doc__, argv=None, version=__version__)

# logging.basicConfig(level=logging.INFO)

dataset = args["--dataset"]
ref_dataset = args["--ref_dataset"]
if dataset is None or ref_dataset is None:
print("Need to set the --dataset arguments")
exit(1)
if args["--write_penalisation_factor"] is not None:
write_penalisation_factor = False
else: # set it by hand, e.g.
ref_dataset = "NeuroLF_Hoffman_Dataset"
dataset = "Siemens_mMR_NEMA_IQ"
write_penalisation_factor = True


# %%
def VOImean(im: sirf.STIR.ImageData, background_mask: sirf.STIR.ImageData) -> float:
background_indices = np.where(background_mask.as_array())
return np.mean(im.as_array()[background_indices]) / len(background_indices)


def backgroundVOImean(dataset: Path) -> float:
im = sirf.STIR.ImageData(str(dataset / "OSEM_image.hv"))
VOI = sirf.STIR.ImageData(str(dataset / "PETRIC" / "VOI_background.hv"))
return VOImean(im, VOI)


def get_penalisation_factor(ref_data: Dataset, cur_data: Dataset) -> float:
ref_mean = VOImean(ref_data.OSEM_image, ref_data.background_mask)
cur_mean = VOImean(cur_data.OSEM_image, cur_data.background_mask)
print(f"ref_mean={ref_mean}, cur_mean={cur_mean}, c/r={cur_mean / ref_mean}, r/c={ref_mean / cur_mean}")
ref_penalisation_factor = ref_data.prior.get_penalisation_factor()
penalisation_factor = ref_penalisation_factor * cur_mean / ref_mean
print(f"penalisation_factor={penalisation_factor}")
return penalisation_factor


# %%
refdir = Path(the_data_path(ref_dataset))
curdir = Path(the_data_path(dataset))
# %%
ref_data = get_data(refdir, outdir=None)
cur_data = get_data(curdir, outdir=None)
# %%
penalisation_factor = get_penalisation_factor(ref_data, cur_data)

# %%
if write_penalisation_factor:
filename = curdir / "penalisation_factor.txt"
print(f"Writing it to {filename}")
with open(filename, "w") as file:
file.write(str(penalisation_factor))
7 changes: 4 additions & 3 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,15 @@ class Dataset:
def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
"""
Load data from `srcdir`, constructs prior and return as a `Dataset`.
Also redirects sirf.STIR log output to `outdir`.
Also redirects sirf.STIR log output to `outdir`, unless that's set to None
"""
srcdir = Path(srcdir)
outdir = Path(outdir)
STIR.set_verbosity(sirf_verbosity) # set to higher value to diagnose problems
STIR.AcquisitionData.set_storage_scheme('memory') # needed for get_subsets()

_ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt'))
if outdir is not None:
outdir = Path(outdir)
_ = STIR.MessageRedirector(str(outdir / 'info.txt'), str(outdir / 'warnings.txt'), str(outdir / 'errors.txt'))
acquired_data = STIR.AcquisitionData(str(srcdir / 'prompts.hs'))
additive_term = STIR.AcquisitionData(str(srcdir / 'additive_term.hs'))
mult_factors = STIR.AcquisitionData(str(srcdir / 'mult_factors.hs'))
Expand Down

0 comments on commit bbb048f

Please sign in to comment.