Skip to content

Commit

Permalink
[MPS] Fix embedding backward with scalar index (pytorch#82809)
Browse files Browse the repository at this point in the history
### Description
Previously the embedding backward always expands `-1` dim to indices, resulting in the following error when the indices is a scalar:

```
 error: Rank of data array must equal number of outer dimensions in indices array + rank of slice to update, 2 != 1 + 0
-:8:10: note: see current operation: %5 = "mps.scatter_nd"(%0, %arg1, %4) {batch_dims = 0 : ui32, mode = 0 : i32} : (tensor<10x5xf16>,
```

Now makes it conditional.

Reproducer:

```python
def repro():
    w = torch.tensor([[-2.6465,  2.5859,  0.4688,  1.7949,  3.2676],
        [-3.1641,  8.9375,  5.7578, -2.9453, -6.5469],
        [ 2.0469,  1.3516, -8.7344,  6.0000,  1.3906],
        [ 6.5781,  7.8438,  6.9766,  3.2891, -5.1172],
        [-7.9414,  7.7344,  4.1875,  2.8574,  2.9531],
        [-0.4844, -5.6328, -6.8359, -4.5156,  3.7891],
        [ 4.9375,  6.6094,  6.7031,  0.6719, -6.4219],
        [ 7.0469,  8.2031,  4.4453,  1.7129, -2.4688],
        [ 1.2207, -3.3750, -2.4531,  7.4062, -6.0469],
        [-8.9688,  2.2656,  2.4160, -1.0176,  8.4531]], dtype=torch.float32, requires_grad=True)
    x = torch.tensor(5)
    out = torch.nn.functional.embedding(x, w)
    out.sum().backward()

    w_mps = w.detach().clone().to("mps").requires_grad_()
    x_mps = x.to("mps")
    out = torch.nn.functional.embedding(x_mps, w_mps)
    out.sum().backward() # error
```

### Issue
<!-- Link to Issue ticket or RFP -->

### Testing
<!-- How did you test your change? -->

Pull Request resolved: pytorch#82809
Approved by: https://github.com/malfet
  • Loading branch information
qqaatw authored and pytorchmergebot committed Nov 4, 2022
1 parent 5b767d4 commit 15e5429
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
11 changes: 7 additions & 4 deletions aten/src/ATen/native/mps/operations/Indexing.mm
Original file line number Diff line number Diff line change
Expand Up @@ -645,11 +645,14 @@ Tensor embedding_dense_backward_mps(

MPSGraphTensor* indicesTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(indices.scalar_type()));

MPSGraphTensor *reshapedIndicesTensor = [mpsGraph expandDimsOfTensor:indicesTensor
axes:@[@-1]
name:nil];
MPSGraphTensor* reshapedIndicesTensor = indicesTensor;

MPSGraphTensor *outgoingGradTensor;
if (num_indices_dims != 0)
reshapedIndicesTensor = [mpsGraph expandDimsOfTensor:indicesTensor
axes:@[@-1]
name:nil];

MPSGraphTensor* outgoingGradTensor;
outgoingGradTensor = [mpsGraph scatterNDWithUpdatesTensor:incomingGradTensor
indicesTensor:reshapedIndicesTensor
shape:native_mps::getMPSShape(IntArrayRef(outgoing_gradient_shape.data(), outgoing_gradient_shape.size()))
Expand Down
13 changes: 7 additions & 6 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4278,33 +4278,34 @@ def helper(shape, dim, index, idx_dtype=torch.int32):
helper((2, 3, 3), -1, [1, 2])

def test_embedding_dense_backward(self):
def helper(n, d, m):
def helper(n, d, m, idx):
embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
W_MPS = torch.randn((m, d), requires_grad=True, device='mps')
idx_MPS = torch.tensor([0, 1, 2]).to('mps')
idx_MPS = torch.tensor(idx).to('mps')
a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable
a_MPS.retain_grad()
b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t() # modifies weight in-place
b_MPS.retain_grad()
out_MPS = (a_MPS.unsqueeze(0) + b_MPS.unsqueeze(1))
out_MPS = (a_MPS.unsqueeze(0) + b_MPS)
loss_MPS = out_MPS.sigmoid().prod()
loss_MPS.backward()

embeddingCPU = nn.Embedding(n, d, max_norm=True, scale_grad_by_freq=True)
W_CPU = W_MPS.to('cpu')
idx_CPU = torch.tensor([0, 1, 2])
idx_CPU = torch.tensor(idx)
a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable
a_CPU.retain_grad()
b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t() # modifies weight in-place
b_CPU.retain_grad()
out_CPU = (a_CPU.unsqueeze(0) + b_CPU.unsqueeze(1))
out_CPU = (a_CPU.unsqueeze(0) + b_CPU)
loss_CPU = out_CPU.sigmoid().prod()
loss_CPU.backward()

self.assertEqual(b_CPU.grad, b_MPS.grad)
self.assertEqual(a_CPU.grad, a_MPS.grad)

helper(3, 5, 7)
helper(3, 5, 7, [0, 1, 2])
helper(3, 5, 7, 2) # test scalar index

# Test pytorch gather
def test_gather(self):
Expand Down

0 comments on commit 15e5429

Please sign in to comment.