Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PYTORCH]Minor bug fixes #5683

Merged
merged 3 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import infer_type as _infer_type
from ..prelude import Prelude, StaticTensorArrayOps

Expand Down Expand Up @@ -152,19 +153,33 @@ def _impl(inputs, input_types):

def _arange():
def _impl(inputs, input_types):
def _get_value(val, dtype):
if isinstance(val, _expr.Expr):
return _op.cast(val, _convert_data_type(dtype))
return _create_typed_const(val, dtype)

def _get_type(val, inp_type):
if isinstance(val, _expr.Expr):
dtype = str(_infer_type(val).checked_type)
return dtype if dtype != "float32" else "float"
return inp_type

if len(inputs) == 5:
dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1])
start = _create_typed_const(0, dtype)
stop = _create_typed_const(inputs[0], dtype)
step = _create_typed_const(1, dtype)
dtype0 = _get_type(inputs[0], input_types[0])
dtype = "float" if dtype0 == "float" else _convert_dtype_value(inputs[1])
start = _get_value(0, dtype)
stop = _get_value(inputs[0], dtype)
step = _get_value(1, dtype)
elif len(inputs) == 7:
dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
start = _create_typed_const(inputs[0], dtype)
stop = _create_typed_const(inputs[1], dtype)
step = _create_typed_const(inputs[2], dtype)
types = [_get_type(inputs[i], input_types[i]) for i in range(3)]
dtype = "float" if "float" in types else _convert_dtype_value(inputs[3])
start = _get_value(inputs[0], dtype)
stop = _get_value(inputs[1], dtype)
step = _get_value(inputs[2], dtype)
else:
msg = "Unknown number of arguments (%d) to parse." % (len(inputs))
raise AssertionError(msg)

return _op.transform.arange(start=start,
stop=stop,
step=step,
Expand Down Expand Up @@ -235,12 +250,18 @@ def _impl(inputs, input_types):

begin = [0] * len(end)
dim = int(inputs[1])
begin[dim] = int(inputs[2])
if isinstance(inputs[2], _expr.Call):
begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int))
else:
begin[dim] = int(inputs[2])

if isinstance(inputs[3], str) and inputs[3].isdigit():
end[dim] = min(end[dim], int(inputs[3]))
else:
end[dim] = inputs[3]
if isinstance(inputs[3], _expr.Call):
end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int))
else:
end[dim] = inputs[3]

strides.append(int(inputs[4]))
return _op.transform.strided_slice(data, begin, end, strides)
Expand Down Expand Up @@ -997,7 +1018,10 @@ def _impl(inputs, input_types):
def _numtotensor():
def _impl(inputs, input_types):
val = inputs[0]
dtype = type(val)
dtype = input_types[0]

if isinstance(val, _expr.Expr):
return val

if isinstance(val, tvm.tir.IntImm):
val = val.__int__()
Expand All @@ -1019,16 +1043,22 @@ def _impl(inputs, input_types):
data = inputs[0]

if len(inputs) == 3:
new_shape = [inputs[1], _infer_shape(inputs[2])[0]]
shape_inp = [inputs[1], _infer_shape(inputs[2])[0]]
else:
if isinstance(inputs[1], list):
new_shape = inputs[1]
shape_inp = inputs[1]
else:
new_shape = _infer_shape(inputs[1])
shape_inp = _infer_shape(inputs[1])
new_shape = shape_inp
for i, shape in enumerate(shape_inp):
if isinstance(shape, _expr.Expr):
val = _infer_value_simulated(shape, {})
new_shape[i] = np.asscalar(val.asnumpy())

return _op.transform.reshape(data, new_shape)
return _impl


def _reshape():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down
183 changes: 183 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,28 +381,61 @@ def test_forward_arange():
class Arange1(Module):
def forward(self, *args):
return torch.arange(5)

class Arange2(Module):
def forward(self, *args):
return torch.arange(2.5)

class Arange3(Module):
def forward(self, *args):
return torch.arange(1, 4)

class Arange4(Module):
def forward(self, *args):
return torch.arange(1, 2.5, 0.5)

class Arange5(Module):
def forward(self, *args):
return torch.arange(1, 2, 1, dtype=torch.int32)

class Arange6(Module):
def forward(self, *args):
return torch.arange(start=1, end=6, step=2)

class Arange7(Module):
def forward(self, *args):
return torch.arange(1, 4, dtype=torch.float32)

class Arange8(Module):
def forward(self, *args):
return torch.arange(1, 2, 1, dtype=torch.int16)

class Arange9(Module):
def forward(self, *args):
end = torch.add(torch.tensor(4), 1)
return torch.arange(end) + torch.ones((5,), dtype=torch.int64)

class Arange10(Module):
def forward(self, *args):
end = torch.add(torch.tensor(4.0), torch.tensor(1.0))
return torch.arange(end) + torch.ones((5,), dtype=torch.float)

class Arange11(Module):
def forward(self, *args):
start = torch.add(torch.tensor(1), 1)
end = torch.add(torch.tensor(4), 1)
step = torch.add(torch.tensor(2), 1)
out = torch.arange(start, end, step)
return out + torch.ones((3,), dtype=torch.int64)

class Arange12(Module):
def forward(self, *args):
start = torch.add(torch.tensor(1), 1)
end = torch.add(torch.tensor(4), 1)
step = torch.add(torch.tensor(2.5), torch.tensor(4.1))
out = torch.arange(start, end, step)
return out + torch.ones((3,), dtype=torch.float)

verify_model(Arange1().float().eval())
verify_model(Arange2().float().eval())
verify_model(Arange3().float().eval())
Expand All @@ -411,6 +444,11 @@ def forward(self, *args):
verify_model(Arange6().float().eval())
verify_model(Arange7().float().eval())
verify_model(Arange8().float().eval())
verify_model(Arange9().float().eval())
verify_model(Arange10().float().eval())
verify_model(Arange11().float().eval())
verify_model(Arange12().float().eval())


def test_forward_abs():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -810,9 +848,15 @@ class View2(Module):
def forward(self, *args):
return args[0].view(args[0].shape[0], -1)

class View3(Module):
def forward(self, *args):
d1 = torch.tensor(3) * torch.tensor(10) * torch.tensor(10)
return args[0].view(args[0].shape[0], d1)

input_data = torch.rand(input_shape).float()
verify_model(View1().float().eval(), input_data=input_data)
verify_model(View2().float().eval(), input_data=input_data)
verify_model(View3().float().eval(), input_data=input_data)

def test_forward_select():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -896,9 +940,17 @@ class Slice2(Module):
def forward(self, *args):
return args[0][0, :, :, :]

class Slice3(Module):
def forward(self, *args):
x0 = torch.tensor(2) - torch.tensor(1)
x1 = torch.tensor(3) + torch.tensor(1)
return args[0][:, x0:, :x1, :]

input_data = torch.rand(input_shape).float()
verify_model(Slice1().float().eval(), input_data=input_data)
verify_model(Slice2().float().eval(), input_data=input_data)
verify_model(Slice3().float().eval(), input_data=input_data)


def test_forward_mean():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -2157,6 +2209,134 @@ def forward(self, *args):
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])


def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM
# ---------------------------------------------------
"""
Refer the bert example given in https://pypi.org/project/pytorch-pretrained-bert

# To get started, pretrained bert package needs to be installed as prerequisite.

.. code-block:: bash

# install bert package
pip install pytorch_pretrained_bert==0.6.2 --user
"""

try:
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM
except:
print("Torch pretrained bert package must be installed to run this script.")
return

######################################################################
# Load the tokenizer and tokenize the input
# -----------------------------------------

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenized input
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet',
'##eer', '[SEP]']

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

######################################################################
# Load a pretrained PyTorch model bert-base-uncased
# -------------------------------------------------

# Bert Model with a language modeling
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

######################################################################
# Predict all tokens with pytorch
# -------------------------------

with torch.no_grad():
torch_preds = model(tokens_tensor, segments_tensors)

######################################################################
# Make TorchScripted model via jit trace
# --------------------------------------

scripted_model = torch.jit.trace(model, (tokens_tensor, segments_tensors)).eval()

######################################################################
# Import the graph to Relay
# -------------------------
# Convert PyTorch graph to Relay graph. The input name can be arbitrary.

input_1 = 'input_ids'
input_2 = 'input.2'
shape_list = [(input_1, list(tokens_tensor.shape)),
(input_2, list(segments_tensors.shape))]

mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

######################################################################
# Compile the model with relay
# ----------------------------

target = 'llvm'
with relay.build_config(opt_level=3):
relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params)

######################################################################
# Execute on TVM
# --------------

ctx = tvm.context(target, 0)
relay_model = graph_runtime.create(relay_graph, relay_lib, ctx)
relay_model.set_input(**relay_params)
relay_model.set_input(input_1, tokens_tensor)
relay_model.set_input(input_2, segments_tensors)
relay_model.run()
compiled_output = relay_model.get_output(0).asnumpy()

######################################################################
# Validate the outputs
# --------------------
# Compare the torch and tvm outputs

tvm.testing.assert_allclose(torch_preds, compiled_output, rtol=1e-3, atol=1e-3)

######################################################################
# Process the output
# ------------------
# Process the model output to token.

# Torch output to token
torch_pred_idx = torch.argmax(torch_preds[0, masked_index]).item()
torch_pred_token = tokenizer.convert_ids_to_tokens([torch_pred_idx])[0]

# TVM output to token
tvm_pred_idx = compiled_output[0, masked_index].argmax()
tvm_pred_token = tokenizer.convert_ids_to_tokens([tvm_pred_idx])[0]

assert torch_pred_idx == tvm_pred_idx
assert torch_pred_token == tvm_pred_token

# Print the outputs
print('Torch top-1 id: {}, token: {}'.format(torch_pred_idx, torch_pred_token))
print('TVM top-1 id: {}, token: {}'.format(tvm_pred_idx, tvm_pred_token))


if __name__ == "__main__":
# Single operator tests
test_forward_add()
Expand Down Expand Up @@ -2284,3 +2464,6 @@ def forward(self, *args):
from lstm_test import custom_lstm_test

custom_lstm_test()

# Test bert model
test_forward_pretrained_bert_base_uncased()