Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support noncontiguous binary op #9986

Merged
merged 33 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
43785c7
init
ofhwei Mar 14, 2023
b02f763
add functional
ofhwei Mar 14, 2023
440a725
refine
ofhwei Mar 14, 2023
ebe863c
update v1
ofhwei Mar 16, 2023
d433fbc
add bwd
ofhwei Mar 16, 2023
06254bd
add strange comment
ofhwei Mar 16, 2023
2f00054
rename to noncontiguous_binary_op
ofhwei Mar 17, 2023
c433453
Merge branch 'master' of https://github.com/OneFlow-Inc/oneflow into …
ofhwei Mar 17, 2023
1e6d46c
add unittest
ofhwei Mar 17, 2023
a3858f0
set output contiguous if not inplace
ofhwei Mar 17, 2023
6fe307b
add y_stride==1 constraint
ofhwei Mar 17, 2023
b5fd6af
add requires_grad & op check when inplace
ofhwei Mar 17, 2023
27024c5
Merge branch 'master' of https://github.com/OneFlow-Inc/oneflow into …
ofhwei Mar 17, 2023
93e3d8c
refine
ofhwei Mar 17, 2023
ab0869e
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
ofhwei Mar 20, 2023
3ed25bf
Merge branch 'master' into dev_transpose_add
ofhwei Apr 24, 2023
6fd4935
Merge branch 'master' into dev_transpose_add
ofhwei Apr 24, 2023
351997e
Merge branch 'dev_transpose_add' of https://github.com/Oneflow-Inc/on…
ofhwei Apr 24, 2023
f4f578f
update ir with NoMemotyEffect
ofhwei Apr 25, 2023
f439825
init inplace
ofhwei Apr 25, 2023
1a273c2
auto format by CI
oneflow-ci-bot Apr 25, 2023
3fb4635
Merge branch 'master' into dev_transpose_add
ofhwei Apr 25, 2023
be276b6
rm unused var
ofhwei Apr 25, 2023
dd3bd34
Merge branch 'dev_transpose_add' of https://github.com/Oneflow-Inc/on…
ofhwei Apr 25, 2023
2dfe780
Merge branch 'master' into dev_transpose_add
ofhwei Apr 27, 2023
88d4f19
Merge branch 'master' into dev_transpose_add
ofhwei Apr 27, 2023
fbe0622
Merge branch 'master' into dev_transpose_add
ofhwei Apr 28, 2023
0553041
Merge branch 'master' of https://github.com/OneFlow-Inc/oneflow into …
ofhwei May 21, 2023
de42b2e
refine
ofhwei May 21, 2023
75ad151
merge master
ofhwei May 21, 2023
7e70c3c
refine
ofhwei May 22, 2023
53d94e3
Merge branch 'master' into dev_transpose_add
ofhwei May 22, 2023
7ee81ff
Merge branch 'master' into dev_transpose_add
ofhwei May 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions oneflow/core/autograd/gradient_funcs/noncontiguous_binary_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <glog/logging.h>
#include "oneflow/core/common/just.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"

namespace oneflow {
namespace one {

struct NonContiguousBinaryOpCaptureState : public AutoGradCaptureState {
bool lhs_requires_grad = false;
bool rhs_requires_grad = false;
std::string op = "add";
bool inplace = false;
};

class NonContiguousBinaryOp : public OpExprGradFunction<NonContiguousBinaryOpCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(NonContiguousBinaryOpCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const NonContiguousBinaryOpCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
AttrMap base_attrs_;
};

Maybe<void> NonContiguousBinaryOp::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> NonContiguousBinaryOp::Capture(NonContiguousBinaryOpCaptureState* ctx,
const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const {
ctx->lhs_requires_grad = inputs.at(0)->requires_grad();
ctx->rhs_requires_grad = inputs.at(1)->requires_grad();
if (!ctx->lhs_requires_grad && !ctx->rhs_requires_grad) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->inplace = JUST(composed_attrs.GetAttr<bool>("inplace"));
ctx->op = JUST(composed_attrs.GetAttr<std::string>("op"));
if (ctx->inplace && ctx->rhs_requires_grad) {
CHECK_OR_RETURN(ctx->op == "add" || ctx->op == "sub")
<< "when inplace and rhs requires grad, op should be add/sub";
}
ctx->SaveTensorForBackward(inputs.at(0));
ctx->SaveTensorForBackward(inputs.at(1));
return Maybe<void>::Ok();
}

Maybe<void> NonContiguousBinaryOp::Apply(const NonContiguousBinaryOpCaptureState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->lhs_requires_grad && !ctx->rhs_requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(2);
auto lhs = ctx->SavedTensors().at(0);
auto rhs = ctx->SavedTensors().at(1);
auto ret = JUST(functional::NonContiguousBinaryOpGrad(out_grads.at(0), lhs, rhs, ctx->op, false));
if (ctx->lhs_requires_grad) in_grads->at(0) = ret->at(0);
if (ctx->rhs_requires_grad) in_grads->at(1) = ret->at(1);
return Maybe<void>::Ok();
}

REGISTER_OP_EXPR_GRAD_FUNCTION("noncontiguous_binary_op", NonContiguousBinaryOp);

} // namespace one
} // namespace oneflow
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2685,6 +2685,14 @@
signature: 'Tensor (Tensor y, Tensor dy, Float scale=0.35355) => FusedScaleMaskBiasSoftmaxGrad'
bind_python: False

- name: "noncontiguous_binary_op"
signature: 'Tensor (Tensor lhs, Tensor rhs, String op="add", Bool inplace=False) => NonContiguousBinaryOp'
bind_python: True

- name: "noncontiguous_binary_op_grad"
signature: 'TensorTuple (Tensor dy, Tensor lhs, Tensor rhs, String op="add", Bool inplace=False) => NonContiguousBinaryOpGrad'
bind_python: False

- name: "fused_get_center_dist"
signature: "Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2) => FusedCenter"
bind_python: True
Expand Down
51 changes: 51 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5353,6 +5353,55 @@ class FusedScaleMaskBiasSoftmaxGradFunctor {
std::shared_ptr<OpExpr> op_;
};

class NonContiguousBinaryOpFunctor {
public:
NonContiguousBinaryOpFunctor() {
op_ = CHECK_JUST(
one::OpBuilder("noncontiguous_binary_op").Input("lhs").Input("rhs").Output("y").Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& lhs, const std::shared_ptr<Tensor>& rhs,
const std::string& op, const bool& inplace = false) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("op", "inplace");
attrs.SetAllAttrs(op, inplace);
if (inplace) {
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
outputs->at(0) = lhs;
JUST(OpInterpUtil::Dispatch(*op_, {lhs, rhs}, outputs.get(), attrs));
return outputs->at(0);
}
return OpInterpUtil::Dispatch<Tensor>(*op_, {lhs, rhs}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class NonContiguousBinaryOpGradFunctor {
public:
NonContiguousBinaryOpGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("noncontiguous_binary_op_grad")
.Input("dy")
.Input("lhs")
.Input("rhs")
.Output("dlhs")
.Output("drhs")
.Build());
}

Maybe<TensorTuple> operator()(const std::shared_ptr<Tensor>& dy,
const std::shared_ptr<Tensor>& lhs,
const std::shared_ptr<Tensor>& rhs, const std::string& op,
const bool& inplace = false) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("op", "inplace");
attrs.SetAllAttrs(op, inplace);
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, lhs, rhs}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand Down Expand Up @@ -5486,6 +5535,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::SkipRMSNormFunctor>("SkipRMSNorm");
m.add_functor<impl::FusedScaleMaskBiasSoftmaxFunctor>("FusedScaleMaskBiasSoftmax");
m.add_functor<impl::FusedScaleMaskBiasSoftmaxGradFunctor>("FusedScaleMaskBiasSoftmaxGrad");
m.add_functor<impl::NonContiguousBinaryOpFunctor>("NonContiguousBinaryOp");
m.add_functor<impl::NonContiguousBinaryOpGradFunctor>("NonContiguousBinaryOpGrad");
m.add_functor<impl::MultiTensorYoloV5WeightUpdateFunctor>("MultiTensorYoloV5WeightUpdate");
}

Expand Down
38 changes: 38 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3322,6 +3322,44 @@ def OneFlow_FusedCodegeexQkvReshapeOp : OneFlow_BaseOp<"fused_codegeex_qkv_resha
let has_data_type_infer_fn = 1;
}

def OneFlow_NonContiguousBinaryOp : OneFlow_BaseOp<"noncontiguous_binary_op", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$lhs,
OneFlow_Tensor:$rhs
);
let output = (outs
OneFlow_Tensor:$y
);
let attrs = (ins
DefaultValuedAttr<StrAttr, "\"add\"">:$op,
DefaultValuedAttr<BoolAttr, "false">:$inplace
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_NonContiguousBinaryOpGrad : OneFlow_BaseOp<"noncontiguous_binary_op_grad", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
OneFlow_Tensor:$lhs,
OneFlow_Tensor:$rhs
);
let output = (outs
OneFlow_Tensor:$dlhs,
OneFlow_Tensor:$drhs
);
let attrs = (ins
DefaultValuedAttr<StrAttr, "\"add\"">:$op,
DefaultValuedAttr<BoolAttr, "false">:$inplace
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

#endif // GET_ONEFLOW_FUSED_OP_DEFINITIONS


Expand Down
Loading