Skip to content

Commit abc5b22

Browse files
committed
ENH: Adding Support for Transform Serialization
1 parent 9f00f9d commit abc5b22

File tree

5 files changed

+294
-0
lines changed

5 files changed

+294
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
if(ITK_WRAP_PYTHON)
2+
itk_python_add_test(NAME itkTransformSerializationTest COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/itkTransformSerializationTest.py)
3+
endif()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# ==========================================================================
2+
#
3+
# Copyright NumFOCUS
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0.txt
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
# ==========================================================================*/
18+
19+
import itk
20+
import numpy as np
21+
import pickle
22+
23+
Dimension = 3
24+
PixelType = itk.D
25+
26+
# List of Transforms to test
27+
transforms_to_test = [itk.AffineTransform[PixelType, Dimension], itk.DisplacementFieldTransform[PixelType, Dimension], itk.Rigid3DTransform[PixelType], itk.BSplineTransform[PixelType, Dimension, 3], itk.QuaternionRigidTransform[PixelType]]
28+
29+
keys_to_test1 = ["name", "parametersValueType", "transformName", "transformType", "inDimension", "outDimension", "numberOfParameters", "numberOfFixedParameters"]
30+
keys_to_test2 = ["parameters", "fixedParameters"]
31+
32+
transform_object_list = []
33+
for i, transform_type in enumerate(transforms_to_test):
34+
transform = transform_type.New()
35+
transform.SetObjectName("transform"+str(i))
36+
37+
# Check the serialization
38+
serialize_deserialize = pickle.loads(pickle.dumps(transform))
39+
40+
# Test all the attributes
41+
for k in keys_to_test1:
42+
assert serialize_deserialize[k] == transform[k]
43+
44+
# Test all the parameters
45+
for k in keys_to_test2:
46+
assert np.array_equal(serialize_deserialize[k], transform[k])
47+
48+
transform_object_list.append(transform)
49+
50+
print('Individual Transforms Test Done')
51+
52+
# Test Composite Transform
53+
transformType = itk.CompositeTransform[PixelType, Dimension]
54+
composite_transform = transformType.New()
55+
composite_transform.SetObjectName('composite_transform')
56+
57+
# Add the above created transforms in the composite transform
58+
for transform in transform_object_list:
59+
composite_transform.AddTransform(transform)
60+
61+
# Check the serialization of composite transform
62+
serialize_deserialize = pickle.loads(pickle.dumps(composite_transform))
63+
64+
assert serialize_deserialize.GetObjectName() == composite_transform.GetObjectName()
65+
assert serialize_deserialize.GetNumberOfTransforms() == 5
66+
assert serialize_deserialize["name"] == composite_transform["name"]
67+
68+
deserialized_object_list = []
69+
70+
keys_to_test1 = ["name", "parametersValueType", "transformName", "inDimension", "outDimension", "numberOfParameters", "numberOfFixedParameters"]
71+
72+
# Get the individual transform objects from the composite transform for testing
73+
for i in range(len(transforms_to_test)):
74+
transform_obj = serialize_deserialize.GetNthTransform(i)
75+
76+
# Test all the attributes
77+
for k in keys_to_test1:
78+
assert transform_obj[k] == transform_object_list[i][k]
79+
80+
# Test all the parameter arrays
81+
for k in keys_to_test2:
82+
assert np.array_equal(transform_obj[k], transform_object_list[i][k])
83+
84+
# Testing for loss of transformType in Composite transform
85+
if i == 3:
86+
# BSpline has same type here D33
87+
assert transform_obj["transformType"] == transform_object_list[i]["transformType"]
88+
else:
89+
assert transform_obj["transformType"] != transform_object_list[i]["transformType"]

Wrapping/Generators/Python/PyBase/pyBase.i

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,63 @@ str = str
405405
%enddef
406406
407407
408+
%define DECL_PYTHON_TRANSFORMBASETEMPLATE_CLASS(swig_name)
409+
%extend swig_name {
410+
%pythoncode %{
411+
def keys(self):
412+
"""
413+
Return keys related to the transform's metadata.
414+
These keys are used in the dictionary resulting from dict(transform).
415+
"""
416+
result = ['name', 'transformType', 'inDimension', 'outDimension', 'numberOfParameters', 'numberOfFixedParameters', 'parameters', 'fixedParameters']
417+
return result
418+
419+
def __getitem__(self, key):
420+
"""Access metadata keys, see help(transform.keys), for string keys."""
421+
import itk
422+
if isinstance(key, str):
423+
state = itk.dict_from_transform(self)
424+
return state[0][key]
425+
426+
def __setitem__(self, key, value):
427+
if isinstance(key, str):
428+
import numpy as np
429+
if key == 'name':
430+
self.SetObjectName(value)
431+
elif key == 'fixedParameters' or key == 'parameters':
432+
if key == 'fixedParameters':
433+
o1 = self.GetFixedParameters()
434+
else:
435+
o1 = self.GetParameters()
436+
437+
o1.SetSize(value.shape[0])
438+
for i, v in enumerate(value):
439+
o1.SetElement(i, v)
440+
441+
if key == 'fixedParameters':
442+
self.SetFixedParameters(o1)
443+
else:
444+
self.SetParameters(o1)
445+
446+
447+
def __getstate__(self):
448+
"""Get object state, necessary for serialization with pickle."""
449+
import itk
450+
state = itk.dict_from_transform(self)
451+
return state
452+
453+
def __setstate__(self, state):
454+
"""Set object state, necessary for serialization with pickle."""
455+
import itk
456+
import numpy as np
457+
deserialized = itk.transform_from_dict(state)
458+
self.__dict__['this'] = deserialized
459+
%}
460+
}
461+
462+
%enddef
463+
464+
408465
%define DECL_PYTHON_IMAGEBASE_CLASS(swig_name, template_params)
409466
%inline %{
410467
#include "itkContinuousIndexSwigInterface.h"

Wrapping/Generators/Python/itk/support/extras.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717
# ==========================================================================*/
1818

19+
import enum
1920
import re
2021
from typing import Optional, Union, Dict, Any, List, Tuple, Sequence, TYPE_CHECKING
2122
from sys import stderr as system_error_stream
@@ -104,6 +105,8 @@
104105
"dict_from_mesh",
105106
"pointset_from_dict",
106107
"dict_from_pointset",
108+
"transform_from_dict",
109+
"dict_from_transform",
107110
"transformwrite",
108111
"transformread",
109112
"search",
@@ -1012,6 +1015,144 @@ def dict_from_pointset(pointset: "itkt.PointSet") -> Dict:
10121015
pointData=point_data_numpy,
10131016
)
10141017

1018+
def dict_from_transform(transform: "itkt.TransformBase") -> Dict:
1019+
import itk
1020+
1021+
def update_transform_dict(current_transform):
1022+
current_transform_type = current_transform.GetTransformTypeAsString()
1023+
current_transform_type_split = current_transform_type.split('_')
1024+
component = itk.template(current_transform)
1025+
1026+
in_transform_dict = dict()
1027+
in_transform_dict["name"] = current_transform.GetObjectName()
1028+
in_transform_dict["numberOfTransforms"] = 1
1029+
1030+
datatype_dict = {'double': itk.D, 'float': itk.F}
1031+
in_transform_dict["parametersValueType"] = python_to_js(datatype_dict[current_transform_type_split[1]])
1032+
in_transform_dict["inDimension"] = int(current_transform_type_split[2])
1033+
in_transform_dict["outDimension"] = int(current_transform_type_split[3])
1034+
in_transform_dict["transformName"] = current_transform_type_split[0]
1035+
1036+
# transformType field to be used for single transform object only.
1037+
# For composite transforms we lose the information for child transform objects.
1038+
data_type_dict = {itk.D: 'D', itk.F: 'F'}
1039+
mangle = data_type_dict[component[1][0]]
1040+
for p in component[1][1:]:
1041+
mangle += str(p)
1042+
in_transform_dict["transformType"] = mangle
1043+
1044+
# To avoid copying the parameters for the Composite Transform as it is a copy of child transforms.
1045+
if 'Composite' not in current_transform_type_split[0]:
1046+
p = np.array(current_transform.GetParameters())
1047+
in_transform_dict["parameters"] = p
1048+
1049+
fp = np.array(current_transform.GetFixedParameters())
1050+
in_transform_dict["fixedParameters"] = fp
1051+
1052+
in_transform_dict["numberOfParameters"] = p.shape[0]
1053+
in_transform_dict["numberOfFixedParameters"] = fp.shape[0]
1054+
else:
1055+
in_transform_dict["parameters"] = np.array([])
1056+
in_transform_dict["fixedParameters"] = np.array([])
1057+
in_transform_dict["numberOfParameters"] = 0
1058+
in_transform_dict["numberOfFixedParameters"] = 0
1059+
1060+
return in_transform_dict
1061+
1062+
1063+
dict_array = []
1064+
transform_type = transform.GetTransformTypeAsString()
1065+
if 'CompositeTransform' in transform_type:
1066+
transform_dict = update_transform_dict(transform)
1067+
transform_dict["numberOfTransforms"] = transform.GetNumberOfTransforms()
1068+
1069+
# Add the first entry for the composite transform
1070+
dict_array.append(transform_dict)
1071+
1072+
# Rest follows the transforms inside the composite transform
1073+
# range is over-ridden so using this hack to create a list
1074+
for i, _ in enumerate([0]*transform.GetNumberOfTransforms()):
1075+
current_transform = transform.GetNthTransform(i)
1076+
dict_array.append(update_transform_dict(current_transform))
1077+
else:
1078+
dict_array.append(update_transform_dict(transform))
1079+
1080+
return dict_array
1081+
1082+
def transform_from_dict(transform_dict: Dict)-> "itkt.TransformBase":
1083+
import itk
1084+
1085+
def set_parameters(transform, transform_parameters, transform_fixed_parameters):
1086+
o1 = transform.GetParameters()
1087+
o1.SetSize(transform_parameters.shape[0])
1088+
for j, v in enumerate(transform_parameters):
1089+
o1.SetElement(j, v)
1090+
transform.SetParameters(o1)
1091+
1092+
o2 = transform.GetFixedParameters()
1093+
o2.SetSize(transform_fixed_parameters.shape[0])
1094+
for j, v in enumerate(transform_fixed_parameters):
1095+
o2.SetElement(j, v)
1096+
transform.SetFixedParameters(o2)
1097+
1098+
1099+
# For checking transforms which don't take additional parameters while instantiation
1100+
def special_transform_check(transform_name):
1101+
if '2D' in transform_name or '3D' in transform_name:
1102+
return True
1103+
1104+
check_list = ['VersorTransform', 'QuaternionRigidTransform']
1105+
for t in check_list:
1106+
if transform_name == t:
1107+
return True
1108+
return False
1109+
1110+
# We only check for the first transform as composite similar to the
1111+
# convention followed in the itkTxtTransformIO.cxx
1112+
if 'CompositeTransform' in transform_dict[0]["transformName"]:
1113+
# Loop over all the transforms in the dictionary
1114+
transforms_list = []
1115+
for i, _ in enumerate(transform_dict):
1116+
if transform_dict[i]["parametersValueType"] == "float32":
1117+
data_type = itk.F
1118+
else:
1119+
data_type = itk.D
1120+
1121+
# No template parameter needed for transforms having 2D or 3D name
1122+
# Also for some selected transforms
1123+
if special_transform_check(transform_dict[i]["transformName"]):
1124+
transform_template = getattr(itk, transform_dict[i]["transformName"])
1125+
transform = transform_template[data_type].New()
1126+
# Currently only BSpline Transform has 3 template parameters
1127+
# For future extensions the information will have to be encoded in
1128+
# the transformType variable. The transform object once added in a
1129+
# composite transform lose the information for other template parameters ex. BSpline.
1130+
# The Spline order is fixed as 3 here.
1131+
elif transform_dict[i]["transformName"] == 'BSplineTransform':
1132+
transform_template = getattr(itk, transform_dict[i]["transformName"])
1133+
transform = transform_template[data_type, transform_dict[i]["inDimension"], 3].New()
1134+
else:
1135+
transform_template = getattr(itk, transform_dict[i]["transformName"])
1136+
transform = transform_template[data_type, transform_dict[i]["inDimension"]].New()
1137+
1138+
transform.SetObjectName(transform_dict[i]["name"])
1139+
transforms_list.append(transform)
1140+
1141+
# Obtain the first object which is composite transform object
1142+
# and add all the transforms in it.
1143+
transform = transforms_list[0]
1144+
for current_transform in transforms_list[1:]:
1145+
transform.AddTransform(current_transform)
1146+
else:
1147+
# For handling single transform objects we rely on itk.template
1148+
# because that way we can handle future extensions easily.
1149+
transform_template = getattr(itk, transform_dict[0]["transformName"])
1150+
transform = getattr(transform_template, transform_dict[0]["transformType"]).New()
1151+
transform.SetObjectName(transform_dict[0]["name"])
1152+
set_parameters(transform, transform_dict[0]["parameters"], transform_dict[0]["fixedParameters"])
1153+
1154+
return transform
1155+
10151156
def image_intensity_min_max(image_or_filter: "itkt.ImageOrImageSource"):
10161157
"""Return the minimum and maximum of values in a image of in the output image of a filter
10171158

Wrapping/TypedefMacros.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,10 @@ macro(itk_wrap_simple_type wrap_class swig_name)
13261326
set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}DECL_PYTHON_MESH_CLASS(${swig_name})\n\n")
13271327
endif()
13281328

1329+
if("${cpp_name}" STREQUAL "itk::TransformBaseTemplate")
1330+
set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}DECL_PYTHON_TRANSFORMBASETEMPLATE_CLASS(${swig_name})\n\n")
1331+
endif()
1332+
13291333
if("${cpp_name}" STREQUAL "itk::PyImageFilter" AND NOT "${swig_name}" MATCHES "Pointer$")
13301334
set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}DECL_PYIMAGEFILTER_CLASS(${swig_name})\n\n")
13311335
endif()

0 commit comments

Comments
 (0)