Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Developer Ergonomics Enhancement #11622

Merged
merged 1 commit into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,10 @@ class SearchStrategyNode : public runtime::Object {

/*!
* \brief Update the search strategy with measurement results.
* \param context The tuning context.
* \param measure_candidates The candidates to be measured.
* \param results The measurement results from the runner.
*/
virtual void NotifyRunnerResults(const TuneContext& context,
const Array<MeasureCandidate>& measure_candidates,
virtual void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
const Array<RunnerResult>& results) = 0;

static constexpr const char* _type_key = "meta_schedule.SearchStrategy";
Expand Down Expand Up @@ -150,8 +148,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
* \brief The function type of `NotifyRunnerResults` method.
* \param results The measurement results from the runner.
*/
using FNotifyRunnerResults = runtime::TypedPackedFunc<void(
const TuneContext&, const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
using FNotifyRunnerResults =
runtime::TypedPackedFunc<void(const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;

/*! \brief The packed function to the `InitializeWithTuneContext` method. */
FInitializeWithTuneContext f_initialize_with_tune_context;
Expand All @@ -177,8 +175,7 @@ class PySearchStrategyNode : public SearchStrategyNode {
const Optional<CostModel>& cost_model) final;
void PostTuning() final;
Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final;
void NotifyRunnerResults(const TuneContext& context,
const Array<MeasureCandidate>& measure_candidates,
void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
const Array<RunnerResult>& results);

static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";
Expand Down
27 changes: 23 additions & 4 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace tvm {
namespace meta_schedule {

class TaskSchedulerNode;
class MeasureCallback;

/*! \brief The auto tuning context. */
class TuneContextNode : public runtime::Object {
Expand Down Expand Up @@ -70,7 +71,7 @@ class TuneContextNode : public runtime::Object {
int num_threads;

/*! \brief Whether the tuning task has been stopped or finished. */
bool is_terminated;
bool is_terminated; // TODO(@junrushao1994): move to TaskScheduler
/*! \brief The measure candidates. */
Optional<Array<MeasureCandidate>> measure_candidates;
/*! \brief The building results. */
Expand All @@ -87,18 +88,36 @@ class TuneContextNode : public runtime::Object {
v->Visit("postprocs", &postprocs);
v->Visit("mutator_probs", &mutator_probs);
v->Visit("task_name", &task_name);
// `logging_func` is not visited
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
v->Visit("is_terminated", &is_terminated);
v->Visit("measure_candidates", &measure_candidates);
v->Visit("builder_results", &builder_results);
v->Visit("runner_futures", &runner_futures);
v->Visit("measure_candidates", &measure_candidates);
// `logging_func` is not visited
}

/*! \brief Initialize members that needs initialization with tune context. */
void Initialize();

/*! \brief Set the measure candidates from the SearchStrategy */
void _SetMeasureCandidates(const Array<MeasureCandidate>& candidates);
/*!
* \brief Send the measure candidates to builder.
* \param builder The builder to send the candidates to.
*/
void _SendToBuilder(const Builder& builder);
/*!
* \brief Send the built measure candidates to runner.
* \param runner The runner to send the candidates to.
*/
void _SendToRunner(const Runner& runner);
/*!
* \brief Join the running tasks.
* \returns The results from the runner
*/
Array<RunnerResult> _Join();
/*! \brief Set `measure_candidates`, `builder_results` and `runner_futures` to null. */
void _ClearMeasureState();
static constexpr const char* _type_key = "meta_schedule.TuneContext";
TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object);
};
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
builder,
cost_model,
database,
default_config,
feature_extractor,
measure_callback,
mutator,
postproc,
runner,
Expand All @@ -32,5 +34,6 @@
from .extracted_task import ExtractedTask
from .relay_integration import extract_task_from_relay
from .search_strategy import MeasureCandidate
from .tune import TuneConfig, tune_relay, tune_te, tune_tir
from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, tune_tir
from .tune_context import TuneContext
from .utils import derived_object
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
"""
from .database import Database, PyDatabase, TuningRecord, Workload
from .json_database import JSONDatabase
from .memory_database import MemoryDatabase
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tuning record database"""
"""TuningRecord database"""
from typing import Any, Callable, List, Optional

from tvm._ffi import register_object
Expand Down
63 changes: 63 additions & 0 deletions python/tvm/meta_schedule/database/memory_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A database that stores TuningRecords in memory"""
from typing import List

from ...ir import IRModule, structural_equal
from ..utils import derived_object
from .database import PyDatabase, TuningRecord, Workload


@derived_object
class MemoryDatabase(PyDatabase):
"""An in-memory database based on python list for testing."""

def __init__(self):
super().__init__()
self.records = []
self.workload_reg = []

def has_workload(self, mod: IRModule) -> bool:
for workload in self.workload_reg:
if structural_equal(workload.mod, mod):
return True
return False

def commit_tuning_record(self, record: TuningRecord) -> None:
self.records.append(record)

def commit_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if structural_equal(workload.mod, mod):
return workload
workload = Workload(mod)
self.workload_reg.append(workload)
return workload

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
return list(
filter(
lambda x: x.workload == workload,
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
)
)[: int(top_k)]

def __len__(self) -> int:
return len(self.records)

def print_results(self) -> None:
print("\n".join([str(r) for r in self.records]))
Loading