Skip to content

Commit

Permalink
[Testing] Enable Target object as argument to _target_to_requirement
Browse files Browse the repository at this point in the history
Previously, tvm.testing._target_to_requirement required the argument
to be a string.  This commit allows it to be either a string or a
`tvm.target.Target`.
  • Loading branch information
Lunderberg committed Aug 3, 2021
1 parent eede865 commit 0019d60
Showing 1 changed file with 29 additions and 15 deletions.
44 changes: 29 additions & 15 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,12 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap):
def _get_targets(target_str=None):
if target_str is None:
target_str = os.environ.get("TVM_TEST_TARGETS", "")
# Use dict instead of set for de-duplication so that the
# targets stay in the order specified.
target_names = list({t.strip(): None for t in target_str.split(";") if t.strip()})

if len(target_str) == 0:
target_str = DEFAULT_TEST_TARGETS

target_names = set(t.strip() for t in target_str.split(";") if t.strip())
if not target_names:
target_names = DEFAULT_TEST_TARGETS

targets = []
for target in target_names:
Expand Down Expand Up @@ -413,10 +414,18 @@ def _get_targets(target_str=None):
return targets


DEFAULT_TEST_TARGETS = (
"llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;"
"llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu"
)
DEFAULT_TEST_TARGETS = [
"llvm",
"llvm -device=arm_cpu",
"cuda",
"nvptx",
"vulkan -from_device=0",
"opencl",
"opencl -device=mali,aocl_sw_emu",
"opencl -device=intel_graphics",
"metal",
"rocm",
]


def device_enabled(target):
Expand Down Expand Up @@ -730,20 +739,25 @@ def requires_rpc(*args):


def _target_to_requirement(target):
if isinstance(target, str):
target_kind = target.split()[0]
else:
target_kind = target.kind.name

# mapping from target to decorator
if target.startswith("cuda"):
if target_kind == "cuda":
return requires_cuda()
if target.startswith("rocm"):
if target_kind == "rocm":
return requires_rocm()
if target.startswith("vulkan"):
if target_kind == "vulkan":
return requires_vulkan()
if target.startswith("nvptx"):
if target_kind == "nvptx":
return requires_nvptx()
if target.startswith("metal"):
if target_kind == "metal":
return requires_metal()
if target.startswith("opencl"):
if target_kind == "opencl":
return requires_opencl()
if target.startswith("llvm"):
if target_kind == "llvm":
return requires_llvm()
return []

Expand Down

0 comments on commit 0019d60

Please sign in to comment.