diff --git a/python/tvm/meta_schedule/builder/__init__.py b/python/tvm/meta_schedule/builder/__init__.py index 859c74d75622..ac71e3a0c1fc 100644 --- a/python/tvm/meta_schedule/builder/__init__.py +++ b/python/tvm/meta_schedule/builder/__init__.py @@ -19,5 +19,5 @@ Meta Schedule builders that translate IRModule to runtime.Module, and then export """ -from .builder import Builder, BuilderInput, BuilderResult, PyBuilder +from .builder import Builder, BuilderInput, BuilderResult, PyBuilder, create from .local_builder import LocalBuilder diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index daa9f7be4214..a2254f243380 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -17,6 +17,10 @@ """Meta Schedule builders that translate IRModule to runtime.Module, and then export""" from typing import Callable, Dict, List, Optional +# isort: off +from typing_extensions import Literal + +# isort: on from tvm._ffi import register_object from tvm.ir import IRModule from tvm.runtime import NDArray, Object @@ -164,3 +168,16 @@ def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: The results of building the given inputs. """ raise NotImplementedError + + +def create( # pylint: disable=keyword-arg-before-vararg + kind: Literal["local"] = "local", + *args, + **kwargs, +) -> Builder: + """Create a Builder.""" + from . import LocalBuilder # pylint: disable=import-outside-toplevel + + if kind == "local": + return LocalBuilder(*args, **kwargs) # type: ignore + raise ValueError(f"Unknown Builder: {kind}") diff --git a/python/tvm/meta_schedule/database/__init__.py b/python/tvm/meta_schedule/database/__init__.py index 679923e47936..66d011ed5246 100644 --- a/python/tvm/meta_schedule/database/__init__.py +++ b/python/tvm/meta_schedule/database/__init__.py @@ -18,7 +18,7 @@ The tvm.meta_schedule.database package. The database that stores serialized tuning records and workloads """ -from .database import Database, PyDatabase, TuningRecord, Workload +from .database import Database, PyDatabase, TuningRecord, Workload, create from .json_database import JSONDatabase from .memory_database import MemoryDatabase from .ordered_union_database import OrderedUnionDatabase diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index aa509b715132..7a1338f46b20 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -17,12 +17,16 @@ """TuningRecord database""" from typing import Any, Callable, List, Optional, Union +# isort: off +from typing_extensions import Literal + +# isort: on + from tvm._ffi import register_object from tvm.ir.module import IRModule from tvm.runtime import Object from tvm.target import Target from tvm.tir.schedule import Schedule, Trace -from typing_extensions import Literal # pylint: disable=wrong-import-order from .. import _ffi_api from ..arg_info import ArgInfo @@ -483,3 +487,38 @@ def __len__(self) -> int: The number of records in the database """ raise NotImplementedError + + +def create( # pylint: disable=keyword-arg-before-vararg + kind: Union[ + Literal[ + "json", + "memory", + "union", + "ordered_union", + ], + Callable[[Schedule], bool], + ] = "json", + *args, + **kwargs, +) -> Database: + """Create a Database.""" + from . import ( # pylint: disable=import-outside-toplevel + JSONDatabase, + MemoryDatabase, + OrderedUnionDatabase, + ScheduleFnDatabase, + UnionDatabase, + ) + + if callable(kind): + return ScheduleFnDatabase(kind, *args, **kwargs) # type: ignore + if kind == "json": + return JSONDatabase(*args, **kwargs) + if kind == "memory": + return MemoryDatabase(*args, **kwargs) # type: ignore + if kind == "union": + return UnionDatabase(*args, **kwargs) # type: ignore + if kind == "ordered_union": + return OrderedUnionDatabase(*args, **kwargs) # type: ignore + raise ValueError(f"Unknown Database: {kind}") diff --git a/python/tvm/meta_schedule/database/json_database.py b/python/tvm/meta_schedule/database/json_database.py index 6897b82d9888..b36ac61ef2fb 100644 --- a/python/tvm/meta_schedule/database/json_database.py +++ b/python/tvm/meta_schedule/database/json_database.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. """The default database that uses a JSON File to store tuning records""" +import os.path as osp +from typing import Optional + from tvm._ffi import register_object from .. import _ffi_api @@ -38,21 +41,37 @@ class JSONDatabase(Database): def __init__( self, - path_workload: str, - path_tuning_record: str, + path_workload: Optional[str] = None, + path_tuning_record: Optional[str] = None, + *, + work_dir: Optional[str] = None, allow_missing: bool = True, ) -> None: """Constructor. Parameters ---------- - path_workload : str - The path to the workload table. - path_tuning_record : str - The path to the tuning record table. + path_workload : Optional[str] = None + The path to the workload table. If not specified, + will be generated from `work_dir` as `$work_dir/database_workload.json`. + path_tuning_record : Optional[str] = None + The path to the tuning record table. If not specified, + will be generated from `work_dir` as `$work_dir/database_tuning_record.json`. + work_dir : Optional[str] = None + The work directory, if specified, will be used to generate `path_tuning_record` + and `path_workload`. allow_missing : bool Whether to create new file when the given path is not found. """ + if work_dir is not None: + if path_workload is None: + path_workload = osp.join(work_dir, "database_workload.json") + if path_tuning_record is None: + path_tuning_record = osp.join(work_dir, "database_tuning_record.json") + if path_workload is None: + raise ValueError("`path_workload` is not specified.") + if path_tuning_record is None: + raise ValueError("`path_tuning_record` is not specified.") self.__init_handle_by_constructor__( _ffi_api.DatabaseJSONDatabase, # type: ignore # pylint: disable=no-member path_workload, diff --git a/python/tvm/meta_schedule/runner/__init__.py b/python/tvm/meta_schedule/runner/__init__.py index 413bea6d2fab..f0e1028bbf28 100644 --- a/python/tvm/meta_schedule/runner/__init__.py +++ b/python/tvm/meta_schedule/runner/__init__.py @@ -19,6 +19,14 @@ Meta Schedule runners that runs an artifact either locally or through the RPC interface """ from .config import EvaluatorConfig, RPCConfig -from .rpc_runner import RPCRunner from .local_runner import LocalRunner, LocalRunnerFuture -from .runner import PyRunner, Runner, RunnerFuture, RunnerInput, RunnerResult, PyRunnerFuture +from .rpc_runner import RPCRunner +from .runner import ( + PyRunner, + PyRunnerFuture, + Runner, + RunnerFuture, + RunnerInput, + RunnerResult, + create, +) diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index 90b53fde8c29..539e47f15c41 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -15,7 +15,12 @@ # specific language governing permissions and limitations # under the License. """Runners""" -from typing import Callable, Optional, List +from typing import Callable, List, Optional + +# isort: off +from typing_extensions import Literal + +# isort: on from tvm._ffi import register_object from tvm.runtime import Object @@ -223,3 +228,18 @@ def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: The runner futures. """ raise NotImplementedError + + +def create( # pylint: disable=keyword-arg-before-vararg + kind: Literal["local", "rpc"] = "local", + *args, + **kwargs, +) -> Runner: + """Create a Runner.""" + from . import LocalRunner, RPCRunner # pylint: disable=import-outside-toplevel + + if kind == "local": + return LocalRunner(*args, **kwargs) # type: ignore + elif kind == "rpc": + return RPCRunner(*args, **kwargs) # type: ignore + raise ValueError(f"Unknown Runner: {kind}") diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index 2046067d6c00..ffe7e1473954 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -23,4 +23,4 @@ from .evolutionary_search import EvolutionarySearch from .replay_func import ReplayFunc from .replay_trace import ReplayTrace -from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy +from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy, create diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 1cd8a448fe8e..e88cdf825a79 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -20,6 +20,10 @@ """ from typing import TYPE_CHECKING, Callable, List, Optional +# isort: off +from typing_extensions import Literal + +# isort: on from tvm._ffi import register_object from tvm.runtime import Object from tvm.tir.schedule import Schedule @@ -245,3 +249,28 @@ def notify_runner_results( The profiling results from the runner. """ raise NotImplementedError + + +def create( # pylint: disable=keyword-arg-before-vararg + kind: Literal[ + "evolutionary", + "replay_trace", + "replay_func", + ] = "evolutionary", + *args, + **kwargs, +) -> SearchStrategy: + """Create a search strategy.""" + from . import ( # pylint: disable=import-outside-toplevel + EvolutionarySearch, + ReplayFunc, + ReplayTrace, + ) + + if kind == "evolutionary": + return EvolutionarySearch(*args, **kwargs) + if kind == "replay_trace": + return ReplayTrace(*args, **kwargs) + if kind == "replay_func": + return ReplayFunc(*args, **kwargs) + raise ValueError(f"Unknown SearchStrategy: {kind}") diff --git a/python/tvm/meta_schedule/space_generator/__init__.py b/python/tvm/meta_schedule/space_generator/__init__.py index d2039c4511c9..c417ec2d7d4a 100644 --- a/python/tvm/meta_schedule/space_generator/__init__.py +++ b/python/tvm/meta_schedule/space_generator/__init__.py @@ -21,5 +21,5 @@ """ from .post_order_apply import PostOrderApply from .schedule_fn import ScheduleFn -from .space_generator import PySpaceGenerator, ScheduleFnType, SpaceGenerator +from .space_generator import PySpaceGenerator, ScheduleFnType, SpaceGenerator, create from .space_generator_union import SpaceGeneratorUnion diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 74c29b4de0dd..9d7ebf3bae26 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -20,6 +20,10 @@ """ from typing import TYPE_CHECKING, Callable, List, Optional, Union +# isort: off +from typing_extensions import Literal + +# isort: on from tvm._ffi import register_object from tvm.ir import IRModule from tvm.runtime import Object @@ -132,3 +136,27 @@ def generate_design_space(self, mod: IRModule) -> List[Schedule]: The generated design spaces, i.e., schedules. """ raise NotImplementedError + + +def create( # pylint: disable=keyword-arg-before-vararg + kind: Union[ + Literal["post_order_apply", "union"], + ScheduleFnType, + ] = "post_order_apply", + *args, + **kwargs, +) -> SpaceGenerator: + """Create a design space generator.""" + from . import ( # pylint: disable=import-outside-toplevel + PostOrderApply, + ScheduleFn, + SpaceGeneratorUnion, + ) + + if callable(kind): + return ScheduleFn(kind, *args, **kwargs) # type: ignore + if kind == "post_order_apply": + return PostOrderApply(*args, **kwargs) + if kind == "union": + return SpaceGeneratorUnion(*args, **kwargs) + raise ValueError(f"Unknown SpaceGenerator: {kind}") diff --git a/python/tvm/meta_schedule/task_scheduler/__init__.py b/python/tvm/meta_schedule/task_scheduler/__init__.py index 1a67aa6f6831..51985570b06f 100644 --- a/python/tvm/meta_schedule/task_scheduler/__init__.py +++ b/python/tvm/meta_schedule/task_scheduler/__init__.py @@ -20,6 +20,6 @@ for measure candidates generation and measurement, then save records to the database. """ -from .task_scheduler import TaskScheduler, PyTaskScheduler -from .round_robin import RoundRobin from .gradient_based import GradientBased +from .round_robin import RoundRobin +from .task_scheduler import PyTaskScheduler, TaskScheduler, create diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index 3d57a6b01b9d..29a5f18dfb8a 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -19,6 +19,11 @@ import logging from typing import Callable, List, Optional +# isort: off +from typing_extensions import Literal + +# isort: on + from tvm._ffi import register_object from tvm.runtime import Object @@ -255,3 +260,18 @@ def touch_task(self, task_id: int) -> None: """ # Using self._outer to replace the self pointer _ffi_api.TaskSchedulerTouchTask(self._outer(), task_id) # type: ignore # pylint: disable=no-member + + +def create( # pylint: disable=keyword-arg-before-vararg + kind: Literal["round-robin", "gradient"] = "gradient", + *args, + **kwargs, +) -> "TaskScheduler": + """Create a task scheduler.""" + from . import GradientBased, RoundRobin # pylint: disable=import-outside-toplevel + + if kind == "round-robin": + return RoundRobin(*args, **kwargs) + if kind == "gradient": + return GradientBased(*args, **kwargs) + raise ValueError(f"Unknown TaskScheduler name: {kind}") diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 016263489527..f4f6336df33f 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -85,7 +85,7 @@ def _get_network( "float32": torch.float32, # pylint: disable=no-member }[dtype] ) - scripted_model = torch.jit.trace(model, input_data).eval() + scripted_model = torch.jit.trace(model, input_data).eval() # type: ignore input_name = "input0" shape_list = [(input_name, input_shape)] mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) @@ -149,7 +149,7 @@ def _get_network( input_dtype = "int64" a = torch.randint(10000, input_shape) # pylint: disable=no-member model.eval() - scripted_model = torch.jit.trace(model, [a], strict=False) + scripted_model = torch.jit.trace(model, [a], strict=False) # type: ignore input_name = "input_ids" shape_list = [(input_name, input_shape)] mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)