Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Add support for ScatterElements operator during compilation #531

Open
fabecode opened this issue Mar 7, 2024 · 18 comments

Comments

@fabecode
Copy link

fabecode commented Mar 7, 2024

Summary

I have implemented a GNN model, quantised using Brevitas, and compiled using 'compile_brevitas_qat_model'.
During the compilation, I face the following error: "ValueError: The following ONNX operators are required to convert the torch model to numpy but are not currently implemented: ScatterElements."

Would it be possible to add ScatterElements as a supported operator, or is there any suggested workarounds? Thank you!

@bcm-at-zama
Copy link
Collaborator

Hello, the team is remote for the end of the week, so we'll answer next week. Would it be possible you share your code, to simplify the reproduction of the error? Or, at least, could you show us your model with Brevitas layers?

@fabecode
Copy link
Author

fabecode commented Mar 7, 2024

Hello, I have just emailed hello@zama.ai the GNN model with Brevitas layers code. I would be grateful if you could help to take a look. Thank you!

@bcm-at-zama
Copy link
Collaborator

Great, we'll have a look! As I said, it's for next week, you'll have to be a bit patient :)

@andrei-stoian-zama
Copy link
Collaborator

Hi! thanks for sending the code! Great job converting your network to Brevitas. Here are a few answers:

  • While the code gives us good information, having a reproducible code that shows the inputs (x, edge_index, that you pass to your model would help. It would allow us to inspect the ONNX that is generated. Can you add that information please?
  • We have not yet implemented ScatterElements, which I guess comes from x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu(). It would take some time to do and it would only work if you index all of the elements in x with that call. Is that the case?
  • You also concatenate two tensors here: x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1) . This will probably fail with quantization since the two concatenated tensors will not have matching quantization parameters. This can be fixed by instantiating a QuantIdentity layer in your model init and applying it to both inputs that are concatenated.

`

@bcm-at-zama
Copy link
Collaborator

And obviously @fabecode, for the information @andrei-stoian-zama is asking you: you can send to my email address if you prefer things to be confidential. Just make sure it opens well, we've had issues before.

@fabecode
Copy link
Author

fabecode commented Mar 9, 2024

Hi Andrei and Benoit, thank you for the prompt reply! Here are my responses to the questions:

  • I have emailed the link to my code base as well as the ONNX file of my quantised model to @bcm-at-zama .
  • I believe so. Before the reshape, here are my input shapes - x: torch.Size([9142, 66]) and edge_attr: torch.Size([5851, 66])
  • Thank you for the advice! Could you kindly advise where to insert the QuantIdentity layer? When I tried to add QuantIdentity, I faced the "The quantizer passed does not adhere to the quantization protocol" error.

@fabecode
Copy link
Author

Hello @andrei-stoian-zama @bcm-at-zama, I would also like to inquire if there is an anticipated timeline for the implementation of Scatter Elements? And if there is any way I could try to contribute to the implementation process? Thank you!

@andrei-stoian-zama
Copy link
Collaborator

Hello again and thank you for the code !

Unfortunately I can't give a deadline for this. It would take a bit of work in both Concrete ML and Concrete.

@fabecode
Copy link
Author

I see, thanks again!

@bcm-at-zama
Copy link
Collaborator

Yes @fabecode , sorry but it's certainly a complicated task to do for an external person. We should have it in 2024, hopefully!

If you want to contribute, can we encourage you to participate to our bounties? https://github.com/zama-ai/bounty-and-grant-program

@fabecode
Copy link
Author

Thank you, will be happy to participate in the bounties when I have the time!

Actually, the GNN project I have embarked on is part of my undergraduate final thesis which is due very soon, and my university has just given the green light for me to apply for the Zama Grant Program. The final leg would be the compilation of the model to be TFHE compatible, so would be really grateful if you happen to have suggestions on temporary workarounds to the ScatterElements error that I can try out 🙏

@andrei-stoian-zama
Copy link
Collaborator

andrei-stoian-zama commented Mar 13, 2024

For workarounds: is the indexing matrix a constant? I seem to think so, can you confirm ?

x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()

is edge_index a constant? does it change at any time during training ?

@fabecode
Copy link
Author

Yes, I just checked, edge_index is a constant!

@andrei-stoian-zama
Copy link
Collaborator

andrei-stoian-zama commented Mar 13, 2024

Ok, and edge_index seems to be a list of values. If it's the same length as x, you could permute x with a matrix multiplication:

>>> import numpy as np
>>> x = np.random.randint(64, size=(32,))
>>> x
array([46,  6, 20, 44, 37,  3, 57, 60, 48, 34,  9,  1, 46, 44, 25, 42, 11,
       23,  5, 27, 48, 37,  7,  8, 17, 18, 38, 13, 18, 52,  1, 55])
>>> perm = np.random.permutation(32)
>>> perm
array([18,  9, 20,  3, 23, 31,  2, 19,  7, 12,  4, 16, 26, 17, 15,  6, 24,
       21, 30, 25,  8, 28, 11, 10,  5, 29, 13, 22,  0, 27,  1, 14])
>>> perm_matrix = np.zeros((32,32),dtype=np.int32)
>>> perm_matrix[perm,np.arange(32)] = 1
>>> x[perm]
array([ 5, 34, 48, 44,  8, 55, 20, 27, 60, 46, 37, 11, 38, 23, 42, 57, 17,
       37,  1, 18, 48, 18,  1,  9,  3, 52, 44,  7, 46, 13,  6, 25])
>>> x @ perm_matrix
array([ 5, 34, 48, 44,  8, 55, 20, 27, 60, 46, 37, 11, 38, 23, 42, 57, 17,
       37,  1, 18, 48, 18,  1,  9,  3, 52, 44,  7, 46, 13,  6, 25])

note that using such a constant matrix in your program won't work directly, as the matrix will be quantized.

I think you should create a ScatterElements operator in Concrete ML and in its quantized computation definition you could use such a permutation matrix (provided the scattering is a 1d permutation)

@fabecode
Copy link
Author

Thank you so much @andrei-stoian-zama for the suggestions, will try it out!

@fabecode
Copy link
Author

fabecode commented Mar 13, 2024

After further analysis, below are my findings:

  • edge_index does not have the same length as x.
  • In x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu() line:
    • Tensor shape of edge_index: [2, num_edges]
      • where 2 represents source + destination nodes
    • Tensor shape of x before reshape: [num_nodes, num_features]
    • So, tensor shape of x[edge_index.T]: [num_edges, 2, num_features]
      • where x[edge_index.T] is a tensor where each row contains the features of the source and dest nodes in each edge
    • Tensor shape of x[edge_index.T].reshape(-1, 2*self.n_hidden): [num_edges, 2*num_features]
    • Tensor shape of x after relu: [num_edges, 2*num_features]
  • In x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1) line:
    • Tensor shape of edge_attribute: [num_edges, num_features]
    • Tensor shape of (edge_attr.view(-1, edge_attr.shape[1])): [num_edges, num_features]
    • Tensor shape of (torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)): [num_edges, 3*num_features]

Also, upon looking at the ONNX graph, it might be possible that ScatterElements occur in the torch.cat line instead of the reshape line? (You could view the full graph in the ONNX file I sent in the earlier email.)
image

Could you kindly advise? Thank you once again!

@andrei-stoian-zama
Copy link
Collaborator

I'm afraid I don't know how to help more on this.

As a first step you should try to determine which torch operator produces ScatterElements. I don't think it's the concatenation though.

@fabecode
Copy link
Author

No worries, I really appreciate all your advice thus far! I will do my best to find a temporary alternative to ScatterElements.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants