Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Lower shape_of to a builtin #14093

Merged
merged 1 commit into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#define TVM_RELAX_UTILS_H_

#include <tvm/ir/module.h>
#include <tvm/relax/expr.h>
#include <tvm/runtime/logging.h>

#include <algorithm>
Expand Down Expand Up @@ -110,9 +109,10 @@ class NameTable {
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

/*!
* \brief Check if the given type is a boolean scalar type (tensor of rank 0 with a boolean dtype).
* \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean
* dtype).
*
* \param ty The input type.
* \param sinfo The input StructInfo.
* \param permit_unknown_rank If true, it will permit the input type to have unknown rank
* (ndim of -1), which will require a dynamic check.
* \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype
Expand All @@ -121,7 +121,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
* \return True iff the input type is a boolean scalar type (or, depending on options, has unknown
* rank or dtype)
*/
TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true,
TVM_DLL bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank = true,
bool permit_unknown_dtype = true);

/*!
Expand Down
10 changes: 10 additions & 0 deletions src/relax/backend/vm/vm_builtin_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class VMBuiltinLowerMutator : public ExprMutator {
return CallTIRDyn(call);
} else if (call->op == reshape_op_) {
return Reshape(call);
} else if (call->op == shape_of_op_) {
return ShapeOf(call);
} else if (call->op == make_closure_op_) {
return MakeClosure(call);
} else if (call->op == invoke_closure_op_) {
Expand Down Expand Up @@ -132,6 +134,12 @@ class VMBuiltinLowerMutator : public ExprMutator {
return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr ShapeOf(const Call& call_node) {
ICHECK(call_node->args.size() == 1);
ICHECK(call_node->struct_info_.defined());
return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr MakeClosure(const Call& call_node) {
ICHECK(call_node->args.size() == 2);
ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
Expand Down Expand Up @@ -173,6 +181,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
// object to pattern match.
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
Expand All @@ -187,6 +196,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
const ExternFunc builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"};
const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
};
Expand Down
5 changes: 3 additions & 2 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
}
}

bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) {
const DynTensorTypeNode* tt = ty.as<DynTensorTypeNode>();
bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
bool permit_unknown_dtype) {
const TensorStructInfoNode* tt = sinfo.as<TensorStructInfoNode>();
if (!tt) {
return false;
}
Expand Down
62 changes: 62 additions & 0 deletions tests/python/relax/test_relax_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import sys
import tempfile

import numpy as np
import tvm
import tvm.testing
from tvm import relax
from tvm._ffi.base import TVMError
from tvm.script import relax as R


def run_cpu(mod, func_name, *input):
target = tvm.target.Target("llvm")
ex = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
return vm[func_name](*input)


@tvm.script.ir_module
class ShapeOfTest:
@R.function
def get_shape(t: R.Tensor(ndim=-1, dtype="int32")) -> R.Shape(ndim=-1):
return R.shape_of(t)

@R.function
def get_shape_const() -> R.Shape(ndim=-1):
x: R.Tensor((), "int32") = R.const(1, dtype="int32")
return R.shape_of(x)


def test_op_shape_of():
const_shape = run_cpu(ShapeOfTest, "get_shape_const")
assert const_shape == tvm.runtime.ShapeTuple([])

scalar_shape = run_cpu(ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, dtype="int32")))
assert scalar_shape == tvm.runtime.ShapeTuple([])

tensor_shape = run_cpu(
ShapeOfTest, "get_shape", tvm.nd.array(np.zeros((1, 2, 3)).astype("int32"))
)
assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3])


if __name__ == "__main__":
tvm.testing.main()