Skip to content

Commit

Permalink
Merge pull request #59 from constantinpape/ngff-trafos
Browse files Browse the repository at this point in the history
Add support for ngff transformations
  • Loading branch information
constantinpape authored May 28, 2022
2 parents 3e6969e + 397b433 commit 885fd0f
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 1 deletion.
1 change: 1 addition & 0 deletions elf/transformation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
matrix_to_parameters,
native_to_bdv,
parameters_to_matrix)
from .ngff import native_to_ngff, ngff_to_native
from .resize import transform_subvolume_resize
2 changes: 1 addition & 1 deletion elf/transformation/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def scale_from_matrix(matrix):
""" Return the scales from the affine matrix """
ndim = matrix.shape[0] - 1
scale = [np.linalg.norm(matrix[:ndim, d]) for d in range(ndim)]
return scale
return np.array(scale)


# TODO need to figure out how to go from affine elements to euler angles
Expand Down
154 changes: 154 additions & 0 deletions elf/transformation/ngff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import json
import os
import warnings

from .affine import (affine_matrix_2d, affine_matrix_3d,
scale_from_matrix, translation_from_matrix)

SUPPORTED_NGFF_VERSIONS = ("0.4",)


def _parse_04_transformation(ngff_trafo, indices):
assert len(ngff_trafo) <= 2

scale, translation = None, None
for trafo in ngff_trafo:
trafo_type = trafo["type"]
assert trafo_type in ("scale", "translation"), f"Expected scale or translation transform, got {trafo_type}"
if trafo_type == "scale":
scale = trafo["scale"]
if trafo_type == "translation":
translation = trafo["translation"]

assert sum((scale is not None, translation is not None)) > 0
if scale is not None and translation is not None:
assert len(scale) == len(translation)

if indices and scale:
scale = [scale[index] for index in indices]
if indices and translation:
translation = [translation[index] for index in indices]

ndim = len(translation) if scale is None else len(scale)
if ndim == 2:
transform = affine_matrix_2d(scale=scale, translation=translation)
elif ndim == 3:
transform = affine_matrix_3d(scale=scale, translation=translation)
else:
raise RuntimeError(f"Only support 2d or 3d affines, got {ndim}")
return transform


def _parse_transformation(ngff_trafo, version, indices):
if version == "0.4":
return _parse_04_transformation(ngff_trafo, indices)
else:
raise RuntimeError(f"Unsupported version {version}")


def _get_04_axis_indices(multiscales, axes):
indices = []
for i, ax in enumerate(multiscales["axes"]):
if ax["name"] in axes:
indices.append(i)
assert len(indices) == len(axes)
return indices


def _get_axis_indices(multiscales, axes, version):
if version == "0.4":
return _get_04_axis_indices(multiscales, axes)
else:
raise RuntimeError(f"Unsupported version {version}")


def ngff_to_native(multiscales, scale_level=0, axes=None):
"""Convert NGFF transformation to affine transformation matrix.
Arguments:
multiscales [str, lis[dict] or dict] - the ngff multiscales metadata.
Can be either a filepath to the corresponding zarr array or a dict
containing the deserialzed ngff metadata.
scale_level [int] - the scale level for which to compute the transformation (default: 0)
axes [str] - subset of axes for which to compute the transformation.
E.g. "zyx" to compute only for spatial axes (default: None)
Returns:
np.ndarray - the 3x3 (2d data) or 4x4 (3d data) transformation matrix
"""
if isinstance(multiscales, str):
assert os.path.exists(multiscales)
if os.path.isdir(multiscales):
multiscales = os.path.join(multiscales, ".zattrs")
with open(multiscales) as f:
multiscales = json.load(f)

if isinstance(multiscales, dict) and len(multiscales) == 1:
assert "multiscales" in multiscales
multiscales = multiscales["multiscales"]
if isinstance(multiscales, list):
multiscales = multiscales[0]
assert isinstance(multiscales, dict)

if "version" in multiscales:
version = multiscales["version"]
else:
version = SUPPORTED_NGFF_VERSIONS[-1]
warnings.warn(f"Could not find version field in multiscales metadata, assuming latest version: {version}")
if version not in SUPPORTED_NGFF_VERSIONS:
raise RuntimeError(
f"NGFF version {version} is not in supported versions: {SUPPORTED_NGFF_VERSIONS}"
)

indices = None if axes is None else _get_axis_indices(multiscales, axes, version)
transformation = multiscales["datasets"][scale_level].get("coordinateTransformations", None)
if transformation is not None:
transformation = _parse_transformation(transformation, version, indices)

if "coordinateTransformations" in multiscales:
global_transformation = multiscales["coordinateTransformations"]
global_transformation = _parse_transformation(global_transformation, version, indices)
if transformation is None:
transformation = global_transformation
else:
assert transformation.shape == global_transformation.shape
transformation = transformation @ global_transformation

return transformation


def _to_04_trafo(transformation):
trafos = []
scale = scale_from_matrix(transformation)
if any(sc != 1.0 for sc in scale):
trafos.append({"type": "scale", "scale": scale.tolist()})
translation = translation_from_matrix(transformation)
if any(trans != 0.0 for trans in translation):
trafos.append({"type": "translation", "translation": translation.tolist()})
return {"coordinateTransformations": trafos}


# TODO implement expanding to axes, e.g. expanding zyx to tczyx trafo
def native_to_ngff(transformation, version=None):
"""Convert affine transformation matrix to NGFF transformation.
Arguments:
transformation [np.ndarray] - the transformation matrix
version [str] - the ngff version to use.
By default will use the latest supported version (default: None)
Returns:
dict - the ngff transformation
"""
if transformation.shape not in [(3, 3), (4, 4)]:
raise ValueError(
f"Invalid shape of the transformation matrix: {transformation.shape}, expect 3x3 or 4x4 matrix"
)

if version is None:
version = SUPPORTED_NGFF_VERSIONS[-1]
if version == "0.4":
trafo = _to_04_trafo(transformation)
else:
raise RuntimeError(
f"NGFF version {version} is not in supported versions: {SUPPORTED_NGFF_VERSIONS}"
)
return trafo
93 changes: 93 additions & 0 deletions test/transformation/test_ngff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
import json
import unittest
from shutil import rmtree
from sys import platform

import numpy as np
import requests
from elf.transformation import affine as affine_utils


NGFF_EXAMPLES = {
"0.4": {
"yx": "https://s3.embl.de/i2k-2020/ngff-example-data/v0.4/yx.ome.zarr",
"zyx": "https://s3.embl.de/i2k-2020/ngff-example-data/v0.4/zyx.ome.zarr",
"tczyx": "https://s3.embl.de/i2k-2020/ngff-example-data/v0.4/tczyx.ome.zarr",
}
}


@unittest.skipIf(platform == "win32", "Download fails on windows")
class TestNgff(unittest.TestCase):
versions = list(NGFF_EXAMPLES.keys())
tmp_folder = "./tmp"

@classmethod
def setUpClass(cls):
os.makedirs(cls.tmp_folder, exist_ok=True)
for version in cls.versions:
version_folder = os.path.join(cls.tmp_folder, version)
os.makedirs(version_folder, exist_ok=True)
examples = NGFF_EXAMPLES[version]
for name, example_url in examples.items():
url = os.path.join(example_url, ".zattrs")
out_path = os.path.join(version_folder, f"{name}.json")
with requests.get(url) as r, open(out_path, "w") as f:
f.write(r.content.decode("utf8"))

@classmethod
def tearDownClass(cls):
try:
rmtree(cls.tmp_folder)
except OSError:
pass

def test_ngff_to_native_simple(self):
from elf.transformation import ngff_to_native
for version in self.versions:
for name in ("yx", "zyx"):
for scale_level in (0, 2):
example = os.path.join(self.tmp_folder, version, f"{name}.json")
with open(example) as f:
multiscales = json.load(f)
trafo = ngff_to_native(multiscales, scale_level=scale_level)
self.assertIsInstance(trafo, np.ndarray)
exp_shape = (3, 3) if name == "yx" else (4, 4)
self.assertEqual(trafo.shape, exp_shape)
scale = affine_utils.scale_from_matrix(trafo)
ds_trafos = multiscales["multiscales"][0]["datasets"][scale_level]["coordinateTransformations"]
exp_scale = ds_trafos[0]["scale"]
self.assertTrue(np.allclose(scale, exp_scale))

def test_ngff_to_native_axes(self):
from elf.transformation import ngff_to_native
axes = "zyx"
for version in self.versions:
name = "tczyx"
example = os.path.join(self.tmp_folder, version, f"{name}.json")
with open(example) as f:
multiscales = json.load(f)
trafo = ngff_to_native(multiscales, axes=axes)
self.assertIsInstance(trafo, np.ndarray)
exp_shape = (4, 4)
self.assertEqual(trafo.shape, exp_shape)
scale = affine_utils.scale_from_matrix(trafo)
exp_scale = multiscales["multiscales"][0]["datasets"][0]["coordinateTransformations"][0]["scale"][2:]
self.assertTrue(np.allclose(scale, exp_scale))

def test_native_to_ngff_2d(self):
from elf.transformation import native_to_ngff
scale, translation = np.random.rand(2), np.random.rand(2)
matrix = affine_utils.affine_matrix_2d(scale=scale, translation=translation)
trafo = native_to_ngff(matrix)

trafo_parts = trafo["coordinateTransformations"]
self.assertEqual(trafo_parts[0]["type"], "scale")
self.assertTrue(np.allclose(scale, trafo_parts[0]["scale"]))
self.assertEqual(trafo_parts[1]["type"], "translation")
self.assertTrue(np.allclose(translation, trafo_parts[1]["translation"]))


if __name__ == "__main__":
unittest.main()

0 comments on commit 885fd0f

Please sign in to comment.