Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Sep 19, 2022
1 parent 3b987b6 commit 760e747
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 26 deletions.
24 changes: 4 additions & 20 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,27 +350,15 @@ Variable NetBuilder::ReluGrad(const Variable& lhs, const Variable& rhs) {
}

Variable NetBuilder::Gather(const Variable& x, const Variable& index, const int& axis) {
Instruction instr("gather", {x, index});
instr.SetAttr("axis", axis);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
return CustomInstr("gather", {x, index}, {{"axis", axis}}).front();
}

Variable NetBuilder::GatherNd(const Variable& x, const Variable& index, const std::vector<int>& axes) {
Instruction instr("gather_nd", {x, index});
instr.SetAttr("axes", axes);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
return CustomInstr("gather_nd", {x, index}, {{"axes", axes}}).front();
}

Variable NetBuilder::Scatter(const Variable& src, const Variable& index, const Variable& out, const int& axis) {
Instruction instr("scatter", {src, index, out});
instr.SetAttr("axis", axis);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
return CustomInstr("scatter", {src, index, out}, {{"axis", axis}}).front();
}
Variable NetBuilder::Scatter(const Variable& src,
const Variable& index,
Expand All @@ -385,11 +373,7 @@ Variable NetBuilder::ScatterNd(const Variable& src,
const Variable& index,
const Variable& out,
const std::vector<int>& axes) {
Instruction instr("scatter_nd", {src, index, out});
instr.SetAttr("axes", axes);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
return CustomInstr("scatter_nd", {src, index, out}, {{"axes", axes}}).front();
}
Variable NetBuilder::ScatterNd(const Variable& src,
const Variable& index,
Expand Down
4 changes: 2 additions & 2 deletions cinn/hlir/op/contrib/gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ ir::Tensor Gather(const ir::Tensor &A, const ir::Tensor &B, const int &axis, con
for (size_t i = axis + 1; i < A->shape.size(); ++i) {
A_indices.push_back(indices[i]);
}
return A(A_indices);
return lang::Identity(A(A_indices));
},
name);
return res;
Expand All @@ -78,7 +78,7 @@ ir::Tensor GatherNd(const ir::Tensor &A, const ir::Tensor &B, const std::vector<
A_indices[axes[i]] = B(B_indices);
B_indices.pop_back();
}
return A(A_indices);
return lang::Identity(A(A_indices));
},
name);
return res;
Expand Down
22 changes: 18 additions & 4 deletions cinn/hlir/op/contrib/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,16 @@ std::shared_ptr<framework::OpStrategy> StrategyForScatter(const framework::NodeA
*ret = CINNValuePack{res};
});

framework::CINNSchedule scatter_schedule([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of reshape schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(
scatter_compute, framework::GetInjectiveScheduleFunc(output_shapes, target), "strategy.scatter.x86", 1);
strategy->AddImpl(scatter_compute, scatter_schedule, "strategy.scatter.x86", 1);
return strategy;
}

Expand Down Expand Up @@ -257,9 +264,16 @@ std::shared_ptr<framework::OpStrategy> StrategyForScatterNd(const framework::Nod
*ret = CINNValuePack{res};
});

framework::CINNSchedule scatter_schedule([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of reshape schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(
scatter_nd_compute, framework::GetInjectiveScheduleFunc(output_shapes, target), "strategy.scatter_nd.x86", 1);
strategy->AddImpl(scatter_compute, scatter_schedule, "strategy.scatter_nd.x86", 1);
return strategy;
}

Expand Down

0 comments on commit 760e747

Please sign in to comment.