Skip to content

Commit

Permalink
fix: tensor prod and prod dim containing nan values (#2515)
Browse files Browse the repository at this point in the history
quinton11 authored Nov 20, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent f64914b commit a0e8e4d
Showing 5 changed files with 77 additions and 0 deletions.
51 changes: 51 additions & 0 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
@@ -1338,6 +1338,57 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out
}

fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float_ops!(ProdOps, B::float_prod, reduce);

let stream = tensor.stream;
let out = tensor
.client
.tensor_uninitialized(vec![1], B::FloatElem::dtype());

let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Prod(desc.clone()),
),
ProdOps::<B>::new(desc),
);

out
}

fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
scalar_float_ops!(ProdDimOps, B::float_prod_dim, usize, noconvert);

let stream = tensor.stream;
let mut shape = tensor.shape.clone();
shape[dim] = 1;
let out = tensor
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::ProdDim(desc.clone()),
),
ProdDimOps::<B>::new(desc),
);

out
}

fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float_ops!(MeanOps, B::float_mean, reduce);

8 changes: 8 additions & 0 deletions crates/burn-ndarray/src/ops/tensor.rs
Original file line number Diff line number Diff line change
@@ -308,6 +308,14 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorO
NdArrayTensor::new(array)
}

fn float_prod(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
NdArrayMathOps::prod(tensor)
}

fn float_prod_dim(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
NdArrayMathOps::prod_dim(tensor, dim)
}

fn float_log1p(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared();

1 change: 1 addition & 0 deletions crates/burn-tensor/src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -220,6 +220,7 @@ macro_rules! testgen_with_float_param {
burn_tensor::testgen_floor!();
burn_tensor::testgen_ceil!();
burn_tensor::testgen_select!();
burn_tensor::testgen_prod!();

// test stats
burn_tensor::testgen_var!();
1 change: 1 addition & 0 deletions crates/burn-tensor/src/tests/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@ mod padding;
mod permute;
mod powf;
mod powf_scalar;
mod prod;
mod random;
mod recip;
mod remainder;
16 changes: 16 additions & 0 deletions crates/burn-tensor/src/tests/ops/prod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#[burn_tensor_testgen::testgen(prod)]
mod tests {
use super::*;
use burn_tensor::{Tensor, TensorData};

#[test]
fn test_prod_float() {
let tensor_1 = TestTensor::<2>::from([[-5.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);

let output = tensor_1.prod();

output
.into_data()
.assert_eq(&TensorData::from([-600.0]), false);
}
}

0 comments on commit a0e8e4d

Please sign in to comment.