Skip to content

Commit

Permalink
[Unity] Lower shape_of to a builtin (#14093)
Browse files Browse the repository at this point in the history
This PR lowers shape_of op to a Relax VM builtin, and changes a utility function to take StructInfo as input.

Co-authored-by: Steven S. Lyubomirsky <slyubomirsky@gmail.com>
  • Loading branch information
2 people authored and tqchen committed Apr 1, 2023
1 parent 79fe0a2 commit 779c54d
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 6 deletions.
8 changes: 4 additions & 4 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#define TVM_RELAX_UTILS_H_

#include <tvm/ir/module.h>
#include <tvm/relax/expr.h>
#include <tvm/runtime/logging.h>

#include <algorithm>
Expand Down Expand Up @@ -110,9 +109,10 @@ class NameTable {
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

/*!
* \brief Check if the given type is a boolean scalar type (tensor of rank 0 with a boolean dtype).
* \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean
* dtype).
*
* \param ty The input type.
* \param sinfo The input StructInfo.
* \param permit_unknown_rank If true, it will permit the input type to have unknown rank
* (ndim of -1), which will require a dynamic check.
* \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype
Expand All @@ -121,7 +121,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
* \return True iff the input type is a boolean scalar type (or, depending on options, has unknown
* rank or dtype)
*/
TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true,
TVM_DLL bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank = true,
bool permit_unknown_dtype = true);

/*!
Expand Down
10 changes: 10 additions & 0 deletions src/relax/backend/vm/vm_builtin_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class VMBuiltinLowerMutator : public ExprMutator {
return CallTIRDyn(call);
} else if (call->op == reshape_op_) {
return Reshape(call);
} else if (call->op == shape_of_op_) {
return ShapeOf(call);
} else if (call->op == make_closure_op_) {
return MakeClosure(call);
} else if (call->op == invoke_closure_op_) {
Expand Down Expand Up @@ -132,6 +134,12 @@ class VMBuiltinLowerMutator : public ExprMutator {
return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr ShapeOf(const Call& call_node) {
ICHECK(call_node->args.size() == 1);
ICHECK(call_node->struct_info_.defined());
return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr MakeClosure(const Call& call_node) {
ICHECK(call_node->args.size() == 2);
ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
Expand Down Expand Up @@ -173,6 +181,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
// object to pattern match.
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
Expand All @@ -187,6 +196,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
const ExternFunc builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"};
const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
};
Expand Down
5 changes: 3 additions & 2 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
}
}

bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) {
const DynTensorTypeNode* tt = ty.as<DynTensorTypeNode>();
bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
bool permit_unknown_dtype) {
const TensorStructInfoNode* tt = sinfo.as<TensorStructInfoNode>();
if (!tt) {
return false;
}
Expand Down
62 changes: 62 additions & 0 deletions tests/python/relax/test_relax_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import sys
import tempfile

import numpy as np
import tvm
import tvm.testing
from tvm import relax
from tvm._ffi.base import TVMError
from tvm.script import relax as R


def run_cpu(mod, func_name, *input):
target = tvm.target.Target("llvm")
ex = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
return vm[func_name](*input)


@tvm.script.ir_module
class ShapeOfTest:
@R.function
def get_shape(t: R.Tensor(ndim=-1, dtype="int32")) -> R.Shape(ndim=-1):
return R.shape_of(t)

@R.function
def get_shape_const() -> R.Shape(ndim=-1):
x: R.Tensor((), "int32") = R.const(1, dtype="int32")
return R.shape_of(x)


def test_op_shape_of():
const_shape = run_cpu(ShapeOfTest, "get_shape_const")
assert const_shape == tvm.runtime.ShapeTuple([])

scalar_shape = run_cpu(ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32")))
assert scalar_shape == tvm.runtime.ShapeTuple([])

tensor_shape = run_cpu(
ShapeOfTest, "get_shape", tvm.nd.array(np.zeros((1, 2, 3)).astype("int32"))
)
assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3])


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 779c54d

Please sign in to comment.