-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[REFACTOR][PY] establish tvm.ir, migrate base, expr, type
- Loading branch information
Showing
30 changed files
with
639 additions
and
628 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=unused-import | ||
"""Common data structures across all IR variants.""" | ||
from .base import SourceName, Span, Node | ||
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc | ||
from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType | ||
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType | ||
from .type_relation import TypeCall, TypeRelation | ||
from .tensor_type import TensorType |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""FFI APIs for tvm.ir""" | ||
import tvm._ffi | ||
|
||
# Exports functions registered via TVM_REGISTER_GLOBAL with the "runtime" prefix. | ||
# e.g. TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile") | ||
tvm._ffi._init_api("ir", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Span data structures.""" | ||
import tvm._ffi | ||
|
||
from tvm.runtime import Object | ||
from . import _ffi_api | ||
|
||
|
||
class Node(Object): | ||
"""Base class of all IR Nodes, implements astext function.""" | ||
def astext(self, show_meta_data=True, annotate=None): | ||
"""Get the text format of the expression. | ||
Parameters | ||
---------- | ||
show_meta_data : bool | ||
Whether to include meta data section in the text | ||
if there is meta data. | ||
annotate: Optional[Object->str] | ||
Optional annotate function to provide additional | ||
information in the comment block. | ||
Note | ||
---- | ||
The meta data section is necessary to fully parse the text format. | ||
However, it can contain dumps that are big (e.g constant weights), | ||
so it can be helpful to skip printing the meta data section. | ||
Returns | ||
------- | ||
text : str | ||
The text format of the expression. | ||
""" | ||
return _ffi_api.AsText(self, show_meta_data, annotate) | ||
|
||
def __str__(self): | ||
return self.astext(show_meta_data=False) | ||
|
||
|
||
@tvm._ffi.register_object("relay.SourceName") | ||
class SourceName(Object): | ||
"""A identifier for a source location. | ||
Parameters | ||
---------- | ||
name : str | ||
The name of the source. | ||
""" | ||
def __init__(self, name): | ||
self.__init_handle_by_constructor__(_ffi_api.SourceName, name) | ||
|
||
|
||
@tvm._ffi.register_object("relay.Span") | ||
class Span(Object): | ||
"""Specifies a location in a source program. | ||
Parameters | ||
---------- | ||
source : SourceName | ||
The source name. | ||
lineno : int | ||
The line number. | ||
col_offset : int | ||
The column offset of the location. | ||
""" | ||
def __init__(self, source, lineno, col_offset): | ||
self.__init_handle_by_constructor__( | ||
_ffi_api.Span, source, lineno, col_offset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Common expressions data structures in the IR.""" | ||
import tvm._ffi | ||
|
||
from .base import Node | ||
from . import _ffi_api | ||
|
||
class BaseExpr(Node): | ||
"""Base class of all the expressions.""" | ||
|
||
|
||
class PrimExpr(BaseExpr): | ||
"""Base class of all primitive expressions. | ||
PrimExpr is used in the low-level code | ||
optimizations and integer analysis. | ||
""" | ||
|
||
|
||
class RelayExpr(BaseExpr): | ||
"""Base class of all non-primitive expressions.""" | ||
@property | ||
def checked_type(self): | ||
"""Get the checked type of tvm.relay.Expr. | ||
Returns | ||
------- | ||
checked_type : tvm.relay.Type | ||
The checked type. | ||
""" | ||
ret = self._checked_type_ | ||
if ret is None: | ||
raise ValueError("The type checker has not populated" | ||
" the checked_type for this node") | ||
return ret | ||
|
||
|
||
class BaseFunc(RelayExpr): | ||
"""Base class of all functions.""" | ||
|
||
|
||
@tvm._ffi.register_object("relay.GlobalVar") | ||
class GlobalVar(RelayExpr): | ||
"""A global variable in the IR. | ||
GlobalVar is used to refer to the global functions | ||
stored in the IRModule. | ||
Parameters | ||
---------- | ||
name_hint: str | ||
The name of the variable. | ||
""" | ||
def __init__(self, name_hint): | ||
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint) | ||
|
||
def __call__(self, *args): | ||
"""Call the global variable. | ||
Parameters | ||
---------- | ||
args: List[RelayExpr] | ||
The arguments to the call. | ||
Returns | ||
------- | ||
call: Call | ||
A call taking the variable as a function. | ||
""" | ||
# pylint: disable=import-outside-toplevel | ||
if all(isinstance(x, RelayExpr) for x in args): | ||
from tvm import relay | ||
return relay.Call(self, args) | ||
arg_types = [type(x) for x in args] | ||
raise RuntimeError( | ||
"Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Type relation and function for type checking.""" | ||
import tvm._ffi | ||
|
||
from .type import Type | ||
from . import _ffi_api | ||
|
||
|
||
@tvm._ffi.register_object("relay.TensorType") | ||
class TensorType(Type): | ||
"""A concrete TensorType in Relay. | ||
This is the type assigned to tensors with a known dtype and shape. | ||
For example, a tensor of `float32` and `(5, 5)`. | ||
Parameters | ||
---------- | ||
shape : List[tvm.ir.PrimExpr] | ||
The shape of the Tensor | ||
dtype : Optional[str] | ||
The content data type. | ||
""" | ||
def __init__(self, shape, dtype="float32"): | ||
self.__init_handle_by_constructor__( | ||
_ffi_api.TensorType, shape, dtype) | ||
|
||
@property | ||
def concrete_shape(self): | ||
"""Get shape of the type as concrete tuple of int. | ||
Returns | ||
------- | ||
shape : List[int] | ||
The concrete shape of the Type. | ||
Raises | ||
------ | ||
TypeError : If the shape is symbolic | ||
""" | ||
return tuple(int(x) for x in self.shape) |
Oops, something went wrong.