Skip to content

Commit 1d50f47

Browse files
cherry-pick remainder op from upstream (#244)
1 parent 892bc60 commit 1d50f47

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,32 @@ Tensor floor_divide_mps(const Tensor& self, const Tensor& other) {
330330
return floor_divide_out_mps(self, other, self);
331331
}
332332

333+
TORCH_IMPL_FUNC(remainder_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) {
334+
// torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
335+
mps::BinaryOpBlock remainder_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
336+
MPSGraph* mpsGraph = cachedGraph->graph();
337+
// Rounding is a no-op for integral types, and also a reasonable workaround
338+
// For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
339+
// See https://github.com/pytorch/pytorch/issues/84995
340+
341+
auto divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor
342+
secondaryTensor:secondaryCastTensor
343+
name:nil];
344+
bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0;
345+
if (isFloatOutput) {
346+
divTensor = [mpsGraph floorWithTensor:divTensor name:nil];
347+
}
348+
349+
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:divTensor
350+
secondaryTensor:secondaryCastTensor
351+
name:nil];
352+
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
353+
secondaryTensor:mulTensor
354+
name: nil];
355+
};
356+
mps::binaryOpTensor(self, other, Scalar(1.0), output, "remainder_out_mps", remainder_op_block);
357+
}
358+
333359
TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
334360
{
335361
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9186,6 +9186,7 @@
91869186
structured_inherits: TensorIteratorBase
91879187
dispatch:
91889188
CPU, CUDA: remainder_out
9189+
MPS: remainder_out_mps
91899190
tags: pointwise
91909191

91919192
- func: remainder.Tensor(Tensor self, Tensor other) -> Tensor

test/test_mps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9485,7 +9485,7 @@ class TestConsistency(TestCase):
94859485
'put': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
94869486
'qr': [torch.float32],
94879487
'quantile': [torch.float32],
9488-
'remainder': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
9488+
'remainder': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8],
94899489
'renorm': [torch.float16, torch.float32],
94909490
'roll': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
94919491
'rounddecimals_0': [torch.float32],

0 commit comments

Comments
 (0)