Skip to content

Commit

Permalink
feat: add CUDA kernels (need to be fixed)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManasviGoyal committed Jan 30, 2024
1 parent e04b01e commit e591460
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

// BEGIN PYTHON
// def f(grid, block, args):
// (numnull, mask, length, validwhen, invocation_index, err_code) = args
// scan_in_array = cupy.empty(length, dtype=cupy.int64)
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_ByteMaskedArray_numnull_a', numnull.dtype, mask.dtype]))(grid, block, (numnull, mask, length, validwhen, scan_in_array, invocation_index, err_code))
// scan_in_array = inclusive_scan(grid, block, (scan_in_array, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_ByteMaskedArray_numnull_b', numnull.dtype, mask.dtype]))(grid, block, (numnull, mask, length, validwhen, scan_in_array, invocation_index, err_code))
// out["awkward_ByteMaskedArray_numnull_a", {dtype_specializations}] = None
// out["awkward_ByteMaskedArray_numnull_b", {dtype_specializations}] = None
// END PYTHON

template <typename T, typename C>
__global__ void
awkward_ByteMaskedArray_numnull_a(T* numnull,
const C* mask,
int64_t length,
bool validwhen,
int64_t* scan_in_array,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < length) {
*numnull = 0;
if ((mask[thread_id] != 0) != validwhen) {
scan_in_array[thread_id] = 1;
}
else {
scan_in_array[thread_id] = 0;
}
}
}
}

template <typename T, typename C>
__global__ void
awkward_ByteMaskedArray_numnull_b(T* numnull,
const C* mask,
int64_t length,
bool validwhen,
int64_t* scan_in_array,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
*numnull = scan_in_array[length - 1];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

// BEGIN PYTHON
// def f(grid, block, args):
// (toindex, length, invocation_index, err_code) = args
// scan_in_array = cupy.empty(length, dtype=cupy.int64)
// scan_in_array_n_non_null = cupy.empty(length, dtype=cupy.int64)
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_Index_nones_as_index_a", toindex.dtype]))(grid, block, (toindex, length, scan_in_array, scan_in_array_n_non_null, invocation_index, err_code))
// scan_in_array = inclusive_scan(grid, block, (scan_in_array, invocation_index, err_code))
// scan_in_array_n_non_null = inclusive_scan(grid, block, (scan_in_array_n_non_null, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_Index_nones_as_index_b", toindex.dtype]))(grid, block, (toindex, length, scan_in_array, scan_in_array_n_non_null, invocation_index, err_code))
// out["awkward_Index_nones_as_index_a", {dtype_specializations}] = None
// out["awkward_Index_nones_as_index_b", {dtype_specializations}] = None
// END PYTHON

template <typename T>
__global__ void
awkward_Index_nones_as_index_a(T* toindex,
int64_t length,
int64_t* scan_in_array,
int64_t* scan_in_array_n_non_null,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id < length) {
if (toindex[thread_id] != -1) {
scan_in_array[thread_id] = 1;
scan_in_array_n_non_null[thread_id] = 0;
}
else {
scan_in_array_n_non_null[thread_id] = 1;
scan_in_array[thread_id] = 0;
}
}
}
}

template <typename T>
__global__ void
awkward_Index_nones_as_index_b(T* toindex,
int64_t length,
int64_t* scan_in_array,
int64_t* scan_in_array_n_non_null,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t n_non_null = scan_in_array[length - 1];
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id < length) {
toindex[thread_id] == -1 ? toindex[thread_id] = (n_non_null + scan_in_array_n_non_null[thread_id] - 1): toindex[thread_id];
}
}
}

// fails for [-1]
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

// BEGIN PYTHON
// def f(grid, block, args):
// (numnull, fromindex, lenindex, invocation_index, err_code) = args
// scan_in_array = cupy.empty(lenindex, dtype=cupy.int64)
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_IndexedArray_numnull_a', numnull.dtype, fromindex.dtype]))(grid, block, (numnull, fromindex, lenindex, scan_in_array, invocation_index, err_code))
// scan_in_array = inclusive_scan(grid, block, (scan_in_array, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_IndexedArray_numnull_b', numnull.dtype, fromindex.dtype]))(grid, block, (numnull, fromindex, lenindex, scan_in_array, invocation_index, err_code))
// out["awkward_IndexedArray_numnull_a", {dtype_specializations}] = None
// out["awkward_IndexedArray_numnull_b", {dtype_specializations}] = None
// END PYTHON

template <typename T, typename C>
__global__ void
awkward_IndexedArray_numnull_a(T* numnull,
const C* fromindex,
int64_t lenindex,
int64_t* scan_in_array,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < lenindex) {
if (fromindex[thread_id] < 0) {
scan_in_array[thread_id] = 1;
}
else {
scan_in_array[thread_id] = 0;
}
}
}
}

template <typename T, typename C>
__global__ void
awkward_IndexedArray_numnull_b(T* numnull,
const C* fromindex,
int64_t lenindex,
int64_t* scan_in_array,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
*numnull = scan_in_array[lenindex - 1];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

// BEGIN PYTHON
// def f(grid, block, args):
// (numnull, tolength, fromindex, lenindex, invocation_index, err_code) = args
// scan_in_array = cupy.empty(lenindex, dtype=cupy.int64)
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_IndexedArray_numnull_parents_a', numnull.dtype, tolength.dtype, fromindex.dtype]))(grid, block, (numnull, tolength, fromindex, lenindex, scan_in_array, invocation_index, err_code))
// scan_in_array = inclusive_scan(grid, block, (scan_in_array, invocation_index, err_code))
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_IndexedArray_numnull_parents_b', numnull.dtype, tolength.dtype, fromindex.dtype]))(grid, block, (numnull, tolength, fromindex, lenindex, scan_in_array, invocation_index, err_code))
// out["awkward_IndexedArray_numnull_parents_a", {dtype_specializations}] = None
// out["awkward_IndexedArray_numnull_parents_b", {dtype_specializations}] = None
// END PYTHON

template <typename T, typename C, typename U>
__global__ void
awkward_IndexedArray_numnull_parents_a(T* numnull,
U* tolength,
const C* fromindex,
int64_t lenindex,
int64_t* scan_in_array,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < lenindex) {
if (fromindex[thread_id] < 0) {
scan_in_array[thread_id] = 1;
}
else {
scan_in_array[thread_id] = 0;
}
}
}
}

template <typename T, typename C, typename U>
__global__ void
awkward_IndexedArray_numnull_parents_b(T* numnull,
U* tolength,
const C* fromindex,
int64_t lenindex,
int64_t* scan_in_array,
uint64_t invocation_index,
uint64_t* err_code) {
if (err_code[0] == NO_ERROR) {
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id < lenindex) {
if (fromindex[thread_id] < 0) {
numnull[thread_id] = 1;
}
else {
numnull[thread_id] = 0;
}
}
*tolength = scan_in_array[lenindex - 1];
}
}

// fails for [-1]

0 comments on commit e591460

Please sign in to comment.