diff --git a/python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py b/python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py index 704801b4a0549..48b91a47c175d 100644 --- a/python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py +++ b/python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py @@ -133,7 +133,7 @@ def _worker_func( "priority": 0, "timeout": 100, } - + build_result = namedtuple("BuildResult", ["filename"])(artifact_path) with module_loader(remote_kw, build_result) as (remote, mod): @@ -189,11 +189,11 @@ def get_rpc_runner_micro( rpc_timeout_sec: The rpc session timeout. serial_numbers: - List of board serial numbers to be used during tuning. - For "CRT" and "QEMU" platforms the serial numners are not used, + List of board serial numbers to be used during tuning. + For "CRT" and "QEMU" platforms the serial numners are not used, but the length of the list determines the number of runner instances. """ - + if evaluator_config is None: evaluator_config = EvaluatorConfig( number=3, @@ -204,7 +204,7 @@ def get_rpc_runner_micro( if tracker_host is None: tracker_host = "127.0.0.1" - + if tracker_port is None: tracker_port = 9000 else: @@ -221,7 +221,7 @@ def get_rpc_runner_micro( reuse_addr=True, timeout=60, ) - + servers = [] rpc_configs = [] for serial_number in serial_numbers: @@ -263,7 +263,7 @@ def handle_SIGINT(signal, frame): project_options=options, rpc_configs=rpc_configs, evaluator_config=evaluator_config, - session_timeout_sec=session_timeout_sec + session_timeout_sec=session_timeout_sec, ) finally: terminate() diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index eb8023f5c1a0f..8033883aaf4b0 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -153,10 +153,9 @@ def __call__(self, remote_kw, build_result): with open(build_result.filename, "rb") as build_file: build_result_bin = build_file.read() - if ("board" in self._project_options and - "$local$device" not in remote_kw["device_key"]): + if "board" in self._project_options and "$local$device" not in remote_kw["device_key"]: self._project_options["serial_number"] = remote_kw["device_key"] - + tracker = _rpc.connect_tracker(remote_kw["host"], remote_kw["port"]) remote = tracker.request( remote_kw["device_key"], diff --git a/tests/micro/zephyr/test_ms_tuning.py b/tests/micro/zephyr/test_ms_tuning.py index d78eb0bb46541..6c07327f3c53d 100644 --- a/tests/micro/zephyr/test_ms_tuning.py +++ b/tests/micro/zephyr/test_ms_tuning.py @@ -95,7 +95,10 @@ def test_ms_tuning_conv2d(workspace_dir, board, microtvm_debug, use_fvp, serial_ builder = get_local_builder_micro() with ms.Profiler() as profiler: with get_rpc_runner_micro( - platform=platform, options=project_options, session_timeout_sec=120, serial_numbers=["0", "1"] + platform=platform, + options=project_options, + session_timeout_sec=120, + serial_numbers=["0", "1"], ) as runner: db: ms.Database = ms.relay_integration.tune_relay(