Skip to content

Commit 33b6a96

Browse files
committed
Switch to using gpu_struct decorator
1 parent cf6f679 commit 33b6a96

File tree

6 files changed

+130
-105
lines changed

6 files changed

+130
-105
lines changed

python/cuda_parallel/cuda/parallel/experimental/_cccl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,4 +215,9 @@ def to_cccl_iter(array_or_iterator) -> Iterator:
215215

216216
def host_array_to_value(array: np.ndarray) -> Value:
217217
info = _numpy_type_to_info(array.dtype)
218-
return Value(info, array.ctypes.data)
218+
if isinstance(array, np.ndarray):
219+
data = array.ctypes.data
220+
else:
221+
# it's a gpudataclass:
222+
data = ctypes.cast(ctypes.pointer(array._data), ctypes.c_void_p)
223+
return Value(info, data)

python/cuda_parallel/cuda/parallel/experimental/_structwrapper.py

Lines changed: 0 additions & 86 deletions
This file was deleted.

python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,14 @@
1616
from .. import _cccl as cccl
1717
from .._bindings import get_bindings, get_paths
1818
from .._caching import CachableFunction, cache_with_key
19-
from .._structwrapper import wrap_struct
2019
from .._utils import cai
2120
from ..iterators._iterators import IteratorBase
22-
from ..typing import DeviceArrayLike
21+
from ..typing import DeviceArrayLike, GpuStruct
2322

2423

2524
class _Op:
26-
def __init__(self, dtype: np.dtype, op: Callable):
27-
# if h_init is a struct, wrap it in a Record type:
28-
if dtype.names is not None:
29-
value_type = wrap_struct(dtype)
30-
else:
31-
value_type = numba.from_dtype(dtype)
25+
def __init__(self, h_init: np.ndarray | GpuStruct, op: Callable):
26+
value_type = numba.typeof(h_init)
3227
self.ltoir, _ = cuda.compile(op, sig=(value_type, value_type), output="ltoir")
3328
self.name = op.__name__.encode("utf-8")
3429

@@ -56,7 +51,7 @@ def __init__(
5651
d_in: DeviceArrayLike | IteratorBase,
5752
d_out: DeviceArrayLike,
5853
op: Callable,
59-
h_init: np.ndarray,
54+
h_init: np.ndarray | GpuStruct,
6055
):
6156
d_in_cccl = cccl.to_cccl_iter(d_in)
6257
self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name(
@@ -67,11 +62,10 @@ def __init__(
6762
cc_major, cc_minor = cuda.get_current_device().compute_capability
6863
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths()
6964
bindings = get_bindings()
70-
self.op_wrapper = _Op(h_init.dtype, op)
65+
self.op_wrapper = _Op(h_init, op)
7166
d_out_cccl = cccl.to_cccl_iter(d_out)
7267
self.build_result = cccl.DeviceReduceBuildResult()
7368

74-
# TODO Figure out caching
7569
error = bindings.cccl_device_reduce_build(
7670
ctypes.byref(self.build_result),
7771
d_in_cccl,
@@ -88,7 +82,9 @@ def __init__(
8882
if error != enums.CUDA_SUCCESS:
8983
raise ValueError("Error building reduce")
9084

91-
def __call__(self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray):
85+
def __call__(
86+
self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | GpuStruct
87+
):
9288
d_in_cccl = cccl.to_cccl_iter(d_in)
9389
if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR:
9490
assert num_items is not None
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from dataclasses import dataclass
2+
from dataclasses import fields as dataclass_fields
3+
4+
import numba
5+
import numpy as np
6+
from numba.core import cgutils
7+
from numba.core.extending import (
8+
make_attribute_wrapper,
9+
models,
10+
register_model,
11+
typeof_impl,
12+
)
13+
from numba.core.typing import signature as nb_signature
14+
from numba.core.typing.templates import AttributeTemplate, ConcreteTemplate
15+
from numba.cuda.cudadecl import registry as cuda_registry
16+
from numba.extending import lower_builtin
17+
18+
from .typing import GpuStruct
19+
20+
21+
def gpu_struct(this: type) -> GpuStruct:
22+
anns = getattr(this, "__annotations__", {})
23+
24+
# set the .dtype attribute on the class for numpy compatibility:
25+
setattr(this, "dtype", np.dtype(list(anns.items())))
26+
27+
# define __post_init__ to create a ctypes object from the fields,
28+
# and keep a reference to it in the `._data` attribute.
29+
def __post_init__(self):
30+
ctypes_typ = np.ctypeslib.as_ctypes_type(this.dtype)
31+
self._data = ctypes_typ(*(getattr(self, name) for name in this.dtype.names))
32+
33+
setattr(this, "__post_init__", __post_init__)
34+
35+
# create a dataclass:
36+
this = dataclass(this)
37+
fields = dataclass_fields(this)
38+
39+
# define a numba type corresponding to the dataclass:
40+
class ThisType(numba.types.Type):
41+
def __init__(self):
42+
super().__init__(name=this.__name__)
43+
44+
this_type = ThisType()
45+
46+
@typeof_impl.register(this)
47+
def typeof_this(val, c):
48+
return ThisType()
49+
50+
# Data model corresponding to ThisType:
51+
@register_model(ThisType)
52+
class ThisModel(models.StructModel):
53+
def __init__(self, dmm, fe_type):
54+
members = [(field.name, numba.from_dtype(field.type)) for field in fields]
55+
super().__init__(dmm, fe_type, members)
56+
57+
# Typing for accessing attributes (fields) of the dataclass:
58+
class ThisAttrsTemplate(AttributeTemplate):
59+
pass
60+
61+
for field in fields:
62+
typ = field.type
63+
name = field.name
64+
65+
def resolver(self, this):
66+
return numba.from_dtype(typ)
67+
68+
setattr(ThisAttrsTemplate, f"resolve_{name}", resolver)
69+
70+
@cuda_registry.register_attr
71+
class ThisAttrs(ThisAttrsTemplate):
72+
key = this_type
73+
74+
# Lowering for attribute access:
75+
for field in fields:
76+
make_attribute_wrapper(ThisType, field.name, field.name)
77+
78+
# Register typing for constructor.
79+
@cuda_registry.register
80+
class TypeConstructor(ConcreteTemplate):
81+
key = this
82+
cases = [
83+
nb_signature(this_type, *[numba.from_dtype(field.type) for field in fields])
84+
]
85+
86+
cuda_registry.register_global(this, numba.types.Function(TypeConstructor))
87+
88+
def type_constructor(context, builder, sig, args):
89+
ty = sig.return_type
90+
retval = cgutils.create_struct_proxy(ty)(context, builder)
91+
for field, val in zip(fields, args):
92+
setattr(retval, field.name, val)
93+
return retval._getvalue()
94+
95+
lower_builtin(this, *[numba.from_dtype(field.type) for field in fields])(
96+
type_constructor
97+
)
98+
99+
return this

python/cuda_parallel/cuda/parallel/experimental/typing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
from typing_extensions import (
24
Protocol,
35
) # TODO: typing_extensions required for Python 3.7 docs env
@@ -10,3 +12,7 @@ class DeviceArrayLike(Protocol):
1012
"""
1113

1214
__cuda_array_interface__: dict
15+
16+
17+
# return type of @gpu_struct
18+
GpuStruct = Any

python/cuda_parallel/tests/test_reduce.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import cuda.parallel.experimental.algorithms as algorithms
1414
import cuda.parallel.experimental.iterators as iterators
15+
from cuda.parallel.experimental.gpu_struct import gpu_struct
1516

1617

1718
def random_int(shape, dtype):
@@ -553,15 +554,19 @@ def binary_op(x, y):
553554

554555

555556
def test_reduce_struct_type():
556-
def max_g_value(x, y):
557-
return x if x["g"] > y["g"] else y
557+
@gpu_struct
558+
class Pixel:
559+
r: np.int32
560+
g: np.int32
561+
b: np.int32
558562

559-
dtype = np.dtype([("r", "int32"), ("g", "int32"), ("b", "int32")])
560-
d_rgb = cp.random.randint(0, 256, (10, 3), dtype=cp.int32).view(dtype)
563+
def max_g_value(x, y):
564+
return x if x.g > y.g else y
561565

562-
d_out = cp.zeros(1, dtype)
566+
d_rgb = cp.random.randint(0, 256, (10, 3), dtype=np.int32).view(Pixel.dtype)
567+
d_out = cp.zeros(1, Pixel.dtype)
563568

564-
h_init = np.asarray([(0, 0, 0)], dtype=dtype)
569+
h_init = Pixel(0, 0, 0)
565570

566571
reduce_into = algorithms.reduce_into(d_rgb, d_out, max_g_value, h_init)
567572
temp_storage_bytes = reduce_into(None, d_rgb, d_out, len(d_rgb), h_init)

0 commit comments

Comments
 (0)