diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index 388f9dbb43cd..1c25b2053bc2 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -27,6 +27,7 @@ from tvm import meta_schedule as ms from . import backend, transform +from .backend.utils import BackendDispatcher def zero_pipeline(*, enable_warning: bool = False): @@ -270,7 +271,8 @@ def library_dispatch_passes(target: tvm.target.Target): return backend.cpu_generic.library_dispatch_passes(target) if target.kind.name == "opencl" and "adreno" in target.keys: return backend.adreno.library_dispatch_passes(target) - # Todo(tvm-team): support gpu-generic + if BackendDispatcher.is_gpu_target(target): + return backend.gpu_generic.library_dispatch_passes(target) raise ValueError(f"Target {target} is not yet supported by library dispatch passes.") @@ -286,8 +288,9 @@ def legalize_passes(target: tvm.target.Target): return backend.cpu_generic.legalize_passes(target) if target.kind.name == "opencl" and "adreno" in target.keys: return backend.adreno.legalize_passes(target) - # Todo(tvm-team): support gpu-generic - raise ValueError(f"Target {target} is not yet supported by library dispatch passes.") + if BackendDispatcher.is_gpu_target(target): + return backend.gpu_generic.legalize_passes(target) + raise ValueError(f"Target {target} is not yet supported by legalize passes.") def dataflow_lower_passes(target: tvm.target.Target): @@ -302,7 +305,8 @@ def dataflow_lower_passes(target: tvm.target.Target): return backend.cpu_generic.dataflow_lower_passes(target) if target.kind.name == "opencl" and "adreno" in target.keys: return backend.adreno.dataflow_lower_passes(target) - # Todo(tvm-team): support gpu-generic + if BackendDispatcher.is_gpu_target(target): + return backend.gpu_generic.dataflow_lower_passes(target) raise ValueError(f"Target {target} is not yet supported by dataflow lowering passes.") @@ -318,7 +322,8 @@ def finalize_passes(target: tvm.target.Target): return backend.cpu_generic.finalize_passes(target) if target.kind.name == "opencl" and "adreno" in target.keys: return backend.adreno.finalize_passes(target) - # Todo(tvm-team): support gpu-generic + if BackendDispatcher.is_gpu_target(target): + return backend.gpu_generic.finalize_passes(target) raise ValueError(f"Target {target} is not yet supported by finalization passes.") @@ -334,7 +339,8 @@ def get_default_pipeline(target: tvm.target.Target): return backend.cpu_generic.get_default_pipeline(target) if target.kind.name == "opencl" and "adreno" in target.keys: return backend.adreno.get_default_pipeline(target) - # Todo(tvm-team): support gpu-generic + if BackendDispatcher.is_gpu_target(target): + return backend.gpu_generic.get_default_pipeline(target) raise ValueError( f"Target {target} is not yet supported by default pipeline. " "Please lower and build the IRModule manually." diff --git a/tests/python/relax/test_pipeline.py b/tests/python/relax/test_pipeline.py index f9bce3539645..482c45fbdd85 100644 --- a/tests/python/relax/test_pipeline.py +++ b/tests/python/relax/test_pipeline.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm import tvm.testing @@ -113,3 +114,37 @@ def main( cache_np[i, :] = x_np + y_np tvm.testing.assert_allclose(kv.numpy(), cache_np[: np_shape[0], :], rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("target_name", ["vulkan", "webgpu"]) +@pytest.mark.parametrize( + "pipeline_func", + [ + relax.pipeline.library_dispatch_passes, + relax.pipeline.legalize_passes, + relax.pipeline.dataflow_lower_passes, + relax.pipeline.finalize_passes, + relax.pipeline.get_default_pipeline, + ], +) +def test_gpu_generic_fallback(target_name, pipeline_func): + target = tvm.target.Target(target_name) + result = pipeline_func(target) + assert result is not None + + +@pytest.mark.parametrize("target_name", ["hexagon", "c"]) +@pytest.mark.parametrize( + "pipeline_func", + [ + relax.pipeline.library_dispatch_passes, + relax.pipeline.legalize_passes, + relax.pipeline.dataflow_lower_passes, + relax.pipeline.finalize_passes, + relax.pipeline.get_default_pipeline, + ], +) +def test_non_gpu_target_raises_error(target_name, pipeline_func): + target = tvm.target.Target(target_name) + with pytest.raises(ValueError, match="not yet supported"): + pipeline_func(target)