Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle more than 2**31 parameters on GPU models #599

Merged
merged 2 commits into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions implicit/gpu/_cuda.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ cdef class KnnQuery(object):
cdef CppMatrix * queries = m.c_matrix
cdef CppCOOMatrix * c_query_filter = NULL
cdef CppVector[int] * c_item_filter = NULL
cdef int rows = queries.rows
cdef size_t rows = queries.rows
cdef int[:, :] x
cdef float[:, :] y

Expand Down Expand Up @@ -154,7 +154,7 @@ cdef class Matrix(object):
rows = IntVector(np.array(rowids).astype("int32"))
self.c_matrix.assign_rows(dereference(rows.c_vector), dereference(other.c_matrix))

def resize(self, int rows, int cols):
def resize(self, size_t rows, size_t cols):
self.c_matrix.resize(rows, cols)

def to_numpy(self):
Expand Down
16 changes: 8 additions & 8 deletions implicit/gpu/als.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace gpu {

using std::invalid_argument;

__global__ void least_squares_cg_kernel(int factors, int user_count,
int item_count, float *X,
__global__ void least_squares_cg_kernel(int factors, size_t user_count,
size_t item_count, float *X,
const float *Y, const float *YtY,
const int *indptr, const int *indices,
const float *data, int cg_steps) {
Expand Down Expand Up @@ -94,7 +94,7 @@ __global__ void least_squares_cg_kernel(int factors, int user_count,
// complain and don't let it perpetuate
if (isnan(rsold)) {
if (threadIdx.x == 0) {
printf("Warning NaN Detected in row %d of %d\n", u, user_count);
printf("Warning NaN Detected in row %i of %lu\n", u, user_count);
}
x[threadIdx.x] = 0;
} else {
Expand All @@ -103,7 +103,7 @@ __global__ void least_squares_cg_kernel(int factors, int user_count,
}
}

__global__ void l2_regularize_kernel(int factors, float regularization,
__global__ void l2_regularize_kernel(size_t factors, float regularization,
float *YtY) {
YtY[threadIdx.x * factors + threadIdx.x] += regularization;
}
Expand All @@ -120,7 +120,7 @@ void LeastSquaresSolver::calculate_yty(const Matrix &Y, Matrix *YtY,
// calculate YtY: note this expects col-major (and we have row-major
// basically) so that we're inverting the CUBLAS_OP_T/CU_BLAS_OP_N ordering to
// overcome this (like calculate YYt instead of YtY)
int factors = Y.cols, item_count = Y.rows;
size_t factors = Y.cols, item_count = Y.rows;
float alpha = 1.0, beta = 0.;
CHECK_CUBLAS(cublasSgemm(blas_handle, CUBLAS_OP_N, CUBLAS_OP_T, factors,
factors, item_count, &alpha, Y.data, factors, Y.data,
Expand Down Expand Up @@ -164,8 +164,8 @@ void LeastSquaresSolver::least_squares(const CSRMatrix &Cui, Matrix *X,
CHECK_CUDA(cudaDeviceSynchronize());
}

__global__ void calculate_loss_kernel(int factors, int user_count,
int item_count, const float *X,
__global__ void calculate_loss_kernel(int factors, size_t user_count,
size_t item_count, const float *X,
const float *Y, const float *YtY,
const int *indptr, const int *indices,
const float *data, float regularization,
Expand Down Expand Up @@ -220,7 +220,7 @@ __global__ void calculate_loss_kernel(int factors, int user_count,
float LeastSquaresSolver::calculate_loss(const CSRMatrix &Cui, const Matrix &X,
const Matrix &Y,
float regularization) {
int item_count = Y.rows, factors = Y.cols, user_count = X.rows;
size_t item_count = Y.rows, factors = Y.cols, user_count = X.rows;

Matrix YtY(factors, factors, NULL);
calculate_yty(Y, &YtY, regularization);
Expand Down
2 changes: 1 addition & 1 deletion implicit/gpu/bpr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace gpu {

__global__ void bpr_update_kernel(int samples, unsigned int *random_likes,
unsigned int *random_dislikes, int *itemids,
int *userids, int *indptr, int factors,
int *userids, int *indptr, size_t factors,
float *X, float *Y, float learning_rate,
float reg, bool verify_negative_samples,
int *stats) {
Expand Down
6 changes: 3 additions & 3 deletions implicit/gpu/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ void KnnQuery::topk(const Matrix &items, const Matrix &query, int k,
batch_size *= 0.15;
}

batch_size = std::min(batch_size, static_cast<size_t>(query.rows));
batch_size = std::min(batch_size, query.rows);
batch_size = std::max(batch_size, static_cast<size_t>(1));

rmm::device_uvector<float> temp_mem(batch_size * temp_distances_cols, stream,
Expand All @@ -170,7 +170,7 @@ void KnnQuery::topk(const Matrix &items, const Matrix &query, int k,
}

for (int start = 0; start < query.rows; start += batch_size) {
auto end = std::min(query.rows, start + static_cast<int>(batch_size));
auto end = std::min(query.rows, start + batch_size);

Matrix batch(query, start, end);
temp_distances.rows = batch.rows;
Expand Down Expand Up @@ -249,7 +249,7 @@ void KnnQuery::topk(const Matrix &items, const Matrix &query, int k,

void KnnQuery::argpartition(const Matrix &items, int k, int *indices,
float *distances, bool allow_tiling) {
k = std::min(k, items.cols);
k = std::min(k, static_cast<int>(items.cols));

if (k >= GPU_MAX_SELECTION_K) {
rmm::cuda_stream_view stream;
Expand Down
40 changes: 20 additions & 20 deletions implicit/gpu/matrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace implicit {
namespace gpu {
template <typename T>
Vector<T>::Vector(int size, const T *host_data)
Vector<T>::Vector(size_t size, const T *host_data)
: size(size),
storage(new rmm::device_uvector<T>(size, rmm::cuda_stream_view())),
data(storage->data()) {
Expand All @@ -28,15 +28,15 @@ template struct Vector<char>;
template struct Vector<int>;
template struct Vector<float>;

Matrix::Matrix(const Matrix &other, int rowid)
Matrix::Matrix(const Matrix &other, size_t rowid)
: rows(1), cols(other.cols), data(other.data + rowid * other.cols),
storage(other.storage) {
if (rowid >= other.rows) {
throw std::invalid_argument("row index out of bounds for matrix");
}
}

Matrix::Matrix(const Matrix &other, int start_rowid, int end_rowid)
Matrix::Matrix(const Matrix &other, size_t start_rowid, size_t end_rowid)
: rows(end_rowid - start_rowid), cols(other.cols),
data(other.data + start_rowid * other.cols), storage(other.storage) {
if (end_rowid < start_rowid) {
Expand All @@ -47,13 +47,13 @@ Matrix::Matrix(const Matrix &other, int start_rowid, int end_rowid)
}
}

void copy_rowids(const float *input, const int *rowids, int rows, int cols,
float *output) {
void copy_rowids(const float *input, const int *rowids, size_t rows,
size_t cols, float *output) {
// copy rows over
auto count = thrust::make_counting_iterator<int>(0);
thrust::for_each(count, count + (rows * cols), [=] __device__(int i) {
int col = i % cols;
int row = rowids[i / cols];
auto count = thrust::make_counting_iterator<size_t>(0);
thrust::for_each(count, count + (rows * cols), [=] __device__(size_t i) {
size_t col = i % cols;
size_t row = rowids[i / cols];
output[i] = input[col + row * cols];
});
}
Expand All @@ -66,7 +66,7 @@ Matrix::Matrix(const Matrix &other, const Vector<int> &rowids)
copy_rowids(other.data, rowids.data, rows, cols, data);
}

Matrix::Matrix(int rows, int cols, float *host_data, bool allocate)
Matrix::Matrix(size_t rows, size_t cols, float *host_data, bool allocate)
: rows(rows), cols(cols) {
if (allocate) {
storage.reset(
Expand All @@ -81,7 +81,7 @@ Matrix::Matrix(int rows, int cols, float *host_data, bool allocate)
}
}

void Matrix::resize(int rows, int cols) {
void Matrix::resize(size_t rows, size_t cols) {
if (cols != this->cols) {
throw std::logic_error(
"changing number of columns in Matrix::resize is not implemented yet");
Expand All @@ -95,7 +95,7 @@ void Matrix::resize(int rows, int cols) {
CHECK_CUDA(cudaMemcpy(new_storage->data(), data,
this->rows * this->cols * sizeof(float),
cudaMemcpyDeviceToDevice));
int extra_rows = rows - this->rows;
size_t extra_rows = rows - this->rows;
CHECK_CUDA(cudaMemset(new_storage->data() + this->rows * this->cols, 0,
extra_rows * cols * sizeof(float)));
storage.reset(new_storage);
Expand All @@ -110,24 +110,24 @@ void Matrix::assign_rows(const Vector<int> &rowids, const Matrix &other) {
"column dimensionality mismatch in Matrix::assign_rows");
}

auto count = thrust::make_counting_iterator<int>(0);
int other_cols = other.cols, other_rows = other.rows;
auto count = thrust::make_counting_iterator<size_t>(0);
size_t other_cols = other.cols, other_rows = other.rows;

int *rowids_data = rowids.data;
float *other_data = other.data;
float *self_data = data;

thrust::for_each(count, count + (other_rows * other_cols),
[=] __device__(int i) {
int col = i % other_cols;
int row = rowids_data[i / other_cols];
int idx = col + row * other_cols;
[=] __device__(size_t i) {
size_t col = i % other_cols;
size_t row = rowids_data[i / other_cols];
size_t idx = col + row * other_cols;
self_data[idx] = other_data[i];
});
}

__global__ void calculate_norms_kernel(const float *input, int rows, int cols,
float *output) {
__global__ void calculate_norms_kernel(const float *input, size_t rows,
size_t cols, float *output) {
static __shared__ float shared[32];
for (int i = blockIdx.x; i < rows; i += gridDim.x) {
float value = input[i * cols + threadIdx.x];
Expand Down
14 changes: 7 additions & 7 deletions implicit/gpu/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ namespace implicit {
namespace gpu {
// Thin wrappers of CUDA memory: copies to from host, frees in destructor etc
template <typename T> struct Vector {
Vector(int size, const T *data = NULL);
Vector(size_t size, const T *data = NULL);
void to_host(T *output) const;

std::shared_ptr<rmm::device_uvector<T>> storage;
int size;
size_t size;
T *data;
};

Expand All @@ -21,28 +21,28 @@ struct Matrix {
// device (if allocate=True and data != null). If allocate=false, this assumes
// the data is preallocated on the gpu (cupy etc) and doesn't allocate any new
// storage
Matrix(int rows, int cols, float *data = NULL, bool allocate = true);
Matrix(size_t rows, size_t cols, float *data = NULL, bool allocate = true);

// Create a new Matrix by slicing a single row from an existing one. The
// underlying storage buffer is shared in this case.
Matrix(const Matrix &other, int rowid);
Matrix(const Matrix &other, size_t rowid);

// Slice a contiguous series of rows from this Matrix. The underlying storge
// buffer is shared here.
Matrix(const Matrix &other, int start_rowid, int end_rowid);
Matrix(const Matrix &other, size_t start_rowid, size_t end_rowid);

// select a bunch of rows from this matrix. this creates a copy
Matrix(const Matrix &other, const Vector<int> &rowids);

void resize(int rows, int cols);
void resize(size_t rows, size_t cols);
void assign_rows(const Vector<int> &rowids, const Matrix &other);

Matrix() : rows(0), cols(0), data(NULL) {}

// Copy the Matrix to host memory.
void to_host(float *output) const;

int rows, cols;
size_t rows, cols;
float *data;

std::shared_ptr<rmm::device_uvector<float>> storage;
Expand Down
14 changes: 7 additions & 7 deletions implicit/gpu/matrix.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ cdef extern from "implicit/gpu/matrix.h" namespace "implicit::gpu" nogil:
const int * row, const int * col, const float * data) except +

cdef cppclass Vector[T]:
Vector(int size, T * data) except +
Vector(size_t size, T * data) except +
void to_host(T * output) except +
T * data
int size
size_t size

cdef cppclass Matrix:
Matrix(int rows, int cols, float * data, bool host) except +
Matrix(const Matrix & other, int rowid) except +
Matrix(const Matrix & other, int start, int end) except +
Matrix(size_t rows, size_t cols, float * data, bool host) except +
Matrix(const Matrix & other, size_t rowid) except +
Matrix(const Matrix & other, size_t start, size_t end) except +
Matrix(const Matrix & other, const Vector[int] & rowids) except +
Matrix(Matrix && other) except +
void to_host(float * output) except +
void resize(int rows, int cols) except +
void resize(size_t rows, size_t cols) except +
void assign_rows(const Vector[int] & rowids, const Matrix & other) except +
int rows, cols
size_t rows, cols
float * data

Matrix calculate_norms(const Matrix & items) except +
Expand Down
4 changes: 2 additions & 2 deletions implicit/gpu/random.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ RandomState::RandomState(long seed) {
CHECK_CURAND(curandSetPseudoRandomGeneratorSeed(rng, seed));
}

Matrix RandomState::uniform(int rows, int cols, float low, float high) {
Matrix RandomState::uniform(size_t rows, size_t cols, float low, float high) {
Matrix ret(rows, cols, NULL);
CHECK_CURAND(curandGenerateUniform(rng, ret.data, rows * cols));

Expand All @@ -30,7 +30,7 @@ Matrix RandomState::uniform(int rows, int cols, float low, float high) {
return ret;
}

Matrix RandomState::randn(int rows, int cols, float mean, float stddev) {
Matrix RandomState::randn(size_t rows, size_t cols, float mean, float stddev) {
Matrix ret(rows, cols, NULL);
CHECK_CURAND(curandGenerateNormal(rng, ret.data, rows * cols, mean, stddev));
return ret;
Expand Down
4 changes: 2 additions & 2 deletions implicit/gpu/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ struct RandomState {
RandomState(long seed);
~RandomState();

Matrix uniform(int rows, int cols, float low = 0.0, float high = 1.0);
Matrix randn(int rows, int cols, float mean = 0, float stddev = 1);
Matrix uniform(size_t rows, size_t cols, float low = 0.0, float high = 1.0);
Matrix randn(size_t rows, size_t cols, float mean = 0, float stddev = 1);

RandomState(const RandomState &) = delete;
RandomState &operator=(const RandomState &) = delete;
Expand Down
4 changes: 2 additions & 2 deletions implicit/gpu/random.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ from .matrix cimport Matrix
cdef extern from "implicit/gpu/random.h" namespace "implicit::gpu" nogil:
cdef cppclass RandomState:
RandomState(long rows) except +
Matrix uniform(int rows, int cols, float low, float high) except +
Matrix randn(int rows, int cols, float mean, float stdev) except +
Matrix uniform(size_t rows, size_t cols, float low, float high) except +
Matrix randn(size_t rows, size_t cols, float mean, float stdev) except +