Skip to content

Commit

Permalink
Support inplace div add sub (#7293)
Browse files Browse the repository at this point in the history
* fix arange bug

* init

* add inplace div sub

* remove annotation

* fix docs and add unittest

* fix comment

* use._C

* fix docs

* add type promotion logic for inplaceable binary func

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
MARD1NO and oneflow-ci-bot authored Jan 25, 2022
1 parent 9968dc5 commit 40d0107
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 37 deletions.
3 changes: 3 additions & 0 deletions docs/source/tensor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ OneFlow Tensor Class
diagonal,
dim,
div,
div_,
double,
dtype,
element_size,
Expand Down Expand Up @@ -100,6 +101,7 @@ OneFlow Tensor Class
min,
mish,
mul,
mul_,
narrow,
ndim,
ndimension,
Expand Down Expand Up @@ -149,6 +151,7 @@ OneFlow Tensor Class
stride,
swapaxes,
sub,
sub_,
tan,
tanh,
tile,
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/autograd/gradient_funcs/variance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ Maybe<void> Variance::Apply(const VarianceState* ctx, const TensorTuple& out_gra
in_grads->resize(1);
in_grads->at(0) = JUST(functional::Mul(
out_grad,
JUST(functional::ScalarMul(Scalar(2.0 / (elem_cnt - correction)),
JUST(functional::Sub(x, JUST(functional::ReduceMean(
x, ctx->axis, /*keepdim=*/true))))))));
JUST(functional::ScalarMul(
Scalar(2.0 / (elem_cnt - correction)),
JUST(functional::Sub(x, JUST(functional::ReduceMean(x, ctx->axis, /*keepdim=*/true)),
/*inplace=*/false))))));

return Maybe<void>::Ok();
}
Expand Down
10 changes: 9 additions & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
- name: "sub"
signature:
[
"Tensor (Tensor input, Tensor other) => Sub",
"Tensor (Tensor input, Tensor other, *, Bool inplace=False) => Sub",
"Tensor (Tensor input, Scalar other, *, Bool inplace=False) => ScalarSub",
"Tensor (Scalar input, Tensor other) => ScalarSub",
]
Expand Down Expand Up @@ -65,6 +65,14 @@
]
bind_python: true

- name: "div_"
signature:
[
"Tensor (Tensor input, Tensor other) => InplaceDiv",
"Tensor (Tensor input, Scalar other) => InplaceScalarDiv",
]
bind_python: true

- name: "div_grad"
signature: "Tensor (Tensor y, Tensor z, Tensor dz) => DivGrad"
bind_python: False
Expand Down
34 changes: 33 additions & 1 deletion oneflow/core/functional/impl/binary_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class BroadcastPowFunctor : public BinaryFloatFunctor {
}
};

class SubFunctor : public BinaryFunctor {
class SubFunctor : public InplaceableBinaryFunctor {
public:
SubFunctor() {
op_ = CHECK_JUST(one::OpBuilder("broadcast_sub").Input("x").Input("y").Output("z").Build());
Expand Down Expand Up @@ -185,6 +185,37 @@ class DivFunctor : public BinaryFloatFunctor {
}
};

class InplaceDivFunctor {
public:
InplaceDivFunctor() {
broadcast_div_op_ =
CHECK_JUST(one::OpBuilder("broadcast_div").Input("x").Input("y").Output("z").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& y) const {
TensorProcessor tensor_processor;
if (y->requires_grad()) {
JUST(tensor_processor.PromoteInputsToCommonDtype(true)
.AddInputs({JUST(Identity(x)), y})
.Apply());
} else {
JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y}).Apply());
}
const TensorTuple& input_vec = JUST(tensor_processor.GetInputs());
const std::shared_ptr<one::Tensor>& x_cast = input_vec.at(0);
const std::shared_ptr<one::Tensor>& y_cast = input_vec.at(1);
JUST(CheckInplaceValid(x));
JUST(CheckInplaceCastValid(x, x_cast));
JUST(CheckShapeCanExpandTo(*y_cast->shape(), *x_cast->shape()));
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
outputs->at(0) = x;
JUST(OpInterpUtil::Dispatch(*broadcast_div_op_, input_vec, outputs.get()));
return outputs->at(0);
}

private:
std::shared_ptr<OpExpr> broadcast_div_op_;
};
class Atan2Functor : public BinaryFloatFunctor {
public:
Atan2Functor() {
Expand Down Expand Up @@ -339,6 +370,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::SubFunctor>("Sub");
m.add_functor<impl::MulFunctor>("Mul");
m.add_functor<impl::InplaceMulFunctor>("InplaceMul");
m.add_functor<impl::InplaceDivFunctor>("InplaceDiv");
m.add_functor<impl::DivFunctor>("Div");
m.add_functor<impl::PowFunctor>("Pow");
m.add_functor<impl::BroadcastPowFunctor>("BroadcastPow");
Expand Down
15 changes: 10 additions & 5 deletions oneflow/core/functional/impl/binary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,20 @@ class InplaceableBinaryFunctor {
public:
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& y, bool inplace) const {
TensorProcessor tensor_processor;
JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y}).Apply());
TensorTuple input_tuple = JUST(tensor_processor.GetInputs());
if (inplace) {
JUST(CheckInplaceValid(x));
JUST(CheckShapeCanExpandTo(*y->shape(), *x->shape()));
std::shared_ptr<one::Tensor>& x_cast = input_tuple.at(0);
std::shared_ptr<one::Tensor>& y_cast = input_tuple.at(1);
JUST(CheckInplaceCastValid(x, x_cast));
JUST(CheckShapeCanExpandTo(*y_cast->shape(), *x_cast->shape()));
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
outputs->at(0) = x;
JUST(OpInterpUtil::Dispatch(*op_, {x, y}, outputs.get()));
outputs->at(0) = x_cast;
JUST(OpInterpUtil::Dispatch(*op_, input_tuple, outputs.get()));
return outputs->at(0);
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {x, y});
return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple);
}
}

Expand Down
13 changes: 11 additions & 2 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ class ScalarDiv2Functor {
}
};

class InplaceScalarDivFunctor : public ScalarMathBaseFunctor {
public:
InplaceScalarDivFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_mul") {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar& scalar) const {
return ScalarMathBaseFunctor::operator()(x, Scalar(1.0) / scalar, true);
}
};

class ScalarPowFunctor : public ScalarMathBaseFunctor {
public:
ScalarPowFunctor() : ScalarMathBaseFunctor(/*op_name=*/"scalar_pow") {}
Expand Down Expand Up @@ -1619,7 +1627,7 @@ class StandardDeviationFunctor {
Scalar((double)reduce_count)));
const auto& square = JUST(functional::Square(JUST(functional::ScalarDiv(
JUST(functional::ReduceSum(input, axis, keepdims)), Scalar((double)reduce_count)))));
const auto& sub = JUST(functional::Sub(sum, square));
const auto& sub = JUST(functional::Sub(sum, square, /*inplace=*/false));
if (unbias) {
return functional::Sqrt(JUST(functional::ScalarMul(
sub, Scalar((double)reduce_count / (double)(reduce_count - 1)), false)));
Expand Down Expand Up @@ -1652,7 +1660,7 @@ class StandardDeviationFunctor {
const auto& square = JUST(functional::Square(
JUST(functional::ScalarDiv(JUST(functional::ReduceSum(double_input, axis, keepdims)),
Scalar((double)reduce_count)))));
const auto& sub = JUST(functional::Sub(sum, square));
const auto& sub = JUST(functional::Sub(sum, square, /*inplace=*/false));
if (unbias) {
return functional::Cast(
JUST(functional::Sqrt(JUST(functional::ScalarMul(
Expand Down Expand Up @@ -1898,6 +1906,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<ScalarMulFunctor, ScalarMul2Functor>("ScalarMul");
m.add_functor<InplaceScalarMulFunctor>("InplaceScalarMul");
m.add_functor<ScalarDivFunctor, ScalarDiv2Functor>("ScalarDiv");
m.add_functor<InplaceScalarDivFunctor>("InplaceScalarDiv");
m.add_functor<ScalarPowFunctor>("ScalarPow");
m.add_functor<ScalarPowGradFunctor>("ScalarPowGrad");
m.add_functor<ReduceMaxFunctor>("ReduceMax");
Expand Down
34 changes: 20 additions & 14 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,9 @@ class MseLossFunctor : public LossFunctorBase {
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& target,
const std::string& reduction) const {
const auto out =
sequence_function(functional::Sub).then(functional::Square).call(input, target);
const auto out = sequence_function(functional::Sub)
.then(functional::Square)
.call(input, target, /*inplace=*/false);
return apply_reduction(out, reduction);
}
};
Expand All @@ -494,7 +495,9 @@ class L1LossFunctor : public LossFunctorBase {
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& target,
const std::string& reduction) const {
const auto out = sequence_function(functional::Sub).then(functional::Abs).call(input, target);
const auto out = sequence_function(functional::Sub)
.then(functional::Abs)
.call(input, target, /*inplace=*/false);
return apply_reduction(out, reduction);
}
};
Expand Down Expand Up @@ -549,7 +552,7 @@ class MarginRankingLossFunctor : public LossFunctorBase {
return functional::ScalarAdd(x, Scalar(margin), /*alpha=*/1, /*inplace=*/true);
})
.then(std::bind(functional::Clamp, std::placeholders::_1, Scalar(0), NullOpt))
.call(input_1, input_2);
.call(input_1, input_2, /*inplace=*/false);
return apply_reduction(out, reduction);
}
};
Expand Down Expand Up @@ -1146,19 +1149,22 @@ class TripletMarginLossFunctor {
if ((reduction != "none") && (reduction != "sum") && (reduction != "mean")) return false;
return true;
}());
auto da_p = JUST(VectorNorm(JUST(ScalarAdd(eps, JUST(Sub(anchor, positive)), /*alpha=*/1)), p,
dim, /*keepdim=*/false, anchor->dtype()));
auto da_n = JUST(VectorNorm(JUST(ScalarAdd(eps, JUST(Sub(anchor, negative)), /*alpha=*/1)), p,
dim, /*keepdim=*/false, anchor->dtype()));
auto da_p = JUST(VectorNorm(
JUST(ScalarAdd(eps, JUST(Sub(anchor, positive, /*inplace=*/false)), /*alpha=*/1)), p, dim,
/*keepdim=*/false, anchor->dtype()));
auto da_n = JUST(VectorNorm(
JUST(ScalarAdd(eps, JUST(Sub(anchor, negative, /*inplace=*/false)), /*alpha=*/1)), p, dim,
/*keepdim=*/false, anchor->dtype()));
if (swap) {
auto distance_swap =
JUST(VectorNorm(JUST(ScalarAdd(eps, JUST(Sub(positive, negative)), /*alpha=*/1)), p, dim,
/*keepdim=*/false, positive->dtype()));
auto distance_swap = JUST(VectorNorm(
JUST(ScalarAdd(eps, JUST(Sub(positive, negative, /*inplace=*/false)), /*alpha=*/1)), p,
dim,
/*keepdim=*/false, positive->dtype()));
da_n = JUST(Minimum(distance_swap, da_n));
}
auto triplet_loss =
JUST(Clamp(JUST(ScalarAdd(JUST(Sub(da_p, da_n)), margin, /*alpha=*/1, /*inplace=*/false)),
/*min=*/0.0, NullOpt));
auto triplet_loss = JUST(Clamp(JUST(ScalarAdd(JUST(Sub(da_p, da_n, /*inplace=*/false)), margin,
/*alpha=*/1, /*inplace=*/false)),
/*min=*/0.0, NullOpt));
int32_t ndim = triplet_loss->ndim() - 1;
std::vector<int32_t> axis(1, ndim);

Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def is_deprecated(func_or_class):
from oneflow._C import diag
from oneflow._C import log1p
from oneflow._C import add
from oneflow._C import div
from oneflow._C import div, div_
from oneflow._C import floor, floor_
from oneflow._C import floor_divide
from oneflow._C import mul
Expand Down
21 changes: 18 additions & 3 deletions python/oneflow/framework/docstr/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,15 +533,30 @@

add_docstr(
oneflow.Tensor.mul,
"""
"""Tensor.mul(value) -> Tensor
See :func:`oneflow.mul`
""",
)

add_docstr(
oneflow.Tensor.mul_,
"""
In-place version of :func`oneflow.Tensor.mul`.
"""Tensor.mul_(value) -> Tensor
In-place version of :func:`oneflow.Tensor.mul`.
""",
)

add_docstr(
oneflow.Tensor.div_,
"""Tensor.div_(value) -> Tensor
In-place version of :func:`oneflow.Tensor.div`.
""",
)

add_docstr(
oneflow.Tensor.sub_,
"""Tensor.sub_(value) -> Tensor
In-place version of :func:`oneflow.Tensor.sub`.
""",
)

Expand Down
22 changes: 16 additions & 6 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _le(self, other):


def _mul(self, other):
return flow.mul(self, other)
return flow._C.mul(self, other)


def _mul_(self, other):
Expand All @@ -231,11 +231,11 @@ def _rmul(self, other):


def _add(self, other):
return flow.add(self, other)
return flow._C.add(self, other)


def _add_inplace(self, other):
return flow.add(self, other, inplace=True)
return flow._C.add(self, other, inplace=True)


def _iadd(self, other):
Expand All @@ -247,15 +247,23 @@ def _radd(self, other):


def _sub(self, other):
return flow.sub(self, other)
return flow._C.sub(self, other)


def _sub_inplace(self, other):
return flow._C.sub(self, other, inplace=True)


def _rsub(self, other):
return flow.sub(other, self)
return flow._C.sub(other, self)


def _truediv(self, other):
return flow.div(self, other)
return flow._C.div(self, other)


def _truediv_inplace(self, other):
return flow._C.div_(self, other)


def _rtruediv(self, other):
Expand Down Expand Up @@ -895,10 +903,12 @@ def RegisterMethods():
Tensor.add = _add
Tensor.add_ = _add_inplace
Tensor.div = _truediv
Tensor.div_ = _truediv_inplace
Tensor.mul = _mul
Tensor.mul_ = _mul_
Tensor.reciprocal = _reciprocal
Tensor.sub = _sub
Tensor.sub_ = _sub_inplace
Tensor.asin = _asin
Tensor.arcsin = _arcsin
Tensor.asinh = _asinh
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/test/modules/test_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_mul(test_case):
arg[0](test_case, *arg[1:])

@autotest(check_graph=False)
def test_boardcast_mul(test_case):
def test_broadcast_mul(test_case):
device = random_device()
x_0 = random_pytorch_tensor(ndim=3, dim0=4, dim1=2, dim2=3).to(device)
y = random_pytorch_tensor(ndim=2, dim0=2, dim1=3).to(device)
Expand Down
Loading

0 comments on commit 40d0107

Please sign in to comment.