Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inplace div add sub #7293

Merged
merged 38 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
0aee15e
fix arange bug
MARD1NO Nov 17, 2021
618c94b
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Nov 18, 2021
91def65
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Nov 19, 2021
b824bda
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Nov 22, 2021
b77ed0b
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Dec 1, 2021
62eb79a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Dec 1, 2021
1cb3f1a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Dec 21, 2021
454b51d
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Dec 30, 2021
cbb8614
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Jan 6, 2022
cd2597a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Jan 11, 2022
de4b8ee
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Jan 12, 2022
64c29e6
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Jan 14, 2022
b19dcc8
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
MARD1NO Jan 14, 2022
1240e80
init
MARD1NO Jan 19, 2022
b02b52e
add inplace div sub
MARD1NO Jan 19, 2022
ee5200e
remove annotation
MARD1NO Jan 19, 2022
b4df66e
fix docs and add unittest
MARD1NO Jan 19, 2022
843964d
Merge branch 'master' into support_inplace_div_add_sub
MARD1NO Jan 19, 2022
082b186
fix comment
MARD1NO Jan 19, 2022
876a8c6
Merge branch 'support_inplace_div_add_sub' of github.com:Oneflow-Inc/…
MARD1NO Jan 19, 2022
74ed8c8
use._C
MARD1NO Jan 19, 2022
6c4456e
fix docs
MARD1NO Jan 19, 2022
3022eb0
Merge branch 'master' into support_inplace_div_add_sub
MARD1NO Jan 19, 2022
3f62d44
Merge branch 'master' into support_inplace_div_add_sub
oneflow-ci-bot Jan 19, 2022
c290e58
Merge branch 'master' into support_inplace_div_add_sub
oneflow-ci-bot Jan 19, 2022
0624285
Merge branch 'master' into support_inplace_div_add_sub
MARD1NO Jan 20, 2022
87beba0
Merge branch 'master' into support_inplace_div_add_sub
oneflow-ci-bot Jan 20, 2022
7ee3c55
add type promotion logic for inplaceable binary func
MARD1NO Jan 20, 2022
446f0a9
Merge branch 'master' into support_inplace_div_add_sub
MARD1NO Jan 24, 2022
5e5953f
Merge branch 'master' into support_inplace_div_add_sub
oneflow-ci-bot Jan 24, 2022
47c105a
Merge branch 'master' into support_inplace_div_add_sub
MARD1NO Jan 24, 2022
ad9eb79
Merge branch 'master' into support_inplace_div_add_sub
oneflow-ci-bot Jan 24, 2022
642bcca
Merge branch 'master' into support_inplace_div_add_sub
oneflow-ci-bot Jan 24, 2022
1e80a25
Merge branch 'master' into support_inplace_div_add_sub
oneflow-ci-bot Jan 24, 2022
2524b1b
Merge branch 'master' into support_inplace_div_add_sub
oneflow-ci-bot Jan 24, 2022
6831201
Merge branch 'master' into support_inplace_div_add_sub
MARD1NO Jan 25, 2022
4ac1565
Merge branch 'master' into support_inplace_div_add_sub
oneflow-ci-bot Jan 25, 2022
6d6ada9
Merge branch 'master' into support_inplace_div_add_sub
oneflow-ci-bot Jan 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/tensor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ OneFlow Tensor Class
diagonal,
dim,
div,
div_,
double,
dtype,
element_size,
Expand Down Expand Up @@ -100,6 +101,7 @@ OneFlow Tensor Class
min,
mish,
mul,
mul_,
narrow,
ndim,
ndimension,
Expand Down Expand Up @@ -149,6 +151,7 @@ OneFlow Tensor Class
stride,
swapaxes,
sub,
sub_,
tan,
tanh,
tile,
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/autograd/gradient_funcs/variance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ Maybe<void> Variance::Apply(const VarianceState* ctx, const TensorTuple& out_gra
in_grads->resize(1);
in_grads->at(0) = JUST(functional::Mul(
out_grad,
JUST(functional::ScalarMul(Scalar(2.0 / (elem_cnt - correction)),
JUST(functional::Sub(x, JUST(functional::ReduceMean(
x, ctx->axis, /*keepdim=*/true))))))));
JUST(functional::ScalarMul(
Scalar(2.0 / (elem_cnt - correction)),
JUST(functional::Sub(x, JUST(functional::ReduceMean(x, ctx->axis, /*keepdim=*/true)),
/*inplace=*/false))))));

return Maybe<void>::Ok();
}
Expand Down
10 changes: 9 additions & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
- name: "sub"
signature:
[
"Tensor (Tensor input, Tensor other) => Sub",
"Tensor (Tensor input, Tensor other, *, Bool inplace=False) => Sub",
"Tensor (Tensor input, Scalar other, *, Bool inplace=False) => ScalarSub",
"Tensor (Scalar input, Tensor other) => ScalarSub",
]
Expand Down Expand Up @@ -65,6 +65,14 @@
]
bind_python: true

- name: "div_"
signature:
[
"Tensor (Tensor input, Tensor other) => InplaceDiv",
"Tensor (Tensor input, Scalar other) => InplaceScalarDiv",
]
bind_python: true

- name: "div_grad"
signature: "Tensor (Tensor y, Tensor z, Tensor dz) => DivGrad"
bind_python: False
Expand Down
34 changes: 33 additions & 1 deletion oneflow/core/functional/impl/binary_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class BroadcastPowFunctor : public BinaryFloatFunctor {
}
};

class SubFunctor : public BinaryFunctor {
class SubFunctor : public InplaceableBinaryFunctor {
public:
SubFunctor() {
op_ = CHECK_JUST(one::OpBuilder("broadcast_sub").Input("x").Input("y").Output("z").Build());
Expand Down Expand Up @@ -185,6 +185,37 @@ class DivFunctor : public BinaryFloatFunctor {
}
};

class InplaceDivFunctor {
public:
InplaceDivFunctor() {
broadcast_div_op_ =
CHECK_JUST(one::OpBuilder("broadcast_div").Input("x").Input("y").Output("z").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& y) const {
TensorProcessor tensor_processor;
if (y->requires_grad()) {
JUST(tensor_processor.PromoteInputsToCommonDtype(true)
.AddInputs({JUST(Identity(x)), y})
.Apply());
} else {
JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y}).Apply());
}
const TensorTuple& input_vec = JUST(tensor_processor.GetInputs());
const std::shared_ptr<one::Tensor>& x_cast = input_vec.at(0);
const std::shared_ptr<one::Tensor>& y_cast = input_vec.at(1);
JUST(CheckInplaceValid(x));
JUST(CheckInplaceCastValid(x, x_cast));
JUST(CheckShapeCanExpandTo(*y_cast->shape(), *x_cast->shape()));
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
outputs->at(0) = x;
JUST(OpInterpUtil::Dispatch(*broadcast_div_op_, input_vec, outputs.get()));
return outputs->at(0);
}

private:
std::shared_ptr<OpExpr> broadcast_div_op_;
};
class Atan2Functor : public BinaryFloatFunctor {
public:
Atan2Functor() {
Expand Down Expand Up @@ -339,6 +370,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::SubFunctor>("Sub");
m.add_functor<impl::MulFunctor>("Mul");
m.add_functor<impl::InplaceMulFunctor>("InplaceMul");
m.add_functor<impl::InplaceDivFunctor>("InplaceDiv");
m.add_functor<impl::DivFunctor>("Div");
m.add_functor<impl::PowFunctor>("Pow");
m.add_functor<impl::BroadcastPowFunctor>("BroadcastPow");
Expand Down
15 changes: 10 additions & 5 deletions oneflow/core/functional/impl/binary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,20 @@ class InplaceableBinaryFunctor {
public:
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& y, bool inplace) const {
TensorProcessor tensor_processor;
JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y}).Apply());
TensorTuple input_tuple = JUST(tensor_processor.GetInputs());
if (inplace) {
JUST(CheckInplaceValid(x));
JUST(CheckShapeCanExpandTo(*y->shape(), *x->shape()));
std::shared_ptr<one::Tensor>& x_cast = input_tuple.at(0);
std::shared_ptr<one::Tensor>& y_cast = input_tuple.at(1);
JUST(CheckInplaceCastValid(x, x_cast));
JUST(CheckShapeCanExpandTo(*y_cast->shape(), *x_cast->shape()));
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
outputs->at(0) = x;
JUST(OpInterpUtil::Dispatch(*op_, {x, y}, outputs.get()));
outputs->at(0) = x_cast;
JUST(OpInterpUtil::Dispatch(*op_, input_tuple, outputs.get()));
return outputs->at(0);
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {x, y});
return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple);
}
}

Expand Down
13 changes: 11 additions & 2 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ class ScalarDiv2Functor {
}
};

class InplaceScalarDivFunctor : public ScalarMathBaseFunctor {
public:
InplaceScalarDivFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_mul") {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& scalar) const {
return ScalarMathBaseFunctor::operator()(x, Scalar(1.0) / scalar, true);
}
};

class ScalarPowFunctor : public ScalarMathBaseFunctor {
public:
ScalarPowFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_pow") {}
Expand Down Expand Up @@ -1619,7 +1627,7 @@ class StandardDeviationFunctor {
Scalar((double)reduce_count)));
const auto& square = JUST(functional::Square(JUST(functional::ScalarDiv(
JUST(functional::ReduceSum(input, axis, keepdims)), Scalar((double)reduce_count)))));
const auto& sub = JUST(functional::Sub(sum, square));
const auto& sub = JUST(functional::Sub(sum, square, /*inplace=*/false));
if (unbias) {
return functional::Sqrt(JUST(functional::ScalarMul(
sub, Scalar((double)reduce_count / (double)(reduce_count - 1)), false)));
Expand Down Expand Up @@ -1652,7 +1660,7 @@ class StandardDeviationFunctor {
const auto& square = JUST(functional::Square(
JUST(functional::ScalarDiv(JUST(functional::ReduceSum(double_input, axis, keepdims)),
Scalar((double)reduce_count)))));
const auto& sub = JUST(functional::Sub(sum, square));
const auto& sub = JUST(functional::Sub(sum, square, /*inplace=*/false));
if (unbias) {
return functional::Cast(
JUST(functional::Sqrt(JUST(functional::ScalarMul(
Expand Down Expand Up @@ -1898,6 +1906,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<ScalarMulFunctor, ScalarMul2Functor>("ScalarMul");
m.add_functor<InplaceScalarMulFunctor>("InplaceScalarMul");
m.add_functor<ScalarDivFunctor, ScalarDiv2Functor>("ScalarDiv");
m.add_functor<InplaceScalarDivFunctor>("InplaceScalarDiv");
m.add_functor<ScalarPowFunctor>("ScalarPow");
m.add_functor<ScalarPowGradFunctor>("ScalarPowGrad");
m.add_functor<ReduceMaxFunctor>("ReduceMax");
Expand Down
34 changes: 20 additions & 14 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,9 @@ class MseLossFunctor : public LossFunctorBase {
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& target,
const std::string& reduction) const {
const auto out =
sequence_function(functional::Sub).then(functional::Square).call(input, target);
const auto out = sequence_function(functional::Sub)
.then(functional::Square)
.call(input, target, /*inplace=*/false);
return apply_reduction(out, reduction);
}
};
Expand All @@ -494,7 +495,9 @@ class L1LossFunctor : public LossFunctorBase {
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& target,
const std::string& reduction) const {
const auto out = sequence_function(functional::Sub).then(functional::Abs).call(input, target);
const auto out = sequence_function(functional::Sub)
.then(functional::Abs)
.call(input, target, /*inplace=*/false);
return apply_reduction(out, reduction);
}
};
Expand Down Expand Up @@ -549,7 +552,7 @@ class MarginRankingLossFunctor : public LossFunctorBase {
return functional::ScalarAdd(x, Scalar(margin), /*alpha=*/1, /*inplace=*/true);
})
.then(std::bind(functional::Clamp, std::placeholders::_1, Scalar(0), NullOpt))
.call(input_1, input_2);
.call(input_1, input_2, /*inplace=*/false);
return apply_reduction(out, reduction);
}
};
Expand Down Expand Up @@ -1146,19 +1149,22 @@ class TripletMarginLossFunctor {
if ((reduction != "none") && (reduction != "sum") && (reduction != "mean")) return false;
return true;
}());
auto da_p = JUST(VectorNorm(JUST(ScalarAdd(eps, JUST(Sub(anchor, positive)), /*alpha=*/1)), p,
dim, /*keepdim=*/false, anchor->dtype()));
auto da_n = JUST(VectorNorm(JUST(ScalarAdd(eps, JUST(Sub(anchor, negative)), /*alpha=*/1)), p,
dim, /*keepdim=*/false, anchor->dtype()));
auto da_p = JUST(VectorNorm(
JUST(ScalarAdd(eps, JUST(Sub(anchor, positive, /*inplace=*/false)), /*alpha=*/1)), p, dim,
/*keepdim=*/false, anchor->dtype()));
auto da_n = JUST(VectorNorm(
JUST(ScalarAdd(eps, JUST(Sub(anchor, negative, /*inplace=*/false)), /*alpha=*/1)), p, dim,
/*keepdim=*/false, anchor->dtype()));
if (swap) {
auto distance_swap =
JUST(VectorNorm(JUST(ScalarAdd(eps, JUST(Sub(positive, negative)), /*alpha=*/1)), p, dim,
/*keepdim=*/false, positive->dtype()));
auto distance_swap = JUST(VectorNorm(
JUST(ScalarAdd(eps, JUST(Sub(positive, negative, /*inplace=*/false)), /*alpha=*/1)), p,
dim,
/*keepdim=*/false, positive->dtype()));
da_n = JUST(Minimum(distance_swap, da_n));
}
auto triplet_loss =
JUST(Clamp(JUST(ScalarAdd(JUST(Sub(da_p, da_n)), margin, /*alpha=*/1, /*inplace=*/false)),
/*min=*/0.0, NullOpt));
auto triplet_loss = JUST(Clamp(JUST(ScalarAdd(JUST(Sub(da_p, da_n, /*inplace=*/false)), margin,
/*alpha=*/1, /*inplace=*/false)),
/*min=*/0.0, NullOpt));
int32_t ndim = triplet_loss->ndim() - 1;
std::vector<int32_t> axis(1, ndim);

Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def is_deprecated(func_or_class):
from oneflow._C import diag
from oneflow._C import log1p
from oneflow._C import add
from oneflow._C import div
from oneflow._C import div, div_
from oneflow._C import floor, floor_
from oneflow._C import floor_divide
from oneflow._C import mul
Expand Down
21 changes: 18 additions & 3 deletions python/oneflow/framework/docstr/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,15 +533,30 @@

add_docstr(
oneflow.Tensor.mul,
"""
"""Tensor.mul(value) -> Tensor
See :func:`oneflow.mul`
""",
)

add_docstr(
oneflow.Tensor.mul_,
"""
In-place version of :func`oneflow.Tensor.mul`.
"""Tensor.mul_(value) -> Tensor

In-place version of :func:`oneflow.Tensor.mul`.
""",
)

add_docstr(
oneflow.Tensor.div_,
"""Tensor.div_(value) -> Tensor
In-place version of :func:`oneflow.Tensor.div`.
""",
)

add_docstr(
oneflow.Tensor.sub_,
"""Tensor.sub_(value) -> Tensor
In-place version of :func:`oneflow.Tensor.sub`.
""",
)

Expand Down
22 changes: 16 additions & 6 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _le(self, other):


def _mul(self, other):
return flow.mul(self, other)
return flow._C.mul(self, other)


def _mul_(self, other):
Expand All @@ -231,11 +231,11 @@ def _rmul(self, other):


def _add(self, other):
return flow.add(self, other)
return flow._C.add(self, other)


def _add_inplace(self, other):
return flow.add(self, other, inplace=True)
return flow._C.add(self, other, inplace=True)


def _iadd(self, other):
Expand All @@ -247,15 +247,23 @@ def _radd(self, other):


def _sub(self, other):
return flow.sub(self, other)
return flow._C.sub(self, other)


def _sub_inplace(self, other):
return flow._C.sub(self, other, inplace=True)


def _rsub(self, other):
return flow.sub(other, self)
return flow._C.sub(other, self)


def _truediv(self, other):
return flow.div(self, other)
return flow._C.div(self, other)


def _truediv_inplace(self, other):
return flow._C.div_(self, other)


def _rtruediv(self, other):
Expand Down Expand Up @@ -895,10 +903,12 @@ def RegisterMethods():
Tensor.add = _add
Tensor.add_ = _add_inplace
Tensor.div = _truediv
Tensor.div_ = _truediv_inplace
Tensor.mul = _mul
Tensor.mul_ = _mul_
Tensor.reciprocal = _reciprocal
Tensor.sub = _sub
Tensor.sub_ = _sub_inplace
Tensor.asin = _asin
Tensor.arcsin = _arcsin
Tensor.asinh = _asinh
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/test/modules/test_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_mul(test_case):
arg[0](test_case, *arg[1:])

@autotest(check_graph=False)
def test_boardcast_mul(test_case):
def test_broadcast_mul(test_case):
device = random_device()
x_0 = random_pytorch_tensor(ndim=3, dim0=4, dim1=2, dim2=3).to(device)
y = random_pytorch_tensor(ndim=2, dim0=2, dim1=3).to(device)
Expand Down
Loading