From 2789e669b59abf9f440504f66fe264203978bb34 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 29 Mar 2022 00:20:58 -0700 Subject: [PATCH] [MetaSchedule] Support grouping in the cost model --- .../tvm/meta_schedule/cost_model/xgb_model.py | 243 +++++++++++------- .../testing/tune_relay_auto_scheduler.py | 30 ++- .../testing/tune_relay_meta_schedule.py | 87 ++++++- python/tvm/meta_schedule/tune.py | 13 +- python/tvm/meta_schedule/utils.py | 11 +- python/tvm/relay/build_module.py | 2 +- python/tvm/rpc/client.py | 10 +- src/meta_schedule/tune_context.cc | 3 + src/meta_schedule/utils.h | 15 ++ .../unittest/test_meta_schedule_cost_model.py | 50 ++-- 10 files changed, 325 insertions(+), 139 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index 9a290516230d..9d95623c2bd6 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -17,26 +17,29 @@ """ XGBoost-based cost model """ -from itertools import chain as itertools_chain import logging import os import tempfile -from typing import Any, Callable, Dict, List, NamedTuple, Optional, TYPE_CHECKING, Tuple +from collections import OrderedDict +from itertools import chain as itertools_chain +from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple import numpy as np # type: ignore from ...contrib.tar import tar, untar +from ...runtime import NDArray from ..cost_model import PyCostModel from ..feature_extractor import FeatureExtractor from ..runner import RunnerResult from ..search_strategy import MeasureCandidate -from ..utils import cpu_count, derived_object +from ..utils import cpu_count, derived_object, shash2hex from .metric import max_curve if TYPE_CHECKING: - from ..tune_context import TuneContext import xgboost as xgb # type: ignore + from ..tune_context import TuneContext + logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -75,8 +78,8 @@ class PackSum: def __init__( self, - xs: List[np.ndarray], - ys: Optional[np.ndarray], + xs: List[np.ndarray], # pylint: disable=invalid-name + ys: Optional[np.ndarray], # pylint: disable=invalid-name ): """Create PackSum format given a batch of samples @@ -217,8 +220,15 @@ class XGBConfig(NamedTuple): Default is None, which means to use physical number of cores. """ + max_depth: int = 10 + gamma: float = 0.001 + min_child_weight: float = 0 + eta: float = 0.2 + seed: int = 43 + nthread: Optional[int] = None + def to_dict(self): - xgb_params = { + return { "max_depth": self.max_depth, "gamma": self.gamma, "min_child_weight": self.min_child_weight, @@ -226,14 +236,47 @@ def to_dict(self): "seed": self.seed, "nthread": self.nthread, } - return xgb_params - max_depth: int = 10 - gamma: float = 0.001 - min_child_weight: float = 0 - eta: float = 0.2 - seed: int = 43 - nthread: Optional[int] = None + +class FeatureGroup: + """Feature group + + Parameters + ---------- + group_hash : str + The hash of the group + features : List[np.ndarray] + The features + costs : List[float] + The costs + min_cost : float + The minimum cost + """ + + group_hash: str + features: List[np.ndarray] + costs: np.ndarray + min_cost: float + + def __init__( + self, + group_hash: str, + features: List[np.ndarray], + costs: np.ndarray, + ) -> None: + self.group_hash = group_hash + self.features = features + self.costs = costs + self.min_cost = np.min(costs) + + def append( + self, + features: List[np.ndarray], + costs: np.ndarray, + ) -> None: + self.features.extend(features) + self.costs = np.append(self.costs, costs) + self.min_cost = np.min(self.costs) @derived_object @@ -268,9 +311,8 @@ class XGBModel(PyCostModel): verbose_eval: int average_peak_n: int # states - cached_features: List[np.ndarray] - cached_mean_costs: np.ndarray - cached_normalizer: Optional[float] + data: Dict[str, FeatureGroup] + data_size: int booster: Optional["xgb.Booster"] def __init__( @@ -293,7 +335,7 @@ def __init__( # model-related if config.nthread is None: # use physical core number - config = config._replace(nthread=cpu_count(logical=False)) + config = config._replace(nthread=cpu_count(logical=True)) self.config = config # behavior of randomness self.num_warmup_samples = num_warmup_samples @@ -302,9 +344,8 @@ def __init__( self.verbose_eval = verbose_eval self.average_peak_n = average_peak_n # states - self.cached_features = [] - self.cached_mean_costs = np.empty((0,), dtype="float64") - self.cached_normalizer = None + self.data = OrderedDict() + self.data_size = 0 self.booster = None def load(self, path: str) -> None: @@ -324,16 +365,29 @@ def load(self, path: str) -> None: import xgboost as xgb # pylint: disable=import-outside-toplevel with tempfile.TemporaryDirectory() as tmp_dir: + model_path = os.path.join(tmp_dir, "model.bin") + data_path = os.path.join(tmp_dir, "data.npy") + # Step 1. Untar untar(path, tmp_dir) - self.booster = xgb.Booster() - self.booster.load_model(os.path.join(tmp_dir, "model.bin")) - self.cached_features = list( - np.load(os.path.join(tmp_dir, "cached_features.npy"), allow_pickle=True) - ) - self.cached_mean_costs = np.load( - os.path.join(tmp_dir, "cached_mean_costs.npy"), allow_pickle=True - ) - self._set_cached_normalizer() + # Step 2. Load data + data = OrderedDict() + data_size = 0 + for group_hash, features, costs in np.load(data_path, allow_pickle=True): + data[group_hash] = FeatureGroup( + group_hash=group_hash, + features=list(features), + costs=costs, + ) + data_size += len(costs) + # Step 3. Load the model + if os.path.exists(model_path): + booster = xgb.Booster() + booster.load_model(model_path) + else: + self.booster = None + self.data = data + self.data_size = data_size + self.booster = booster def save(self, path: str) -> None: """Save the cost model to given file location. @@ -349,26 +403,30 @@ def save(self, path: str) -> None: previously cached feature vectors and results, so that the subsequent training process could use all the existing data being stored on disk. """ - import xgboost as xgb # pylint: disable=import-outside-toplevel - - if self.booster is None: - # save all the parameters - self.booster = xgb.Booster(self.config.to_dict()) with tempfile.TemporaryDirectory() as tmp_dir: - self.booster.save_model(os.path.join(tmp_dir, "model.bin")) + model_path = os.path.join(tmp_dir, "model.bin") + data_path = os.path.join(tmp_dir, "data.npy") + # Step 1. Save the model + booster = self.booster + if booster is not None: + booster.save_model(model_path) + else: + model_path = None + # Step 2. Save data + data = [ + ( + g.group_hash, + g.features, + g.costs, + ) + for g in self.data.values() + ] np.save( - os.path.join(tmp_dir, "cached_features.npy"), - np.array(self.cached_features, dtype=object), - ) - np.save(os.path.join(tmp_dir, "cached_mean_costs.npy"), self.cached_mean_costs) - tar( - path, - [ - os.path.join(tmp_dir, "model.bin"), - os.path.join(tmp_dir, "cached_features.npy"), - os.path.join(tmp_dir, "cached_mean_costs.npy"), - ], + file=data_path, + arr=np.array(data, dtype=object), ) + # Step 3. Tar it + tar(path, [x for x in [model_path, data_path] if x is not None]) logger.info("Saved XGBModel to %s", path) def update( @@ -391,39 +449,55 @@ def update( assert len(candidates) == len(results) if len(candidates) == 0: return - # extract feature and do validation + + # Step 1. Get the feature group + new_group_hash = shash2hex(context.mod) + group = self.data.get(new_group_hash, None) + + # Step 2. Extract features + def _feature(x: NDArray) -> np.ndarray: + return x.numpy().astype("float32") def _mean_cost(x: RunnerResult) -> float: if not x.run_secs: return 1e10 return float(np.median([float(s) for s in x.run_secs])) - new_features = [ - x.numpy().astype("float32") for x in self.extractor.extract_from(context, candidates) - ] - new_mean_costs = np.asarray( - [_mean_cost(x) for x in results], - dtype="float32", - ) - if self.booster is not None and self.cached_normalizer is not None: + new_features = [_feature(x) for x in self.extractor.extract_from(context, candidates)] + new_mean_costs = np.array([_mean_cost(x) for x in results]).astype("float32") + + # Steps 3. Run validation + if group is not None and self.booster is not None: logger.debug( "XGB validation: %s", "\t".join( f"{key}: {score:.6f}" for key, score in self._validate( xs=new_features, - ys=new_mean_costs, + ys=group.min_cost / new_mean_costs, ) ), ) - # use together with previous features - self.cached_features.extend(new_features) - self.cached_mean_costs = np.append(self.cached_mean_costs, new_mean_costs) - self._set_cached_normalizer() - # train xgb model + + # Step 4. Add the features into the data points + if group is None: + group = FeatureGroup( + group_hash=new_group_hash, + features=new_features, + costs=new_mean_costs, + ) + else: + group.append(new_features, new_mean_costs) + self.data[new_group_hash] = group + self.data_size += len(new_features) + + # Step 5. Re-train the model self._train( - xs=self.cached_features, - ys=self.cached_mean_costs, + xs=list(itertools_chain.from_iterable([g.features for g in self.data.values()])), + ys=np.concatenate( + [g.min_cost / g.costs for g in self.data.values()], + axis=0, + ), ) def predict( @@ -445,10 +519,16 @@ def predict( result : np.ndarray The predicted normalized score. """ - n_measured = len(self.cached_features) - if self.booster is not None and n_measured >= self.num_warmup_samples: - features = self.extractor.extract_from(context, candidates) - ret = self._predict(xs=[x.numpy().astype("float32") for x in features]) + if self.data_size >= self.num_warmup_samples and self.booster is not None: + ret = self._predict( + xs=[ + x.numpy().astype("float32") + for x in self.extractor.extract_from( + context, + candidates, + ) + ] + ) else: ret = np.random.uniform( low=0, @@ -464,10 +544,7 @@ def _train( # type: ignore # pylint: disable=invalid-name ) -> None: import xgboost as xgb # type: ignore # pylint: disable=import-outside-toplevel - self.d_train = PackSum( - xs=xs, - ys=self.cached_normalizer / ys, - ) + self.d_train = PackSum(xs=xs, ys=ys) def obj(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument return self.d_train.obj_square_error(ys_pred) @@ -475,9 +552,7 @@ def obj(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: def rmse(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument return self.d_train.rmse(ys_pred) - def average_peak_score( - ys_pred: np.ndarray, d_train: "xgb.DMatrix" # type: ignore # pylint: disable = unused-argument - ): + def avg_peak_score(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument return self.d_train.average_peak_score(ys_pred, self.average_peak_n) self.booster = xgb.train( @@ -491,7 +566,7 @@ def average_peak_score( verbose_eval=self.verbose_eval, fevals=[ rmse, - average_peak_score, + avg_peak_score, ], evals=[(self.d_train.dmatrix, "tr")], ) @@ -528,13 +603,9 @@ def _validate( # type: ignore # pylint: disable=invalid-name scores: np.ndarray The predicted result for all inputs. """ - if self.booster is None or self.cached_normalizer is None: - return [] + assert self.booster is not None - d_valid = PackSum( - xs=xs, - ys=self.cached_normalizer / ys, - ) + d_valid = PackSum(xs=xs, ys=ys) def average_peak_score(ys_pred: np.ndarray): return d_valid.average_peak_score(ys_pred, n=self.average_peak_n) @@ -550,14 +621,6 @@ def average_peak_score(ys_pred: np.ndarray): eval_result.sort(key=make_metric_sorter("p-rmse")) return eval_result - def _set_cached_normalizer(self) -> None: - filtered = self.cached_mean_costs[self.cached_mean_costs > 0] - if filtered.size == 0: - self.cached_normalizer = 1.0 - else: - self.cached_normalizer = np.min(filtered) - assert self.cached_normalizer > 0 - def custom_callback( early_stopping_rounds: int, diff --git a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py index 37484226e85b..2a2c20868bb7 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py +++ b/python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py @@ -169,7 +169,7 @@ def main(): target=ARGS.target, params=params, ) - + graph, rt_mod, params = lib.graph_json, lib.lib, lib.params if input_dtype.startswith("float"): input_data = np.random.uniform(size=input_shape).astype(input_dtype) else: @@ -189,9 +189,10 @@ def f_timer(rt_mod, dev, input_data): min_repeat_ms=500, repeat=3, ) - return list(np.array(ftimer().results)) + results = list(np.array(ftimer().results) * 1000.0) # type: ignore + print("Running time in time_evaluator: ", results) - results = run_module_via_rpc( + run_module_via_rpc( rpc_config=ARGS.rpc_config, lib=lib, dev_type=ARGS.target.kind.name, @@ -199,7 +200,28 @@ def f_timer(rt_mod, dev, input_data): continuation=f_timer, ) - print(results) + def f_per_layer(rt_mod, dev, input_data): + # pylint: disable=import-outside-toplevel + from tvm.contrib.debugger.debug_executor import create + + # pylint: enable=import-outside-toplevel + mod = create(graph, rt_mod, dev) + mod.set_input(input_name, input_data) + graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]] + graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000) + print("|graph_nodes| = ", len(graph_nodes)) + print("|graph_time| = ", len(graph_time)) + graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)} + for k, v in graph_nodes_time.items(): + print(f"{k} : {v:.3f}") + + run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=rt_mod, + dev_type=ARGS.target.kind.name, + args=[input_data], + continuation=f_per_layer, + ) if __name__ == "__main__": diff --git a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py index c353684de52c..dde1b1f0489c 100644 --- a/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py +++ b/python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py @@ -18,12 +18,16 @@ import argparse import json import logging +import os import numpy as np # type: ignore import tvm from tvm import meta_schedule as ms +from tvm.ir.transform import PassContext +from tvm.meta_schedule.integration import extract_task_from_relay from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.relay import build as relay_build def _parse_args(): @@ -85,7 +89,7 @@ def _parse_args(): tracker_host=parsed.rpc_host, tracker_port=parsed.rpc_port, tracker_key=parsed.rpc_key, - session_timeout_sec=60, + session_timeout_sec=3600, ) return parsed @@ -95,12 +99,62 @@ def _parse_args(): ARGS = _parse_args() +def tune_each_task( + mod, + target, + config, + runner, + work_dir, + params, +): + extracted_tasks = extract_task_from_relay(mod, target, params) + database = ms.database.JSONDatabase( + path_workload=os.path.join(work_dir, "default_database_workload.json"), + path_tuning_record=os.path.join(work_dir, "default_database_tuning_record.json"), + ) + for task in extracted_tasks: + # pylint: disable=protected-access + tune_context = ms.tune.Parse._tune_context( + tune_context=None, + mod=ms.tune.Parse._mod(task.dispatched[0]), + target=target, + config=config, + task_name=task.task_name, + space_generator=None, + sch_rules=None, + postprocs=None, + mutator_probs=None, + num_threads=os.cpu_count(), + ) + task_scheduler = ms.tune.Parse._task_scheduler( + None, + [tune_context], + builder=ms.tune.Parse._builder(None), + runner=ms.tune.Parse._runner(runner), + database=database, + cost_model=ms.tune.Parse._cost_model(None), + measure_callbacks=ms.tune.Parse._callbacks(None), + ) + # pylint: enable=protected-access + task_scheduler.tune() + with target, ms.integration.ApplyHistoryBest(database): + with PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + return relay_build(mod, target=target, params=params) + + def main(): mod, params, (input_name, input_shape, input_dtype) = get_network( ARGS.workload, ARGS.input_shape, cache_dir=ARGS.cache_dir, ) + print(f"Workload: {ARGS.workload}") + print(f" input_name: {input_name}") + print(f" input_shape: {input_shape}") + print(f" input_dtype: {input_dtype}") alloc_repeat = 1 runner = ms.runner.RPCRunner( rpc_config=ARGS.rpc_config, @@ -113,7 +167,7 @@ def main(): alloc_repeat=alloc_repeat, max_workers=ARGS.rpc_workers, ) - lib = ms.tune_relay( + lib = tune_each_task( # or ms.tune_relay mod=mod, target=ARGS.target, config=ms.EvolutionarySearchConfig( @@ -125,6 +179,7 @@ def main(): work_dir=ARGS.work_dir, params=params, ) + graph, rt_mod, params = lib.graph_json, lib.lib, lib.params if input_dtype.startswith("float"): input_data = np.random.uniform(size=input_shape).astype(input_dtype) else: @@ -144,9 +199,10 @@ def f_timer(rt_mod, dev, input_data): min_repeat_ms=500, repeat=3, ) - return list(np.array(ftimer().results)) + results = list(np.array(ftimer().results) * 1000.0) # type: ignore + print("Running time in time_evaluator: ", results) - results = run_module_via_rpc( + run_module_via_rpc( rpc_config=ARGS.rpc_config, lib=lib, dev_type=ARGS.target.kind.name, @@ -154,7 +210,28 @@ def f_timer(rt_mod, dev, input_data): continuation=f_timer, ) - print(results) + def f_per_layer(rt_mod, dev, input_data): + # pylint: disable=import-outside-toplevel + from tvm.contrib.debugger.debug_executor import create + + # pylint: enable=import-outside-toplevel + mod = create(graph, rt_mod, dev) + mod.set_input(input_name, input_data) + graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]] + graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000) + print("|graph_nodes| = ", len(graph_nodes)) + print("|graph_time| = ", len(graph_time)) + graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)} + for k, v in graph_nodes_time.items(): + print(f"{k} : {v:.3f}") + + run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=rt_mod, + dev_type=ARGS.target.kind.name, + args=[input_data], + continuation=f_per_layer, + ) if __name__ == "__main__": diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 7a90b05f9b97..ba574010152b 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -20,9 +20,9 @@ import os.path from typing import Callable, Dict, List, Optional, Tuple, Union -import tvm from tvm._ffi.registry import register_func from tvm.ir import IRModule, structural_hash +from tvm.ir.transform import PassContext from tvm.relay import Function as RelayFunc from tvm.relay import build as relay_build from tvm.runtime import Module, NDArray @@ -48,6 +48,7 @@ from .space_generator import PostOrderApply, SpaceGenerator from .task_scheduler import RoundRobin, TaskScheduler from .tune_context import TuneContext +from .utils import autotvm_silencer logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -607,7 +608,7 @@ def deduplicate_extracted_tasks( for task in extracted_tasks: assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" - mod = Parse._mod(task.dispatched[0]) + mod = Parse._mod(task.dispatched[0]) # pylint: disable=protected-access shash = structural_hash(mod) if shash in hash2idx: count[hash2idx[shash]] += 1 @@ -714,6 +715,7 @@ def tune_extracted_tasks( ) # pylint: enable=protected-access task_scheduler.tune() + task_scheduler.cost_model.save(os.path.join(work_dir, "cost_model.xgb")) return database @@ -772,6 +774,9 @@ def tune_relay( """ logger.info("Working directory: %s", work_dir) + # pylint: disable=protected-access + target = Parse._target(target) + # parse the tuning contexts extracted_tasks = extract_task_from_relay(mod, target, params) database = tune_extracted_tasks( extracted_tasks, @@ -790,8 +795,8 @@ def tune_relay( mutator_probs=mutator_probs, num_threads=num_threads, ) - with ApplyHistoryBest(database): - with tvm.transform.PassContext( + with target, autotvm_silencer(), ApplyHistoryBest(database): + with PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, ): diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 7d751ea12fcb..6b36ace98586 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -23,7 +23,6 @@ from typing import Any, Callable, List, Optional, Union import psutil # type: ignore -import tvm from tvm._ffi import get_global_func, register_func from tvm.error import TVMError from tvm.ir import Array, IRModule, Map @@ -321,7 +320,7 @@ def batch_json_str2obj(json_strs: List[str]) -> List[Any]: ] -def structural_hash(mod: IRModule) -> str: +def shash2hex(mod: IRModule) -> str: """Get the structural hash of a module. Parameters @@ -334,12 +333,8 @@ def structural_hash(mod: IRModule) -> str: result : str The structural hash of the module. """ - shash = tvm.ir.structural_hash(mod) - if shash < 0: - # Workaround because `structural_hash` returns a size_t, i.e., unsigned integer - # but ffi can't handle unsigned integers properly so it's parsed into a negative number - shash += 1 << 64 - return str(shash) + func = get_global_func("meta_schedule._SHash2Hex") + return str(func(mod)) def _get_default_str(obj: Any) -> str: diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 97f7adce63ed..876145c63fc0 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -195,7 +195,7 @@ def build( # Turn off AutoTVM config not found warnings if auto_scheduler is enabled. old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent - autotvm.GLOBAL_SCOPE.silent = use_auto_scheduler + autotvm.GLOBAL_SCOPE.silent = use_auto_scheduler or old_autotvm_silent mod_name = mangle_module_name(mod_name) diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index 5bd4490d4d49..4e6c9025383f 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -16,19 +16,17 @@ # under the License. """RPC client tools""" import os -import stat import socket +import stat import struct import time import tvm._ffi -from tvm.contrib import utils from tvm._ffi.base import TVMError +from tvm.contrib import utils from tvm.runtime import ndarray as nd -from . import base -from . import server -from . import _ffi_api +from . import _ffi_api, base, server class RPCSession(object): @@ -332,7 +330,7 @@ def text_summary(self): sorted_server = sorted(data["server_info"], key=lambda x: x["key"]) for item in sorted_server: addr = item["addr"] - res += "%21s " % ":".join(addr) + res += "%21s " % ":".join(map(str, addr)) res += item["key"] + "\n" key = item["key"].split(":")[1] # 'server:rasp3b` -> 'rasp3b' if key not in total_ct: diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 3b7fd0200e1e..31a913e80798 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -87,5 +87,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") return TuneContext(mod, target, space_generator, search_strategy, sch_rules, postprocs, mutator_probs, task_name, rand_state, num_threads); }); + +TVM_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex); + } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index bd76ca794a9a..90d1e4755cac 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -149,6 +149,21 @@ inline String JSONObj2Str(const ObjectRef& json_obj) { */ inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } +/*! + * \brief Converts an TVM object to the hex string representation of its structural hash. + * \param obj The TVM object. + * \return The hex string representation of the hash code. + */ +inline String SHash2Hex(const ObjectRef& obj) { + std::ostringstream os; + size_t hash_code = 0; + if (obj.defined()) { + hash_code = StructuralHash()(obj); + } + os << "0x" << std::setw(16) << std::setfill('0') << std::hex << hash_code; + return os.str(); +} + /*! * \brief Find the entry function of the given IRModule, i.e, functions marked by * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index b2e23049713b..621cf5f3264b 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -24,18 +24,17 @@ import numpy as np import pytest - import tvm -from tvm.meta_schedule.cost_model import PyCostModel, RandomModel +from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor from tvm.meta_schedule.runner import RunnerResult -from tvm.meta_schedule.cost_model import XGBModel from tvm.meta_schedule.search_strategy import MeasureCandidate from tvm.meta_schedule.tune_context import TuneContext from tvm.meta_schedule.utils import derived_object from tvm.script import tir as T from tvm.tir.schedule.schedule import Schedule + # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @tvm.script.ir_module class Matmul: @@ -175,25 +174,34 @@ def test_meta_schedule_xgb_model_reload(): [_dummy_result() for i in range(update_sample_count)], ) model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) - random_state = model.extractor.random_state # save feature extractor's random state - path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_xgb_model.bin") - cached = (model.cached_features.copy(), model.cached_mean_costs.copy()) - model.save(path) - res1 = model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) - model.extractor.random_state = random_state # load feature extractor's random state - model.cached_features = None - model.cached_mean_costs = None - model.load(path) - new_cached = (model.cached_features.copy(), model.cached_mean_costs.copy()) - res2 = model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) - shutil.rmtree(os.path.dirname(path)) + with tempfile.NamedTemporaryFile() as path: + # Backup + random_state = model.extractor.random_state # save feature extractor's random state + old_data = model.data + old_data_size = model.data_size + model.save(path.name) + res1 = model.predict( + TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)] + ) + # Load + model.extractor.random_state = random_state # load feature extractor's random state + model.load(path.name) + new_data = model.data + new_data_size = model.data_size + res2 = model.predict( + TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)] + ) assert (res1 == res2).all() - # cached feature does not change - assert len(cached[0]) == len(new_cached[0]) - for i in range(len(cached[0])): - assert (cached[0][i] == new_cached[0][i]).all() - # cached meaen cost does not change - assert (cached[1] == new_cached[1]).all() + assert old_data_size == new_data_size + assert len(old_data) == len(new_data) + for (k1, g1), (k2, g2) in zip(old_data.items(), new_data.items()): + assert k1 == k2 + assert k1 == g1.group_hash + assert k2 == g2.group_hash + assert (g1.costs == g2.costs).all() + assert len(g1.features) == len(g2.features) + for f1, f2 in zip(g1.features, g2.features): + assert (f1 == f2).all() def test_meta_schedule_xgb_model_reupdate():