@@ -330,6 +330,32 @@ Tensor floor_divide_mps(const Tensor& self, const Tensor& other) {
330330 return floor_divide_out_mps (self, other, self);
331331}
332332
333+ TORCH_IMPL_FUNC (remainder_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) {
334+ // torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
335+ mps::BinaryOpBlock remainder_op_block = ^BinaryOpFn (cachedGraph, primaryCastTensor, secondaryCastTensor) {
336+ MPSGraph* mpsGraph = cachedGraph->graph ();
337+ // Rounding is a no-op for integral types, and also a reasonable workaround
338+ // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
339+ // See https://github.com/pytorch/pytorch/issues/84995
340+
341+ auto divTensor = [mpsGraph divisionWithPrimaryTensor: primaryCastTensor
342+ secondaryTensor: secondaryCastTensor
343+ name: nil ];
344+ bool isFloatOutput = ([divTensor dataType ] & MPSDataTypeFloatBit) != 0 ;
345+ if (isFloatOutput) {
346+ divTensor = [mpsGraph floorWithTensor: divTensor name: nil ];
347+ }
348+
349+ auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor: divTensor
350+ secondaryTensor: secondaryCastTensor
351+ name: nil ];
352+ return [mpsGraph subtractionWithPrimaryTensor: primaryCastTensor
353+ secondaryTensor: mulTensor
354+ name: nil ];
355+ };
356+ mps::binaryOpTensor (self, other, Scalar (1.0 ), output, " remainder_out_mps" , remainder_op_block);
357+ }
358+
333359TORCH_IMPL_FUNC (logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
334360{
335361 mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn (cachedGraph, primaryCastTensor, secondaryCastTensor) {
0 commit comments