diff --git a/nibabel/pointset.py b/nibabel/pointset.py index b40449801..db70f815d 100644 --- a/nibabel/pointset.py +++ b/nibabel/pointset.py @@ -48,7 +48,12 @@ def __array__(self, dtype: _DType, /) -> np.ndarray[ty.Any, _DType]: ... # pragma: no cover -@dataclass +class HasMeshAttrs(ty.Protocol): + coordinates: CoordinateArray + triangles: CoordinateArray + + +@dataclass(init=False) class Pointset: """A collection of points described by coordinates. @@ -65,7 +70,7 @@ class Pointset: coordinates: CoordinateArray affine: np.ndarray - homogeneous: bool = False + homogeneous: bool # Force use of __rmatmul__ with numpy arrays __array_priority__ = 99 @@ -144,6 +149,82 @@ def get_coords(self, *, as_homogeneous: bool = False): return coords +@dataclass(init=False) +class TriangularMesh(Pointset): + triangles: CoordinateArray + + def __init__( + self, + coordinates: CoordinateArray, + triangles: CoordinateArray, + affine: np.ndarray | None = None, + homogeneous: bool = False, + ): + super().__init__(coordinates, affine=affine, homogeneous=homogeneous) + self.triangles = triangles + + @classmethod + def from_tuple( + cls, + mesh: tuple[CoordinateArray, CoordinateArray], + affine: np.ndarray | None = None, + homogeneous: bool = False, + **kwargs, + ) -> Self: + return cls(mesh[0], mesh[1], affine=affine, homogeneous=homogeneous, **kwargs) + + @classmethod + def from_object( + cls, + mesh: HasMeshAttrs, + affine: np.ndarray | None = None, + homogeneous: bool = False, + **kwargs, + ) -> Self: + return cls( + mesh.coordinates, mesh.triangles, affine=affine, homogeneous=homogeneous, **kwargs + ) + + @property + def n_triangles(self): + """Number of faces + + Subclasses should override with more efficient implementations. + """ + return self.triangles.shape[0] + + def get_triangles(self): + """Mx3 array of indices into coordinate table""" + return np.asanyarray(self.triangles) + + def get_mesh(self, *, as_homogeneous: bool = False): + return self.get_coords(as_homogeneous=as_homogeneous), self.get_triangles() + + +class CoordinateFamilyMixin(Pointset): + def __init__(self, *args, name='original', **kwargs): + mapping = kwargs.pop('mapping', {}) + super().__init__(*args, **kwargs) + self._coords = {name: self.coordinates, **mapping} + + def get_names(self): + """List of surface names that can be passed to :meth:`with_name`""" + return list(self._coords) + + def with_name(self, name: str) -> Self: + new_coords = self._coords[name] + if new_coords is self.coordinates: + return self + # Make a copy, preserving all dataclass fields + new = replace(self, coordinates=new_coords) + # Conserve exact _coords mapping + new._coords = self._coords + return new + + def add_coordinates(self, name, coordinates): + self._coords[name] = coordinates + + class Grid(Pointset): r"""A regularly-spaced collection of coordinates diff --git a/nibabel/tests/test_pointset.py b/nibabel/tests/test_pointset.py index fb9a7c5c8..31c5335a1 100644 --- a/nibabel/tests/test_pointset.py +++ b/nibabel/tests/test_pointset.py @@ -1,3 +1,4 @@ +from collections import namedtuple from math import prod from pathlib import Path from unittest import skipUnless @@ -12,7 +13,7 @@ from nibabel.onetime import auto_attr from nibabel.optpkg import optional_package from nibabel.spatialimages import SpatialImage -from nibabel.tests.nibabel_data import get_nibabel_data +from nibabel.tests.nibabel_data import get_nibabel_data, needs_nibabel_data h5, has_h5py, _ = optional_package('h5py') @@ -182,3 +183,247 @@ def test_to_mask(self): ], ) assert np.array_equal(mask_img.affine, np.eye(4)) + + +class TestTriangularMeshes: + def test_api(self): + # Tetrahedron + coords = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + ] + ) + triangles = np.array( + [ + [0, 2, 1], + [0, 3, 2], + [0, 1, 3], + [1, 2, 3], + ] + ) + + mesh = namedtuple('mesh', ('coordinates', 'triangles'))(coords, triangles) + + tm1 = ps.TriangularMesh(coords, triangles) + tm2 = ps.TriangularMesh.from_tuple(mesh) + tm3 = ps.TriangularMesh.from_object(mesh) + + assert np.allclose(tm1.affine, np.eye(4)) + assert np.allclose(tm2.affine, np.eye(4)) + assert np.allclose(tm3.affine, np.eye(4)) + + assert tm1.homogeneous is False + assert tm2.homogeneous is False + assert tm3.homogeneous is False + + assert (tm1.n_coords, tm1.dim) == (4, 3) + assert (tm2.n_coords, tm2.dim) == (4, 3) + assert (tm3.n_coords, tm3.dim) == (4, 3) + + assert tm1.n_triangles == 4 + assert tm2.n_triangles == 4 + assert tm3.n_triangles == 4 + + out_coords, out_tris = tm1.get_mesh() + # Currently these are the exact arrays, but I don't think we should + # bake that assumption into the tests + assert np.allclose(out_coords, coords) + assert np.allclose(out_tris, triangles) + + +class TestCoordinateFamilyMixin(TestPointsets): + def test_names(self): + coords = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + ] + ) + cfm = ps.CoordinateFamilyMixin(coords) + + assert cfm.get_names() == ['original'] + assert np.allclose(cfm.with_name('original').coordinates, coords) + + cfm.add_coordinates('shifted', coords + 1) + assert set(cfm.get_names()) == {'original', 'shifted'} + shifted = cfm.with_name('shifted') + assert np.allclose(shifted.coordinates, coords + 1) + assert set(shifted.get_names()) == {'original', 'shifted'} + original = shifted.with_name('original') + assert np.allclose(original.coordinates, coords) + + # Avoid duplicating objects + assert original.with_name('original') is original + # But don't try too hard + assert original.with_name('original') is not cfm + + # with_name() preserves the exact coordinate mapping of the source object. + # Modifications of one are immediately available to all others. + # This is currently an implementation detail, and the expectation is that + # a family will be created once and then queried, but this behavior could + # potentially become confusing or relied upon. + # Change with care. + shifted.add_coordinates('shifted-again', coords + 2) + shift2 = shifted.with_name('shifted-again') + shift3 = cfm.with_name('shifted-again') + + +class H5ArrayProxy: + def __init__(self, file_like, dataset_name): + self.file_like = file_like + self.dataset_name = dataset_name + with h5.File(file_like, 'r') as h5f: + arr = h5f[dataset_name] + self._shape = arr.shape + self._dtype = arr.dtype + + @property + def is_proxy(self): + return True + + @property + def shape(self): + return self._shape + + @property + def ndim(self): + return len(self.shape) + + @property + def dtype(self): + return self._dtype + + def __array__(self, dtype=None): + with h5.File(self.file_like, 'r') as h5f: + return np.asanyarray(h5f[self.dataset_name], dtype) + + def __getitem__(self, slicer): + with h5.File(self.file_like, 'r') as h5f: + return h5f[self.dataset_name][slicer] + + +class H5Geometry(ps.CoordinateFamilyMixin, ps.TriangularMesh): + """Simple Geometry file structure that combines a single topology + with one or more coordinate sets + """ + + @classmethod + def from_filename(klass, pathlike): + coords = {} + with h5.File(pathlike, 'r') as h5f: + triangles = H5ArrayProxy(pathlike, '/topology') + for name in h5f['coordinates']: + coords[name] = H5ArrayProxy(pathlike, f'/coordinates/{name}') + self = klass(next(iter(coords.values())), triangles, mapping=coords) + return self + + def to_filename(self, pathlike): + with h5.File(pathlike, 'w') as h5f: + h5f.create_dataset('/topology', data=self.get_triangles()) + for name, coord in self._coords.items(): + h5f.create_dataset(f'/coordinates/{name}', data=coord) + + +class FSGeometryProxy: + def __init__(self, pathlike): + self._file_like = str(Path(pathlike)) + self._offset = None + self._vnum = None + self._fnum = None + + def _peek(self): + from nibabel.freesurfer.io import _fread3 + + with open(self._file_like, 'rb') as fobj: + magic = _fread3(fobj) + if magic != 16777214: + raise NotImplementedError('Triangle files only!') + fobj.readline() + fobj.readline() + self._vnum = np.fromfile(fobj, '>i4', 1)[0] + self._fnum = np.fromfile(fobj, '>i4', 1)[0] + self._offset = fobj.tell() + + @property + def vnum(self): + if self._vnum is None: + self._peek() + return self._vnum + + @property + def fnum(self): + if self._fnum is None: + self._peek() + return self._fnum + + @property + def offset(self): + if self._offset is None: + self._peek() + return self._offset + + @auto_attr + def coordinates(self): + return ArrayProxy(self._file_like, ((self.vnum, 3), '>f4', self.offset), order='C') + + @auto_attr + def triangles(self): + return ArrayProxy( + self._file_like, + ((self.fnum, 3), '>i4', self.offset + 12 * self.vnum), + order='C', + ) + + +class FreeSurferHemisphere(ps.CoordinateFamilyMixin, ps.TriangularMesh): + @classmethod + def from_filename(klass, pathlike): + path = Path(pathlike) + hemi, default = path.name.split('.') + self = klass.from_object(FSGeometryProxy(path), name=default) + mesh_names = ( + 'orig', + 'white', + 'smoothwm', + 'pial', + 'inflated', + 'sphere', + 'midthickness', + 'graymid', + ) # Often created + + for mesh in mesh_names: + if mesh != default: + fpath = path.parent / f'{hemi}.{mesh}' + if fpath.exists(): + self.add_coordinates(mesh, FSGeometryProxy(fpath).coordinates) + return self + + +@needs_nibabel_data('nitest-freesurfer') +def test_FreeSurferHemisphere(): + lh = FreeSurferHemisphere.from_filename(FS_DATA / 'fsaverage/surf/lh.white') + assert lh.n_coords == 163842 + assert lh.n_triangles == 327680 + + +@skipUnless(has_h5py, reason='Test requires h5py') +@needs_nibabel_data('nitest-freesurfer') +def test_make_H5Geometry(tmp_path): + lh = FreeSurferHemisphere.from_filename(FS_DATA / 'fsaverage/surf/lh.white') + h5geo = H5Geometry.from_object(lh) + for name in ('white', 'pial'): + h5geo.add_coordinates(name, lh.with_name(name).coordinates) + h5geo.to_filename(tmp_path / 'geometry.h5') + + rt_h5geo = H5Geometry.from_filename(tmp_path / 'geometry.h5') + assert set(h5geo._coords) == set(rt_h5geo._coords) + assert np.array_equal( + lh.with_name('white').get_coords(), rt_h5geo.with_name('white').get_coords() + ) + assert np.array_equal(lh.get_triangles(), rt_h5geo.get_triangles())