Fixes issue 403 for select functions: add, sub, mul #413
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Types of changes
Motivation and Context / Related issue
At present, it is not possible to do things like
torch_tensor.add(cryptensor)
ortorch_tensor + cryptensor
. The problem is that functions like__radd__
never get called becausetorch.Tensor.add
fails with aTypeError
rather than aNotImplementedError
(which would trigger the reverse function to get called). This limitation leads to issues such as #403This PR fixes this issue for the
add
,sub
, andmul
functions. The general approach is as follows:torch.Tensor.{add,sub,mul}
in the__torch_function__
handler via an@implements
decorator.__init_subclass__
function inCrypTensor
that ensures these decorators are inherited by subclasses ofCrypTensor
.MPCTensor
dynamically adds functions likeadd
,sub
, andmul
after the subclass is created, the registration is also done manually for those functions inMPCTensor
.MPCTensor.binary_wrapper_function
assumes specific structure ofMPCTensor
thattorch.Tensor
does not have, we switch the order of the arguments if needed and alter the function name to be__radd__
,__rsub__
, etc.Note that it is not immediately clear how to make the same work for other functions like
matmul
that do not have an__rmatmul__
or for functions that do not exist in PyTorch likeconv1d
. It can be done but things will get pretty messy. So the question with this PR is if this is a path we want to continue on.How Has This Been Tested
This PR is currently an RFC so I have not deeply tested all the changes yet. I would first like to get feedback on whether we want to make this change at all.
That said, these simple examples pass:
Similarly, the example from #403 passes:
If we want to proceed in this direction, I will add full unit tests.
Checklist