Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stack kernel #7152

Merged
merged 28 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0aee15e
fix arange bug
MARD1NO Nov 17, 2021
618c94b
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Nov 18, 2021
91def65
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Nov 19, 2021
b824bda
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Nov 22, 2021
b77ed0b
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Dec 1, 2021
62eb79a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Dec 1, 2021
1cb3f1a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Dec 21, 2021
454b51d
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Dec 30, 2021
cdb09aa
build init kernel
MARD1NO Dec 30, 2021
7eb558a
add stack backward
MARD1NO Dec 31, 2021
76523bd
remove annotation
MARD1NO Dec 31, 2021
c91c27f
reformat and fix sbp
MARD1NO Dec 31, 2021
02b38b0
fix ops td format
MARD1NO Dec 31, 2021
c01c50e
fix format
MARD1NO Dec 31, 2021
12318ee
fix comment
MARD1NO Jan 4, 2022
b996b31
add more test case in dim
MARD1NO Jan 4, 2022
2a70ac8
fiux user ops td
MARD1NO Jan 4, 2022
2b91499
fix to use size_t
MARD1NO Jan 4, 2022
07c2d03
fix annotation
MARD1NO Jan 4, 2022
7daca0d
fix less than
MARD1NO Jan 4, 2022
019ed5e
fix userop tabelgen
MARD1NO Jan 4, 2022
d08c367
Merge branch 'master' into add_stack_kernel
MARD1NO Jan 5, 2022
7be1596
Merge branch 'master' into add_stack_kernel
oneflow-ci-bot Jan 5, 2022
5f2ddc3
fix bug when num of inputs greater than 128
MARD1NO Jan 5, 2022
01306c5
Merge branch 'add_stack_kernel' of github.com:Oneflow-Inc/oneflow int…
MARD1NO Jan 5, 2022
61fa6d9
Merge branch 'master' into add_stack_kernel
oneflow-ci-bot Jan 5, 2022
2c9a469
Merge branch 'master' into add_stack_kernel
oneflow-ci-bot Jan 5, 2022
6f009aa
Merge branch 'master' into add_stack_kernel
oneflow-ci-bot Jan 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions oneflow/core/autograd/gradient_funcs/stack.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

struct StackCaptureState : public AutoGradCaptureState {
std::vector<bool> requires_grad;
int64_t axis = 1;
int64_t input_num = 2;
};

class Stack : public OpExprGradFunction<StackCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(StackCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const StackCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
AttrMap base_attrs_;
};

Maybe<void> Stack::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> Stack::Capture(StackCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad.resize(inputs.size());
for (int i = 0; i < inputs.size(); ++i) { ctx->requires_grad[i] = inputs.at(i)->requires_grad(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<int64_t>("axis"));
for (const auto& input : inputs) { ctx->SaveTensorForBackward(input); }
ctx->input_num = inputs.size();
return Maybe<void>::Ok();
}

Maybe<void> Stack::Apply(const StackCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(ctx->input_num);
TensorTuple like(ctx->input_num);
for (int i = 0; i < ctx->input_num; ++i) { like[i] = ctx->SavedTensors().at(i); }
if (ctx->input_num == 1) {
in_grads->at(0) = out_grads.at(0);
} else {
const auto& results = JUST(functional::StackGrad(out_grads.at(0), like, ctx->axis));
CHECK_EQ_OR_RETURN(results->size(), ctx->input_num);
for (int i = 0; i < ctx->input_num; ++i) {
if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); }
}
}
return Maybe<void>::Ok();
}

REGISTER_OP_EXPR_GRAD_FUNCTION("stack", Stack);

} // namespace one
} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,10 @@
signature: "Tensor (TensorTuple inputs, Int64 dim=0) => Stack"
bind_python: True

- name: "stack_grad"
signature: "TensorTuple (Tensor x, TensorTuple like, Int64 axis) => StackGrad"
bind_python: False

- name: "local_to_consistent"
signature: "Tensor (Tensor x, Placement placement, SbpList sbp, Shape shape, DataType dtype) => LocalToConsistent"
bind_python: False
Expand Down
81 changes: 66 additions & 15 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,26 +483,76 @@ class ConcatFunctor {

class StackFunctor {
public:
StackFunctor() = default;
StackFunctor() {
ops_.resize(kMaxInputCount);
for (int n = 0; n < ops_.size(); ++n) {
ops_[n] = CHECK_JUST(one::OpBuilder("stack").Input("in", n + 1).Output("out").Build());
}
}
Maybe<Tensor> operator()(const TensorTuple& inputs, const int64_t& dim) const {
CHECK_GE_OR_RETURN(inputs.size(), 1) << "Needs one input at least.";
int64_t ndims = inputs.at(0)->shape()->NumAxes();
const int64_t ninput = inputs.size();
int64_t ndims = inputs[0]->ndim();
int64_t stack_dim = dim;
for (int i = 1; i < inputs.size(); ++i) {
CHECK_EQ_OR_RETURN(inputs.at(i)->shape()->NumAxes(), ndims)
<< "The input dimensions are not equal.";
}
CHECK_OR_RETURN(dim >= -(ndims + 1) && dim <= ndims)
<< "( Dimension out of range, expected to be in range of [" << -(ndims + 1) << ", " << ndims
<< "], but got " << dim << " )";
if (dim < 0) { stack_dim = stack_dim + ndims + 1; }
TensorTuple expand_inputs(inputs.size());
if (inputs.size() == 1) { return ExpandDims(inputs.at(0), stack_dim); }
for (int i = 0; i < inputs.size(); ++i) {
expand_inputs[i] = JUST(ExpandDims(inputs.at(i), stack_dim));
CHECK_OR_RETURN(stack_dim >= 0 && stack_dim <= ndims)
<< "Index Error: Dimension out of range (expected in range of [" << -ndims - 1 << ", "
<< ndims << "], but got " << stack_dim;
const std::shared_ptr<const Shape>& first_in_shape = inputs[0]->shape();
for (const auto& input : inputs) {
for (int i = 0; i < ndims; ++i) {
CHECK_OR_RETURN(input->shape()->At(i) == first_in_shape->At(i))
<< " Stacks expects each tensor to be equal size"
", but got "
<< first_in_shape->ToString() << " at first input and " << input->shape()->ToString()
<< " which index is " << i;
}
}
int64_t max_dim_size = ninput;
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("axis", stack_dim));
JUST(attrs.SetAttr<int64_t>("max_dim_size", max_dim_size));
TensorTuple outputs;
for (int i = 0; i < ninput; i += kMaxInputCount) {
size_t size = (i + kMaxInputCount) < ninput ? kMaxInputCount : ninput - i;
TensorTuple partial_inputs(size);
for (int j = 0; j < size; ++j) { partial_inputs[j] = inputs[i + j]; }
outputs.emplace_back(
JUST(OpInterpUtil::Dispatch<Tensor>(*ops_.at(size - 1), partial_inputs, attrs)));
}
if (outputs.size() == 1) { return outputs.at(0); }
return this->operator()(outputs, stack_dim);
}

private:
std::vector<std::shared_ptr<OpExpr>> ops_;
};

class StackGradFunctor {
public:
StackGradFunctor() {
ops_.resize(kMaxInputCount);
for (int n = 1; n < ops_.size(); ++n) {
ops_[n] = CHECK_JUST(one::OpBuilder("stack_grad")
.Input("in")
.Input("like", n + 1)
.Output("out", n + 1)
.Build());
}
return Concat(expand_inputs, stack_dim);
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x, const TensorTuple& like,
const int64_t& axis) const {
CHECK_GE_OR_RETURN(like.size(), 2);
CHECK_LE_OR_RETURN(like.size(), kMaxInputCount);
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("axis", axis));
TensorTuple inputs(like.size() + 1);
inputs[0] = x;
for (int i = 0; i < like.size(); ++i) { inputs[i + 1] = like[i]; }
return OpInterpUtil::Dispatch<TensorTuple>(*ops_.at(like.size() - 1), inputs, attrs);
}

private:
std::vector<std::shared_ptr<OpExpr>> ops_;
};

class ExpandFunctor {
Expand Down Expand Up @@ -2493,6 +2543,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::BroadcastLikeFunctor>("BroadcastLike");
m.add_functor<impl::ConcatFunctor>("Concat");
m.add_functor<impl::StackFunctor>("Stack");
m.add_functor<impl::StackGradFunctor>("StackGrad");
m.add_functor<impl::ExpandFunctor>("Expand");
m.add_functor<impl::ExpandGradFunctor>("ExpandGrad");
m.add_functor<impl::ExpandDimsFunctor>("ExpandDims");
Expand Down
41 changes: 39 additions & 2 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4127,8 +4127,8 @@ def OneFlow_MatmulOp : OneFlow_BaseOp<"matmul", [NoSideEffect, DeclareOpInterfac
#endif // GET_ONEFLOW_MATMUL_OP_DEFINITIONS

// Group: MISC
// CategoricalOrdinalEncode, add_n, arange, coin_flip, concat, constant, dropout, elementwise_maximum_backward, elementwise_minimum_backward, empty, eye, grid_sample_grad, multi_count_not_finite, multi_square_sum, nll, nll_grad, pow_x_grad, pow_y_grad, prelu_grad, randperm, recv, send, split_like, ssp_variable_proxy, tf_prelu_grad, uniform, uniform_int, unique_with_counts, xdivy_x_grad, xdivy_y_grad
// Total: 30
// CategoricalOrdinalEncode, add_n, arange, coin_flip, concat, constant, dropout, elementwise_maximum_backward, elementwise_minimum_backward, empty, eye, grid_sample_grad, multi_count_not_finite, multi_square_sum, nll, nll_grad, pow_x_grad, pow_y_grad, prelu_grad, randperm, recv, send, split_like, ssp_variable_proxy, tf_prelu_grad, uniform, uniform_int, unique_with_counts, xdivy_x_grad, xdivy_y_grad, stack, stack_grad
// Total: 32

#ifdef GET_ONEFLOW_MISC_OP_DEFINITIONS

Expand Down Expand Up @@ -4659,6 +4659,43 @@ def OneFlow_XdivyYGradOp : OneFlow_BaseOp<"xdivy_y_grad", [NoSideEffect, Declare
let has_data_type_infer_fn = 1;
}

def OneFlow_StackOp : OneFlow_BaseOp<"stack", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
Variadic<OneFlow_Tensor>:$in
);
let output = (outs
OneFlow_Tensor:$out
);
let attrs = (ins
DefaultValuedAttr<SI64Attr, "0">:$axis,
DefaultValuedAttr<SI64Attr, "0">:$max_dim_size
);
let has_check_fn = 1;
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_StackGradOp : OneFlow_BaseOp<"stack_grad", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in,
Variadic<OneFlow_Tensor>:$like
);
let output = (outs
Variadic<OneFlow_Tensor>:$out
);
let attrs = (ins
DefaultValuedAttr<SI64Attr, "0">:$axis
);
let has_check_fn = 1;
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
let has_input_arg_modify_fn = 1;
}

#endif // GET_ONEFLOW_MISC_OP_DEFINITIONS

// Group: NCCL
Expand Down
Loading