Skip to content

Commit

Permalink
[Bug] Fix Infinite Loop Caused When Calling Methods Not Overrided In …
Browse files Browse the repository at this point in the history
…PyClass (apache#496)

* Add replay func.

* Simplify unittest.

* Fix the infinite loop for not implemented methods in PyClass.

* Fix funcname.

* Restore the test.

* Fix all the PyClaes with optional function override.

* Fix __str__ and __len__.

* Move NotImplementedError to declaration.

* Add docs.

* Rebase.
  • Loading branch information
zxybazh authored Nov 4, 2021
1 parent bf83546 commit 8f58137
Show file tree
Hide file tree
Showing 31 changed files with 321 additions and 51 deletions.
1 change: 1 addition & 0 deletions include/tvm/meta_schedule/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class PyBuilderNode : public BuilderNode {
}

Array<BuilderResult> Build(const Array<BuilderInput>& build_inputs) final {
ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!";
return f_build(build_inputs);
}

Expand Down
23 changes: 17 additions & 6 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,29 @@ class PyDatabaseNode : public DatabaseNode {
// `f_size` is not visited
}

static constexpr const char* _type_key = "meta_schedule.PyDatabase";
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode);

Workload CommitWorkload(const IRModule& mod) final { return f_commit_workload(mod); }
Workload CommitWorkload(const IRModule& mod) final {
ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
return f_commit_workload(mod);
}

void CommitTuningRecord(const TuningRecord& record) final { f_commit_tuning_record(record); }
void CommitTuningRecord(const TuningRecord& record) final {
ICHECK(f_commit_tuning_record != nullptr)
<< "PyDatabase's CommitTuningRecord method not implemented!";
f_commit_tuning_record(record);
}

Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
return f_get_top_k(workload, top_k);
}

int64_t Size() final { return f_size(); }
int64_t Size() final {
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
return f_size();
}

static constexpr const char* _type_key = "meta_schedule.PyDatabase";
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode);
};

/*!
Expand Down
1 change: 1 addition & 0 deletions include/tvm/meta_schedule/measure_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class PyMeasureCallbackNode : public MeasureCallbackNode {
const Array<MeasureCandidate>& measure_candidates, //
const Array<BuilderResult>& builds, //
const Array<RunnerResult>& results) final {
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results);
}

Expand Down
7 changes: 6 additions & 1 deletion include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,15 @@ class PyMutatorNode : public MutatorNode {
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyMutator's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

Optional<tir::Trace> Apply(const tir::Trace& trace) final { return this->f_apply(trace); }
Optional<tir::Trace> Apply(const tir::Trace& trace) final {
ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!";
return this->f_apply(trace);
}

static constexpr const char* _type_key = "meta_schedule.PyMutator";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,15 @@ class PyPostprocNode : public PostprocNode {
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyPostproc's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

bool Apply(const tir::Schedule& sch) final { return this->f_apply(sch); }
bool Apply(const tir::Schedule& sch) final {
ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!";
return this->f_apply(sch);
}

static constexpr const char* _type_key = "meta_schedule.PyPostproc";
TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
Expand Down
5 changes: 4 additions & 1 deletion include/tvm/meta_schedule/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ class PyRunnerNode : public RunnerNode {
// `f_run` is not visited
}

Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) final { return f_run(runner_inputs); }
Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) final {
ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!";
return f_run(runner_inputs);
}

static constexpr const char* _type_key = "meta_schedule.PyRunner";
TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode);
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,13 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PyScheduleRule's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final {
ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!";
return this->f_apply(sch, block);
}

Expand Down
12 changes: 11 additions & 1 deletion include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,30 @@ class PySearchStrategyNode : public SearchStrategyNode {
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PySearchStrategy's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

void PreTuning(const Array<tir::Schedule>& design_spaces) final {
ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!";
this->f_pre_tuning(design_spaces);
}

void PostTuning() final { this->f_post_tuning(); }
void PostTuning() final {
ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!";
this->f_post_tuning();
}

Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
ICHECK(f_generate_measure_candidates != nullptr)
<< "PySearchStrategy's GenerateMeasureCandidates method not implemented!";
return this->f_generate_measure_candidates();
}

void NotifyRunnerResults(const Array<RunnerResult>& results) final {
ICHECK(f_notify_runner_results != nullptr)
<< "PySearchStrategy's NotifyRunnerResults method not implemented!";
this->f_notify_runner_results(results);
}

Expand Down
4 changes: 4 additions & 0 deletions include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,14 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
}

void InitializeWithTuneContext(const TuneContext& tune_context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PySpaceGenerator's InitializeWithTuneContext !";
f_initialize_with_tune_context(tune_context);
}

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final {
ICHECK(f_generate_design_space != nullptr)
<< "PySpaceGenerator's GenerateDesignSpace method not implemented!";
return f_generate_design_space(mod);
}

Expand Down
33 changes: 28 additions & 5 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class TaskSchedulerNode : public runtime::Object {
TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object);
};

class TaskScheduler;

/*! \brief The task scheduler with customized methods on the python-side. */
class PyTaskSchedulerNode : public TaskSchedulerNode {
public:
Expand Down Expand Up @@ -183,26 +185,47 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
}

void Tune() final { //
f_tune();
if (f_tune == nullptr) {
TaskSchedulerNode::Tune();
} else {
f_tune();
}
}

void InitializeTask(int task_id) final { //
f_initialize_task(task_id);
if (f_initialize_task == nullptr) {
TaskSchedulerNode::InitializeTask(task_id);
} else {
f_initialize_task(task_id);
}
}

void SetTaskStopped(int task_id) final { //
f_set_task_stopped(task_id);
if (f_set_task_stopped == nullptr) {
TaskSchedulerNode::SetTaskStopped(task_id);
} else {
f_set_task_stopped(task_id);
}
}

bool IsTaskRunning(int task_id) final { //
return f_is_task_running(task_id);
if (f_is_task_running == nullptr) {
return TaskSchedulerNode::IsTaskRunning(task_id);
} else {
return f_is_task_running(task_id);
}
}

void JoinRunningTask(int task_id) final { //
f_join_running_task(task_id);
if (f_join_running_task == nullptr) {
return TaskSchedulerNode::JoinRunningTask(task_id);
} else {
return f_join_running_task(task_id);
}
}

int NextTaskId() final { //
ICHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!";
return f_next_task_id();
}

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm.target import Target

from .. import _ffi_api
from ..utils import check_override


@register_object("meta_schedule.BuilderInput")
Expand Down Expand Up @@ -119,6 +120,7 @@ class PyBuilder(Builder):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, Builder)
def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]:
return self.build(build_inputs)

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .. import _ffi_api
from ..arg_info import ArgInfo
from ..utils import _json_de_tvm
from ..utils import _json_de_tvm, check_override


@register_object("meta_schedule.Workload")
Expand Down Expand Up @@ -207,15 +207,19 @@ class PyDatabase(Database):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, Database)
def f_commit_workload(mod: IRModule) -> Workload:
return self.commit_workload(mod)

@check_override(self.__class__, Database)
def f_commit_tuning_record(record: TuningRecord) -> None:
self.commit_tuning_record(record)

@check_override(self.__class__, Database)
def f_get_top_k(workload: Workload, top_k: int) -> List[TuningRecord]:
return self.get_top_k(workload, top_k)

@check_override(self.__class__, Database, func_name="__len__")
def f_size() -> int:
return len(self)

Expand All @@ -225,4 +229,4 @@ def f_size() -> int:
f_commit_tuning_record,
f_get_top_k,
f_size,
)
)
13 changes: 7 additions & 6 deletions python/tvm/meta_schedule/measure_callback/measure_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@

from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.search_strategy import MeasureCandidate
from tvm.meta_schedule.builder import BuilderResult
from tvm.meta_schedule.runner import RunnerResult
from tvm.meta_schedule.utils import _get_hex_address

from ..tune_context import TuneContext
from ..search_strategy import MeasureCandidate
from ..builder import BuilderResult
from ..runner import RunnerResult
from ..utils import _get_hex_address, check_override

from .. import _ffi_api

if TYPE_CHECKING:
from ..tune_context import TuneContext
from ..task_scheduler import TaskScheduler


Expand Down Expand Up @@ -77,6 +77,7 @@ class PyMeasureCallback(MeasureCallback):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, MeasureCallback)
def f_apply(
task_scheduler: "TaskScheduler",
tasks: List[TuneContext],
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/meta_schedule/mutator/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.runtime import Object
from tvm.tir.schedule import Trace

from ..utils import _get_hex_address
from ..utils import _get_hex_address, check_override
from .. import _ffi_api

if TYPE_CHECKING:
Expand Down Expand Up @@ -66,9 +66,11 @@ class PyMutator(Mutator):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, Mutator)
def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
self.initialize_with_tune_context(tune_context)

@check_override(self.__class__, Mutator)
def f_apply(trace: Trace) -> Optional[Trace]:
return self.apply(trace)

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/meta_schedule/postproc/postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

from typing import TYPE_CHECKING

from tvm._ffi import register_object, register_func
from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.tir.schedule import Schedule
from tvm.meta_schedule.utils import _get_hex_address

from .. import _ffi_api
from ..utils import _get_hex_address, check_override

if TYPE_CHECKING:
from ..tune_context import TuneContext
Expand Down Expand Up @@ -75,9 +75,11 @@ class PyPostproc(Postproc):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, Postproc)
def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
self.initialize_with_tune_context(tune_context)

@check_override(self.__class__, Postproc)
def f_apply(sch: Schedule) -> bool:
return self.apply(sch)

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .. import _ffi_api
from ..arg_info import ArgInfo
from ..utils import check_override


@register_object("meta_schedule.RunnerInput")
Expand Down Expand Up @@ -158,6 +159,7 @@ class PyRunner(Runner):
def __init__(self) -> None:
"""Constructor"""

@check_override(self.__class__, Runner)
def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
return self.run(runner_inputs)

Expand Down
4 changes: 3 additions & 1 deletion python/tvm/meta_schedule/schedule_rule/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tvm.runtime import Object
from tvm.tir.schedule import Schedule, BlockRV

from ..utils import _get_hex_address
from ..utils import _get_hex_address, check_override
from .. import _ffi_api

if TYPE_CHECKING:
Expand Down Expand Up @@ -72,9 +72,11 @@ class PyScheduleRule(ScheduleRule):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, ScheduleRule)
def f_initialize_with_tune_context(tune_context: "TuneContext") -> None:
self.initialize_with_tune_context(tune_context)

@check_override(self.__class__, ScheduleRule)
def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]:
return self.apply(sch, block)

Expand Down
Loading

0 comments on commit 8f58137

Please sign in to comment.