Skip to content

Queue inference for numba_dpex kernel args returns incorrect values #1021

@diptorupd

Description

@diptorupd

The following code should raise a ComputeFollowsDataInferenceError, however such an exception is not raised:

import dpnp
import dpctl

import numba_dpex as dpex
from numba_dpex.core.exceptions import ComputeFollowsDataInferenceError


@dpex.kernel
def sum_kernel(a, b, c):
    i = dpex.get_global_id(0)
    c[i] = a[i] + b[i]

q1 = dpctl.SyclQueue()
q2 = dpctl.SyclQueue()

print(q1 == q2)
a = dpnp.ones(1, sycl_queue=q1)
b = dpnp.ones(1, sycl_queue=q2)
c = dpnp.empty_like(a)

sum_kernel[dpex.Range(1)](a, b, c)

Inside dispatcher when the type of the input arguments is inferred I see that all the args are inferred as usm_ndarray type with the same queue.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions