diff --git a/src/amrex/MultiFab.py b/src/amrex/MultiFab.py index 30ca5f4a..a890a303 100644 --- a/src/amrex/MultiFab.py +++ b/src/amrex/MultiFab.py @@ -6,8 +6,7 @@ License: BSD-3-Clause-LBNL """ - -def mf_to_numpy(self, copy=False, order="F"): +def mf_to_numpy(amr, self, copy=False, order="F"): """ Provide a Numpy view into a MultiFab. @@ -29,13 +28,24 @@ def mf_to_numpy(self, copy=False, order="F"): Returns ------- - list of np.array + list of numpy.array A list of numpy n-dimensional arrays, for each local block in the MultiFab. """ + mf = self + if copy: + mf = amr.MultiFab( + self.box_array(), + self.dm(), + self.n_comp(), + self.n_grow_vect(), + amr.MFInfo().set_arena(amr.The_Pinned_Arena()), + ) + amr.dtoh_memcpy(mf, self) + views = [] - for mfi in self: - views.append(self.array(mfi).to_numpy(copy, order)) + for mfi in mf: + views.append(mf.array(mfi).to_numpy(copy=False, order=order)) return views @@ -80,15 +90,9 @@ def mf_to_cupy(self, copy=False, order="F"): def register_MultiFab_extension(amr): """MultiFab helper methods""" - import inspect - import sys - - # register member functions for every MultiFab* type - for _, MultiFab_type in inspect.getmembers( - sys.modules[amr.__name__], - lambda member: inspect.isclass(member) - and member.__module__ == amr.__name__ - and member.__name__.startswith("MultiFab"), - ): - MultiFab_type.to_numpy = mf_to_numpy - MultiFab_type.to_cupy = mf_to_cupy + + # register member functions for the MultiFab type + amr.MultiFab.to_numpy = lambda self, copy=False, order="F": mf_to_numpy(amr, self, copy, order) + amr.MultiFab.to_numpy.__doc__ = mf_to_numpy.__doc__ + + amr.MultiFab.to_cupy = mf_to_cupy diff --git a/src/amrex/space1d/__init__.py b/src/amrex/space1d/__init__.py index 76b75647..1ad9f5fd 100644 --- a/src/amrex/space1d/__init__.py +++ b/src/amrex/space1d/__init__.py @@ -45,11 +45,13 @@ def Print(*args, **kwargs): from ..Array4 import register_Array4_extension +from ..MultiFab import register_MultiFab_extension from ..ArrayOfStructs import register_AoS_extension from ..PODVector import register_PODVector_extension from ..StructOfArrays import register_SoA_extension register_Array4_extension(amrex_1d_pybind) +register_MultiFab_extension(amrex_1d_pybind) register_PODVector_extension(amrex_1d_pybind) register_SoA_extension(amrex_1d_pybind) register_AoS_extension(amrex_1d_pybind) diff --git a/src/amrex/space2d/__init__.py b/src/amrex/space2d/__init__.py index b3626276..6dcf7c10 100644 --- a/src/amrex/space2d/__init__.py +++ b/src/amrex/space2d/__init__.py @@ -45,11 +45,13 @@ def Print(*args, **kwargs): from ..Array4 import register_Array4_extension +from ..MultiFab import register_MultiFab_extension from ..ArrayOfStructs import register_AoS_extension from ..PODVector import register_PODVector_extension from ..StructOfArrays import register_SoA_extension register_Array4_extension(amrex_2d_pybind) +register_MultiFab_extension(amrex_2d_pybind) register_PODVector_extension(amrex_2d_pybind) register_SoA_extension(amrex_2d_pybind) register_AoS_extension(amrex_2d_pybind) diff --git a/src/amrex/space3d/__init__.py b/src/amrex/space3d/__init__.py index 4b4fd623..df8b3727 100644 --- a/src/amrex/space3d/__init__.py +++ b/src/amrex/space3d/__init__.py @@ -45,11 +45,13 @@ def Print(*args, **kwargs): from ..Array4 import register_Array4_extension +from ..MultiFab import register_MultiFab_extension from ..ArrayOfStructs import register_AoS_extension from ..PODVector import register_PODVector_extension from ..StructOfArrays import register_SoA_extension register_Array4_extension(amrex_3d_pybind) +register_MultiFab_extension(amrex_3d_pybind) register_PODVector_extension(amrex_3d_pybind) register_SoA_extension(amrex_3d_pybind) register_AoS_extension(amrex_3d_pybind) diff --git a/tests/test_multifab.py b/tests/test_multifab.py index 80727615..0e0713aa 100644 --- a/tests/test_multifab.py +++ b/tests/test_multifab.py @@ -350,3 +350,12 @@ def test_mfab_dtoh_copy(make_mfab_device): device_max = mfab_device.max(0) assert device_min == device_max assert device_max == 11.0 + + # numpy bindings (w/ copy) + local_boxes_host = mfab_device.to_numpy(copy=True) + assert max([np.max(box) for box in local_boxes_host]) == device_max + + # cupy bindings (w/o copy) + import cupy as cp + local_boxes_device = mfab_device.to_cupy() + assert max([cp.max(box) for box in local_boxes_device]) == device_max