From baa01269b8ff85abfcd3359f5c65e9363241deb1 Mon Sep 17 00:00:00 2001 From: 6clc <569519574@qq.com> Date: Sun, 10 Mar 2024 17:43:01 +0800 Subject: [PATCH] Revert "cinn(op): add broadcast compute (#62488)" This reverts commit d27c2ea30d7d68eb2eddaedabe3e8f9c3a57fb06. --- paddle/cinn/hlir/op/broadcast.cc | 7 ++++++- paddle/cinn/hlir/pe/broadcast.cc | 25 ++++++++++++++++++------- paddle/cinn/hlir/pe/broadcast.h | 1 + 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/paddle/cinn/hlir/op/broadcast.cc b/paddle/cinn/hlir/op/broadcast.cc index 444a6f69c5d52..c6c7ee00a9449 100644 --- a/paddle/cinn/hlir/op/broadcast.cc +++ b/paddle/cinn/hlir/op/broadcast.cc @@ -307,7 +307,12 @@ std::shared_ptr StrategyForBroadcastToSymbolic( output_shapes[0].end(), out_shape.begin(), [](const ir::Dim &dim) { return dim->dim_expr; }); + std::vector broadcast_axes; + CHECK_GT(attrs.attr_store.count("broadcast_axes"), 0); + broadcast_axes = + absl::get>(attrs.attr_store.at("broadcast_axes")); VLOG(3) << "broadcast out shape: " << utils::Join(out_shape, ", "); + VLOG(3) << "broadcast_axes shape: " << utils::Join(broadcast_axes, ", "); framework::CINNCompute broadcast_to_compute([=](lang::Args args, lang::RetValue *ret) { @@ -323,7 +328,7 @@ std::shared_ptr StrategyForBroadcastToSymbolic( Expr A_expr = pack_args[0]; CHECK(A_expr.as_tensor()); ir::Tensor A = A_expr.as_tensor_ref(); - auto out = pe::BroadcastTo(A, out_shape, tensor_name); + auto out = pe::BroadcastTo(A, out_shape, broadcast_axes, tensor_name); auto stages = CreateStages({A, out}); *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); diff --git a/paddle/cinn/hlir/pe/broadcast.cc b/paddle/cinn/hlir/pe/broadcast.cc index 9ab00fc8ce5da..29189a5b1987c 100644 --- a/paddle/cinn/hlir/pe/broadcast.cc +++ b/paddle/cinn/hlir/pe/broadcast.cc @@ -374,25 +374,36 @@ Tensor BroadcastTo(const Tensor& A, Tensor BroadcastTo(const Tensor& A, const std::vector& out_shape, + const std::vector& broadcast_axes, const std::string& out_name) { auto A_shape = A->shape; - CHECK_EQ(A_shape.size(), out_shape.size()) - << "broadcast_to's out_shape's size should be same with the input " - "shape's size"; + CHECK_EQ(A_shape.size(), broadcast_axes.size()) + << "broadcast_axes's size should be same with the input shape's size"; + CHECK_GE(out_shape.size(), broadcast_axes.size()) + << "broadcast_axes's size should be no more than out_shape's size"; + auto axes = broadcast_axes; + for (auto& axis : axes) { + // if axis < 0, plus out_shape.size + if (axis < 0) { + axis = out_shape.size() + axis; + } + CHECK_LT(axis, out_shape.size()); + } + std::sort(axes.begin(), axes.end()); return Compute( ToCinnExprs(out_shape), [=](const std::vector& indice) { std::vector broadcast_indice; - for (int idx = 0; idx < out_shape.size(); ++idx) { + for (int idx = 0; idx < axes.size(); ++idx) { ir::Expr a_shape_i = A_shape[idx]; if (MathEqual(a_shape_i, ir::Expr(1))) { broadcast_indice.push_back(ir::Expr(0)); - } else if (MathEqual(a_shape_i, out_shape[idx])) { - broadcast_indice.push_back(indice[idx]); + } else if (MathEqual(a_shape_i, out_shape[axes[idx]])) { + broadcast_indice.push_back(indice[axes[idx]]); } else { LOG(FATAL) << "fail to broad cast input shape " << a_shape_i - << " to output shape " << out_shape[idx]; + << " to output shape " << out_shape[axes[idx]]; } } return A(broadcast_indice); diff --git a/paddle/cinn/hlir/pe/broadcast.h b/paddle/cinn/hlir/pe/broadcast.h index f2cb2649ad499..efdafee9c9dce 100644 --- a/paddle/cinn/hlir/pe/broadcast.h +++ b/paddle/cinn/hlir/pe/broadcast.h @@ -118,6 +118,7 @@ ir::Tensor BroadcastTo( ir::Tensor BroadcastTo( const ir::Tensor& A, const std::vector& out_shape, + const std::vector& broadcast_axes, const std::string& out_name = cinn::common::UniqName("T_broadcast_to_out")); // This operator checks if all x and y satisfy the condition: |x - y| <= atol +