Skip to content

Commit

Permalink
add device<->host transfer mappers (#282)
Browse files Browse the repository at this point in the history
* implement TransferTo{Host,Device}Mapper

* do not support transfer to same type

* use {to,from}_numpy

* lint fixes

* add comment

* ruff

* disable spurious pylint warning

* add to docs

* more doc fixes

* datawrapper doc

* more doc fixes

* Improve mapper names

* Bring back numpy actx constructor docs

---------

Co-authored-by: Andreas Kloeckner <inform@tiker.net>
  • Loading branch information
matthiasdiener and inducer authored Nov 14, 2024
1 parent a59a9c8 commit dee0ca4
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 7 deletions.
4 changes: 2 additions & 2 deletions arraycontext/impl/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations


"""
__doc__ = """
.. currentmodule:: arraycontext
A mod :`numpy`-based array context.
A :mod:`numpy`-based array context.
.. autoclass:: NumpyArrayContext
"""
Expand Down
14 changes: 10 additions & 4 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
__doc__ = """
.. currentmodule:: arraycontext
A :mod:`pytato`-based array context defers the evaluation of an array until its
A :mod:`pytato`-based array context defers the evaluation of an array until it is
frozen. The execution contexts for the evaluations are specific to an
:class:`~arraycontext.ArrayContext` type. For ex.
:class:`~arraycontext.ArrayContext` type. For example,
:class:`~arraycontext.PytatoPyOpenCLArrayContext` uses :mod:`pyopencl` to
JIT-compile and execute the array expressions.
Following :mod:`pytato`-based array context are provided:
The following :mod:`pytato`-based array contexts are provided:
.. autoclass:: PytatoPyOpenCLArrayContext
.. autoclass:: PytatoJAXArrayContext
Expand All @@ -20,6 +20,12 @@
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: arraycontext.impl.pytato.compile
Utils
^^^^^
.. automodule:: arraycontext.impl.pytato.utils
"""
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
Expand Down Expand Up @@ -227,7 +233,7 @@ def get_target(self):

class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
"""
A :class:`ArrayContext` that uses :mod:`pytato` data types to represent
An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
the arrays targeting OpenCL for offloading operations.
.. attribute:: queue
Expand Down
92 changes: 91 additions & 1 deletion arraycontext/impl/pytato/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
__doc__ = """
.. autofunction:: transfer_from_numpy
.. autofunction:: transfer_to_numpy
"""


__copyright__ = """
Copyright (C) 2021 University of Illinois Board of Trustees
"""
Expand All @@ -22,6 +28,7 @@
THE SOFTWARE.
"""


from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast

Expand All @@ -36,9 +43,10 @@
make_placeholder,
)
from pytato.target.loopy import LoopyPyOpenCLTarget
from pytato.transform import CopyMapper
from pytato.transform import ArrayOrNames, CopyMapper
from pytools import UniqueNameGenerator, memoize_method

from arraycontext import ArrayContext
from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis


Expand Down Expand Up @@ -125,4 +133,86 @@ def get_loopy_target(self) -> "lp.PyOpenCLTarget":

# }}}


# {{{ Transfer mappers

class TransferFromNumpyMapper(CopyMapper):
"""A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be device arrays, using
:meth:`~arraycontext.ArrayContext.from_numpy`.
"""
def __init__(self, actx: ArrayContext) -> None:
super().__init__()
self.actx = actx

def map_data_wrapper(self, expr: DataWrapper) -> Array:
import numpy as np

if not isinstance(expr.data, np.ndarray):
raise ValueError("TransferFromNumpyMapper: tried to transfer data that "
"is already on the device")

# Ideally, this code should just do
# return self.actx.from_numpy(expr.data).tagged(expr.tags),
# but there seems to be no way to transfer the non_equality_tags in that case.
new_dw = self.actx.from_numpy(expr.data)
assert isinstance(new_dw, DataWrapper)

# https://github.com/pylint-dev/pylint/issues/3893
# pylint: disable=unexpected-keyword-arg
return DataWrapper(
data=new_dw.data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)


class TransferToNumpyMapper(CopyMapper):
"""A mapper to transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be :class:`numpy.ndarray` instances, using
:meth:`~arraycontext.ArrayContext.to_numpy`.
"""
def __init__(self, actx: ArrayContext) -> None:
super().__init__()
self.actx = actx

def map_data_wrapper(self, expr: DataWrapper) -> Array:
import numpy as np

import arraycontext.impl.pyopencl.taggable_cl_array as tga
if not isinstance(expr.data, tga.TaggableCLArray):
raise ValueError("TransferToNumpyMapper: tried to transfer data that "
"is already on the host")

np_data = self.actx.to_numpy(expr.data)
assert isinstance(np_data, np.ndarray)

# https://github.com/pylint-dev/pylint/issues/3893
# pylint: disable=unexpected-keyword-arg
return DataWrapper(
data=np_data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)


def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be device arrays, using
:meth:`~arraycontext.ArrayContext.from_numpy`.
"""
return TransferFromNumpyMapper(actx)(expr)


def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
"""Transfer arrays contained in :class:`~pytato.array.DataWrapper`
instances to be :class:`numpy.ndarray` instances, using
:meth:`~arraycontext.ArrayContext.to_numpy`.
"""
return TransferToNumpyMapper(actx)(expr)

# }}}

# vim: foldmethod=marker
53 changes: 53 additions & 0 deletions test/test_pytato_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,59 @@ def twice(x):
assert res == 198


def test_transfer(actx_factory):
import numpy as np

import pytato as pt
actx = actx_factory()

# {{{ simple tests

a = actx.from_numpy(np.array([0, 1, 2, 3])).tagged(FooTag())

from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
assert isinstance(a.data, TaggableCLArray)

from arraycontext.impl.pytato.utils import transfer_from_numpy, transfer_to_numpy

ah = transfer_to_numpy(a, actx)
assert ah != a
assert a.tags == ah.tags
assert a.non_equality_tags == ah.non_equality_tags
assert isinstance(ah.data, np.ndarray)

with pytest.raises(ValueError):
_ahh = transfer_to_numpy(ah, actx)

ad = transfer_from_numpy(ah, actx)
assert isinstance(ad.data, TaggableCLArray)
assert ad != ah
assert ad != a # copied DataWrappers compare unequal
assert ad.tags == ah.tags
assert ad.non_equality_tags == ah.non_equality_tags
assert np.array_equal(a.data.get(), ad.data.get())

with pytest.raises(ValueError):
_add = transfer_from_numpy(ad, actx)

# }}}

# {{{ test with DictOfNamedArrays

dag = pt.make_dict_of_named_arrays({
"a_expr": a + 2
})

dagh = transfer_to_numpy(dag, actx)
assert dagh != dag
assert isinstance(dagh["a_expr"].expr.bindings["_in0"].data, np.ndarray)

daghd = transfer_from_numpy(dagh, actx)
assert isinstance(daghd["a_expr"].expr.bindings["_in0"].data, TaggableCLArray)

# }}}


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down

0 comments on commit dee0ca4

Please sign in to comment.