Skip to content

Commit

Permalink
[Testing] Auto-target parametrization, handle pytest ParameterSet
Browse files Browse the repository at this point in the history
If the unit test has already been parametrized with pytest.params to
add parameter-specific marks, respect those existing marks.

This can happen in some cases in the CI, uncertain yet what is causing
them.  Maybe pytest-xdist related, but there's some difficulty in
reproducing it locally.
  • Loading branch information
Lunderberg committed Aug 5, 2021
1 parent eb3e2c1 commit c49a86b
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,19 +817,37 @@ def update_parametrize_target_arg(
):
args = [arg.strip() for arg in argnames.split(",") if arg.strip()]
if "target" in args:
if len(args) == 1:
targets = argvalues
param_sets = [(target,) for target in targets]
else:
target_i = args.index("target")
targets = [param_set[target_i] for param_set in argvalues]
param_sets = argvalues
target_i = args.index("target")

new_argvalues = []
for argvalue in argvalues:

if isinstance(argvalue, _pytest.mark.structures.ParameterSet):
# The parametrized value is already a
# pytest.param, so track any marks already
# defined.
param_set = argvalue.values
target = param_set[target_i]
additional_marks = argvalue.marks
elif len(args) == 1:
# Single value parametrization, argvalue is a list of values.
target = argvalue
param_set = (target,)
additional_marks = []
else:
# Multiple correlated parameters, argvalue is a list of tuple of values.
param_set = argvalue
target = param_set[target_i]
additional_marks = []

new_argvalues.append(
pytest.param(
*param_set, marks=_target_to_requirement(target) + additional_marks
)
)

try:
argvalues[:] = [
pytest.param(*param_set, marks=_target_to_requirement(target))
for target, param_set in zip(targets, param_sets)
]
argvalues[:] = new_argvalues
except TypeError as e:
pyfunc = metafunc.definition.function
filename = pyfunc.__code__.co_filename
Expand Down

0 comments on commit c49a86b

Please sign in to comment.