diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index b4c4e75aa37a..d69e71f47c8f 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -107,12 +107,17 @@ def drive_run(args): rpc_hostname, rpc_port = common.tracker_host_port_from_cli(args.rpc_tracker) + try: + inputs = np.load(args.inputs) if args.inputs else {} + except IOError as ex: + raise TVMCException("Error loading inputs file: %s" % ex) + outputs, times = run_module( args.FILE, rpc_hostname, rpc_port, args.rpc_key, - inputs_file=args.inputs, + inputs=inputs, device=args.device, fill_mode=args.fill_mode, repeat=args.repeat, @@ -221,7 +226,7 @@ def generate_tensor_data(shape, dtype, fill_mode): return tensor -def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode): +def make_inputs_dict(shape_dict, dtype_dict, inputs=None, fill_mode="random"): """Make the inputs dictionary for a graph. Use data from 'inputs' where specified. For input tensors @@ -230,13 +235,13 @@ def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode): Parameters ---------- - inputs_file : str - Path to a .npz file containing the inputs. shape_dict : dict Shape dictionary - {input_name: tuple}. dtype_dict : dict dtype dictionary - {input_name: dtype}. - fill_mode : str + inputs : dict, optional + A dictionary that maps input names to numpy values. + fill_mode : str, optional The fill-mode to use when generating tensor data. Can be either "zeros", "ones" or "random". @@ -247,10 +252,8 @@ def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode): """ logger.debug("creating inputs dict") - try: - inputs = np.load(inputs_file) if inputs_file else {} - except IOError as ex: - raise TVMCException("Error loading inputs file: %s" % ex) + if inputs is None: + inputs = {} # First check all the keys in inputs exist in the graph for input_name in inputs: @@ -291,7 +294,7 @@ def run_module( port=9090, rpc_key=None, device=None, - inputs_file=None, + inputs=None, fill_mode="random", repeat=1, profile=False, @@ -316,8 +319,8 @@ def run_module( device: str, optional the device (e.g. "cpu" or "gpu") to be targeted by the RPC session, local or remote). - inputs_file : str, optional - Path to an .npz file containing the inputs. + inputs : dict, optional + A dictionary that maps input names to numpy values. fill_mode : str, optional The fill-mode to use when generating data for input tensors. Valid options are "zeros", "ones" and "random". @@ -379,7 +382,7 @@ def run_module( module.load_params(params) shape_dict, dtype_dict = get_input_info(graph, params) - inputs_dict = make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode) + inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) logger.debug("setting inputs to the module") module.set_input(**inputs_dict) diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py index 5fdf58fa8d64..366a6df4280f 100644 --- a/tests/python/driver/tvmc/test_runner.py +++ b/tests/python/driver/tvmc/test_runner.py @@ -73,9 +73,11 @@ def test_run_tflite_module__with_profile__valid_input( # some CI environments wont offer TFLite, so skip in case it is not present pytest.importorskip("tflite") + inputs = np.load(imagenet_cat) + outputs, times = tvmc.run( tflite_compiled_module_as_tarfile, - inputs_file=imagenet_cat, + inputs=inputs, hostname=None, device="cpu", profile=True,