Skip to content

Commit

Permalink
Merge pull request #720 from guillaumekln/half-baddbmm
Browse files Browse the repository at this point in the history
Support half precision in baddbmm
  • Loading branch information
soumith authored Mar 3, 2017
2 parents bbd8bfb + dfaf790 commit 359ee80
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions lib/THC/generic/THCTensorMathBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ __global__ void createBatchGemmBuffer(const real** buffer, real* data,
THC_API void
THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
real alpha, THCTensor *batch1, THCTensor *batch2) {
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2));
THArgCheck(THCTensor_(nDimension)(state, t) == 3, 4, "expected 3D tensor");
THArgCheck(THCTensor_(nDimension)(state, batch1) == 3, 6, "expected 3D tensor");
Expand Down Expand Up @@ -522,8 +522,10 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
ldb = batch2_->stride[1];
}

// Compute pointers to matrices in each batch.
long num_batches = result_->size[0];

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
// Compute pointers to matrices in each batch.
size_t matrices_size = num_batches * sizeof(real*);

// Copy pointers to device.
Expand Down Expand Up @@ -580,6 +582,24 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
THCudaFree(state, d_matrices2);
THCudaFree(state, d_result_matrices);

#elif defined(THC_REAL_IS_HALF)
// Currently no HgemmBatched in Cublas
for (long i = 0; i < num_batches; ++i) {
THCudaBlas_Hgemm(
state,
transpose_batch1,
transpose_batch2,
result_->size[transpose_result ? 2 : 1],
result_->size[transpose_result ? 1 : 2],
batch1_->size[transpose_result ? 1 : 2],
alpha,
THCTensor_(data)(state, batch1_) + i * batch1_->stride[0], lda,
THCTensor_(data)(state, batch2_) + i * batch2_->stride[0], ldb,
beta,
THCTensor_(data)(state, result_) + i * result_->stride[0], ldc);
}
#endif

if (batch1_ != batch1) {
THCTensor_(free)(state, batch1_);
}
Expand Down

0 comments on commit 359ee80

Please sign in to comment.