Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DSL/TE] Scalar support for te.extern #6079

Merged
merged 6 commits into from
Jul 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def extern(shape,
out_buffers=None,
tag="",
attrs=None):
"""Compute several tensor via extern function.
"""Compute several tensors via an extern function.
Parameters
----------
Expand Down
1 change: 0 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
def _pack_buffer(buf):
"""Build intrinsics that packs the buffer.
"""
assert buf.shape
shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape)
strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides) if buf.strides else 0
pack_args = [buf.data,
Expand Down
20 changes: 13 additions & 7 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class BuiltinLower : public StmtExprMutator {
stack_value_ = Var("stack_value", DataType::Handle());
stack_tcode_ = Var("stack_tcode", DataType::Handle());
stmt = this->VisitStmt(stmt);
if (max_shape_stack_ != 0) {
// create a shape var if any shape is made (including scalar shapes)
if (max_shape_stack_ != -1) {
stmt = LetStmt(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
}
if (max_array_stack_ != 0) {
Expand All @@ -69,7 +70,7 @@ class BuiltinLower : public StmtExprMutator {

Stmt VisitStmt(const Stmt& s) final {
auto stmt = StmtExprMutator::VisitStmt(s);
CHECK_EQ(run_shape_stack_, 0);
CHECK_EQ(run_shape_stack_, -1);
CHECK_EQ(run_array_stack_, 0);

if (prep_seq_.size() != 0) {
Expand Down Expand Up @@ -156,10 +157,15 @@ class BuiltinLower : public StmtExprMutator {
}
// call shape
PrimExpr MakeShape(const CallNode* op) {
size_t stack_begin = run_shape_stack_;
// if args.size() == 0, it represents a scalar shape ()
if (run_shape_stack_ == -1) {
run_shape_stack_ = 0;
}
int64_t stack_begin = run_shape_stack_;
run_shape_stack_ += op->args.size();
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
// no need to perform any store for a scalar shape
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(Store(stack_shape_, cast(DataType::Int(64), op->args[i]),
ConstInt32(stack_begin + i), const_true(1)));
Expand Down Expand Up @@ -206,7 +212,7 @@ class BuiltinLower : public StmtExprMutator {
}
// call packed.
PrimExpr MakeCallPacked(const CallNode* op) {
size_t restore_shape_stack = run_shape_stack_;
int64_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
Expand Down Expand Up @@ -245,7 +251,7 @@ class BuiltinLower : public StmtExprMutator {
}

PrimExpr MakeCallTracePacked(const CallNode* op) {
size_t restore_shape_stack = run_shape_stack_;
int64_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
Expand Down Expand Up @@ -307,11 +313,11 @@ class BuiltinLower : public StmtExprMutator {
Var stack_tcode_;
Var stack_value_;
// The running statistics
uint64_t run_shape_stack_{0};
int64_t run_shape_stack_{-1};
uint64_t run_array_stack_{0};
uint64_t run_arg_stack_{0};
// statistics of stacks
uint64_t max_shape_stack_{0};
int64_t max_shape_stack_{-1};
uint64_t max_array_stack_{0};
uint64_t max_arg_stack_{0};
};
Expand Down
49 changes: 49 additions & 0 deletions tests/python/unittest/test_te_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
from tvm import te
from topi.nn.pooling import pool

Expand Down Expand Up @@ -303,6 +304,52 @@ def intrin_func(ins, outs):
s[P].tensorize(oh, intrin)
tvm.lower(s, [A, P])

def test_tensor_scalar_mixed():
# test te with tensor and scalar
a = np.array(np.random.uniform(size=(10,)), 'float32')
b = np.array(np.random.uniform(size=(1))[0], 'float32')
c = np.array(np.random.uniform(size=(10,)), 'float32')

@tvm.register_func("tvm.test_tensor_scalar_scale")
def my_scale(tensor, scalar, out):
out_np = tensor.asnumpy() * scalar.asnumpy()
tvm.nd.array(out_np).copyto(out)

A = te.placeholder(a.shape, name='A')
B = te.placeholder(b.shape, name='B')
C = te.extern(a.shape, [A, B],
lambda ins, outs: tvm.tir.call_packed(
"tvm.test_tensor_scalar_scale", ins[0], ins[1], outs[0]), name="C")
s = te.create_schedule(C.op)
f = tvm.build(s, [A, B, C], 'llvm')

ta = tvm.nd.array(a)
tb = tvm.nd.array(b)
tc = tvm.nd.array(c)
f(ta, tb, tc)
tvm.testing.assert_allclose(a * b, tc.asnumpy())


def test_tensor_scalar():
# test te with scalar shape
a = np.array(np.random.uniform(size=(1))[0], 'float32')
b = np.array(0.0, 'float32')

@tvm.register_func("tvm.test_tensor_scalar_copy")
def mycopy(x, y):
x.copyto(y)

A = te.placeholder(a.shape, name='A')
B = te.extern(a.shape, [A],
lambda ins, outs: tvm.tir.call_packed(
"tvm.test_tensor_scalar_copy", ins[0], outs[0]), name="B")
s = te.create_schedule(B.op)
f = tvm.build(s, [A, B], 'llvm')

ta = tvm.nd.array(a)
tb = tvm.nd.array(b)
f(ta, tb)
tvm.testing.assert_allclose(ta.asnumpy(), tb.asnumpy())

if __name__ == "__main__":
test_rank_zero()
Expand All @@ -321,3 +368,5 @@ def intrin_func(ins, outs):
test_tuple_inputs()
test_tuple_with_different_deps()
test_tensor_pool()
test_tensor_scalar()
test_tensor_scalar_mixed()