Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keep parameter names from PyTorch #5887

Merged
merged 1 commit into from
Jun 22, 2020
Merged

Conversation

t-vi
Copy link
Contributor

@t-vi t-vi commented Jun 22, 2020

This patch uses PyTorch parameter names as the name hint in variables.
This means that one can load a stored PyTorch state dict and map the required inputs from that even after conversion.

@masahi masahi merged commit 3637164 into apache:master Jun 22, 2020
@masahi
Copy link
Member

masahi commented Jun 22, 2020

Thanks @t-vi

@siju-samuel
Copy link
Member

@t-vi Thanks for the PR.
I think this PR has slight impact in some scenarios.
You can confirm by running the below script before and after your change.

pytorch_pretrained_bert_uncased.py

I think after this PR, the param_tensor and param is not getting assinged properly in some cases..
Please look into this.
Thanks in advance.

@t-vi
Copy link
Contributor Author

t-vi commented Jun 23, 2020

What is the error you are seeing?

@siju-samuel
Copy link
Member

Output mismatch between pytorch and tvm.

Afer this PR

  File "pytorch_pretrained_bert_uncased.py", line 112, in <module>
    tvm.testing.assert_allclose(torch_preds, compiled_output, rtol=1e-3, atol=1e-3)
  File "/home/siju/workspace/tvm/python/tvm/testing.py", line 36, in assert_allclose
    np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)
  File "/home/siju/.local/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 1532, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/siju/.local/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 846, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.001

Mismatched elements: 427297 / 427308 (100%)
Max absolute difference: 24.692068
Max relative difference: 196537.89
 x: array([[[ -7.879808,  -7.787371,  -7.786093, ...,  -7.043789,
          -6.745376,  -4.60134 ],
        [-13.363304, -13.769426, -13.781861, ..., -11.81282 ,...
 y: array([[[-0.419126, -0.420205, -0.41907 , ..., -0.789973, -0.782199,
         -0.496477],
        [-0.419126, -0.420205, -0.41907 , ..., -0.789973, -0.782199,...

@t-vi
Copy link
Contributor Author

t-vi commented Jun 23, 2020

I found the problem and will send a fix.
The old code included the embedding weight twice (which arguably is a bug in itself and needs fixing) and the new code deduplicated the param but not the var (which is even worse).
We underappreciate the structure of the PyTorch input...

@t-vi
Copy link
Contributor Author

t-vi commented Jun 23, 2020

#5897 has the fix.

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 30, 2020
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Jul 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants