Skip to content

Commit

Permalink
[DSL/TE] Scalar support for te.extern (apache#6079)
Browse files Browse the repository at this point in the history
* fix make shape with scalar shapes

* add test

* add test

* remove scalar shape assertion

* fix the data type for overflow problems

* add extra tests

Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-138.ec2.internal>
  • Loading branch information
2 people authored and Trevor Morris committed Aug 26, 2020
1 parent bfa5819 commit 64fc5ed
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def extern(shape,
out_buffers=None,
tag="",
attrs=None):
"""Compute several tensor via extern function.
"""Compute several tensors via an extern function.
Parameters
----------
Expand Down
1 change: 0 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
def _pack_buffer(buf):
"""Build intrinsics that packs the buffer.
"""
assert buf.shape
shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape)
strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides) if buf.strides else 0
pack_args = [buf.data,
Expand Down
20 changes: 13 additions & 7 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class BuiltinLower : public StmtExprMutator {
stack_value_ = Var("stack_value", DataType::Handle());
stack_tcode_ = Var("stack_tcode", DataType::Handle());
stmt = this->VisitStmt(stmt);
if (max_shape_stack_ != 0) {
// create a shape var if any shape is made (including scalar shapes)
if (max_shape_stack_ != -1) {
stmt = LetStmt(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
}
if (max_array_stack_ != 0) {
Expand All @@ -69,7 +70,7 @@ class BuiltinLower : public StmtExprMutator {

Stmt VisitStmt(const Stmt& s) final {
auto stmt = StmtExprMutator::VisitStmt(s);
CHECK_EQ(run_shape_stack_, 0);
CHECK_EQ(run_shape_stack_, -1);
CHECK_EQ(run_array_stack_, 0);

if (prep_seq_.size() != 0) {
Expand Down Expand Up @@ -156,10 +157,15 @@ class BuiltinLower : public StmtExprMutator {
}
// call shape
PrimExpr MakeShape(const CallNode* op) {
size_t stack_begin = run_shape_stack_;
// if args.size() == 0, it represents a scalar shape ()
if (run_shape_stack_ == -1) {
run_shape_stack_ = 0;
}
int64_t stack_begin = run_shape_stack_;
run_shape_stack_ += op->args.size();
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
// no need to perform any store for a scalar shape
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(Store(stack_shape_, cast(DataType::Int(64), op->args[i]),
ConstInt32(stack_begin + i), const_true(1)));
Expand Down Expand Up @@ -206,7 +212,7 @@ class BuiltinLower : public StmtExprMutator {
}
// call packed.
PrimExpr MakeCallPacked(const CallNode* op) {
size_t restore_shape_stack = run_shape_stack_;
int64_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
Expand Down Expand Up @@ -245,7 +251,7 @@ class BuiltinLower : public StmtExprMutator {
}

PrimExpr MakeCallTracePacked(const CallNode* op) {
size_t restore_shape_stack = run_shape_stack_;
int64_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
Expand Down Expand Up @@ -307,11 +313,11 @@ class BuiltinLower : public StmtExprMutator {
Var stack_tcode_;
Var stack_value_;
// The running statistics
uint64_t run_shape_stack_{0};
int64_t run_shape_stack_{-1};
uint64_t run_array_stack_{0};
uint64_t run_arg_stack_{0};
// statistics of stacks
uint64_t max_shape_stack_{0};
int64_t max_shape_stack_{-1};
uint64_t max_array_stack_{0};
uint64_t max_arg_stack_{0};
};
Expand Down
49 changes: 49 additions & 0 deletions tests/python/unittest/test_te_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
from tvm import te
from topi.nn.pooling import pool

Expand Down Expand Up @@ -303,6 +304,52 @@ def intrin_func(ins, outs):
s[P].tensorize(oh, intrin)
tvm.lower(s, [A, P])

def test_tensor_scalar_mixed():
# test te with tensor and scalar
a = np.array(np.random.uniform(size=(10,)), 'float32')
b = np.array(np.random.uniform(size=(1))[0], 'float32')
c = np.array(np.random.uniform(size=(10,)), 'float32')

@tvm.register_func("tvm.test_tensor_scalar_scale")
def my_scale(tensor, scalar, out):
out_np = tensor.asnumpy() * scalar.asnumpy()
tvm.nd.array(out_np).copyto(out)

A = te.placeholder(a.shape, name='A')
B = te.placeholder(b.shape, name='B')
C = te.extern(a.shape, [A, B],
lambda ins, outs: tvm.tir.call_packed(
"tvm.test_tensor_scalar_scale", ins[0], ins[1], outs[0]), name="C")
s = te.create_schedule(C.op)
f = tvm.build(s, [A, B, C], 'llvm')

ta = tvm.nd.array(a)
tb = tvm.nd.array(b)
tc = tvm.nd.array(c)
f(ta, tb, tc)
tvm.testing.assert_allclose(a * b, tc.asnumpy())


def test_tensor_scalar():
# test te with scalar shape
a = np.array(np.random.uniform(size=(1))[0], 'float32')
b = np.array(0.0, 'float32')

@tvm.register_func("tvm.test_tensor_scalar_copy")
def mycopy(x, y):
x.copyto(y)

A = te.placeholder(a.shape, name='A')
B = te.extern(a.shape, [A],
lambda ins, outs: tvm.tir.call_packed(
"tvm.test_tensor_scalar_copy", ins[0], outs[0]), name="B")
s = te.create_schedule(B.op)
f = tvm.build(s, [A, B], 'llvm')

ta = tvm.nd.array(a)
tb = tvm.nd.array(b)
f(ta, tb)
tvm.testing.assert_allclose(ta.asnumpy(), tb.asnumpy())

if __name__ == "__main__":
test_rank_zero()
Expand All @@ -321,3 +368,5 @@ def intrin_func(ins, outs):
test_tuple_inputs()
test_tuple_with_different_deps()
test_tensor_pool()
test_tensor_scalar()
test_tensor_scalar_mixed()

0 comments on commit 64fc5ed

Please sign in to comment.