From 27242f4aa53b7209feb39f2600cabb317330e682 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 21 Oct 2021 16:22:40 +1100 Subject: [PATCH] ENH: pickle serialization for itk images (issue 1091) --- Wrapping/Generators/Python/PyBase/pyBase.i | 17 +++ Wrapping/Generators/Python/Tests/extras.py | 16 +++ .../Generators/Python/itk/support/extras.py | 32 +++++ .../Generators/Python/itk/support/helpers.py | 133 ++++++++++++++++++ 4 files changed, 198 insertions(+) diff --git a/Wrapping/Generators/Python/PyBase/pyBase.i b/Wrapping/Generators/Python/PyBase/pyBase.i index 0c83aa08546..8d764e67db8 100644 --- a/Wrapping/Generators/Python/PyBase/pyBase.i +++ b/Wrapping/Generators/Python/PyBase/pyBase.i @@ -522,6 +522,23 @@ str = str import itk itk.array_view_from_image(self).__setitem__(key, value) + def __getstate__(self): + """Get object state, necessary for serialization with pickle.""" + import itk + state = itk.dict_from_image(self) + return state + + def __setstate__(self, state): + """Set object state, necessary for serialization with pickle.""" + import itk + import numpy as np + deserialized = itk.image_from_dict(state) + self.__dict__['this'] = deserialized + self.SetOrigin(state['origin']) + self.SetSpacing(state['spacing']) + direction = np.asarray(self.GetDirection()) + self.SetDirection(direction) + %} // TODO: also add that method. But with which types? diff --git a/Wrapping/Generators/Python/Tests/extras.py b/Wrapping/Generators/Python/Tests/extras.py index 89839eb3ffe..1a1f94fb32c 100644 --- a/Wrapping/Generators/Python/Tests/extras.py +++ b/Wrapping/Generators/Python/Tests/extras.py @@ -22,6 +22,7 @@ import os import numpy as np import pathlib +import pickle import itk @@ -162,6 +163,21 @@ def custom_callback(name, progress): image = itk.imread(filename, imageio=itk.PNGImageIO.New()) assert type(image) == itk.Image[itk.RGBPixel[itk.UC], 2] +# Test serialization with pickle +array = np.random.randint(0, 256, (8, 12)).astype(np.uint8) +image = itk.image_from_array(array) +image.SetSpacing([1.0, 2.0]) +image.SetOrigin([11.0, 4.0]) +theta = np.radians(30) +cosine = np.cos(theta) +sine = np.sin(theta) +rotation = np.array(((cosine, -sine), (sine, cosine))) +image.SetDirection(rotation) +serialize_deserialize = pickle.loads(pickle.dumps(image)) +# verify_input_information checks origin, spacing, direction consistency +comparison = itk.comparison_image_filter(image, serialize_deserialize, verify_input_information=True) +assert np.sum(comparison) == 0.0 + # Make sure we can read unsigned short, unsigned int, and cast image = itk.imread(filename, itk.UI) assert type(image) == itk.Image[itk.UI, 2] diff --git a/Wrapping/Generators/Python/itk/support/extras.py b/Wrapping/Generators/Python/itk/support/extras.py index e690463bf0e..5d877d37591 100644 --- a/Wrapping/Generators/Python/itk/support/extras.py +++ b/Wrapping/Generators/Python/itk/support/extras.py @@ -35,6 +35,8 @@ import itk.support.types as itkt +from .helpers import wasm_type_from_image_type, image_type_from_wasm_type + if TYPE_CHECKING: try: import xarray as xr @@ -89,6 +91,8 @@ "image_from_xarray", "vtk_image_from_image", "image_from_vtk_image", + "dict_from_image", + "image_from_dict", "image_intensity_min_max", "imwrite", "imread", @@ -801,6 +805,34 @@ def image_from_vtk_image(vtk_image: "vtk.vtkImageData") -> "itkt.ImageBase": return l_image +def dict_from_image(image: "itkt.Image") -> Dict: + """Serialize a Python itk.Image object to a pickable Python dictionary.""" + import itk + + pixel_arr = itk.array_view_from_image(image) + imageType = wasm_type_from_image_type(image) + return dict( + imageType=imageType, + origin=tuple(image.GetOrigin()), + spacing=tuple(image.GetSpacing()), + size=tuple(image.GetBufferedRegion().GetSize()), + direction=np.asarray(image.GetDirection()), + data=pixel_arr + ) + + +def image_from_dict(image_dict: Dict) -> "itkt.Image": + """Deserialize an dictionary representing an itk.Image object.""" + import itk + + ImageType = image_type_from_wasm_type(image_dict['imageType']) + image = itk.PyBuffer[ImageType].GetImageViewFromArray(image_dict['data']) + image.SetOrigin(image_dict['origin']) + image.SetSpacing(image_dict['spacing']) + image.SetDirection(image_dict['direction']) + return image + + # return an image diff --git a/Wrapping/Generators/Python/itk/support/helpers.py b/Wrapping/Generators/Python/itk/support/helpers.py index 2ee053abe65..a7394f4b481 100644 --- a/Wrapping/Generators/Python/itk/support/helpers.py +++ b/Wrapping/Generators/Python/itk/support/helpers.py @@ -16,9 +16,12 @@ # # ==========================================================================*/ +import os import re import functools +import numpy as np + _HAVE_XARRAY = False try: import xarray as xr @@ -173,3 +176,133 @@ def image_filter_wrapper(*args, **kwargs): return image_filter(*args, **kwargs) return image_filter_wrapper + + +def wasm_type_from_image_type(itkimage): # noqa: C901 + import itk + + component = itk.template(itkimage)[1][0] + if component == itk.UL: + if os.name == 'nt': + return 'uint32_t', 1 + else: + return 'uint64_t', 1 + mangle = None + pixelType = 1 + if component == itk.SL: + if os.name == 'nt': + return 'int32_t', 1, + else: + return 'int64_t', 1, + if component in (itk.SC, itk.UC, itk.SS, itk.US, itk.SI, itk.UI, itk.F, + itk.D, itk.B, itk.SL, itk.SLL, itk.UL, itk.ULL): + mangle = component + elif component in [i[1] for i in itk.Vector.items()]: + mangle = itk.template(component)[1][0] + pixelType = 5 + elif component == itk.complex[itk.F]: + # complex float + return 'float', 10 + elif component == itk.complex[itk.D]: + # complex float + return 'double', 10 + elif component in [i[1] for i in itk.CovariantVector.items()]: + # CovariantVector + mangle = itk.template(component)[1][0] + pixelType = 7 + elif component in [i[1] for i in itk.Offset.items()]: + # Offset + return 'int64_t', 4 + elif component in [i[1] for i in itk.FixedArray.items()]: + # FixedArray + mangle = itk.template(component)[1][0] + pixelType = 11 + elif component in [i[1] for i in itk.RGBAPixel.items()]: + # RGBA + mangle = itk.template(component)[1][0] + pixelType = 3 + elif component in [i[1] for i in itk.RGBPixel.items()]: + # RGB + mangle = itk.template(component)[1][0] + pixelType = 2 + elif component in [i[1] for i in itk.SymmetricSecondRankTensor.items()]: + # SymmetricSecondRankTensor + mangle = itk.template(component)[1][0] + pixelType = 8 + else: + raise RuntimeError('Unrecognized component type: {0}'.format(str(component))) + + def _long_type(): + if os.name == 'nt': + return 'int32_t' + else: + return 'int64_t' + _python_to_js = { + itk.SC: 'int8_t', + itk.UC: 'uint8_t', + itk.SS: 'int16_t', + itk.US: 'uint16_t', + itk.SI: 'int32_t', + itk.UI: 'uint32_t', + itk.F: 'float', + itk.D: 'double', + itk.B: 'uint8_t', + itk.SL: _long_type(), + itk.UL: 'u' + _long_type(), + itk.SLL: 'int64_t', + itk.ULL: 'uint64_t', + } + imageType = dict( + dimension=itkimage.GetImageDimension(), + componentType=_python_to_js[mangle], + pixelType=pixelType, + components=itkimage.GetNumberOfComponentsPerPixel() + ) + return imageType + + +def image_type_from_wasm_type(jstype): + import itk + + _pixelType_to_prefix = { + 1: '', + 2: 'RGB', + 3: 'RGBA', + 4: 'O', + 5: 'V', + 7: 'CV', + 8: 'SSRT', + 11: 'FA' + } + pixelType = jstype['pixelType'] + dimension = jstype['dimension'] + if pixelType == 10: + if jstype['componentType'] == 'float': + return itk.Image[itk.complex, itk.F], np.float32 + else: + return itk.Image[itk.complex, itk.D], np.float64 + + def _long_type(): + if os.name == 'nt': + return 'LL' + else: + return 'L' + prefix = _pixelType_to_prefix[pixelType] + _js_to_python = { + 'int8_t': 'SC', + 'uint8_t': 'UC', + 'int16_t': 'SS', + 'uint16_t': 'US', + 'int32_t': 'SI', + 'uint32_t': 'UI', + 'int64_t': 'S' + _long_type(), + 'uint64_t': 'U' + _long_type(), + 'float': 'F', + 'double': 'D' + } + if pixelType != 4: + prefix += _js_to_python[jstype['componentType']] + if pixelType not in (1, 2, 3, 10): + prefix += str(dimension) + prefix += str(dimension) + return getattr(itk.Image, prefix)