Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Support for fp16 in SpGeMM
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Aug 15, 2020
1 parent be12c8d commit 810d62d
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/operator/tensor/dot-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace op {
* \brief GPU scalar kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements: 1 thread/element
*/
template<int req>
template<int req, typename AType>
struct DotCsrDnsDnsScalarKernel {
/*!
* \brief This function represents performing an inner product between a row of lhs
Expand All @@ -63,20 +63,20 @@ struct DotCsrDnsDnsScalarKernel {
const nnvm::dim_t num_cols_r) {
const nnvm::dim_t irow = tid / num_cols_r; // row id of the lhs
const nnvm::dim_t icol = tid % num_cols_r; // col id of the rhs
DType sum = 0;
AType sum = 0;
for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) {
const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs
sum += data_l[j] * data_r[cur_col*num_cols_r+icol];
}
KERNEL_ASSIGN(out[tid], req, sum);
KERNEL_ASSIGN(out[tid], req, static_cast<DType>(sum));
}
};

/*!
* \brief GPU vector kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements: 1 warp/element
*/
template<int req>
template<int req, typename AType>
struct DotCsrDnsDnsVectorKernel {
/*!
* \brief see DotCsrDnsDnsScalarKernel Map for documentation.
Expand All @@ -90,7 +90,7 @@ struct DotCsrDnsDnsVectorKernel {
const DType* data_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
__shared__ volatile DType vals[mshadow::cuda::kBaseThreadNum];
__shared__ volatile AType vals[mshadow::cuda::kBaseThreadNum];
const dim_t warp_id = tid / 32; // global warp id
const dim_t lane = tid & (32-1); // local thread id within warp
const dim_t irow = warp_id / num_cols_r; // lhs row that this warp computes
Expand All @@ -101,9 +101,9 @@ struct DotCsrDnsDnsVectorKernel {
const dim_t high = static_cast<dim_t>(indptr_l[irow+1]);

// Compute running sum per thread
DType sum = 0;
AType sum = 0;
for (dim_t j = low+lane; j < high; j+=32) {
sum += data_l[j] * data_r[col_idx_l[j]*num_cols_r + kcol];
sum += static_cast<AType>(data_l[j]) * static_cast<AType>(data_r[col_idx_l[j]*num_cols_r + kcol]);
}
vals[threadIdx.x] = sum; __syncwarp();

Expand All @@ -115,7 +115,7 @@ struct DotCsrDnsDnsVectorKernel {
if (lane < 1) {vals[threadIdx.x] += vals[threadIdx.x+ 1];} __syncwarp();

if (lane == 0) {
KERNEL_ASSIGN(out[irow*num_cols_r+kcol], req, vals[threadIdx.x]);
KERNEL_ASSIGN(out[irow*num_cols_r+kcol], req, static_cast<DType>(vals[threadIdx.x]));
}
}
};
Expand Down Expand Up @@ -418,7 +418,7 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
const TBlob& data_r = rhs;
const TBlob data_out = *ret;

MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MXNET_REAL_ACC_TYPE_SWITCH(data_l.type_flag_, DType, AType, { // data type and accelerator type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
if (kWriteTo == req) {
Expand Down Expand Up @@ -513,14 +513,14 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
if (num_cols_r > 4) {
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
Kernel<DotCsrDnsDnsScalarKernel<ReqType, AType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), csc_data_ptr, csc_indptr_ptr,
csc_indices_ptr, data_r.dptr<DType>(), num_cols_r);
});
} else {
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
Kernel<DotCsrDnsDnsVectorKernel<ReqType, AType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), csc_data_ptr, csc_indptr_ptr,
csc_indices_ptr, data_r.dptr<DType>(), num_cols_r);
});
Expand All @@ -529,14 +529,14 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
if (num_cols_r > 4) {
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
Kernel<DotCsrDnsDnsScalarKernel<ReqType, AType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
} else {
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
Kernel<DotCsrDnsDnsVectorKernel<ReqType, AType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
Expand Down

0 comments on commit 810d62d

Please sign in to comment.