Skip to content

Commit

Permalink
Revert "cinn(op): add broadcast compute (#62488)"
Browse files Browse the repository at this point in the history
This reverts commit d27c2ea.
  • Loading branch information
6clc authored Mar 10, 2024
1 parent 00266ae commit baa0126
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
7 changes: 6 additions & 1 deletion paddle/cinn/hlir/op/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,12 @@ std::shared_ptr<OpStrategy> StrategyForBroadcastToSymbolic(
output_shapes[0].end(),
out_shape.begin(),
[](const ir::Dim &dim) { return dim->dim_expr; });
std::vector<int> broadcast_axes;
CHECK_GT(attrs.attr_store.count("broadcast_axes"), 0);
broadcast_axes =
absl::get<std::vector<int>>(attrs.attr_store.at("broadcast_axes"));
VLOG(3) << "broadcast out shape: " << utils::Join(out_shape, ", ");
VLOG(3) << "broadcast_axes shape: " << utils::Join(broadcast_axes, ", ");

framework::CINNCompute broadcast_to_compute([=](lang::Args args,
lang::RetValue *ret) {
Expand All @@ -323,7 +328,7 @@ std::shared_ptr<OpStrategy> StrategyForBroadcastToSymbolic(
Expr A_expr = pack_args[0];
CHECK(A_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
auto out = pe::BroadcastTo(A, out_shape, tensor_name);
auto out = pe::BroadcastTo(A, out_shape, broadcast_axes, tensor_name);
auto stages = CreateStages({A, out});
*ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}};
});
Expand Down
25 changes: 18 additions & 7 deletions paddle/cinn/hlir/pe/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,25 +374,36 @@ Tensor BroadcastTo(const Tensor& A,

Tensor BroadcastTo(const Tensor& A,
const std::vector<ir::Expr>& out_shape,
const std::vector<int>& broadcast_axes,
const std::string& out_name) {
auto A_shape = A->shape;
CHECK_EQ(A_shape.size(), out_shape.size())
<< "broadcast_to's out_shape's size should be same with the input "
"shape's size";
CHECK_EQ(A_shape.size(), broadcast_axes.size())
<< "broadcast_axes's size should be same with the input shape's size";
CHECK_GE(out_shape.size(), broadcast_axes.size())
<< "broadcast_axes's size should be no more than out_shape's size";
auto axes = broadcast_axes;
for (auto& axis : axes) {
// if axis < 0, plus out_shape.size
if (axis < 0) {
axis = out_shape.size() + axis;
}
CHECK_LT(axis, out_shape.size());
}
std::sort(axes.begin(), axes.end());

return Compute(
ToCinnExprs(out_shape),
[=](const std::vector<Expr>& indice) {
std::vector<Expr> broadcast_indice;
for (int idx = 0; idx < out_shape.size(); ++idx) {
for (int idx = 0; idx < axes.size(); ++idx) {
ir::Expr a_shape_i = A_shape[idx];
if (MathEqual(a_shape_i, ir::Expr(1))) {
broadcast_indice.push_back(ir::Expr(0));
} else if (MathEqual(a_shape_i, out_shape[idx])) {
broadcast_indice.push_back(indice[idx]);
} else if (MathEqual(a_shape_i, out_shape[axes[idx]])) {
broadcast_indice.push_back(indice[axes[idx]]);
} else {
LOG(FATAL) << "fail to broad cast input shape " << a_shape_i
<< " to output shape " << out_shape[idx];
<< " to output shape " << out_shape[axes[idx]];
}
}
return A(broadcast_indice);
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/hlir/pe/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ ir::Tensor BroadcastTo(
ir::Tensor BroadcastTo(
const ir::Tensor& A,
const std::vector<ir::Expr>& out_shape,
const std::vector<int>& broadcast_axes,
const std::string& out_name = cinn::common::UniqName("T_broadcast_to_out"));

// This operator checks if all x and y satisfy the condition: |x - y| <= atol +
Expand Down

0 comments on commit baa0126

Please sign in to comment.