Skip to content

Commit

Permalink
Merge pull request #833 from IntelPython/refine-gemv-example
Browse files Browse the repository at this point in the history
Refine gemv example
  • Loading branch information
oleksandr-pavlyk authored May 13, 2022
2 parents b196e73 + 50a3243 commit b5d361f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 24 deletions.
30 changes: 13 additions & 17 deletions examples/pybind11/onemkl_gemv/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
import dpctl.tensor as dpt


def empty_like(A):
return dpt.empty(A.shape, A.dtype, device=A.device)


def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]):
"""Chebyshev iterative solver using SYCL routines"""
d = (lMax + lMin) / 2
Expand All @@ -33,9 +29,9 @@ def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]):
x = dpt.copy(x0)
exec_queue = A.sycl_queue
assert exec_queue == x.sycl_queue
Ax = empty_like(A[:, 0])
r = empty_like(Ax)
p = empty_like(Ax)
Ax = dpt.empty_like(A[:, 0])
r = dpt.empty_like(Ax)
p = dpt.empty_like(Ax)

e_x = dpctl.SyclEvent()
# Ax = A @ x
Expand Down Expand Up @@ -131,12 +127,13 @@ def cg_solve(A, b):
converged is False if solver has not converged, or the iteration number
"""
exec_queue = A.sycl_queue
x = dpt.zeros(b.shape, dtype=b.dtype)
Ap = empty_like(x)
x = dpt.zeros_like(b)
Ap = dpt.empty_like(x)

all_host_tasks = []
r = dpt.copy(b)
p = dpt.copy(b)
r = dpt.copy(b) # synchronous copy
p = dpt.copy(b) # synchronous copy

rsold = sycl_gemm.norm_squared_blocking(exec_queue, r)
if rsold < 1e-20:
return (b, 0)
Expand All @@ -147,22 +144,21 @@ def cg_solve(A, b):
e_x = dpctl.SyclEvent()
for i in range(max_iters):
# Ap = A @ p
he_dot, e_dot = sycl_gemm.gemv(exec_queue, A, p, Ap, depends=[e_p])
all_host_tasks.append(he_dot)
he_gemv, e_gemv = sycl_gemm.gemv(exec_queue, A, p, Ap, depends=[e_p])
all_host_tasks.append(he_gemv)
# alpha = rsold / dot(p, Ap)
alpha = rsold / sycl_gemm.dot_blocking(
exec_queue, p, Ap, depends=[e_dot]
exec_queue, p, Ap, depends=[e_p, e_gemv]
)
# x = x + alpha * p
he1_x_update, e1_x_update = sycl_gemm.axpby_inplace(
exec_queue, alpha, p, 1, x, depends=[e_p, e_x]
exec_queue, alpha, p, 1, x, depends=[e_x]
)
all_host_tasks.append(he1_x_update)
e_x = e1_x_update

# r = r - alpha * Ap
he2_r_update, e2_r_update = sycl_gemm.axpby_inplace(
exec_queue, -alpha, Ap, 1, r, depends=[e_p]
exec_queue, -alpha, Ap, 1, r
)
all_host_tasks.append(he2_r_update)

Expand Down
4 changes: 4 additions & 0 deletions examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ py::object py_dot_blocking(sycl::queue q,
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
T res_v{};
q.copy<T>(res_usm, &res_v, 1, {dot_ev}).wait_and_throw();
sycl::free(res_usm, q);
res = py::float_(res_v);
}
else if (v1_typenum == UAR_FLOAT) {
Expand All @@ -507,6 +508,7 @@ py::object py_dot_blocking(sycl::queue q,
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
T res_v(0);
q.copy<T>(res_usm, &res_v, 1, {dot_ev}).wait_and_throw();
sycl::free(res_usm, q);
res = py::float_(res_v);
}
else if (v1_typenum == UAR_CDOUBLE) {
Expand All @@ -517,6 +519,7 @@ py::object py_dot_blocking(sycl::queue q,
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
T res_v{};
q.copy<T>(res_usm, &res_v, 1, {dotc_ev}).wait_and_throw();
sycl::free(res_usm, q);
res = py::cast(res_v);
}
else if (v1_typenum == UAR_CFLOAT) {
Expand All @@ -527,6 +530,7 @@ py::object py_dot_blocking(sycl::queue q,
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
T res_v{};
q.copy<T>(res_usm, &res_v, 1, {dotc_ev}).wait_and_throw();
sycl::free(res_usm, q);
res = py::cast(res_v);
}
else {
Expand Down
9 changes: 4 additions & 5 deletions examples/pybind11/onemkl_gemv/sycl_gemm/cg_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,16 @@ int cg_solve(sycl::queue exec_q,
}

int converged_at = max_iters;
sycl::event prev_dep = copy_to_p_ev;
sycl::event e_p = copy_to_p_ev;
sycl::event e_x = fill_ev;

for (std::int64_t i = 0; i < max_iters; ++i) {
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv(
exec_q, oneapi::mkl::transpose::N, n, n, T(1), Amat, n, p, 1, T(0),
Ap, 1, {prev_dep});
Ap, 1, {e_p});

sycl::event pAp_dot_ev = oneapi::mkl::blas::row_major::dot(
exec_q, n, p, 1, Ap, 1, pAp_dot_dev, {prev_dep, gemv_ev});
exec_q, n, p, 1, Ap, 1, pAp_dot_dev, {e_p, gemv_ev});

T pAp_dot_host{};
exec_q.copy<T>(pAp_dot_dev, &pAp_dot_host, 1, {pAp_dot_ev})
Expand All @@ -212,8 +212,7 @@ int cg_solve(sycl::queue exec_q,
T beta = rs_new / rs_old;

// p = r + beta * p
prev_dep =
detail::axpby_inplace(exec_q, n, T(1), r, beta, p, {r_update_ev});
e_p = detail::axpby_inplace(exec_q, n, T(1), r, beta, p, {r_update_ev});
e_x = x_update_ev;

rs_old = rs_new;
Expand Down
15 changes: 13 additions & 2 deletions examples/pybind11/onemkl_gemv/sycl_timing_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@
A = dpt.asarray(Anp, "d", device=api_dev)
b = dpt.asarray(bnp, "d", device=api_dev)

assert A.sycl_queue == b.sycl_queue

# allocate buffers for computation of residual
r = dpt.empty_like(b)
delta = dpt.empty_like(b)

timer = dpctl.SyclTimer(time_scale=1e3)

iters = []
Expand All @@ -64,17 +70,22 @@

print(i, "(host_dt, device_dt)=", timer.dt)
iters.append(conv_in)
assert x.usm_type == A.usm_type
assert x.usm_type == b.usm_type
assert x.sycl_queue == A.sycl_queue
assert x.sycl_queue == b.sycl_queue

print("Converged in: ", iters)

r = dpt.empty_like(b)
hev, ev = sycl_gemm.gemv(q, A, x, r)
delta = dpt.empty_like(b)
hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev])
rs = sycl_gemm.norm_squared_blocking(q, delta)
dpctl.SyclEvent.wait_for([hev, hev2])
print(f"Python solution residual norm squared: {rs}")

assert q == api_dev.sycl_queue
print("")

x_cpp = dpt.empty_like(b)
iters = []
for i in range(6):
Expand Down

0 comments on commit b5d361f

Please sign in to comment.