Skip to content

Commit

Permalink
Merge pull request #459 from richardbeare/BuildTemplateFix
Browse files Browse the repository at this point in the history
updated build_template with affine shape correction
  • Loading branch information
ntustison authored May 12, 2023
2 parents 0be998a + 58db724 commit 719bee7
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 12 deletions.
5 changes: 5 additions & 0 deletions ants/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ pybind11_add_module(ResampleImage antscore/ResampleImage.cxx WRAP_ResampleImage.
pybind11_add_module(ThresholdImage antscore/ThresholdImage.cxx WRAP_ThresholdImage.cxx)
pybind11_add_module(TileImages antscore/TileImages.cxx WRAP_TileImages.cxx)

pybind11_add_module(AverageAffineTransform antscore/AverageAffineTransform.cxx WRAP_AverageAffineTransform.cxx)
pybind11_add_module(AverageAffineTransformNoRigid antscore/AverageAffineTransformNoRigid.cxx WRAP_AverageAffineTransformNoRigid.cxx)

## CONTRIB ##
pybind11_add_module(antsImageAugment CONTRIB_antsImageAugment.cxx)

Expand Down Expand Up @@ -167,6 +170,8 @@ target_link_libraries(N4BiasFieldCorrection PRIVATE ${ITK_LIBRARIES} antsUtiliti
target_link_libraries(ResampleImage PRIVATE ${ITK_LIBRARIES} antsUtilities)
target_link_libraries(ThresholdImage PRIVATE ${ITK_LIBRARIES} antsUtilities)
target_link_libraries(TileImages PRIVATE ${ITK_LIBRARIES} antsUtilities)
target_link_libraries(AverageAffineTransform PRIVATE ${ITK_LIBRARIES} antsUtilities registrationUtilities)
target_link_libraries(AverageAffineTransformNoRigid PRIVATE ${ITK_LIBRARIES} antsUtilities registrationUtilities)

## CONTRIB ##
target_link_libraries(antsImageAugment PRIVATE ${ITK_LIBRARIES})
Expand Down
4 changes: 4 additions & 0 deletions ants/lib/LOCAL_antsTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "itkWindowedSincInterpolateImageFunction.h"
#include "itkLabelImageGaussianInterpolateImageFunction.h"
#include "itkTransformFileWriter.h"
#include "itkTransformFactory.h"

#include "itkMacro.h"
#include "itkImage.h"
Expand All @@ -62,6 +63,7 @@
#include "antscore/antsUtilities.h"
#include "itkAffineTransform.h"
#include "LOCAL_antsImage.h"
#include "register_transforms.h"

namespace py = pybind11;

Expand Down Expand Up @@ -377,6 +379,8 @@ py::capsule composeTransforms( std::vector<void *> tformlist,
template <typename TransformBaseType, class PrecisionType, unsigned int Dimension>
py::capsule readTransform( std::string filename, unsigned int dimension, std::string precision )
{
register_transforms();

typedef typename TransformBaseType::Pointer TransformBasePointerType;
typedef typename itk::CompositeTransform<PrecisionType, Dimension> CompositeTransformType;

Expand Down
5 changes: 4 additions & 1 deletion ants/lib/LOCAL_readTransform.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@

#include "LOCAL_readTransform.h"

namespace py = pybind11;
#include "register_transforms.h"

namespace py = pybind11;

unsigned int getTransformDimensionFromFile( std::string filename )
{
register_transforms();
typedef itk::TransformFileReader TransformReaderType1;
typedef typename TransformReaderType1::Pointer TransformReaderType;
TransformReaderType reader = itk::TransformFileReader::New();
Expand All @@ -71,6 +73,7 @@ unsigned int getTransformDimensionFromFile( std::string filename )

std::string getTransformNameFromFile( std::string filename )
{
register_transforms();
typedef itk::TransformFileReader TransformReaderType1;
typedef typename TransformReaderType1::Pointer TransformReaderType;
TransformReaderType reader = itk::TransformFileReader::New();
Expand Down
17 changes: 17 additions & 0 deletions ants/lib/WRAP_AverageAffineTransform.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "antscore/AverageAffineTransform.h"
#include "antscore/AverageAffineTransformNoRigid.h"

namespace py = pybind11;

int AverageAffineTransform( std::vector<std::string> instring )
{
return ants::AverageAffineTransform(instring, NULL);
}

PYBIND11_MODULE(AverageAffineTransform, m)
{
m.def("AverageAffineTransform", &AverageAffineTransform);
}
16 changes: 16 additions & 0 deletions ants/lib/WRAP_AverageAffineTransformNoRigid.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "antscore/AverageAffineTransformNoRigid.h"

namespace py = pybind11;

int AverageAffineTransformNoRigid( std::vector<std::string> instring )
{
return ants::AverageAffineTransformNoRigid(instring, NULL);
}

PYBIND11_MODULE(AverageAffineTransformNoRigid, m)
{
m.def("AverageAffineTransformNoRigid", &AverageAffineTransformNoRigid);
}
3 changes: 2 additions & 1 deletion ants/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
from .ThresholdImage import *
from .integrateVelocityField import *
from .TileImages import *

from .AverageAffineTransform import *
from .AverageAffineTransformNoRigid import *

## CONTRIB ##
# NOTE: contrib contains code which is experimental
Expand Down
22 changes: 22 additions & 0 deletions ants/lib/register_transforms.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef ANTS_REGISTER_TRANSFORM_H_
#define ANTS_REGISTER_TRANSFORM_H_

#include "itkTransform.h"
#include "itkTransformFactory.h"

void register_transforms()
{
using MatrixOffsetTransformTypeA = itk::MatrixOffsetTransformBase<double, 3, 3>;
itk::TransformFactory<MatrixOffsetTransformTypeA>::RegisterTransform();

using MatrixOffsetTransformTypeB = itk::MatrixOffsetTransformBase<float, 3, 3>;
itk::TransformFactory<MatrixOffsetTransformTypeB>::RegisterTransform();

using MatrixOffsetTransformTypeC = itk::MatrixOffsetTransformBase<double, 2, 2>;
itk::TransformFactory<MatrixOffsetTransformTypeC>::RegisterTransform();

using MatrixOffsetTransformTypeD = itk::MatrixOffsetTransformBase<float, 2, 2>;
itk::TransformFactory<MatrixOffsetTransformTypeC>::RegisterTransform();
}

#endif
47 changes: 37 additions & 10 deletions ants/registration/build_template.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
__all__ = ["build_template"]

import numpy as np

import os
from tempfile import mktemp

from .reflect_image import reflect_image
from .interface import registration
from .apply_transforms import apply_transforms
from .resample_image import resample_image_to_target
from ..core import ants_image_io as iio
from ..core import ants_transform_io as tio
from .. import utils


def build_template(
initial_template=None,
image_list=None,
iterations=3,
gradient_step=0.2,
blending_weight=0.75,
weights=None,
useNoRigid=False,
**kwargs
):
"""
Expand Down Expand Up @@ -46,6 +47,9 @@ def build_template(
weights : vector
weight for each input image
useNoRigid : boolean
equivalent of -y in the script. Template update
step will not use the rigid component if this is True.
kwargs : keyword args
extra arguments passed to ants registration
Expand Down Expand Up @@ -79,22 +83,45 @@ def build_template(

xavg = initial_template.clone()
for i in range(iterations):
affinelist = []
for k in range(len(image_list)):
w1 = registration(
xavg, image_list[k], type_of_transform=type_of_transform, **kwargs
)
L = len(w1["fwdtransforms"])
# affine is the last one
affinelist.append(w1["fwdtransforms"][L-1])

if k == 0:
wavg = iio.image_read(w1["fwdtransforms"][0]) * weights[k]
if L == 2:
wavg = iio.image_read(w1["fwdtransforms"][0]) * weights[k]
xavgNew = w1["warpedmovout"] * weights[k]
else:
wavg = wavg + iio.image_read(w1["fwdtransforms"][0]) * weights[k]
if L == 2:
wavg = wavg + iio.image_read(w1["fwdtransforms"][0]) * weights[k]
xavgNew = xavgNew + w1["warpedmovout"] * weights[k]
print(wavg.abs().mean())
wscl = (-1.0) * gradient_step
wavg = wavg * wscl
wavgfn = mktemp(suffix=".nii.gz")
iio.image_write(wavg, wavgfn)
xavg = apply_transforms(xavgNew, xavgNew, wavgfn)

if useNoRigid:
avgaffine = utils.average_affine_transform_no_rigid(affinelist)
else:
avgaffine = utils.average_affine_transform(affinelist)
afffn = mktemp(suffix=".mat")
tio.write_transform(avgaffine, afffn)

if L == 2:
print(wavg.abs().mean())
wscl = (-1.0) * gradient_step
wavg = wavg * wscl
# apply affine to the nonlinear?
# need to save the average
wavgA = apply_transforms(fixed = xavgNew, moving = wavg, imagetype=1, transformlist=afffn, whichtoinvert=[1])
wavgfn = mktemp(suffix=".nii.gz")
iio.image_write(wavgA, wavgfn)
xavg = apply_transforms(fixed=xavgNew, moving=xavgNew, transformlist=[wavgfn, afffn], whichtoinvert=[0, 1])
else:
xavg = apply_transforms(fixed=xavgNew, moving=xavgNew, transformlist=[afffn], whichtoinvert=[1])

os.remove(afffn)
if blending_weight is not None:
xavg = xavg * blending_weight + utils.iMath(xavg, "Sharpen") * (
1.0 - blending_weight
Expand Down
1 change: 1 addition & 0 deletions ants/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
from .smooth_image import *
from .threshold_image import *
from .weingarten_image_curvature import *
from .average_transform import *
43 changes: 43 additions & 0 deletions ants/utils/average_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from .. import utils, core
from tempfile import mktemp
import os



def _average_affine_transform_driver(transformlist, referencetransform=None, funcname="AverageAffineTransform"):
"""
takes a list of transforms (files at the moment)
and returns the average
"""

# AverageAffineTransform deals with transform files,
# so this function will need to deal with already
# loaded files. Doesn't look like the magic
# available for images has been added for transforms.
res_temp_file = mktemp(suffix='.mat')

# could do some stuff here to cope with transform lists that
# aren't files

# load one of the transforms to figure out the dimension
tf = core.ants_transform_io.read_transform(transformlist[0])
if referencetransform is None:
args = [tf.dimension, res_temp_file] + transformlist
else:
args = [tf.dimension, res_temp_file] + ['-R', referencetransform] + transformlist
pargs = utils._int_antsProcessArguments(args)
print(pargs)
libfun = utils.get_lib_fn(funcname)
status = libfun(pargs)

res = core.ants_transform_io.read_transform(res_temp_file)
os.remove(res_temp_file)
return res

def average_affine_transform(transformlist, referencetransform=None):
return _average_affine_transform_driver(transformlist, referencetransform, "AverageAffineTransform")


def average_affine_transform_no_rigid(transformlist, referencetransform=None):
return _average_affine_transform_driver(transformlist, referencetransform, "AverageAffineTransformNoRigid")

0 comments on commit 719bee7

Please sign in to comment.