diff --git a/tests/python/frontend/pytorch/test_rnns.py b/tests/python/frontend/pytorch/test_rnns.py index 18f8e984ac38..b5784a6fe1e1 100644 --- a/tests/python/frontend/pytorch/test_rnns.py +++ b/tests/python/frontend/pytorch/test_rnns.py @@ -17,7 +17,6 @@ import tvm import tvm.testing -import numpy as np import torch import onnx import io @@ -31,22 +30,72 @@ ## LSTM parameters lstm_feature_size = 16 lstm_hidden_size = 32 -lstm_num_layers = 2 -projection_size = 20 +lstm_projection_size = 20 ## GRU parameters gru_feature_size = 8 gru_hidden_size = 16 -gru_num_layers = 2 +num_layers = 2 seqs_length = 2 batch_size = 2 -class GRU_Model(nn.Module): +class RNN_Model(nn.Module): + """ + It is base class for RNN layer classes. + It contains some common fields and methods for child classes. + """ + + def __init__( + self, + ): + super().__init__() + + # model is defined in child class + self.model = None + + def forward(self, input, hidden_init=None): + """ + Computes the output tensor after input inference along RNN layer. + + :param input: batch of data as a tensor of shape (seqs_length, batch_size, feature_size) or (batch_size, seqs_length, feature_size) if self.batch_first = True + :param hidden_init: initial hidden state(s) of the RNN as a tensor(s) of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. + :return: the output tensor of shape (batch_size, hidden_size) + """ + if self.model is None: + raise NotImplementedError("self.model must be defined in subclasses!") + out, _ = self.model(input, hidden_init) + + return out + + def gen_rnd_weights(self): + """ + Generate random weigths for the model + """ + if self.model is None: + raise NotImplementedError("self.model must be defined in subclasses!") + with torch.no_grad(): + for weight_group in self.model.all_weights: + for weight in weight_group: + weight.data = torch.rand(weight.shape) + + def get_dummy_inputs(self): + raise NotImplementedError("subclasses must override get_dummy_inputs()!") + + def get_input_names(self): + raise NotImplementedError("subclasses must override get_input_names()!") + + def get_shape_desc(self, frontend_type): + raise NotImplementedError("subclasses must override get_shape_desc(frontend_type)!") + + def get_tvm_inputs(self, dtype): + raise NotImplementedError("subclasses must override get_tvm_inputs(dtype)!") + + +class GRU_Model(RNN_Model): def __init__( self, - device, seq_len=seqs_length, batch_size=batch_size, feature_size=gru_feature_size, @@ -59,35 +108,27 @@ def __init__( ): super().__init__() - self.batch_first = batch_first - self.seqs_length = seq_len - self.batch_size = batch_size - self.feature_size = feature_size - - self.gru = nn.GRU( - input_size=self.feature_size, + # Shapes + self.shape = [seq_len, batch_size, feature_size] + if batch_first: + self.shape = [batch_size, seq_len, feature_size] + layers_num = 2 * layer_num if bidirectional else layer_num + self.h0_shape = [layers_num, batch_size, hidden_size] + # Dummy inputs + self.dummy_inputs = (torch.rand(self.shape), torch.zeros(self.h0_shape)) + + self.model = nn.GRU( + input_size=feature_size, hidden_size=hidden_size, num_layers=layer_num, bidirectional=bidirectional, batch_first=batch_first, bias=use_bias, - ).to(device) + ) if rnd_weights_init: self.gen_rnd_weights() - def forward(self, input, hidden_init=None): - """ - Computes the output tensor after input inference along GRU layer. - - :param input: batch of data as a tensor of shape (seqs_length, batch_size, feature_size) or (batch_size, seqs_length, feature_size) if self.batch_first = True - :param hidden_init: initial hidden state of the GRU as a tensor of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. - :return: the output tensor of shape (batch_size, hidden_size) - """ - out, hidden = self.gru(input, hidden_init) - - return out - def gen_rnd_weights(self): """ Generate random weigths for the model with biases @@ -103,18 +144,30 @@ def gen_rnd_weights(self): Bh (3*hidden_size) For generation of random weigths for the model without biases the Bi and Bh weights are skipped """ - with torch.no_grad(): - for weight_group in self.gru.all_weights: - for weight in weight_group: - weight.data = torch.rand(weight.shape) + super().gen_rnd_weights() - def get_dummy_input(self): - shape = [self.seqs_length, self.batch_size, self.feature_size] - if self.batch_first: - shape = [self.batch_size, self.seqs_length, self.feature_size] - res = torch.rand(shape) + def get_dummy_inputs(self): + return self.dummy_inputs - return res, shape + def get_input_names(self): + return ["input", "h0"] + + def get_shape_desc(self, frontend_type): + shape_desc = None + if frontend_type == "pt": # PyTorch + shape_desc = [("input", self.shape)] + elif frontend_type == "onnx": # ONNX + shape_desc = { + "input": self.shape, + "h0": self.h0_shape, + } + return shape_desc + + def get_tvm_inputs(self, dtype): + return { + "input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)), + "h0": tvm.nd.array(self.dummy_inputs[1].numpy().astype(dtype)), + } def check_torch_version_for_proj_in_lstm(): @@ -135,10 +188,13 @@ def check_torch_version_for_proj_in_lstm(): return me -class LSTM_Model(nn.Module): +class LSTM_Model(RNN_Model): def __init__( self, - device, + seq_len=seqs_length, + batch_size=batch_size, + feature_size=lstm_feature_size, + hidden_size=lstm_hidden_size, batch_first=False, layer_num=1, bidirectional=False, @@ -148,12 +204,23 @@ def __init__( ): super().__init__() - self.device = device - self.batch_first = batch_first - self.use_bias = use_bias + # Shapes + self.shape = [seq_len, batch_size, feature_size] + if batch_first: + self.shape = [batch_size, seq_len, feature_size] + layers_num = 2 * layer_num if bidirectional else layer_num + self.h0_shape = [layers_num, batch_size, hidden_size] + if proj_size > 0: + self.h0_shape = [layers_num, batch_size, proj_size] + self.c0_shape = [layers_num, batch_size, hidden_size] + # Dummy inputs + self.dummy_inputs = ( + torch.rand(self.shape), + (torch.zeros(self.h0_shape), torch.zeros(self.c0_shape)), + ) if check_torch_version_for_proj_in_lstm(): - self.lstm = nn.LSTM( + self.model = nn.LSTM( input_size=lstm_feature_size, hidden_size=lstm_hidden_size, num_layers=layer_num, @@ -161,7 +228,7 @@ def __init__( proj_size=proj_size, batch_first=batch_first, bias=use_bias, - ).to(device) + ) else: if proj_size > 0: print( @@ -169,32 +236,18 @@ def __init__( "LSTM was constructed without projection!", ) # sys.exit() - self.lstm = nn.LSTM( + self.model = nn.LSTM( input_size=lstm_feature_size, hidden_size=lstm_hidden_size, num_layers=layer_num, bidirectional=bidirectional, batch_first=batch_first, bias=use_bias, - ).to(device) + ) if rnd_weights_init: self.gen_rnd_weights() - def forward(self, input, hidden_init=None): - """ - Computes the output tensor after input inference along LSTM layer. - - :param input: batch of data as a tensor of shape (seqs_length, batch_size, lstm_feature_size) or (batch_size, seqs_length, lstm_feature_size) if self.batch_first = True - :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. - :return: the output tensor of shape (batch_size, lstm_hidden_size) - """ - # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state - # and the final cell state. - out, (hidden, cell) = self.lstm(input, hidden_init) - - return out - def gen_rnd_weights(self): """ Generate random weigths for the model with biases @@ -235,328 +288,143 @@ def gen_rnd_weights(self): P (proj_size, lstm_hidden_size) For generation of random weigths for the model without biases Bi and Bh are skipped """ - with torch.no_grad(): - for weight_group in self.lstm.all_weights: - for weight in weight_group: - weight.data = torch.rand(weight.shape) - - def get_dummy_input(self): - shape = [seqs_length, batch_size, lstm_feature_size] - if self.batch_first: - shape = [batch_size, seqs_length, lstm_feature_size] - res = torch.rand(shape) - - return res, shape + super().gen_rnd_weights() + + def get_dummy_inputs(self): + return self.dummy_inputs + + def get_input_names(self): + return ["input", "h0", "c0"] + + def get_shape_desc(self, frontend_type): + shape_desc = None + if frontend_type == "pt": # PyTorch + shape_desc = [("input", self.shape)] + elif frontend_type == "onnx": # ONNX + shape_desc = { + "input": self.shape, + "h0": self.h0_shape, + "c0": self.c0_shape, + } + return shape_desc + + def get_tvm_inputs(self, dtype): + return { + "input": tvm.nd.array(self.dummy_inputs[0].numpy().astype(dtype)), + "h0": tvm.nd.array(self.dummy_inputs[1][0].numpy().astype(dtype)), + "c0": tvm.nd.array(self.dummy_inputs[1][1].numpy().astype(dtype)), + } def compare(input, gold_data, rtol=1e-5, atol=1e-5): tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol) -def check_gru_with_type(gru_type, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0)): - device = torch.device("cpu") - hidden_layers_num = 1 - model = None - for batch_first in (True, False): - for use_bias in (True, False): - for rnd_weights in [True]: # (True, False): - if gru_type == "uni": - model = GRU_Model( - device, - batch_first=batch_first, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - elif gru_type == "b": - model = GRU_Model( - device, - batch_first=batch_first, - bidirectional=True, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = 2 - elif gru_type == "s": - model = GRU_Model( - device, - batch_first=batch_first, - layer_num=gru_num_layers, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = gru_num_layers - elif gru_type == "sb": - model = GRU_Model( - device, - batch_first=batch_first, - bidirectional=True, - layer_num=gru_num_layers, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = 2 * gru_num_layers - else: - print("WARNING: GRU type {} is not supported here!".format(gru_type)) - return - - model.eval() - - # Get golden output from original model - input_hidden_shape = (hidden_layers_num, batch_size, gru_hidden_size) - dummy_input, input_shape = model.get_dummy_input() - golden_output_batch = model.forward(dummy_input.to(device)).detach().cpu().numpy() - - dtype = "float32" - h_zeros = np.zeros(input_hidden_shape, dtype=dtype) - - tvm_output = None - for format in ["ts"]: # ["ts", "onnx"]: - if format == "ts": - # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. - traced_script_module = torch.jit.trace(model, dummy_input).eval() - - # Import model to Relay - shape_list = [("input", input_shape)] - mod, params = relay.frontend.from_pytorch(traced_script_module, shape_list) - - # Model compilation by tvm - with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target=target, params=params) - elif format == "onnx": - onnx_io = io.BytesIO() - with torch.no_grad(): - h0 = torch.rand(input_hidden_shape) - input_names = ["input", "h0"] - - # default export (without dynamic input) - torch.onnx.export( - model, (dummy_input, h0), onnx_io, input_names=input_names - ) - onnx_io.seek(0, 0) - onnx_model = onnx.load_model(onnx_io) - - # Import model to Relay - shape_dict = { - "input": input_shape, - "h0": input_hidden_shape, - } - mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) - - # Model compilation by tvm - with tvm.transform.PassContext(opt_level=1): - lib = relay.build(mod, target=target, params=params) - - # Inference of the model with given input data - m = graph_executor.GraphModule(lib["default"](dev)) - - # Set inputs - m.set_input( - input=tvm.nd.array(dummy_input.numpy().astype(dtype)), - h0=tvm.nd.array(h_zeros), - ) - # Execute - m.run() - # Get outputs (converted to numpy array) - tvm_output = m.get_output(0).numpy() - - compare(tvm_output, golden_output_batch) +def check_rnn(rnn_type, rnn_mod, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0)): + def get_model( + rnn_type, + rnn_mod, + args, + ): + # Fill args + if "b" in rnn_mod: + args["bidirectional"] = True + if "s" in rnn_mod: + args["layer_num"] = num_layers + + if rnn_type == "GRU": + RNN_Model_selector = GRU_Model + elif rnn_type == "LSTM": + RNN_Model_selector = LSTM_Model + if "p" in rnn_mod: + args["proj_size"] = lstm_projection_size + + return RNN_Model_selector(**args) + + def get_onnx_model(model): + onnx_io = io.BytesIO() + with torch.no_grad(): + input_names = model.get_input_names() + inputs = model.get_dummy_inputs() + # default export (without dynamic input) + torch.onnx.export(model, inputs, onnx_io, input_names=input_names) -def check_lstm_with_type( - lstm_type, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0) -): - has_proj = "p" in lstm_type + onnx_io.seek(0, 0) + return onnx.load_model(onnx_io) - device = torch.device("cpu") - hidden_layers_num = 1 model = None + dtype = "float32" + device = torch.device("cpu") for batch_first in (True, False): for use_bias in (True, False): for rnd_weights in [True]: # (True, False): - if lstm_type == "uni": - model = LSTM_Model( - device, - batch_first=batch_first, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - elif lstm_type == "b": - model = LSTM_Model( - device, - batch_first=batch_first, - bidirectional=True, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = 2 - elif lstm_type == "p": - model = LSTM_Model( - device, - batch_first=batch_first, - proj_size=projection_size, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - elif lstm_type == "s": - model = LSTM_Model( - device, - batch_first=batch_first, - layer_num=lstm_num_layers, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = lstm_num_layers - elif lstm_type == "sb": - model = LSTM_Model( - device, - batch_first=batch_first, - bidirectional=True, - layer_num=lstm_num_layers, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = 2 * lstm_num_layers - elif lstm_type == "sp": - model = LSTM_Model( - device, - batch_first=batch_first, - layer_num=lstm_num_layers, - proj_size=projection_size, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = lstm_num_layers - elif lstm_type == "bp": - model = LSTM_Model( - device, - batch_first=batch_first, - bidirectional=True, - proj_size=projection_size, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = 2 - elif lstm_type == "sbp": - model = LSTM_Model( - device, - batch_first=batch_first, - bidirectional=True, - layer_num=lstm_num_layers, - proj_size=projection_size, - rnd_weights_init=rnd_weights, - use_bias=use_bias, - ) - hidden_layers_num = 2 * lstm_num_layers - else: - print("WARNING: LSTM type {} is not supported here!".format(lstm_type)) - return - + model_inputs = { + "batch_first": batch_first, + "use_bias": use_bias, + "rnd_weights_init": rnd_weights, + } + model = get_model(rnn_type, rnn_mod, model_inputs) + model.to(device) model.eval() # Get golden output from original model - input_hidden_shape = (hidden_layers_num, batch_size, lstm_hidden_size) - input_hidden_shape_with_proj = (hidden_layers_num, batch_size, projection_size) - dummy_input, input_shape = model.get_dummy_input() - golden_output_batch = model.forward(dummy_input.to(device)).detach().cpu().numpy() - - dtype = "float32" - h_zeros = np.zeros(input_hidden_shape, dtype=dtype) - if has_proj: - h_zeros = np.zeros(input_hidden_shape_with_proj, dtype=dtype) - c_zeros = np.zeros(input_hidden_shape, dtype=dtype) + dummy_inputs = model.get_dummy_inputs() + golden_output = model.forward(dummy_inputs[0].to(device)).detach().cpu().numpy() tvm_output = None - for format in ["ts"]: # ["ts", "onnx"]: - if format == "ts": + for format in ["pt"]: # ["pt", "onnx"]: + shape_desc = model.get_shape_desc(format) + if format == "pt": # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. - traced_script_module = torch.jit.trace(model, dummy_input).eval() + traced_script_module = torch.jit.trace(model, dummy_inputs[0]).eval() # Import model to Relay - shape_list = [("input", input_shape)] - mod, params = relay.frontend.from_pytorch(traced_script_module, shape_list) - - # Model compilation by tvm - with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target=target, params=params) + mod, params = relay.frontend.from_pytorch(traced_script_module, shape_desc) elif format == "onnx": - if has_proj: + try: + onnx_model = get_onnx_model(model) + except: print( "WARNING: torch.onnx.export does not support conversion LSTM with projection " "from pytorch! TODO: waiting for the support and correct test after that." ) continue - onnx_io = io.BytesIO() - with torch.no_grad(): - h0 = torch.rand(input_hidden_shape) - if has_proj: - h0 = torch.rand(input_hidden_shape_with_proj) - c0 = torch.rand(input_hidden_shape) - input_names = ["input", "h0", "c0"] - - # default export (without dynamic input) - torch.onnx.export( - model, (dummy_input, (h0, c0)), onnx_io, input_names=input_names - ) - onnx_io.seek(0, 0) - onnx_model = onnx.load_model(onnx_io) # Import model to Relay - shape_dict = { - "input": input_shape, - "h0": input_hidden_shape, - "c0": input_hidden_shape, - } - if has_proj: - shape_dict = { - "input": input_shape, - "h0": input_hidden_shape_with_proj, - "c0": input_hidden_shape, - } - mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) - - # Model compilation by tvm - with tvm.transform.PassContext(opt_level=1): - lib = relay.build(mod, target=target, params=params) + mod, params = relay.frontend.from_onnx(onnx_model, shape_desc) + + # Model compilation by tvm + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) # Inference of the model with given input data m = graph_executor.GraphModule(lib["default"](dev)) # Set inputs - m.set_input( - input=tvm.nd.array(dummy_input.numpy().astype(dtype)), - h0=tvm.nd.array(h_zeros), - c0=tvm.nd.array(c_zeros), - ) + tvm_inputs = model.get_tvm_inputs(dtype) + m.set_input(**tvm_inputs) # Execute m.run() # Get outputs (converted to numpy array) tvm_output = m.get_output(0).numpy() - compare(tvm_output, golden_output_batch) - - -@tvm.testing.uses_gpu -def test_grus(): - for target, dev in tvm.testing.enabled_targets(): - check_gru_with_type("uni", target, dev) - check_gru_with_type("s", target, dev) - check_gru_with_type("b", target, dev) - check_gru_with_type("sb", target, dev) + compare(tvm_output, golden_output) @tvm.testing.uses_gpu -def test_lstms(): +def test_rnns(): for target, dev in tvm.testing.enabled_targets(): - check_lstm_with_type("uni", target, dev) - # check_lstm_with_type("p", target, dev) - check_lstm_with_type("s", target, dev) - check_lstm_with_type("b", target, dev) - # check_lstm_with_type("bp", target, dev) - # check_lstm_with_type("sp", target, dev) - check_lstm_with_type("sb", target, dev) - # check_lstm_with_type("sbp", target, dev) + # RNN types: GRU, LSTM + # GRU modifications: unidirectional, stacked, bidirectional, stacked bidirectional + for mod_type in ["uni", "s", "b", "sb"]: + check_rnn("GRU", mod_type, target, dev) + # LSTM modifications: unidirectional, stacked, bidirectional, stacked bidirectional, + # and all these types with projection ("p", "sp", "bp", "sbp") + # The latter are skiped for test acceleration + for mod_type in ["uni", "s", "b", "sb"]: + check_rnn("LSTM", mod_type, target, dev) if __name__ == "__main__": - test_lstms() - test_grus() + test_rnns()