diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 885eb0d1d0f8d..17d2001e3a84f 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -30,13 +30,14 @@ When we need the dag, we decode the string and call the function, which will return the dag. """ +import json import logging import pickle -import json import tvm._ffi from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON -from .utils import serialize_args, deserialize_args, get_func_name + +from .utils import deserialize_args, get_func_name, serialize_args logger = logging.getLogger("auto_scheduler") @@ -194,7 +195,10 @@ def workload_key_to_tensors(workload_key): assert callable(value) args = deserialize_args(workload[1:]) - return value(*args) + result = value(*args) + if isinstance(result, tuple): + result = list(result) + return result def serialize_workload_registry_entry(workload_key): diff --git a/python/tvm/meta_schedule/runner/config.py b/python/tvm/meta_schedule/runner/config.py index 712766de99c1a..585b88ed9939c 100644 --- a/python/tvm/meta_schedule/runner/config.py +++ b/python/tvm/meta_schedule/runner/config.py @@ -45,7 +45,7 @@ class EvaluatorConfig(NamedTuple): number: int = 3 repeat: int = 1 - min_repeat_ms: int = 40 + min_repeat_ms: int = 100 enable_cpu_cache_flush: bool = False @staticmethod diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py index 66dec30a71dba..5697f85f229e5 100644 --- a/python/tvm/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -16,9 +16,9 @@ # under the License. """RPC Runner""" import concurrent.futures -from contextlib import contextmanager import logging import os.path as osp +from contextlib import contextmanager from typing import Callable, List, Optional, Union from tvm.contrib.popen_pool import PopenPoolExecutor @@ -31,15 +31,14 @@ get_global_func_with_default_on_worker, ) from .config import EvaluatorConfig, RPCConfig -from .runner import PyRunner, RunnerFuture, PyRunnerFuture, RunnerInput, RunnerResult +from .runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult from .utils import ( - T_ARGUMENT_LIST, T_ARG_INFO_JSON_OBJ_LIST, + T_ARGUMENT_LIST, alloc_argument_common, run_evaluator_common, ) - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -118,7 +117,7 @@ def done(self) -> bool: def result(self) -> RunnerResult: try: run_secs: List[float] = self.future.result() - except TimeoutError as exception: + except TimeoutError: return RunnerResult( None, error_msg=f"RPCRunner: Timeout, killed after {self.timeout_sec} seconds", diff --git a/python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py b/python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py new file mode 100644 index 0000000000000..b52f88aaa8763 --- /dev/null +++ b/python/tvm/meta_schedule/testing/run_subgraph_auto_scheduler.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitatios +# under the License. +# pylint: disable=missing-docstring +import argparse +import os + +import tvm +from tvm import auto_scheduler +from tvm.meta_schedule.runner import RPCConfig +from tvm.meta_schedule.testing.te_workload import CONFIGS + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--log-dir", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.rpc_workers = RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=30, + ).count_num_servers(allow_missing=True) + return parsed + + +ARGS = _parse_args() + + +def main(): + log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json") + workload_func, params = CONFIGS[ARGS.workload] + params = params[0] # type: ignore + workload_func = auto_scheduler.register_workload(workload_func) + + if ARGS.target.kind.name == "llvm": + hardware_params = auto_scheduler.HardwareParams( + num_cores=int(ARGS.target.attrs["num-cores"]), + target=ARGS.target, + ) + elif ARGS.target.kind.name == "cuda": + hardware_params = auto_scheduler.HardwareParams( + num_cores=-1, + vector_unit_bytes=16, + cache_line_bytes=64, + max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]), + max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]), + max_vthread_extent=8, + warp_size=32, + ) + else: + raise NotImplementedError(f"Unsupported target {ARGS.target}") + task = auto_scheduler.SearchTask( + func=workload_func, + args=params, + target=ARGS.target, + hardware_params=hardware_params, + ) + runner = auto_scheduler.RPCRunner( + key=ARGS.rpc_key, + host=ARGS.rpc_host, + port=ARGS.rpc_port, + n_parallel=ARGS.rpc_workers, + number=3, + repeat=1, + min_repeat_ms=100, + enable_cpu_cache_flush=False, + ) + + # Inspect the computational graph + print("Computational DAG:") + print(task.compute_dag) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=ARGS.num_trials, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, + runner=runner, + ) + print("Running AutoTuning:") + task.tune(tune_option) + print("History Best:") + print(task.print_best(log_file)) + sch, args = task.apply_best(log_file) + print("Lowered TIR:") + print(tvm.lower(sch, args, simple_mode=True)) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py b/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py new file mode 100644 index 0000000000000..d4166b10f502e --- /dev/null +++ b/python/tvm/meta_schedule/testing/run_subgraph_meta_schedule.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import logging +from os import cpu_count +from typing import Optional + +import tvm +from tvm import meta_schedule as ms +from tvm import tir +from tvm.meta_schedule.testing.te_workload import create_te_workload + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu": + parsed.alloc_repeat = 3 + else: + parsed.alloc_repeat = 1 + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=30, + ) + parsed.rpc_workers = parsed.rpc_config.count_num_servers(allow_missing=False) + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def main(): + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, + evaluator_config=ms.runner.EvaluatorConfig( + number=3, + repeat=1, + min_repeat_ms=100, + enable_cpu_cache_flush=False, + ), + alloc_repeat=ARGS.alloc_repeat, + max_workers=ARGS.rpc_workers, + ) + sch: Optional[tir.Schedule] = ms.tune_tir( + mod=create_te_workload(ARGS.workload, 0), + target=ARGS.target, + config=ms.EvolutionarySearchConfig( + num_trials_per_iter=64, + num_trials_total=ARGS.num_trials, + init_min_unmeasured=50, + ), + runner=runner, # type: ignore + task_name=ARGS.workload, + work_dir=ARGS.work_dir, + num_threads=cpu_count(), + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +if __name__ == "__main__": + main() diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 6b34f69bc0b10..e2c71b7ec164f 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -104,8 +104,7 @@ class VerifyGPUCodeNode : public PostprocNode { ICHECK(context->target.defined()); Target target = context->target.value(); this->target_constraints_ = Map{ - {"max_shared_memory_per_block", Extract(target, "shared_memory_per_block")}, - {"max_local_memory_per_block", Extract(target, "registers_per_block")}, + {"max_shared_memory_per_block", Extract(target, "max_shared_memory_per_block")}, {"max_threads_per_block", Extract(target, "max_threads_per_block")}, {"max_vthread", Integer(8)}, {"max_vector_bytes", Integer(16)}}; diff --git a/src/target/tag.cc b/src/target/tag.cc index a931a288924ec..07a5a5f7c8122 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -70,14 +70,38 @@ Target TargetTag::AddTag(String name, Map config, bool overri /********** Register Target tags **********/ +TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64") + .set_config({{"kind", String("llvm")}, + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("cortex-a72")}, + {"mattr", Array{"+neon"}}, + {"num-cores", Integer(4)}, + {"host", Map{{"kind", String("llvm")}, + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("cortex-a72")}, + {"mattr", Array{"+neon"}}, + {"num-cores", Integer(4)}}}}); + +TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") + .set_config({{"kind", String("cuda")}, + {"arch", String("sm_72")}, + {"max_shared_memory_per_block", Integer(49152)}, + {"max_threads_per_block", Integer(1024)}, + {"thread_warp_size", Integer(32)}, + {"registers_per_block", Integer(65536)}, + {"host", Map{{"kind", String("llvm")}, + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("carmel")}, + {"num-cores", Integer(4)}}}}); + #define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \ TVM_REGISTER_TARGET_TAG(Name).set_config({ \ {"kind", String("cuda")}, \ {"arch", String(Arch)}, \ - {"shared_memory_per_block", Integer(SharedMem)}, \ - {"registers_per_block", Integer(RegPerBlock)}, \ + {"max_shared_memory_per_block", Integer(SharedMem)}, \ {"max_threads_per_block", Integer(1024)}, \ {"thread_warp_size", Integer(32)}, \ + {"registers_per_block", Integer(RegPerBlock)}, \ }); TVM_REGISTER_CUDA_TAG("nvidia/tesla-k80", "sm_37", 49152, 65536); @@ -318,7 +342,6 @@ TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-415m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-480m", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-710m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-410m", "sm_21", 49152, 32768); -TVM_REGISTER_CUDA_TAG("nvidia/jetson-agx-xavier", "sm_72", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/jetson-nano", "sm_53", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx2", "sm_62", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx1", "sm_53", 49152, 32768); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index c562c78bd1874..1131e6e7d2a83 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -286,11 +286,11 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("arch") .add_attr_option("system-lib") - .add_attr_option("max_num_threads", Integer(1024)) + .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_threads_per_block") .add_attr_option("thread_warp_size", Integer(32)) - .add_attr_option("shared_memory_per_block") .add_attr_option("registers_per_block") - .add_attr_option("max_threads_per_block") + .add_attr_option("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it .set_default_keys({"cuda", "gpu"}) .set_attrs_preprocessor(UpdateCUDAAttrs); diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index 2c37731b44f01..db302f4b7e4da 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -17,6 +17,7 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import sys + import pytest import tvm from tvm import tir @@ -380,7 +381,7 @@ def test_postproc_verify_gpu_1(): mod = Conv2dCuda1 ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") - assert not ctx.postprocs[0].apply(sch) + assert ctx.postprocs[0].apply(sch) def test_postproc_verify_gpu_2(): diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index 33f9a969a1049..99cdb86314e76 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -216,7 +216,7 @@ def test_target_tag_0(): tgt = tvm.target.Target("nvidia/geforce-rtx-2080-ti") assert tgt.kind.name == "cuda" assert tgt.attrs["arch"] == "sm_75" - assert tgt.attrs["shared_memory_per_block"] == 49152 + assert tgt.attrs["max_shared_memory_per_block"] == 49152 assert tgt.attrs["max_threads_per_block"] == 1024 assert tgt.attrs["thread_warp_size"] == 32 assert tgt.attrs["registers_per_block"] == 65536 @@ -226,7 +226,7 @@ def test_target_tag_1(): tgt = tvm.target.Target("nvidia/jetson-nano") assert tgt.kind.name == "cuda" assert tgt.attrs["arch"] == "sm_53" - assert tgt.attrs["shared_memory_per_block"] == 49152 + assert tgt.attrs["max_shared_memory_per_block"] == 49152 assert tgt.attrs["max_threads_per_block"] == 1024 assert tgt.attrs["thread_warp_size"] == 32 assert tgt.attrs["registers_per_block"] == 32768 @@ -243,13 +243,13 @@ def test_target_host_tags(): tgt = tvm.target.Target("nvidia/jetson-nano", "nvidia/geforce-rtx-2080-ti") assert tgt.kind.name == "cuda" assert tgt.attrs["arch"] == "sm_53" - assert tgt.attrs["shared_memory_per_block"] == 49152 + assert tgt.attrs["max_shared_memory_per_block"] == 49152 assert tgt.attrs["max_threads_per_block"] == 1024 assert tgt.attrs["thread_warp_size"] == 32 assert tgt.attrs["registers_per_block"] == 32768 assert tgt.host.kind.name == "cuda" assert tgt.host.attrs["arch"] == "sm_75" - assert tgt.host.attrs["shared_memory_per_block"] == 49152 + assert tgt.host.attrs["max_shared_memory_per_block"] == 49152 assert tgt.host.attrs["max_threads_per_block"] == 1024 assert tgt.host.attrs["thread_warp_size"] == 32 assert tgt.host.attrs["registers_per_block"] == 65536 @@ -259,7 +259,7 @@ def test_target_host_tag_dict(): tgt = tvm.target.Target("nvidia/jetson-nano", {"kind": "llvm"}) assert tgt.kind.name == "cuda" assert tgt.attrs["arch"] == "sm_53" - assert tgt.attrs["shared_memory_per_block"] == 49152 + assert tgt.attrs["max_shared_memory_per_block"] == 49152 assert tgt.attrs["max_threads_per_block"] == 1024 assert tgt.attrs["thread_warp_size"] == 32 assert tgt.attrs["registers_per_block"] == 32768 @@ -271,7 +271,7 @@ def test_target_host_single_dict(): assert tgt.kind.name == "llvm" assert tgt.host.kind.name == "cuda" assert tgt.host.attrs["arch"] == "sm_53" - assert tgt.host.attrs["shared_memory_per_block"] == 49152 + assert tgt.host.attrs["max_shared_memory_per_block"] == 49152 assert tgt.host.attrs["max_threads_per_block"] == 1024 assert tgt.host.attrs["thread_warp_size"] == 32 assert tgt.host.attrs["registers_per_block"] == 32768 @@ -288,7 +288,7 @@ def test_target_host_single_string_with_tag(): assert tgt.kind.name == "cuda" assert tgt.host.kind.name == "cuda" assert tgt.host.attrs["arch"] == "sm_53" - assert tgt.host.attrs["shared_memory_per_block"] == 49152 + assert tgt.host.attrs["max_shared_memory_per_block"] == 49152 assert tgt.host.attrs["max_threads_per_block"] == 1024 assert tgt.host.attrs["thread_warp_size"] == 32 assert tgt.host.attrs["registers_per_block"] == 32768 @@ -299,7 +299,7 @@ def test_target_host_merge_0(): assert tgt.kind.name == "cuda" assert tgt.host.kind.name == "cuda" assert tgt.host.attrs["arch"] == "sm_53" - assert tgt.host.attrs["shared_memory_per_block"] == 49152 + assert tgt.host.attrs["max_shared_memory_per_block"] == 49152 assert tgt.host.attrs["max_threads_per_block"] == 1024 assert tgt.host.attrs["thread_warp_size"] == 32 assert tgt.host.attrs["registers_per_block"] == 32768 @@ -346,7 +346,7 @@ def test_target_with_host(): tgt = tgt.with_host(cuda_host) assert tgt.host.kind.name == "cuda" assert tgt.host.attrs["arch"] == "sm_53" - assert tgt.host.attrs["shared_memory_per_block"] == 49152 + assert tgt.host.attrs["max_shared_memory_per_block"] == 49152 assert tgt.host.attrs["max_threads_per_block"] == 1024 assert tgt.host.attrs["thread_warp_size"] == 32 assert tgt.host.attrs["registers_per_block"] == 32768