diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index e3d10658fc..1ec4290598 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -228,6 +228,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` | | `tensor.prod()` | `tensor.prod()` | | `tensor.prod_dim(dim)` | `tensor.prod(dim, keepdim=True)` | +| `tensor.rem(other)` or `tensor % other` | `tensor % other` | | `tensor.scatter(dim, indices, values)` | `tensor.scatter_add(dim, indices, values)` | | `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` | | `tensor.select_assign(dim, indices, values)` | N/A | diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 40674d6d64..705d3b0072 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2868,6 +2868,20 @@ where } } +impl core::ops::Rem for Tensor +where + E: ElementConversion, + B: Backend, + K: Numeric, + K::Elem: Element, +{ + type Output = Self; + + fn rem(self, other: E) -> Self { + Tensor::remainder_scalar(self, other) + } +} + impl core::ops::Mul> for Tensor where B: Backend, diff --git a/crates/burn-tensor/src/tests/ops/remainder.rs b/crates/burn-tensor/src/tests/ops/remainder.rs index e3f86cd0f5..6e69f5889e 100644 --- a/crates/burn-tensor/src/tests/ops/remainder.rs +++ b/crates/burn-tensor/src/tests/ops/remainder.rs @@ -95,4 +95,17 @@ mod tests { let data_expected = Data::from([9.0, 1.0]); data_expected.assert_approx_eq(&data_actual, 3); } + + #[test] + fn should_support_remainder_op() { + let data = Data::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]); + let device = Default::default(); + let tensor = Tensor::::from_data(data, &device); + + let output = tensor % 2.0; + + let data_actual = output.into_data(); + let data_expected = Data::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]); + data_expected.assert_approx_eq(&data_actual, 3); + } }