Skip to content

Commit

Permalink
PyTorch frontend: fix handling of duplicate use of a model weight
Browse files Browse the repository at this point in the history
This happens e.g. in shared input/output embeddings in BERT
or siamese networks.

Thank you @siju-samuel for reporting.
  • Loading branch information
t-vi committed Jun 23, 2020
1 parent b94e8b7 commit c7399e0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
13 changes: 9 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2335,6 +2335,7 @@ def convert_params(graph, state_dict):
params = {}
param_tensors = {}
packed_param_map = {}
vars = {}
seen = set()

for node in getattr_nodes:
Expand All @@ -2352,10 +2353,14 @@ def convert_params(graph, state_dict):
assert full_attr in state_dict, err_msg
packed_param_map[full_attr_node_name] = full_attr
elif full_attr in state_dict:
torch_tensor = state_dict[full_attr]
tensor, var = _get_tensor_and_var(torch_tensor,
full_attr)
param_tensors[full_attr] = tensor
if full_attr in vars:
var = vars[full_attr]
else:
torch_tensor = state_dict[full_attr]
tensor, var = _get_tensor_and_var(torch_tensor,
full_attr)
param_tensors[full_attr] = tensor
vars[full_attr] = var
params[full_attr_node_name] = var

return params, param_tensors, packed_param_map
Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2390,6 +2390,23 @@ def test_weight_names():
assert set(params.keys()) == set(n for n, p in tm.named_parameters())


def test_duplicate_weight_use():
# The test cases doesn't make any sense as a neural network,
# the issue popped up in shared input/output embeddings of bert,
# but this is quicker
class Test(Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(5, 3)

def forward(self, x):
x = self.lin(x)
x = x @ self.lin.weight
return x

verify_model(Test(), input_data=[torch.randn(5, 5)])


def test_forward_matmul():
torch.set_grad_enabled(False)

Expand Down Expand Up @@ -2556,6 +2573,7 @@ def test_forward_pretrained_bert_base_uncased():
test_forward_traced_function()
test_forward_dtypes()
test_weight_names()
test_duplicate_weight_use()

# Single operator tests
test_forward_add()
Expand Down

0 comments on commit c7399e0

Please sign in to comment.