@@ -328,6 +328,32 @@ Tensor floor_divide_mps(const Tensor& self, const Tensor& other) {
328328 return floor_divide_out_mps (self, other, self);
329329}
330330
331+ TORCH_IMPL_FUNC (remainder_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) {
332+ // torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
333+ mps::BinaryOpBlock remainder_op_block = ^BinaryOpFn (cachedGraph, primaryCastTensor, secondaryCastTensor) {
334+ MPSGraph* mpsGraph = cachedGraph->graph ();
335+ // Rounding is a no-op for integral types, and also a reasonable workaround
336+ // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
337+ // See https://github.com/pytorch/pytorch/issues/84995
338+
339+ auto divTensor = [mpsGraph divisionWithPrimaryTensor: primaryCastTensor
340+ secondaryTensor: secondaryCastTensor
341+ name: nil ];
342+ bool isFloatOutput = ([divTensor dataType ] & MPSDataTypeFloatBit) != 0 ;
343+ if (isFloatOutput) {
344+ divTensor = [mpsGraph floorWithTensor: divTensor name: nil ];
345+ }
346+
347+ auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor: divTensor
348+ secondaryTensor: secondaryCastTensor
349+ name: nil ];
350+ return [mpsGraph subtractionWithPrimaryTensor: primaryCastTensor
351+ secondaryTensor: mulTensor
352+ name: nil ];
353+ };
354+ mps::binaryOpTensor (self, other, Scalar (1.0 ), output, " remainder_out_mps" , remainder_op_block);
355+ }
356+
331357TORCH_IMPL_FUNC (logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output)
332358{
333359 mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn (cachedGraph, primaryCastTensor, secondaryCastTensor) {
0 commit comments