20
20
#define TVM_META_SCHEDULE_RUNNER_H_
21
21
22
22
#include < tvm/ir/expr.h>
23
+ #include < tvm/meta_schedule/arg_info.h>
23
24
24
25
namespace tvm {
25
26
namespace meta_schedule {
26
27
27
- /* ! \brief Runner's output containing measurement result of MeasureCandidate or error msg if any. */
28
+ /* ! \brief The runner's input. */
29
+ class RunnerInputNode : public runtime ::Object {
30
+ public:
31
+ /* ! \brief The path to the built artifact. */
32
+ String artifact_path;
33
+ /* ! \brief The type of device. */
34
+ String device_type;
35
+ /* ! \brief The argument information. */
36
+ Array<ArgInfo> args_info;
37
+
38
+ void VisitAttrs (tvm::AttrVisitor* v) {
39
+ v->Visit (" artifact_path" , &artifact_path);
40
+ v->Visit (" device_type" , &device_type);
41
+ v->Visit (" args_info" , &args_info);
42
+ }
43
+
44
+ static constexpr const char * _type_key = " meta_schedule.RunnerInput" ;
45
+ TVM_DECLARE_FINAL_OBJECT_INFO (RunnerInputNode, runtime::Object);
46
+ };
47
+
48
+ /* !
49
+ * \brief Managed reference to RunnerInputNode
50
+ * \sa RunnerInputNode
51
+ */
52
+ class RunnerInput : public runtime ::ObjectRef {
53
+ public:
54
+ /* !
55
+ * \brief Constructor of RunnerInput
56
+ * \param artifact_path The path to the built artifact.
57
+ * \param device_type The type of device.
58
+ * \param args_info The argument information.
59
+ */
60
+ TVM_DLL explicit RunnerInput (String artifact_path, String device_type, Array<ArgInfo> args_info);
61
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS (RunnerInput, runtime::ObjectRef, RunnerInputNode);
62
+ };
63
+
64
+ /* ! \brief The runner's output. */
28
65
class RunnerResultNode : public runtime ::Object {
29
66
public:
30
- /* ! \brief The run time in seconds. If not None, error_msg should be None. */
67
+ /* ! \brief The run time in seconds.*/
31
68
Optional<Array<FloatImm>> run_secs;
32
- /* ! \brief The error message, if any. If not None, run_secs should be None. */
69
+ /* ! \brief The error message, if any. */
33
70
Optional<String> error_msg;
34
71
35
72
void VisitAttrs (tvm::AttrVisitor* v) {
@@ -48,14 +85,134 @@ class RunnerResultNode : public runtime::Object {
48
85
class RunnerResult : public runtime ::ObjectRef {
49
86
public:
50
87
/* !
51
- * \brief Constructor for RunnerResult.
52
- * \param run_secs The run time in seconds.
53
- * \param error_msg The error message, if any.
88
+ * \brief Constructor
89
+ * \brief The run time in seconds.
90
+ * \brief The error message, if any.
54
91
*/
55
92
TVM_DLL explicit RunnerResult (Optional<Array<FloatImm>> run_secs, Optional<String> error_msg);
56
93
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS (RunnerResult, runtime::ObjectRef, RunnerResultNode);
57
94
};
58
95
96
+ /* !
97
+ * \brief A class to asynchronously fetch runner's output.
98
+ * \note The API design is consistent with python's concurrent.futures.Future:
99
+ * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future
100
+ */
101
+ class RunnerFutureNode : public runtime ::Object {
102
+ public:
103
+ /* !
104
+ * \brief The function type to check whether the runner has finished.
105
+ * \return Whether the runner's output is ready.
106
+ */
107
+ using FDone = runtime::TypedPackedFunc<bool ()>;
108
+ /* !
109
+ * \brief The function type to fetch runner output if it is ready.
110
+ * \return The runner's output.
111
+ */
112
+ using FResult = runtime::TypedPackedFunc<RunnerResult()>;
113
+
114
+ /* ! \brief The packed function to check whether the runner has finished. */
115
+ FDone f_done;
116
+ /* ! \brief The packed function to fetch runner output if it is ready. */
117
+ FResult f_result;
118
+
119
+ void VisitAttrs (tvm::AttrVisitor* v) {
120
+ // `f_done` is not visited
121
+ // `f_result` is not visited
122
+ }
123
+
124
+ /* !
125
+ * \brief Check whether the runner has finished.
126
+ * \return A boolean indicating whether the runner has finished.
127
+ */
128
+ bool Done () const { return f_done (); }
129
+ /* !
130
+ * \brief Fetch the runner's output if it is ready.
131
+ * \return The runner's output.
132
+ */
133
+ RunnerResult Result () const { return f_result (); }
134
+
135
+ static constexpr const char * _type_key = " meta_schedule.RunnerFuture" ;
136
+ TVM_DECLARE_FINAL_OBJECT_INFO (RunnerFutureNode, runtime::Object);
137
+ };
138
+
139
+ /* !
140
+ * \brief Managed reference to RunnerFutureNode
141
+ * \sa RunnerFutureNode
142
+ */
143
+ class RunnerFuture : public runtime ::ObjectRef {
144
+ public:
145
+ using FDone = RunnerFutureNode::FDone;
146
+ using FResult = RunnerFutureNode::FResult;
147
+
148
+ /* !
149
+ * \brief Constructor of RunnerFuture
150
+ * \param f_done The packed function to check whether the runner has finished.
151
+ * \param f_result The packed function to fetch runner output if it is ready.
152
+ */
153
+ TVM_DLL explicit RunnerFuture (FDone f_done, FResult f_result);
154
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS (RunnerFuture, runtime::ObjectRef,
155
+ RunnerFutureNode);
156
+ };
157
+
158
+ /* ! \brief The abstract runner interface. */
159
+ class RunnerNode : public runtime ::Object {
160
+ public:
161
+ /* !
162
+ * \brief The function type to run the built artifacts and get runner futures.
163
+ * \param input The runner's inputs.
164
+ * \return The runner futures.
165
+ * \sa RunnerFuture
166
+ */
167
+ using FRun = runtime::TypedPackedFunc<Array<RunnerFuture>(Array<RunnerInput>)>;
168
+
169
+ /* ! \brief Default destructor */
170
+ virtual ~RunnerNode () = default ;
171
+
172
+ /* !
173
+ * \brief Run the built artifact and get runner futures.
174
+ * \param runner_inputs The runner's inputs.
175
+ * \return The runner futures.
176
+ */
177
+ virtual Array<RunnerFuture> Run (Array<RunnerInput> runner_inputs) = 0;
178
+
179
+ static constexpr const char * _type_key = " meta_schedule.Runner" ;
180
+ TVM_DECLARE_BASE_OBJECT_INFO (RunnerNode, runtime::Object);
181
+ };
182
+
183
+ /* !
184
+ * \brief Managed reference to RunnerNode
185
+ * \sa RunnerNode
186
+ */
187
+ class Runner : public runtime ::ObjectRef {
188
+ public:
189
+ using FRun = RunnerNode::FRun;
190
+
191
+ /* !
192
+ * \brief Create a runner with customized build method on the python-side.
193
+ * \param f_run The packed function to run the built artifacts and get runner futures.
194
+ * \return The runner created.
195
+ */
196
+ TVM_DLL static Runner PyRunner (FRun f_run);
197
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS (Runner, runtime::ObjectRef, RunnerNode);
198
+ };
199
+
200
+ /* ! \brief An abstract runner with customized build method on the python-side. */
201
+ class PyRunnerNode : public RunnerNode {
202
+ public:
203
+ /* ! \brief The packed function to run the built artifacts and get runner futures. */
204
+ FRun f_run;
205
+
206
+ void VisitAttrs (tvm::AttrVisitor* v) {
207
+ // `f_run` is not visited
208
+ }
209
+
210
+ Array<RunnerFuture> Run (Array<RunnerInput> runner_inputs) final { return f_run (runner_inputs); }
211
+
212
+ static constexpr const char * _type_key = " meta_schedule.PyRunner" ;
213
+ TVM_DECLARE_FINAL_OBJECT_INFO (PyRunnerNode, RunnerNode);
214
+ };
215
+
59
216
} // namespace meta_schedule
60
217
} // namespace tvm
61
218
0 commit comments