Skip to content

Commit 71ebe62

Browse files
DenisVieriu97kulinseth
authored andcommitted
cherry-pick remainder op from upstream (#244)
1 parent a5c5c10 commit 71ebe62

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,32 @@ Tensor floor_divide_mps(const Tensor& self, const Tensor& other) {
328328
return floor_divide_out_mps(self, other, self);
329329
}
330330

331+
TORCH_IMPL_FUNC(remainder_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) {
332+
// torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
333+
mps::BinaryOpBlock remainder_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
334+
MPSGraph* mpsGraph = cachedGraph->graph();
335+
// Rounding is a no-op for integral types, and also a reasonable workaround
336+
// For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
337+
// See https://github.com/pytorch/pytorch/issues/84995
338+
339+
auto divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor
340+
secondaryTensor:secondaryCastTensor
341+
name:nil];
342+
bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0;
343+
if (isFloatOutput) {
344+
divTensor = [mpsGraph floorWithTensor:divTensor name:nil];
345+
}
346+
347+
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:divTensor
348+
secondaryTensor:secondaryCastTensor
349+
name:nil];
350+
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor
351+
secondaryTensor:mulTensor
352+
name: nil];
353+
};
354+
mps::binaryOpTensor(self, other, Scalar(1.0), output, "remainder_out_mps", remainder_op_block);
355+
}
356+
331357
TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
332358
{
333359
mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9218,6 +9218,7 @@
92189218
structured_inherits: TensorIteratorBase
92199219
dispatch:
92209220
CPU, CUDA: remainder_out
9221+
MPS: remainder_out_mps
92219222
tags: pointwise
92229223

92239224
- func: remainder.Tensor(Tensor self, Tensor other) -> Tensor

test/test_mps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9624,7 +9624,7 @@ class TestConsistency(TestCase):
96249624
'put': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
96259625
'qr': [torch.float32],
96269626
'quantile': [torch.float32],
9627-
'remainder': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
9627+
'remainder': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8],
96289628
'renorm': [torch.float16, torch.float32],
96299629
'roll': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
96309630
'rounddecimals_0': [torch.float32],

0 commit comments

Comments
 (0)