From 65a8607788c44a8750c473cfa1c0e3713281dc6e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 16 Jul 2021 14:00:02 +0300 Subject: [PATCH] transfer test_lstms to pytest format --- tests/python/frontend/pytorch/test_lstms.py | 438 +++++++++----------- 1 file changed, 193 insertions(+), 245 deletions(-) diff --git a/tests/python/frontend/pytorch/test_lstms.py b/tests/python/frontend/pytorch/test_lstms.py index 9884595d8374a..811363c548cd5 100644 --- a/tests/python/frontend/pytorch/test_lstms.py +++ b/tests/python/frontend/pytorch/test_lstms.py @@ -16,12 +16,13 @@ # under the License. import tvm +import tvm.testing import numpy as np import torch import onnx -import argparse import sys import shutil +import pytest from tvm import relay from tvm.contrib import graph_executor @@ -167,255 +168,202 @@ def get_dummy_input(self): return res, shape -def compare(input, gold_data, epsilon=1e-6): - remain = np.abs(gold_data - input) - err = np.max(remain) - if err < epsilon: - print("SUCCESS: RESULTS ARE THE SAME WITH MAX ERROR {} AND EPSILON {}".format(err, epsilon)) - else: - print("WARNING: RESULTS ARE NOT THE SAME WITH ERROR {}".format(err)) +def compare(input, gold_data, rtol=1e-5, atol=1e-5): + tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol) + # remain = np.abs(gold_data - input) + # err = np.max(remain) + # if err < 1e-6: + # print("SUCCESS: RESULTS ARE THE SAME WITH MAX ERROR {} AND EPSILON {}".format(err, 1e-6)) + # else: + # print("WARNING: RESULTS ARE NOT THE SAME WITH ERROR {}".format(err)) -if __name__ == "__main__": - - class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): - pass - - parser = argparse.ArgumentParser( - description="It constructs neural network which conists of LSTM layer only. " - "But the layer can be different types and adjusted by special parameters (see https://pytorch.org/docs/1.8.0/generated/torch.nn.LSTM.html?highlight=lstm#torch.nn.LSTM). " - "There are several types: unidirectional, bidirectional, projection, stacked, stacked bidirectional, stacked projection bidirectional", - formatter_class=MyFormatter, - ) - parser.add_argument( - "-t", - "--lstm_type", - type=str, - default="uni", - help="Type of lstm layer for test. There are several options: uni (unidirectional), b (bidirectional), " - "p (with projection), s (stacked), sb (stacked bidirectional), sp (stacked with projection), " - "bp (bidirectional with projection), sbp (stacked bidirectional with projection)", - ) - parser.add_argument( - "-f", - "--format", - type=str, - default="ts", - help='Format of the model. There are two options: "ts"(TorchScript) and "onnx"(ONNX). The first one is used by default', - ) - parser.add_argument( - "-l", - "--layer_num", - type=int, - default=model_num_layers, - help="Number of LSTM layers. It is useded for stacked LSTM", - ) - parser.add_argument( - "-p", - "--projection_size", - type=int, - default=projection_size, - help="Projection size is used in LSTM with projection", - ) - parser.add_argument( - "-b", - "--batch_first", - action="store_true", - default=False, - help="Batch first parameter used for LSTM layer initialization", - ) - parser.add_argument( - "-bias", - "--use_bias", - action="store_false", - default=True, - help="Skip using biases weights for LSTM layer", - ) - parser.add_argument( - "-w", - "--rnd_weights", - action="store_true", - default=False, - help="Generate random weights and biases for the model. NOTE: By default All the weights and biases are initialized from " - "\mathcal{U}(-\sqrt{k}, \sqrt{k}), where k = \frac{1}{\text{hidden\_size}}", - ) - parser.add_argument( - "-o", - "--out_dir", - type=Path, - default=argparse.SUPPRESS, - help="Path to directory for saving of intermediate results. NOTE: At the end the directory is removed with all dependencies inside", - ) - - args = parser.parse_args() - if not hasattr(args, "out_dir"): - args.out_dir = Path.cwd().joinpath("output") - args.out_dir.mkdir(exist_ok=True, parents=True) - has_proj = "p" in args.lstm_type +def check_lstm_with_type(lstm_type): + # Create outdir directory to keep temporal files + out_dir = Path.cwd().joinpath("output") + out_dir.mkdir(exist_ok=True, parents=True) + has_proj = "p" in lstm_type device = torch.device("cpu") hidden_layers_num = 1 model = None - if args.lstm_type == "uni": - model = LSTM_Model( - device, - batch_first=args.batch_first, - rnd_weights_init=args.rnd_weights, - use_bias=args.use_bias, - ) - elif args.lstm_type == "b": - model = LSTM_Model( - device, - batch_first=args.batch_first, - bidirectional=True, - rnd_weights_init=args.rnd_weights, - use_bias=args.use_bias, - ) - hidden_layers_num = 2 - elif args.lstm_type == "p": - model = LSTM_Model( - device, - batch_first=args.batch_first, - proj_size=args.projection_size, - rnd_weights_init=args.rnd_weights, - use_bias=args.use_bias, - ) - elif args.lstm_type == "s": - model = LSTM_Model( - device, - batch_first=args.batch_first, - layer_num=args.layer_num, - rnd_weights_init=args.rnd_weights, - use_bias=args.use_bias, - ) - hidden_layers_num = args.layer_num - elif args.lstm_type == "sb": - model = LSTM_Model( - device, - batch_first=args.batch_first, - bidirectional=True, - layer_num=args.layer_num, - rnd_weights_init=args.rnd_weights, - use_bias=args.use_bias, - ) - hidden_layers_num = 2 * args.layer_num - elif args.lstm_type == "sp": - model = LSTM_Model( - device, - batch_first=args.batch_first, - layer_num=args.layer_num, - proj_size=args.projection_size, - rnd_weights_init=args.rnd_weights, - use_bias=args.use_bias, - ) - hidden_layers_num = args.layer_num - elif args.lstm_type == "bp": - model = LSTM_Model( - device, - batch_first=args.batch_first, - bidirectional=True, - proj_size=args.projection_size, - rnd_weights_init=args.rnd_weights, - use_bias=args.use_bias, - ) - hidden_layers_num = 2 * args.layer_num - elif args.lstm_type == "sbp": - model = LSTM_Model( - device, - batch_first=args.batch_first, - bidirectional=True, - layer_num=args.layer_num, - proj_size=args.projection_size, - rnd_weights_init=args.rnd_weights, - use_bias=args.use_bias, - ) - hidden_layers_num = 2 * args.layer_num - else: - print("LSTM type {} is not supported here!".format(args.lstm_type)) - sys.exit() - - model.eval() - - # Get golden output from original model - input_hidden_shape = (hidden_layers_num, batch_size, model_hidden_size) - input_hidden_shape_with_proj = (hidden_layers_num, batch_size, projection_size) - dummy_input, input_shape = model.get_dummy_input() - golden_output_batch = model.forward(dummy_input.to(device)).detach().cpu().numpy() - - dtype = "float32" - h_zeros = np.zeros(input_hidden_shape, dtype=dtype) - if has_proj: - h_zeros = np.zeros(input_hidden_shape_with_proj, dtype=dtype) - c_zeros = np.zeros(input_hidden_shape, dtype=dtype) - - tvm_output = None - if args.format == "ts": - # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. - traced_script_module = torch.jit.trace(model, dummy_input).eval() - - # Import model to Relay - shape_list = [("input", input_shape)] - mod, params = relay.frontend.from_pytorch(traced_script_module, shape_list) - - # Model compilation by tvm - target = tvm.target.Target("llvm", host="llvm") - dev = tvm.cpu(0) - with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target=target, params=params) - elif args.format == "onnx": - if has_proj: - print( - "ERROR: torch.onnx.export does not support conversion LSTM with projection " - "from pytorch! TODO: waiting for the support and correct test after that." - ) - sys.exit() - onnx_fpath = args.out_dir.joinpath("model_{}.onnx".format(args.lstm_type)) - - with torch.no_grad(): - h0 = torch.rand(input_hidden_shape) - if has_proj: - h0 = torch.rand(input_hidden_shape_with_proj) - c0 = torch.rand(input_hidden_shape) - input_names = ["input", "h0", "c0"] - - # default export (without dynamic input) - torch.onnx.export(model, (dummy_input, (h0, c0)), onnx_fpath, input_names=input_names) - - onnx_model = onnx.load(onnx_fpath) - - # Import model to Relay - shape_dict = {"input": input_shape, "h0": input_hidden_shape, "c0": input_hidden_shape} - if has_proj: - shape_dict = { - "input": input_shape, - "h0": input_hidden_shape_with_proj, - "c0": input_hidden_shape, - } - mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) - - # Model compilation by tvm - target = "llvm" - dev = tvm.cpu(0) - with tvm.transform.PassContext(opt_level=1): - lib = relay.build(mod, target=target, params=params) - else: - print("ERROR: {} format is unsupported".format(args.format)) - - # Inference of the model with given input data - m = graph_executor.GraphModule(lib["default"](dev)) - - # Set inputs - m.set_input( - input=tvm.nd.array(dummy_input.numpy().astype(dtype)), - h0=tvm.nd.array(h_zeros), - c0=tvm.nd.array(c_zeros), - ) - # Execute - m.run() - # Get outputs (converted to numpy array) - tvm_output = m.get_output(0).numpy() - - compare(tvm_output, golden_output_batch) + for batch_first in (True, False): + for use_bias in (True, False): + for rnd_weights in (True, False): + if lstm_type == "uni": + model = LSTM_Model( + device, + batch_first=batch_first, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + elif lstm_type == "b": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 + elif lstm_type == "p": + model = LSTM_Model( + device, + batch_first=batch_first, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + elif lstm_type == "s": + model = LSTM_Model( + device, + batch_first=batch_first, + layer_num=model_num_layers, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = model_num_layers + elif lstm_type == "sb": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + layer_num=model_num_layers, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 * model_num_layers + elif lstm_type == "sp": + model = LSTM_Model( + device, + batch_first=batch_first, + layer_num=model_num_layers, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = model_num_layers + elif lstm_type == "bp": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 + elif lstm_type == "sbp": + model = LSTM_Model( + device, + batch_first=batch_first, + bidirectional=True, + layer_num=model_num_layers, + proj_size=projection_size, + rnd_weights_init=rnd_weights, + use_bias=use_bias, + ) + hidden_layers_num = 2 * model_num_layers + else: + print("WARNING: LSTM type {} is not supported here!".format(lstm_type)) + return + + model.eval() + + # Get golden output from original model + input_hidden_shape = (hidden_layers_num, batch_size, model_hidden_size) + input_hidden_shape_with_proj = (hidden_layers_num, batch_size, projection_size) + dummy_input, input_shape = model.get_dummy_input() + golden_output_batch = model.forward(dummy_input.to(device)).detach().cpu().numpy() + + dtype = "float32" + h_zeros = np.zeros(input_hidden_shape, dtype=dtype) + if has_proj: + h_zeros = np.zeros(input_hidden_shape_with_proj, dtype=dtype) + c_zeros = np.zeros(input_hidden_shape, dtype=dtype) + + tvm_output = None + for format in ("ts", "onnx"): + if format == "ts": + # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. + traced_script_module = torch.jit.trace(model, dummy_input).eval() + + # Import model to Relay + shape_list = [("input", input_shape)] + mod, params = relay.frontend.from_pytorch(traced_script_module, shape_list) + + # Model compilation by tvm + target = tvm.target.Target("llvm", host="llvm") + dev = tvm.cpu(0) + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + elif format == "onnx": + if has_proj: + print( + "WARNING: torch.onnx.export does not support conversion LSTM with projection " + "from pytorch! TODO: waiting for the support and correct test after that." + ) + continue + onnx_fpath = out_dir.joinpath("model_{}.onnx".format(lstm_type)) + + with torch.no_grad(): + h0 = torch.rand(input_hidden_shape) + if has_proj: + h0 = torch.rand(input_hidden_shape_with_proj) + c0 = torch.rand(input_hidden_shape) + input_names = ["input", "h0", "c0"] + + # default export (without dynamic input) + torch.onnx.export(model, (dummy_input, (h0, c0)), onnx_fpath, input_names=input_names) + + onnx_model = onnx.load(onnx_fpath) + + # Import model to Relay + shape_dict = {"input": input_shape, "h0": input_hidden_shape, "c0": input_hidden_shape} + if has_proj: + shape_dict = { + "input": input_shape, + "h0": input_hidden_shape_with_proj, + "c0": input_hidden_shape, + } + mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) + + # Model compilation by tvm + target = "llvm" + dev = tvm.cpu(0) + with tvm.transform.PassContext(opt_level=1): + lib = relay.build(mod, target=target, params=params) + + # Inference of the model with given input data + m = graph_executor.GraphModule(lib["default"](dev)) + + # Set inputs + m.set_input( + input=tvm.nd.array(dummy_input.numpy().astype(dtype)), + h0=tvm.nd.array(h_zeros), + c0=tvm.nd.array(c_zeros), + ) + # Execute + m.run() + # Get outputs (converted to numpy array) + tvm_output = m.get_output(0).numpy() + + compare(tvm_output, golden_output_batch) # Remove output directory with tmp files - shutil.rmtree(args.out_dir) + shutil.rmtree(out_dir) + + +def test_lstms(): + check_lstm_with_type("uni") + check_lstm_with_type("p") + check_lstm_with_type("s") + check_lstm_with_type("b") + check_lstm_with_type("bp") + check_lstm_with_type("sp") + check_lstm_with_type("sb") + check_lstm_with_type("sbp") + + +if __name__ == "__main__": + test_lstms()