From dee0ca44032d46a4fa4480fd6594a0ff348fafa7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 14 Nov 2024 08:43:49 -0600 Subject: [PATCH] add device<->host transfer mappers (#282) * implement TransferTo{Host,Device}Mapper * do not support transfer to same type * use {to,from}_numpy * lint fixes * add comment * ruff * disable spurious pylint warning * add to docs * more doc fixes * datawrapper doc * more doc fixes * Improve mapper names * Bring back numpy actx constructor docs --------- Co-authored-by: Andreas Kloeckner --- arraycontext/impl/numpy/__init__.py | 4 +- arraycontext/impl/pytato/__init__.py | 14 +++-- arraycontext/impl/pytato/utils.py | 92 +++++++++++++++++++++++++++- test/test_pytato_arraycontext.py | 53 ++++++++++++++++ 4 files changed, 156 insertions(+), 7 deletions(-) diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index f8ba95e3..c2f884a6 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -1,10 +1,10 @@ from __future__ import annotations -""" +__doc__ = """ .. currentmodule:: arraycontext -A mod :`numpy`-based array context. +A :mod:`numpy`-based array context. .. autoclass:: NumpyArrayContext """ diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 1d36971c..e3ce52a7 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -4,13 +4,13 @@ __doc__ = """ .. currentmodule:: arraycontext -A :mod:`pytato`-based array context defers the evaluation of an array until its +A :mod:`pytato`-based array context defers the evaluation of an array until it is frozen. The execution contexts for the evaluations are specific to an -:class:`~arraycontext.ArrayContext` type. For ex. +:class:`~arraycontext.ArrayContext` type. For example, :class:`~arraycontext.PytatoPyOpenCLArrayContext` uses :mod:`pyopencl` to JIT-compile and execute the array expressions. -Following :mod:`pytato`-based array context are provided: +The following :mod:`pytato`-based array contexts are provided: .. autoclass:: PytatoPyOpenCLArrayContext .. autoclass:: PytatoJAXArrayContext @@ -20,6 +20,12 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. automodule:: arraycontext.impl.pytato.compile + + +Utils +^^^^^ + +.. automodule:: arraycontext.impl.pytato.utils """ __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -227,7 +233,7 @@ def get_target(self): class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): """ - A :class:`ArrayContext` that uses :mod:`pytato` data types to represent + An :class:`ArrayContext` that uses :mod:`pytato` data types to represent the arrays targeting OpenCL for offloading operations. .. attribute:: queue diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index d0c80a33..5b059992 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -1,3 +1,9 @@ +__doc__ = """ +.. autofunction:: transfer_from_numpy +.. autofunction:: transfer_to_numpy +""" + + __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees """ @@ -22,6 +28,7 @@ THE SOFTWARE. """ + from collections.abc import Mapping from typing import TYPE_CHECKING, Any, cast @@ -36,9 +43,10 @@ make_placeholder, ) from pytato.target.loopy import LoopyPyOpenCLTarget -from pytato.transform import CopyMapper +from pytato.transform import ArrayOrNames, CopyMapper from pytools import UniqueNameGenerator, memoize_method +from arraycontext import ArrayContext from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis @@ -125,4 +133,86 @@ def get_loopy_target(self) -> "lp.PyOpenCLTarget": # }}} + +# {{{ Transfer mappers + +class TransferFromNumpyMapper(CopyMapper): + """A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper` + instances to be device arrays, using + :meth:`~arraycontext.ArrayContext.from_numpy`. + """ + def __init__(self, actx: ArrayContext) -> None: + super().__init__() + self.actx = actx + + def map_data_wrapper(self, expr: DataWrapper) -> Array: + import numpy as np + + if not isinstance(expr.data, np.ndarray): + raise ValueError("TransferFromNumpyMapper: tried to transfer data that " + "is already on the device") + + # Ideally, this code should just do + # return self.actx.from_numpy(expr.data).tagged(expr.tags), + # but there seems to be no way to transfer the non_equality_tags in that case. + new_dw = self.actx.from_numpy(expr.data) + assert isinstance(new_dw, DataWrapper) + + # https://github.com/pylint-dev/pylint/issues/3893 + # pylint: disable=unexpected-keyword-arg + return DataWrapper( + data=new_dw.data, + shape=expr.shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + +class TransferToNumpyMapper(CopyMapper): + """A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper` + instances to be :class:`numpy.ndarray` instances, using + :meth:`~arraycontext.ArrayContext.to_numpy`. + """ + def __init__(self, actx: ArrayContext) -> None: + super().__init__() + self.actx = actx + + def map_data_wrapper(self, expr: DataWrapper) -> Array: + import numpy as np + + import arraycontext.impl.pyopencl.taggable_cl_array as tga + if not isinstance(expr.data, tga.TaggableCLArray): + raise ValueError("TransferToNumpyMapper: tried to transfer data that " + "is already on the host") + + np_data = self.actx.to_numpy(expr.data) + assert isinstance(np_data, np.ndarray) + + # https://github.com/pylint-dev/pylint/issues/3893 + # pylint: disable=unexpected-keyword-arg + return DataWrapper( + data=np_data, + shape=expr.shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + +def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames: + """Transfer arrays contained in :class:`~pytato.array.DataWrapper` + instances to be device arrays, using + :meth:`~arraycontext.ArrayContext.from_numpy`. + """ + return TransferFromNumpyMapper(actx)(expr) + + +def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames: + """Transfer arrays contained in :class:`~pytato.array.DataWrapper` + instances to be :class:`numpy.ndarray` instances, using + :meth:`~arraycontext.ArrayContext.to_numpy`. + """ + return TransferToNumpyMapper(actx)(expr) + +# }}} + # vim: foldmethod=marker diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index 7922f383..a14df50f 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -192,6 +192,59 @@ def twice(x): assert res == 198 +def test_transfer(actx_factory): + import numpy as np + + import pytato as pt + actx = actx_factory() + + # {{{ simple tests + + a = actx.from_numpy(np.array([0, 1, 2, 3])).tagged(FooTag()) + + from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray + assert isinstance(a.data, TaggableCLArray) + + from arraycontext.impl.pytato.utils import transfer_from_numpy, transfer_to_numpy + + ah = transfer_to_numpy(a, actx) + assert ah != a + assert a.tags == ah.tags + assert a.non_equality_tags == ah.non_equality_tags + assert isinstance(ah.data, np.ndarray) + + with pytest.raises(ValueError): + _ahh = transfer_to_numpy(ah, actx) + + ad = transfer_from_numpy(ah, actx) + assert isinstance(ad.data, TaggableCLArray) + assert ad != ah + assert ad != a # copied DataWrappers compare unequal + assert ad.tags == ah.tags + assert ad.non_equality_tags == ah.non_equality_tags + assert np.array_equal(a.data.get(), ad.data.get()) + + with pytest.raises(ValueError): + _add = transfer_from_numpy(ad, actx) + + # }}} + + # {{{ test with DictOfNamedArrays + + dag = pt.make_dict_of_named_arrays({ + "a_expr": a + 2 + }) + + dagh = transfer_to_numpy(dag, actx) + assert dagh != dag + assert isinstance(dagh["a_expr"].expr.bindings["_in0"].data, np.ndarray) + + daghd = transfer_from_numpy(dagh, actx) + assert isinstance(daghd["a_expr"].expr.bindings["_in0"].data, TaggableCLArray) + + # }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: