Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TOPI]fix scatterND large shape problem (apache#12200)
Browse files Browse the repository at this point in the history
* fix scatterND large shape problem

* fix thread pool alloca

* add scatternd unit test

* update with comment

* Empty

Co-authored-by: wrongtest <wrongtest0@gmail.com>
  • Loading branch information
2 people authored and xinetzone committed Nov 25, 2022
1 parent 22a0104 commit 8e8da9b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,8 @@ CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array<Var>& vfields,
}
llvm::StructType* ctype = struct_name.size() ? llvm::StructType::create(fields, struct_name)
: llvm::StructType::create(fields);
llvm::Value* cvalue = builder_->CreateAlloca(ctype, ConstInt32(1));
llvm::AllocaInst* cvalue =
WithFunctionEntry([&]() { return builder_->CreateAlloca(ctype, ConstInt32(1)); });
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(var_map_.at(vfields[i].get()),
Expand Down
20 changes: 20 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,6 +1909,26 @@ def test_cumprod(target, dev, executor_kind):

@tvm.testing.parametrize_targets
def test_scatter_nd(target, dev, executor_kind):
def test_scatter_nd_large_shape():
def before():
data = relay.const(np.zeros((1, 900, 300), dtype="float32"), dtype="float32")
indices = relay.const(np.ones((3, 1, 900, 300), dtype="int64"), dtype="int64")
update = relay.const(np.ones((1, 900, 300), dtype="float32"), dtype="float32")
b = relay.op.scatter_nd(data, indices, update)
return relay.Function(relay.analysis.free_vars(b), b)

passes = tvm.transform.Sequential(
[
relay.transform.InferType(),
relay.transform.FoldConstant(),
]
)
before_mod = tvm.IRModule.from_expr(before())
with tvm.transform.PassContext(opt_level=3):
after_mod = passes(before_mod)

test_scatter_nd_large_shape()

def verify_scatter_nd(
data_np, indices_np, updates_np, ref_res, mode="add", rtol=1e-5, atol=1e-5
):
Expand Down

0 comments on commit 8e8da9b

Please sign in to comment.