Skip to content

Commit

Permalink
ENH: pickle serialization for itk images (issue 1091)
Browse files Browse the repository at this point in the history
  • Loading branch information
GenevieveBuckley committed Nov 11, 2021
1 parent abb6b68 commit 27242f4
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 0 deletions.
17 changes: 17 additions & 0 deletions Wrapping/Generators/Python/PyBase/pyBase.i
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,23 @@ str = str
import itk
itk.array_view_from_image(self).__setitem__(key, value)
def __getstate__(self):
"""Get object state, necessary for serialization with pickle."""
import itk
state = itk.dict_from_image(self)
return state
def __setstate__(self, state):
"""Set object state, necessary for serialization with pickle."""
import itk
import numpy as np
deserialized = itk.image_from_dict(state)
self.__dict__['this'] = deserialized
self.SetOrigin(state['origin'])
self.SetSpacing(state['spacing'])
direction = np.asarray(self.GetDirection())
self.SetDirection(direction)
%}
// TODO: also add that method. But with which types?
Expand Down
16 changes: 16 additions & 0 deletions Wrapping/Generators/Python/Tests/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import numpy as np
import pathlib
import pickle

import itk

Expand Down Expand Up @@ -162,6 +163,21 @@ def custom_callback(name, progress):
image = itk.imread(filename, imageio=itk.PNGImageIO.New())
assert type(image) == itk.Image[itk.RGBPixel[itk.UC], 2]

# Test serialization with pickle
array = np.random.randint(0, 256, (8, 12)).astype(np.uint8)
image = itk.image_from_array(array)
image.SetSpacing([1.0, 2.0])
image.SetOrigin([11.0, 4.0])
theta = np.radians(30)
cosine = np.cos(theta)
sine = np.sin(theta)
rotation = np.array(((cosine, -sine), (sine, cosine)))
image.SetDirection(rotation)
serialize_deserialize = pickle.loads(pickle.dumps(image))
# verify_input_information checks origin, spacing, direction consistency
comparison = itk.comparison_image_filter(image, serialize_deserialize, verify_input_information=True)
assert np.sum(comparison) == 0.0

# Make sure we can read unsigned short, unsigned int, and cast
image = itk.imread(filename, itk.UI)
assert type(image) == itk.Image[itk.UI, 2]
Expand Down
32 changes: 32 additions & 0 deletions Wrapping/Generators/Python/itk/support/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

import itk.support.types as itkt

from .helpers import wasm_type_from_image_type, image_type_from_wasm_type

if TYPE_CHECKING:
try:
import xarray as xr
Expand Down Expand Up @@ -89,6 +91,8 @@
"image_from_xarray",
"vtk_image_from_image",
"image_from_vtk_image",
"dict_from_image",
"image_from_dict",
"image_intensity_min_max",
"imwrite",
"imread",
Expand Down Expand Up @@ -801,6 +805,34 @@ def image_from_vtk_image(vtk_image: "vtk.vtkImageData") -> "itkt.ImageBase":
return l_image


def dict_from_image(image: "itkt.Image") -> Dict:
"""Serialize a Python itk.Image object to a pickable Python dictionary."""
import itk

pixel_arr = itk.array_view_from_image(image)
imageType = wasm_type_from_image_type(image)
return dict(
imageType=imageType,
origin=tuple(image.GetOrigin()),
spacing=tuple(image.GetSpacing()),
size=tuple(image.GetBufferedRegion().GetSize()),
direction=np.asarray(image.GetDirection()),
data=pixel_arr
)


def image_from_dict(image_dict: Dict) -> "itkt.Image":
"""Deserialize an dictionary representing an itk.Image object."""
import itk

ImageType = image_type_from_wasm_type(image_dict['imageType'])
image = itk.PyBuffer[ImageType].GetImageViewFromArray(image_dict['data'])
image.SetOrigin(image_dict['origin'])
image.SetSpacing(image_dict['spacing'])
image.SetDirection(image_dict['direction'])
return image


# return an image


Expand Down
133 changes: 133 additions & 0 deletions Wrapping/Generators/Python/itk/support/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
#
# ==========================================================================*/

import os
import re
import functools

import numpy as np

_HAVE_XARRAY = False
try:
import xarray as xr
Expand Down Expand Up @@ -173,3 +176,133 @@ def image_filter_wrapper(*args, **kwargs):
return image_filter(*args, **kwargs)

return image_filter_wrapper


def wasm_type_from_image_type(itkimage): # noqa: C901
import itk

component = itk.template(itkimage)[1][0]
if component == itk.UL:
if os.name == 'nt':
return 'uint32_t', 1
else:
return 'uint64_t', 1
mangle = None
pixelType = 1
if component == itk.SL:
if os.name == 'nt':
return 'int32_t', 1,
else:
return 'int64_t', 1,
if component in (itk.SC, itk.UC, itk.SS, itk.US, itk.SI, itk.UI, itk.F,
itk.D, itk.B, itk.SL, itk.SLL, itk.UL, itk.ULL):
mangle = component
elif component in [i[1] for i in itk.Vector.items()]:
mangle = itk.template(component)[1][0]
pixelType = 5
elif component == itk.complex[itk.F]:
# complex float
return 'float', 10
elif component == itk.complex[itk.D]:
# complex float
return 'double', 10
elif component in [i[1] for i in itk.CovariantVector.items()]:
# CovariantVector
mangle = itk.template(component)[1][0]
pixelType = 7
elif component in [i[1] for i in itk.Offset.items()]:
# Offset
return 'int64_t', 4
elif component in [i[1] for i in itk.FixedArray.items()]:
# FixedArray
mangle = itk.template(component)[1][0]
pixelType = 11
elif component in [i[1] for i in itk.RGBAPixel.items()]:
# RGBA
mangle = itk.template(component)[1][0]
pixelType = 3
elif component in [i[1] for i in itk.RGBPixel.items()]:
# RGB
mangle = itk.template(component)[1][0]
pixelType = 2
elif component in [i[1] for i in itk.SymmetricSecondRankTensor.items()]:
# SymmetricSecondRankTensor
mangle = itk.template(component)[1][0]
pixelType = 8
else:
raise RuntimeError('Unrecognized component type: {0}'.format(str(component)))

def _long_type():
if os.name == 'nt':
return 'int32_t'
else:
return 'int64_t'
_python_to_js = {
itk.SC: 'int8_t',
itk.UC: 'uint8_t',
itk.SS: 'int16_t',
itk.US: 'uint16_t',
itk.SI: 'int32_t',
itk.UI: 'uint32_t',
itk.F: 'float',
itk.D: 'double',
itk.B: 'uint8_t',
itk.SL: _long_type(),
itk.UL: 'u' + _long_type(),
itk.SLL: 'int64_t',
itk.ULL: 'uint64_t',
}
imageType = dict(
dimension=itkimage.GetImageDimension(),
componentType=_python_to_js[mangle],
pixelType=pixelType,
components=itkimage.GetNumberOfComponentsPerPixel()
)
return imageType


def image_type_from_wasm_type(jstype):
import itk

_pixelType_to_prefix = {
1: '',
2: 'RGB',
3: 'RGBA',
4: 'O',
5: 'V',
7: 'CV',
8: 'SSRT',
11: 'FA'
}
pixelType = jstype['pixelType']
dimension = jstype['dimension']
if pixelType == 10:
if jstype['componentType'] == 'float':
return itk.Image[itk.complex, itk.F], np.float32
else:
return itk.Image[itk.complex, itk.D], np.float64

def _long_type():
if os.name == 'nt':
return 'LL'
else:
return 'L'
prefix = _pixelType_to_prefix[pixelType]
_js_to_python = {
'int8_t': 'SC',
'uint8_t': 'UC',
'int16_t': 'SS',
'uint16_t': 'US',
'int32_t': 'SI',
'uint32_t': 'UI',
'int64_t': 'S' + _long_type(),
'uint64_t': 'U' + _long_type(),
'float': 'F',
'double': 'D'
}
if pixelType != 4:
prefix += _js_to_python[jstype['componentType']]
if pixelType not in (1, 2, 3, 10):
prefix += str(dimension)
prefix += str(dimension)
return getattr(itk.Image, prefix)

0 comments on commit 27242f4

Please sign in to comment.