Skip to content

Commit

Permalink
[cherry-pick]add cast cuda kernel (PaddlePaddle#29352) PaddlePaddle#…
Browse files Browse the repository at this point in the history
…30263

 add cast cuda kernel

cherry-pick PaddlePaddle#29352
  • Loading branch information
zhangting2020 authored Jan 11, 2021
1 parent 6dd70b9 commit afbc636
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions paddle/fluid/operators/cast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,39 @@ limitations under the License. */

#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_launch_config.h"

namespace paddle {
namespace operators {

template <typename InT, typename OutT>
__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast<OutT>(in[index]); }
}

template <typename InT>
struct CastOpFunctor<platform::CUDADeviceContext, InT> {
const framework::Tensor* in_;
framework::Tensor* out_;
const platform::CUDADeviceContext& ctx_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const platform::CUDADeviceContext& ctx)
: in_(in), out_(out), ctx_(ctx) {}

template <typename OutT>
void apply() const {
auto* in = in_->data<InT>();
auto size = in_->numel();
auto* out = out_->mutable_data<OutT>(ctx_.GetPlace());
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx_, size);
CastCUDAKernel<InT, OutT><<<config.block_per_grid, config.thread_per_block,
0, ctx_.stream()>>>(in, size, out);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

Expand Down

0 comments on commit afbc636

Please sign in to comment.