Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify linear axis assignment logic #265

Merged
merged 8 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/imagej/_java.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class MyJavaClasses(JavaClasses):
significantly easier and more readable.
"""

@JavaClasses.java_import
def Double(self):
return "java.lang.Double"

@JavaClasses.java_import
def Throwable(self):
return "java.lang.Throwable"
Expand All @@ -50,6 +54,14 @@ def MetadataWrapper(self):
def LabelingIOService(self):
return "io.scif.labeling.LabelingIOService"

@JavaClasses.java_import
def DefaultLinearAxis(self):
return "net.imagej.axis.DefaultLinearAxis"

@JavaClasses.java_import
def EnumeratedAxis(self):
return "net.imagej.axis.EnumeratedAxis"

@JavaClasses.java_import
def Dataset(self):
return "net.imagej.Dataset"
Expand Down
116 changes: 49 additions & 67 deletions src/imagej/dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Utility functions for querying and manipulating dimensional axis metadata.
"""
import logging
from typing import List, Tuple
from typing import List, Tuple, Union

import numpy as np
import scyjava as sj
Expand Down Expand Up @@ -177,49 +177,53 @@ def prioritize_rai_axes_order(
return permute_order


def _assign_axes(xarr: xr.DataArray):
def _assign_axes(
xarr: xr.DataArray,
) -> List[Union["jc.DefaultLinearAxis", "jc.EnumeratedAxis"]]:
"""
Obtain xarray axes names, origin, and scale and convert into ImageJ Axis;
currently supports EnumeratedAxis
:param xarr: xarray that holds the units
:return: A list of ImageJ Axis with the specified origin and scale
Obtain xarray axes names, origin, scale and convert into ImageJ Axis. Supports both
DefaultLinearAxis and the newer EnumeratedAxis.

Note that, in many cases, there are small discrepancies between the coordinates.
This can either be actually within the data, or it can be from floating point math
errors. In this case, we delegate to numpy.isclose to tell us whether our
coordinates are linear or not. If our coordinates are nonlinear, and the
EnumeratedAxis type is available, we will use it. Otherwise, this function
returns a DefaultLinearAxis.

:param xarr: xarray that holds the data.
:return: A list of ImageJ Axis with the specified origin and scale.
"""
Double = sj.jimport("java.lang.Double")

axes = [""] * len(xarr.dims)

# try to get EnumeratedAxis, if not then default to LinearAxis in the loop
try:
EnumeratedAxis = _get_enumerated_axis()
except (JException, TypeError):
EnumeratedAxis = None

axes = [""] * xarr.ndim
for dim in xarr.dims:
axis_str = _convert_dim(dim, direction="java")
axis_str = _convert_dim(dim, "java")
ax_type = jc.Axes.get(axis_str)
ax_num = _get_axis_num(xarr, dim)
scale = _get_scale(xarr.coords[dim])
coords_arr = xarr.coords[dim]

if scale is None:
# coerce numeric scale
if not _is_numeric_scale(coords_arr):
_logger.warning(
f"The {ax_type.label} axis is non-numeric and is translated "
f"The {ax_type.getLabel()} axis is non-numeric and is translated "
"to a linear index."
)
doub_coords = [
Double(np.double(x)) for x in np.arange(len(xarr.coords[dim]))
]
coords_arr = [np.double(x) for x in np.arange(len(xarr.coords[dim]))]
else:
doub_coords = [Double(np.double(x)) for x in xarr.coords[dim]]

# EnumeratedAxis is a new axis made for xarray, so is only present in
# ImageJ versions that are released later than March 2020.
# This actually returns a LinearAxis if using an earlier version.
if EnumeratedAxis is not None:
java_axis = EnumeratedAxis(ax_type, sj.to_java(doub_coords))
coords_arr = coords_arr.to_numpy().astype(np.double)

# check scale linearity
diffs = np.diff(coords_arr)
linear: bool = diffs.size and np.all(np.isclose(diffs, diffs[0]))

if not linear:
try:
j_coords = [jc.Double(x) for x in coords_arr]
axes[ax_num] = jc.EnumeratedAxis(ax_type, sj.to_java(j_coords))
except (JException, TypeError):
# if EnumeratedAxis not available - use DefaultLinearAxis
axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type)
else:
java_axis = _get_linear_axis(ax_type, sj.to_java(doub_coords))

axes[ax_num] = java_axis
axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type)

return axes

Expand Down Expand Up @@ -274,48 +278,26 @@ def _get_axes_coords(
return coords


def _get_scale(axis):
def _get_default_linear_axis(coords_arr: np.ndarray, ax_type: "jc.AxisType"):
"""
Get the scale of an axis, assuming it is linear and so the scale is simply
second - first coordinate.
Create a new DefaultLinearAxis with the given coordinate array and axis type.

:param axis: A 1D list like entry accessible with indexing, which contains the
axis coordinates
:return: The scale for this axis or None if it is a non-numeric scale.
:param coords_arr: A 1D NumPy array.
:return: An instance of net.imagej.axis.DefaultLinearAxis.
"""
try:
# HACK: This axis length check is a work around for singleton dimensions.
# You can't calculate the slope of a singleton dimension.
# This section will be removed when axis-scale-logic is merged.
if len(axis) <= 1:
return 1
else:
return axis.values[1] - axis.values[0]
except TypeError:
return None

scale = coords_arr[1] - coords_arr[0] if len(coords_arr) > 1 else 1
origin = coords_arr[0] if len(coords_arr) > 0 else 0
return jc.DefaultLinearAxis(ax_type, jc.Double(scale), jc.Double(origin))

def _get_enumerated_axis():
"""Get EnumeratedAxis.

EnumeratedAxis is only in releases later than March 2020. If using
an older version of ImageJ without EnumeratedAxis, use
_get_linear_axis() instead.
def _is_numeric_scale(coords_array: np.ndarray) -> bool:
"""
return sj.jimport("net.imagej.axis.EnumeratedAxis")


def _get_linear_axis(axis_type: "jc.AxisType", values):
"""Get linear axis.
Checks if the coordinates array of the given axis is numeric.

This is used if no EnumeratedAxis is found. If EnumeratedAxis
is available, use _get_enumerated_axis() instead.
:param coords_array: A 1D NumPy array.
:return: bool
"""
DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis")
origin = values[0]
scale = values[1] - values[0]
axis = DefaultLinearAxis(axis_type, scale, origin)
return axis
return np.issubdtype(coords_array.dtype, np.number)


def _dataset_to_imgplus(rai: "jc.RandomAccessibleInterval") -> "jc.ImgPlus":
Expand Down
99 changes: 99 additions & 0 deletions tests/test_image_conversion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
import string

import numpy as np
import pytest
Expand All @@ -7,6 +8,7 @@

import imagej.dims as dims
import imagej.images as images
from imagej._java import jc

# -- Image helpers --

Expand Down Expand Up @@ -94,6 +96,75 @@ def get_xarr(option="C"):
return xarr


def get_non_linear_coord_xarr(option="C"):
name: str = "non_linear_coord_data_array"
linear_coord_arr = np.arange(5)
# generate a 1D log scale array
non_linear_coord_arr = np.logspace(0, np.log10(100), num=30)
if option == "C":
xarr = xr.DataArray(
np.random.rand(30, 30, 5),
dims=["row", "col", "ch"],
coords={
"row": non_linear_coord_arr,
"col": non_linear_coord_arr,
"ch": linear_coord_arr,
},
attrs={"Hello": "World"},
name=name,
)
elif option == "F":
xarr = xr.DataArray(
np.ndarray([30, 30, 5], order="F"),
dims=["row", "col", "ch"],
coords={
"row": non_linear_coord_arr,
"col": non_linear_coord_arr,
"ch": linear_coord_arr,
},
attrs={"Hello": "World"},
name=name,
)
else:
xarr = xr.DataArray(np.random.rand(30, 30, 5), name=name)

return xarr


def get_non_numeric_coord_xarr(option="C"):
name: str = "non_numeric_coord_data_array"
non_numeric_coord_list = [random.choice(string.ascii_letters) for _ in range(30)]
linear_coord_arr = np.arange(5)
if option == "C":
xarr = xr.DataArray(
np.random.rand(30, 30, 5),
dims=["row", "col", "ch"],
coords={
"row": non_numeric_coord_list,
"col": non_numeric_coord_list,
"ch": linear_coord_arr,
},
attrs={"Hello": "World"},
name=name,
)
elif option == "F":
xarr = xr.DataArray(
np.ndarray([30, 30, 5], order="F"),
dims=["row", "col", "ch"],
coords={
"row": non_numeric_coord_list,
"col": non_numeric_coord_list,
"ch": linear_coord_arr,
},
attrs={"Hello": "World"},
name=name,
)
else:
xarr = xr.DataArray(np.random.rand(30, 30, 5), name=name)

return xarr


# -- Helpers --


Expand Down Expand Up @@ -359,6 +430,34 @@ def test_no_coords_or_dims_in_xarr(ij_fixture):
assert_inverted_xarr_equal_to_xarr(dataset, ij_fixture, xarr)


def test_linear_coord_on_xarr_conversion(ij_fixture):
xarr = get_xarr()
dataset = ij_fixture.py.to_java(xarr)
axes = dataset.dim_axes
# all axes should be DefaultLinearAxis
for ax in axes:
assert isinstance(ax, jc.DefaultLinearAxis)


def test_non_linear_coord_on_xarr_conversion(ij_fixture):
xarr = get_non_linear_coord_xarr()
dataset = ij_fixture.py.to_java(xarr)
axes = dataset.dim_axes
# axes [0, 1] should be EnumeratedAxis with axis 2 as DefaultLinearAxis
for i in range(2):
assert isinstance(axes[i], jc.EnumeratedAxis)
assert isinstance(axes[-1], jc.DefaultLinearAxis)


def test_non_numeric_coord_on_xarr_conversion(ij_fixture):
xarr = get_non_numeric_coord_xarr()
dataset = ij_fixture.py.to_java(xarr)
axes = dataset.dim_axes
# all axes should be DefaultLinearAxis
for ax in axes:
assert isinstance(ax, jc.DefaultLinearAxis)


dataset_conversion_parameters = [
(
get_img,
Expand Down