Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release GIL on GPU calls #528

Merged
merged 1 commit into from
Jan 25, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions implicit/gpu/_cuda.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ cdef class KnnQuery(object):
x = indices
y = distances

self.c_knn.topk(dereference(items.c_matrix), dereference(queries), k,
&x[0, 0], &y[0, 0], c_item_norms, c_query_filter, c_item_filter)
with nogil:
self.c_knn.topk(dereference(items.c_matrix), dereference(queries), k,
&x[0, 0], &y[0, 0], c_item_norms, c_query_filter, c_item_filter)

return indices, distances

Expand Down Expand Up @@ -217,20 +218,22 @@ cdef class LeastSquaresSolver(object):
self.c_solver = new CppLeastSquaresSolver()

def least_squares(self, CSRMatrix cui, Matrix X, Matrix YtY, Matrix Y, int cg_steps):
self.c_solver.least_squares(dereference(cui.c_matrix), X.c_matrix,
dereference(YtY.c_matrix), dereference(Y.c_matrix),
cg_steps)

with nogil:
self.c_solver.least_squares(dereference(cui.c_matrix), X.c_matrix,
dereference(YtY.c_matrix), dereference(Y.c_matrix),
cg_steps)
def calculate_loss(self, CSRMatrix cui, Matrix X, Matrix Y,
float regularization):
return self.c_solver.calculate_loss(dereference(cui.c_matrix), dereference(X.c_matrix),
dereference(Y.c_matrix), regularization)
cdef float loss
with nogil:
loss = self.c_solver.calculate_loss(dereference(cui.c_matrix), dereference(X.c_matrix),
dereference(Y.c_matrix), regularization)
return loss

def calculate_yty(self, Matrix Y, Matrix YtY, float regularization):
if YtY is None:
YtY = Matrix(None)

self.c_solver.calculate_yty(dereference(Y.c_matrix), YtY.c_matrix, regularization)
def calculate_yty(self, Matrix Y, Matrix YtY, float regularization):
with nogil:
self.c_solver.calculate_yty(dereference(Y.c_matrix), YtY.c_matrix, regularization)

def __dealloc__(self):
del self.c_solver
Expand All @@ -249,9 +252,10 @@ def get_device_count():
def bpr_update(IntVector userids, IntVector itemids, IntVector indptr,
Matrix X, Matrix Y,
float learning_rate, float regularization, long seed, bool verify_negative):
ret = cpp_bpr_update(dereference(userids.c_vector),
dereference(itemids.c_vector),
dereference(indptr.c_vector),
X.c_matrix, Y.c_matrix,
learning_rate, regularization, seed, verify_negative)
with nogil:
ret = cpp_bpr_update(dereference(userids.c_vector),
dereference(itemids.c_vector),
dereference(indptr.c_vector),
X.c_matrix, Y.c_matrix,
learning_rate, regularization, seed, verify_negative)
return ret.first, ret.second