Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Hexagon]: Add upload function to hexagon session (apache#13161)
Browse files Browse the repository at this point in the history
* [Hexagon]: Add upload to hexagon session

* lint

* fix typo

* fix serial device skips

* fix test on device

* add serial device to rpc key

* update error name

* fix comment

* move create_session to launcher

* add is_simulator

* lint

* address Eric comment on Session object

* rebase with main
  • Loading branch information
mehrdadh authored and liuxinwei committed Nov 10, 2022
1 parent 5d4cd57 commit f077faa
Show file tree
Hide file tree
Showing 17 changed files with 196 additions and 246 deletions.
195 changes: 46 additions & 149 deletions python/tvm/contrib/hexagon/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
import tempfile
from typing import Union

import tvm
from tvm.contrib.hexagon.hexagon_profiler import HexagonProfiler
from ..._ffi import libinfo
from .session import Session

from .tools import HEXAGON_SIMULATOR_NAME

HEXAGON_RPC_LIB_DIR = os.environ.get("HEXAGON_RPC_LIB_DIR")
ANDROID_BASH_FILE_NAME = "android_bash.sh"
HEXAGON_REMOTE_DEVICE_KEY = "hexagon-dev"


def _check_call_verbose(cmd, **kwargs) -> None:
Expand Down Expand Up @@ -103,14 +103,9 @@ class HexagonLauncherRPC(metaclass=abc.ABCMeta):
The basic flow of interaction with the launcher is
launcher = HexagonLauncher(...)
launcher.start_server()
with launcher.start_session() as session:
with launcher.create_session() as session:
# Do something with the session
launcher.stop_server()
"""

HEXAGON_REMOTE_DEVICE_KEY = "hexagon-dev"

"""Configure HexagonLauncherRPC.
Parameters
----------
Expand All @@ -129,7 +124,9 @@ class HexagonLauncherRPC(metaclass=abc.ABCMeta):
used.
"""

def __init__(self, rpc_info: dict, workspace: Union[str, pathlib.Path] = None):
def __init__(
self, rpc_info: dict, workspace: Union[str, pathlib.Path] = None, serial_number: str = None
):
self._rpc_info = {
"rpc_tracker_host": "0.0.0.0",
"rpc_tracker_port": 9190,
Expand All @@ -138,7 +135,7 @@ def __init__(self, rpc_info: dict, workspace: Union[str, pathlib.Path] = None):
}
self._rpc_info.update(rpc_info)
self._workspace = self._create_workspace(workspace)
self._device_key = self.HEXAGON_REMOTE_DEVICE_KEY
self._serial_number = serial_number

@abc.abstractmethod
def start_server(self):
Expand Down Expand Up @@ -205,138 +202,6 @@ def _create_workspace(self, workspace: Union[str, pathlib.Path]) -> pathlib.Path
workspace = os.path.join(base_dir, _get_test_directory_name())
return self._create_remote_directory(workspace)

def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str) -> pathlib.Path:
"""Upload a local file to the remote workspace.
Parameters
----------
local_path : str or pathlib.Path
Path to the local file to be copied.
remote_filename : str
Name of the file in the remote workspace.
Returns
-------
pathlib.Path :
Uploaded file remote path.
"""
assert self._workspace
remote_file_path = self._workspace / remote_filename
self._copy_to_remote(local_path, str(remote_file_path))
return remote_file_path

def start_session(self, session_name: str = "hexagon-rpc") -> Session:
"""Connect to the RPC server.
Parameters
----------
session_name : str
RPC session name.
Returns
-------
Session :
The session object.
"""
hexagon_remote_kw = {
"host": self._rpc_info["rpc_tracker_host"],
"port": self._rpc_info["rpc_tracker_port"],
"priority": 0,
"timeout": 0,
"key": self._device_key,
}
return Session(self, hexagon_remote_kw, session_name=session_name)

def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], session: Session):
"""Load TVM module.
Parameters
----------
module : Union[str, pathlib.Path, tvm.runtime.Module]
The module to load. If `module` is a
`tvm.runtime.Module`, it will be uploaded to the remote
session and loaded.
If the object passed is a string or pathlib.Path, it must
be a full path in the remote system.
session : Session
Remote session. The session must be established (via __enter__)
prior to calling this function.
Returns
-------
TVMModule :
TVM module object.
"""
return session.load_module(module)

def get_graph_executor(
self,
graph_json: str,
module: Union[str, pathlib.Path, tvm.runtime.Module],
session: Session,
):
"""Create a local GraphModule which consumes a remote libmod.
Parameters
----------
graph_json : str
The string with the graph JSON.
module : Union[str, pathlib.Path, tvm.runtime.Module]
The module to load. If `module` is a
`tvm.runtime.Module`, it will be uploaded to the remote
session and loaded.
If the object passed is a string or pathlib.Path, it must
be a full path in the remote system.
session : Session
Remote session. The session must be established (via __enter__)
prior to calling this function.
Returns
-------
GraphModule :
Runtime graph module that can be used to execute the graph.
"""
return session.get_graph_executor(graph_json, module)

def get_graph_debug_executor(
self,
graph_json: str,
module: Union[str, pathlib.Path, tvm.runtime.Module],
session: Session,
dump_root: Union[str, pathlib.Path] = None,
):
"""Create a local GraphModuleDebug which consumes a remote libmod.
Parameters
----------
graph_json : str
The string with the graph JSON.
module : Union[str, pathlib.Path, tvm.runtime.Module]
The module to load. If `module` is a
`tvm.runtime.Module`, it will be uploaded to the remote
session and loaded.
If the object passed is a string or pathlib.Path, it must
be a full path in the remote system.
session : Session
Remote session. The session must be established (via __enter__)
prior to calling this function.
Returns
-------
GraphModuleDebug :
Runtime debug graph module that can be used to debug the graph.
"""
return session.get_graph_debug_executor(graph_json, module, dump_root=dump_root)

@abc.abstractmethod
def get_profile_output(
self,
Expand All @@ -360,6 +225,31 @@ def get_profile_output(
"""
...

def create_session(self, session_name: str = "hexagon-rpc") -> Session:
"""Create an RPC session.
Parameters
----------
session_name : str
RPC session name.
Returns
-------
Session :
The session object.
"""
hexagon_session_kw = {
"remote_workspace": self._workspace,
"rpc_tracker": (self._rpc_info["rpc_tracker_host"], self._rpc_info["rpc_tracker_port"]),
"rpc_server_key": self._rpc_info["device_key"],
"serial_number": self._serial_number,
"session_name": session_name,
}
return Session(**hexagon_session_kw)

def is_simulator(self):
return self._serial_number == HEXAGON_SIMULATOR_NAME


class HexagonLauncherAndroid(HexagonLauncherRPC):
"""Hexagon Launcher for Android."""
Expand Down Expand Up @@ -402,15 +292,18 @@ def __init__(
if not rpc_info.get("workspace_base"):
rpc_info["workspace_base"] = self.ANDROID_HEXAGON_TEST_BASE_DIR
self._serial_number = serial_number
assert self._serial_number != "", "Android serial number is not set."

adb_socket = rpc_info["adb_server_socket"] if rpc_info["adb_server_socket"] else "tcp:5037"
self._adb_device_sub_cmd = ["adb", "-L", adb_socket, "-s", self._serial_number]
self.forwarded_ports_ = []
self._hexagon_debug = hexagon_debug
self._clear_logcat = clear_logcat
self._sysmon_profile = sysmon_profile
self._sysmon_process = None
rpc_info["device_key"] = HEXAGON_REMOTE_DEVICE_KEY + "." + self._serial_number

super(HexagonLauncherAndroid, self).__init__(rpc_info, workspace)
super(HexagonLauncherAndroid, self).__init__(rpc_info, workspace, self._serial_number)

def _copy_to_remote(
self, local_path: Union[str, pathlib.Path], remote_path: Union[str, pathlib.Path]
Expand Down Expand Up @@ -442,7 +335,9 @@ def _copy_binaries(self):
"<RPC_TRACKER_PORT>", str(self._rpc_info["rpc_tracker_port"])
)
if "<HEXAGON_REMOTE_DEVICE_KEY>" in line:
line = line.replace("<HEXAGON_REMOTE_DEVICE_KEY>", self._device_key)
line = line.replace(
"<HEXAGON_REMOTE_DEVICE_KEY>", self._rpc_info["device_key"]
)
if "<RPC_SERVER_PORT>" in line:
line = line.replace(
"<RPC_SERVER_PORT>", str(self._rpc_info["rpc_server_port"])
Expand Down Expand Up @@ -691,12 +586,13 @@ def __init__(self, rpc_info: dict, workspace: Union[str, pathlib.Path] = None):
Parameters are same as for HexagonLauncherRPC.
"""
super(HexagonLauncherSimulator, self).__init__(rpc_info, workspace)

self._toolchain = os.environ.get("HEXAGON_TOOLCHAIN")
if not self._toolchain:
raise RuntimeError("Please set HEXAGON_TOOLCHAIN env variable")
self._serial_number = "simulator"
self._serial_number = HEXAGON_SIMULATOR_NAME

super(HexagonLauncherSimulator, self).__init__(rpc_info, workspace, self._serial_number)

def _copy_to_remote(
self, local_path: Union[str, pathlib.Path], remote_path: Union[str, pathlib.Path]
Expand Down Expand Up @@ -740,18 +636,19 @@ def start_server(self):
self._copy_to_remote(lib_dir / item, self._workspace / item)
# Copy libc++ from the toolchain to the workspace
self._copy_libcxx(self._workspace)
self._device_key = self.HEXAGON_REMOTE_DEVICE_KEY + "." + str(os.getpid())
self._rpc_info["device_key"] = HEXAGON_REMOTE_DEVICE_KEY + "." + str(os.getpid())

rpc_tracker_host = self._rpc_info["rpc_tracker_host"]
rpc_tracker_port = self._rpc_info["rpc_tracker_port"]
rpc_server_port = self._rpc_info["rpc_server_port"]
device_key = self._rpc_info["device_key"]
server_exe = os.path.join(".", "tvm_rpc_x86")

args = [
"server",
f"--tracker={rpc_tracker_host}:{rpc_tracker_port}",
f"--port={rpc_server_port}",
f"--key={self._device_key}",
f"--key={device_key}",
"--timeout=0",
]

Expand Down Expand Up @@ -823,7 +720,7 @@ def HexagonLauncher(
sysmon_profile: bool = False,
):
"""Creates a HexagonLauncher"""
if serial_number == "simulator":
if serial_number == HEXAGON_SIMULATOR_NAME:
return HexagonLauncherSimulator(rpc_info, workspace)
return HexagonLauncherAndroid(
serial_number, rpc_info, workspace, hexagon_debug, clear_logcat, sysmon_profile
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/hexagon/meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:


def _worker_func(hexagon_launcher, evaluator_config, alloc_repeat, artifact_path, args_info):
with hexagon_launcher.start_session() as session:
with hexagon_launcher.create_session() as session:
device = session.device
_, remote_path = os.path.split(artifact_path)
uploaded = session.upload(artifact_path, remote_path)
Expand Down
15 changes: 8 additions & 7 deletions python/tvm/contrib/hexagon/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import tvm.rpc.tracker
from tvm.contrib.hexagon.build import HexagonLauncher, HexagonLauncherRPC
from tvm.contrib.hexagon.session import Session
from tvm.contrib.hexagon.tools import HEXAGON_SIMULATOR_NAME

HEXAGON_TOOLCHAIN = "HEXAGON_TOOLCHAIN"
TVM_TRACKER_HOST = "TVM_TRACKER_HOST"
Expand Down Expand Up @@ -173,7 +174,7 @@ def hexagon_server_process(

if android_serial_num is None:
pytest.skip("ANDROID_SERIAL_NUMBER is not set.")
if android_serial_num == ["simulator"]:
if android_serial_num == [HEXAGON_SIMULATOR_NAME]:
yield None
else:
# Requesting these fixtures sets up a local tracker, if one
Expand Down Expand Up @@ -220,7 +221,7 @@ def pytest_configure(config):


def pytest_configure_node(node):
# the master for each node fills slaveinput dictionary
# the master for each node fills node input dictionary
# which pytest-xdist will transfer to the subprocess
if node.config.iplist is not None:
node.workerinput["device_adr"] = node.config.iplist.pop()
Expand All @@ -240,7 +241,7 @@ def hexagon_launcher(
"""Initials and returns hexagon launcher which reuses RPC info and Android serial number."""
android_serial_num = android_serial_number()

if android_serial_num != ["simulator"]:
if android_serial_num != [HEXAGON_SIMULATOR_NAME]:
rpc_info = hexagon_server_process["launcher"]._rpc_info
else:
rpc_info = {
Expand All @@ -250,7 +251,7 @@ def hexagon_launcher(
"adb_server_socket": adb_server_socket,
}
try:
if android_serial_num == ["simulator"]:
if android_serial_num == [HEXAGON_SIMULATOR_NAME]:
launcher = HexagonLauncher(serial_number=android_serial_num[0], rpc_info=rpc_info)
launcher.start_server()
else:
Expand All @@ -263,7 +264,7 @@ def hexagon_launcher(
)
yield launcher
finally:
if android_serial_num == ["simulator"]:
if android_serial_num == [HEXAGON_SIMULATOR_NAME]:
launcher.stop_server()
elif not hexagon_debug:
launcher.cleanup_directory()
Expand All @@ -274,7 +275,7 @@ def hexagon_session(hexagon_launcher: HexagonLauncherRPC) -> Session:
if hexagon_launcher is None:
yield None
else:
with hexagon_launcher.start_session() as session:
with hexagon_launcher.create_session() as session:
yield session


Expand All @@ -289,7 +290,7 @@ def terminate_rpc_servers():
# yield happens every time.
serial = os.environ.get(ANDROID_SERIAL_NUMBER)
yield []
if serial == ["simulator"]:
if serial == [HEXAGON_SIMULATOR_NAME]:
os.system("ps ax | grep tvm_rpc_x86 | awk '{print $1}' | xargs kill")


Expand Down
Loading

0 comments on commit f077faa

Please sign in to comment.