diff --git a/tests/python/driver/tvmc/test_autotuner.py b/tests/python/driver/tvmc/test_autotuner.py index a1915a0251e96..66017823a6699 100644 --- a/tests/python/driver/tvmc/test_autotuner.py +++ b/tests/python/driver/tvmc/test_autotuner.py @@ -20,6 +20,7 @@ from unittest import mock from os import path +from pathlib import Path from tvm import autotvm from tvm.driver import tvmc @@ -163,9 +164,16 @@ def test_tune_tasks__invalid_tuner(onnx_mnist, tmpdir_factory): def test_tune_rpc_tracker_parsing(mock_load_model, mock_tune_model, mock_auto_scheduler): cli_args = mock.MagicMock() cli_args.rpc_tracker = "10.0.0.1:9999" + # FILE is not used but it's set to a valid value here to avoid it being set + # by mock to a MagicMock class, which won't pass the checks for valid FILE. + fake_input_file = "./fake_input_file.tflite" + Path(fake_input_file).touch() + cli_args.FILE = fake_input_file tvmc.autotuner.drive_tune(cli_args) + os.remove(fake_input_file) + mock_tune_model.assert_called_once() # inspect the mock call, to search for specific arguments