Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Make dict_from_transform more consistent with other dict representations #4635

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,13 @@

keys_to_test1 = [
"name",
"parametersValueType",
"transformType",
"inputDimension",
"outputDimension",
"inputSpaceName",
"outputSpaceName",
"numberOfParameters",
"numberOfFixedParameters",
]
keys_to_test2 = ["parameters", "fixedParameters"]
keys_to_test3 = ["transformParameterization", "parametersValueType", "inputDimension", "outputDimension"]

transform_object_list = []
for i, transform_type in enumerate(transforms_to_test):
Expand All @@ -60,6 +57,8 @@
# Test all the parameters
for k in keys_to_test2:
assert np.array_equal(serialize_deserialize[k], transform[k])
for k in keys_to_test3:
assert serialize_deserialize["transformType"][k], transform["transformType"][k]
transform_object_list.append(transform)

print("Individual Transforms Test Done")
Expand Down Expand Up @@ -93,6 +92,9 @@
for k in keys_to_test2:
assert np.array_equal(transform_obj[k], transform_object_list[i][k])

for k in keys_to_test3:
assert transform_object_list[i]["transformType"][k], transform["transformType"][k]


# Test for transformation using de-serialized BSpline Transform
ImageDimension = 2
Expand Down
5 changes: 2 additions & 3 deletions Wrapping/Generators/Python/PyBase/pyBase.i
Original file line number Diff line number Diff line change
Expand Up @@ -430,15 +430,15 @@ str = str
Return keys related to the transform's metadata.
These keys are used in the dictionary resulting from dict(transform).
"""
result = ['name', 'inputDimension', 'outputDimension', 'inputSpaceName', 'outputSpaceName', 'numberOfParameters', 'numberOfFixedParameters', 'parameters', 'fixedParameters']
result = ['transformType', 'name', 'inputSpaceName', 'outputSpaceName', 'numberOfParameters', 'numberOfFixedParameters', 'parameters', 'fixedParameters']
return result

def __getitem__(self, key):
"""Access metadata keys, see help(transform.keys), for string keys."""
import itk
if isinstance(key, str):
state = itk.dict_from_transform(self)
return state[0][key]
return state[key]

def __setitem__(self, key, value):
if isinstance(key, str):
Expand Down Expand Up @@ -474,7 +474,6 @@ str = str
def __setstate__(self, state):
"""Set object state, necessary for serialization with pickle."""
import itk
import numpy as np
deserialized = itk.transform_from_dict(state)
self.__dict__['this'] = deserialized
%}
Expand Down
5 changes: 5 additions & 0 deletions Wrapping/Generators/Python/Tests/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ def custom_callback(name, progress):
parameters = np.asarray(transforms[0].GetParameters())
assert np.allclose(parameters, np.array(baseline_additional_transform_params))

transform_dict = itk.dict_from_transform(transforms[0])
transform_back = itk.transform_from_dict(transform_dict)
transform_dict = itk.dict_from_transform(transforms)
transform_back = itk.transform_from_dict(transform_dict)

# pipeline, auto_pipeline and templated class are tested in other files

# BridgeNumPy
Expand Down
104 changes: 67 additions & 37 deletions Wrapping/Generators/Python/itk/support/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,57 +981,81 @@ def dict_from_pointset(pointset: "itkt.PointSet") -> Dict:
)


def dict_from_transform(transform: "itkt.TransformBase") -> Dict:
def dict_from_transform(transform: Union["itkt.TransformBase", List["itkt.TransformBase"]]) -> Union[List[Dict], Dict]:
"""Serialize a Python itk.Transform object to a pickable Python dictionary.

If the transform is a list of transforms, then a list of dictionaries is returned.
If the transform is a single, non-Composite transform, then a single dictionary is returned.
Composite transforms and nested composite transforms are flattened into a list of dictionaries.
"""
import itk
datatype_dict = {"double": itk.D, "float": itk.F}

def update_transform_dict(current_transform):
current_transform_type = current_transform.GetTransformTypeAsString()
current_transform_type_split = current_transform_type.split("_")
component = itk.template(current_transform)

in_transform_dict = dict()
in_transform_dict["name"] = current_transform.GetObjectName()
transform_type = dict()
transform_parameterization = current_transform_type_split[0].replace("Transform", "")
transform_type["transformParameterization"] = transform_parameterization

datatype_dict = {"double": itk.D, "float": itk.F}
in_transform_dict["parametersValueType"] = python_to_js(
transform_type["parametersValueType"] = python_to_js(
datatype_dict[current_transform_type_split[1]]
)
in_transform_dict["inputDimension"] = int(current_transform_type_split[2])
in_transform_dict["outputDimension"] = int(current_transform_type_split[3])
in_transform_dict["transformType"] = current_transform_type_split[0]
transform_type["inputDimension"] = int(current_transform_type_split[2])
transform_type["outputDimension"] = int(current_transform_type_split[3])

in_transform_dict["inputSpaceName"] = current_transform.GetInputSpaceName()
in_transform_dict["outputSpaceName"] = current_transform.GetOutputSpaceName()
transform_dict = dict()
transform_dict['transformType'] = transform_type
transform_dict["name"] = current_transform.GetObjectName()

transform_dict["inputSpaceName"] = current_transform.GetInputSpaceName()
transform_dict["outputSpaceName"] = current_transform.GetOutputSpaceName()

# To avoid copying the parameters for the Composite Transform
# as it is a copy of child transforms.
if "Composite" not in current_transform_type_split[0]:
p = np.array(current_transform.GetParameters())
in_transform_dict["parameters"] = p
transform_dict["parameters"] = p

fp = np.array(current_transform.GetFixedParameters())
in_transform_dict["fixedParameters"] = fp
transform_dict["fixedParameters"] = fp

in_transform_dict["numberOfParameters"] = p.shape[0]
in_transform_dict["numberOfFixedParameters"] = fp.shape[0]
transform_dict["numberOfParameters"] = p.shape[0]
transform_dict["numberOfFixedParameters"] = fp.shape[0]

return in_transform_dict
return transform_dict

dict_array = []
transform_type = transform.GetTransformTypeAsString()
if "CompositeTransform" in transform_type:
# Add the transforms inside the composite transform
# range is over-ridden so using this hack to create a list
for i, _ in enumerate([0] * transform.GetNumberOfTransforms()):
current_transform = transform.GetNthTransform(i)
dict_array.append(update_transform_dict(current_transform))
multi = False
def add_transform_dict(transform):
transform_type = transform.GetTransformTypeAsString()
if "CompositeTransform" in transform_type:
# Add the transforms inside the composite transform
# range is over-ridden so using this hack to create a list
for i, _ in enumerate([0] * transform.GetNumberOfTransforms()):
current_transform = transform.GetNthTransform(i)
dict_array.append(update_transform_dict(current_transform))
return True
else:
dict_array.append(update_transform_dict(transform))
return False
if isinstance(transform, list):
multi = True
for t in transform:
add_transform_dict(t)
else:
dict_array.append(update_transform_dict(transform))
multi = add_transform_dict(transform)

return dict_array
if multi:
return dict_array
else:
return dict_array[0]

def transform_from_dict(transform_dict: Union[Dict, List[Dict]]) -> "itkt.TransformBase":
"""Deserialize a dictionary representing an itk.Transform object.

def transform_from_dict(transform_dict: Dict) -> "itkt.TransformBase":
If the dictionary represents a list of transforms, then a Composite Transform is returned."""
import itk

def set_parameters(transform, transform_parameters, transform_fixed_parameters, data_type):
Expand All @@ -1055,35 +1079,41 @@ def special_transform_check(transform_name):

parametersValueType_dict = {"float32": itk.F, "float64": itk.D}

if not isinstance(transform_dict, list):
transform_dict = [transform_dict]

# Loop over all the transforms in the dictionary
transforms_list = []
for i, _ in enumerate(transform_dict):
data_type = parametersValueType_dict[transform_dict[i]["parametersValueType"]]
transform_type = transform_dict[i]["transformType"]
data_type = parametersValueType_dict[transform_type["parametersValueType"]]

transform_parameterization = transform_type["transformParameterization"] + 'Transform'

# No template parameter needed for transforms having 2D or 3D name
# Also for some selected transforms
if special_transform_check(transform_dict[i]["transformType"]):
transform_template = getattr(itk, transform_dict[i]["transformType"])
if special_transform_check(transform_parameterization):
transform_template = getattr(itk, transform_parameterization)
transform = transform_template[data_type].New()
# Currently only BSpline Transform has 3 template parameters
# For future extensions the information will have to be encoded in
# the transformType variable. The transform object once added in a
# composite transform lose the information for other template parameters ex. BSpline.
# The Spline order is fixed as 3 here.
elif transform_dict[i]["transformType"] == "BSplineTransform":
transform_template = getattr(itk, transform_dict[i]["transformType"])
elif transform_parameterization == "BSplineTransform":
transform_template = getattr(itk, transform_parameterization)
transform = transform_template[
data_type, transform_dict[i]["inputDimension"], 3
data_type, transform_type["inputDimension"], 3
].New()
else:
transform_template = getattr(itk, transform_dict[i]["transformType"])
transform_template = getattr(itk, transform_parameterization)
if len(transform_template.items()[0][0]) > 2:
transform = transform_template[
data_type, transform_dict[i]["inputDimension"], transform_dict[i]["outputDimension"]
data_type, transform_type["inputDimension"], transform_type["outputDimension"]
].New()
else:
transform = transform_template[
data_type, transform_dict[i]["inputDimension"]
data_type, transform_type["inputDimension"]
].New()

transform.SetObjectName(transform_dict[i]["name"])
Expand All @@ -1102,8 +1132,8 @@ def special_transform_check(transform_name):
if len(transforms_list) > 1:
# Create a Composite Transform object
# and add all the transforms in it.
data_type = parametersValueType_dict[transform_dict[0]["parametersValueType"]]
transform = itk.CompositeTransform[data_type, transforms_list[0]['inputDimension']].New()
data_type = parametersValueType_dict[transform_dict[0]["transformType"]["parametersValueType"]]
transform = itk.CompositeTransform[data_type, transforms_list[0]["transformType"]['inputDimension']].New()
for current_transform in transforms_list:
transform.AddTransform(current_transform)
else:
Expand Down