Skip to content

Commit

Permalink
ENH: Adding Support for Transform Serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
PranjalSahu committed Mar 30, 2022
1 parent 9f00f9d commit abc5b22
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Modules/Core/Transform/wrapping/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
if(ITK_WRAP_PYTHON)
itk_python_add_test(NAME itkTransformSerializationTest COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/itkTransformSerializationTest.py)
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# ==========================================================================
#
# Copyright NumFOCUS
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==========================================================================*/

import itk
import numpy as np
import pickle

Dimension = 3
PixelType = itk.D

# List of Transforms to test
transforms_to_test = [itk.AffineTransform[PixelType, Dimension], itk.DisplacementFieldTransform[PixelType, Dimension], itk.Rigid3DTransform[PixelType], itk.BSplineTransform[PixelType, Dimension, 3], itk.QuaternionRigidTransform[PixelType]]

keys_to_test1 = ["name", "parametersValueType", "transformName", "transformType", "inDimension", "outDimension", "numberOfParameters", "numberOfFixedParameters"]
keys_to_test2 = ["parameters", "fixedParameters"]

transform_object_list = []
for i, transform_type in enumerate(transforms_to_test):
transform = transform_type.New()
transform.SetObjectName("transform"+str(i))

# Check the serialization
serialize_deserialize = pickle.loads(pickle.dumps(transform))

# Test all the attributes
for k in keys_to_test1:
assert serialize_deserialize[k] == transform[k]

# Test all the parameters
for k in keys_to_test2:
assert np.array_equal(serialize_deserialize[k], transform[k])

transform_object_list.append(transform)

print('Individual Transforms Test Done')

# Test Composite Transform
transformType = itk.CompositeTransform[PixelType, Dimension]
composite_transform = transformType.New()
composite_transform.SetObjectName('composite_transform')

# Add the above created transforms in the composite transform
for transform in transform_object_list:
composite_transform.AddTransform(transform)

# Check the serialization of composite transform
serialize_deserialize = pickle.loads(pickle.dumps(composite_transform))

assert serialize_deserialize.GetObjectName() == composite_transform.GetObjectName()
assert serialize_deserialize.GetNumberOfTransforms() == 5
assert serialize_deserialize["name"] == composite_transform["name"]

deserialized_object_list = []

keys_to_test1 = ["name", "parametersValueType", "transformName", "inDimension", "outDimension", "numberOfParameters", "numberOfFixedParameters"]

# Get the individual transform objects from the composite transform for testing
for i in range(len(transforms_to_test)):
transform_obj = serialize_deserialize.GetNthTransform(i)

# Test all the attributes
for k in keys_to_test1:
assert transform_obj[k] == transform_object_list[i][k]

# Test all the parameter arrays
for k in keys_to_test2:
assert np.array_equal(transform_obj[k], transform_object_list[i][k])

# Testing for loss of transformType in Composite transform
if i == 3:
# BSpline has same type here D33
assert transform_obj["transformType"] == transform_object_list[i]["transformType"]
else:
assert transform_obj["transformType"] != transform_object_list[i]["transformType"]
57 changes: 57 additions & 0 deletions Wrapping/Generators/Python/PyBase/pyBase.i
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,63 @@ str = str
%enddef
%define DECL_PYTHON_TRANSFORMBASETEMPLATE_CLASS(swig_name)
%extend swig_name {
%pythoncode %{
def keys(self):
"""
Return keys related to the transform's metadata.
These keys are used in the dictionary resulting from dict(transform).
"""
result = ['name', 'transformType', 'inDimension', 'outDimension', 'numberOfParameters', 'numberOfFixedParameters', 'parameters', 'fixedParameters']
return result
def __getitem__(self, key):
"""Access metadata keys, see help(transform.keys), for string keys."""
import itk
if isinstance(key, str):
state = itk.dict_from_transform(self)
return state[0][key]
def __setitem__(self, key, value):
if isinstance(key, str):
import numpy as np
if key == 'name':
self.SetObjectName(value)
elif key == 'fixedParameters' or key == 'parameters':
if key == 'fixedParameters':
o1 = self.GetFixedParameters()
else:
o1 = self.GetParameters()
o1.SetSize(value.shape[0])
for i, v in enumerate(value):
o1.SetElement(i, v)
if key == 'fixedParameters':
self.SetFixedParameters(o1)
else:
self.SetParameters(o1)
def __getstate__(self):
"""Get object state, necessary for serialization with pickle."""
import itk
state = itk.dict_from_transform(self)
return state
def __setstate__(self, state):
"""Set object state, necessary for serialization with pickle."""
import itk
import numpy as np
deserialized = itk.transform_from_dict(state)
self.__dict__['this'] = deserialized
%}
}
%enddef
%define DECL_PYTHON_IMAGEBASE_CLASS(swig_name, template_params)
%inline %{
#include "itkContinuousIndexSwigInterface.h"
Expand Down
141 changes: 141 additions & 0 deletions Wrapping/Generators/Python/itk/support/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#
# ==========================================================================*/

import enum
import re
from typing import Optional, Union, Dict, Any, List, Tuple, Sequence, TYPE_CHECKING
from sys import stderr as system_error_stream
Expand Down Expand Up @@ -104,6 +105,8 @@
"dict_from_mesh",
"pointset_from_dict",
"dict_from_pointset",
"transform_from_dict",
"dict_from_transform",
"transformwrite",
"transformread",
"search",
Expand Down Expand Up @@ -1012,6 +1015,144 @@ def dict_from_pointset(pointset: "itkt.PointSet") -> Dict:
pointData=point_data_numpy,
)

def dict_from_transform(transform: "itkt.TransformBase") -> Dict:
import itk

def update_transform_dict(current_transform):
current_transform_type = current_transform.GetTransformTypeAsString()
current_transform_type_split = current_transform_type.split('_')
component = itk.template(current_transform)

in_transform_dict = dict()
in_transform_dict["name"] = current_transform.GetObjectName()
in_transform_dict["numberOfTransforms"] = 1

datatype_dict = {'double': itk.D, 'float': itk.F}
in_transform_dict["parametersValueType"] = python_to_js(datatype_dict[current_transform_type_split[1]])
in_transform_dict["inDimension"] = int(current_transform_type_split[2])
in_transform_dict["outDimension"] = int(current_transform_type_split[3])
in_transform_dict["transformName"] = current_transform_type_split[0]

# transformType field to be used for single transform object only.
# For composite transforms we lose the information for child transform objects.
data_type_dict = {itk.D: 'D', itk.F: 'F'}
mangle = data_type_dict[component[1][0]]
for p in component[1][1:]:
mangle += str(p)
in_transform_dict["transformType"] = mangle

# To avoid copying the parameters for the Composite Transform as it is a copy of child transforms.
if 'Composite' not in current_transform_type_split[0]:
p = np.array(current_transform.GetParameters())
in_transform_dict["parameters"] = p

fp = np.array(current_transform.GetFixedParameters())
in_transform_dict["fixedParameters"] = fp

in_transform_dict["numberOfParameters"] = p.shape[0]
in_transform_dict["numberOfFixedParameters"] = fp.shape[0]
else:
in_transform_dict["parameters"] = np.array([])
in_transform_dict["fixedParameters"] = np.array([])
in_transform_dict["numberOfParameters"] = 0
in_transform_dict["numberOfFixedParameters"] = 0

return in_transform_dict


dict_array = []
transform_type = transform.GetTransformTypeAsString()
if 'CompositeTransform' in transform_type:
transform_dict = update_transform_dict(transform)
transform_dict["numberOfTransforms"] = transform.GetNumberOfTransforms()

# Add the first entry for the composite transform
dict_array.append(transform_dict)

# Rest follows the transforms inside the composite transform
# range is over-ridden so using this hack to create a list
for i, _ in enumerate([0]*transform.GetNumberOfTransforms()):
current_transform = transform.GetNthTransform(i)
dict_array.append(update_transform_dict(current_transform))
else:
dict_array.append(update_transform_dict(transform))

return dict_array

def transform_from_dict(transform_dict: Dict)-> "itkt.TransformBase":
import itk

def set_parameters(transform, transform_parameters, transform_fixed_parameters):
o1 = transform.GetParameters()
o1.SetSize(transform_parameters.shape[0])
for j, v in enumerate(transform_parameters):
o1.SetElement(j, v)
transform.SetParameters(o1)

o2 = transform.GetFixedParameters()
o2.SetSize(transform_fixed_parameters.shape[0])
for j, v in enumerate(transform_fixed_parameters):
o2.SetElement(j, v)
transform.SetFixedParameters(o2)


# For checking transforms which don't take additional parameters while instantiation
def special_transform_check(transform_name):
if '2D' in transform_name or '3D' in transform_name:
return True

check_list = ['VersorTransform', 'QuaternionRigidTransform']
for t in check_list:
if transform_name == t:
return True
return False

# We only check for the first transform as composite similar to the
# convention followed in the itkTxtTransformIO.cxx
if 'CompositeTransform' in transform_dict[0]["transformName"]:
# Loop over all the transforms in the dictionary
transforms_list = []
for i, _ in enumerate(transform_dict):
if transform_dict[i]["parametersValueType"] == "float32":
data_type = itk.F
else:
data_type = itk.D

# No template parameter needed for transforms having 2D or 3D name
# Also for some selected transforms
if special_transform_check(transform_dict[i]["transformName"]):
transform_template = getattr(itk, transform_dict[i]["transformName"])
transform = transform_template[data_type].New()
# Currently only BSpline Transform has 3 template parameters
# For future extensions the information will have to be encoded in
# the transformType variable. The transform object once added in a
# composite transform lose the information for other template parameters ex. BSpline.
# The Spline order is fixed as 3 here.
elif transform_dict[i]["transformName"] == 'BSplineTransform':
transform_template = getattr(itk, transform_dict[i]["transformName"])
transform = transform_template[data_type, transform_dict[i]["inDimension"], 3].New()
else:
transform_template = getattr(itk, transform_dict[i]["transformName"])
transform = transform_template[data_type, transform_dict[i]["inDimension"]].New()

transform.SetObjectName(transform_dict[i]["name"])
transforms_list.append(transform)

# Obtain the first object which is composite transform object
# and add all the transforms in it.
transform = transforms_list[0]
for current_transform in transforms_list[1:]:
transform.AddTransform(current_transform)
else:
# For handling single transform objects we rely on itk.template
# because that way we can handle future extensions easily.
transform_template = getattr(itk, transform_dict[0]["transformName"])
transform = getattr(transform_template, transform_dict[0]["transformType"]).New()
transform.SetObjectName(transform_dict[0]["name"])
set_parameters(transform, transform_dict[0]["parameters"], transform_dict[0]["fixedParameters"])

return transform

def image_intensity_min_max(image_or_filter: "itkt.ImageOrImageSource"):
"""Return the minimum and maximum of values in a image of in the output image of a filter
Expand Down
4 changes: 4 additions & 0 deletions Wrapping/TypedefMacros.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,10 @@ macro(itk_wrap_simple_type wrap_class swig_name)
set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}DECL_PYTHON_MESH_CLASS(${swig_name})\n\n")
endif()

if("${cpp_name}" STREQUAL "itk::TransformBaseTemplate")
set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}DECL_PYTHON_TRANSFORMBASETEMPLATE_CLASS(${swig_name})\n\n")
endif()

if("${cpp_name}" STREQUAL "itk::PyImageFilter" AND NOT "${swig_name}" MATCHES "Pointer$")
set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}DECL_PYIMAGEFILTER_CLASS(${swig_name})\n\n")
endif()
Expand Down

0 comments on commit abc5b22

Please sign in to comment.