Skip to content

Commit

Permalink
Fix hanging in tril
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed Jun 9, 2023
1 parent c286b52 commit 40be0e3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
4 changes: 4 additions & 0 deletions dpnp/backend/extensions/vm/div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ static sycl::event div_impl(sycl::queue exec_q,
{
type_utils::validate_type_for_device<T>(exec_q);

std::cout << typeid(T).name() << std::endl;

const T* a = reinterpret_cast<const T*>(in_a);
const T* b = reinterpret_cast<const T*>(in_b);
T* y = reinterpret_cast<T*>(out_y);
Expand Down Expand Up @@ -169,12 +171,14 @@ std::pair<sycl::event, sycl::event> div(sycl::queue exec_q,
throw py::value_error("Input and outpur arrays must be C-contiguous");
}

std::cout << "dst_typeid = " << int(dst_typeid) << std::endl;
auto div_fn = div_dispatch_vector[dst_typeid];
if (div_fn == nullptr)
{
throw py::value_error("No div implementation defined");
}
sycl::event sum_ev = div_fn(exec_q, src_nelems, src1_data, src2_data, dst_data, depends);
std::cout << "leaving div_fn" << std::endl;

sycl::event ht_ev = dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, {sum_ev});
return std::make_pair(ht_ev, sum_ev);
Expand Down
12 changes: 12 additions & 0 deletions dpnp/dpnp_algo/dpnp_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ def dpnp_divide(x1, x2, out=None, order='K'):
def _call_divide(src1, src2, dst, sycl_queue, depends=[]):
"""A callback to register in BinaryElementwiseFunc class of dpctl.tensor"""

print("_call_divide", sycl_queue)
print("src1 =", src1, type(src1), src1.sycl_queue, src1.device, src1.usm_type, src1.ndim, src1.dtype, src1.shape)
print(src1.__sycl_usm_array_interface__)
print(src1._byte_bounds)
print("src2 =", src2, type(src2), src2.sycl_queue, src2.device, src2.usm_type, src2.ndim, src2.dtype, src2.shape)
print(src2.__sycl_usm_array_interface__)
print(src2._byte_bounds)
print("dst =", dst, type(dst), dst.sycl_queue, dst.device, dst.usm_type, dst.ndim, dst.dtype, dst.shape)
print(dst.__sycl_usm_array_interface__)
print(dst._byte_bounds)

if vmi._can_call_div(sycl_queue, src1, src2, dst):
# call pybind11 extension for div() function from OneMKL VM
return vmi._div(sycl_queue, src1, src2, dst, depends)
Expand All @@ -86,5 +97,6 @@ def _call_divide(src1, src2, dst, sycl_queue, depends=[]):
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

func = BinaryElementwiseFunc("divide", ti._divide_result_type, _call_divide, _divide_docstring_)
print("func is done")
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)
4 changes: 4 additions & 0 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,10 @@ def __sub__(self, other):
# '__subclasshook__',

def __truediv__(self, other):
print("__truediv__")
print("self =", self, type(self), self.sycl_queue, self.device, self.usm_type, self.ndim, self.dtype)
print(self.__sycl_usm_array_interface__)
print("other =", other, type(other))
return dpnp.true_divide(self, other)

def __xor__(self, other):
Expand Down

0 comments on commit 40be0e3

Please sign in to comment.