Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularized the examples/pybind11/onemkl_gemv/sycl_timing_solver.py #838

Merged
merged 1 commit into from
May 17, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 59 additions & 34 deletions examples/pybind11/onemkl_gemv/sycl_timing_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,41 +63,66 @@

timer = dpctl.SyclTimer(time_scale=1e3)

iters = []
for i in range(6):
with timer(api_dev.sycl_queue):
x, conv_in = solve.cg_solve(A, b)

print(i, "(host_dt, device_dt)=", timer.dt)
iters.append(conv_in)
assert x.usm_type == A.usm_type
assert x.usm_type == b.usm_type
assert x.sycl_queue == A.sycl_queue
assert x.sycl_queue == b.sycl_queue

print("Converged in: ", iters)

hev, ev = sycl_gemm.gemv(q, A, x, r)
hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev])
rs = sycl_gemm.norm_squared_blocking(q, delta)
dpctl.SyclEvent.wait_for([hev, hev2])
print(f"Python solution residual norm squared: {rs}")

def time_python_solver(num_iters=6):
"""
Time solver implemented in Python with use of asynchronous
SYCL kernel submission.
"""
global x
iters = []
for i in range(num_iters):
with timer(api_dev.sycl_queue):
x, conv_in = solve.cg_solve(A, b)

print(i, "(host_dt, device_dt)=", timer.dt)
iters.append(conv_in)
assert x.usm_type == A.usm_type
assert x.usm_type == b.usm_type
assert x.sycl_queue == A.sycl_queue
assert x.sycl_queue == b.sycl_queue

return iters


def time_cpp_solver(num_iters=6):
"""
Time solver implemented in C++ but callable from Python.
C++ implementation uses the same algorithm and submits same
kernels asynchronously, but bypasses Python binding overhead
incurred when algorithm is driver from Python.
"""
global x_cpp
x_cpp = dpt.empty_like(b)
iters = []
for i in range(num_iters):
with timer(api_dev.sycl_queue):
conv_in = sycl_gemm.cpp_cg_solve(q, A, b, x_cpp)

print(i, "(host_dt, device_dt)=", timer.dt)
iters.append(conv_in)

return iters


def compute_residual(x):
"""
Computes quality of the solution, `norm_squared(A@x - b)`.
"""
assert isinstance(x, dpt.usm_ndarray)
q = A.sycl_queue
hev, ev = sycl_gemm.gemv(q, A, x, r)
hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev])
rs = sycl_gemm.norm_squared_blocking(q, delta)
dpctl.SyclEvent.wait_for([hev, hev2])
return rs


print("Converged in: ", time_python_solver())
print(f"Python solution residual norm squared: {compute_residual(x)}")

assert q == api_dev.sycl_queue
print("")

x_cpp = dpt.empty_like(b)
iters = []
for i in range(6):
with timer(api_dev.sycl_queue):
conv_in = sycl_gemm.cpp_cg_solve(q, A, b, x_cpp)

print(i, "(host_dt, device_dt)=", timer.dt)
iters.append(conv_in)

print("Converged in: ", iters)
hev, ev = sycl_gemm.gemv(q, A, x_cpp, r)
hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev])
rs = sycl_gemm.norm_squared_blocking(q, delta)
dpctl.SyclEvent.wait_for([hev, hev2])
print(f"cpp_cg_solve solution residual norm squared: {rs}")
print("Converged in: ", time_cpp_solver())
print(f"cpp_cg_solve solution residual norm squared: {compute_residual(x_cpp)}")