Skip to content

Commit

Permalink
[REFACTOR][PY] establish tvm.ir, migrate base, expr, type
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 11, 2020
1 parent 902e21b commit b4cf4d3
Show file tree
Hide file tree
Showing 30 changed files with 639 additions and 628 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class RelayExpr : public BaseExpr {

class GlobalVar;
/*!
* \brief Global variable that leaves in the top-level module.
* \brief Global variable that lives in the top-level module.
*
* A GlobalVar only refers to function definitions.
* This is used to enable recursive calls between function.
Expand Down
7 changes: 4 additions & 3 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ enum TypeKind : int {
};

/*!
* \brief Type parameter in the function.
* This can be viewed as template parameter in c++ template function.
* \brief Type parameter in functions.
*
* A type variable can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeVar of f is TypeVar(kind=kShapeVar, var=n).
* the TypeVar of f is TypeVar("n", kind=kShapeVar).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/type_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ using TypeRelationFn =
const TypeReporter& reporter)>;

/*!
* \brief User defined type relation, is an input-output relation on types.
* \brief User defined type relation, it is an input-output relation on types.
*
* TypeRelation is more generalized than type call as it allows inference
* of both inputs and outputs.
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc
from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .type_relation import TypeCall, TypeRelation
from .tensor_type import TensorType
22 changes: 22 additions & 0 deletions python/tvm/ir/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.ir"""
import tvm._ffi

# Exports functions registered via TVM_REGISTER_GLOBAL with the "runtime" prefix.
# e.g. TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile")
tvm._ffi._init_api("ir", __name__)
86 changes: 86 additions & 0 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Span data structures."""
import tvm._ffi

from tvm.runtime import Object
from . import _ffi_api


class Node(Object):
"""Base class of all IR Nodes, implements astext function."""
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
Parameters
----------
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[Object->str]
Optional annotate function to provide additional
information in the comment block.
Note
----
The meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big (e.g constant weights),
so it can be helpful to skip printing the meta data section.
Returns
-------
text : str
The text format of the expression.
"""
return _ffi_api.AsText(self, show_meta_data, annotate)

def __str__(self):
return self.astext(show_meta_data=False)


@tvm._ffi.register_object("relay.SourceName")
class SourceName(Object):
"""A identifier for a source location.
Parameters
----------
name : str
The name of the source.
"""
def __init__(self, name):
self.__init_handle_by_constructor__(_ffi_api.SourceName, name)


@tvm._ffi.register_object("relay.Span")
class Span(Object):
"""Specifies a location in a source program.
Parameters
----------
source : SourceName
The source name.
lineno : int
The line number.
col_offset : int
The column offset of the location.
"""
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(
_ffi_api.Span, source, lineno, col_offset)
91 changes: 91 additions & 0 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Common expressions data structures in the IR."""
import tvm._ffi

from .base import Node
from . import _ffi_api

class BaseExpr(Node):
"""Base class of all the expressions."""


class PrimExpr(BaseExpr):
"""Base class of all primitive expressions.
PrimExpr is used in the low-level code
optimizations and integer analysis.
"""


class RelayExpr(BaseExpr):
"""Base class of all non-primitive expressions."""
@property
def checked_type(self):
"""Get the checked type of tvm.relay.Expr.
Returns
-------
checked_type : tvm.relay.Type
The checked type.
"""
ret = self._checked_type_
if ret is None:
raise ValueError("The type checker has not populated"
" the checked_type for this node")
return ret


class BaseFunc(RelayExpr):
"""Base class of all functions."""


@tvm._ffi.register_object("relay.GlobalVar")
class GlobalVar(RelayExpr):
"""A global variable in the IR.
GlobalVar is used to refer to the global functions
stored in the IRModule.
Parameters
----------
name_hint: str
The name of the variable.
"""
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint)

def __call__(self, *args):
"""Call the global variable.
Parameters
----------
args: List[RelayExpr]
The arguments to the call.
Returns
-------
call: Call
A call taking the variable as a function.
"""
# pylint: disable=import-outside-toplevel
if all(isinstance(x, RelayExpr) for x in args):
from tvm import relay
return relay.Call(self, args)
arg_types = [type(x) for x in args]
raise RuntimeError(
"Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types))
56 changes: 56 additions & 0 deletions python/tvm/ir/tensor_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Type relation and function for type checking."""
import tvm._ffi

from .type import Type
from . import _ffi_api


@tvm._ffi.register_object("relay.TensorType")
class TensorType(Type):
"""A concrete TensorType in Relay.
This is the type assigned to tensors with a known dtype and shape.
For example, a tensor of `float32` and `(5, 5)`.
Parameters
----------
shape : List[tvm.ir.PrimExpr]
The shape of the Tensor
dtype : Optional[str]
The content data type.
"""
def __init__(self, shape, dtype="float32"):
self.__init_handle_by_constructor__(
_ffi_api.TensorType, shape, dtype)

@property
def concrete_shape(self):
"""Get shape of the type as concrete tuple of int.
Returns
-------
shape : List[int]
The concrete shape of the Type.
Raises
------
TypeError : If the shape is symbolic
"""
return tuple(int(x) for x in self.shape)
Loading

0 comments on commit b4cf4d3

Please sign in to comment.