Skip to content

Commit 0bb4d73

Browse files
shwinadavebayer
authored andcommitted
cuda.parallel: Support structured types as algorithm inputs (NVIDIA#3218)
* Introduce gpu_struct decorator and typing * Enable `reduce` to accept arrays of structs as inputs * Add test for reducing arrays-of-struct * Update documentation * Use a numpy array rather than ctypes object * Change zeros -> empty for output array and temp storage * Add a TODO for typing GpuStruct * Documentation udpates * Remove test_reduce_struct_type from test_reduce.py * Revert to `to_cccl_value()` accepting ndarray + GpuStruct * Bump copyrights --------- Co-authored-by: Ashwin Srinath <shwina@users.noreply.github.com>
1 parent cae9d85 commit 0bb4d73

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

python/cuda_parallel/cuda/parallel/experimental/_cccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
from numba import cuda, types
1313

14-
from ._utils.protocols import get_dtype, is_contiguous
14+
from ._utils.cai import get_dtype, is_contiguous
1515
from .iterators._iterators import IteratorBase
1616
from .typing import DeviceArrayLike, GpuStruct
1717

python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,7 @@ def __init__(
8989
raise ValueError("Error building reduce")
9090

9191
def __call__(
92-
self,
93-
temp_storage,
94-
d_in,
95-
d_out,
96-
num_items: int,
97-
h_init: np.ndarray | GpuStruct,
98-
stream=None,
92+
self, temp_storage, d_in, d_out, num_items: int, h_init: np.ndarray | GpuStruct
9993
):
10094
d_in_cccl = cccl.to_cccl_iter(d_in)
10195
if d_in_cccl.type.value == cccl.IteratorKind.ITERATOR:
@@ -110,7 +104,7 @@ def __call__(
110104
self._ctor_d_in_cccl_type_enum_name,
111105
cccl.type_enum_as_name(d_in_cccl.value_type.type.value),
112106
)
113-
_dtype_validation(self._ctor_d_out_dtype, protocols.get_dtype(d_out))
107+
_dtype_validation(self._ctor_d_out_dtype, cai.get_dtype(d_out))
114108
_dtype_validation(self._ctor_init_dtype, h_init.dtype)
115109
stream_handle = protocols.validate_and_get_stream(stream)
116110
bindings = get_bindings()
@@ -132,7 +126,7 @@ def __call__(
132126
ctypes.c_ulonglong(num_items),
133127
self.op_wrapper.handle(),
134128
cccl.to_cccl_value(h_init),
135-
stream_handle,
129+
None,
136130
)
137131
if error != enums.CUDA_SUCCESS:
138132
raise ValueError("Error reducing")

0 commit comments

Comments
 (0)