Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon]: Add upload function to hexagon session #13161

Merged
merged 13 commits into from
Oct 26, 2022
195 changes: 46 additions & 149 deletions python/tvm/contrib/hexagon/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
import tempfile
from typing import Union

import tvm
from tvm.contrib.hexagon.hexagon_profiler import HexagonProfiler
from ..._ffi import libinfo
from .session import Session

from .tools import HEXAGON_SIMULATOR_NAME

HEXAGON_RPC_LIB_DIR = os.environ.get("HEXAGON_RPC_LIB_DIR")
ANDROID_BASH_FILE_NAME = "android_bash.sh"
HEXAGON_REMOTE_DEVICE_KEY = "hexagon-dev"


def _check_call_verbose(cmd, **kwargs) -> None:
Expand Down Expand Up @@ -103,14 +103,9 @@ class HexagonLauncherRPC(metaclass=abc.ABCMeta):
The basic flow of interaction with the launcher is
launcher = HexagonLauncher(...)
launcher.start_server()
with launcher.start_session() as session:
with launcher.create_session() as session:
# Do something with the session
launcher.stop_server()
"""

HEXAGON_REMOTE_DEVICE_KEY = "hexagon-dev"

"""Configure HexagonLauncherRPC.

Parameters
----------
Expand All @@ -129,7 +124,9 @@ class HexagonLauncherRPC(metaclass=abc.ABCMeta):
used.
"""

def __init__(self, rpc_info: dict, workspace: Union[str, pathlib.Path] = None):
def __init__(
self, rpc_info: dict, workspace: Union[str, pathlib.Path] = None, serial_number: str = None
):
self._rpc_info = {
"rpc_tracker_host": "0.0.0.0",
"rpc_tracker_port": 9190,
Expand All @@ -138,7 +135,7 @@ def __init__(self, rpc_info: dict, workspace: Union[str, pathlib.Path] = None):
}
self._rpc_info.update(rpc_info)
self._workspace = self._create_workspace(workspace)
self._device_key = self.HEXAGON_REMOTE_DEVICE_KEY
self._serial_number = serial_number

@abc.abstractmethod
def start_server(self):
Expand Down Expand Up @@ -205,138 +202,6 @@ def _create_workspace(self, workspace: Union[str, pathlib.Path]) -> pathlib.Path
workspace = os.path.join(base_dir, _get_test_directory_name())
return self._create_remote_directory(workspace)

def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str) -> pathlib.Path:
"""Upload a local file to the remote workspace.

Parameters
----------
local_path : str or pathlib.Path
Path to the local file to be copied.
remote_filename : str
Name of the file in the remote workspace.

Returns
-------
pathlib.Path :
Uploaded file remote path.
"""
assert self._workspace
remote_file_path = self._workspace / remote_filename
self._copy_to_remote(local_path, str(remote_file_path))
return remote_file_path

def start_session(self, session_name: str = "hexagon-rpc") -> Session:
"""Connect to the RPC server.

Parameters
----------
session_name : str
RPC session name.

Returns
-------
Session :
The session object.
"""
hexagon_remote_kw = {
"host": self._rpc_info["rpc_tracker_host"],
"port": self._rpc_info["rpc_tracker_port"],
"priority": 0,
"timeout": 0,
"key": self._device_key,
}
return Session(self, hexagon_remote_kw, session_name=session_name)

def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], session: Session):
"""Load TVM module.

Parameters
----------
module : Union[str, pathlib.Path, tvm.runtime.Module]

The module to load. If `module` is a
`tvm.runtime.Module`, it will be uploaded to the remote
session and loaded.

If the object passed is a string or pathlib.Path, it must
be a full path in the remote system.

session : Session

Remote session. The session must be established (via __enter__)
prior to calling this function.

Returns
-------
TVMModule :
TVM module object.

"""
return session.load_module(module)

def get_graph_executor(
self,
graph_json: str,
module: Union[str, pathlib.Path, tvm.runtime.Module],
session: Session,
):
"""Create a local GraphModule which consumes a remote libmod.

Parameters
----------
graph_json : str
The string with the graph JSON.
module : Union[str, pathlib.Path, tvm.runtime.Module]

The module to load. If `module` is a
`tvm.runtime.Module`, it will be uploaded to the remote
session and loaded.

If the object passed is a string or pathlib.Path, it must
be a full path in the remote system.
session : Session
Remote session. The session must be established (via __enter__)
prior to calling this function.

Returns
-------
GraphModule :
Runtime graph module that can be used to execute the graph.
"""
return session.get_graph_executor(graph_json, module)

def get_graph_debug_executor(
self,
graph_json: str,
module: Union[str, pathlib.Path, tvm.runtime.Module],
session: Session,
dump_root: Union[str, pathlib.Path] = None,
):
"""Create a local GraphModuleDebug which consumes a remote libmod.

Parameters
----------
graph_json : str
The string with the graph JSON.
module : Union[str, pathlib.Path, tvm.runtime.Module]

The module to load. If `module` is a
`tvm.runtime.Module`, it will be uploaded to the remote
session and loaded.

If the object passed is a string or pathlib.Path, it must
be a full path in the remote system.
session : Session
Remote session. The session must be established (via __enter__)
prior to calling this function.

Returns
-------
GraphModuleDebug :
Runtime debug graph module that can be used to debug the graph.
"""
return session.get_graph_debug_executor(graph_json, module, dump_root=dump_root)

@abc.abstractmethod
def get_profile_output(
self,
Expand All @@ -360,6 +225,31 @@ def get_profile_output(
"""
...

def create_session(self, session_name: str = "hexagon-rpc") -> Session:
"""Create an RPC session.

Parameters
----------
session_name : str
RPC session name.

Returns
-------
Session :
The session object.
"""
hexagon_session_kw = {
"remote_workspace": self._workspace,
"rpc_tracker": (self._rpc_info["rpc_tracker_host"], self._rpc_info["rpc_tracker_port"]),
"rpc_server_key": self._rpc_info["device_key"],
"serial_number": self._serial_number,
"session_name": session_name,
}
return Session(**hexagon_session_kw)

def is_simulator(self):
return self._serial_number == HEXAGON_SIMULATOR_NAME


class HexagonLauncherAndroid(HexagonLauncherRPC):
"""Hexagon Launcher for Android."""
Expand Down Expand Up @@ -402,15 +292,18 @@ def __init__(
if not rpc_info.get("workspace_base"):
rpc_info["workspace_base"] = self.ANDROID_HEXAGON_TEST_BASE_DIR
self._serial_number = serial_number
assert self._serial_number != "", "Android serial number is not set."

adb_socket = rpc_info["adb_server_socket"] if rpc_info["adb_server_socket"] else "tcp:5037"
self._adb_device_sub_cmd = ["adb", "-L", adb_socket, "-s", self._serial_number]
self.forwarded_ports_ = []
self._hexagon_debug = hexagon_debug
self._clear_logcat = clear_logcat
self._sysmon_profile = sysmon_profile
self._sysmon_process = None
rpc_info["device_key"] = HEXAGON_REMOTE_DEVICE_KEY + "." + self._serial_number
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved

super(HexagonLauncherAndroid, self).__init__(rpc_info, workspace)
super(HexagonLauncherAndroid, self).__init__(rpc_info, workspace, self._serial_number)

def _copy_to_remote(
self, local_path: Union[str, pathlib.Path], remote_path: Union[str, pathlib.Path]
Expand Down Expand Up @@ -442,7 +335,9 @@ def _copy_binaries(self):
"<RPC_TRACKER_PORT>", str(self._rpc_info["rpc_tracker_port"])
)
if "<HEXAGON_REMOTE_DEVICE_KEY>" in line:
line = line.replace("<HEXAGON_REMOTE_DEVICE_KEY>", self._device_key)
line = line.replace(
"<HEXAGON_REMOTE_DEVICE_KEY>", self._rpc_info["device_key"]
)
if "<RPC_SERVER_PORT>" in line:
line = line.replace(
"<RPC_SERVER_PORT>", str(self._rpc_info["rpc_server_port"])
Expand Down Expand Up @@ -691,12 +586,13 @@ def __init__(self, rpc_info: dict, workspace: Union[str, pathlib.Path] = None):

Parameters are same as for HexagonLauncherRPC.
"""
super(HexagonLauncherSimulator, self).__init__(rpc_info, workspace)

self._toolchain = os.environ.get("HEXAGON_TOOLCHAIN")
if not self._toolchain:
raise RuntimeError("Please set HEXAGON_TOOLCHAIN env variable")
self._serial_number = "simulator"
self._serial_number = HEXAGON_SIMULATOR_NAME

super(HexagonLauncherSimulator, self).__init__(rpc_info, workspace, self._serial_number)

def _copy_to_remote(
self, local_path: Union[str, pathlib.Path], remote_path: Union[str, pathlib.Path]
Expand Down Expand Up @@ -740,18 +636,19 @@ def start_server(self):
self._copy_to_remote(lib_dir / item, self._workspace / item)
# Copy libc++ from the toolchain to the workspace
self._copy_libcxx(self._workspace)
self._device_key = self.HEXAGON_REMOTE_DEVICE_KEY + "." + str(os.getpid())
self._rpc_info["device_key"] = HEXAGON_REMOTE_DEVICE_KEY + "." + str(os.getpid())
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved

rpc_tracker_host = self._rpc_info["rpc_tracker_host"]
rpc_tracker_port = self._rpc_info["rpc_tracker_port"]
rpc_server_port = self._rpc_info["rpc_server_port"]
device_key = self._rpc_info["device_key"]
server_exe = os.path.join(".", "tvm_rpc_x86")

args = [
"server",
f"--tracker={rpc_tracker_host}:{rpc_tracker_port}",
f"--port={rpc_server_port}",
f"--key={self._device_key}",
f"--key={device_key}",
"--timeout=0",
]

Expand Down Expand Up @@ -823,7 +720,7 @@ def HexagonLauncher(
sysmon_profile: bool = False,
):
"""Creates a HexagonLauncher"""
if serial_number == "simulator":
if serial_number == HEXAGON_SIMULATOR_NAME:
return HexagonLauncherSimulator(rpc_info, workspace)
return HexagonLauncherAndroid(
serial_number, rpc_info, workspace, hexagon_debug, clear_logcat, sysmon_profile
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/hexagon/meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:


def _worker_func(hexagon_launcher, evaluator_config, alloc_repeat, artifact_path, args_info):
with hexagon_launcher.start_session() as session:
with hexagon_launcher.create_session() as session:
device = session.device
_, remote_path = os.path.split(artifact_path)
uploaded = session.upload(artifact_path, remote_path)
Expand Down
15 changes: 8 additions & 7 deletions python/tvm/contrib/hexagon/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import tvm.rpc.tracker
from tvm.contrib.hexagon.build import HexagonLauncher, HexagonLauncherRPC
from tvm.contrib.hexagon.session import Session
from tvm.contrib.hexagon.tools import HEXAGON_SIMULATOR_NAME

HEXAGON_TOOLCHAIN = "HEXAGON_TOOLCHAIN"
TVM_TRACKER_HOST = "TVM_TRACKER_HOST"
Expand Down Expand Up @@ -173,7 +174,7 @@ def hexagon_server_process(

if android_serial_num is None:
pytest.skip("ANDROID_SERIAL_NUMBER is not set.")
if android_serial_num == ["simulator"]:
if android_serial_num == [HEXAGON_SIMULATOR_NAME]:
yield None
else:
# Requesting these fixtures sets up a local tracker, if one
Expand Down Expand Up @@ -220,7 +221,7 @@ def pytest_configure(config):


def pytest_configure_node(node):
# the master for each node fills slaveinput dictionary
# the master for each node fills node input dictionary
# which pytest-xdist will transfer to the subprocess
if node.config.iplist is not None:
node.workerinput["device_adr"] = node.config.iplist.pop()
Expand All @@ -240,7 +241,7 @@ def hexagon_launcher(
"""Initials and returns hexagon launcher which reuses RPC info and Android serial number."""
android_serial_num = android_serial_number()

if android_serial_num != ["simulator"]:
if android_serial_num != [HEXAGON_SIMULATOR_NAME]:
rpc_info = hexagon_server_process["launcher"]._rpc_info
else:
rpc_info = {
Expand All @@ -250,7 +251,7 @@ def hexagon_launcher(
"adb_server_socket": adb_server_socket,
}
try:
if android_serial_num == ["simulator"]:
if android_serial_num == [HEXAGON_SIMULATOR_NAME]:
launcher = HexagonLauncher(serial_number=android_serial_num[0], rpc_info=rpc_info)
launcher.start_server()
else:
Expand All @@ -263,7 +264,7 @@ def hexagon_launcher(
)
yield launcher
finally:
if android_serial_num == ["simulator"]:
if android_serial_num == [HEXAGON_SIMULATOR_NAME]:
launcher.stop_server()
elif not hexagon_debug:
launcher.cleanup_directory()
Expand All @@ -274,7 +275,7 @@ def hexagon_session(hexagon_launcher: HexagonLauncherRPC) -> Session:
if hexagon_launcher is None:
yield None
else:
with hexagon_launcher.start_session() as session:
with hexagon_launcher.create_session() as session:
yield session


Expand All @@ -289,7 +290,7 @@ def terminate_rpc_servers():
# yield happens every time.
serial = os.environ.get(ANDROID_SERIAL_NUMBER)
yield []
if serial == ["simulator"]:
if serial == [HEXAGON_SIMULATOR_NAME]:
os.system("ps ax | grep tvm_rpc_x86 | awk '{print $1}' | xargs kill")


Expand Down
Loading