Skip to content

Commit a23f8ce

Browse files
junrushaoylc
authored andcommitted
[Meta Schedule][M3b] Runner (apache#9111)
This PR is part of the meta schedule project (apache#8473) that adds the asynchronous program runner interface, as well as a reference implementation of RPCRunner. LocalRunner will be implemented with PopenPool executor in a follow-up PR. Co-authored-by: Xiyou Zhou <xiyou@octoml.ai> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Address comments Co-authored-by: Cody Yu <comaniac0422@gmail.com> fix lint
1 parent c7af20a commit a23f8ce

File tree

12 files changed

+1776
-23
lines changed

12 files changed

+1776
-23
lines changed

include/tvm/meta_schedule/runner.h

+163-6
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,53 @@
2020
#define TVM_META_SCHEDULE_RUNNER_H_
2121

2222
#include <tvm/ir/expr.h>
23+
#include <tvm/meta_schedule/arg_info.h>
2324

2425
namespace tvm {
2526
namespace meta_schedule {
2627

27-
/*! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */
28+
/*! \brief The runner's input. */
29+
class RunnerInputNode : public runtime::Object {
30+
public:
31+
/*! \brief The path to the built artifact. */
32+
String artifact_path;
33+
/*! \brief The type of device. */
34+
String device_type;
35+
/*! \brief The argument information. */
36+
Array<ArgInfo> args_info;
37+
38+
void VisitAttrs(tvm::AttrVisitor* v) {
39+
v->Visit("artifact_path", &artifact_path);
40+
v->Visit("device_type", &device_type);
41+
v->Visit("args_info", &args_info);
42+
}
43+
44+
static constexpr const char* _type_key = "meta_schedule.RunnerInput";
45+
TVM_DECLARE_FINAL_OBJECT_INFO(RunnerInputNode, runtime::Object);
46+
};
47+
48+
/*!
49+
* \brief Managed reference to RunnerInputNode
50+
* \sa RunnerInputNode
51+
*/
52+
class RunnerInput : public runtime::ObjectRef {
53+
public:
54+
/*!
55+
* \brief Constructor of RunnerInput
56+
* \param artifact_path The path to the built artifact.
57+
* \param device_type The type of device.
58+
* \param args_info The argument information.
59+
*/
60+
TVM_DLL explicit RunnerInput(String artifact_path, String device_type, Array<ArgInfo> args_info);
61+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode);
62+
};
63+
64+
/*! \brief The runner's output. */
2865
class RunnerResultNode : public runtime::Object {
2966
public:
30-
/*! \brief The run time in seconds. If not None, error_msg should be None. */
67+
/*! \brief The run time in seconds.*/
3168
Optional<Array<FloatImm>> run_secs;
32-
/*! \brief The error message, if any. If not None, run_secs should be None. */
69+
/*! \brief The error message, if any. */
3370
Optional<String> error_msg;
3471

3572
void VisitAttrs(tvm::AttrVisitor* v) {
@@ -48,14 +85,134 @@ class RunnerResultNode : public runtime::Object {
4885
class RunnerResult : public runtime::ObjectRef {
4986
public:
5087
/*!
51-
* \brief Constructor for RunnerResult.
52-
* \param run_secs The run time in seconds.
53-
* \param error_msg The error message, if any.
88+
* \brief Constructor
89+
* \brief The run time in seconds.
90+
* \brief The error message, if any.
5491
*/
5592
TVM_DLL explicit RunnerResult(Optional<Array<FloatImm>> run_secs, Optional<String> error_msg);
5693
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode);
5794
};
5895

96+
/*!
97+
* \brief A class to asynchronously fetch runner's output.
98+
* \note The API design is consistent with python's concurrent.futures.Future:
99+
* https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
100+
*/
101+
class RunnerFutureNode : public runtime::Object {
102+
public:
103+
/*!
104+
* \brief The function type to check whether the runner has finished.
105+
* \return Whether the runner's output is ready.
106+
*/
107+
using FDone = runtime::TypedPackedFunc<bool()>;
108+
/*!
109+
* \brief The function type to fetch runner output if it is ready.
110+
* \return The runner's output.
111+
*/
112+
using FResult = runtime::TypedPackedFunc<RunnerResult()>;
113+
114+
/*! \brief The packed function to check whether the runner has finished. */
115+
FDone f_done;
116+
/*! \brief The packed function to fetch runner output if it is ready. */
117+
FResult f_result;
118+
119+
void VisitAttrs(tvm::AttrVisitor* v) {
120+
// `f_done` is not visited
121+
// `f_result` is not visited
122+
}
123+
124+
/*!
125+
* \brief Check whether the runner has finished.
126+
* \return A boolean indicating whether the runner has finished.
127+
*/
128+
bool Done() const { return f_done(); }
129+
/*!
130+
* \brief Fetch the runner's output if it is ready.
131+
* \return The runner's output.
132+
*/
133+
RunnerResult Result() const { return f_result(); }
134+
135+
static constexpr const char* _type_key = "meta_schedule.RunnerFuture";
136+
TVM_DECLARE_FINAL_OBJECT_INFO(RunnerFutureNode, runtime::Object);
137+
};
138+
139+
/*!
140+
* \brief Managed reference to RunnerFutureNode
141+
* \sa RunnerFutureNode
142+
*/
143+
class RunnerFuture : public runtime::ObjectRef {
144+
public:
145+
using FDone = RunnerFutureNode::FDone;
146+
using FResult = RunnerFutureNode::FResult;
147+
148+
/*!
149+
* \brief Constructor of RunnerFuture
150+
* \param f_done The packed function to check whether the runner has finished.
151+
* \param f_result The packed function to fetch runner output if it is ready.
152+
*/
153+
TVM_DLL explicit RunnerFuture(FDone f_done, FResult f_result);
154+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerFuture, runtime::ObjectRef,
155+
RunnerFutureNode);
156+
};
157+
158+
/*! \brief The abstract runner interface. */
159+
class RunnerNode : public runtime::Object {
160+
public:
161+
/*!
162+
* \brief The function type to run the built artifacts and get runner futures.
163+
* \param input The runner's inputs.
164+
* \return The runner futures.
165+
* \sa RunnerFuture
166+
*/
167+
using FRun = runtime::TypedPackedFunc<Array<RunnerFuture>(Array<RunnerInput>)>;
168+
169+
/*! \brief Default destructor */
170+
virtual ~RunnerNode() = default;
171+
172+
/*!
173+
* \brief Run the built artifact and get runner futures.
174+
* \param runner_inputs The runner's inputs.
175+
* \return The runner futures.
176+
*/
177+
virtual Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) = 0;
178+
179+
static constexpr const char* _type_key = "meta_schedule.Runner";
180+
TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, runtime::Object);
181+
};
182+
183+
/*!
184+
* \brief Managed reference to RunnerNode
185+
* \sa RunnerNode
186+
*/
187+
class Runner : public runtime::ObjectRef {
188+
public:
189+
using FRun = RunnerNode::FRun;
190+
191+
/*!
192+
* \brief Create a runner with customized build method on the python-side.
193+
* \param f_run The packed function to run the built artifacts and get runner futures.
194+
* \return The runner created.
195+
*/
196+
TVM_DLL static Runner PyRunner(FRun f_run);
197+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Runner, runtime::ObjectRef, RunnerNode);
198+
};
199+
200+
/*! \brief An abstract runner with customized build method on the python-side. */
201+
class PyRunnerNode : public RunnerNode {
202+
public:
203+
/*! \brief The packed function to run the built artifacts and get runner futures. */
204+
FRun f_run;
205+
206+
void VisitAttrs(tvm::AttrVisitor* v) {
207+
// `f_run` is not visited
208+
}
209+
210+
Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) final { return f_run(runner_inputs); }
211+
212+
static constexpr const char* _type_key = "meta_schedule.PyRunner";
213+
TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode);
214+
};
215+
59216
} // namespace meta_schedule
60217
} // namespace tvm
61218

python/tvm/meta_schedule/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
# under the License.
1717
"""Package `tvm.meta_schedule`. The meta schedule infrastructure."""
1818
from . import arg_info
19-
from . import builder
2019
from . import database
20+
from . import builder
21+
from . import runner
2122
from . import space_generator
2223
from . import search_strategy
23-
from . import runner
24-
from .database import TuningRecord
2524
from .tune_context import TuneContext

python/tvm/meta_schedule/builder/local_builder.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,20 @@ class LocalBuilder(PyBuilder):
4848
Attributes
4949
----------
5050
T_BUILD : typing._GenericAlias
51-
The signature of the build function `f_build`, which is
52-
`Callable[[IRModule, Target], Module]`
51+
The signature of the function `f_build`, which is
52+
53+
.. code-block:: python
54+
55+
def default_build(mod: IRModule, target: Target) -> Module:
56+
...
57+
5358
T_EXPORT : typing._GenericAlias
54-
The signature of the build function `f_export`, which is
55-
`Callable[[Module], str]`
59+
The signature of the function `f_export`, which is
60+
61+
.. code-block:: python
62+
63+
def default_export(mod: Module) -> str:
64+
...
5665
5766
Note
5867
----

python/tvm/meta_schedule/runner/__init__.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,10 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""meta_schedule.runner"""
18-
from .runner import RunnerResult
17+
"""
18+
The tvm.meta_schedule.runner package.
19+
Meta Schedule runners that runs an artifact either locally or through the RPC interface
20+
"""
21+
from .config import EvaluatorConfig, RPCConfig
22+
from .rpc_runner import RPCRunner
23+
from .runner import PyRunner, Runner, RunnerFuture, RunnerInput, RunnerResult

0 commit comments

Comments
 (0)