Skip to content

Commit

Permalink
Fix infinite recursion, overflow, and off-by-one error in triu/tril (#…
Browse files Browse the repository at this point in the history
…1418)

* Fixes infinite recursion and off-by-one error

* Avoids overflow using saturating arithmetic

* Removes unused import

* Fixes bug for isize::MAX for triu

* Fix formatting

* Uses broadcast indices to remove D::Smaller: Copy trait bound
  • Loading branch information
akern40 authored Aug 11, 2024
1 parent f563af0 commit 1df6c32
Showing 1 changed file with 129 additions and 54 deletions.
183 changes: 129 additions & 54 deletions src/tri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use core::cmp::{max, min};
use core::cmp::min;

use num_traits::Zero;

use crate::{dimension::is_layout_f, Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip};
use crate::{
dimension::{is_layout_c, is_layout_f},
Array,
ArrayBase,
Axis,
Data,
Dimension,
Zip,
};

impl<S, A, D> ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
A: Clone + Zero,
D::Smaller: Copy,
{
/// Upper triangular of an array.
///
Expand All @@ -30,38 +37,56 @@ where
/// ```
/// use ndarray::array;
///
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
/// let res = arr.triu(0);
/// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
/// let arr = array![
/// [1, 2, 3],
/// [4, 5, 6],
/// [7, 8, 9]
/// ];
/// assert_eq!(
/// arr.triu(0),
/// array![
/// [1, 2, 3],
/// [0, 5, 6],
/// [0, 0, 9]
/// ]
/// );
/// ```
pub fn triu(&self, k: isize) -> Array<A, D>
{
if self.ndim() <= 1 {
return self.to_owned();
}
match is_layout_f(&self.dim, &self.strides) {
true => {
let n = self.ndim();
let mut x = self.view();
x.swap_axes(n - 2, n - 1);
let mut tril = x.tril(-k);
tril.swap_axes(n - 2, n - 1);

tril
}
false => {
let mut res = Array::zeros(self.raw_dim());
Zip::indexed(self.rows())
.and(res.rows_mut())
.for_each(|i, src, mut dst| {
let row_num = i.into_dimension().last_elem();
let lower = max(row_num as isize + k, 0);
dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
});

res
}

// Performance optimization for F-order arrays.
// C-order array check prevents infinite recursion in edge cases like [[1]].
// k-size check prevents underflow when k == isize::MIN
let n = self.ndim();
if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN {
let mut x = self.view();
x.swap_axes(n - 2, n - 1);
let mut tril = x.tril(-k);
tril.swap_axes(n - 2, n - 1);

return tril;
}

let mut res = Array::zeros(self.raw_dim());
let ncols = self.len_of(Axis(n - 1));
let nrows = self.len_of(Axis(n - 2));
let indices = Array::from_iter(0..nrows);
Zip::from(self.rows())
.and(res.rows_mut())
.and_broadcast(&indices)
.for_each(|src, mut dst, row_num| {
let mut lower = match k >= 0 {
true => row_num.saturating_add(k as usize), // Avoid overflow
false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0
};
lower = min(lower, ncols);
dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
});

res
}

/// Lower triangular of an array.
Expand All @@ -75,45 +100,65 @@ where
/// ```
/// use ndarray::array;
///
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
/// let res = arr.tril(0);
/// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
/// let arr = array![
/// [1, 2, 3],
/// [4, 5, 6],
/// [7, 8, 9]
/// ];
/// assert_eq!(
/// arr.tril(0),
/// array![
/// [1, 0, 0],
/// [4, 5, 0],
/// [7, 8, 9]
/// ]
/// );
/// ```
pub fn tril(&self, k: isize) -> Array<A, D>
{
if self.ndim() <= 1 {
return self.to_owned();
}
match is_layout_f(&self.dim, &self.strides) {
true => {
let n = self.ndim();
let mut x = self.view();
x.swap_axes(n - 2, n - 1);
let mut tril = x.triu(-k);
tril.swap_axes(n - 2, n - 1);

tril
}
false => {
let mut res = Array::zeros(self.raw_dim());
let ncols = self.len_of(Axis(self.ndim() - 1)) as isize;
Zip::indexed(self.rows())
.and(res.rows_mut())
.for_each(|i, src, mut dst| {
let row_num = i.into_dimension().last_elem();
let upper = min(row_num as isize + k, ncols) + 1;
dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
});

res
}

// Performance optimization for F-order arrays.
// C-order array check prevents infinite recursion in edge cases like [[1]].
// k-size check prevents underflow when k == isize::MIN
let n = self.ndim();
if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN {
let mut x = self.view();
x.swap_axes(n - 2, n - 1);
let mut tril = x.triu(-k);
tril.swap_axes(n - 2, n - 1);

return tril;
}

let mut res = Array::zeros(self.raw_dim());
let ncols = self.len_of(Axis(n - 1));
let nrows = self.len_of(Axis(n - 2));
let indices = Array::from_iter(0..nrows);
Zip::from(self.rows())
.and(res.rows_mut())
.and_broadcast(&indices)
.for_each(|src, mut dst, row_num| {
// let row_num = i.into_dimension().last_elem();
let mut upper = match k >= 0 {
true => row_num.saturating_add(k as usize).saturating_add(1), // Avoid overflow
false => row_num.saturating_sub((k + 1).unsigned_abs()), // Avoid underflow
};
upper = min(upper, ncols);
dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
});

res
}
}

#[cfg(test)]
mod tests
{
use core::isize;

use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder};
use alloc::vec;

Expand Down Expand Up @@ -188,6 +233,19 @@ mod tests
assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
}

#[test]
fn test_2d_single()
{
let x = array![[1]];

assert_eq!(x.triu(0), array![[1]]);
assert_eq!(x.tril(0), array![[1]]);
assert_eq!(x.triu(1), array![[0]]);
assert_eq!(x.tril(1), array![[1]]);
assert_eq!(x.triu(-1), array![[1]]);
assert_eq!(x.tril(-1), array![[0]]);
}

#[test]
fn test_3d()
{
Expand Down Expand Up @@ -285,8 +343,25 @@ mod tests
let res = x.triu(0);
assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]);

let res = x.tril(0);
assert_eq!(res, array![[1, 0, 0], [4, 5, 0]]);

let x = array![[1, 2], [3, 4], [5, 6]];
let res = x.triu(0);
assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]);

let res = x.tril(0);
assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]);
}

#[test]
fn test_odd_k()
{
let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
let z = Array2::zeros([3, 3]);
assert_eq!(x.triu(isize::MIN), x);
assert_eq!(x.tril(isize::MIN), z);
assert_eq!(x.triu(isize::MAX), z);
assert_eq!(x.tril(isize::MAX), x);
}
}

0 comments on commit 1df6c32

Please sign in to comment.