Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev Fully fused MLP Grad[OneEmbedding] #8462

Merged
merged 38 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
090c0a9
support fully fused mlp grad in eager
MARD1NO Jun 22, 2022
110a368
support lazy backward
MARD1NO Jun 22, 2022
e06c240
fix output size
MARD1NO Jun 23, 2022
9158bca
add fallback to tmp_buf logic when ones buffer is not enough
MARD1NO Jun 23, 2022
2395e6e
build sbp
MARD1NO Jun 23, 2022
42fd871
overlap allreduce
MARD1NO Jun 23, 2022
f8ec2b9
fix overlap order
MARD1NO Jun 23, 2022
b017dbd
fix format
MARD1NO Jun 23, 2022
bff8a33
CUDA Graphs delayed capture
liujuncheng Jun 23, 2022
2fd8dd9
Merge branch 'master' into dev_fused_mlp_grad
MARD1NO Jun 23, 2022
4eb6ac1
Add ifcomm create for graph
MARD1NO Jun 24, 2022
84e4a64
insert weight event roughly
MARD1NO Jun 24, 2022
e2d45aa
Merge remote-tracking branch 'origin/dev_cuda_graph_delayed_capture' …
MARD1NO Jun 24, 2022
9086ba6
fix dbias allreduce error
MARD1NO Jun 24, 2022
45db76e
simplify code
MARD1NO Jun 27, 2022
dfec471
Add 11060 limit
MARD1NO Jun 27, 2022
63e125b
Merge branch 'master' into dev_fused_mlp_grad
MARD1NO Jun 28, 2022
6695a20
Merge branch 'dev_fused_mlp_grad' of github.com:Oneflow-Inc/oneflow i…
MARD1NO Jun 28, 2022
4e3fa24
Remove print
MARD1NO Jun 28, 2022
5115c69
Rename
MARD1NO Jun 28, 2022
68c1e3e
fix fill bug and remove comm to cache
MARD1NO Jun 28, 2022
8274d6f
Merge branch 'master' into dev_fused_mlp_grad
MARD1NO Jun 30, 2022
87dfba9
Rename variable and add debug code for cache
MARD1NO Jun 30, 2022
36b773f
Use kernel state and fix bug
MARD1NO Jun 30, 2022
e2e2069
remove print
MARD1NO Jun 30, 2022
6e60e9c
fix allreduce dbias bug
MARD1NO Jun 30, 2022
2a56060
fix header file
MARD1NO Jul 4, 2022
90cd291
fix comment
MARD1NO Jul 5, 2022
2c61af2
Merge branch 'master' into dev_fused_mlp_grad
MARD1NO Jul 5, 2022
9ed87a9
remove redundant headerfile
MARD1NO Jul 5, 2022
9179bf5
fix userops build error
MARD1NO Jul 5, 2022
b64a4a4
refine
MARD1NO Jul 5, 2022
7d00f49
Merge branch 'master' into dev_fused_mlp_grad
MARD1NO Jul 5, 2022
f4ebaf9
init nccl comm before execute kernel
MARD1NO Jul 5, 2022
e8cf34a
Merge branch 'dev_fused_mlp_grad' of github.com:Oneflow-Inc/oneflow i…
MARD1NO Jul 5, 2022
3f06229
Merge branch 'master' into dev_fused_mlp_grad
MARD1NO Jul 7, 2022
63b7b2c
fix comment
MARD1NO Jul 7, 2022
ac02327
Merge branch 'dev_fused_mlp_grad' of github.com:Oneflow-Inc/oneflow i…
MARD1NO Jul 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 75 additions & 50 deletions oneflow/core/autograd/gradient_funcs/cublas_fused_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Maybe<void> CublasFusedMLP::Capture(CublasFusedMLPCaptureState* ctx, const Tenso
ctx->SaveTensorForBackward(
JUST(VectorAt(outputs, i + 1))); // cublas aux. need minus 1. idx_sum:2+2w
}
for (int32_t i = 0; i < weight_num - 1; i++) {
for (int32_t i = 0; i < weight_num; i++) {
ctx->SaveTensorForBackward(JUST(VectorAt(outputs, i + 1 + weight_num))); // hidden.
}

Expand All @@ -103,14 +103,7 @@ Maybe<void> CublasFusedMLP::Apply(const CublasFusedMLPCaptureState* ctx,
JUST(VectorAt(ctx->SavedTensors(), 1 + weight_num))));
}

// step2: use reduce_sum to get last layer's bias grad.
std::vector<int32_t> reduce_axes_vec{0};
if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) {
JUST(VectorAt(*in_grads, 2 * weight_num)) =
JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false));
}

TensorTuple hiddens(weight_num - 1);
TensorTuple hiddens(weight_num);
TensorTuple weights(weight_num);
TensorTuple cublas_auxs(weight_num);
TensorTuple dgrad(weight_num);
Expand All @@ -125,56 +118,88 @@ Maybe<void> CublasFusedMLP::Apply(const CublasFusedMLPCaptureState* ctx,
cublas_auxs[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + weight_num));
}

for (int32_t i = 0; i < weight_num - 1; ++i) {
for (int32_t i = 0; i < weight_num; ++i) {
hiddens[i] = JUST(VectorAt(ctx->SavedTensors(), i + 2 + 2 * weight_num));
}

std::shared_ptr<one::Tensor> cublas_dy = last_bias_dy;
for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) {
// If it is final layer, we use out_grads[0] as dy.
if (hidden_layer_idx != weight_num - 1) {
cublas_dy = JUST(VectorAt(dgrad, hidden_layer_idx + 1));

// Use Fully Fused MLP Backward.
if (ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_FUSED_MLP_ASYNC_GRAD", false)) {
const auto& fused_mlp_grad = JUST(functional::FusedMLPGrad(
cublas_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), weights, cublas_auxs, hiddens));
if (ctx->x_requires_grad) {
// dx:
JUST(VectorAt(*in_grads, 0)) = fused_mlp_grad->at(0);
}
/*
Here we use cublas to compute bias + relu + matmul grad.
Then use Matmul to compute weight grad.
*/
const auto& matmul_relu_bias_bgrad = JUST(functional::CublasBiasAddReluMatmulGrad(
cublas_dy, JUST(VectorAt(weights, hidden_layer_idx)),
JUST(VectorAt(cublas_auxs, hidden_layer_idx - 1)), /*alpha=*/1.0));

// dgrad
dgrad.at(hidden_layer_idx) = matmul_relu_bias_bgrad->at(0); // NOLINT

if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx - 1)))) {
// dbias
JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx)) =
matmul_relu_bias_bgrad->at(1); // NOLINT

for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > -1; hidden_layer_idx--) {
if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx)))) {
// dbias
JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx + 1)) =
fused_mlp_grad->at(1 + hidden_layer_idx); // NOLINT
}

// dw
if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) {
JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) =
fused_mlp_grad->at(1 + weight_num + hidden_layer_idx);
}
}
// dw
if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) {
JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = JUST(functional::MatMul(
cublas_dy, JUST(VectorAt(hiddens, hidden_layer_idx - 1)), true, false, 1.0));
} else {
// step2: use reduce_sum to get last layer's bias grad.
std::vector<int32_t> reduce_axes_vec{0};
if (JUST(VectorAt(ctx->biases_requires_grad, weight_num - 1))) {
JUST(VectorAt(*in_grads, 2 * weight_num)) =
JUST(functional::ReduceSum(last_bias_dy, reduce_axes_vec, false));
}
}

// For the first layer, we need to use 2 matmul to get grads.
std::shared_ptr<one::Tensor> last_dy;
if (weight_num != 1) {
last_dy = JUST(VectorAt(dgrad, 1));
} else {
last_dy = last_bias_dy;
}
for (int32_t hidden_layer_idx = weight_num - 1; hidden_layer_idx > 0; hidden_layer_idx--) {
// If it is final layer, we use out_grads[0] as dy.
if (hidden_layer_idx != weight_num - 1) {
cublas_dy = JUST(VectorAt(dgrad, hidden_layer_idx + 1));
}
/*
Here we use cublas to compute bias + relu + matmul grad.
Then use Matmul to compute weight grad.
*/
const auto& matmul_relu_bias_bgrad = JUST(functional::CublasBiasAddReluMatmulGrad(
cublas_dy, JUST(VectorAt(weights, hidden_layer_idx)),
JUST(VectorAt(cublas_auxs, hidden_layer_idx - 1)), /*alpha=*/1.0));

// dgrad
dgrad.at(hidden_layer_idx) = matmul_relu_bias_bgrad->at(0); // NOLINT

if (JUST(VectorAt(ctx->biases_requires_grad, (hidden_layer_idx - 1)))) {
// dbias
JUST(VectorAt(*in_grads, weight_num + hidden_layer_idx)) =
matmul_relu_bias_bgrad->at(1); // NOLINT
}
// dw
if (JUST(VectorAt(ctx->weights_requires_grad, hidden_layer_idx))) {
JUST(VectorAt(*in_grads, (1 + hidden_layer_idx))) = JUST(functional::MatMul(
cublas_dy, JUST(VectorAt(hiddens, hidden_layer_idx - 1)), true, false, 1.0));
}
}

if (ctx->x_requires_grad) {
// dx:
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::MatMul(last_dy, JUST(VectorAt(weights, 0)), false, false, 1.0));
}
if (JUST(VectorAt(ctx->weights_requires_grad, 0))) {
// dw:
JUST(VectorAt(*in_grads, 1)) =
JUST(functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0));
// For the first layer, we need to use 2 matmul to get grads.
std::shared_ptr<one::Tensor> last_dy;
if (weight_num != 1) {
last_dy = JUST(VectorAt(dgrad, 1));
} else {
last_dy = last_bias_dy;
}

if (ctx->x_requires_grad) {
// dx:
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::MatMul(last_dy, JUST(VectorAt(weights, 0)), false, false, 1.0));
}
if (JUST(VectorAt(ctx->weights_requires_grad, 0))) {
// dw:
JUST(VectorAt(*in_grads, 1)) = JUST(
functional::MatMul(last_dy, JUST(VectorAt(ctx->SavedTensors(), 0)), true, false, 1.0));
}
}

return Maybe<void>::Ok();
Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,11 @@
"Tensor (Tensor x, TensorTuple weights, TensorTuple biases, Bool skip_final_activation) => FusedMLP"
bind_python: True

- name: "fused_mlp_grad"
signature:
"TensorTuple (Tensor dy, Tensor x, TensorTuple weights, TensorTuple cublas_aux, TensorTuple hidden) => FusedMLPGrad"
bind_python: False

- name: "cublas_bias_add_relu_matmul_grad"
signature:
"TensorTuple (Tensor dy, Tensor weight, Tensor aux, Double alpha=1.0) => CublasBiasAddReluMatmulGrad"
Expand Down
42 changes: 42 additions & 0 deletions oneflow/core/functional/impl/nn_grad_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,47 @@ class FusedCrossFeatureInteractionV2GradFunctor {
std::shared_ptr<OpExpr> v2_grad_op_;
};

class FusedMLPGradFunctor {
public:
FusedMLPGradFunctor() {
#if CUDA_VERSION >= 11060
fused_op_.resize(kMaxInputCount /*the maximum number of layers*/);
for (int n = 1; n < fused_op_.size(); ++n) {
fused_op_[n] = CHECK_JUST(one::OpBuilder("cublas_fused_mlp_grad")
.Input("dy")
.Input("x")
.Input("weights", n)
.Input("cublas_aux", n)
.Input("hidden", n)
.Output("d_x")
.Output("d_biases", n)
.Output("d_weights", n)
.Build());
}
#endif
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& x, const TensorTuple& weights,
const TensorTuple& cublas_aux, const TensorTuple& hidden) const {
const int64_t weight_size = weights.size();
TensorTuple input(2 + 3 * weight_size);
input[0] = dy;
input[1] = x;
std::copy(weights.begin(), weights.end(), input.begin() + 2);
std::copy(cublas_aux.begin(), cublas_aux.end(), input.begin() + 2 + weight_size);
std::copy(hidden.begin(), hidden.end(), input.begin() + 2 + 2 * weight_size);
#if CUDA_VERSION >= 11060
return OpInterpUtil::Dispatch<TensorTuple>(*fused_op_[weight_size], input);
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
#endif
UNIMPLEMENTED_THEN_RETURN() << "Only Support in CUDA_VERSION >= 11060";
}

private:
#if CUDA_VERSION >= 11060
std::vector<std::shared_ptr<OpExpr>> fused_op_;
#endif
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand Down Expand Up @@ -1173,6 +1214,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
"FusedCrossFeatureInteractionV1Grad");
m.add_functor<impl::FusedCrossFeatureInteractionV2GradFunctor>(
"FusedCrossFeatureInteractionV2Grad");
m.add_functor<impl::FusedMLPGradFunctor>("FusedMLPGrad");
m.add_functor<impl::BinaryCrossEntropyWithLogitsReduceMeanLossGradFunctor>(
"BinaryCrossEntropyWithLogitsReduceMeanLossGrad");
};
Expand Down
19 changes: 19 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4584,6 +4584,25 @@ def OneFlow_CublasFusedMLPOp : OneFlow_BaseOp<"cublas_fused_mlp", [NoSideEffect,
let has_data_type_infer_fn = 1;
}

def OneFlow_CublasFusedMLPGradOp : OneFlow_BaseOp<"cublas_fused_mlp_grad", [NoSideEffect, NoGrad, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
OneFlow_Tensor:$x,
Variadic<OneFlow_Tensor>:$weights,
Variadic<OneFlow_Tensor>:$cublas_aux,
Variadic<OneFlow_Tensor>:$hidden
);
let output = (outs
OneFlow_Tensor:$d_x,
Variadic<OneFlow_Tensor>:$d_biases,
Variadic<OneFlow_Tensor>:$d_weights
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_CublasBiasAddReluMatmulGradOp : OneFlow_BaseOp<"cublas_bias_add_relu_matmul_grad", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
Expand Down
Loading