Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Aug 22, 2022
1 parent c035189 commit 7364eac
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions cinn/hlir/op/contrib/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,39 @@ ir::Tensor Scatter(
CHECK_EQ(A->shape.size(), B->shape.size());
CHECK_EQ(A->shape.size(), C->shape.size());

// std::string extern_fun_name = "cinn_host_find_int_nd";
//
// int pos_axis = axis;
// if (pos_axis < 0) {
// pos_axis += C->shape.size();
// }
//
// std::vector<int> new_axes;
// for (int i = 0; i < A->shape.size(); ++i) {
// if (i != pos_axis) {
// new_axes.push_back(i);
// }
// }
// new_axes.push_back(axis);
std::string extern_fun_name = "cinn_host_find_int_nd";

int pos_axis = axis;
if (pos_axis < 0) {
pos_axis += C->shape.size();
}

std::vector<int> new_axes;
for (int i = 0; i < A->shape.size(); ++i) {
if (i != pos_axis) {
new_axes.push_back(i);
}
}
new_axes.push_back(axis);
// auto new_B = pe::Transpose(B, new_axes, name + "_index_transpose");
auto res = Compute(
C->shape,
[=](const std::vector<Expr> &indices) {
// auto offset = Expr(0);
// for (int i = 0; i < indices.size(); i++) {
// if (i != pos_axis) {
// offset = offset * C->shape[i] + indices[i];
// }
// }
// offset = common::AutoSimplify(offset);
auto offset = Expr(0);
for (int i = 0; i < indices.size(); i++) {
if (i != pos_axis) {
offset = offset * C->shape[i] + indices[i];
}
}
offset = common::AutoSimplify(offset);
auto idx = lang::CallExtern(extern_fun_name, {B, B->shape[-1], indices[pos_axis], offset, Expr(1)});
// auto idx = lang::CallExtern(extern_fun_name, {new_B, new_B->shape[-1], indices[pos_axis], offset,
// Expr(1)}); std::vector<Expr> A_indices(indices); A_indices[pos_axis] = idx; auto keep =
// ir::EQ::Make(idx, Expr(-1));
// return ir::Select::Make(keep, C(indices), A(A_indices));
return C(indices);
// Expr(1)});
std::vector<Expr> A_indices(indices);
A_indices[pos_axis] = idx;
auto keep = ir::EQ::Make(idx, Expr(-1));
return ir::Select::Make(keep, C(indices), A(A_indices));
// return C(indices);
},
name);
return res;
Expand Down Expand Up @@ -119,8 +122,9 @@ ir::Tensor ScatterNd(const ir::Tensor &A,
auto keep = Expr(true);
std::vector<Expr> idx;
for (int i = 0; i < pos_axes.size(); ++i) {
auto cur_idx = lang::CallExtern(
extern_fun_name, {B, B->shape[-2], indices[pos_axes[i]], offset + Expr(i), Expr(pos_axes.size())});
auto cur_idx = lang::CallExtern(extern_fun_name, {B, B->shape[-2], indices[pos_axes[i]]});
// extern_fun_name, {B, B->shape[-2], indices[pos_axes[i]], offset + Expr(i),
// Expr(pos_axes.size())});
if (idx.empty()) {
idx.push_back(cur_idx);
A_indices.push_back(cur_idx);
Expand Down

0 comments on commit 7364eac

Please sign in to comment.