Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VTA] Fix vta rpc server, refactor launch cond to not depend on sys.argv #8671

Merged
merged 1 commit into from
Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions vta/python/vta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
configure the hardware environment and access remote device through RPC.
"""
import sys
import tvm._ffi.base

from .autotvm import module_loader
from .bitstream import get_bitstream_path, download_bitstream
Expand All @@ -29,8 +30,9 @@

__version__ = "0.1.0"


# do not from tvm import topi when running vta.exec.rpc_server
# to maintain minimum dependency on the board
if sys.argv[0] not in ("-c", "-m"):
# in lib tvm runtime only mode
if not tvm._ffi.base._RUNTIME_ONLY:
from . import top
from .build_module import build_config, lower, build
65 changes: 64 additions & 1 deletion vta/python/vta/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
# pylint: disable=unused-argument, invalid-name
"""VTA specific buildin for runtime."""
import tvm
from tvm.ir import register_intrin_lowering
from . import transform
from .environment import get_env
from .environment import get_env, Environment


def EarlyRewrite():
Expand Down Expand Up @@ -134,3 +135,65 @@ def build(*args, **kwargs):

tvm.ir.register_op_attr("tir.vta.command_handle", "TGlobalSymbol", "VTATLSCommandHandle")
tvm.ir.register_op_attr("tir.vta.command_handle", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)

# The memory information for the compiler
@tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope)
def mem_info_inp_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.INP_ELEM_BITS,
max_simd_bits=spec.INP_ELEM_BITS,
max_num_bits=spec.INP_BUFF_SIZE * 8,
head_address=None,
)


@tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope)
def mem_info_wgt_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.WGT_ELEM_BITS,
max_simd_bits=spec.WGT_ELEM_BITS,
max_num_bits=spec.WGT_BUFF_SIZE * 8,
head_address=None,
)


@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
def mem_info_acc_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.ACC_ELEM_BITS,
max_simd_bits=spec.ACC_ELEM_BITS,
max_num_bits=spec.ACC_BUFF_SIZE * 8,
head_address=None,
)


# TVM Op related registration
@register_intrin_lowering("tir.vta.coproc_sync", "default")
def coproc_sync(op):
_ = op
return tvm.tir.call_extern(
"int32",
"VTASynchronize",
get_env().dev.command_handle,
tvm.runtime.const(1 << 31, dtype="uint32"),
)


@register_intrin_lowering("tir.vta.coproc_dep_push", "default")
def coproc_dep_push(op):
return tvm.tir.call_extern(
"int32", "VTADepPush", get_env().dev.command_handle, op.args[0], op.args[1]
)


@register_intrin_lowering("tir.vta.coproc_dep_pop", "default")
def coproc_dep_pop(op):
return tvm.tir.call_extern(
"int32", "VTADepPop", get_env().dev.command_handle, op.args[0], op.args[1]
)
64 changes: 0 additions & 64 deletions vta/python/vta/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import copy
import tvm
from tvm import te
from tvm.ir import register_intrin_lowering
from . import intrin


Expand Down Expand Up @@ -255,69 +254,6 @@ def get_env():
return Environment.current


# The memory information for the compiler
@tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope)
def mem_info_inp_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.INP_ELEM_BITS,
max_simd_bits=spec.INP_ELEM_BITS,
max_num_bits=spec.INP_BUFF_SIZE * 8,
head_address=None,
)


@tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope)
def mem_info_wgt_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.WGT_ELEM_BITS,
max_simd_bits=spec.WGT_ELEM_BITS,
max_num_bits=spec.WGT_BUFF_SIZE * 8,
head_address=None,
)


@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
def mem_info_acc_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.ACC_ELEM_BITS,
max_simd_bits=spec.ACC_ELEM_BITS,
max_num_bits=spec.ACC_BUFF_SIZE * 8,
head_address=None,
)


# TVM Op related registration
@register_intrin_lowering("tir.vta.coproc_sync", "default")
def coproc_sync(op):
_ = op
return tvm.tir.call_extern(
"int32",
"VTASynchronize",
get_env().dev.command_handle,
tvm.runtime.const(1 << 31, dtype="uint32"),
)


@register_intrin_lowering("tir.vta.coproc_dep_push", "default")
def coproc_dep_push(op):
return tvm.tir.call_extern(
"int32", "VTADepPush", get_env().dev.command_handle, op.args[0], op.args[1]
)


@register_intrin_lowering("tir.vta.coproc_dep_pop", "default")
def coproc_dep_pop(op):
return tvm.tir.call_extern(
"int32", "VTADepPop", get_env().dev.command_handle, op.args[0], op.args[1]
)


def _init_env():
"""Initialize the default global env"""
config_path = os.path.join(get_vta_hw_path(), "config/vta_config.json")
Expand Down