Skip to content

Commit

Permalink
[TVMC] Allow direct numpy inputs to run_module (apache#7788)
Browse files Browse the repository at this point in the history
* progress, graph params need to figure out

* black and lint

* change np.load(inputs_file) to happen in drive_run

* make inputs optional

Co-authored-by: Jocelyn <jocelyn@pop-os.localdomain>
  • Loading branch information
2 people authored and Trevor Morris committed May 6, 2021
1 parent 29dfd56 commit 7c02da3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
29 changes: 16 additions & 13 deletions python/tvm/driver/tvmc/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,17 @@ def drive_run(args):

rpc_hostname, rpc_port = common.tracker_host_port_from_cli(args.rpc_tracker)

try:
inputs = np.load(args.inputs) if args.inputs else {}
except IOError as ex:
raise TVMCException("Error loading inputs file: %s" % ex)

outputs, times = run_module(
args.FILE,
rpc_hostname,
rpc_port,
args.rpc_key,
inputs_file=args.inputs,
inputs=inputs,
device=args.device,
fill_mode=args.fill_mode,
repeat=args.repeat,
Expand Down Expand Up @@ -221,7 +226,7 @@ def generate_tensor_data(shape, dtype, fill_mode):
return tensor


def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode):
def make_inputs_dict(shape_dict, dtype_dict, inputs=None, fill_mode="random"):
"""Make the inputs dictionary for a graph.
Use data from 'inputs' where specified. For input tensors
Expand All @@ -230,13 +235,13 @@ def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode):
Parameters
----------
inputs_file : str
Path to a .npz file containing the inputs.
shape_dict : dict
Shape dictionary - {input_name: tuple}.
dtype_dict : dict
dtype dictionary - {input_name: dtype}.
fill_mode : str
inputs : dict, optional
A dictionary that maps input names to numpy values.
fill_mode : str, optional
The fill-mode to use when generating tensor data.
Can be either "zeros", "ones" or "random".
Expand All @@ -247,10 +252,8 @@ def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode):
"""
logger.debug("creating inputs dict")

try:
inputs = np.load(inputs_file) if inputs_file else {}
except IOError as ex:
raise TVMCException("Error loading inputs file: %s" % ex)
if inputs is None:
inputs = {}

# First check all the keys in inputs exist in the graph
for input_name in inputs:
Expand Down Expand Up @@ -291,7 +294,7 @@ def run_module(
port=9090,
rpc_key=None,
device=None,
inputs_file=None,
inputs=None,
fill_mode="random",
repeat=1,
profile=False,
Expand All @@ -316,8 +319,8 @@ def run_module(
device: str, optional
the device (e.g. "cpu" or "gpu") to be targeted by the RPC
session, local or remote).
inputs_file : str, optional
Path to an .npz file containing the inputs.
inputs : dict, optional
A dictionary that maps input names to numpy values.
fill_mode : str, optional
The fill-mode to use when generating data for input tensors.
Valid options are "zeros", "ones" and "random".
Expand Down Expand Up @@ -379,7 +382,7 @@ def run_module(
module.load_params(params)

shape_dict, dtype_dict = get_input_info(graph, params)
inputs_dict = make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode)
inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode)

logger.debug("setting inputs to the module")
module.set_input(**inputs_dict)
Expand Down
4 changes: 3 additions & 1 deletion tests/python/driver/tvmc/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def test_run_tflite_module__with_profile__valid_input(
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

inputs = np.load(imagenet_cat)

outputs, times = tvmc.run(
tflite_compiled_module_as_tarfile,
inputs_file=imagenet_cat,
inputs=inputs,
hostname=None,
device="cpu",
profile=True,
Expand Down

0 comments on commit 7c02da3

Please sign in to comment.