From 7bd43095a0807915b996b4584567d72041b093af Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Wed, 27 May 2020 20:26:40 +0530 Subject: [PATCH 1/3] [PYTORCH]Minor bug fixes --- python/tvm/relay/frontend/pytorch.py | 58 +++++++++++++++++++++------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index cc7cd4830cd4..f68affd82726 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -34,6 +34,7 @@ from .common import get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value +from .common import infer_value_simulated as _infer_value_simulated from .common import infer_type as _infer_type from ..prelude import Prelude, StaticTensorArrayOps @@ -152,19 +153,33 @@ def _impl(inputs, input_types): def _arange(): def _impl(inputs, input_types): + def _get_value(val, dtype): + if isinstance(val, _expr.Expr): + return _op.cast(val, _convert_data_type(dtype)) + return _create_typed_const(val, dtype) + + def _get_type(val, inp_type): + if isinstance(val, _expr.Expr): + dtype = str(_infer_type(val).checked_type) + return dtype if dtype != "float32" else "float" + return inp_type + if len(inputs) == 5: - dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1]) - start = _create_typed_const(0, dtype) - stop = _create_typed_const(inputs[0], dtype) - step = _create_typed_const(1, dtype) + dtype0 = _get_type(inputs[0], input_types[0]) + dtype = "float" if dtype0 == "float" else _convert_dtype_value(inputs[1]) + start = _get_value(0, dtype) + stop = _get_value(inputs[0], dtype) + step = _get_value(1, dtype) elif len(inputs) == 7: - dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3]) - start = _create_typed_const(inputs[0], dtype) - stop = _create_typed_const(inputs[1], dtype) - step = _create_typed_const(inputs[2], dtype) + types = [_get_type(inputs[i], input_types[i]) for i in range(3)] + dtype = "float" if "float" in types else _convert_dtype_value(inputs[3]) + start = _get_value(inputs[0], dtype) + stop = _get_value(inputs[1], dtype) + step = _get_value(inputs[2], dtype) else: msg = "Unknown number of arguments (%d) to parse." % (len(inputs)) raise AssertionError(msg) + return _op.transform.arange(start=start, stop=stop, step=step, @@ -235,12 +250,18 @@ def _impl(inputs, input_types): begin = [0] * len(end) dim = int(inputs[1]) - begin[dim] = int(inputs[2]) + if isinstance(inputs[2], _expr.Call): + begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int)) + else: + begin[dim] = int(inputs[2]) if isinstance(inputs[3], str) and inputs[3].isdigit(): end[dim] = min(end[dim], int(inputs[3])) else: - end[dim] = inputs[3] + if isinstance(inputs[3], _expr.Call): + end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) + else: + end[dim] = inputs[3] strides.append(int(inputs[4])) return _op.transform.strided_slice(data, begin, end, strides) @@ -997,7 +1018,10 @@ def _impl(inputs, input_types): def _numtotensor(): def _impl(inputs, input_types): val = inputs[0] - dtype = type(val) + dtype = input_types[0] + + if isinstance(val, _expr.Expr): + return val if isinstance(val, tvm.tir.IntImm): val = val.__int__() @@ -1019,16 +1043,22 @@ def _impl(inputs, input_types): data = inputs[0] if len(inputs) == 3: - new_shape = [inputs[1], _infer_shape(inputs[2])[0]] + shape_inp = [inputs[1], _infer_shape(inputs[2])[0]] else: if isinstance(inputs[1], list): - new_shape = inputs[1] + shape_inp = inputs[1] else: - new_shape = _infer_shape(inputs[1]) + shape_inp = _infer_shape(inputs[1]) + new_shape = shape_inp + for i, shape in enumerate(shape_inp): + if isinstance(shape, _expr.Expr): + val = _infer_value_simulated(shape, {}) + new_shape[i] = np.asscalar(val.asnumpy()) return _op.transform.reshape(data, new_shape) return _impl + def _reshape(): def _impl(inputs, input_types): data = inputs[0] From 4d9bb476b7c429549a0d692d1b5258a9daf1636c Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Thu, 28 May 2020 17:31:50 +0530 Subject: [PATCH 2/3] Review comment fix, testcase added --- tests/python/frontend/pytorch/test_forward.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 85928bfd60c2..af11b5ad4284 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -381,28 +381,61 @@ def test_forward_arange(): class Arange1(Module): def forward(self, *args): return torch.arange(5) + class Arange2(Module): def forward(self, *args): return torch.arange(2.5) + class Arange3(Module): def forward(self, *args): return torch.arange(1, 4) + class Arange4(Module): def forward(self, *args): return torch.arange(1, 2.5, 0.5) + class Arange5(Module): def forward(self, *args): return torch.arange(1, 2, 1, dtype=torch.int32) + class Arange6(Module): def forward(self, *args): return torch.arange(start=1, end=6, step=2) + class Arange7(Module): def forward(self, *args): return torch.arange(1, 4, dtype=torch.float32) + class Arange8(Module): def forward(self, *args): return torch.arange(1, 2, 1, dtype=torch.int16) + class Arange9(Module): + def forward(self, *args): + end = torch.add(torch.tensor(4), 1) + return torch.arange(end) + torch.ones((5,), dtype=torch.int64) + + class Arange10(Module): + def forward(self, *args): + end = torch.add(torch.tensor(4.0), torch.tensor(1.0)) + return torch.arange(end) + torch.ones((5,), dtype=torch.float) + + class Arange11(Module): + def forward(self, *args): + start = torch.add(torch.tensor(1), 1) + end = torch.add(torch.tensor(4), 1) + step = torch.add(torch.tensor(2), 1) + out = torch.arange(start, end, step) + return out + torch.ones((3,), dtype=torch.int64) + + class Arange12(Module): + def forward(self, *args): + start = torch.add(torch.tensor(1), 1) + end = torch.add(torch.tensor(4), 1) + step = torch.add(torch.tensor(2.5), torch.tensor(4.1)) + out = torch.arange(start, end, step) + return out + torch.ones((3,), dtype=torch.float) + verify_model(Arange1().float().eval()) verify_model(Arange2().float().eval()) verify_model(Arange3().float().eval()) @@ -411,6 +444,11 @@ def forward(self, *args): verify_model(Arange6().float().eval()) verify_model(Arange7().float().eval()) verify_model(Arange8().float().eval()) + verify_model(Arange9().float().eval()) + verify_model(Arange10().float().eval()) + verify_model(Arange11().float().eval()) + verify_model(Arange12().float().eval()) + def test_forward_abs(): torch.set_grad_enabled(False) @@ -810,9 +848,15 @@ class View2(Module): def forward(self, *args): return args[0].view(args[0].shape[0], -1) + class View3(Module): + def forward(self, *args): + d1 = torch.tensor(3) * torch.tensor(10) * torch.tensor(10) + return args[0].view(args[0].shape[0], d1) + input_data = torch.rand(input_shape).float() verify_model(View1().float().eval(), input_data=input_data) verify_model(View2().float().eval(), input_data=input_data) + verify_model(View3().float().eval(), input_data=input_data) def test_forward_select(): torch.set_grad_enabled(False) @@ -896,9 +940,17 @@ class Slice2(Module): def forward(self, *args): return args[0][0, :, :, :] + class Slice3(Module): + def forward(self, *args): + x0 = torch.tensor(2) - torch.tensor(1) + x1 = torch.tensor(3) + torch.tensor(1) + return args[0][:, x0:, :x1, :] + input_data = torch.rand(input_shape).float() verify_model(Slice1().float().eval(), input_data=input_data) verify_model(Slice2().float().eval(), input_data=input_data) + verify_model(Slice3().float().eval(), input_data=input_data) + def test_forward_mean(): torch.set_grad_enabled(False) From 06ac4709f2e5914d4639e4ada614fea8506af7a3 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Fri, 29 May 2020 09:33:59 +0530 Subject: [PATCH 3/3] Added testcase for bert model --- tests/python/frontend/pytorch/test_forward.py | 131 ++++++++++++++++++ 1 file changed, 131 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index af11b5ad4284..6159bb816ccf 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2209,6 +2209,134 @@ def forward(self, *args): verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) +def test_forward_pretrained_bert_base_uncased(): + ###################################################################### + # This is an example how to run BERT models using TVM + # --------------------------------------------------- + """ + Refer the bert example given in https://pypi.org/project/pytorch-pretrained-bert + + # To get started, pretrained bert package needs to be installed as prerequisite. + + .. code-block:: bash + + # install bert package + pip install pytorch_pretrained_bert==0.6.2 --user + """ + + try: + from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM + except: + print("Torch pretrained bert package must be installed to run this script.") + return + + ###################################################################### + # Load the tokenizer and tokenize the input + # ----------------------------------------- + + # Load pre-trained model tokenizer (vocabulary) + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + + # Tokenized input + text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" + tokenized_text = tokenizer.tokenize(text) + + # Mask a token that we will try to predict back with `BertForMaskedLM` + masked_index = 8 + tokenized_text[masked_index] = '[MASK]' + assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', + '##eer', '[SEP]'] + + # Convert token to vocabulary indices + indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) + # Define sentence A and B indices associated to 1st and 2nd sentences (see paper) + segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] + + # Convert inputs to PyTorch tensors + tokens_tensor = torch.tensor([indexed_tokens]) + segments_tensors = torch.tensor([segments_ids]) + + ###################################################################### + # Load a pretrained PyTorch model bert-base-uncased + # ------------------------------------------------- + + # Bert Model with a language modeling + model = BertForMaskedLM.from_pretrained('bert-base-uncased') + model.eval() + + ###################################################################### + # Predict all tokens with pytorch + # ------------------------------- + + with torch.no_grad(): + torch_preds = model(tokens_tensor, segments_tensors) + + ###################################################################### + # Make TorchScripted model via jit trace + # -------------------------------------- + + scripted_model = torch.jit.trace(model, (tokens_tensor, segments_tensors)).eval() + + ###################################################################### + # Import the graph to Relay + # ------------------------- + # Convert PyTorch graph to Relay graph. The input name can be arbitrary. + + input_1 = 'input_ids' + input_2 = 'input.2' + shape_list = [(input_1, list(tokens_tensor.shape)), + (input_2, list(segments_tensors.shape))] + + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + + ###################################################################### + # Compile the model with relay + # ---------------------------- + + target = 'llvm' + with relay.build_config(opt_level=3): + relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params) + + ###################################################################### + # Execute on TVM + # -------------- + + ctx = tvm.context(target, 0) + relay_model = graph_runtime.create(relay_graph, relay_lib, ctx) + relay_model.set_input(**relay_params) + relay_model.set_input(input_1, tokens_tensor) + relay_model.set_input(input_2, segments_tensors) + relay_model.run() + compiled_output = relay_model.get_output(0).asnumpy() + + ###################################################################### + # Validate the outputs + # -------------------- + # Compare the torch and tvm outputs + + tvm.testing.assert_allclose(torch_preds, compiled_output, rtol=1e-3, atol=1e-3) + + ###################################################################### + # Process the output + # ------------------ + # Process the model output to token. + + # Torch output to token + torch_pred_idx = torch.argmax(torch_preds[0, masked_index]).item() + torch_pred_token = tokenizer.convert_ids_to_tokens([torch_pred_idx])[0] + + # TVM output to token + tvm_pred_idx = compiled_output[0, masked_index].argmax() + tvm_pred_token = tokenizer.convert_ids_to_tokens([tvm_pred_idx])[0] + + assert torch_pred_idx == tvm_pred_idx + assert torch_pred_token == tvm_pred_token + + # Print the outputs + print('Torch top-1 id: {}, token: {}'.format(torch_pred_idx, torch_pred_token)) + print('TVM top-1 id: {}, token: {}'.format(tvm_pred_idx, tvm_pred_token)) + + if __name__ == "__main__": # Single operator tests test_forward_add() @@ -2336,3 +2464,6 @@ def forward(self, *args): from lstm_test import custom_lstm_test custom_lstm_test() + + # Test bert model + test_forward_pretrained_bert_base_uncased()