Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[MetaSchedule][UX] Convenient Object Creation (apache#12643)
Browse files Browse the repository at this point in the history
This PR introduces a set of `.create` methods making it easier to create
MetaSchedule objects.

For example:

```python
ms.database.JSONDatabase(...)
ms.database.create("json")

ms.runner.RPCRunner(...)
ms.runner.create("rpc")
```

Besides, this PR allows `JSONDatabase` to be created via `work_dir`:

```python
db = ms.database.create("json", work_dir="/path/to/db/")
db = ms.database.create(work_dir="/path/to/db/")  # or even simpler
```
  • Loading branch information
junrushao authored and xinetzone committed Nov 25, 2022
1 parent da01998 commit 557a5a6
Show file tree
Hide file tree
Showing 14 changed files with 198 additions and 18 deletions.
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/builder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
Meta Schedule builders that translate IRModule to runtime.Module,
and then export
"""
from .builder import Builder, BuilderInput, BuilderResult, PyBuilder
from .builder import Builder, BuilderInput, BuilderResult, PyBuilder, create
from .local_builder import LocalBuilder
17 changes: 17 additions & 0 deletions python/tvm/meta_schedule/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
"""Meta Schedule builders that translate IRModule to runtime.Module, and then export"""
from typing import Callable, Dict, List, Optional

# isort: off
from typing_extensions import Literal

# isort: on
from tvm._ffi import register_object
from tvm.ir import IRModule
from tvm.runtime import NDArray, Object
Expand Down Expand Up @@ -164,3 +168,16 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
The results of building the given inputs.
"""
raise NotImplementedError


def create( # pylint: disable=keyword-arg-before-vararg
kind: Literal["local"] = "local",
*args,
**kwargs,
) -> Builder:
"""Create a Builder."""
from . import LocalBuilder # pylint: disable=import-outside-toplevel

if kind == "local":
return LocalBuilder(*args, **kwargs) # type: ignore
raise ValueError(f"Unknown Builder: {kind}")
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
The tvm.meta_schedule.database package.
The database that stores serialized tuning records and workloads
"""
from .database import Database, PyDatabase, TuningRecord, Workload
from .database import Database, PyDatabase, TuningRecord, Workload, create
from .json_database import JSONDatabase
from .memory_database import MemoryDatabase
from .ordered_union_database import OrderedUnionDatabase
Expand Down
41 changes: 40 additions & 1 deletion python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
"""TuningRecord database"""
from typing import Any, Callable, List, Optional, Union

# isort: off
from typing_extensions import Literal

# isort: on

from tvm._ffi import register_object
from tvm.ir.module import IRModule
from tvm.runtime import Object
from tvm.target import Target
from tvm.tir.schedule import Schedule, Trace
from typing_extensions import Literal # pylint: disable=wrong-import-order

from .. import _ffi_api
from ..arg_info import ArgInfo
Expand Down Expand Up @@ -483,3 +487,38 @@ def __len__(self) -> int:
The number of records in the database
"""
raise NotImplementedError


def create( # pylint: disable=keyword-arg-before-vararg
kind: Union[
Literal[
"json",
"memory",
"union",
"ordered_union",
],
Callable[[Schedule], bool],
] = "json",
*args,
**kwargs,
) -> Database:
"""Create a Database."""
from . import ( # pylint: disable=import-outside-toplevel
JSONDatabase,
MemoryDatabase,
OrderedUnionDatabase,
ScheduleFnDatabase,
UnionDatabase,
)

if callable(kind):
return ScheduleFnDatabase(kind, *args, **kwargs) # type: ignore
if kind == "json":
return JSONDatabase(*args, **kwargs)
if kind == "memory":
return MemoryDatabase(*args, **kwargs) # type: ignore
if kind == "union":
return UnionDatabase(*args, **kwargs) # type: ignore
if kind == "ordered_union":
return OrderedUnionDatabase(*args, **kwargs) # type: ignore
raise ValueError(f"Unknown Database: {kind}")
31 changes: 25 additions & 6 deletions python/tvm/meta_schedule/database/json_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""The default database that uses a JSON File to store tuning records"""
import os.path as osp
from typing import Optional

from tvm._ffi import register_object

from .. import _ffi_api
Expand All @@ -38,21 +41,37 @@ class JSONDatabase(Database):

def __init__(
self,
path_workload: str,
path_tuning_record: str,
path_workload: Optional[str] = None,
path_tuning_record: Optional[str] = None,
*,
work_dir: Optional[str] = None,
allow_missing: bool = True,
) -> None:
"""Constructor.
Parameters
----------
path_workload : str
The path to the workload table.
path_tuning_record : str
The path to the tuning record table.
path_workload : Optional[str] = None
The path to the workload table. If not specified,
will be generated from `work_dir` as `$work_dir/database_workload.json`.
path_tuning_record : Optional[str] = None
The path to the tuning record table. If not specified,
will be generated from `work_dir` as `$work_dir/database_tuning_record.json`.
work_dir : Optional[str] = None
The work directory, if specified, will be used to generate `path_tuning_record`
and `path_workload`.
allow_missing : bool
Whether to create new file when the given path is not found.
"""
if work_dir is not None:
if path_workload is None:
path_workload = osp.join(work_dir, "database_workload.json")
if path_tuning_record is None:
path_tuning_record = osp.join(work_dir, "database_tuning_record.json")
if path_workload is None:
raise ValueError("`path_workload` is not specified.")
if path_tuning_record is None:
raise ValueError("`path_tuning_record` is not specified.")
self.__init_handle_by_constructor__(
_ffi_api.DatabaseJSONDatabase, # type: ignore # pylint: disable=no-member
path_workload,
Expand Down
12 changes: 10 additions & 2 deletions python/tvm/meta_schedule/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
Meta Schedule runners that runs an artifact either locally or through the RPC interface
"""
from .config import EvaluatorConfig, RPCConfig
from .rpc_runner import RPCRunner
from .local_runner import LocalRunner, LocalRunnerFuture
from .runner import PyRunner, Runner, RunnerFuture, RunnerInput, RunnerResult, PyRunnerFuture
from .rpc_runner import RPCRunner
from .runner import (
PyRunner,
PyRunnerFuture,
Runner,
RunnerFuture,
RunnerInput,
RunnerResult,
create,
)
22 changes: 21 additions & 1 deletion python/tvm/meta_schedule/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""Runners"""
from typing import Callable, Optional, List
from typing import Callable, List, Optional

# isort: off
from typing_extensions import Literal

# isort: on

from tvm._ffi import register_object
from tvm.runtime import Object
Expand Down Expand Up @@ -223,3 +228,18 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
The runner futures.
"""
raise NotImplementedError


def create( # pylint: disable=keyword-arg-before-vararg
kind: Literal["local", "rpc"] = "local",
*args,
**kwargs,
) -> Runner:
"""Create a Runner."""
from . import LocalRunner, RPCRunner # pylint: disable=import-outside-toplevel

if kind == "local":
return LocalRunner(*args, **kwargs) # type: ignore
elif kind == "rpc":
return RPCRunner(*args, **kwargs) # type: ignore
raise ValueError(f"Unknown Runner: {kind}")
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/search_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
from .evolutionary_search import EvolutionarySearch
from .replay_func import ReplayFunc
from .replay_trace import ReplayTrace
from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy
from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy, create
29 changes: 29 additions & 0 deletions python/tvm/meta_schedule/search_strategy/search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
"""
from typing import TYPE_CHECKING, Callable, List, Optional

# isort: off
from typing_extensions import Literal

# isort: on
from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.tir.schedule import Schedule
Expand Down Expand Up @@ -245,3 +249,28 @@ def notify_runner_results(
The profiling results from the runner.
"""
raise NotImplementedError


def create( # pylint: disable=keyword-arg-before-vararg
kind: Literal[
"evolutionary",
"replay_trace",
"replay_func",
] = "evolutionary",
*args,
**kwargs,
) -> SearchStrategy:
"""Create a search strategy."""
from . import ( # pylint: disable=import-outside-toplevel
EvolutionarySearch,
ReplayFunc,
ReplayTrace,
)

if kind == "evolutionary":
return EvolutionarySearch(*args, **kwargs)
if kind == "replay_trace":
return ReplayTrace(*args, **kwargs)
if kind == "replay_func":
return ReplayFunc(*args, **kwargs)
raise ValueError(f"Unknown SearchStrategy: {kind}")
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/space_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@
"""
from .post_order_apply import PostOrderApply
from .schedule_fn import ScheduleFn
from .space_generator import PySpaceGenerator, ScheduleFnType, SpaceGenerator
from .space_generator import PySpaceGenerator, ScheduleFnType, SpaceGenerator, create
from .space_generator_union import SpaceGeneratorUnion
28 changes: 28 additions & 0 deletions python/tvm/meta_schedule/space_generator/space_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
"""
from typing import TYPE_CHECKING, Callable, List, Optional, Union

# isort: off
from typing_extensions import Literal

# isort: on
from tvm._ffi import register_object
from tvm.ir import IRModule
from tvm.runtime import Object
Expand Down Expand Up @@ -132,3 +136,27 @@ def generate_design_space(self, mod: IRModule) -> List[Schedule]:
The generated design spaces, i.e., schedules.
"""
raise NotImplementedError


def create( # pylint: disable=keyword-arg-before-vararg
kind: Union[
Literal["post_order_apply", "union"],
ScheduleFnType,
] = "post_order_apply",
*args,
**kwargs,
) -> SpaceGenerator:
"""Create a design space generator."""
from . import ( # pylint: disable=import-outside-toplevel
PostOrderApply,
ScheduleFn,
SpaceGeneratorUnion,
)

if callable(kind):
return ScheduleFn(kind, *args, **kwargs) # type: ignore
if kind == "post_order_apply":
return PostOrderApply(*args, **kwargs)
if kind == "union":
return SpaceGeneratorUnion(*args, **kwargs)
raise ValueError(f"Unknown SpaceGenerator: {kind}")
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/task_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
for measure candidates generation and measurement, then save
records to the database.
"""
from .task_scheduler import TaskScheduler, PyTaskScheduler
from .round_robin import RoundRobin
from .gradient_based import GradientBased
from .round_robin import RoundRobin
from .task_scheduler import PyTaskScheduler, TaskScheduler, create
20 changes: 20 additions & 0 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import logging
from typing import Callable, List, Optional

# isort: off
from typing_extensions import Literal

# isort: on

from tvm._ffi import register_object
from tvm.runtime import Object

Expand Down Expand Up @@ -255,3 +260,18 @@ def touch_task(self, task_id: int) -> None:
"""
# Using self._outer to replace the self pointer
_ffi_api.TaskSchedulerTouchTask(self._outer(), task_id) # type: ignore # pylint: disable=no-member


def create( # pylint: disable=keyword-arg-before-vararg
kind: Literal["round-robin", "gradient"] = "gradient",
*args,
**kwargs,
) -> "TaskScheduler":
"""Create a task scheduler."""
from . import GradientBased, RoundRobin # pylint: disable=import-outside-toplevel

if kind == "round-robin":
return RoundRobin(*args, **kwargs)
if kind == "gradient":
return GradientBased(*args, **kwargs)
raise ValueError(f"Unknown TaskScheduler name: {kind}")
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _get_network(
"float32": torch.float32, # pylint: disable=no-member
}[dtype]
)
scripted_model = torch.jit.trace(model, input_data).eval()
scripted_model = torch.jit.trace(model, input_data).eval() # type: ignore
input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
Expand Down Expand Up @@ -149,7 +149,7 @@ def _get_network(
input_dtype = "int64"
a = torch.randint(10000, input_shape) # pylint: disable=no-member
model.eval()
scripted_model = torch.jit.trace(model, [a], strict=False)
scripted_model = torch.jit.trace(model, [a], strict=False) # type: ignore
input_name = "input_ids"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
Expand Down

0 comments on commit 557a5a6

Please sign in to comment.