From 090c0a9acc789acbd6562f81848b2ade3ccf2408 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Wed, 22 Jun 2022 19:49:57 +0800 Subject: [PATCH 01/28] support fully fused mlp grad in eager --- .../gradient_funcs/cublas_fused_mlp.cpp | 125 +++++--- oneflow/core/functional/functional_api.yaml | 5 + .../core/functional/impl/nn_grad_functor.cpp | 39 +++ oneflow/ir/include/OneFlow/OneFlowUserOps.td | 19 ++ oneflow/user/kernels/cublas_fused_mlp_grad.cu | 290 ++++++++++++++++++ .../user/kernels/cublas_fused_mlp_util.cuh | 31 +- oneflow/user/ops/cublas_fused_mlp_grad_op.cpp | 103 +++++++ 7 files changed, 551 insertions(+), 61 deletions(-) create mode 100644 oneflow/user/kernels/cublas_fused_mlp_grad.cu create mode 100644 oneflow/user/ops/cublas_fused_mlp_grad_op.cpp diff --git a/oneflow/core/autograd/gradient_funcs/cublas_fused_mlp.cpp b/oneflow/core/autograd/gradient_funcs/cublas_fused_mlp.cpp index e0ae114b140..e01197fff0b 100644 --- a/oneflow/core/autograd/gradient_funcs/cublas_fused_mlp.cpp +++ b/oneflow/core/autograd/gradient_funcs/cublas_fused_mlp.cpp @@ -80,7 +80,7 @@ Maybe CublasFusedMLP::Capture(CublasFusedMLPCaptureState* ctx, const Tenso ctx->SaveTensorForBackward( JUST(VectorAt(outputs, i + 1))); // cublas aux. need minus 1. idx_sum:2+2w } - for (int32_t i = 0; i < weight_num - 1; i++) { + for (int32_t i = 0; i < weight_num; i++) { ctx->SaveTensorForBackward(JUST(VectorAt(outputs, i + 1 + weight_num))); // hidden. } @@ -102,14 +102,7 @@ Maybe CublasFusedMLP::Apply(const CublasFusedMLPCaptureState* ctx, JUST(VectorAt(ctx->SavedTensors(), 1 + weight_num)))); } - // step2: use reduce_sum to get last layer's bias grad. - std::vector reduce_axes_vec{0}; - if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) { - JUST(VectorAt(*in_grads, 2 * weight_num)) = - JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false)); - } - - TensorTuple hiddens(weight_num - 1); + TensorTuple hiddens(weight_num); TensorTuple weights(weight_num); TensorTuple cublas_auxs(weight_num); TensorTuple dgrad(weight_num); @@ -124,56 +117,88 @@ Maybe CublasFusedMLP::Apply(const CublasFusedMLPCaptureState* ctx, cublas_auxs[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + weight_num)); } - for (int32_t i = 0; i < weight_num - 1; ++i) { + for (int32_t i = 0; i < weight_num; ++i) { hiddens[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + 2 * weight_num)); } std::shared_ptr cublas_dy = last_bias_dy; - for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) { - // If it is final layer, we use out_grads[0] as dy. - if (hidden_layer_idx != weight_num - 1) { - cublas_dy = JUST(VectorAt(dgrad, hidden_layer_idx + 1)); + + // Use Fully Fused MLP Backward. + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD", false)) { + const auto& fused_mlp_grad = JUST(functional::FusedMLPGrad( + cublas_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), weights, cublas_auxs, hiddens)); + if (ctx->x_requires_grad) { + // dx: + JUST(VectorAt(*in_grads, 0)) = fused_mlp_grad->at(0); } - /* - Here we use cublas to compute bias + relu + matmul grad. - Then use Matmul to compute weight grad. - */ - const auto& matmul_relu_bias_bgrad = JUST(functional::CublasBiasAddReluMatmulGrad( - cublas_dy, JUST(VectorAt(weights, hidden_layer_idx)), - JUST(VectorAt(cublas_auxs, hidden_layer_idx - 1)), /*alpha=*/1.0)); - - // dgrad - dgrad.at(hidden_layer_idx) = matmul_relu_bias_bgrad->at(0); // NOLINT - - if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx - 1)))) { - // dbias - JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx)) = - matmul_relu_bias_bgrad->at(1); // NOLINT + + for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > -1; hidden_layer_idx--) { + if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx)))) { + // dbias + JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx + 1)) = + fused_mlp_grad->at(1 + hidden_layer_idx); // NOLINT + } + + // dw + if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) { + JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = + fused_mlp_grad->at(1 + weight_num + hidden_layer_idx); + } } - // dw - if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) { - JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = JUST(functional::MatMul( - cublas_dy, JUST(VectorAt(hiddens, hidden_layer_idx - 1)), true, false, 1.0)); + } else { + // step2: use reduce_sum to get last layer's bias grad. + std::vector reduce_axes_vec{0}; + if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) { + JUST(VectorAt(*in_grads, 2 * weight_num)) = + JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false)); } - } - // For the first layer, we need to use 2 matmul to get grads. - std::shared_ptr last_dy; - if (weight_num != 1) { - last_dy = JUST(VectorAt(dgrad, 1)); - } else { - last_dy = last_bias_dy; - } + for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) { + // If it is final layer, we use out_grads[0] as dy. + if (hidden_layer_idx != weight_num - 1) { + cublas_dy = JUST(VectorAt(dgrad, hidden_layer_idx + 1)); + } + /* + Here we use cublas to compute bias + relu + matmul grad. + Then use Matmul to compute weight grad. + */ + const auto& matmul_relu_bias_bgrad = JUST(functional::CublasBiasAddReluMatmulGrad( + cublas_dy, JUST(VectorAt(weights, hidden_layer_idx)), + JUST(VectorAt(cublas_auxs, hidden_layer_idx - 1)), /*alpha=*/1.0)); + + // dgrad + dgrad.at(hidden_layer_idx) = matmul_relu_bias_bgrad->at(0); // NOLINT + + if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx - 1)))) { + // dbias + JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx)) = + matmul_relu_bias_bgrad->at(1); // NOLINT + } + // dw + if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) { + JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = JUST(functional::MatMul( + cublas_dy, JUST(VectorAt(hiddens, hidden_layer_idx - 1)), true, false, 1.0)); + } + } - if (ctx->x_requires_grad) { - // dx: - JUST(VectorAt(*in_grads, 0)) = - JUST(functional::MatMul(last_dy, JUST(VectorAt(weights, 0)), false, false, 1.0)); - } - if (JUST(VectorAt(ctx->weights_requires_grad, 0))) { - // dw: - JUST(VectorAt(*in_grads, 1)) = - JUST(functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0)); + // For the first layer, we need to use 2 matmul to get grads. + std::shared_ptr last_dy; + if (weight_num != 1) { + last_dy = JUST(VectorAt(dgrad, 1)); + } else { + last_dy = last_bias_dy; + } + + if (ctx->x_requires_grad) { + // dx: + JUST(VectorAt(*in_grads, 0)) = + JUST(functional::MatMul(last_dy, JUST(VectorAt(weights, 0)), false, false, 1.0)); + } + if (JUST(VectorAt(ctx->weights_requires_grad, 0))) { + // dw: + JUST(VectorAt(*in_grads, 1)) = JUST( + functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0)); + } } return Maybe::Ok(); diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index e6db3be942c..490ee3556f4 100755 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -973,6 +973,11 @@ "Tensor (Tensor x, TensorTuple weights, TensorTuple biases, Bool skip_final_activation) => FusedMLP" bind_python: True +- name: "fused_mlp_grad" + signature: + "TensorTuple (Tensor dy, Tensor x, TensorTuple weights, TensorTuple cublas_aux, TensorTuple hidden) => FusedMLPGrad" + bind_python: False + - name: "cublas_bias_add_relu_matmul_grad" signature: "TensorTuple (Tensor dy, Tensor weight, Tensor aux, Double alpha=1.0) => CublasBiasAddReluMatmulGrad" diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index 5689710ac2b..36d619ff58f 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -1108,6 +1108,44 @@ class FusedCrossFeatureInteractionV2GradFunctor { std::shared_ptr v2_grad_op_; }; +class FusedMLPGradFunctor { + public: + FusedMLPGradFunctor() { +#if CUDA_VERSION >= 11060 + fused_op_.resize(kMaxInputCount /*the maximum number of inputs*/); + for (int n = 1; n < fused_op_.size(); ++n) { + fused_op_[n] = CHECK_JUST(one::OpBuilder("cublas_fused_mlp_grad") + .Input("dy") + .Input("x") + .Input("weights", n) + .Input("cublas_aux", n) + .Input("hidden", n) + .Output("d_grad") + .Output("d_biases", n) + .Output("d_weights", n) + .Build()); + } +#endif + } + Maybe operator()(const std::shared_ptr& dy, + const std::shared_ptr& x, const TensorTuple& weights, + const TensorTuple& cublas_aux, const TensorTuple& hidden) const { + const int64_t weight_size = weights.size(); + TensorTuple input(2 + 3 * weight_size); + input[0] = dy; + input[1] = x; + std::copy(weights.begin(), weights.end(), input.begin() + 2); + std::copy(cublas_aux.begin(), cublas_aux.end(), input.begin() + 2 + weight_size); + std::copy(hidden.begin(), hidden.end(), input.begin() + 2 + 2 * weight_size); + return OpInterpUtil::Dispatch(*fused_op_[weight_size], input); + } + + private: +#if CUDA_VERSION >= 11060 + std::vector> fused_op_; +#endif +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -1151,6 +1189,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { "FusedCrossFeatureInteractionV1Grad"); m.add_functor( "FusedCrossFeatureInteractionV2Grad"); + m.add_functor("FusedMLPGrad"); }; } // namespace functional diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 1305bfeb6c9..0a8cff09042 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -4554,6 +4554,25 @@ def OneFlow_CublasFusedMLPOp : OneFlow_BaseOp<"cublas_fused_mlp", [NoSideEffect, let has_data_type_infer_fn = 1; } +def OneFlow_CublasFusedMLPGradOp : OneFlow_BaseOp<"cublas_fused_mlp_grad", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x, + Variadic:$weights, + Variadic:$cublas_aux, + Variadic:$hidden + ); + let output = (outs + OneFlow_Tensor:$d_grad, + Variadic:$d_biases, + Variadic:$d_weights + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + def OneFlow_CublasBiasAddReluMatmulGradOp : OneFlow_BaseOp<"cublas_bias_add_relu_matmul_grad", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad.cu b/oneflow/user/kernels/cublas_fused_mlp_grad.cu new file mode 100644 index 00000000000..d7ba3f99168 --- /dev/null +++ b/oneflow/user/kernels/cublas_fused_mlp_grad.cu @@ -0,0 +1,290 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/kernel/cuda_graph_support.h" +#include "oneflow/user/kernels/cublas_fused_mlp_util.cuh" +// CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link. +#if CUDA_VERSION >= 11060 + +namespace oneflow { + +namespace { + +class MatmulGradKernelState final : public user_op::OpKernelState { + public: + MatmulGradKernelState() { + OF_CUDA_CHECK(cudaStreamCreate(&cuda_stream_)); + OF_CUBLAS_CHECK(cublasLtCreate(&cublas_lt_handle_)); + OF_CUDA_CHECK(cudaMalloc(&workspace_, 8 * 1024 * 1024)); + } + ~MatmulGradKernelState() { + OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); + OF_CUBLAS_CHECK(cublasLtDestroy(cublas_lt_handle_)); + OF_CUDA_CHECK(cudaStreamDestroy(cuda_stream_)); + OF_CUDA_CHECK(cudaFree(workspace_)); + } + cudaStream_t cuda_stream() const { return cuda_stream_; } + cublasLtHandle_t cublas_lt_handle() const { return cublas_lt_handle_; } + size_t cublas_workspace_size() const { return 8 * 1024 * 1024; } + void* cublas_workspace() const { return workspace_; } + + private: + cudaStream_t cuda_stream_{}; + cublasLtHandle_t cublas_lt_handle_{}; + void* workspace_{}; +}; + +template +class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { + public: + CublasFusedMLPGradKernel() { + OF_CUDA_CHECK(cudaEventCreate(&main_stream_event)); + OF_CUDA_CHECK(cudaEventCreate(&async_weight_grad_event)); + }; + ~CublasFusedMLPGradKernel() override { + OF_CUDA_CHECK(cudaEventDestroy(main_stream_event)); + OF_CUDA_CHECK(cudaEventDestroy(async_weight_grad_event)); + }; + + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateCublasFusedMLPKernelCache(); + } + + std::shared_ptr CreateOpKernelState( + user_op::KernelInitContext* ctx) const override { + return std::make_shared(); + } + + private: + cudaEvent_t main_stream_event; + cudaEvent_t async_weight_grad_event; + + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache* cache) const override { + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + void* dy_tmp_buf = tmp_buffer->mut_dptr(); + size_t offset = 0; + + user_op::Tensor* d_grad = ctx->Tensor4ArgNameAndIndex("d_grad", 0); + + const int64_t weight_num = ctx->input_size("weights"); + + const auto* matmul_grad_cache = + CHECK_NOTNULL(dynamic_cast(cache)); + auto* cuda_stream = ctx->stream()->As(); + + auto* kernel_state = dynamic_cast(state); + + const DataType data_type = dy->data_type(); + const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type); + const cudaDataType_t cuda_data_type = GetCudaDataType(data_type); + size_t cublas_m = 0, cublas_n = 0, cublas_k = 0; + int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0; + + double alpha = 1.0; + auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); + double beta = 0.0; + auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); + + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; // = CUBLASLT_EPILOGUE_DRELU_BGRAD + + // currently only support 2D matmul. + DimVector weight_shape(2); + DimVector hidden_shape(2); + DimVector dy_shape(2); + dy->shape().ToDimVector(&dy_shape); + const void* dgrad_buf = dy->dptr(); + + // step1: Get last layer's dbias. + const int64_t batch_size = dy->shape().At(0); + const T* ones = nullptr; + auto* cuda_device = dynamic_cast(ctx->stream()->device()); + if (cuda_device != nullptr) { + ones = static_cast(cuda_device->GetConstOnes(dy->data_type(), batch_size)); + } + DimVector ones_buf_shape(2); + ones_buf_shape.at(0) = 1; + ones_buf_shape.at(1) = batch_size; + user_op::Tensor* last_d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); + + InferMatmulCublasMNK(ones_buf_shape, dy_shape, + /*transpose_a=*/ep::primitive::BlasTransposeType::N, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n, + &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); + SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, + /*transpose_a=*/ep::primitive::BlasTransposeType::N, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, + cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); + OF_CUBLAS_CHECK(cublasLtMatmul( + kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, dgrad_buf, + matmul_grad_cache->cublas_a_desc, ones, matmul_grad_cache->cublas_b_desc, &sp_beta, + last_d_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, last_d_bias->mut_dptr(), + matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), + kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); + + for (int idx = weight_num - 1; idx > -1; idx--) { + if (idx != 0) { + const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", idx); + const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex("cublas_aux", idx - 1); + user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx - 1); + + weight->shape().ToDimVector(&weight_shape); + epilogue = CUBLASLT_EPILOGUE_DRELU_BGRAD; + InferMatmulCublasMNK(dy_shape, weight_shape, + /*transpose_a=*/ep::primitive::BlasTransposeType::N, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, + &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); + SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/true, + /*transpose_a=*/ep::primitive::BlasTransposeType::N, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, + d_bias->mut_dptr(), aux->dptr(), cublas_m, cublas_n, cublas_k, cublas_lda, + cublas_ldb, cublas_ldc); + /* + a = dy, b = weight + cublas_a=weight, cublas_b=dy + */ + OF_CUDA_CHECK(cudaEventRecord(main_stream_event, cuda_stream->cuda_stream())); + OF_CUBLAS_CHECK(cublasLtMatmul( + cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, + weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, + matmul_grad_cache->cublas_b_desc, &sp_beta, dy_tmp_buf, + matmul_grad_cache->cublas_c_desc, dy_tmp_buf, matmul_grad_cache->cublas_c_desc, nullptr, + cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), + cuda_stream->cuda_stream())); + } else { + const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", 0); + weight->shape().ToDimVector(&weight_shape); + epilogue = CUBLASLT_EPILOGUE_DEFAULT; + InferMatmulCublasMNK(dy_shape, weight_shape, + /*transpose_a=*/ep::primitive::BlasTransposeType::N, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, + &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); + SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, + /*transpose_a=*/ep::primitive::BlasTransposeType::N, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, + nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); + /* + a = dy, b = weight + cublas_a=weight, cublas_b=dy + */ + OF_CUDA_CHECK(cudaEventRecord(main_stream_event, cuda_stream->cuda_stream())); + OF_CUBLAS_CHECK(cublasLtMatmul( + cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, + weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, + matmul_grad_cache->cublas_b_desc, &sp_beta, d_grad->mut_dptr(), + matmul_grad_cache->cublas_c_desc, d_grad->mut_dptr(), matmul_grad_cache->cublas_c_desc, + nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), + cuda_stream->cuda_stream())); + } + alpha = 1.0; + sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); + beta = 0.0; + sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); + + // currently only support 2D matmul. + if (idx != 0) { + const user_op::Tensor* hidden = ctx->Tensor4ArgNameAndIndex("hidden", idx - 1); // here + user_op::Tensor* d_weights = ctx->Tensor4ArgNameAndIndex("d_weights", idx); + hidden->shape().ToDimVector(&hidden_shape); + + epilogue = CUBLASLT_EPILOGUE_DEFAULT; + + InferMatmulCublasMNK(dy_shape, hidden_shape, + /*transpose_a=*/ep::primitive::BlasTransposeType::T, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, + &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); + + SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, + /*transpose_a=*/ep::primitive::BlasTransposeType::T, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, + nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); + + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); + + OF_CUBLAS_CHECK(cublasLtMatmul( + kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, + hidden->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, + matmul_grad_cache->cublas_b_desc, &sp_beta, d_weights->mut_dptr(), + matmul_grad_cache->cublas_c_desc, d_weights->mut_dptr(), + matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), + kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); + + // compute dy shape + dy_shape.at(1) = weight_shape.at(1); + // compute dybuf + dgrad_buf = dy_tmp_buf; + offset += GetCudaAlignedSize(dy_shape.at(0) * dy_shape.at(1) * sizeof(T)); + dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); + } else { + user_op::Tensor* d_weights = ctx->Tensor4ArgNameAndIndex("d_weights", 0); + x->shape().ToDimVector(&hidden_shape); + epilogue = CUBLASLT_EPILOGUE_DEFAULT; + InferMatmulCublasMNK(dy_shape, hidden_shape, + /*transpose_a=*/ep::primitive::BlasTransposeType::T, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, + &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); + SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, + /*transpose_a=*/ep::primitive::BlasTransposeType::T, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, + nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); + OF_CUBLAS_CHECK(cublasLtMatmul( + kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, + x->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, + matmul_grad_cache->cublas_b_desc, &sp_beta, d_weights->mut_dptr(), + matmul_grad_cache->cublas_c_desc, d_weights->mut_dptr(), + matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), + kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); + OF_CUDA_CHECK(cudaEventRecord(async_weight_grad_event, kernel_state->cuda_stream())); + OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), async_weight_grad_event)); + } + } + }; + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(dtype) \ + REGISTER_USER_KERNEL("cublas_fused_mlp_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("x", 0) == GetDataType::value)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ + const int64_t weight_num = ctx->input_size("weights"); \ + const Shape& dy_shape = ctx->InputShape("dy", 0); \ + int64_t m = dy_shape.At(0); \ + int64_t k = dy_shape.At(1); \ + int64_t tmp_buffer_size = 0; \ + for (int idx = weight_num - 1; idx > 0; idx--) { \ + const Shape& weight_shape = ctx->InputShape("weights", idx); \ + k = weight_shape.At(1); \ + tmp_buffer_size += GetCudaAlignedSize(m * k * sizeof(dtype)); \ + } \ + return tmp_buffer_size; \ + }); + +REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(float) +REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(double) +REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(half) + +} // namespace + +} // namespace oneflow + +#endif // CUDA_VERSION >= 11060 diff --git a/oneflow/user/kernels/cublas_fused_mlp_util.cuh b/oneflow/user/kernels/cublas_fused_mlp_util.cuh index 3d4a57ad936..d5b27523736 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_util.cuh +++ b/oneflow/user/kernels/cublas_fused_mlp_util.cuh @@ -28,6 +28,7 @@ namespace oneflow { namespace { constexpr int32_t kAuxReluLdAlignRequirement = 128; +constexpr size_t kDefaultWorkspaceSize = 4 * 1024 * 1024; // 4M long AlignReluAuxLd(long aux_ld) { /* @@ -47,17 +48,20 @@ class CublasFusedMLPKernelCache final : public user_op::OpKernelCache { OF_CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&cublas_a_desc, CUDA_R_32F, 1, 1, 1)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&cublas_b_desc, CUDA_R_32F, 1, 1, 1)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&cublas_c_desc, CUDA_R_32F, 1, 1, 1)); + OF_CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&cublas_preference)); } ~CublasFusedMLPKernelCache() override { OF_CUBLAS_CHECK(cublasLtMatmulDescDestroy(operation_desc)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(cublas_a_desc)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(cublas_b_desc)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(cublas_c_desc)); + OF_CUBLAS_CHECK(cublasLtMatmulPreferenceDestroy(cublas_preference)); } cublasLtMatmulDesc_t operation_desc; cublasLtMatrixLayout_t cublas_a_desc; cublasLtMatrixLayout_t cublas_b_desc; cublasLtMatrixLayout_t cublas_c_desc; + cublasLtMatmulPreference_t cublas_preference; }; std::shared_ptr CreateCublasFusedMLPKernelCache() { @@ -168,21 +172,24 @@ void SetCublasMatrixLayout(cublasLtMatrixLayout_t layout_desc, cudaDataType_t cu void SetCublasEpilogue(const CublasFusedMLPKernelCache* matmul_cache, cublasLtEpilogue_t epilogue, const void* bias_ptr, const void* aux_ptr) { + // Set epilogue + OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute( + matmul_cache->operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); if (epilogue == CUBLASLT_EPILOGUE_RELU_BIAS || epilogue == CUBLASLT_EPILOGUE_BIAS || epilogue == CUBLASLT_EPILOGUE_RELU_AUX_BIAS || epilogue == CUBLASLT_EPILOGUE_DRELU_BGRAD || epilogue == CUBLASLT_EPILOGUE_BGRADB) { - // Set epilogue - OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute( - matmul_cache->operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); // Set bias ptr OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_cache->operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); } else { - Error::UnimplementedError() << "Unsupported Epilogue. "; + // unset + bias_ptr = nullptr; + OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_cache->operation_desc, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, + sizeof(bias_ptr))); } - // TODO: Support GELU_AUX_BIAS if (epilogue == CUBLASLT_EPILOGUE_RELU_AUX_BIAS || epilogue == CUBLASLT_EPILOGUE_DRELU_BGRAD) { // Set aux ptr for backward. OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(matmul_cache->operation_desc, @@ -208,12 +215,14 @@ void SetCublasAttr(const CublasFusedMLPKernelCache* matmul_grad_cache, matmul_grad_cache->operation_desc, CUBLASLT_MATMUL_DESC_COMPUTE_TYPE, &cublas_compute_dtype, sizeof(cublas_compute_dtype))); - // For best performance when using the bias vector, specify beta == 0 and - // CUBLASLT_POINTER_MODE_HOST.(from - // https://docs.nvidia.com/cuda/cublas/index.html#cublasLtPointerMode_t) - cublasLtPointerMode_t mode = CUBLASLT_POINTER_MODE_HOST; - OF_CUBLAS_CHECK(cublasLtMatmulDescSetAttribute( - matmul_grad_cache->operation_desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &mode, sizeof(mode))); + OF_CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute( + matmul_grad_cache->cublas_preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &kDefaultWorkspaceSize, sizeof(kDefaultWorkspaceSize))); + + uint32_t pointer_mode = CUBLASLT_POINTER_MODE_MASK_HOST; + OF_CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(matmul_grad_cache->cublas_preference, + CUBLASLT_MATMUL_PREF_POINTER_MODE_MASK, + &pointer_mode, sizeof(pointer_mode))); // transpose_a = False, transpose_b = True. But in cublas is reversed. const cublasOperation_t cublas_trans_a = diff --git a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp new file mode 100644 index 00000000000..5c6f0bb07ba --- /dev/null +++ b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp @@ -0,0 +1,103 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/just.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/infer_util.h" +#include "oneflow/core/framework/op_generated.h" + +namespace oneflow { + +namespace { + +Maybe InferTensorDesc4FusedMatmulBackward(user_op::InferContext* ctx) { + const int64_t weight_num = ctx->input_size("weights"); + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + for (int idx = weight_num - 1; idx > -1; idx--) { + const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weights", idx); + *ctx->OutputShape("d_weights", idx) = weight_desc.shape(); + *ctx->OutputShape("d_biases", idx) = Shape({weight_desc.shape().At(0)}); + } + *ctx->OutputShape("d_grad", 0) = x_desc.shape(); + return Maybe::Ok(); +} + +Maybe InferDataType4MatmulBackward(user_op::InferContext* ctx) { + const int64_t weight_num = ctx->input_size("weights"); + const int64_t dweight_num = ctx->output_size("d_weights"); + CHECK_EQ(weight_num, dweight_num) << "The number of weights and d_weights should be equal. "; + const int64_t dbias_size = ctx->output_size("d_biases"); + CHECK_EQ(weight_num, dbias_size) << "The number of d_biases should be equal to weight_num. " + "Because last layer's bias_grad is computed by ReduceSum. "; + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + for (int idx = weight_num - 1; idx > -1; idx--) { + *ctx->OutputDType("d_weights", idx) = dy_desc.data_type(); + *ctx->OutputDType("d_biases", idx) = dy_desc.data_type(); + } + *ctx->OutputDType("d_grad", 0) = dy_desc.data_type(); + return Maybe::Ok(); +} + +} // namespace + +/* static */ Maybe CublasFusedMLPGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4FusedMatmulBackward(ctx); +} + +/*static*/ Maybe CublasFusedMLPGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CublasFusedMLPGradOp::GetSbp(user_op::SbpContext* ctx) { + auto builder = ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0); + builder.Split(user_op::OpArg("dy", 0), 0); + for (int i = 0; i < ctx->user_op_conf().input_size("weights"); ++i) { + builder.Broadcast(user_op::OpArg("weights", i)); + } + for (int i = 0; i < ctx->user_op_conf().input_size("cublas_aux"); ++i) { + builder.Split(user_op::OpArg("cublas_aux", i), 0); + } + for (int i = 0; i < ctx->user_op_conf().input_size("hidden"); ++i) { + builder.Split(user_op::OpArg("hidden", i), 0); + } + + builder.Split(user_op::OpArg("d_grad", 0), 0); + for (int i = 0; i < ctx->user_op_conf().input_size("d_biases"); ++i) { + builder.PartialSum(user_op::OpArg("d_biases", i)); + } + for (int i = 0; i < ctx->user_op_conf().input_size("d_weights"); ++i) { + builder.PartialSum(user_op::OpArg("d_weights", i)); + } + return Maybe::Ok(); +} + +/* static */ Maybe CublasFusedMLPGradOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType4MatmulBackward(ctx); +} + +} // namespace oneflow From 110a368a5a3ba535457d92dce1182d9f3ab619be Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Wed, 22 Jun 2022 20:16:13 +0800 Subject: [PATCH 02/28] support lazy backward --- oneflow/user/ops/cublas_fused_mlp_grad_op.cpp | 12 -- oneflow/user/ops/cublas_fused_mlp_op.cpp | 204 ++++++++++-------- 2 files changed, 120 insertions(+), 96 deletions(-) diff --git a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp index 5c6f0bb07ba..db5a7f8d671 100644 --- a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp @@ -13,18 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ #include "oneflow/core/common/just.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" diff --git a/oneflow/user/ops/cublas_fused_mlp_op.cpp b/oneflow/user/ops/cublas_fused_mlp_op.cpp index 65619b14d0c..62ee7076501 100644 --- a/oneflow/user/ops/cublas_fused_mlp_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_op.cpp @@ -154,95 +154,131 @@ REGISTER_USER_OP_GRAD("cublas_fused_mlp") } else { last_bias_grad = op.GetGradTensorWithOpOutput("out", 0); } - - // step2: use reduce_sum to get last layer's bias grad. - // TODO: Currently Only support 2d fused_matmul. - // so here we hard encode bias reduce axis as 0. - std::vector reduce_axes_vec{0}; - user_op::UserOpConfWrapperBuilder bias_grad_builder(op.op_name() + "_bias_grad"); - user_op::UserOpConfWrapper bias_grad_op = bias_grad_builder.Op("reduce_sum") - .Input("input_tensor", last_bias_grad) - .Output("output_tensor") - .Attr("axis", reduce_axes_vec) - .Attr("keepdims", false) - .Build(); - AddOp(bias_grad_op); - if (op.NeedGenGradTensor4OpInput("biases", weight_num - 1)) { - op.BindGradTensorWithOpInput(bias_grad_op.output("output_tensor", 0), "biases", - weight_num - 1); - } std::string cublas_dy = last_bias_grad; - for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) { - user_op::UserOpConfWrapperBuilder cublas_bias_add_relu_matmul_grad_builder( - op.op_name() + "_cublas_bias_add_relu_matmul_grad_" + std::to_string(hidden_layer_idx)); - user_op::UserOpConfWrapper cublas_bias_add_relu_matmul_grad_op = - cublas_bias_add_relu_matmul_grad_builder.Op("cublas_bias_add_relu_matmul_grad") - .Input("dy", cublas_dy) - .Input("weight", op.input("weights", hidden_layer_idx)) - .Input("aux", op.output("cublas_aux", hidden_layer_idx - 1)) - .Attr("alpha", 1.0) - .Output("d_grad") - .Output("d_bias") - .Build(); - AddOp(cublas_bias_add_relu_matmul_grad_op); - if (op.NeedGenGradTensor4OpInput("biases", hidden_layer_idx - 1)) { - op.BindGradTensorWithOpInput(cublas_bias_add_relu_matmul_grad_op.output("d_bias", 0), - "biases", - hidden_layer_idx - 1); // previous layers bias grad + + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD", false)) { + printf("Here use fully fusedmlp grad \n"); + // Use Fully Fused MLP Backward. + user_op::UserOpConfWrapperBuilder fused_mlp_grad_builder(op.op_name() + "_fused_mlp_grad"); + fused_mlp_grad_builder.Op("cublas_fused_mlp_grad") + .Input("dy", cublas_dy) + .Input("x", op.input("x", 0)) + .Output("d_grad") + .Output("d_biases", weight_num) + .Output("d_weights", weight_num); + + for (int32_t hidden_layer_idx = 0; hidden_layer_idx < weight_num; hidden_layer_idx++) { + fused_mlp_grad_builder.Input("weights", op.input("weights", hidden_layer_idx)) + .Input("cublas_aux", op.output("cublas_aux", hidden_layer_idx)) + .Input("hidden", op.output("hidden", hidden_layer_idx)); } + user_op::UserOpConfWrapper fused_mlp_grad_op = fused_mlp_grad_builder.Build(); - user_op::UserOpConfWrapperBuilder matmul_weight_grad_builder( - op.op_name() + "_matmul_a_grad_" + std::to_string(hidden_layer_idx)); - user_op::UserOpConfWrapper matmul_weight_grad_op = - matmul_weight_grad_builder.Op("matmul") - .Input("a", cublas_dy) - .Input("b", op.output("hidden", hidden_layer_idx - 1)) - .Output("out") - .Attr("transpose_a", true) - .Attr("transpose_b", false) - .Attr("alpha", 1.0) - .Build(); - AddOp(matmul_weight_grad_op); - if (op.NeedGenGradTensor4OpInput("weights", hidden_layer_idx)) { - op.BindGradTensorWithOpInput(matmul_weight_grad_op.output("out", 0), "weights", - hidden_layer_idx); + AddOp(fused_mlp_grad_op); + + for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > -1; hidden_layer_idx--) { + if (op.NeedGenGradTensor4OpInput("biases", hidden_layer_idx)) { + op.BindGradTensorWithOpInput(fused_mlp_grad_op.output("d_biases", hidden_layer_idx), + "biases", hidden_layer_idx); + } + if (op.NeedGenGradTensor4OpInput("weights", hidden_layer_idx)) { + op.BindGradTensorWithOpInput(fused_mlp_grad_op.output("d_weights", hidden_layer_idx), + "weights", hidden_layer_idx); + } } - // update dgrad - cublas_dy = cublas_bias_add_relu_matmul_grad_op.output("d_grad", 0); - } + if (op.NeedGenGradTensor4OpInput("x", 0)) { + op.BindGradTensorWithOpInput(fused_mlp_grad_op.output("d_grad", 0), "x", 0); + } + } else { + // step2: use reduce_sum to get last layer's bias grad. + // TODO: Currently Only support 2d fused_matmul. + // so here we hard encode bias reduce axis as 0. + std::vector reduce_axes_vec{0}; + user_op::UserOpConfWrapperBuilder bias_grad_builder(op.op_name() + "_bias_grad"); + user_op::UserOpConfWrapper bias_grad_op = bias_grad_builder.Op("reduce_sum") + .Input("input_tensor", last_bias_grad) + .Output("output_tensor") + .Attr("axis", reduce_axes_vec) + .Attr("keepdims", false) + .Build(); + AddOp(bias_grad_op); + if (op.NeedGenGradTensor4OpInput("biases", weight_num - 1)) { + op.BindGradTensorWithOpInput(bias_grad_op.output("output_tensor", 0), "biases", + weight_num - 1); + } + for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) { + user_op::UserOpConfWrapperBuilder cublas_bias_add_relu_matmul_grad_builder( + op.op_name() + "_cublas_bias_add_relu_matmul_grad_" + + std::to_string(hidden_layer_idx)); + user_op::UserOpConfWrapper cublas_bias_add_relu_matmul_grad_op = + cublas_bias_add_relu_matmul_grad_builder.Op("cublas_bias_add_relu_matmul_grad") + .Input("dy", cublas_dy) + .Input("weight", op.input("weights", hidden_layer_idx)) + .Input("aux", op.output("cublas_aux", hidden_layer_idx - 1)) + .Attr("alpha", 1.0) + .Output("d_grad") + .Output("d_bias") + .Build(); + AddOp(cublas_bias_add_relu_matmul_grad_op); + if (op.NeedGenGradTensor4OpInput("biases", hidden_layer_idx - 1)) { + op.BindGradTensorWithOpInput(cublas_bias_add_relu_matmul_grad_op.output("d_bias", 0), + "biases", + hidden_layer_idx - 1); // previous layers bias grad + } - // For the first layer, we need to use 2 matmul to get grads. - std::string last_dy; - if (weight_num != 1) { last_dy = cublas_dy; } - // dx: - user_op::UserOpConfWrapperBuilder matmul_input_grad_builder(op.op_name() - + "_matmul_input_grad"); - user_op::UserOpConfWrapper matmul_input_grad_op = matmul_input_grad_builder.Op("matmul") - .Input("a", last_dy) - .Input("b", op.input("weights", 0)) - .Output("out") - .Attr("transpose_a", false) - .Attr("transpose_b", false) - .Attr("alpha", 1.0) - .Build(); - AddOp(matmul_input_grad_op); - if (op.NeedGenGradTensor4OpInput("x", 0)) { - op.BindGradTensorWithOpInput(matmul_input_grad_op.output("out", 0), "x", 0); - } - // dw: - user_op::UserOpConfWrapperBuilder matmul_weight_grad_builder(op.op_name() - + "_matmul_input_weight_grad"); - user_op::UserOpConfWrapper matmul_weight_grad_op = matmul_weight_grad_builder.Op("matmul") - .Input("a", last_dy) - .Input("b", op.input("x", 0)) - .Output("out") - .Attr("transpose_a", true) - .Attr("transpose_b", false) - .Attr("alpha", 1.0) - .Build(); - AddOp(matmul_weight_grad_op); - if (op.NeedGenGradTensor4OpInput("weights", 0)) { - op.BindGradTensorWithOpInput(matmul_weight_grad_op.output("out", 0), "weights", 0); + user_op::UserOpConfWrapperBuilder matmul_weight_grad_builder( + op.op_name() + "_matmul_a_grad_" + std::to_string(hidden_layer_idx)); + user_op::UserOpConfWrapper matmul_weight_grad_op = + matmul_weight_grad_builder.Op("matmul") + .Input("a", cublas_dy) + .Input("b", op.output("hidden", hidden_layer_idx - 1)) + .Output("out") + .Attr("transpose_a", true) + .Attr("transpose_b", false) + .Attr("alpha", 1.0) + .Build(); + AddOp(matmul_weight_grad_op); + if (op.NeedGenGradTensor4OpInput("weights", hidden_layer_idx)) { + op.BindGradTensorWithOpInput(matmul_weight_grad_op.output("out", 0), "weights", + hidden_layer_idx); + } + // update dgrad + cublas_dy = cublas_bias_add_relu_matmul_grad_op.output("d_grad", 0); + } + + // For the first layer, we need to use 2 matmul to get grads. + std::string last_dy; + if (weight_num != 1) { last_dy = cublas_dy; } + // dx: + user_op::UserOpConfWrapperBuilder matmul_input_grad_builder(op.op_name() + + "_matmul_input_grad"); + user_op::UserOpConfWrapper matmul_input_grad_op = matmul_input_grad_builder.Op("matmul") + .Input("a", last_dy) + .Input("b", op.input("weights", 0)) + .Output("out") + .Attr("transpose_a", false) + .Attr("transpose_b", false) + .Attr("alpha", 1.0) + .Build(); + AddOp(matmul_input_grad_op); + if (op.NeedGenGradTensor4OpInput("x", 0)) { + op.BindGradTensorWithOpInput(matmul_input_grad_op.output("out", 0), "x", 0); + } + // dw: + user_op::UserOpConfWrapperBuilder matmul_weight_grad_builder(op.op_name() + + "_matmul_input_weight_grad"); + user_op::UserOpConfWrapper matmul_weight_grad_op = matmul_weight_grad_builder.Op("matmul") + .Input("a", last_dy) + .Input("b", op.input("x", 0)) + .Output("out") + .Attr("transpose_a", true) + .Attr("transpose_b", false) + .Attr("alpha", 1.0) + .Build(); + AddOp(matmul_weight_grad_op); + if (op.NeedGenGradTensor4OpInput("weights", 0)) { + op.BindGradTensorWithOpInput(matmul_weight_grad_op.output("out", 0), "weights", 0); + } } return Maybe::Ok(); From e06c240eeb86f24c58cc5db357bb050ba1397462 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 23 Jun 2022 09:15:28 +0800 Subject: [PATCH 03/28] fix output size --- oneflow/user/ops/cublas_fused_mlp_grad_op.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp index db5a7f8d671..4241a199770 100644 --- a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp @@ -75,10 +75,10 @@ Maybe InferDataType4MatmulBackward(user_op::InferContext* ctx) { } builder.Split(user_op::OpArg("d_grad", 0), 0); - for (int i = 0; i < ctx->user_op_conf().input_size("d_biases"); ++i) { + for (int i = 0; i < ctx->user_op_conf().output_size("d_biases"); ++i) { builder.PartialSum(user_op::OpArg("d_biases", i)); } - for (int i = 0; i < ctx->user_op_conf().input_size("d_weights"); ++i) { + for (int i = 0; i < ctx->user_op_conf().output_size("d_weights"); ++i) { builder.PartialSum(user_op::OpArg("d_weights", i)); } return Maybe::Ok(); From 9158bcafb600e5851071b28d5b327e8256262841 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 23 Jun 2022 09:23:08 +0800 Subject: [PATCH 04/28] add fallback to tmp_buf logic when ones buffer is not enough --- oneflow/user/kernels/cublas_fused_mlp_grad.cu | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad.cu b/oneflow/user/kernels/cublas_fused_mlp_grad.cu index d7ba3f99168..3971677345b 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad.cu @@ -113,10 +113,13 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: // step1: Get last layer's dbias. const int64_t batch_size = dy->shape().At(0); - const T* ones = nullptr; + const void* ones = nullptr; auto* cuda_device = dynamic_cast(ctx->stream()->device()); if (cuda_device != nullptr) { - ones = static_cast(cuda_device->GetConstOnes(dy->data_type(), batch_size)); + ones = cuda_device->GetConstOnes(dy->data_type(), batch_size); + } else { + ones = dy_tmp_buf; + offset += batch_size; } DimVector ones_buf_shape(2); ones_buf_shape.at(0) = 1; @@ -260,23 +263,26 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(dtype) \ - REGISTER_USER_KERNEL("cublas_fused_mlp_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ - && (user_op::HobDataType("x", 0) == GetDataType::value)) \ - .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - const int64_t weight_num = ctx->input_size("weights"); \ - const Shape& dy_shape = ctx->InputShape("dy", 0); \ - int64_t m = dy_shape.At(0); \ - int64_t k = dy_shape.At(1); \ - int64_t tmp_buffer_size = 0; \ - for (int idx = weight_num - 1; idx > 0; idx--) { \ - const Shape& weight_shape = ctx->InputShape("weights", idx); \ - k = weight_shape.At(1); \ - tmp_buffer_size += GetCudaAlignedSize(m * k * sizeof(dtype)); \ - } \ - return tmp_buffer_size; \ +#define REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(dtype) \ + REGISTER_USER_KERNEL("cublas_fused_mlp_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("x", 0) == GetDataType::value)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ + const int64_t weight_num = ctx->input_size("weights"); \ + const Shape& dy_shape = ctx->InputShape("dy", 0); \ + int64_t m = dy_shape.At(0); \ + int64_t k = dy_shape.At(1); \ + int64_t tmp_buffer_size = 0; \ + if (m > 1024 * 1024) { \ + tmp_buffer_size += GetCudaAlignedSize(m * sizeof(dtype)); /*For last layer's bias grad*/ \ + } \ + for (int idx = weight_num - 1; idx > 0; idx--) { \ + const Shape& weight_shape = ctx->InputShape("weights", idx); \ + k = weight_shape.At(1); \ + tmp_buffer_size += GetCudaAlignedSize(m * k * sizeof(dtype)); \ + } \ + return tmp_buffer_size; \ }); REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(float) From 2395e6ef1d2b9c41d61f90659083f79130265b75 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 23 Jun 2022 09:36:59 +0800 Subject: [PATCH 05/28] build sbp --- oneflow/user/ops/cublas_fused_mlp_grad_op.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp index 4241a199770..5e582836760 100644 --- a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp @@ -81,6 +81,7 @@ Maybe InferDataType4MatmulBackward(user_op::InferContext* ctx) { for (int i = 0; i < ctx->user_op_conf().output_size("d_weights"); ++i) { builder.PartialSum(user_op::OpArg("d_weights", i)); } + builder.Build(); return Maybe::Ok(); } From 42fd871690bca48939855e3c6cbf62cdcf3c4dba Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 23 Jun 2022 11:38:33 +0800 Subject: [PATCH 06/28] overlap allreduce --- oneflow/user/kernels/cublas_fused_mlp_grad.cu | 114 +++++++++++++++--- oneflow/user/ops/cublas_fused_mlp_grad_op.cpp | 24 +++- 2 files changed, 118 insertions(+), 20 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad.cu b/oneflow/user/kernels/cublas_fused_mlp_grad.cu index 3971677345b..ccbf824798f 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad.cu @@ -13,6 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/device/nccl_util.h" +#include "oneflow/core/job/eager_nccl_comm_manager.h" +#include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/cublas_fused_mlp_util.cuh" // CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link. @@ -24,8 +27,10 @@ namespace { class MatmulGradKernelState final : public user_op::OpKernelState { public: - MatmulGradKernelState() { + MatmulGradKernelState(user_op::KernelInitContext* ctx) + : parallel_desc_(ctx->parallel_desc()), stream_name_(EagerNcclCommMgr::kDefaultStreamName) { OF_CUDA_CHECK(cudaStreamCreate(&cuda_stream_)); + OF_CUDA_CHECK(cudaStreamCreate(&allreduce_stream_)); OF_CUBLAS_CHECK(cublasLtCreate(&cublas_lt_handle_)); OF_CUDA_CHECK(cudaMalloc(&workspace_, 8 * 1024 * 1024)); } @@ -33,17 +38,48 @@ class MatmulGradKernelState final : public user_op::OpKernelState { OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); OF_CUBLAS_CHECK(cublasLtDestroy(cublas_lt_handle_)); OF_CUDA_CHECK(cudaStreamDestroy(cuda_stream_)); + OF_CUDA_CHECK(cudaStreamSynchronize(allreduce_stream_)); + OF_CUDA_CHECK(cudaStreamDestroy(allreduce_stream_)); OF_CUDA_CHECK(cudaFree(workspace_)); } cudaStream_t cuda_stream() const { return cuda_stream_; } + cudaStream_t allreduce_stream() const { return allreduce_stream_; } cublasLtHandle_t cublas_lt_handle() const { return cublas_lt_handle_; } size_t cublas_workspace_size() const { return 8 * 1024 * 1024; } void* cublas_workspace() const { return workspace_; } + ncclComm_t comm() { return GetOrCreate().comm; } private: + struct Comm { + Comm(ncclComm_t comm) : comm(comm) {} + ncclComm_t comm; + }; + + const Comm& GetOrCreate() { + if (!comm_) { Init(); } + return *comm_; + } + + void Init() { + std::set> device_set; + for (int64_t parallel_id = 0; parallel_id < parallel_desc_.parallel_num(); ++parallel_id) { + int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Global::Get()); + ncclComm_t comm; + comm = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_); + comm_.reset(new Comm(comm)); + } + cudaStream_t cuda_stream_{}; + cudaStream_t allreduce_stream_{}; cublasLtHandle_t cublas_lt_handle_{}; void* workspace_{}; + std::unique_ptr comm_; + ParallelDesc parallel_desc_; + std::string stream_name_; }; template @@ -52,10 +88,16 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: CublasFusedMLPGradKernel() { OF_CUDA_CHECK(cudaEventCreate(&main_stream_event)); OF_CUDA_CHECK(cudaEventCreate(&async_weight_grad_event)); + OF_CUDA_CHECK(cudaEventCreate(&dweight_event)); + OF_CUDA_CHECK(cudaEventCreate(&dbias_event)); + OF_CUDA_CHECK(cudaEventCreate(&allreduce_event)); }; ~CublasFusedMLPGradKernel() override { OF_CUDA_CHECK(cudaEventDestroy(main_stream_event)); OF_CUDA_CHECK(cudaEventDestroy(async_weight_grad_event)); + OF_CUDA_CHECK(cudaEventDestroy(dweight_event)); + OF_CUDA_CHECK(cudaEventDestroy(dbias_event)); + OF_CUDA_CHECK(cudaEventDestroy(allreduce_event)); }; std::shared_ptr InitOpKernelCache( @@ -65,12 +107,15 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { - return std::make_shared(); + return std::make_shared(ctx); } private: cudaEvent_t main_stream_event; cudaEvent_t async_weight_grad_event; + cudaEvent_t dweight_event; + cudaEvent_t dbias_event; + cudaEvent_t allreduce_event; using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, @@ -78,6 +123,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + auto* kernel_state = dynamic_cast(state); + ncclComm_t comm = kernel_state->comm(); void* dy_tmp_buf = tmp_buffer->mut_dptr(); size_t offset = 0; @@ -89,8 +136,6 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: CHECK_NOTNULL(dynamic_cast(cache)); auto* cuda_stream = ctx->stream()->As(); - auto* kernel_state = dynamic_cast(state); - const DataType data_type = dy->data_type(); const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type); const cudaDataType_t cuda_data_type = GetCudaDataType(data_type); @@ -124,7 +169,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: DimVector ones_buf_shape(2); ones_buf_shape.at(0) = 1; ones_buf_shape.at(1) = batch_size; - user_op::Tensor* last_d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); + user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); InferMatmulCublasMNK(ones_buf_shape, dy_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::N, @@ -137,15 +182,24 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: OF_CUBLAS_CHECK(cublasLtMatmul( kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, dgrad_buf, matmul_grad_cache->cublas_a_desc, ones, matmul_grad_cache->cublas_b_desc, &sp_beta, - last_d_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, last_d_bias->mut_dptr(), + d_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); + // allreduce first Dbias. + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + OF_CUDA_CHECK(cudaEventRecord(dbias_event, kernel_state->cuda_stream())); + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dbias_event)); + OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), + d_bias->shape().elem_cnt(), GetNcclDataType(d_bias->data_type()), + ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream())); + } + for (int idx = weight_num - 1; idx > -1; idx--) { if (idx != 0) { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", idx); const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex("cublas_aux", idx - 1); - user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx - 1); + d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx - 1); weight->shape().ToDimVector(&weight_shape); epilogue = CUBLASLT_EPILOGUE_DRELU_BGRAD; @@ -170,6 +224,9 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: matmul_grad_cache->cublas_c_desc, dy_tmp_buf, matmul_grad_cache->cublas_c_desc, nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + OF_CUDA_CHECK(cudaEventRecord(dbias_event, kernel_state->cuda_stream())); + } } else { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", 0); weight->shape().ToDimVector(&weight_shape); @@ -201,9 +258,9 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); // currently only support 2D matmul. + user_op::Tensor* d_weight = ctx->Tensor4ArgNameAndIndex("d_weights", idx); if (idx != 0) { const user_op::Tensor* hidden = ctx->Tensor4ArgNameAndIndex("hidden", idx - 1); // here - user_op::Tensor* d_weights = ctx->Tensor4ArgNameAndIndex("d_weights", idx); hidden->shape().ToDimVector(&hidden_shape); epilogue = CUBLASLT_EPILOGUE_DEFAULT; @@ -223,11 +280,14 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: OF_CUBLAS_CHECK(cublasLtMatmul( kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, hidden->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, - matmul_grad_cache->cublas_b_desc, &sp_beta, d_weights->mut_dptr(), - matmul_grad_cache->cublas_c_desc, d_weights->mut_dptr(), + matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(), + matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); + } // compute dy shape dy_shape.at(1) = weight_shape.at(1); // compute dybuf @@ -235,7 +295,6 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: offset += GetCudaAlignedSize(dy_shape.at(0) * dy_shape.at(1) * sizeof(T)); dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); } else { - user_op::Tensor* d_weights = ctx->Tensor4ArgNameAndIndex("d_weights", 0); x->shape().ToDimVector(&hidden_shape); epilogue = CUBLASLT_EPILOGUE_DEFAULT; InferMatmulCublasMNK(dy_shape, hidden_shape, @@ -250,13 +309,40 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: OF_CUBLAS_CHECK(cublasLtMatmul( kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, x->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, - matmul_grad_cache->cublas_b_desc, &sp_beta, d_weights->mut_dptr(), - matmul_grad_cache->cublas_c_desc, d_weights->mut_dptr(), + matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(), + matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); OF_CUDA_CHECK(cudaEventRecord(async_weight_grad_event, kernel_state->cuda_stream())); - OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), async_weight_grad_event)); + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + OF_CUDA_CHECK( + cudaStreamWaitEvent(kernel_state->allreduce_stream(), async_weight_grad_event)); + OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), + d_weight->shape().elem_cnt(), + GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + } } + + // Do Allreduce for d_bias and d_weight. + if (idx > 0) { + // Here we wait wgrad and dgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dbias_event)); + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); + OF_NCCL_CHECK(ncclGroupStart()); + OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), + d_bias->shape().elem_cnt(), + GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), + d_weight->shape().elem_cnt(), + GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + OF_NCCL_CHECK(ncclGroupEnd()); + } + } + if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), async_weight_grad_event)); } }; diff --git a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp index 5e582836760..9bf278d059b 100644 --- a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp @@ -75,13 +75,25 @@ Maybe InferDataType4MatmulBackward(user_op::InferContext* ctx) { } builder.Split(user_op::OpArg("d_grad", 0), 0); - for (int i = 0; i < ctx->user_op_conf().output_size("d_biases"); ++i) { - builder.PartialSum(user_op::OpArg("d_biases", i)); + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + // FusedMLPGradKernel do allreduce for dbias and dweight, so here convert from PartialSum to + // Broadcast. + for (int i = 0; i < ctx->user_op_conf().output_size("d_biases"); ++i) { + builder.Broadcast(user_op::OpArg("d_biases", i)); + } + for (int i = 0; i < ctx->user_op_conf().output_size("d_weights"); ++i) { + builder.Broadcast(user_op::OpArg("d_weights", i)); + } + } else { + for (int i = 0; i < ctx->user_op_conf().output_size("d_biases"); ++i) { + builder.PartialSum(user_op::OpArg("d_biases", i)); + } + for (int i = 0; i < ctx->user_op_conf().output_size("d_weights"); ++i) { + builder.PartialSum(user_op::OpArg("d_weights", i)); + } } - for (int i = 0; i < ctx->user_op_conf().output_size("d_weights"); ++i) { - builder.PartialSum(user_op::OpArg("d_weights", i)); - } - builder.Build(); + + builder.Build(); return Maybe::Ok(); } From f8ec2b97c108f2c048ecf63d5fa70954e0d736f0 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 23 Jun 2022 16:23:01 +0800 Subject: [PATCH 07/28] fix overlap order --- oneflow/user/kernels/cublas_fused_mlp_grad.cu | 121 ++++++++---------- 1 file changed, 55 insertions(+), 66 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad.cu b/oneflow/user/kernels/cublas_fused_mlp_grad.cu index ccbf824798f..2c0efd7864e 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad.cu @@ -89,14 +89,12 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: OF_CUDA_CHECK(cudaEventCreate(&main_stream_event)); OF_CUDA_CHECK(cudaEventCreate(&async_weight_grad_event)); OF_CUDA_CHECK(cudaEventCreate(&dweight_event)); - OF_CUDA_CHECK(cudaEventCreate(&dbias_event)); OF_CUDA_CHECK(cudaEventCreate(&allreduce_event)); }; ~CublasFusedMLPGradKernel() override { OF_CUDA_CHECK(cudaEventDestroy(main_stream_event)); OF_CUDA_CHECK(cudaEventDestroy(async_weight_grad_event)); OF_CUDA_CHECK(cudaEventDestroy(dweight_event)); - OF_CUDA_CHECK(cudaEventDestroy(dbias_event)); OF_CUDA_CHECK(cudaEventDestroy(allreduce_event)); }; @@ -114,7 +112,6 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: cudaEvent_t main_stream_event; cudaEvent_t async_weight_grad_event; cudaEvent_t dweight_event; - cudaEvent_t dbias_event; cudaEvent_t allreduce_event; using user_op::OpKernel::Compute; @@ -123,15 +120,16 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + const int64_t weight_num = ctx->input_size("weights"); + user_op::Tensor* d_grad = ctx->Tensor4ArgNameAndIndex("d_grad", 0); + // just a placeholder. + user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); + user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); + auto* kernel_state = dynamic_cast(state); ncclComm_t comm = kernel_state->comm(); void* dy_tmp_buf = tmp_buffer->mut_dptr(); size_t offset = 0; - - user_op::Tensor* d_grad = ctx->Tensor4ArgNameAndIndex("d_grad", 0); - - const int64_t weight_num = ctx->input_size("weights"); - const auto* matmul_grad_cache = CHECK_NOTNULL(dynamic_cast(cache)); auto* cuda_stream = ctx->stream()->As(); @@ -155,8 +153,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: DimVector dy_shape(2); dy->shape().ToDimVector(&dy_shape); const void* dgrad_buf = dy->dptr(); - - // step1: Get last layer's dbias. + const int64_t batch_size = dy->shape().At(0); const void* ones = nullptr; auto* cuda_device = dynamic_cast(ctx->stream()->device()); @@ -165,34 +162,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: } else { ones = dy_tmp_buf; offset += batch_size; - } - DimVector ones_buf_shape(2); - ones_buf_shape.at(0) = 1; - ones_buf_shape.at(1) = batch_size; - user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); - - InferMatmulCublasMNK(ones_buf_shape, dy_shape, - /*transpose_a=*/ep::primitive::BlasTransposeType::N, - /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n, - &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); - SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, - /*transpose_a=*/ep::primitive::BlasTransposeType::N, - /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, - cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); - OF_CUBLAS_CHECK(cublasLtMatmul( - kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, dgrad_buf, - matmul_grad_cache->cublas_a_desc, ones, matmul_grad_cache->cublas_b_desc, &sp_beta, - d_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_bias->mut_dptr(), - matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), - kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); - - // allreduce first Dbias. - if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - OF_CUDA_CHECK(cudaEventRecord(dbias_event, kernel_state->cuda_stream())); - OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dbias_event)); - OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), - d_bias->shape().elem_cnt(), GetNcclDataType(d_bias->data_type()), - ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream())); + dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); } for (int idx = weight_num - 1; idx > -1; idx--) { @@ -224,9 +194,6 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: matmul_grad_cache->cublas_c_desc, dy_tmp_buf, matmul_grad_cache->cublas_c_desc, nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); - if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - OF_CUDA_CHECK(cudaEventRecord(dbias_event, kernel_state->cuda_stream())); - } } else { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", 0); weight->shape().ToDimVector(&weight_shape); @@ -258,6 +225,30 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); // currently only support 2D matmul. + // step1: Get last layer's dbias. + if(idx == weight_num - 1){ + DimVector ones_buf_shape(2); + ones_buf_shape.at(0) = 1; + ones_buf_shape.at(1) = batch_size; + + epilogue = CUBLASLT_EPILOGUE_DEFAULT; + InferMatmulCublasMNK(ones_buf_shape, dy_shape, + /*transpose_a=*/ep::primitive::BlasTransposeType::N, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n, + &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); + SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, + /*transpose_a=*/ep::primitive::BlasTransposeType::N, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, + cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); + OF_CUBLAS_CHECK(cublasLtMatmul( + kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, dgrad_buf, + matmul_grad_cache->cublas_a_desc, ones, matmul_grad_cache->cublas_b_desc, &sp_beta, + d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_last_bias->mut_dptr(), + matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), + kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); + } + user_op::Tensor* d_weight = ctx->Tensor4ArgNameAndIndex("d_weights", idx); if (idx != 0) { const user_op::Tensor* hidden = ctx->Tensor4ArgNameAndIndex("hidden", idx - 1); // here @@ -274,8 +265,11 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); - - OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); + + if(idx != weight_num - 1){ + // if idx == weight_num - 1, async_stream has wait main_stream_event in d_bias. + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); + } OF_CUBLAS_CHECK(cublasLtMatmul( kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, @@ -314,34 +308,29 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); OF_CUDA_CHECK(cudaEventRecord(async_weight_grad_event, kernel_state->cuda_stream())); - if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - OF_CUDA_CHECK( - cudaStreamWaitEvent(kernel_state->allreduce_stream(), async_weight_grad_event)); - OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), - d_weight->shape().elem_cnt(), - GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, - comm, kernel_state->allreduce_stream())); - } } // Do Allreduce for d_bias and d_weight. - if (idx > 0) { - // Here we wait wgrad and dgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. - OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dbias_event)); - OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); - OF_NCCL_CHECK(ncclGroupStart()); - OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), - d_bias->shape().elem_cnt(), - GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, - comm, kernel_state->allreduce_stream())); - OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), - d_weight->shape().elem_cnt(), - GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, - comm, kernel_state->allreduce_stream())); - OF_NCCL_CHECK(ncclGroupEnd()); + // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); + OF_NCCL_CHECK(ncclGroupStart()); + OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), + d_bias->shape().elem_cnt(), + GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), + d_weight->shape().elem_cnt(), + GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + OF_NCCL_CHECK(ncclGroupEnd()); + if(idx == 0){ + // We should sync allreduce before the kernel finish. + OF_CUDA_CHECK(cudaEventRecord(allreduce_event, kernel_state->allreduce_stream())); } } - if (!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event)); + } else { OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), async_weight_grad_event)); } }; From b017dbddb76c0635426f2be3450c0119dfe54315 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 23 Jun 2022 16:52:46 +0800 Subject: [PATCH 08/28] fix format --- oneflow/user/kernels/cublas_fused_mlp_grad.cu | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad.cu b/oneflow/user/kernels/cublas_fused_mlp_grad.cu index 2c0efd7864e..60c53cef326 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad.cu @@ -122,9 +122,9 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t weight_num = ctx->input_size("weights"); user_op::Tensor* d_grad = ctx->Tensor4ArgNameAndIndex("d_grad", 0); - // just a placeholder. + // just a placeholder. user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); - user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); + user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); auto* kernel_state = dynamic_cast(state); ncclComm_t comm = kernel_state->comm(); @@ -153,7 +153,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: DimVector dy_shape(2); dy->shape().ToDimVector(&dy_shape); const void* dgrad_buf = dy->dptr(); - + const int64_t batch_size = dy->shape().At(0); const void* ones = nullptr; auto* cuda_device = dynamic_cast(ctx->stream()->device()); @@ -226,28 +226,29 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: // currently only support 2D matmul. // step1: Get last layer's dbias. - if(idx == weight_num - 1){ + if (idx == weight_num - 1) { DimVector ones_buf_shape(2); ones_buf_shape.at(0) = 1; ones_buf_shape.at(1) = batch_size; - + epilogue = CUBLASLT_EPILOGUE_DEFAULT; InferMatmulCublasMNK(ones_buf_shape, dy_shape, - /*transpose_a=*/ep::primitive::BlasTransposeType::N, - /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, &cublas_n, - &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); + /*transpose_a=*/ep::primitive::BlasTransposeType::N, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, + &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, /*transpose_a=*/ep::primitive::BlasTransposeType::N, - /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, - cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); + /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, + nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); OF_CUBLAS_CHECK(cublasLtMatmul( - kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, dgrad_buf, - matmul_grad_cache->cublas_a_desc, ones, matmul_grad_cache->cublas_b_desc, &sp_beta, - d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_last_bias->mut_dptr(), - matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), - kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); - } + kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, + dgrad_buf, matmul_grad_cache->cublas_a_desc, ones, matmul_grad_cache->cublas_b_desc, + &sp_beta, d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, + d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, + kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), + kernel_state->cuda_stream())); + } user_op::Tensor* d_weight = ctx->Tensor4ArgNameAndIndex("d_weights", idx); if (idx != 0) { @@ -265,9 +266,9 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); - - if(idx != weight_num - 1){ - // if idx == weight_num - 1, async_stream has wait main_stream_event in d_bias. + + if (idx != weight_num - 1) { + // if idx == weight_num - 1, async_stream has wait main_stream_event in d_bias. OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); } @@ -315,16 +316,15 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); OF_NCCL_CHECK(ncclGroupStart()); OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), - d_bias->shape().elem_cnt(), - GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, - comm, kernel_state->allreduce_stream())); + d_bias->shape().elem_cnt(), GetNcclDataType(d_bias->data_type()), + ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream())); OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), d_weight->shape().elem_cnt(), GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream())); OF_NCCL_CHECK(ncclGroupEnd()); - if(idx == 0){ - // We should sync allreduce before the kernel finish. + if (idx == 0) { + // We should sync allreduce before the kernel finish. OF_CUDA_CHECK(cudaEventRecord(allreduce_event, kernel_state->allreduce_stream())); } } From bff8a33c407cfb05913c9bc5ec50e8a8d5f6b288 Mon Sep 17 00:00:00 2001 From: liujuncheng Date: Thu, 23 Jun 2022 17:09:12 +0800 Subject: [PATCH 09/28] CUDA Graphs delayed capture --- oneflow/core/kernel/cuda_graph_support.h | 7 +++++++ oneflow/core/kernel/user_kernel.cpp | 16 ++++++++++++++-- oneflow/core/kernel/user_kernel.h | 1 + oneflow/core/lazy/actor/light_actor.cpp | 10 ++++++++-- 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/oneflow/core/kernel/cuda_graph_support.h b/oneflow/core/kernel/cuda_graph_support.h index 975dd08680c..2ec118c3a1d 100644 --- a/oneflow/core/kernel/cuda_graph_support.h +++ b/oneflow/core/kernel/cuda_graph_support.h @@ -19,7 +19,9 @@ namespace oneflow { namespace user_op { class KernelInitContext; +class KernelComputeContext; class OpKernelState; +class OpKernelCache; class CudaGraphSupport { public: @@ -29,6 +31,11 @@ class CudaGraphSupport { virtual bool IsCudaGraphSupported(KernelInitContext* ctx, OpKernelState* state) const { return true; } + + virtual bool IsReadyForCapture(KernelComputeContext* ctx, OpKernelState* state, + const OpKernelCache* cache) const { + return true; + } }; } // namespace user_op diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 1f29ad41012..885add005fb 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -649,8 +649,13 @@ void UserKernel::ForwardUserKernel(const std::functionLaunchGraph(cuda_graph_exec_.get()); return; } - current_scope_capturing = true; - cuda_stream->BeginGraphCapture(); + const auto* cuda_graph_support = + CHECK_NOTNULL(dynamic_cast(kernel_.get())); + if (cuda_graph_support->IsReadyForCapture(ctx_.get(), opkernel_state, + opkernel_cache_.get())) { + current_scope_capturing = true; + cuda_stream->BeginGraphCapture(); + } } } #endif // WITH_CUDA_GRAPHS @@ -674,6 +679,13 @@ bool UserKernel::IsCudaGraphSupported() const { #endif // WITH_CUDA_GRAPHS } +bool UserKernel::IsReadyForCudaGraphCapture(KernelContext* ctx) const { + const auto* cuda_graph_support = dynamic_cast(kernel_.get()); + if (cuda_graph_support == nullptr) { return false; } + return cuda_graph_support->IsReadyForCapture(ctx_.get(), opkernel_state_.get(), + opkernel_cache_.get()); +} + void UserKernel::VirtualKernelInit(KernelContext* ctx) { InitUserKernel(ctx->stream()); CHECK(opkernel_state_.get() == nullptr); diff --git a/oneflow/core/kernel/user_kernel.h b/oneflow/core/kernel/user_kernel.h index ffe3c854927..d3915dd2479 100644 --- a/oneflow/core/kernel/user_kernel.h +++ b/oneflow/core/kernel/user_kernel.h @@ -51,6 +51,7 @@ class UserKernel final : public Kernel { void ForwardUserKernel(const std::function& BnInOp2Blob, user_op::OpKernelState* opkernel_state) const; bool IsCudaGraphSupported() const; + bool IsReadyForCudaGraphCapture(KernelContext* ctx) const; private: void VirtualKernelInit(KernelContext* ctx) override; diff --git a/oneflow/core/lazy/actor/light_actor.cpp b/oneflow/core/lazy/actor/light_actor.cpp index ef168540a5b..7650699fa13 100644 --- a/oneflow/core/lazy/actor/light_actor.cpp +++ b/oneflow/core/lazy/actor/light_actor.cpp @@ -465,18 +465,24 @@ class LightActor : public ActorBase, public KernelContext, public ActorContextPr inline void LaunchKernel() { #ifdef WITH_CUDA_GRAPHS + bool is_capturing = false; if (cuda_graph_exec_[0]) { auto* cuda_stream = stream_ctx_->stream()->As(); if (cuda_graph_exec_[0]->IsInstantiated()) { cuda_stream->LaunchGraph(cuda_graph_exec_[0].get()); return; } - cuda_stream->BeginGraphCapture(); + auto* user_kernel = + CHECK_NOTNULL(dynamic_cast(kernel_info_[0]->kernel.get())); + if (user_kernel->IsReadyForCudaGraphCapture(this)) { + is_capturing = true; + cuda_stream->BeginGraphCapture(); + } } #endif kernel_info_[0]->kernel->Launch(this); #ifdef WITH_CUDA_GRAPHS - if (cuda_graph_exec_[0]) { + if (cuda_graph_exec_[0] && is_capturing) { auto* cuda_stream = stream_ctx_->stream()->As(); cuda_stream->EndGraphCapture(cuda_graph_exec_[0].get()); cuda_stream->LaunchGraph(cuda_graph_exec_[0].get()); From 4eb6ac1f3d8f962c743e3b5e1ab56f9890f10c0f Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Fri, 24 Jun 2022 17:49:40 +0800 Subject: [PATCH 10/28] Add ifcomm create for graph --- oneflow/user/kernels/cublas_fused_mlp_grad.cu | 56 ++++++++++++------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad.cu b/oneflow/user/kernels/cublas_fused_mlp_grad.cu index 60c53cef326..3b0c4c9b33d 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad.cu @@ -48,6 +48,13 @@ class MatmulGradKernelState final : public user_op::OpKernelState { size_t cublas_workspace_size() const { return 8 * 1024 * 1024; } void* cublas_workspace() const { return workspace_; } ncclComm_t comm() { return GetOrCreate().comm; } + bool IfCommCreate(){ + if(!comm_){ + printf("Here no create. \n"); + return false; + } + return true; + } private: struct Comm { @@ -114,6 +121,11 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: cudaEvent_t dweight_event; cudaEvent_t allreduce_event; + bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { + auto* kernel_state = dynamic_cast(state); + return kernel_state->IfCommCreate(); + } + using user_op::OpKernel::Compute; void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { @@ -127,7 +139,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); auto* kernel_state = dynamic_cast(state); - ncclComm_t comm = kernel_state->comm(); + // ncclComm_t comm = kernel_state->comm(); + void* dy_tmp_buf = tmp_buffer->mut_dptr(); size_t offset = 0; const auto* matmul_grad_cache = @@ -151,10 +164,10 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: DimVector weight_shape(2); DimVector hidden_shape(2); DimVector dy_shape(2); - dy->shape().ToDimVector(&dy_shape); + dy->shape_view().ToDimVector(&dy_shape); const void* dgrad_buf = dy->dptr(); - const int64_t batch_size = dy->shape().At(0); + const int64_t batch_size = dy->shape_view().At(0); const void* ones = nullptr; auto* cuda_device = dynamic_cast(ctx->stream()->device()); if (cuda_device != nullptr) { @@ -171,7 +184,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex("cublas_aux", idx - 1); d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx - 1); - weight->shape().ToDimVector(&weight_shape); + weight->shape_view().ToDimVector(&weight_shape); epilogue = CUBLASLT_EPILOGUE_DRELU_BGRAD; InferMatmulCublasMNK(dy_shape, weight_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::N, @@ -196,7 +209,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: cuda_stream->cuda_stream())); } else { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", 0); - weight->shape().ToDimVector(&weight_shape); + weight->shape_view().ToDimVector(&weight_shape); epilogue = CUBLASLT_EPILOGUE_DEFAULT; InferMatmulCublasMNK(dy_shape, weight_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::N, @@ -253,7 +266,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: user_op::Tensor* d_weight = ctx->Tensor4ArgNameAndIndex("d_weights", idx); if (idx != 0) { const user_op::Tensor* hidden = ctx->Tensor4ArgNameAndIndex("hidden", idx - 1); // here - hidden->shape().ToDimVector(&hidden_shape); + hidden->shape_view().ToDimVector(&hidden_shape); epilogue = CUBLASLT_EPILOGUE_DEFAULT; @@ -290,7 +303,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: offset += GetCudaAlignedSize(dy_shape.at(0) * dy_shape.at(1) * sizeof(T)); dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); } else { - x->shape().ToDimVector(&hidden_shape); + x->shape_view().ToDimVector(&hidden_shape); epilogue = CUBLASLT_EPILOGUE_DEFAULT; InferMatmulCublasMNK(dy_shape, hidden_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::T, @@ -313,19 +326,22 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: // Do Allreduce for d_bias and d_weight. // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. - OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); - OF_NCCL_CHECK(ncclGroupStart()); - OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), - d_bias->shape().elem_cnt(), GetNcclDataType(d_bias->data_type()), - ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream())); - OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), - d_weight->shape().elem_cnt(), - GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, - comm, kernel_state->allreduce_stream())); - OF_NCCL_CHECK(ncclGroupEnd()); - if (idx == 0) { - // We should sync allreduce before the kernel finish. - OF_CUDA_CHECK(cudaEventRecord(allreduce_event, kernel_state->allreduce_stream())); + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); + OF_NCCL_CHECK(ncclGroupStart()); + OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), + d_bias->shape_view().elem_cnt(), + GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), + d_weight->shape_view().elem_cnt(), + GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + OF_NCCL_CHECK(ncclGroupEnd()); + if (idx == 0) { + // We should sync allreduce before the kernel finish. + OF_CUDA_CHECK(cudaEventRecord(allreduce_event, kernel_state->allreduce_stream())); + } } } if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { From 84e4a645755d09eae8a115551469bd1c8a4558d7 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Fri, 24 Jun 2022 19:56:41 +0800 Subject: [PATCH 11/28] insert weight event roughly --- oneflow/user/kernels/cublas_fused_mlp_grad.cu | 315 ++++++++++++++++-- 1 file changed, 289 insertions(+), 26 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad.cu b/oneflow/user/kernels/cublas_fused_mlp_grad.cu index 3b0c4c9b33d..d85466fbbf6 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad.cu @@ -52,6 +52,8 @@ class MatmulGradKernelState final : public user_op::OpKernelState { if(!comm_){ printf("Here no create. \n"); return false; + } else { + printf("Here create. \n"); } return true; } @@ -89,6 +91,273 @@ class MatmulGradKernelState final : public user_op::OpKernelState { std::string stream_name_; }; +// template +// class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { +// public: +// CublasFusedMLPGradKernel() { +// OF_CUDA_CHECK(cudaEventCreate(&main_stream_event)); +// OF_CUDA_CHECK(cudaEventCreate(&async_weight_grad_event)); +// OF_CUDA_CHECK(cudaEventCreate(&dweight_event)); +// OF_CUDA_CHECK(cudaEventCreate(&allreduce_event)); +// }; +// ~CublasFusedMLPGradKernel() override { +// OF_CUDA_CHECK(cudaEventDestroy(main_stream_event)); +// OF_CUDA_CHECK(cudaEventDestroy(async_weight_grad_event)); +// OF_CUDA_CHECK(cudaEventDestroy(dweight_event)); +// OF_CUDA_CHECK(cudaEventDestroy(allreduce_event)); +// }; + +// std::shared_ptr InitOpKernelCache( +// user_op::KernelCacheContext* ctx) const override { +// return CreateCublasFusedMLPKernelCache(); +// } + +// std::shared_ptr CreateOpKernelState( +// user_op::KernelInitContext* ctx) const override { +// return std::make_shared(ctx); +// } + +// private: +// cudaEvent_t main_stream_event; +// cudaEvent_t async_weight_grad_event; +// cudaEvent_t dweight_event; +// cudaEvent_t allreduce_event; + +// bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { +// auto* kernel_state = dynamic_cast(state); +// return kernel_state->IfCommCreate(); +// } + +// using user_op::OpKernel::Compute; +// void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, +// const user_op::OpKernelCache* cache) const override { +// const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); +// const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); +// user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); +// const int64_t weight_num = ctx->input_size("weights"); +// user_op::Tensor* d_grad = ctx->Tensor4ArgNameAndIndex("d_grad", 0); +// // just a placeholder. +// user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); +// user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); + +// auto* kernel_state = dynamic_cast(state); +// ncclComm_t comm = kernel_state->comm(); + +// void* dy_tmp_buf = tmp_buffer->mut_dptr(); +// size_t offset = 0; +// const auto* matmul_grad_cache = +// CHECK_NOTNULL(dynamic_cast(cache)); +// auto* cuda_stream = ctx->stream()->As(); + +// const DataType data_type = dy->data_type(); +// const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type); +// const cudaDataType_t cuda_data_type = GetCudaDataType(data_type); +// size_t cublas_m = 0, cublas_n = 0, cublas_k = 0; +// int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0; + +// double alpha = 1.0; +// auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); +// double beta = 0.0; +// auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); + +// cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; // = CUBLASLT_EPILOGUE_DRELU_BGRAD + +// // currently only support 2D matmul. +// DimVector weight_shape(2); +// DimVector hidden_shape(2); +// DimVector dy_shape(2); +// dy->shape_view().ToDimVector(&dy_shape); +// const void* dgrad_buf = dy->dptr(); + +// const int64_t batch_size = dy->shape_view().At(0); +// const void* ones = nullptr; +// auto* cuda_device = dynamic_cast(ctx->stream()->device()); +// if (cuda_device != nullptr) { +// ones = cuda_device->GetConstOnes(dy->data_type(), batch_size); +// } else { +// ones = dy_tmp_buf; +// offset += batch_size; +// dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); +// } + +// for (int idx = weight_num - 1; idx > -1; idx--) { +// if (idx != 0) { +// const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", idx); +// const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex("cublas_aux", idx - 1); +// d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx - 1); + +// weight->shape_view().ToDimVector(&weight_shape); +// epilogue = CUBLASLT_EPILOGUE_DRELU_BGRAD; +// InferMatmulCublasMNK(dy_shape, weight_shape, +// /*transpose_a=*/ep::primitive::BlasTransposeType::N, +// /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, +// &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); +// SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/true, +// /*transpose_a=*/ep::primitive::BlasTransposeType::N, +// /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, +// d_bias->mut_dptr(), aux->dptr(), cublas_m, cublas_n, cublas_k, cublas_lda, +// cublas_ldb, cublas_ldc); +// /* +// a = dy, b = weight +// cublas_a=weight, cublas_b=dy +// */ +// OF_CUDA_CHECK(cudaEventRecord(main_stream_event, cuda_stream->cuda_stream())); +// OF_CUBLAS_CHECK(cublasLtMatmul( +// cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, +// weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, +// matmul_grad_cache->cublas_b_desc, &sp_beta, dy_tmp_buf, +// matmul_grad_cache->cublas_c_desc, dy_tmp_buf, matmul_grad_cache->cublas_c_desc, nullptr, +// cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), +// cuda_stream->cuda_stream())); +// } else { +// const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", 0); +// weight->shape_view().ToDimVector(&weight_shape); +// epilogue = CUBLASLT_EPILOGUE_DEFAULT; +// InferMatmulCublasMNK(dy_shape, weight_shape, +// /*transpose_a=*/ep::primitive::BlasTransposeType::N, +// /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, +// &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); +// SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, +// /*transpose_a=*/ep::primitive::BlasTransposeType::N, +// /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, +// nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); +// /* +// a = dy, b = weight +// cublas_a=weight, cublas_b=dy +// */ +// // OF_CUDA_CHECK(cudaEventRecord(main_stream_event, cuda_stream->cuda_stream())); +// OF_CUBLAS_CHECK(cublasLtMatmul( +// cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, +// weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, +// matmul_grad_cache->cublas_b_desc, &sp_beta, d_grad->mut_dptr(), +// matmul_grad_cache->cublas_c_desc, d_grad->mut_dptr(), matmul_grad_cache->cublas_c_desc, +// nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), +// cuda_stream->cuda_stream())); +// } +// alpha = 1.0; +// sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); +// beta = 0.0; +// sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); + +// // currently only support 2D matmul. +// // step1: Get last layer's dbias. +// if (idx == weight_num - 1) { +// DimVector ones_buf_shape(2); +// ones_buf_shape.at(0) = 1; +// ones_buf_shape.at(1) = batch_size; + +// epilogue = CUBLASLT_EPILOGUE_DEFAULT; +// InferMatmulCublasMNK(ones_buf_shape, dy_shape, +// /*transpose_a=*/ep::primitive::BlasTransposeType::N, +// /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, +// &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); +// SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, +// /*transpose_a=*/ep::primitive::BlasTransposeType::N, +// /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, +// nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); +// OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); +// OF_CUBLAS_CHECK(cublasLtMatmul( +// kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, +// dgrad_buf, matmul_grad_cache->cublas_a_desc, ones, matmul_grad_cache->cublas_b_desc, +// &sp_beta, d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, +// d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, +// kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), +// kernel_state->cuda_stream())); +// } + +// user_op::Tensor* d_weight = ctx->Tensor4ArgNameAndIndex("d_weights", idx); +// if (idx != 0) { +// const user_op::Tensor* hidden = ctx->Tensor4ArgNameAndIndex("hidden", idx - 1); // here +// hidden->shape_view().ToDimVector(&hidden_shape); + +// epilogue = CUBLASLT_EPILOGUE_DEFAULT; + +// InferMatmulCublasMNK(dy_shape, hidden_shape, +// /*transpose_a=*/ep::primitive::BlasTransposeType::T, +// /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, +// &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); + +// SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, +// /*transpose_a=*/ep::primitive::BlasTransposeType::T, +// /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, +// nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); + +// if (idx != weight_num - 1) { +// // if idx == weight_num - 1, async_stream has wait main_stream_event in d_bias. +// OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); +// } + +// OF_CUBLAS_CHECK(cublasLtMatmul( +// kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, +// hidden->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, +// matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(), +// matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), +// matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), +// kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); + +// if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { +// OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); +// } +// // compute dy shape +// dy_shape.at(1) = weight_shape.at(1); +// // compute dybuf +// dgrad_buf = dy_tmp_buf; +// offset += GetCudaAlignedSize(dy_shape.at(0) * dy_shape.at(1) * sizeof(T)); +// dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); +// } else { +// x->shape_view().ToDimVector(&hidden_shape); +// epilogue = CUBLASLT_EPILOGUE_DEFAULT; +// InferMatmulCublasMNK(dy_shape, hidden_shape, +// /*transpose_a=*/ep::primitive::BlasTransposeType::T, +// /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, +// &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); +// SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, +// /*transpose_a=*/ep::primitive::BlasTransposeType::T, +// /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, +// nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); +// OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); +// OF_CUBLAS_CHECK(cublasLtMatmul( +// kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, +// x->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, +// matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(), +// matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), +// matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), +// kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); +// if(!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)){ +// OF_CUDA_CHECK(cudaEventRecord(async_weight_grad_event, kernel_state->cuda_stream())); +// } +// } + +// // Do Allreduce for d_bias and d_weight. +// // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. +// if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { +// OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); +// OF_NCCL_CHECK(ncclGroupStart()); +// OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), +// d_bias->shape_view().elem_cnt(), +// GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, +// comm, kernel_state->allreduce_stream())); +// OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), +// d_weight->shape_view().elem_cnt(), +// GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, +// comm, kernel_state->allreduce_stream())); +// OF_NCCL_CHECK(ncclGroupEnd()); +// if (idx == 0) { +// // We should sync allreduce before the kernel finish. +// OF_CUDA_CHECK(cudaEventRecord(allreduce_event, kernel_state->allreduce_stream())); +// } +// } +// } +// if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { +// OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event)); +// } else { +// OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), async_weight_grad_event)); +// } +// }; + +// bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +// }; + template class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: @@ -139,7 +408,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); auto* kernel_state = dynamic_cast(state); - // ncclComm_t comm = kernel_state->comm(); + ncclComm_t comm = kernel_state->comm(); void* dy_tmp_buf = tmp_buffer->mut_dptr(); size_t offset = 0; @@ -293,9 +562,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); - if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); - } + OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); + // compute dy shape dy_shape.at(1) = weight_shape.at(1); // compute dybuf @@ -321,39 +589,34 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); - OF_CUDA_CHECK(cudaEventRecord(async_weight_grad_event, kernel_state->cuda_stream())); + OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); } // Do Allreduce for d_bias and d_weight. // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. - if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); - OF_NCCL_CHECK(ncclGroupStart()); - OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), - d_bias->shape_view().elem_cnt(), - GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, - comm, kernel_state->allreduce_stream())); - OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), - d_weight->shape_view().elem_cnt(), - GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, - comm, kernel_state->allreduce_stream())); - OF_NCCL_CHECK(ncclGroupEnd()); - if (idx == 0) { - // We should sync allreduce before the kernel finish. - OF_CUDA_CHECK(cudaEventRecord(allreduce_event, kernel_state->allreduce_stream())); - } + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); + OF_NCCL_CHECK(ncclGroupStart()); + OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), + d_bias->shape_view().elem_cnt(), + GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), + d_weight->shape_view().elem_cnt(), + GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + OF_NCCL_CHECK(ncclGroupEnd()); + if (idx == 0) { + // We should sync allreduce before the kernel finish. + OF_CUDA_CHECK(cudaEventRecord(allreduce_event, kernel_state->allreduce_stream())); } } - if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event)); - } else { - OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), async_weight_grad_event)); - } + OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event)); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; + #define REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("cublas_fused_mlp_grad") \ .SetCreateFn>() \ From 9086ba623548d7f730e58450f74cf99dafe4d423 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Fri, 24 Jun 2022 20:44:28 +0800 Subject: [PATCH 12/28] fix dbias allreduce error --- oneflow/user/kernels/cublas_fused_mlp_grad.cu | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad.cu b/oneflow/user/kernels/cublas_fused_mlp_grad.cu index d85466fbbf6..e6f828eb333 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad.cu @@ -448,7 +448,9 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: } for (int idx = weight_num - 1; idx > -1; idx--) { + printf("Here idx is: %d \n", idx); if (idx != 0) { + printf("Here BGRAD idx is: %d \n", idx); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", idx); const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex("cublas_aux", idx - 1); d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx - 1); @@ -477,6 +479,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); } else { + printf("Here Dx idx is: %d \n", idx); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", 0); weight->shape_view().ToDimVector(&weight_shape); epilogue = CUBLASLT_EPILOGUE_DEFAULT; @@ -509,6 +512,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: // currently only support 2D matmul. // step1: Get last layer's dbias. if (idx == weight_num - 1) { + printf("Here last dbias idx is: %d \n", idx); + DimVector ones_buf_shape(2); ones_buf_shape.at(0) = 1; ones_buf_shape.at(1) = batch_size; @@ -534,6 +539,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: user_op::Tensor* d_weight = ctx->Tensor4ArgNameAndIndex("d_weights", idx); if (idx != 0) { + printf("Here dw idx is: %d \n", idx); + const user_op::Tensor* hidden = ctx->Tensor4ArgNameAndIndex("hidden", idx - 1); // here hidden->shape_view().ToDimVector(&hidden_shape); @@ -571,6 +578,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: offset += GetCudaAlignedSize(dy_shape.at(0) * dy_shape.at(1) * sizeof(T)); dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); } else { + printf("Here dw idx is: %d \n", idx); + x->shape_view().ToDimVector(&hidden_shape); epilogue = CUBLASLT_EPILOGUE_DEFAULT; InferMatmulCublasMNK(dy_shape, hidden_shape, @@ -596,10 +605,22 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); OF_NCCL_CHECK(ncclGroupStart()); - OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), - d_bias->shape_view().elem_cnt(), - GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, - comm, kernel_state->allreduce_stream())); + if(idx == weight_num - 1){ + printf("Here last dbias allreduce, idx is: %d \n", idx); + + OF_NCCL_CHECK(ncclAllReduce(d_last_bias->mut_dptr(), d_last_bias->mut_dptr(), + d_last_bias->shape_view().elem_cnt(), + GetNcclDataType(d_last_bias->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + } else { + printf("Here dbias allreduce, idx is: %d \n", idx); + + OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), + d_bias->shape_view().elem_cnt(), + GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + } + printf("Here dw allreduce, idx is: %d \n", idx); OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), d_weight->shape_view().elem_cnt(), GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, From 45db76ecedf3952b7921495c77e9dffab4e6bc34 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Mon, 27 Jun 2022 14:05:10 +0800 Subject: [PATCH 13/28] simplify code --- oneflow/user/kernels/cublas_fused_mlp_grad.cu | 386 +++--------------- 1 file changed, 46 insertions(+), 340 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad.cu b/oneflow/user/kernels/cublas_fused_mlp_grad.cu index e6f828eb333..91fbc036761 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad.cu @@ -48,14 +48,9 @@ class MatmulGradKernelState final : public user_op::OpKernelState { size_t cublas_workspace_size() const { return 8 * 1024 * 1024; } void* cublas_workspace() const { return workspace_; } ncclComm_t comm() { return GetOrCreate().comm; } - bool IfCommCreate(){ - if(!comm_){ - printf("Here no create. \n"); - return false; - } else { - printf("Here create. \n"); - } - return true; + bool IfCommCreate() { + if (!comm_) { return false; } + return true; } private: @@ -91,273 +86,6 @@ class MatmulGradKernelState final : public user_op::OpKernelState { std::string stream_name_; }; -// template -// class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { -// public: -// CublasFusedMLPGradKernel() { -// OF_CUDA_CHECK(cudaEventCreate(&main_stream_event)); -// OF_CUDA_CHECK(cudaEventCreate(&async_weight_grad_event)); -// OF_CUDA_CHECK(cudaEventCreate(&dweight_event)); -// OF_CUDA_CHECK(cudaEventCreate(&allreduce_event)); -// }; -// ~CublasFusedMLPGradKernel() override { -// OF_CUDA_CHECK(cudaEventDestroy(main_stream_event)); -// OF_CUDA_CHECK(cudaEventDestroy(async_weight_grad_event)); -// OF_CUDA_CHECK(cudaEventDestroy(dweight_event)); -// OF_CUDA_CHECK(cudaEventDestroy(allreduce_event)); -// }; - -// std::shared_ptr InitOpKernelCache( -// user_op::KernelCacheContext* ctx) const override { -// return CreateCublasFusedMLPKernelCache(); -// } - -// std::shared_ptr CreateOpKernelState( -// user_op::KernelInitContext* ctx) const override { -// return std::make_shared(ctx); -// } - -// private: -// cudaEvent_t main_stream_event; -// cudaEvent_t async_weight_grad_event; -// cudaEvent_t dweight_event; -// cudaEvent_t allreduce_event; - -// bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { -// auto* kernel_state = dynamic_cast(state); -// return kernel_state->IfCommCreate(); -// } - -// using user_op::OpKernel::Compute; -// void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, -// const user_op::OpKernelCache* cache) const override { -// const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); -// const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); -// user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); -// const int64_t weight_num = ctx->input_size("weights"); -// user_op::Tensor* d_grad = ctx->Tensor4ArgNameAndIndex("d_grad", 0); -// // just a placeholder. -// user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); -// user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); - -// auto* kernel_state = dynamic_cast(state); -// ncclComm_t comm = kernel_state->comm(); - -// void* dy_tmp_buf = tmp_buffer->mut_dptr(); -// size_t offset = 0; -// const auto* matmul_grad_cache = -// CHECK_NOTNULL(dynamic_cast(cache)); -// auto* cuda_stream = ctx->stream()->As(); - -// const DataType data_type = dy->data_type(); -// const cublasComputeType_t cublas_compute_dtype = GetComputeType(data_type); -// const cudaDataType_t cuda_data_type = GetCudaDataType(data_type); -// size_t cublas_m = 0, cublas_n = 0, cublas_k = 0; -// int64_t cublas_lda = 0, cublas_ldb = 0, cublas_ldc = 0; - -// double alpha = 1.0; -// auto sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); -// double beta = 0.0; -// auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); - -// cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; // = CUBLASLT_EPILOGUE_DRELU_BGRAD - -// // currently only support 2D matmul. -// DimVector weight_shape(2); -// DimVector hidden_shape(2); -// DimVector dy_shape(2); -// dy->shape_view().ToDimVector(&dy_shape); -// const void* dgrad_buf = dy->dptr(); - -// const int64_t batch_size = dy->shape_view().At(0); -// const void* ones = nullptr; -// auto* cuda_device = dynamic_cast(ctx->stream()->device()); -// if (cuda_device != nullptr) { -// ones = cuda_device->GetConstOnes(dy->data_type(), batch_size); -// } else { -// ones = dy_tmp_buf; -// offset += batch_size; -// dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); -// } - -// for (int idx = weight_num - 1; idx > -1; idx--) { -// if (idx != 0) { -// const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", idx); -// const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex("cublas_aux", idx - 1); -// d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx - 1); - -// weight->shape_view().ToDimVector(&weight_shape); -// epilogue = CUBLASLT_EPILOGUE_DRELU_BGRAD; -// InferMatmulCublasMNK(dy_shape, weight_shape, -// /*transpose_a=*/ep::primitive::BlasTransposeType::N, -// /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, -// &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); -// SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/true, -// /*transpose_a=*/ep::primitive::BlasTransposeType::N, -// /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, -// d_bias->mut_dptr(), aux->dptr(), cublas_m, cublas_n, cublas_k, cublas_lda, -// cublas_ldb, cublas_ldc); -// /* -// a = dy, b = weight -// cublas_a=weight, cublas_b=dy -// */ -// OF_CUDA_CHECK(cudaEventRecord(main_stream_event, cuda_stream->cuda_stream())); -// OF_CUBLAS_CHECK(cublasLtMatmul( -// cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, -// weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, -// matmul_grad_cache->cublas_b_desc, &sp_beta, dy_tmp_buf, -// matmul_grad_cache->cublas_c_desc, dy_tmp_buf, matmul_grad_cache->cublas_c_desc, nullptr, -// cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), -// cuda_stream->cuda_stream())); -// } else { -// const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", 0); -// weight->shape_view().ToDimVector(&weight_shape); -// epilogue = CUBLASLT_EPILOGUE_DEFAULT; -// InferMatmulCublasMNK(dy_shape, weight_shape, -// /*transpose_a=*/ep::primitive::BlasTransposeType::N, -// /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, -// &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); -// SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, -// /*transpose_a=*/ep::primitive::BlasTransposeType::N, -// /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, -// nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); -// /* -// a = dy, b = weight -// cublas_a=weight, cublas_b=dy -// */ -// // OF_CUDA_CHECK(cudaEventRecord(main_stream_event, cuda_stream->cuda_stream())); -// OF_CUBLAS_CHECK(cublasLtMatmul( -// cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, -// weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, -// matmul_grad_cache->cublas_b_desc, &sp_beta, d_grad->mut_dptr(), -// matmul_grad_cache->cublas_c_desc, d_grad->mut_dptr(), matmul_grad_cache->cublas_c_desc, -// nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), -// cuda_stream->cuda_stream())); -// } -// alpha = 1.0; -// sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); -// beta = 0.0; -// sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); - -// // currently only support 2D matmul. -// // step1: Get last layer's dbias. -// if (idx == weight_num - 1) { -// DimVector ones_buf_shape(2); -// ones_buf_shape.at(0) = 1; -// ones_buf_shape.at(1) = batch_size; - -// epilogue = CUBLASLT_EPILOGUE_DEFAULT; -// InferMatmulCublasMNK(ones_buf_shape, dy_shape, -// /*transpose_a=*/ep::primitive::BlasTransposeType::N, -// /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, -// &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); -// SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, -// /*transpose_a=*/ep::primitive::BlasTransposeType::N, -// /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, -// nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); -// OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); -// OF_CUBLAS_CHECK(cublasLtMatmul( -// kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, -// dgrad_buf, matmul_grad_cache->cublas_a_desc, ones, matmul_grad_cache->cublas_b_desc, -// &sp_beta, d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, -// d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, -// kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), -// kernel_state->cuda_stream())); -// } - -// user_op::Tensor* d_weight = ctx->Tensor4ArgNameAndIndex("d_weights", idx); -// if (idx != 0) { -// const user_op::Tensor* hidden = ctx->Tensor4ArgNameAndIndex("hidden", idx - 1); // here -// hidden->shape_view().ToDimVector(&hidden_shape); - -// epilogue = CUBLASLT_EPILOGUE_DEFAULT; - -// InferMatmulCublasMNK(dy_shape, hidden_shape, -// /*transpose_a=*/ep::primitive::BlasTransposeType::T, -// /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, -// &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); - -// SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, -// /*transpose_a=*/ep::primitive::BlasTransposeType::T, -// /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, -// nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); - -// if (idx != weight_num - 1) { -// // if idx == weight_num - 1, async_stream has wait main_stream_event in d_bias. -// OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); -// } - -// OF_CUBLAS_CHECK(cublasLtMatmul( -// kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, -// hidden->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, -// matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(), -// matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), -// matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), -// kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); - -// if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { -// OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); -// } -// // compute dy shape -// dy_shape.at(1) = weight_shape.at(1); -// // compute dybuf -// dgrad_buf = dy_tmp_buf; -// offset += GetCudaAlignedSize(dy_shape.at(0) * dy_shape.at(1) * sizeof(T)); -// dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); -// } else { -// x->shape_view().ToDimVector(&hidden_shape); -// epilogue = CUBLASLT_EPILOGUE_DEFAULT; -// InferMatmulCublasMNK(dy_shape, hidden_shape, -// /*transpose_a=*/ep::primitive::BlasTransposeType::T, -// /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, -// &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); -// SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, -// /*transpose_a=*/ep::primitive::BlasTransposeType::T, -// /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, -// nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); -// OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); -// OF_CUBLAS_CHECK(cublasLtMatmul( -// kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, -// x->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, -// matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(), -// matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), -// matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), -// kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); -// if(!ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)){ -// OF_CUDA_CHECK(cudaEventRecord(async_weight_grad_event, kernel_state->cuda_stream())); -// } -// } - -// // Do Allreduce for d_bias and d_weight. -// // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. -// if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { -// OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); -// OF_NCCL_CHECK(ncclGroupStart()); -// OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), -// d_bias->shape_view().elem_cnt(), -// GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, -// comm, kernel_state->allreduce_stream())); -// OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), -// d_weight->shape_view().elem_cnt(), -// GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, -// comm, kernel_state->allreduce_stream())); -// OF_NCCL_CHECK(ncclGroupEnd()); -// if (idx == 0) { -// // We should sync allreduce before the kernel finish. -// OF_CUDA_CHECK(cudaEventRecord(allreduce_event, kernel_state->allreduce_stream())); -// } -// } -// } -// if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { -// OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event)); -// } else { -// OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), async_weight_grad_event)); -// } -// }; - -// bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -// }; - template class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: @@ -390,7 +118,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: cudaEvent_t dweight_event; cudaEvent_t allreduce_event; - bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { + bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache* cache) const override { auto* kernel_state = dynamic_cast(state); return kernel_state->IfCommCreate(); } @@ -408,7 +137,10 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); auto* kernel_state = dynamic_cast(state); - ncclComm_t comm = kernel_state->comm(); + ncclComm_t comm{}; + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + comm = kernel_state->comm(); + } void* dy_tmp_buf = tmp_buffer->mut_dptr(); size_t offset = 0; @@ -448,19 +180,16 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: } for (int idx = weight_num - 1; idx > -1; idx--) { - printf("Here idx is: %d \n", idx); + const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", idx); + weight->shape_view().ToDimVector(&weight_shape); + InferMatmulCublasMNK(dy_shape, weight_shape, + /*transpose_a=*/ep::primitive::BlasTransposeType::N, + /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, + &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); if (idx != 0) { - printf("Here BGRAD idx is: %d \n", idx); - const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", idx); const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex("cublas_aux", idx - 1); d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx - 1); - - weight->shape_view().ToDimVector(&weight_shape); epilogue = CUBLASLT_EPILOGUE_DRELU_BGRAD; - InferMatmulCublasMNK(dy_shape, weight_shape, - /*transpose_a=*/ep::primitive::BlasTransposeType::N, - /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, - &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/true, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, @@ -479,14 +208,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); } else { - printf("Here Dx idx is: %d \n", idx); - const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", 0); - weight->shape_view().ToDimVector(&weight_shape); epilogue = CUBLASLT_EPILOGUE_DEFAULT; - InferMatmulCublasMNK(dy_shape, weight_shape, - /*transpose_a=*/ep::primitive::BlasTransposeType::N, - /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, - &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/false, /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, @@ -506,18 +228,12 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: } alpha = 1.0; sp_alpha = GetCublasScalarParameter(alpha, cublas_compute_dtype); - beta = 0.0; - sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); - // currently only support 2D matmul. // step1: Get last layer's dbias. if (idx == weight_num - 1) { - printf("Here last dbias idx is: %d \n", idx); - DimVector ones_buf_shape(2); ones_buf_shape.at(0) = 1; ones_buf_shape.at(1) = batch_size; - epilogue = CUBLASLT_EPILOGUE_DEFAULT; InferMatmulCublasMNK(ones_buf_shape, dy_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::N, @@ -538,14 +254,10 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: } user_op::Tensor* d_weight = ctx->Tensor4ArgNameAndIndex("d_weights", idx); + epilogue = CUBLASLT_EPILOGUE_DEFAULT; if (idx != 0) { - printf("Here dw idx is: %d \n", idx); - const user_op::Tensor* hidden = ctx->Tensor4ArgNameAndIndex("hidden", idx - 1); // here hidden->shape_view().ToDimVector(&hidden_shape); - - epilogue = CUBLASLT_EPILOGUE_DEFAULT; - InferMatmulCublasMNK(dy_shape, hidden_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, @@ -555,12 +267,10 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); - if (idx != weight_num - 1) { // if idx == weight_num - 1, async_stream has wait main_stream_event in d_bias. OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); } - OF_CUBLAS_CHECK(cublasLtMatmul( kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, hidden->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, @@ -568,9 +278,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); - OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); - // compute dy shape dy_shape.at(1) = weight_shape.at(1); // compute dybuf @@ -578,10 +286,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: offset += GetCudaAlignedSize(dy_shape.at(0) * dy_shape.at(1) * sizeof(T)); dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); } else { - printf("Here dw idx is: %d \n", idx); - x->shape_view().ToDimVector(&hidden_shape); - epilogue = CUBLASLT_EPILOGUE_DEFAULT; InferMatmulCublasMNK(dy_shape, hidden_shape, /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, &cublas_m, @@ -601,43 +306,44 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); } - // Do Allreduce for d_bias and d_weight. - // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. - OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); - OF_NCCL_CHECK(ncclGroupStart()); - if(idx == weight_num - 1){ - printf("Here last dbias allreduce, idx is: %d \n", idx); - - OF_NCCL_CHECK(ncclAllReduce(d_last_bias->mut_dptr(), d_last_bias->mut_dptr(), - d_last_bias->shape_view().elem_cnt(), - GetNcclDataType(d_last_bias->data_type()), ncclRedOp_t::ncclSum, - comm, kernel_state->allreduce_stream())); - } else { - printf("Here dbias allreduce, idx is: %d \n", idx); - - OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), - d_bias->shape_view().elem_cnt(), - GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + // Do Allreduce for d_bias and d_weight. + // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); + OF_NCCL_CHECK(ncclGroupStart()); + if (idx == weight_num - 1) { + OF_NCCL_CHECK(ncclAllReduce( + d_last_bias->mut_dptr(), d_last_bias->mut_dptr(), + d_last_bias->shape_view().elem_cnt(), GetNcclDataType(d_last_bias->data_type()), + ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream())); + } else { + OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), + d_bias->shape_view().elem_cnt(), + GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, + comm, kernel_state->allreduce_stream())); + } + OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), + d_weight->shape_view().elem_cnt(), + GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream())); - } - printf("Here dw allreduce, idx is: %d \n", idx); - OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), - d_weight->shape_view().elem_cnt(), - GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, - comm, kernel_state->allreduce_stream())); - OF_NCCL_CHECK(ncclGroupEnd()); - if (idx == 0) { - // We should sync allreduce before the kernel finish. - OF_CUDA_CHECK(cudaEventRecord(allreduce_event, kernel_state->allreduce_stream())); + OF_NCCL_CHECK(ncclGroupEnd()); + if (idx == 0) { + // We should sync allreduce before the kernel finish. + OF_CUDA_CHECK(cudaEventRecord(allreduce_event, kernel_state->allreduce_stream())); + } } } - OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event)); + + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event)); + } else { + OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), dweight_event)); + } }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; - #define REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("cublas_fused_mlp_grad") \ .SetCreateFn>() \ From dfec4718bd2c7e4c0de0c2959b00ca80da1d201b Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Mon, 27 Jun 2022 15:05:11 +0800 Subject: [PATCH 14/28] Add 11060 limit --- oneflow/core/functional/impl/nn_grad_functor.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index 36d619ff58f..4fca5f41e76 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_builder.h" @@ -1137,7 +1138,10 @@ class FusedMLPGradFunctor { std::copy(weights.begin(), weights.end(), input.begin() + 2); std::copy(cublas_aux.begin(), cublas_aux.end(), input.begin() + 2 + weight_size); std::copy(hidden.begin(), hidden.end(), input.begin() + 2 + 2 * weight_size); +#if CUDA_VERSION >= 11060 return OpInterpUtil::Dispatch(*fused_op_[weight_size], input); +#endif + UNIMPLEMENTED_THEN_RETURN() << "Only Support in CUDA_VERSION >= 11060"; } private: From 4e3fa249e0162a1f9b3f620deff219d799947641 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Tue, 28 Jun 2022 15:40:58 +0800 Subject: [PATCH 15/28] Remove print --- oneflow/user/ops/cublas_fused_mlp_op.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/oneflow/user/ops/cublas_fused_mlp_op.cpp b/oneflow/user/ops/cublas_fused_mlp_op.cpp index 62ee7076501..c9e8dc6872e 100644 --- a/oneflow/user/ops/cublas_fused_mlp_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_op.cpp @@ -157,7 +157,6 @@ REGISTER_USER_OP_GRAD("cublas_fused_mlp") std::string cublas_dy = last_bias_grad; if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD", false)) { - printf("Here use fully fusedmlp grad \n"); // Use Fully Fused MLP Backward. user_op::UserOpConfWrapperBuilder fused_mlp_grad_builder(op.op_name() + "_fused_mlp_grad"); fused_mlp_grad_builder.Op("cublas_fused_mlp_grad") From 5115c697338e53d21597585c4aeba0e26528973e Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Tue, 28 Jun 2022 15:41:31 +0800 Subject: [PATCH 16/28] Rename --- .../{cublas_fused_mlp_grad.cu => cublas_fused_mlp_grad_kernel.cu} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename oneflow/user/kernels/{cublas_fused_mlp_grad.cu => cublas_fused_mlp_grad_kernel.cu} (100%) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad.cu b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu similarity index 100% rename from oneflow/user/kernels/cublas_fused_mlp_grad.cu rename to oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu From 68c1e3ecf25cb6b27aecae2616d0d9bcc17c6395 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Tue, 28 Jun 2022 16:26:18 +0800 Subject: [PATCH 17/28] fix fill bug and remove comm to cache --- .../kernels/cublas_fused_mlp_grad_kernel.cu | 76 +++++++------------ .../user/kernels/cublas_fused_mlp_util.cuh | 52 ++++++++++++- 2 files changed, 78 insertions(+), 50 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu index 91fbc036761..ba16f9aabe3 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu @@ -13,11 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/device/nccl_util.h" -#include "oneflow/core/job/eager_nccl_comm_manager.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/cublas_fused_mlp_util.cuh" +#include "oneflow/core/ep/include/primitive/fill.h" // CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link. #if CUDA_VERSION >= 11060 @@ -27,8 +26,7 @@ namespace { class MatmulGradKernelState final : public user_op::OpKernelState { public: - MatmulGradKernelState(user_op::KernelInitContext* ctx) - : parallel_desc_(ctx->parallel_desc()), stream_name_(EagerNcclCommMgr::kDefaultStreamName) { + MatmulGradKernelState(user_op::KernelInitContext* ctx) { OF_CUDA_CHECK(cudaStreamCreate(&cuda_stream_)); OF_CUDA_CHECK(cudaStreamCreate(&allreduce_stream_)); OF_CUBLAS_CHECK(cublasLtCreate(&cublas_lt_handle_)); @@ -47,43 +45,12 @@ class MatmulGradKernelState final : public user_op::OpKernelState { cublasLtHandle_t cublas_lt_handle() const { return cublas_lt_handle_; } size_t cublas_workspace_size() const { return 8 * 1024 * 1024; } void* cublas_workspace() const { return workspace_; } - ncclComm_t comm() { return GetOrCreate().comm; } - bool IfCommCreate() { - if (!comm_) { return false; } - return true; - } private: - struct Comm { - Comm(ncclComm_t comm) : comm(comm) {} - ncclComm_t comm; - }; - - const Comm& GetOrCreate() { - if (!comm_) { Init(); } - return *comm_; - } - - void Init() { - std::set> device_set; - for (int64_t parallel_id = 0; parallel_id < parallel_desc_.parallel_num(); ++parallel_id) { - int64_t machine_id = CHECK_JUST(parallel_desc_.MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(parallel_desc_.DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } - EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Global::Get()); - ncclComm_t comm; - comm = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_); - comm_.reset(new Comm(comm)); - } - cudaStream_t cuda_stream_{}; cudaStream_t allreduce_stream_{}; cublasLtHandle_t cublas_lt_handle_{}; void* workspace_{}; - std::unique_ptr comm_; - ParallelDesc parallel_desc_; - std::string stream_name_; }; template @@ -104,7 +71,9 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { - return CreateCublasFusedMLPKernelCache(); + std::shared_ptr kernel_cache = CreateCublasFusedMLPKernelCache(); + kernel_cache->Init(ctx); + return kernel_cache; } std::shared_ptr CreateOpKernelState( @@ -120,8 +89,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { - auto* kernel_state = dynamic_cast(state); - return kernel_state->IfCommCreate(); + auto* kernel_cache = dynamic_cast(state); + return kernel_cache->IfCommCreate(); } using user_op::OpKernel::Compute; @@ -137,15 +106,18 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); auto* kernel_state = dynamic_cast(state); + const auto* matmul_grad_cache = + CHECK_NOTNULL(dynamic_cast(cache)); + ncclComm_t comm{}; - if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - comm = kernel_state->comm(); + bool if_comm_create = matmul_grad_cache->IfCommCreate(); + if (if_comm_create + && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + comm = matmul_grad_cache->comm(); } void* dy_tmp_buf = tmp_buffer->mut_dptr(); size_t offset = 0; - const auto* matmul_grad_cache = - CHECK_NOTNULL(dynamic_cast(cache)); auto* cuda_stream = ctx->stream()->As(); const DataType data_type = dy->data_type(); @@ -171,11 +143,15 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: const int64_t batch_size = dy->shape_view().At(0); const void* ones = nullptr; auto* cuda_device = dynamic_cast(ctx->stream()->device()); - if (cuda_device != nullptr) { - ones = cuda_device->GetConstOnes(dy->data_type(), batch_size); - } else { - ones = dy_tmp_buf; - offset += batch_size; + if (cuda_device != nullptr) { ones = cuda_device->GetConstOnes(dy->data_type(), batch_size); } + if (ones == nullptr) { + std::unique_ptr fill = + ep::primitive::NewPrimitive(ctx->stream()->device_type(), + data_type); + CHECK(fill); + fill->Launch(ctx->stream(), tmp_buffer->mut_dptr(), 1.0, batch_size); + ones = tmp_buffer->mut_dptr(); + offset += GetCudaAlignedSize(batch_size * sizeof(T)); dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); } @@ -306,7 +282,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); } - if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + if (if_comm_create + && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { // Do Allreduce for d_bias and d_weight. // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); @@ -334,7 +311,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: } } - if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + if (if_comm_create + && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event)); } else { OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), dweight_event)); diff --git a/oneflow/user/kernels/cublas_fused_mlp_util.cuh b/oneflow/user/kernels/cublas_fused_mlp_util.cuh index d5b27523736..33d6fb5f6b4 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_util.cuh +++ b/oneflow/user/kernels/cublas_fused_mlp_util.cuh @@ -20,6 +20,8 @@ limitations under the License. #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include +#include "oneflow/core/device/nccl_util.h" +#include "oneflow/core/job/eager_nccl_comm_manager.h" // CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link. #if CUDA_VERSION >= 11060 @@ -40,9 +42,15 @@ long AlignReluAuxLd(long aux_ld) { * kAuxReluLdAlignRequirement; } +struct Comm { + Comm(ncclComm_t comm) : comm(comm) {} + ncclComm_t comm; +}; + class CublasFusedMLPKernelCache final : public user_op::OpKernelCache { public: - CublasFusedMLPKernelCache() { + CublasFusedMLPKernelCache() + : stream_name_(EagerNcclCommMgr::kDefaultStreamName), if_support_comm_(true) { // Just for init. OF_CUBLAS_CHECK(cublasLtMatmulDescCreate(&operation_desc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&cublas_a_desc, CUDA_R_32F, 1, 1, 1)); @@ -62,6 +70,48 @@ class CublasFusedMLPKernelCache final : public user_op::OpKernelCache { cublasLtMatrixLayout_t cublas_b_desc; cublasLtMatrixLayout_t cublas_c_desc; cublasLtMatmulPreference_t cublas_preference; + + bool IfCommCreate() const { + if (!comm_) { return false; } + return true; + } + ncclComm_t comm() const { return Get().comm; } + + void Init(user_op::KernelCacheContext* ctx) { + if (ctx->parallel_ctx().parallel_num() > 1) { + const int64_t d_weights_size = ctx->output_size("d_weights"); + for (int i = 0; i < d_weights_size; i++) { + if (!ctx->SbpParallel4ArgNameAndIndex("d_weights", i).has_broadcast_parallel() + || !ctx->SbpParallel4ArgNameAndIndex("d_biases", i).has_broadcast_parallel() + || !ctx->SbpParallel4ArgNameAndIndex("dy", 0).has_split_parallel()) { + if_support_comm_ = false; + break; + } + } + } else { + if_support_comm_ = false; + } + if (if_support_comm_ + && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + std::set> device_set; + for (int64_t parallel_id = 0; parallel_id < ctx->parallel_desc().parallel_num(); + ++parallel_id) { + int64_t machine_id = CHECK_JUST(ctx->parallel_desc().MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(ctx->parallel_desc().DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Global::Get()); + ncclComm_t comm; + comm = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_); + comm_.reset(new Comm(comm)); + } + } + + private: + const Comm& Get() const { return *comm_; } + std::string stream_name_; + std::unique_ptr comm_; + bool if_support_comm_; }; std::shared_ptr CreateCublasFusedMLPKernelCache() { From 87dfba963966352fecf5aae04eca3965ba6fe867 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 30 Jun 2022 11:46:22 +0800 Subject: [PATCH 18/28] Rename variable and add debug code for cache --- .../core/functional/impl/nn_grad_functor.cpp | 2 +- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 2 +- .../kernels/cublas_fused_mlp_grad_kernel.cu | 22 +++++++++++++++---- .../user/kernels/cublas_fused_mlp_util.cuh | 2 +- oneflow/user/ops/cublas_fused_mlp_grad_op.cpp | 6 ++--- oneflow/user/ops/cublas_fused_mlp_op.cpp | 4 ++-- 6 files changed, 26 insertions(+), 12 deletions(-) diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index 4fca5f41e76..ba5a603a5cd 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -1121,7 +1121,7 @@ class FusedMLPGradFunctor { .Input("weights", n) .Input("cublas_aux", n) .Input("hidden", n) - .Output("d_grad") + .Output("d_x") .Output("d_biases", n) .Output("d_weights", n) .Build()); diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index e1eefdf686c..d57ca7cadd8 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -4563,7 +4563,7 @@ def OneFlow_CublasFusedMLPGradOp : OneFlow_BaseOp<"cublas_fused_mlp_grad", [NoSi Variadic:$hidden ); let output = (outs - OneFlow_Tensor:$d_grad, + OneFlow_Tensor:$d_x, Variadic:$d_biases, Variadic:$d_weights ); diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu index ba16f9aabe3..6b684261245 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu @@ -72,6 +72,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { std::shared_ptr kernel_cache = CreateCublasFusedMLPKernelCache(); + // if (if_comm_create && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) kernel_cache->Init(ctx); return kernel_cache; } @@ -90,7 +91,14 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { auto* kernel_cache = dynamic_cast(state); - return kernel_cache->IfCommCreate(); + if(ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)){ + printf("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE! \n"); + return kernel_cache->IfCommCreate(); + } else { + printf("Always ready for capture! \n"); + return true; + } + // return kernel_cache->IfCommCreate(); } using user_op::OpKernel::Compute; @@ -100,7 +108,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t weight_num = ctx->input_size("weights"); - user_op::Tensor* d_grad = ctx->Tensor4ArgNameAndIndex("d_grad", 0); + user_op::Tensor* d_x = ctx->Tensor4ArgNameAndIndex("d_x", 0); // just a placeholder. user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); @@ -111,6 +119,12 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: ncclComm_t comm{}; bool if_comm_create = matmul_grad_cache->IfCommCreate(); + if(if_comm_create){ + printf("Comm create! \n"); + } else { + printf("No create \n"); + } + if (if_comm_create && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { comm = matmul_grad_cache->comm(); @@ -197,8 +211,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: OF_CUBLAS_CHECK(cublasLtMatmul( cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, - matmul_grad_cache->cublas_b_desc, &sp_beta, d_grad->mut_dptr(), - matmul_grad_cache->cublas_c_desc, d_grad->mut_dptr(), matmul_grad_cache->cublas_c_desc, + matmul_grad_cache->cublas_b_desc, &sp_beta, d_x->mut_dptr(), + matmul_grad_cache->cublas_c_desc, d_x->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, cuda_stream->cublas_workspace(), cuda_stream->cublas_workspace_size(), cuda_stream->cuda_stream())); } diff --git a/oneflow/user/kernels/cublas_fused_mlp_util.cuh b/oneflow/user/kernels/cublas_fused_mlp_util.cuh index 33d6fb5f6b4..6c24a01ad90 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_util.cuh +++ b/oneflow/user/kernels/cublas_fused_mlp_util.cuh @@ -100,7 +100,7 @@ class CublasFusedMLPKernelCache final : public user_op::OpKernelCache { int64_t device_id = CHECK_JUST(ctx->parallel_desc().DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } - EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Global::Get()); + EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); ncclComm_t comm; comm = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_); comm_.reset(new Comm(comm)); diff --git a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp index 9bf278d059b..c1f7d34f1e3 100644 --- a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp @@ -31,7 +31,7 @@ Maybe InferTensorDesc4FusedMatmulBackward(user_op::InferContext* ctx) { *ctx->OutputShape("d_weights", idx) = weight_desc.shape(); *ctx->OutputShape("d_biases", idx) = Shape({weight_desc.shape().At(0)}); } - *ctx->OutputShape("d_grad", 0) = x_desc.shape(); + *ctx->OutputShape("d_x", 0) = x_desc.shape(); return Maybe::Ok(); } @@ -47,7 +47,7 @@ Maybe InferDataType4MatmulBackward(user_op::InferContext* ctx) { *ctx->OutputDType("d_weights", idx) = dy_desc.data_type(); *ctx->OutputDType("d_biases", idx) = dy_desc.data_type(); } - *ctx->OutputDType("d_grad", 0) = dy_desc.data_type(); + *ctx->OutputDType("d_x", 0) = dy_desc.data_type(); return Maybe::Ok(); } @@ -74,7 +74,7 @@ Maybe InferDataType4MatmulBackward(user_op::InferContext* ctx) { builder.Split(user_op::OpArg("hidden", i), 0); } - builder.Split(user_op::OpArg("d_grad", 0), 0); + builder.Split(user_op::OpArg("d_x", 0), 0); if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { // FusedMLPGradKernel do allreduce for dbias and dweight, so here convert from PartialSum to // Broadcast. diff --git a/oneflow/user/ops/cublas_fused_mlp_op.cpp b/oneflow/user/ops/cublas_fused_mlp_op.cpp index c9e8dc6872e..f67e3ce624d 100644 --- a/oneflow/user/ops/cublas_fused_mlp_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_op.cpp @@ -162,7 +162,7 @@ REGISTER_USER_OP_GRAD("cublas_fused_mlp") fused_mlp_grad_builder.Op("cublas_fused_mlp_grad") .Input("dy", cublas_dy) .Input("x", op.input("x", 0)) - .Output("d_grad") + .Output("d_x") .Output("d_biases", weight_num) .Output("d_weights", weight_num); @@ -186,7 +186,7 @@ REGISTER_USER_OP_GRAD("cublas_fused_mlp") } } if (op.NeedGenGradTensor4OpInput("x", 0)) { - op.BindGradTensorWithOpInput(fused_mlp_grad_op.output("d_grad", 0), "x", 0); + op.BindGradTensorWithOpInput(fused_mlp_grad_op.output("d_x", 0), "x", 0); } } else { // step2: use reduce_sum to get last layer's bias grad. From 36b773fdd5ff8d5ee71fd5f38fe9eabba6a14792 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 30 Jun 2022 16:19:10 +0800 Subject: [PATCH 19/28] Use kernel state and fix bug --- .../kernels/cublas_fused_mlp_grad_kernel.cu | 101 +++++++++++++----- .../user/kernels/cublas_fused_mlp_util.cuh | 59 ++-------- 2 files changed, 83 insertions(+), 77 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu index 6b684261245..3573b957b5e 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu @@ -24,13 +24,22 @@ namespace oneflow { namespace { +struct Comm { + Comm(ncclComm_t comm) : comm(comm) {} + ncclComm_t comm; +}; + class MatmulGradKernelState final : public user_op::OpKernelState { public: - MatmulGradKernelState(user_op::KernelInitContext* ctx) { + MatmulGradKernelState(user_op::KernelInitContext* ctx) + : if_need_comm_(false), stream_name_(EagerNcclCommMgr::kDefaultStreamName) { OF_CUDA_CHECK(cudaStreamCreate(&cuda_stream_)); OF_CUDA_CHECK(cudaStreamCreate(&allreduce_stream_)); OF_CUBLAS_CHECK(cublasLtCreate(&cublas_lt_handle_)); - OF_CUDA_CHECK(cudaMalloc(&workspace_, 8 * 1024 * 1024)); + workspace_size_ = + ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CUBLAS_WORKSPACE_SIZE_MB", kDefaultWorkspaceSizeMb) + * 1024 * 1024; + OF_CUDA_CHECK(cudaMalloc(&workspace_, workspace_size_)); } ~MatmulGradKernelState() { OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); @@ -43,14 +52,58 @@ class MatmulGradKernelState final : public user_op::OpKernelState { cudaStream_t cuda_stream() const { return cuda_stream_; } cudaStream_t allreduce_stream() const { return allreduce_stream_; } cublasLtHandle_t cublas_lt_handle() const { return cublas_lt_handle_; } - size_t cublas_workspace_size() const { return 8 * 1024 * 1024; } + size_t cublas_workspace_size() const { return workspace_size_; } void* cublas_workspace() const { return workspace_; } + bool IfCommCreate() const { + if (!comm_) { return false; } + return true; + } + + bool IfNeedComm() const { return if_need_comm_; } + + ncclComm_t comm() const { return comm_->comm; } + + void InitNeedComm(user_op::KernelInitContext* ctx) { + if_need_comm_ = true; + if (ctx->parallel_ctx().parallel_num() > 1) { + const int64_t d_weights_size = ctx->output_size("d_weights"); + for (int i = 0; i < d_weights_size; i++) { + if (!ctx->SbpParallel4ArgNameAndIndex("d_weights", i).has_broadcast_parallel() + || !ctx->SbpParallel4ArgNameAndIndex("d_biases", i).has_broadcast_parallel() + || !ctx->SbpParallel4ArgNameAndIndex("dy", 0).has_split_parallel()) { + if_need_comm_ = false; + break; + } + } + } else { + if_need_comm_ = false; + } + } + + void InitCommMgr(user_op::KernelInitContext* ctx) { + std::set> device_set; + for (int64_t parallel_id = 0; parallel_id < ctx->parallel_desc().parallel_num(); + ++parallel_id) { + int64_t machine_id = CHECK_JUST(ctx->parallel_desc().MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(ctx->parallel_desc().DeviceId4ParallelId(parallel_id)); + device_set.emplace(std::make_pair(machine_id, device_id)); + } + EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); + ncclComm_t comm; + comm = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_); + comm_.reset(new Comm(comm)); + } + private: cudaStream_t cuda_stream_{}; cudaStream_t allreduce_stream_{}; cublasLtHandle_t cublas_lt_handle_{}; void* workspace_{}; + size_t workspace_size_; + std::string stream_name_; + std::unique_ptr comm_; + bool if_need_comm_; }; template @@ -71,15 +124,18 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: std::shared_ptr InitOpKernelCache( user_op::KernelCacheContext* ctx) const override { - std::shared_ptr kernel_cache = CreateCublasFusedMLPKernelCache(); - // if (if_comm_create && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) - kernel_cache->Init(ctx); - return kernel_cache; + return CreateCublasFusedMLPKernelCache(); } std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + std::shared_ptr kernel_state = + std::make_shared(ctx); + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + kernel_state->InitNeedComm(ctx); + if (kernel_state->IfNeedComm()) { kernel_state->InitCommMgr(ctx); } + } + return kernel_state; } private: @@ -90,15 +146,14 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { - auto* kernel_cache = dynamic_cast(state); - if(ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)){ - printf("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE! \n"); - return kernel_cache->IfCommCreate(); + auto* kernel_state = dynamic_cast(state); + if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + printf("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE! \n"); + return kernel_state->IfCommCreate(); } else { - printf("Always ready for capture! \n"); - return true; + printf("Always ready for capture! \n"); + return true; } - // return kernel_cache->IfCommCreate(); } using user_op::OpKernel::Compute; @@ -118,16 +173,11 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: CHECK_NOTNULL(dynamic_cast(cache)); ncclComm_t comm{}; - bool if_comm_create = matmul_grad_cache->IfCommCreate(); - if(if_comm_create){ - printf("Comm create! \n"); - } else { - printf("No create \n"); - } + bool if_need_comm = kernel_state->IfNeedComm(); - if (if_comm_create + if (if_need_comm && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - comm = matmul_grad_cache->comm(); + comm = kernel_state->comm(); } void* dy_tmp_buf = tmp_buffer->mut_dptr(); @@ -296,8 +346,9 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); } - if (if_comm_create + if (if_need_comm && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + // printf("Here need comm and overlap allreduce. \n"); // Do Allreduce for d_bias and d_weight. // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); @@ -325,7 +376,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: } } - if (if_comm_create + if (if_need_comm && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event)); } else { diff --git a/oneflow/user/kernels/cublas_fused_mlp_util.cuh b/oneflow/user/kernels/cublas_fused_mlp_util.cuh index 6c24a01ad90..9c4d330d395 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_util.cuh +++ b/oneflow/user/kernels/cublas_fused_mlp_util.cuh @@ -42,15 +42,9 @@ long AlignReluAuxLd(long aux_ld) { * kAuxReluLdAlignRequirement; } -struct Comm { - Comm(ncclComm_t comm) : comm(comm) {} - ncclComm_t comm; -}; - class CublasFusedMLPKernelCache final : public user_op::OpKernelCache { public: - CublasFusedMLPKernelCache() - : stream_name_(EagerNcclCommMgr::kDefaultStreamName), if_support_comm_(true) { + CublasFusedMLPKernelCache() { // Just for init. OF_CUBLAS_CHECK(cublasLtMatmulDescCreate(&operation_desc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); OF_CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&cublas_a_desc, CUDA_R_32F, 1, 1, 1)); @@ -70,48 +64,6 @@ class CublasFusedMLPKernelCache final : public user_op::OpKernelCache { cublasLtMatrixLayout_t cublas_b_desc; cublasLtMatrixLayout_t cublas_c_desc; cublasLtMatmulPreference_t cublas_preference; - - bool IfCommCreate() const { - if (!comm_) { return false; } - return true; - } - ncclComm_t comm() const { return Get().comm; } - - void Init(user_op::KernelCacheContext* ctx) { - if (ctx->parallel_ctx().parallel_num() > 1) { - const int64_t d_weights_size = ctx->output_size("d_weights"); - for (int i = 0; i < d_weights_size; i++) { - if (!ctx->SbpParallel4ArgNameAndIndex("d_weights", i).has_broadcast_parallel() - || !ctx->SbpParallel4ArgNameAndIndex("d_biases", i).has_broadcast_parallel() - || !ctx->SbpParallel4ArgNameAndIndex("dy", 0).has_split_parallel()) { - if_support_comm_ = false; - break; - } - } - } else { - if_support_comm_ = false; - } - if (if_support_comm_ - && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - std::set> device_set; - for (int64_t parallel_id = 0; parallel_id < ctx->parallel_desc().parallel_num(); - ++parallel_id) { - int64_t machine_id = CHECK_JUST(ctx->parallel_desc().MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(ctx->parallel_desc().DeviceId4ParallelId(parallel_id)); - device_set.emplace(std::make_pair(machine_id, device_id)); - } - EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - ncclComm_t comm; - comm = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_); - comm_.reset(new Comm(comm)); - } - } - - private: - const Comm& Get() const { return *comm_; } - std::string stream_name_; - std::unique_ptr comm_; - bool if_support_comm_; }; std::shared_ptr CreateCublasFusedMLPKernelCache() { @@ -265,9 +217,12 @@ void SetCublasAttr(const CublasFusedMLPKernelCache* matmul_grad_cache, matmul_grad_cache->operation_desc, CUBLASLT_MATMUL_DESC_COMPUTE_TYPE, &cublas_compute_dtype, sizeof(cublas_compute_dtype))); - OF_CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute( - matmul_grad_cache->cublas_preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &kDefaultWorkspaceSize, sizeof(kDefaultWorkspaceSize))); + size_t workspace_size = + ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CUBLAS_WORKSPACE_SIZE_MB", kDefaultWorkspaceSizeMb) + * 1024 * 1024; + OF_CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(matmul_grad_cache->cublas_preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, sizeof(workspace_size))); uint32_t pointer_mode = CUBLASLT_POINTER_MODE_MASK_HOST; OF_CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(matmul_grad_cache->cublas_preference, From e2e2069d1586164536944c0a3c510fd563afa4fa Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 30 Jun 2022 16:25:01 +0800 Subject: [PATCH 20/28] remove print --- oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu index 3573b957b5e..21718aad8df 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu @@ -148,10 +148,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: const user_op::OpKernelCache* cache) const override { auto* kernel_state = dynamic_cast(state); if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - printf("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE! \n"); return kernel_state->IfCommCreate(); } else { - printf("Always ready for capture! \n"); return true; } } From 6e60e9c0b7269431779e409b3d7224b59b25fbe8 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 30 Jun 2022 18:02:35 +0800 Subject: [PATCH 21/28] fix allreduce dbias bug --- .../kernels/cublas_fused_mlp_grad_kernel.cu | 22 ++++++------------- .../user/kernels/cublas_fused_mlp_util.cuh | 4 +--- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu index 21718aad8df..9c329ba9c8c 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu @@ -162,9 +162,6 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t weight_num = ctx->input_size("weights"); user_op::Tensor* d_x = ctx->Tensor4ArgNameAndIndex("d_x", 0); - // just a placeholder. - user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); - user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); auto* kernel_state = dynamic_cast(state); const auto* matmul_grad_cache = @@ -226,7 +223,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: &cublas_n, &cublas_k, &cublas_lda, &cublas_ldb, &cublas_ldc); if (idx != 0) { const user_op::Tensor* aux = ctx->Tensor4ArgNameAndIndex("cublas_aux", idx - 1); - d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx - 1); + user_op::Tensor* d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx - 1); epilogue = CUBLASLT_EPILOGUE_DRELU_BGRAD; SetCublasAttr(matmul_grad_cache, cublas_compute_dtype, cuda_data_type, /*need_aux=*/true, /*transpose_a=*/ep::primitive::BlasTransposeType::N, @@ -269,6 +266,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: // step1: Get last layer's dbias. if (idx == weight_num - 1) { + user_op::Tensor* d_last_bias = ctx->Tensor4ArgNameAndIndex("d_biases", weight_num - 1); DimVector ones_buf_shape(2); ones_buf_shape.at(0) = 1; ones_buf_shape.at(1) = batch_size; @@ -351,17 +349,11 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); OF_NCCL_CHECK(ncclGroupStart()); - if (idx == weight_num - 1) { - OF_NCCL_CHECK(ncclAllReduce( - d_last_bias->mut_dptr(), d_last_bias->mut_dptr(), - d_last_bias->shape_view().elem_cnt(), GetNcclDataType(d_last_bias->data_type()), - ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream())); - } else { - OF_NCCL_CHECK(ncclAllReduce(d_bias->mut_dptr(), d_bias->mut_dptr(), - d_bias->shape_view().elem_cnt(), - GetNcclDataType(d_bias->data_type()), ncclRedOp_t::ncclSum, - comm, kernel_state->allreduce_stream())); - } + user_op::Tensor* allreduce_d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx); + OF_NCCL_CHECK(ncclAllReduce(allreduce_d_bias->mut_dptr(), allreduce_d_bias->mut_dptr(), + allreduce_d_bias->shape_view().elem_cnt(), + GetNcclDataType(allreduce_d_bias->data_type()), + ncclRedOp_t::ncclSum, comm, kernel_state->allreduce_stream())); OF_NCCL_CHECK(ncclAllReduce(d_weight->mut_dptr(), d_weight->mut_dptr(), d_weight->shape_view().elem_cnt(), GetNcclDataType(d_weight->data_type()), ncclRedOp_t::ncclSum, diff --git a/oneflow/user/kernels/cublas_fused_mlp_util.cuh b/oneflow/user/kernels/cublas_fused_mlp_util.cuh index 9c4d330d395..d02d4c2f898 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_util.cuh +++ b/oneflow/user/kernels/cublas_fused_mlp_util.cuh @@ -20,8 +20,6 @@ limitations under the License. #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include -#include "oneflow/core/device/nccl_util.h" -#include "oneflow/core/job/eager_nccl_comm_manager.h" // CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link. #if CUDA_VERSION >= 11060 @@ -30,7 +28,7 @@ namespace oneflow { namespace { constexpr int32_t kAuxReluLdAlignRequirement = 128; -constexpr size_t kDefaultWorkspaceSize = 4 * 1024 * 1024; // 4M +constexpr size_t kDefaultWorkspaceSizeMb = 4; // 4M long AlignReluAuxLd(long aux_ld) { /* From 2a5606078fe48aba49d9bd8022971bab15fe305b Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Mon, 4 Jul 2022 10:50:29 +0800 Subject: [PATCH 22/28] fix header file --- oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu index 9c329ba9c8c..000c5c5c8fe 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu @@ -17,6 +17,8 @@ limitations under the License. #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/user/kernels/cublas_fused_mlp_util.cuh" #include "oneflow/core/ep/include/primitive/fill.h" +#include "oneflow/core/device/nccl_util.h" +#include "oneflow/core/job/eager_nccl_comm_manager.h" // CUBLAS_AUX_EPILOGUE only support in cuda11.4 or higher version, in cuda11.4 it need static link. #if CUDA_VERSION >= 11060 From 90cd291f4949250df743a86bf7513f142a8063e2 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Tue, 5 Jul 2022 11:32:19 +0800 Subject: [PATCH 23/28] fix comment --- .../core/functional/impl/nn_grad_functor.cpp | 2 +- .../kernels/cublas_fused_mlp_grad_kernel.cu | 120 +++++++++--------- oneflow/user/ops/cublas_fused_mlp_grad_op.cpp | 4 +- oneflow/user/ops/cublas_fused_mlp_op.cpp | 2 +- 4 files changed, 66 insertions(+), 62 deletions(-) diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index ba5a603a5cd..52d102f42bb 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -1113,7 +1113,7 @@ class FusedMLPGradFunctor { public: FusedMLPGradFunctor() { #if CUDA_VERSION >= 11060 - fused_op_.resize(kMaxInputCount /*the maximum number of inputs*/); + fused_op_.resize(kMaxInputCount /*the maximum number of layers*/); for (int n = 1; n < fused_op_.size(); ++n) { fused_op_[n] = CHECK_JUST(one::OpBuilder("cublas_fused_mlp_grad") .Input("dy") diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu index 000c5c5c8fe..db4a88c84e6 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu @@ -51,7 +51,7 @@ class MatmulGradKernelState final : public user_op::OpKernelState { OF_CUDA_CHECK(cudaStreamDestroy(allreduce_stream_)); OF_CUDA_CHECK(cudaFree(workspace_)); } - cudaStream_t cuda_stream() const { return cuda_stream_; } + cudaStream_t grad_cuda_stream() const { return cuda_stream_; } cudaStream_t allreduce_stream() const { return allreduce_stream_; } cublasLtHandle_t cublas_lt_handle() const { return cublas_lt_handle_; } size_t cublas_workspace_size() const { return workspace_size_; } @@ -67,19 +67,27 @@ class MatmulGradKernelState final : public user_op::OpKernelState { ncclComm_t comm() const { return comm_->comm; } void InitNeedComm(user_op::KernelInitContext* ctx) { - if_need_comm_ = true; + if_need_comm_ = false; if (ctx->parallel_ctx().parallel_num() > 1) { const int64_t d_weights_size = ctx->output_size("d_weights"); + bool has_d_weight_broadcast = false; for (int i = 0; i < d_weights_size; i++) { - if (!ctx->SbpParallel4ArgNameAndIndex("d_weights", i).has_broadcast_parallel() - || !ctx->SbpParallel4ArgNameAndIndex("d_biases", i).has_broadcast_parallel() - || !ctx->SbpParallel4ArgNameAndIndex("dy", 0).has_split_parallel()) { - if_need_comm_ = false; + if (ctx->SbpParallel4ArgNameAndIndex("d_weights", i).has_broadcast_parallel()) { + has_d_weight_broadcast = true; break; } } - } else { - if_need_comm_ = false; + if (has_d_weight_broadcast) { + for (int i = 0; i < d_weights_size; i++) { + CHECK(ctx->SbpParallel4ArgNameAndIndex("d_weights", i).has_broadcast_parallel()) + << "All d_weight's SBP should be Broadcast. "; + CHECK(ctx->SbpParallel4ArgNameAndIndex("d_biases", i).has_broadcast_parallel()) + << "All d_bias's SBP should be Broadcast. "; + } + if (ctx->SbpParallel4ArgNameAndIndex("dy", 0).has_split_parallel()) { + if_need_comm_ = true; + } + } } } @@ -112,16 +120,16 @@ template class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: CublasFusedMLPGradKernel() { - OF_CUDA_CHECK(cudaEventCreate(&main_stream_event)); - OF_CUDA_CHECK(cudaEventCreate(&async_weight_grad_event)); - OF_CUDA_CHECK(cudaEventCreate(&dweight_event)); - OF_CUDA_CHECK(cudaEventCreate(&allreduce_event)); + OF_CUDA_CHECK(cudaEventCreate(&main_stream_event_)); + OF_CUDA_CHECK(cudaEventCreate(&async_weight_grad_event_)); + OF_CUDA_CHECK(cudaEventCreate(&dweight_event_)); + OF_CUDA_CHECK(cudaEventCreate(&allreduce_event_)); }; ~CublasFusedMLPGradKernel() override { - OF_CUDA_CHECK(cudaEventDestroy(main_stream_event)); - OF_CUDA_CHECK(cudaEventDestroy(async_weight_grad_event)); - OF_CUDA_CHECK(cudaEventDestroy(dweight_event)); - OF_CUDA_CHECK(cudaEventDestroy(allreduce_event)); + OF_CUDA_CHECK(cudaEventDestroy(main_stream_event_)); + OF_CUDA_CHECK(cudaEventDestroy(async_weight_grad_event_)); + OF_CUDA_CHECK(cudaEventDestroy(dweight_event_)); + OF_CUDA_CHECK(cudaEventDestroy(allreduce_event_)); }; std::shared_ptr InitOpKernelCache( @@ -133,23 +141,21 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: user_op::KernelInitContext* ctx) const override { std::shared_ptr kernel_state = std::make_shared(ctx); - if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - kernel_state->InitNeedComm(ctx); - if (kernel_state->IfNeedComm()) { kernel_state->InitCommMgr(ctx); } - } + kernel_state->InitNeedComm(ctx); + if (kernel_state->IfNeedComm()) { kernel_state->InitCommMgr(ctx); } return kernel_state; } private: - cudaEvent_t main_stream_event; - cudaEvent_t async_weight_grad_event; - cudaEvent_t dweight_event; - cudaEvent_t allreduce_event; + cudaEvent_t main_stream_event_; + cudaEvent_t async_weight_grad_event_; + cudaEvent_t dweight_event_; + cudaEvent_t allreduce_event_; bool IsReadyForCapture(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, const user_op::OpKernelCache* cache) const override { auto* kernel_state = dynamic_cast(state); - if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { + if (kernel_state->IfNeedComm()) { return kernel_state->IfCommCreate(); } else { return true; @@ -162,6 +168,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + int64_t tmp_buf_elem_cnt = tmp_buffer->shape_view().elem_cnt(); const int64_t weight_num = ctx->input_size("weights"); user_op::Tensor* d_x = ctx->Tensor4ArgNameAndIndex("d_x", 0); @@ -172,13 +179,10 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: ncclComm_t comm{}; bool if_need_comm = kernel_state->IfNeedComm(); - if (if_need_comm - && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - comm = kernel_state->comm(); - } + if (if_need_comm) { comm = kernel_state->comm(); } void* dy_tmp_buf = tmp_buffer->mut_dptr(); - size_t offset = 0; + size_t tmp_buf_offset = 0; auto* cuda_stream = ctx->stream()->As(); const DataType data_type = dy->data_type(); @@ -192,7 +196,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: double beta = 0.0; auto sp_beta = GetCublasScalarParameter(beta, cublas_compute_dtype); - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; // = CUBLASLT_EPILOGUE_DRELU_BGRAD + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; // currently only support 2D matmul. DimVector weight_shape(2); @@ -203,8 +207,9 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: const int64_t batch_size = dy->shape_view().At(0); const void* ones = nullptr; - auto* cuda_device = dynamic_cast(ctx->stream()->device()); - if (cuda_device != nullptr) { ones = cuda_device->GetConstOnes(dy->data_type(), batch_size); } + ep::CudaDevice* cuda_device = dynamic_cast(ctx->stream()->device()); + CHECK_NOTNULL(cuda_device); + ones = cuda_device->GetConstOnes(dy->data_type(), batch_size); if (ones == nullptr) { std::unique_ptr fill = ep::primitive::NewPrimitive(ctx->stream()->device_type(), @@ -212,11 +217,11 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: CHECK(fill); fill->Launch(ctx->stream(), tmp_buffer->mut_dptr(), 1.0, batch_size); ones = tmp_buffer->mut_dptr(); - offset += GetCudaAlignedSize(batch_size * sizeof(T)); - dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); + tmp_buf_offset += GetCudaAlignedSize(batch_size * sizeof(T)); + dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + tmp_buf_offset); } - for (int idx = weight_num - 1; idx > -1; idx--) { + for (int idx = weight_num - 1; idx >= 0; idx--) { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weights", idx); weight->shape_view().ToDimVector(&weight_shape); InferMatmulCublasMNK(dy_shape, weight_shape, @@ -236,7 +241,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: a = dy, b = weight cublas_a=weight, cublas_b=dy */ - OF_CUDA_CHECK(cudaEventRecord(main_stream_event, cuda_stream->cuda_stream())); + OF_CUDA_CHECK(cudaEventRecord(main_stream_event_, cuda_stream->cuda_stream())); OF_CUBLAS_CHECK(cublasLtMatmul( cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, @@ -254,7 +259,7 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: a = dy, b = weight cublas_a=weight, cublas_b=dy */ - OF_CUDA_CHECK(cudaEventRecord(main_stream_event, cuda_stream->cuda_stream())); + OF_CUDA_CHECK(cudaEventRecord(main_stream_event_, cuda_stream->cuda_stream())); OF_CUBLAS_CHECK(cublasLtMatmul( cuda_stream->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, weight->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, @@ -281,14 +286,14 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: /*transpose_a=*/ep::primitive::BlasTransposeType::N, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); - OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->grad_cuda_stream(), main_stream_event_)); OF_CUBLAS_CHECK(cublasLtMatmul( kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, dgrad_buf, matmul_grad_cache->cublas_a_desc, ones, matmul_grad_cache->cublas_b_desc, &sp_beta, d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_last_bias->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), kernel_state->cublas_workspace_size(), - kernel_state->cuda_stream())); + kernel_state->grad_cuda_stream())); } user_op::Tensor* d_weight = ctx->Tensor4ArgNameAndIndex("d_weights", idx); @@ -306,8 +311,8 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); if (idx != weight_num - 1) { - // if idx == weight_num - 1, async_stream has wait main_stream_event in d_bias. - OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); + // if idx == weight_num - 1, async_stream has wait main_stream_event_ in d_bias. + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->grad_cuda_stream(), main_stream_event_)); } OF_CUBLAS_CHECK(cublasLtMatmul( kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, @@ -315,14 +320,16 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), - kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); - OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); + kernel_state->cublas_workspace_size(), kernel_state->grad_cuda_stream())); + OF_CUDA_CHECK(cudaEventRecord(dweight_event_, kernel_state->grad_cuda_stream())); // compute dy shape dy_shape.at(1) = weight_shape.at(1); // compute dybuf dgrad_buf = dy_tmp_buf; - offset += GetCudaAlignedSize(dy_shape.at(0) * dy_shape.at(1) * sizeof(T)); - dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + offset); + tmp_buf_offset += GetCudaAlignedSize(dy_shape.at(0) * dy_shape.at(1) * sizeof(T)); + CHECK_LE(tmp_buf_offset, tmp_buf_elem_cnt) + << "Tmp buffer offset should <= Tmp buffer elem_cnt. "; + dy_tmp_buf = reinterpret_cast(tmp_buffer->mut_dptr() + tmp_buf_offset); } else { x->shape_view().ToDimVector(&hidden_shape); InferMatmulCublasMNK(dy_shape, hidden_shape, @@ -333,23 +340,21 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: /*transpose_a=*/ep::primitive::BlasTransposeType::T, /*transpose_b=*/ep::primitive::BlasTransposeType::N, epilogue, nullptr, nullptr, cublas_m, cublas_n, cublas_k, cublas_lda, cublas_ldb, cublas_ldc); - OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->cuda_stream(), main_stream_event)); + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->grad_cuda_stream(), main_stream_event_)); OF_CUBLAS_CHECK(cublasLtMatmul( kernel_state->cublas_lt_handle(), matmul_grad_cache->operation_desc, &sp_alpha, x->dptr(), matmul_grad_cache->cublas_a_desc, dgrad_buf, matmul_grad_cache->cublas_b_desc, &sp_beta, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, d_weight->mut_dptr(), matmul_grad_cache->cublas_c_desc, nullptr, kernel_state->cublas_workspace(), - kernel_state->cublas_workspace_size(), kernel_state->cuda_stream())); - OF_CUDA_CHECK(cudaEventRecord(dweight_event, kernel_state->cuda_stream())); + kernel_state->cublas_workspace_size(), kernel_state->grad_cuda_stream())); + OF_CUDA_CHECK(cudaEventRecord(dweight_event_, kernel_state->grad_cuda_stream())); } - if (if_need_comm - && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - // printf("Here need comm and overlap allreduce. \n"); + if (if_need_comm) { // Do Allreduce for d_bias and d_weight. // Here we wait wgrad event, and set a ncclGroup to Allreduce d_bias and d_weight. - OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event)); + OF_CUDA_CHECK(cudaStreamWaitEvent(kernel_state->allreduce_stream(), dweight_event_)); OF_NCCL_CHECK(ncclGroupStart()); user_op::Tensor* allreduce_d_bias = ctx->Tensor4ArgNameAndIndex("d_biases", idx); OF_NCCL_CHECK(ncclAllReduce(allreduce_d_bias->mut_dptr(), allreduce_d_bias->mut_dptr(), @@ -363,16 +368,15 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: OF_NCCL_CHECK(ncclGroupEnd()); if (idx == 0) { // We should sync allreduce before the kernel finish. - OF_CUDA_CHECK(cudaEventRecord(allreduce_event, kernel_state->allreduce_stream())); + OF_CUDA_CHECK(cudaEventRecord(allreduce_event_, kernel_state->allreduce_stream())); } } } - if (if_need_comm - && ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_GRAD_OVERLAP_ALLREDUCE", false)) { - OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event)); + if (if_need_comm) { + OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), allreduce_event_)); } else { - OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), dweight_event)); + OF_CUDA_CHECK(cudaStreamWaitEvent(cuda_stream->cuda_stream(), dweight_event_)); } }; diff --git a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp index c1f7d34f1e3..23127620eab 100644 --- a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp @@ -26,7 +26,7 @@ namespace { Maybe InferTensorDesc4FusedMatmulBackward(user_op::InferContext* ctx) { const int64_t weight_num = ctx->input_size("weights"); const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - for (int idx = weight_num - 1; idx > -1; idx--) { + for (int idx = weight_num - 1; idx >= 0; idx--) { const user_op::TensorDesc& weight_desc = ctx->InputTensorDesc("weights", idx); *ctx->OutputShape("d_weights", idx) = weight_desc.shape(); *ctx->OutputShape("d_biases", idx) = Shape({weight_desc.shape().At(0)}); @@ -43,7 +43,7 @@ Maybe InferDataType4MatmulBackward(user_op::InferContext* ctx) { CHECK_EQ(weight_num, dbias_size) << "The number of d_biases should be equal to weight_num. " "Because last layer's bias_grad is computed by ReduceSum. "; const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - for (int idx = weight_num - 1; idx > -1; idx--) { + for (int idx = weight_num - 1; idx >= 0; idx--) { *ctx->OutputDType("d_weights", idx) = dy_desc.data_type(); *ctx->OutputDType("d_biases", idx) = dy_desc.data_type(); } diff --git a/oneflow/user/ops/cublas_fused_mlp_op.cpp b/oneflow/user/ops/cublas_fused_mlp_op.cpp index f67e3ce624d..9bc5d9f1b57 100644 --- a/oneflow/user/ops/cublas_fused_mlp_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_op.cpp @@ -175,7 +175,7 @@ REGISTER_USER_OP_GRAD("cublas_fused_mlp") AddOp(fused_mlp_grad_op); - for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > -1; hidden_layer_idx--) { + for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx >= 0; hidden_layer_idx--) { if (op.NeedGenGradTensor4OpInput("biases", hidden_layer_idx)) { op.BindGradTensorWithOpInput(fused_mlp_grad_op.output("d_biases", hidden_layer_idx), "biases", hidden_layer_idx); From 9ed87a936ce3388a64b9f0b85a851e4e4870165a Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Tue, 5 Jul 2022 11:34:11 +0800 Subject: [PATCH 24/28] remove redundant headerfile --- oneflow/core/functional/impl/nn_grad_functor.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index 62e5309a238..09dc532b65b 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_builder.h" From 9179bf5cf30e3f40eba32163f22e1a7708c226c1 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Tue, 5 Jul 2022 11:47:18 +0800 Subject: [PATCH 25/28] fix userops build error --- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 836d7ae6cc7..bb39bf37ad5 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -4584,7 +4584,7 @@ def OneFlow_CublasFusedMLPOp : OneFlow_BaseOp<"cublas_fused_mlp", [NoSideEffect, let has_data_type_infer_fn = 1; } -def OneFlow_CublasFusedMLPGradOp : OneFlow_BaseOp<"cublas_fused_mlp_grad", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { +def OneFlow_CublasFusedMLPGradOp : OneFlow_BaseOp<"cublas_fused_mlp_grad", [NoSideEffect, NoGrad, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, OneFlow_Tensor:$x, From b64a4a455e6189220f8be2dee7168181763e8da4 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Tue, 5 Jul 2022 14:59:15 +0800 Subject: [PATCH 26/28] refine --- .../kernels/cublas_fused_mlp_grad_kernel.cu | 47 ++++++++----------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu index db4a88c84e6..03f17c67761 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu @@ -70,14 +70,7 @@ class MatmulGradKernelState final : public user_op::OpKernelState { if_need_comm_ = false; if (ctx->parallel_ctx().parallel_num() > 1) { const int64_t d_weights_size = ctx->output_size("d_weights"); - bool has_d_weight_broadcast = false; - for (int i = 0; i < d_weights_size; i++) { - if (ctx->SbpParallel4ArgNameAndIndex("d_weights", i).has_broadcast_parallel()) { - has_d_weight_broadcast = true; - break; - } - } - if (has_d_weight_broadcast) { + if (ctx->SbpParallel4ArgNameAndIndex("d_weights", 0).has_broadcast_parallel()) { for (int i = 0; i < d_weights_size; i++) { CHECK(ctx->SbpParallel4ArgNameAndIndex("d_weights", i).has_broadcast_parallel()) << "All d_weight's SBP should be Broadcast. "; @@ -383,26 +376,24 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -#define REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(dtype) \ - REGISTER_USER_KERNEL("cublas_fused_mlp_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ - && (user_op::HobDataType("x", 0) == GetDataType::value)) \ - .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ - const int64_t weight_num = ctx->input_size("weights"); \ - const Shape& dy_shape = ctx->InputShape("dy", 0); \ - int64_t m = dy_shape.At(0); \ - int64_t k = dy_shape.At(1); \ - int64_t tmp_buffer_size = 0; \ - if (m > 1024 * 1024) { \ - tmp_buffer_size += GetCudaAlignedSize(m * sizeof(dtype)); /*For last layer's bias grad*/ \ - } \ - for (int idx = weight_num - 1; idx > 0; idx--) { \ - const Shape& weight_shape = ctx->InputShape("weights", idx); \ - k = weight_shape.At(1); \ - tmp_buffer_size += GetCudaAlignedSize(m * k * sizeof(dtype)); \ - } \ - return tmp_buffer_size; \ +#define REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(dtype) \ + REGISTER_USER_KERNEL("cublas_fused_mlp_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("x", 0) == GetDataType::value)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ + const int64_t weight_num = ctx->input_size("weights"); \ + const Shape& dy_shape = ctx->InputShape("dy", 0); \ + int64_t m = dy_shape.At(0); \ + int64_t k = dy_shape.At(1); \ + int64_t tmp_buffer_size = 0; \ + tmp_buffer_size += GetCudaAlignedSize(m * sizeof(dtype)); /*For last layer's bias grad*/ \ + for (int idx = weight_num - 1; idx > 0; idx--) { \ + const Shape& weight_shape = ctx->InputShape("weights", idx); \ + k = weight_shape.At(1); \ + tmp_buffer_size += GetCudaAlignedSize(m * k * sizeof(dtype)); \ + } \ + return tmp_buffer_size; \ }); REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(float) From f4ebaf99d27199a461dfe7a3eda1627c82010456 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Tue, 5 Jul 2022 17:49:13 +0800 Subject: [PATCH 27/28] init nccl comm before execute kernel --- oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu index 03f17c67761..1c041298fa5 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu @@ -400,6 +400,8 @@ REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(float) REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(double) REGISTER_CUBLAS_FUSED_MLP_GRAD_KERNEL(half) +REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("cublas_fused_mlp_grad"); + } // namespace } // namespace oneflow From 63b7b2c7963ac7310838acec3c97583b8ee87d79 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 7 Jul 2022 15:07:39 +0800 Subject: [PATCH 28/28] fix comment --- .../kernels/cublas_fused_mlp_grad_kernel.cu | 22 +++++++++++++------ oneflow/user/ops/cublas_fused_mlp_grad_op.cpp | 3 --- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu index 1c041298fa5..ac95f2be059 100644 --- a/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu +++ b/oneflow/user/kernels/cublas_fused_mlp_grad_kernel.cu @@ -42,6 +42,9 @@ class MatmulGradKernelState final : public user_op::OpKernelState { ParseIntegerFromEnv("ONEFLOW_EP_CUDA_CUBLAS_WORKSPACE_SIZE_MB", kDefaultWorkspaceSizeMb) * 1024 * 1024; OF_CUDA_CHECK(cudaMalloc(&workspace_, workspace_size_)); + if (ctx->parallel_ctx().parallel_num() > 1) { + parallel_conf_ = ctx->parallel_desc().parallel_conf(); + } } ~MatmulGradKernelState() { OF_CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); @@ -64,7 +67,12 @@ class MatmulGradKernelState final : public user_op::OpKernelState { bool IfNeedComm() const { return if_need_comm_; } - ncclComm_t comm() const { return comm_->comm; } + ncclComm_t comm() { return GetOrCreate().comm; } + + const Comm& GetOrCreate() { + if (!comm_) { InitCommMgr(); } + return *comm_; + } void InitNeedComm(user_op::KernelInitContext* ctx) { if_need_comm_ = false; @@ -84,12 +92,12 @@ class MatmulGradKernelState final : public user_op::OpKernelState { } } - void InitCommMgr(user_op::KernelInitContext* ctx) { + void InitCommMgr() { std::set> device_set; - for (int64_t parallel_id = 0; parallel_id < ctx->parallel_desc().parallel_num(); - ++parallel_id) { - int64_t machine_id = CHECK_JUST(ctx->parallel_desc().MachineId4ParallelId(parallel_id)); - int64_t device_id = CHECK_JUST(ctx->parallel_desc().DeviceId4ParallelId(parallel_id)); + const ParallelDesc parallel_desc(parallel_conf_); + for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) { + int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id)); + int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id)); device_set.emplace(std::make_pair(machine_id, device_id)); } EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); @@ -107,6 +115,7 @@ class MatmulGradKernelState final : public user_op::OpKernelState { std::string stream_name_; std::unique_ptr comm_; bool if_need_comm_; + ParallelConf parallel_conf_; }; template @@ -135,7 +144,6 @@ class CublasFusedMLPGradKernel final : public user_op::OpKernel, public user_op: std::shared_ptr kernel_state = std::make_shared(ctx); kernel_state->InitNeedComm(ctx); - if (kernel_state->IfNeedComm()) { kernel_state->InitCommMgr(ctx); } return kernel_state; } diff --git a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp index 23127620eab..cf4fd9d3bcd 100644 --- a/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp +++ b/oneflow/user/ops/cublas_fused_mlp_grad_op.cpp @@ -13,10 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/common/just.h" -#include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/framework.h" -#include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow {