forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ir builder in python * `ForFrame`s in python * python methods * rename and add @staticmethod for current builder * POC demo in python * apply code review suggestions * apply code review suggestions * apply code review suggestions * apply code review suggestions
- Loading branch information
Showing
21 changed files
with
505 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,4 +19,5 @@ | |
from . import tir | ||
from . import relax | ||
|
||
from .builder import Builder | ||
from .parser import ir_module, from_source |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=unused-import | ||
"""Namespace for the TVMScript Builder API.""" | ||
|
||
|
||
from .builder import Builder, def_, def_many | ||
from .frame import Frame, IRModuleFrame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""FFI APIs for tvm.script.builder""" | ||
import tvm._ffi | ||
|
||
tvm._ffi._init_api("script.builder", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""TVM Script IR Builder""" | ||
from typing import List | ||
from tvm._ffi import register_object as _register_object | ||
from .frame import Frame | ||
|
||
from tvm.runtime import Object | ||
|
||
from . import _ffi_api | ||
|
||
from typing import TypeVar | ||
|
||
|
||
@_register_object("script.builder.Builder") | ||
class Builder(Object): | ||
def __init__(self) -> None: | ||
self.__init_handle_by_constructor__(_ffi_api.Builder) | ||
|
||
def __enter__(self) -> "Builder": | ||
_ffi_api.BuilderEnter(self) | ||
return self | ||
|
||
def __exit__(self, ptype, value, trace) -> None: | ||
_ffi_api.BuilderExit(self) | ||
|
||
@staticmethod | ||
def current(self) -> "Builder": | ||
return _ffi_api.BuilderCurrent(self) | ||
|
||
def get(self) -> Frame: | ||
return _ffi_api.BuilderGet(self) | ||
|
||
|
||
DefType = TypeVar("DefType", bound=Object) | ||
|
||
|
||
def def_(name: str, var: DefType) -> DefType: | ||
return _ffi_api.Def(name, var) | ||
|
||
|
||
def def_many(names: List[str], vars: List[DefType]) -> List[DefType]: | ||
assert len(names) == len(vars) | ||
return [def_(name, var) for name, var in zip(names, vars)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""TVM Script Frames""" | ||
from tvm._ffi import register_object as _register_object | ||
|
||
from tvm.runtime import Object | ||
|
||
from . import _ffi_api | ||
|
||
|
||
@_register_object("script.builder.Frame") | ||
class Frame(Object): | ||
def __enter__(self) -> "Frame": | ||
_ffi_api.FrameEnter(self) | ||
return self | ||
|
||
def __exit__(self, ptype, value, trace) -> None: | ||
_ffi_api.FrameExit(self) | ||
|
||
|
||
@_register_object("script.builder.IRModuleFrame") | ||
class IRModuleFrame(Frame): | ||
def __init__(self) -> None: | ||
self.__init_handle_by_constructor__(_ffi_api.IRModuleFrame) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=unused-import | ||
"""Namespace for the TVMScript TIR Builder API.""" | ||
|
||
from .base import TIRFrame | ||
from .for_frame import ( | ||
ForFrame, | ||
serial, | ||
parallel, | ||
vectorized, | ||
unroll, | ||
thread_binding, | ||
grid, | ||
) | ||
from .prim_func_frame import prim_func, arg | ||
from .block_frame import block | ||
from .var import Buffer | ||
from . import axis |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""FFI APIs for tvm.script.builder""" | ||
import tvm._ffi | ||
|
||
from .. import _ffi_api as _base_ffi_api | ||
|
||
tvm._ffi._init_api("script.builder.tir", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""TVM Script TIR Axis""" | ||
|
||
from . import _ffi_api | ||
from tvm.ir import Range | ||
from tvm.tir import IterVar | ||
|
||
|
||
def spatial(dom, binding, dtype="int32") -> IterVar: | ||
if not isinstance(dom, Range): | ||
dom = Range(0, dom) | ||
return _ffi_api.AxisSpatial(dom, binding, dtype) | ||
|
||
|
||
def reduce(dom, binding, dtype="int32") -> IterVar: | ||
if not isinstance(dom, Range): | ||
dom = Range(0, dom) | ||
return _ffi_api.AxisReduce(dom, binding, dtype) | ||
|
||
|
||
def remap(kinds, bindings, dtype="int32") -> IterVar: | ||
return _ffi_api.AxisRemap(kinds, bindings, dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""TVM Script TIR Frame""" | ||
from tvm._ffi import register_object as _register_object | ||
|
||
from . import _ffi_api | ||
from ..frame import Frame | ||
|
||
|
||
@_register_object("script.builder.tir.TIRFrame") | ||
class TIRFrame(Frame): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""TVM Script TIR Block Frame""" | ||
from tvm._ffi import register_object as _register_object | ||
from .base import TIRFrame | ||
|
||
|
||
from . import _ffi_api | ||
|
||
|
||
@_register_object("script.builder.tir.BlockFrame") | ||
class BlockFrame(TIRFrame): | ||
pass | ||
|
||
|
||
def block(name) -> BlockFrame: | ||
return _ffi_api.BlockFrame(name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""TVM Script TIR For Frame""" | ||
from tvm._ffi import register_object as _register_object | ||
|
||
from tvm.tir import Var | ||
|
||
from . import _ffi_api | ||
from ._ffi_api import _base_ffi_api | ||
from .base import TIRFrame | ||
from typing import List | ||
|
||
|
||
@_register_object("script.builder.tir.ForFrame") | ||
class ForFrame(TIRFrame): | ||
def __enter__(self) -> List[Var]: | ||
_base_ffi_api.FrameEnter(self) | ||
return self.vars | ||
|
||
|
||
def serial(min_val, extent, attrs) -> ForFrame: | ||
return _ffi_api.Serial(min_val, extent, attrs) | ||
|
||
|
||
def parallel(min_val, extent, attrs) -> ForFrame: | ||
return _ffi_api.Parallel(min_val, extent, attrs) | ||
|
||
|
||
def vectorized(min_val, extent, attrs) -> ForFrame: | ||
return _ffi_api.Vectorized(min_val, extent, attrs) | ||
|
||
|
||
def unroll(min_val, extent, attrs) -> ForFrame: | ||
return _ffi_api.Unroll(min_val, extent, attrs) | ||
|
||
|
||
def thread_binding(min_val, extent, attrs) -> ForFrame: | ||
return _ffi_api.ThreadBinding(min_val, extent, attrs) | ||
|
||
|
||
def grid(*extents) -> ForFrame: | ||
return _ffi_api.Grid(extents) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""TVM Script TIR Prim Func Frame""" | ||
from tvm._ffi import register_object as _register_object | ||
|
||
from tvm.tir.expr import Var | ||
from tvm.tir.buffer import Buffer | ||
|
||
|
||
from . import _ffi_api | ||
from .base import TIRFrame | ||
|
||
from typing import Union | ||
|
||
|
||
@_register_object("script.builder.tir.PrimFuncFrame") | ||
class PrimFuncFrame(TIRFrame): | ||
pass | ||
|
||
|
||
def prim_func(name) -> PrimFuncFrame: | ||
return _ffi_api.PrimFuncFrame(name) | ||
|
||
|
||
def arg(name, arg) -> Union[Var, Buffer]: | ||
return _ffi_api.Arg(name, arg) |
Oops, something went wrong.