Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes issue 403 for select functions: add, sub, mul #413

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

lvdmaaten
Copy link
Member

@lvdmaaten lvdmaaten commented Oct 9, 2022

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Docs change / refactoring / dependency upgrade

Motivation and Context / Related issue

At present, it is not possible to do things like torch_tensor.add(cryptensor) or torch_tensor + cryptensor. The problem is that functions like __radd__ never get called because torch.Tensor.add fails with a TypeError rather than a NotImplementedError (which would trigger the reverse function to get called). This limitation leads to issues such as #403

This PR fixes this issue for the add, sub, and mul functions. The general approach is as follows:

  • Add handling of torch.Tensor.{add,sub,mul} in the __torch_function__ handler via an @implements decorator.
  • Add an __init_subclass__ function in CrypTensor that ensures these decorators are inherited by subclasses of CrypTensor.
  • Due to the way that MPCTensor dynamically adds functions like add, sub, and mul after the subclass is created, the registration is also done manually for those functions in MPCTensor.
  • Because MPCTensor.binary_wrapper_function assumes specific structure of MPCTensor that torch.Tensor does not have, we switch the order of the arguments if needed and alter the function name to be __radd__, __rsub__, etc.

Note that it is not immediately clear how to make the same work for other functions like matmul that do not have an __rmatmul__ or for functions that do not exist in PyTorch like conv1d. It can be done but things will get pretty messy. So the question with this PR is if this is a path we want to continue on.

How Has This Been Tested

This PR is currently an RFC so I have not deeply tested all the changes yet. I would first like to get feedback on whether we want to make this change at all.

That said, these simple examples pass:

import crypten
import torch

crypten.init()

a = torch.randn(3, 3) * 5
b = torch.randn(3, 3) * 5

result = a.add(crypten.cryptensor(b))
reference = a.add(b)
assert torch.allclose(reference, result.get_plain_text(), atol=2e-4)

result = a.sub(crypten.cryptensor(b))
reference = a.sub(b)
assert torch.allclose(reference, result.get_plain_text(), atol=2e-4)

result = a.mul(crypten.cryptensor(b))
reference = a.mul(b)
assert torch.allclose(reference, result.get_plain_text(), atol=2e-4)

Similarly, the example from #403 passes:

import crypten
import torch
import torch.nn as nn

crypten.init()


class NN(nn.Sequential):
    # network architecture:
    def __init__(self, mask):
        super(NN, self).__init__()
        self.mask = mask
        self.fc = nn.Linear(3, 3)
    def forward(self, attn):
        masked = attn.masked_fill(self.mask == 0, 2.)
        result = self.fc(masked)
        return result


model = NN(torch.randn(3, 3) > 0.)
dummy_input = torch.randn(3, 3)
private_model = crypten.nn.from_pytorch(model, dummy_input).encrypt()
private_inputs = crypten.cryptensor(dummy_input)
private_model(private_inputs)

If we want to proceed in this direction, I will add full unit tests.

Checklist

  • The documentation is up-to-date with the changes I made.
  • I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
  • All tests passed, and additional code has been covered with new tests.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants