From 0f54159712dd60de694715bfb4ac73ada795feba Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 29 Sep 2021 09:35:56 -0700 Subject: [PATCH] [Meta Schedule][M3b] Runner (#9111) This PR is part of the meta schedule project (#8473) that adds the asynchronous program runner interface, as well as a reference implementation of RPCRunner. LocalRunner will be implemented with PopenPool executor in a follow-up PR. Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng Address comments Co-authored-by: Cody Yu fix lint --- include/tvm/meta_schedule/runner.h | 169 +++++- python/tvm/meta_schedule/__init__.py | 5 +- .../meta_schedule/builder/local_builder.py | 17 +- python/tvm/meta_schedule/runner/__init__.py | 9 +- python/tvm/meta_schedule/runner/config.py | 190 ++++++ python/tvm/meta_schedule/runner/rpc_runner.py | 567 +++++++++++++++++ python/tvm/meta_schedule/runner/runner.py | 111 ++++ python/tvm/meta_schedule/testing.py | 74 +++ python/tvm/meta_schedule/tune_context.py | 4 +- python/tvm/meta_schedule/utils.py | 37 +- src/meta_schedule/runner/runner.cc | 45 +- .../unittest/test_meta_schedule_runner.py | 571 ++++++++++++++++++ 12 files changed, 1776 insertions(+), 23 deletions(-) create mode 100644 python/tvm/meta_schedule/runner/config.py create mode 100644 python/tvm/meta_schedule/runner/rpc_runner.py create mode 100644 python/tvm/meta_schedule/testing.py create mode 100644 tests/python/unittest/test_meta_schedule_runner.py diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index 36d07024559d..a45a4898d64a 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -20,16 +20,53 @@ #define TVM_META_SCHEDULE_RUNNER_H_ #include +#include namespace tvm { namespace meta_schedule { -/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */ +/*! \brief The runner's input. */ +class RunnerInputNode : public runtime::Object { + public: + /*! \brief The path to the built artifact. */ + String artifact_path; + /*! \brief The type of device. */ + String device_type; + /*! \brief The argument information. */ + Array args_info; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("artifact_path", &artifact_path); + v->Visit("device_type", &device_type); + v->Visit("args_info", &args_info); + } + + static constexpr const char* _type_key = "meta_schedule.RunnerInput"; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerInputNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerInputNode + * \sa RunnerInputNode + */ +class RunnerInput : public runtime::ObjectRef { + public: + /*! + * \brief Constructor of RunnerInput + * \param artifact_path The path to the built artifact. + * \param device_type The type of device. + * \param args_info The argument information. + */ + TVM_DLL explicit RunnerInput(String artifact_path, String device_type, Array args_info); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode); +}; + +/*! \brief The runner's output. */ class RunnerResultNode : public runtime::Object { public: - /*! \brief The run time in seconds. If not None, error_msg should be None. */ + /*! \brief The run time in seconds.*/ Optional> run_secs; - /*! \brief The error message, if any. If not None, run_secs should be None. */ + /*! \brief The error message, if any. */ Optional error_msg; void VisitAttrs(tvm::AttrVisitor* v) { @@ -48,14 +85,134 @@ class RunnerResultNode : public runtime::Object { class RunnerResult : public runtime::ObjectRef { public: /*! - * \brief Constructor for RunnerResult. - * \param run_secs The run time in seconds. - * \param error_msg The error message, if any. + * \brief Constructor + * \brief The run time in seconds. + * \brief The error message, if any. */ TVM_DLL explicit RunnerResult(Optional> run_secs, Optional error_msg); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode); }; +/*! + * \brief A class to asynchronously fetch runner's output. + * \note The API design is consistent with python's concurrent.futures.Future: + * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future + */ +class RunnerFutureNode : public runtime::Object { + public: + /*! + * \brief The function type to check whether the runner has finished. + * \return Whether the runner's output is ready. + */ + using FDone = runtime::TypedPackedFunc; + /*! + * \brief The function type to fetch runner output if it is ready. + * \return The runner's output. + */ + using FResult = runtime::TypedPackedFunc; + + /*! \brief The packed function to check whether the runner has finished. */ + FDone f_done; + /*! \brief The packed function to fetch runner output if it is ready. */ + FResult f_result; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_done` is not visited + // `f_result` is not visited + } + + /*! + * \brief Check whether the runner has finished. + * \return A boolean indicating whether the runner has finished. + */ + bool Done() const { return f_done(); } + /*! + * \brief Fetch the runner's output if it is ready. + * \return The runner's output. + */ + RunnerResult Result() const { return f_result(); } + + static constexpr const char* _type_key = "meta_schedule.RunnerFuture"; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerFutureNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerFutureNode + * \sa RunnerFutureNode + */ +class RunnerFuture : public runtime::ObjectRef { + public: + using FDone = RunnerFutureNode::FDone; + using FResult = RunnerFutureNode::FResult; + + /*! + * \brief Constructor of RunnerFuture + * \param f_done The packed function to check whether the runner has finished. + * \param f_result The packed function to fetch runner output if it is ready. + */ + TVM_DLL explicit RunnerFuture(FDone f_done, FResult f_result); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerFuture, runtime::ObjectRef, + RunnerFutureNode); +}; + +/*! \brief The abstract runner interface. */ +class RunnerNode : public runtime::Object { + public: + /*! + * \brief The function type to run the built artifacts and get runner futures. + * \param input The runner's inputs. + * \return The runner futures. + * \sa RunnerFuture + */ + using FRun = runtime::TypedPackedFunc(Array)>; + + /*! \brief Default destructor */ + virtual ~RunnerNode() = default; + + /*! + * \brief Run the built artifact and get runner futures. + * \param runner_inputs The runner's inputs. + * \return The runner futures. + */ + virtual Array Run(Array runner_inputs) = 0; + + static constexpr const char* _type_key = "meta_schedule.Runner"; + TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerNode + * \sa RunnerNode + */ +class Runner : public runtime::ObjectRef { + public: + using FRun = RunnerNode::FRun; + + /*! + * \brief Create a runner with customized build method on the python-side. + * \param f_run The packed function to run the built artifacts and get runner futures. + * \return The runner created. + */ + TVM_DLL static Runner PyRunner(FRun f_run); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Runner, runtime::ObjectRef, RunnerNode); +}; + +/*! \brief An abstract runner with customized build method on the python-side. */ +class PyRunnerNode : public RunnerNode { + public: + /*! \brief The packed function to run the built artifacts and get runner futures. */ + FRun f_run; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_run` is not visited + } + + Array Run(Array runner_inputs) final { return f_run(runner_inputs); } + + static constexpr const char* _type_key = "meta_schedule.PyRunner"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode); +}; + } // namespace meta_schedule } // namespace tvm diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index c22cc205bf35..2e280ef20ac3 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -16,10 +16,9 @@ # under the License. """Package `tvm.meta_schedule`. The meta schedule infrastructure.""" from . import arg_info -from . import builder from . import database +from . import builder +from . import runner from . import space_generator from . import search_strategy -from . import runner -from .database import TuningRecord from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index cefe5ec50cad..99dfaea56090 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -48,11 +48,20 @@ class LocalBuilder(PyBuilder): Attributes ---------- T_BUILD : typing._GenericAlias - The signature of the build function `f_build`, which is - `Callable[[IRModule, Target], Module]` + The signature of the function `f_build`, which is + + .. code-block:: python + + def default_build(mod: IRModule, target: Target) -> Module: + ... + T_EXPORT : typing._GenericAlias - The signature of the build function `f_export`, which is - `Callable[[Module], str]` + The signature of the function `f_export`, which is + + .. code-block:: python + + def default_export(mod: Module) -> str: + ... Note ---- diff --git a/python/tvm/meta_schedule/runner/__init__.py b/python/tvm/meta_schedule/runner/__init__.py index 65d2ef04e04c..47f4557e1d3a 100644 --- a/python/tvm/meta_schedule/runner/__init__.py +++ b/python/tvm/meta_schedule/runner/__init__.py @@ -14,5 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""meta_schedule.runner""" -from .runner import RunnerResult +""" +The tvm.meta_schedule.runner package. +Meta Schedule runners that runs an artifact either locally or through the RPC interface +""" +from .config import EvaluatorConfig, RPCConfig +from .rpc_runner import RPCRunner +from .runner import PyRunner, Runner, RunnerFuture, RunnerInput, RunnerResult diff --git a/python/tvm/meta_schedule/runner/config.py b/python/tvm/meta_schedule/runner/config.py new file mode 100644 index 000000000000..712766de99c1 --- /dev/null +++ b/python/tvm/meta_schedule/runner/config.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Configurations for measurements in the runner""" +import os +from threading import Thread +from typing import NamedTuple, Optional, Union + +from tvm import rpc + + +class EvaluatorConfig(NamedTuple): + """Config Details of Evaluator + + Parameters + ---------- + number: int + The number of runs. + repeat: int + The number of times to repeat in each run. + min_repeat_ms: int + Minimum repeat time in ms. if the execution latency is too short, + increase the number of runs to the given time (in ms) to reduce the measurement error. + enable_cpu_cache_flush: bool + Whether to flush the cache on CPU. + + Note + ---- + The total number of actual executions is 1+number*repeat because we would warm up 1 time before + actual run. The number of runs would be increased if run time is below min_repeat_ms. + """ + + number: int = 3 + repeat: int = 1 + min_repeat_ms: int = 40 + enable_cpu_cache_flush: bool = False + + @staticmethod + def _normalized(config: Optional["EvaluatorConfig"]) -> "EvaluatorConfig": + if config is None: + return EvaluatorConfig() + config = EvaluatorConfig( + number=config.number, + repeat=config.repeat, + min_repeat_ms=config.min_repeat_ms, + enable_cpu_cache_flush=config.enable_cpu_cache_flush, + ) + return config + + +class RPCConfig(NamedTuple): + """RPC configuration + + Parameters + ---------- + tracker_host: str + Host of the RPC Tracker + tracker_port: int + Port of the RPC Tracker + tracker_key: str + Key of the Tracker + session_timeout_sec: float + Timeout of the RPC session + session_priority: int + Priority of the RPC session + """ + + tracker_host: Optional[str] = None + tracker_port: Union[None, int, str] = None + tracker_key: Optional[str] = None + session_priority: int = 1 + session_timeout_sec: int = 10 + + def _sanity_check(self) -> None: + err_str = ( + "RPCConfig.{0} is not provided. Please provide it explicitly," + "or set environment variable {1}" + ) + if self.tracker_host is None: + raise ValueError(err_str.format("tracker_host", "TVM_TRACKER_HOST")) + if self.tracker_port is None: + raise ValueError(err_str.format("tracker_port", "TVM_TRACKER_PORT")) + if self.tracker_key is None: + raise ValueError(err_str.format("tracker_key", "TVM_TRACKER_KEY")) + + @staticmethod + def _normalized(config: Optional["RPCConfig"]) -> "RPCConfig": + if config is None: + config = RPCConfig() + config = RPCConfig( + tracker_host=config.tracker_host or os.environ.get("TVM_TRACKER_HOST", None), + tracker_port=config.tracker_port or os.environ.get("TVM_TRACKER_PORT", None), + tracker_key=config.tracker_key or os.environ.get("TVM_TRACKER_KEY", None), + session_priority=config.session_priority, + session_timeout_sec=config.session_timeout_sec, + ) + config._sanity_check() # pylint: disable=protected-access + return config + + def connect_tracker(self) -> rpc.TrackerSession: + """Connect to the tracker + + Returns + ------- + tracker : TrackerSession + The connected tracker session + """ + tracker: Optional[rpc.TrackerSession] = None + + def _connect(): + nonlocal tracker + tracker = rpc.connect_tracker(self.tracker_host, self.tracker_port) + + t = Thread(target=_connect) + t.start() + t.join(self.session_timeout_sec) + if t.is_alive() or tracker is None: + raise ValueError( + "Unable to connect to the tracker using the following configuration:\n" + f" tracker host: {self.tracker_host}\n" + f" tracker port: {self.tracker_port}\n" + f" timeout (sec): {self.session_timeout_sec}\n" + "Please check the tracker status via the following command:\n" + " python3 -m tvm.exec.query_rpc_tracker " + f"--host {self.tracker_host} --port {self.tracker_port}" + ) + return tracker + + def connect_server(self) -> rpc.RPCSession: + """Connect to the server + + Returns + ------- + session : RPCSession + The connected rpc session + """ + tracker = self.connect_tracker() + session: rpc.RPCSession = tracker.request( + key=self.tracker_key, + priority=self.session_priority, + session_timeout=self.session_timeout_sec, + ) + return session + + def count_num_servers(self, allow_missing=True) -> int: + """Count the number of servers available in the tracker + + Parameters + ---------- + allow_missing : bool + Whether to allow no server to be found. + + Returns + ------- + num_servers : int + The number of servers + """ + tracker = self.connect_tracker() + tracker_summary = tracker.summary() + result: int = 0 + for item in tracker_summary["server_info"]: + _, item_key = item["key"].split(":") + if item_key == self.tracker_key: + result += 1 + if result == 0 and not allow_missing: + raise ValueError( + "Unable to find servers with the specific key using the following configuration:\n" + f" tracker host: {self.tracker_host}\n" + f" tracker port: {self.tracker_port}\n" + f" tracker key: {self.tracker_key}\n" + f" timeout (sec): {self.session_timeout_sec}\n" + "Please check the tracker status via the following command:\n" + " python3 -m tvm.exec.query_rpc_tracker " + f"--host {self.tracker_host} --port {self.tracker_port}\n" + f'and look for key: "{self.tracker_key}"' + ) + return result diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py new file mode 100644 index 000000000000..d20e1707fcec --- /dev/null +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -0,0 +1,567 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""RPC Runner""" +import concurrent.futures +from contextlib import contextmanager +import itertools +import os.path as osp +from typing import Any, Callable, Dict, List, Optional, Union + +from tvm.contrib.popen_pool import PopenPoolExecutor +from tvm.rpc import RPCSession +from tvm.runtime import Device, Module, ndarray + +from ..utils import ( + get_global_func_on_rpc_session, + get_global_func_with_default_on_worker, +) +from .config import EvaluatorConfig, RPCConfig +from .runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult + + +class RPCRunnerFuture(RunnerFuture): + """RPC based runner future + + Parameters + ---------- + future: concurrent.futures.Future + The concurrent function to check when the function is done and to return the result. + timeout_sec: float + The timeout in seconds. + """ + + future: concurrent.futures.Future + timeout_sec: float + + def __init__(self, future: concurrent.futures.Future, timeout_sec: float) -> None: + """Constructor + + Parameters + ---------- + future: concurrent.futures.Future + The concurrent function to check when the function is done and to return the result. + timeout_sec: float + The timeout in seconds. + """ + super().__init__() + self.future = future + self.timeout_sec = timeout_sec + + def done(self) -> bool: + return self.future.done() + + def result(self) -> RunnerResult: + try: + run_secs: List[float] = self.future.result() + except TimeoutError as exception: + return RunnerResult( + None, + error_msg=f"RPCRunner: Timeout, killed after {self.timeout_sec} seconds", + ) + except Exception as exception: # pylint: disable=broad-except + return RunnerResult( + None, + error_msg="RPCRunner: An exception occurred\n" + str(exception), + ) + return RunnerResult(run_secs, None) + + +T_ARG_INFO_JSON_OBJ = List[Any] # pylint: disable=invalid-name +T_ARG_INFO_JSON_OBJ_LIST = List[T_ARG_INFO_JSON_OBJ] # pylint: disable=invalid-name +T_ARGUMENT = Any # pylint: disable=invalid-name +T_ARGUMENT_LIST = List[T_ARGUMENT] # pylint: disable=invalid-name + + +class RPCRunner(PyRunner): + """RPC based runner + + Parameters + ---------- + rpc_config: RPCConfig + The rpc configuration. + evaluator_config: EvaluatorConfig + The evaluator configuration. + cooldown_sec: float + The cooldown in seconds. TODO(@junrushao1994,@zxybazh): This is not used yet. + alloc_repeat: int + The number of times to repeat the allocation. + f_create_session: Optional[str, Callable] + The function name to create the session or the function itself. + f_upload_module: Optional[str, Callable] + The function name to upload the module or the function itself. + f_alloc_argument: Optional[str, Callable] + The function name to allocate the arguments or the function itself. + f_run_evaluator: Optional[str, Callable] + The function name to run the evaluator or the function itself. + f_cleanup: Optional[str, Callable] + The function name to cleanup the session or the function itself. + pool: PopenPoolExecutor + The popen pool executor. + + Attributes + ---------- + T_CREATE_SESSION : typing._GenericAlias + The signature of the function `f_create_session`, which is: + + .. code-block:: python + + def default_create_session(rpc_config: RPCConfig) -> RPCSession: + ... + + T_UPLOAD_MODULE : typing._GenericAlias + The signature of the function `f_upload_module`, which is: + + .. code-block:: python + + def default_upload_module( + session: RPCSession, + local_path: str, + remote_path: str, + ) -> Module: + ... + + T_ALLOC_ARGUMENT : typing._GenericAlias + The signature of the function `f_alloc_argument`, which is: + + .. code-block:: python + + def default_alloc_argument( + session: RPCSession, + device: Device, + args_info: T_ARG_INFO_JSON_OBJ_LIST, + alloc_repeat: int, + ) -> List[T_ARGUMENT_LIST]: + ... + + T_RUN_EVALUATOR : typing._GenericAlias + The signature of the function `f_run_evaluator`, which is: + + .. code-block:: python + + def default_run_evaluator( + session: RPCSession, + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[T_ARGUMENT_LIST], + ) -> List[float]: + ... + + T_CLEANUP : typing._GenericAlias + The signature of the function `f_cleanup`, which is: + + .. code-block:: python + + def default_cleanup( + session: Optional[RPCSession], + remote_path: Optional[str], + ) -> None: + ... + """ + + T_CREATE_SESSION = Callable[ + [RPCConfig], # The RPC configuration + RPCSession, # The RPC Session + ] + T_UPLOAD_MODULE = Callable[ + [ + RPCSession, # The RPC Session + str, # local path to the artifact + str, # remote path to the artifact + ], + Module, # the Module opened on the remote + ] + T_ALLOC_ARGUMENT = Callable[ + [ + RPCSession, # The RPC Session + Device, # The device on the remote + T_ARG_INFO_JSON_OBJ_LIST, # The metadata information of the arguments to be allocated + int, # The number of repeated allocations to be done + ], + List[T_ARGUMENT_LIST], # A list of argument lists + ] + T_RUN_EVALUATOR = Callable[ + [ + RPCSession, # The RPC Session + Module, # The Module opened on the remote + Device, # The device on the remote + EvaluatorConfig, # The evaluator configuration + List[T_ARGUMENT_LIST], # A list of argument lists + ], + List[float], # A list of running time + ] + T_CLEANUP = Callable[ + [ + Optional[RPCSession], # The RPC Session to be cleaned up + Optional[str], # remote path to the artifact + ], + None, + ] + + rpc_config: RPCConfig + evaluator_config: EvaluatorConfig + cooldown_sec: float + alloc_repeat: int + + f_create_session: Union[T_CREATE_SESSION, str, None] + f_upload_module: Union[T_UPLOAD_MODULE, str, None] + f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] + f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] + f_cleanup: Union[T_CLEANUP, str, None] + + pool: PopenPoolExecutor + + def __init__( + self, + rpc_config: Optional[RPCConfig] = None, + evaluator_config: Optional[EvaluatorConfig] = None, + cooldown_sec: float = 0.0, + alloc_repeat: int = 1, + f_create_session: Union[T_CREATE_SESSION, str, None] = None, + f_upload_module: Union[T_UPLOAD_MODULE, str, None] = None, + f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] = None, + f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] = None, + f_cleanup: Union[T_CLEANUP, str, None] = None, + max_connections: Optional[int] = None, + initializer: Optional[Callable[[], None]] = None, + ) -> None: + """Constructor + + Parameters + ---------- + rpc_config: RPCConfig + The rpc configuration. + evaluator_config: EvaluatorConfig + The evaluator configuration. + cooldown_sec: float + The cooldown in seconds. + alloc_repeat: int + The number of times to random fill the allocation. + f_create_session: Union[T_CREATE_SESSION, str, None] + The function name to create the session or the function itself. + f_upload_module: Union[T_UPLOAD_MODULE, str, None] + The function name to upload the module or the function itself. + f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] + The function name to allocate the arguments or the function itself. + f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] + The function name to run the evaluator or the function itself. + f_cleanup: Union[T_CLEANUP, str, None] + The function name to cleanup the session or the function itself. + max_connections: Optional[int] + The maximum number of connections. + initializer: Optional[Callable[[], None]] + The initializer function. + """ + super().__init__() + self.rpc_config = RPCConfig._normalized(rpc_config) + self.evaluator_config = EvaluatorConfig._normalized(evaluator_config) + self.cooldown_sec = cooldown_sec + self.alloc_repeat = alloc_repeat + self.f_create_session = f_create_session + self.f_upload_module = f_upload_module + self.f_alloc_argument = f_alloc_argument + self.f_run_evaluator = f_run_evaluator + self.f_cleanup = f_cleanup + + num_servers = self.rpc_config.count_num_servers(allow_missing=False) + if max_connections is None: + max_connections = num_servers + else: + max_connections = min(max_connections, num_servers) + + self.pool = PopenPoolExecutor( + max_workers=max_connections, + timeout=rpc_config.session_timeout_sec, + initializer=initializer, + ) + self._sanity_check() + + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + results: List[RunnerFuture] = [] + for runner_input in runner_inputs: + future = RPCRunnerFuture( + future=self.pool.submit( + RPCRunner._worker_func, + self.f_create_session, + self.f_upload_module, + self.f_alloc_argument, + self.f_run_evaluator, + self.f_cleanup, + self.rpc_config, + self.evaluator_config, + self.alloc_repeat, + str(runner_input.artifact_path), + str(runner_input.device_type), + tuple(arg_info.as_json() for arg_info in runner_input.args_info), + ), + timeout_sec=self.rpc_config.session_timeout_sec, + ) + results.append(future) + return results + + def _sanity_check(self) -> None: + def _check( + f_create_session, + f_upload_module, + f_alloc_argument, + f_run_evaluator, + f_cleanup, + ) -> None: + get_global_func_with_default_on_worker(name=f_create_session, default=None) + get_global_func_with_default_on_worker(name=f_upload_module, default=None) + get_global_func_with_default_on_worker(name=f_alloc_argument, default=None) + get_global_func_with_default_on_worker(name=f_run_evaluator, default=None) + get_global_func_with_default_on_worker(name=f_cleanup, default=None) + + value = self.pool.submit( + _check, + self.f_create_session, + self.f_upload_module, + self.f_alloc_argument, + self.f_run_evaluator, + self.f_cleanup, + ) + value.result() + + @staticmethod + def _worker_func( + _f_create_session: Union[T_CREATE_SESSION, str, None], + _f_upload_module: Union[T_UPLOAD_MODULE, str, None], + _f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None], + _f_run_evaluator: Union[T_RUN_EVALUATOR, str, None], + _f_cleanup: Union[T_CLEANUP, str, None], + rpc_config: RPCConfig, + evaluator_config: EvaluatorConfig, + alloc_repeat: int, + artifact_path: str, + device_type: str, + args_info: T_ARG_INFO_JSON_OBJ_LIST, + ) -> List[float]: + # Step 0. Get the registered functions + f_create_session: RPCRunner.T_CREATE_SESSION = get_global_func_with_default_on_worker( + _f_create_session, default_create_session + ) + f_upload_module: RPCRunner.T_UPLOAD_MODULE = get_global_func_with_default_on_worker( + _f_upload_module, default_upload_module + ) + f_alloc_argument: RPCRunner.T_ALLOC_ARGUMENT = get_global_func_with_default_on_worker( + _f_alloc_argument, default_alloc_argument + ) + f_run_evaluator: RPCRunner.T_RUN_EVALUATOR = get_global_func_with_default_on_worker( + _f_run_evaluator, default_run_evaluator + ) + f_cleanup: RPCRunner.T_CLEANUP = get_global_func_with_default_on_worker( + _f_cleanup, default_cleanup + ) + # Managed resources + session: Optional[RPCSession] = None + remote_path: Optional[str] = None + + @contextmanager + def resource_handler(): + try: + yield + finally: + # Step 5. Clean up + f_cleanup(session, remote_path) + + with resource_handler(): + # Step 1. Create session + session = f_create_session(rpc_config) + device = session.device(dev_type=device_type, dev_id=0) + # Step 2. Upload the module + _, remote_path = osp.split(artifact_path) + local_path: str = artifact_path + rt_mod: Module = f_upload_module(session, local_path, remote_path) + # Step 3: Allocate input arguments + repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument( + session, + device, + args_info, + alloc_repeat, + ) + # Step 4: Run time_evaluator + costs: List[float] = f_run_evaluator( + session, + rt_mod, + device, + evaluator_config, + repeated_args, + ) + return costs + + +def default_create_session(rpc_config: RPCConfig) -> RPCSession: + """Default function to create the session + + Parameters + ---------- + rpc_config : RPCConfig + The configuration of the RPC session + + Returns + ------- + session : RPCSession + The created rpc session + """ + return rpc_config.connect_server() + + +def default_upload_module( + session: RPCSession, + local_path: str, + remote_path: str, +) -> Module: + """Default function to upload the module + + Parameters + ---------- + session: RPCSession + The session to upload the module + local_path: str + The local path of the module + remote_path: str + The remote path to place the module + + Returns + ------- + rt_mod : Module + The runtime module + """ + session.upload(local_path, remote_path) + rt_mod: Module = session.load_module(remote_path) + return rt_mod + + +def default_alloc_argument( + session: RPCSession, + device: Device, + args_info: T_ARG_INFO_JSON_OBJ_LIST, + alloc_repeat: int, +) -> List[T_ARGUMENT_LIST]: + """Default function to allocate the arguments + + Parameters + ---------- + session: RPCSession + The session to allocate the arguments + device: Device + The device to allocate the arguments + alloc_repeat: int + The number of times to repeat the allocation + args_info: PyArgsInfo + The arguments info + + Returns + ------- + repeated_args: List[Args] + The allocation args + """ + f_random_fill = get_global_func_on_rpc_session( + session, + "tvm.contrib.random.random_fill", + "Please make sure 'USE_RANDOM' is turned ON in the config.cmake on the RPC server.", + ) + + def alloc_tensor(_, dtype, shape) -> ndarray.NDArray: + arg = ndarray.empty(shape=shape, dtype=dtype, device=device) + f_random_fill(arg) + return arg + + def alloc_fail(*arg_info) -> None: + raise NotImplementedError(arg_info) + + dispatcher: Dict[Any, Callable] = { + "TENSOR": alloc_tensor, + None: alloc_fail, + } + + repeated_args: List[T_ARGUMENT_LIST] = [] + for _ in range(alloc_repeat): + args: T_ARGUMENT_LIST = [] + arg_info: T_ARG_INFO_JSON_OBJ + for arg_info in args_info: + arg_type = arg_info[0] + arg: Any = dispatcher.get(arg_type, None)(*arg_info) + args.append(arg) + repeated_args.append(args) + return repeated_args + + +def default_run_evaluator( + session: RPCSession, # pylint: disable=unused-argument + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[T_ARGUMENT_LIST], +) -> List[float]: + """Default function to run the evaluator + + Parameters + ---------- + session: RPCSession + The session to run the evaluator + rt_mod: Module + The runtime module + device: Device + The device to run the evaluator + evaluator_config: EvaluatorConfig + The evaluator config + repeated_args: List[Args] + The repeated arguments + + Returns + ------- + costs: List[float] + The evaluator results + """ + evaluator = rt_mod.time_evaluator( + func_name=rt_mod.entry_name, + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + device.sync() + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + return costs + + +def default_cleanup( + session: Optional[RPCSession], + remote_path: Optional[str], +) -> None: + """Default function to clean up the session + + Parameters + ---------- + session: RPCSession + The session to clean up + remote_path: str + The remote path to clean up + """ + if session is not None and remote_path is not None: + session.remove(remote_path) + session.remove(remote_path + ".so") + session.remove("") diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index b756c6e6b011..9f7be8ea4af4 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -21,6 +21,50 @@ from tvm.runtime import Object from .. import _ffi_api +from ..arg_info import ArgInfo + + +@register_object("meta_schedule.RunnerInput") +class RunnerInput(Object): + """The runner's input + + Parameters + ---------- + artifact_path : str + The path to the built artifact. + device_type : str + The device type. + args_info : List[ArgInfo] + The argument information. + """ + + artifact_path: str + device_type: str + args_info: List[ArgInfo] + + def __init__( + self, + artifact_path: str, + device_type: str, + args_info: List[ArgInfo], + ) -> None: + """Constructor + + Parameters + ---------- + artifact_path : str + The path to the built artifact. + device_type : str + The device type. + args_info : List[ArgInfo] + The argument information. + """ + self.__init_handle_by_constructor__( + _ffi_api.RunnerInput, # type: ignore # pylint: disable=no-member + artifact_path, + device_type, + args_info, + ) @register_object("meta_schedule.RunnerResult") @@ -57,3 +101,70 @@ def __init__( run_secs, error_msg, ) + + +@register_object("meta_schedule.RunnerFuture") +class RunnerFuture(Object): + """A class to fetch asynchronous runner's output.""" + + def __init__(self) -> None: + """Constructor""" + + def f_done(): + return self.done() + + def f_result(): + return self.result() + + self.__init_handle_by_constructor__( + _ffi_api.RunnerFuture, # type: ignore # pylint: disable=no-member + f_done, + f_result, + ) + + def done(self) -> bool: + """Check whether the runner has finished.""" + raise NotImplementedError + + def result(self) -> RunnerResult: + """Fetch the runner's output if it is ready.""" + raise NotImplementedError + + +@register_object("meta_schedule.Runner") +class Runner(Object): + """The abstract runner interface""" + + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + """Run the built artifact and get runner futures. + + Parameters + ---------- + runner_inputs : List[RunnerInput] + The inputs to the runner. + + Returns + ------- + runner_futures: List[RunnerFuture] + The runner futures. + """ + return _ffi_api.RunnerRun(self, runner_inputs) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyRunner") +class PyRunner(Runner): + """An abstract runner with customized build method on the python-side.""" + + def __init__(self) -> None: + """Constructor""" + + def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + return self.run(runner_inputs) + + self.__init_handle_by_constructor__( + _ffi_api.RunnerPyRunner, # type: ignore # pylint: disable=no-member + f_run, + ) + + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + raise NotImplementedError diff --git a/python/tvm/meta_schedule/testing.py b/python/tvm/meta_schedule/testing.py new file mode 100644 index 000000000000..4caaeb7553cc --- /dev/null +++ b/python/tvm/meta_schedule/testing.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities in meta schedule""" +import time + +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server + + +class LocalRPC: + """A pair of RPC tracker/server running locally + + Parameters + ---------- + tracker_host : str + The host URL of the tracker + tracker_port : int + The port of the tracker + tracker_key: str + The key used in the tracker to refer to a worker + """ + + tracker_host: str + tracker_port: int + tracker_key: str + + def __init__( + self, + tracker_key: str = "key", + silent: bool = False, + no_fork: bool = False, + ) -> None: + self.tracker = Tracker( + silent=silent, + port=9190, + port_end=12345, + ) + time.sleep(0.5) + self.server = Server( + host="0.0.0.0", + is_proxy=False, + tracker_addr=(self.tracker.host, self.tracker.port), + key=tracker_key, + silent=silent, + no_fork=no_fork, + port=9190, + port_end=12345, + ) + self.tracker_host = self.tracker.host + self.tracker_port = self.tracker.port + self.tracker_key = tracker_key + + def __enter__(self): + return self + + def __exit__(self, _type, _value, _traceback): + if hasattr(self, "server"): + del self.server + if hasattr(self, "tracker"): + del self.tracker diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 4c83b9afa289..9c41b4d575da 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -19,10 +19,10 @@ from typing import Optional, TYPE_CHECKING from tvm import IRModule +from tvm._ffi import register_object +from tvm.meta_schedule.utils import cpu_count from tvm.runtime import Object from tvm.target import Target -from tvm.meta_schedule.utils import cpu_count -from tvm._ffi import register_object from . import _ffi_api diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index e710b0ed06f3..5f536994a9fd 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -18,14 +18,14 @@ import json import os import shutil -from typing import Any, Callable, List, Union +from typing import Any, Callable, List, Optional, Union import psutil - from tvm._ffi import get_global_func, register_func from tvm.error import TVMError from tvm.ir import Array, Map -from tvm.runtime import String +from tvm.rpc import RPCSession +from tvm.runtime import PackedFunc, String from tvm.tir import FloatImm, IntImm @@ -95,6 +95,37 @@ def get_global_func_with_default_on_worker( ) from error +def get_global_func_on_rpc_session( + session: RPCSession, + name: str, + extra_error_msg: Optional[str] = None, +) -> PackedFunc: + """Get a PackedFunc from the global registry from an RPCSession. + + Parameters + ---------- + session : RPCSession + The RPCSession to be retrieved from + name : str + The name of the PackedFunc + extra_error_msg : Optional[str] + Extra information to provide in the error message + + Returns + ------- + result : PackedFunc + The result + """ + try: + result = session.get_function(name) + except AttributeError as error: + error_msg = f'Unable to find function "{name}" on the remote RPC server.' + if extra_error_msg: + error_msg = f"{error_msg} {extra_error_msg}" + raise AttributeError(error_msg) from error + return result + + @register_func("meta_schedule.remove_build_dir") def remove_build_dir(artifact_path: str) -> None: """Clean up the build directory""" diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc index 8f509bdd7b84..800a76f21e65 100644 --- a/src/meta_schedule/runner/runner.cc +++ b/src/meta_schedule/runner/runner.cc @@ -16,13 +16,19 @@ * specific language governing permissions and limitations * under the License. */ -#include - #include "../utils.h" namespace tvm { namespace meta_schedule { +RunnerInput::RunnerInput(String artifact_path, String device_type, Array args_info) { + ObjectPtr n = make_object(); + n->artifact_path = artifact_path; + n->device_type = device_type; + n->args_info = args_info; + this->data_ = n; +} + RunnerResult::RunnerResult(Optional> run_secs, Optional error_msg) { ObjectPtr n = make_object(); n->run_secs = run_secs; @@ -30,12 +36,45 @@ RunnerResult::RunnerResult(Optional> run_secs, Optional this->data_ = n; } -TVM_REGISTER_NODE_TYPE(RunnerResultNode); +RunnerFuture::RunnerFuture(RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) { + ObjectPtr n = make_object(); + n->f_done = f_done; + n->f_result = f_result; + this->data_ = n; +} +Runner Runner::PyRunner(Runner::FRun f_run) { + ObjectPtr n = make_object(); + n->f_run = f_run; + return Runner(n); +} + +/******** FFI ********/ + +TVM_REGISTER_NODE_TYPE(RunnerInputNode); +TVM_REGISTER_NODE_TYPE(RunnerResultNode); +TVM_REGISTER_NODE_TYPE(RunnerFutureNode); +TVM_REGISTER_OBJECT_TYPE(RunnerNode); +TVM_REGISTER_NODE_TYPE(PyRunnerNode); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerInput") + .set_body_typed([](String artifact_path, String device_type, + Array args_info) -> RunnerInput { + return RunnerInput(artifact_path, device_type, args_info); + }); TVM_REGISTER_GLOBAL("meta_schedule.RunnerResult") .set_body_typed([](Array run_secs, Optional error_msg) -> RunnerResult { return RunnerResult(run_secs, error_msg); }); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerFuture") + .set_body_typed([](RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) -> RunnerFuture { + return RunnerFuture(f_done, f_result); + }); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureDone") + .set_body_method(&RunnerFutureNode::Done); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureResult") + .set_body_method(&RunnerFutureNode::Result); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerRun").set_body_method(&RunnerNode::Run); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerPyRunner").set_body_typed(Runner::PyRunner); } // namespace meta_schedule } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py new file mode 100644 index 000000000000..3c8aee0c6d58 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_runner.py @@ -0,0 +1,571 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Runner """ + +import itertools +import sys +import time +from typing import Any, List + +import numpy as np +import pytest + +import tvm +from tvm import tir +from tvm._ffi import register_func +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder +from tvm.meta_schedule.runner import ( + EvaluatorConfig, + PyRunner, + RPCConfig, + RPCRunner, + RunnerFuture, + RunnerInput, +) +from tvm.meta_schedule.runner.rpc_runner import ( + default_alloc_argument as rpc_default_alloc_argument, +) +from tvm.meta_schedule.testing import LocalRPC +from tvm.meta_schedule.utils import get_global_func_with_default_on_worker +from tvm.rpc import RPCSession +from tvm.runtime import Device, Module +from tvm.script import ty +from tvm.target import Target +import tvm.testing +from tvm.tir import FloatImm + +MATMUL_N = 16 +MATMUL_M = 32 + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking + + +@tvm.script.tir +class MatmulModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.tir +class MatmulReluModule: + def main(a: ty.handle, b: ty.handle, d: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + D = tir.match_buffer(d, (16, 16), "float32") + C = tir.alloc_buffer((16, 16), "float32") + with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + with tir.block([16, 16], "relu") as [vi, vj]: + D[vi, vj] = tir.max(C[vi, vj], 0.0) + + +@tvm.script.tir +class BatchMatmulModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, [16, 32, 32]) + B = tir.match_buffer(b, [16, 32, 32]) + C = tir.match_buffer(c, [16, 32, 32]) + with tir.block([16, 32, 32, tir.reduce_axis(0, 32)], "update") as [vn, vi, vj, vk]: + with tir.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +@tvm.script.tir +class AddModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, [32], "float32") + B = tir.match_buffer(b, [32], "float32") + C = tir.match_buffer(c, [32], "float32") + with tir.block([32], "add") as [vi]: + C[vi] = A[vi] + B[vi] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring + + +def _clean_build(artifact_path: str) -> None: + f_clean_build = get_global_func_with_default_on_worker("meta_schedule.remove_build_dir", None) + if f_clean_build is not None: + f_clean_build(artifact_path) + else: + raise RuntimeError("Unable to find remove_build_dir function.") + + +def test_meta_schedule_rpc_single_run(): + """Test meta schedule rpc runner for a single run""" + # Build the module + mod = MatmulModule() + builder = LocalBuilder() + (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + runner_input = RunnerInput( + builder_result.artifact_path, + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner(rpc_config, evaluator_config) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + _clean_build(builder_result.artifact_path) + + +def test_meta_schedule_rpc_multiple_runs(): + """Test meta schedule rpc runner for multiple runs""" + # Build the module + mods = [ + MatmulModule(), + MatmulReluModule(), + BatchMatmulModule(), + ] + builder = LocalBuilder() + builder_inputs = [BuilderInput(mod, Target("llvm")) for mod in mods] + builder_results = builder.build(builder_inputs) + for builder_result in builder_results: + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + args_infos = [ + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + [ + TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), + TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), + TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), + ], + ] + + runner_inputs = [ + RunnerInput(builder_results[i].artifact_path, "llvm", args_infos[i]) + for i in range(len(mods)) + ] + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner(rpc_config, evaluator_config) + # Run the module + runner_futures = runner.run(runner_inputs) + runner_results = [runner_future.result() for runner_future in runner_futures] + + for runner_result in runner_results: + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + + for builder_result in builder_results: + _clean_build(builder_result.artifact_path) + + +def test_meta_schedule_py_runner(): + """Test meta schedule PyRunner""" + + class TestRunner(PyRunner): + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + raise ValueError("TestRunner") + + runner = TestRunner() + with pytest.raises(ValueError, match="TestRunner"): + runner.run([]) + + +def test_meta_schedule_rpc_runner_time_out(): + """Test meta schedule RPC Runner time out""" + + def initializer(): + @register_func("meta_schedule.runner.test_time_out") + def timeout_session_creator( # pylint: disable=unused-variable + rpc_config: RPCConfig, # pylint: disable=unused-argument + ) -> RPCSession: + time.sleep(2) + + runner_input = RunnerInput( + "test", + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=1, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + initializer=initializer, + f_create_session="meta_schedule.runner.test_time_out", + ) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + + assert runner_result.error_msg is not None and runner_result.error_msg.startswith( + "RPCRunner: Timeout, killed after" + ) + assert runner_result.run_secs is None + + +def test_meta_schedule_rpc_runner_exception(): + """Test meta schedule RPC Runner exception""" + + def initializer(): + @register_func("meta_schedule.runner.test_exception") + def exception_session_creator( # pylint: disable=unused-variable + rpc_config: RPCConfig, # pylint: disable=unused-argument + ) -> RPCSession: + raise Exception("Test") + + runner_input = RunnerInput( + "test", + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + initializer=initializer, + f_create_session="meta_schedule.runner.test_exception", + ) + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + + assert runner_result.error_msg is not None and runner_result.error_msg.startswith( + "RPCRunner: An exception occurred\n" + ) + assert runner_result.run_secs is None + + +def test_meta_schedule_runner_matmul_test(): + """Test meta schedule runner with add module""" + + def _check_correct_matmul( + args_before: List[np.ndarray], + args_after: List[np.ndarray], + ) -> None: + a_before, b_before, c_before = args_before + a_after, b_after, c_after = args_after + c_before = np.matmul(a_before, b_before) + assert (a_before == a_after).all() + assert (b_before == b_after).all() + tvm.testing.assert_allclose(c_before, c_after, rtol=1e-5) + + def test_alloc_argument( + session: RPCSession, + device: Device, + args_info: Any, + alloc_repeat: int, + ) -> List[Any]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_before = [] # type: ignore + repeated_args = rpc_default_alloc_argument(session, device, args_info, alloc_repeat) + for args in repeated_args: + repeated_args_before.append([arg.numpy() for arg in args]) # type: ignore + return repeated_args + + def test_run_evaluator( + session: RPCSession, # pylint: disable=unused-argument + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[Any], + ) -> List[float]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_after = [] + evaluator = rt_mod.time_evaluator( + func_name=rt_mod.entry_name, + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + device.sync() + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + repeated_args_after.append([arg.numpy() for arg in args]) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + for args_before, args_after in zip( + repeated_args_before, # type: ignore + repeated_args_after, + ): + _check_correct_matmul(args_before, args_after) + del repeated_args_before # type: ignore + return costs + + # Build the module + mod = MatmulModule() + builder = LocalBuilder() + (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + runner_input = RunnerInput( + builder_result.artifact_path, + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + f_alloc_argument=test_alloc_argument, + f_run_evaluator=test_run_evaluator, + ) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + _clean_build(builder_result.artifact_path) + + +def test_meta_schedule_runner_add_test(): + """Test meta schedule runner with add module""" + + def _check_correct_add(args_before: List[np.ndarray], args_after: List[np.ndarray]) -> None: + a_before, b_before, c_before = args_before + a_after, b_after, c_after = args_after + c_before = a_before + b_before + assert (a_before == a_after).all() + assert (b_before == b_after).all() + assert (c_before == c_after).all() + + def test_alloc_argument( + session: RPCSession, + device: Device, + args_info: Any, + alloc_repeat: int, + ) -> List[Any]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_before = [] # type: ignore + repeated_args = rpc_default_alloc_argument( + session, + device, + args_info, + alloc_repeat, + ) + for args in repeated_args: + repeated_args_before.append([arg.numpy() for arg in args]) # type: ignore + return repeated_args + + def test_run_evaluator( + session: RPCSession, # pylint: disable=unused-argument + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[Any], + ) -> List[float]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_after = [] + evaluator = rt_mod.time_evaluator( + func_name=rt_mod.entry_name, + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + device.sync() + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + repeated_args_after.append([arg.numpy() for arg in args]) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + for args_before, args_after in zip( + repeated_args_before, # type: ignore + repeated_args_after, + ): + _check_correct_add(args_before, args_after) + del repeated_args_before # type: ignore + return costs + + # Build the module + mod = AddModule() + builder = LocalBuilder() + (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + runner_input = RunnerInput( + builder_result.artifact_path, + "llvm", + [ + TensorInfo("float32", [MATMUL_M]), + TensorInfo("float32", [MATMUL_M]), + TensorInfo("float32", [MATMUL_M]), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + f_alloc_argument=test_alloc_argument, + f_run_evaluator=test_run_evaluator, + ) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + _clean_build(builder_result.artifact_path) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))