Skip to content

Commit

Permalink
Release GIL on GPU calls (#528)
Browse files Browse the repository at this point in the history
Previous code was holding on the GIL when training models on the GPU
or when calculating results. This caused some jank on progress bars
displaying in jupyter notebooks, as well as just being poor form.
Fix by releasing the GIL before starting on gpu code that will take
an appreciable amount of time.
  • Loading branch information
benfred authored Jan 25, 2022
1 parent eccb1a9 commit af44c79
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions implicit/gpu/_cuda.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ cdef class KnnQuery(object):
x = indices
y = distances

self.c_knn.topk(dereference(items.c_matrix), dereference(queries), k,
&x[0, 0], &y[0, 0], c_item_norms, c_query_filter, c_item_filter)
with nogil:
self.c_knn.topk(dereference(items.c_matrix), dereference(queries), k,
&x[0, 0], &y[0, 0], c_item_norms, c_query_filter, c_item_filter)

return indices, distances

Expand Down Expand Up @@ -217,20 +218,22 @@ cdef class LeastSquaresSolver(object):
self.c_solver = new CppLeastSquaresSolver()

def least_squares(self, CSRMatrix cui, Matrix X, Matrix YtY, Matrix Y, int cg_steps):
self.c_solver.least_squares(dereference(cui.c_matrix), X.c_matrix,
dereference(YtY.c_matrix), dereference(Y.c_matrix),
cg_steps)

with nogil:
self.c_solver.least_squares(dereference(cui.c_matrix), X.c_matrix,
dereference(YtY.c_matrix), dereference(Y.c_matrix),
cg_steps)
def calculate_loss(self, CSRMatrix cui, Matrix X, Matrix Y,
float regularization):
return self.c_solver.calculate_loss(dereference(cui.c_matrix), dereference(X.c_matrix),
dereference(Y.c_matrix), regularization)
cdef float loss
with nogil:
loss = self.c_solver.calculate_loss(dereference(cui.c_matrix), dereference(X.c_matrix),
dereference(Y.c_matrix), regularization)
return loss

def calculate_yty(self, Matrix Y, Matrix YtY, float regularization):
if YtY is None:
YtY = Matrix(None)

self.c_solver.calculate_yty(dereference(Y.c_matrix), YtY.c_matrix, regularization)
def calculate_yty(self, Matrix Y, Matrix YtY, float regularization):
with nogil:
self.c_solver.calculate_yty(dereference(Y.c_matrix), YtY.c_matrix, regularization)

def __dealloc__(self):
del self.c_solver
Expand All @@ -249,9 +252,10 @@ def get_device_count():
def bpr_update(IntVector userids, IntVector itemids, IntVector indptr,
Matrix X, Matrix Y,
float learning_rate, float regularization, long seed, bool verify_negative):
ret = cpp_bpr_update(dereference(userids.c_vector),
dereference(itemids.c_vector),
dereference(indptr.c_vector),
X.c_matrix, Y.c_matrix,
learning_rate, regularization, seed, verify_negative)
with nogil:
ret = cpp_bpr_update(dereference(userids.c_vector),
dereference(itemids.c_vector),
dereference(indptr.c_vector),
X.c_matrix, Y.c_matrix,
learning_rate, regularization, seed, verify_negative)
return ret.first, ret.second

0 comments on commit af44c79

Please sign in to comment.