Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions python/tvm/relax/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tvm import meta_schedule as ms

from . import backend, transform
from .backend.utils import BackendDispatcher


def zero_pipeline(*, enable_warning: bool = False):
Expand Down Expand Up @@ -270,7 +271,8 @@ def library_dispatch_passes(target: tvm.target.Target):
return backend.cpu_generic.library_dispatch_passes(target)
if target.kind.name == "opencl" and "adreno" in target.keys:
return backend.adreno.library_dispatch_passes(target)
# Todo(tvm-team): support gpu-generic
if BackendDispatcher.is_gpu_target(target):
return backend.gpu_generic.library_dispatch_passes(target)
raise ValueError(f"Target {target} is not yet supported by library dispatch passes.")


Expand All @@ -286,8 +288,9 @@ def legalize_passes(target: tvm.target.Target):
return backend.cpu_generic.legalize_passes(target)
if target.kind.name == "opencl" and "adreno" in target.keys:
return backend.adreno.legalize_passes(target)
# Todo(tvm-team): support gpu-generic
raise ValueError(f"Target {target} is not yet supported by library dispatch passes.")
if BackendDispatcher.is_gpu_target(target):
return backend.gpu_generic.legalize_passes(target)
raise ValueError(f"Target {target} is not yet supported by legalize passes.")


def dataflow_lower_passes(target: tvm.target.Target):
Expand All @@ -302,7 +305,8 @@ def dataflow_lower_passes(target: tvm.target.Target):
return backend.cpu_generic.dataflow_lower_passes(target)
if target.kind.name == "opencl" and "adreno" in target.keys:
return backend.adreno.dataflow_lower_passes(target)
# Todo(tvm-team): support gpu-generic
if BackendDispatcher.is_gpu_target(target):
return backend.gpu_generic.dataflow_lower_passes(target)
raise ValueError(f"Target {target} is not yet supported by dataflow lowering passes.")


Expand All @@ -318,7 +322,8 @@ def finalize_passes(target: tvm.target.Target):
return backend.cpu_generic.finalize_passes(target)
if target.kind.name == "opencl" and "adreno" in target.keys:
return backend.adreno.finalize_passes(target)
# Todo(tvm-team): support gpu-generic
if BackendDispatcher.is_gpu_target(target):
return backend.gpu_generic.finalize_passes(target)
raise ValueError(f"Target {target} is not yet supported by finalization passes.")


Expand All @@ -334,7 +339,8 @@ def get_default_pipeline(target: tvm.target.Target):
return backend.cpu_generic.get_default_pipeline(target)
if target.kind.name == "opencl" and "adreno" in target.keys:
return backend.adreno.get_default_pipeline(target)
# Todo(tvm-team): support gpu-generic
if BackendDispatcher.is_gpu_target(target):
return backend.gpu_generic.get_default_pipeline(target)
raise ValueError(
f"Target {target} is not yet supported by default pipeline. "
"Please lower and build the IRModule manually."
Expand Down
35 changes: 35 additions & 0 deletions tests/python/relax/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest

import tvm
import tvm.testing
Expand Down Expand Up @@ -113,3 +114,37 @@ def main(

cache_np[i, :] = x_np + y_np
tvm.testing.assert_allclose(kv.numpy(), cache_np[: np_shape[0], :], rtol=1e-7, atol=1e-7)


@pytest.mark.parametrize("target_name", ["vulkan", "webgpu"])
@pytest.mark.parametrize(
"pipeline_func",
[
relax.pipeline.library_dispatch_passes,
relax.pipeline.legalize_passes,
relax.pipeline.dataflow_lower_passes,
relax.pipeline.finalize_passes,
relax.pipeline.get_default_pipeline,
],
Comment on lines +121 to +128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This list of pipeline functions is duplicated in test_non_gpu_target_raises_error below. To improve maintainability and avoid this duplication, consider extracting the list into a module-level constant and reusing it in both pytest.mark.parametrize decorators.

For example:

PIPELINE_FUNCS_FOR_TESTING = [
    relax.pipeline.library_dispatch_passes,
    relax.pipeline.legalize_passes,
    relax.pipeline.dataflow_lower_passes,
    relax.pipeline.finalize_passes,
    relax.pipeline.get_default_pipeline,
]


@pytest.mark.parametrize("pipeline_func", PIPELINE_FUNCS_FOR_TESTING)
# ...

)
def test_gpu_generic_fallback(target_name, pipeline_func):
target = tvm.target.Target(target_name)
result = pipeline_func(target)
assert result is not None


@pytest.mark.parametrize("target_name", ["hexagon", "c"])
@pytest.mark.parametrize(
"pipeline_func",
[
relax.pipeline.library_dispatch_passes,
relax.pipeline.legalize_passes,
relax.pipeline.dataflow_lower_passes,
relax.pipeline.finalize_passes,
relax.pipeline.get_default_pipeline,
],
)
def test_non_gpu_target_raises_error(target_name, pipeline_func):
target = tvm.target.Target(target_name)
with pytest.raises(ValueError, match="not yet supported"):
pipeline_func(target)
Loading