Skip to content

Commit

Permalink
fix reduce_ops 0size bug (#8551)
Browse files Browse the repository at this point in the history
* fix reduce_ops 0size bug

* fix commnet

* auto format by CI

* fix bug

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 6, 2022
1 parent 91eab12 commit 28690a2
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
10 changes: 0 additions & 10 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,16 +488,6 @@ class ReduceAllWholeFunctor {
one::OpBuilder("reduce_all").Input("input_tensor").Output("output_tensor").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {
bool IsZeroSize = [&]() {
for (int i = 0; i < x->shape()->NumAxes(); i++) {
if (x->shape()->at(i) == 0) return true;
}
return false;
}();
if (x->shape()->NumAxes() == 0 || IsZeroSize) {
return JUST(Squeeze(JUST(Constant(Shape{1}, Scalar(1), DType::Bool(), JUST(x->device()))),
std::vector<int32_t>({0})));
}
MutableAttrMap attrs;
std::vector<int32_t> reduce_axis(x->ndim());
std::iota(reduce_axis.begin(), reduce_axis.end(), 0);
Expand Down
30 changes: 26 additions & 4 deletions oneflow/user/kernels/reduce_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/ndarray/ndarray_util.h"
#include "oneflow/core/ndarray/xpu_var_ndarray.h"
Expand Down Expand Up @@ -57,6 +58,12 @@ std::unique_ptr<ep::primitive::Matmul> NewReduceMatmulNoTransAPrimitive(Context*
/*transpose_b=*/false);
}

template<typename Context>
std::unique_ptr<ep::primitive::Fill> NewFillPrimitive(Context* ctx) {
const DataType data_type = ctx->TensorDesc4ArgNameAndIndex("output_tensor", 0)->data_type();
return ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->device_type(), data_type);
}

auto ReduceMatmulTransAPrimitiveExists() {
return hob::make_custom("ReduceMatmulTransAPrimitiveExists",
[](const user_op::KernelRegContext& ctx) {
Expand All @@ -71,6 +78,12 @@ auto ReduceMatmulNoTransAPrimitiveExists() {
});
}

auto FillPrimitiveExists() {
return hob::make_custom("FillPrimitiveExists", [](const user_op::KernelRegContext& ctx) {
return NewFillPrimitive(&ctx).operator bool();
});
}

template<template<typename> class BinaryFunc, DeviceType device_type, typename T, typename K>
class ReduceKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {
public:
Expand All @@ -83,12 +96,20 @@ class ReduceKernel final : public user_op::OpKernel, public user_op::CudaGraphSu
user_op::Tensor* output_tensor = ctx->Tensor4ArgNameAndIndex("output_tensor", 0);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const auto& axis = ctx->Attr<std::vector<int32_t>>("axis");
const int32_t output_elem_cnt = output_tensor->shape_view().elem_cnt();

if (input_tensor->shape_view().elem_cnt() == 0) {
if (output_tensor->shape_view().elem_cnt() != 0) {
Memset<device_type>(
ctx->stream(), output_tensor->mut_dptr<K>(), 0,
output_tensor->shape_view().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()));
Scalar init_value = [&]() {
if (std::is_same<BinaryFunc<T>, BinaryFuncAny<T>>::value) { return Scalar(0); }
if (std::is_same<BinaryFunc<T>, BinaryFuncAll<T>>::value) { return Scalar(1); }
return Scalar(0);
}();
CHECK_GE(output_elem_cnt, 0);
if (output_elem_cnt == 0) { return; }
std::unique_ptr<ep::primitive::Fill> fill = NewFillPrimitive(ctx);
CHECK(fill);
fill->Launch(ctx->stream(), output_tensor->mut_dptr<K>(), init_value, output_elem_cnt);
}
return;
}
Expand Down Expand Up @@ -119,7 +140,8 @@ class ReduceKernel final : public user_op::OpKernel, public user_op::CudaGraphSu
.SetCreateFn<ReduceKernel<binary_func, device, dtype, bool>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == device) \
&& (user_op::HobDataType("input_tensor", 0) == GetDataType<dtype>::value) \
&& (user_op::HobDataType("output_tensor", 0) == DataType::kBool)) \
&& (user_op::HobDataType("output_tensor", 0) == DataType::kBool) \
&& FillPrimitiveExists()) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) { \
const Shape& in_shape = ctx->InputShape("input_tensor", 0); \
return in_shape.elem_cnt() * sizeof(dtype); \
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/test/modules/test_logical_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_any_bool_input_with_random_data(test_case):
@autotest(n=5, auto_backward=False)
def test_reduce_all_0dim_tensor(test_case):
device = random_device()
x = torch.empty(0).to(device)
x = random_tensor(ndim=0, requires_grad=False).to(device)
return torch.all(x)

@autotest(n=5, auto_backward=False)
Expand Down

0 comments on commit 28690a2

Please sign in to comment.