From da5a76521a707f3f42f0e2ca1f3bf35ed2299ad6 Mon Sep 17 00:00:00 2001 From: pradeep Date: Sat, 8 May 2021 14:19:53 +0530 Subject: [PATCH] Fix logic to sequence based indexing in row/col/slice functions Prior to this change, the following functions were not checking for additional dimensional data beyond the dimension concerned with the particular function. - row - col - slice - rows - cols - slices - set_row - set_col - slice - set_rows - set_cols - set_slices Similar logic was missing in one particular matching pattern of view macro which is also fixed in this change. Few additional unit tests are added in macro and index module checking for the pitfalls this change addresses --- src/core/index.rs | 124 +++++++++++++++++++++++++++++++++------------ src/core/macros.rs | 56 +++++++++++++++++++- 2 files changed, 148 insertions(+), 32 deletions(-) diff --git a/src/core/index.rs b/src/core/index.rs index 4055152c7..17ec7b4aa 100644 --- a/src/core/index.rs +++ b/src/core/index.rs @@ -293,13 +293,11 @@ pub fn row(input: &Array, row_num: i64) -> Array where T: HasAfEnum, { - index( - input, - &[ - Seq::new(row_num as f64, row_num as f64, 1.0), - Seq::default(), - ], - ) + let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)]; + for _d in 1..input.dims().ndims() { + seqs.push(Seq::default()); + } + index(input, &seqs) } /// Set `row_num`^th row in `inout` Array to a new Array `new_row` @@ -308,7 +306,7 @@ where T: HasAfEnum, { let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)]; - if inout.dims().ndims() > 1 { + for _d in 1..inout.dims().ndims() { seqs.push(Seq::default()); } assign_seq(inout, &seqs, new_row) @@ -320,10 +318,11 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - index( - input, - &[Seq::new(first as f64, last as f64, step), Seq::default()], - ) + let mut seqs = vec![Seq::new(first as f64, last as f64, step)]; + for _d in 1..input.dims().ndims() { + seqs.push(Seq::default()); + } + index(input, &seqs) } /// Set rows from `first` to `last` in `inout` Array with rows from Array `new_rows` @@ -332,7 +331,10 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - let seqs = [Seq::new(first as f64, last as f64, step), Seq::default()]; + let mut seqs = vec![Seq::new(first as f64, last as f64, step)]; + for _d in 1..inout.dims().ndims() { + seqs.push(Seq::default()); + } assign_seq(inout, &seqs, new_rows) } @@ -352,13 +354,14 @@ pub fn col(input: &Array, col_num: i64) -> Array where T: HasAfEnum, { - index( - input, - &[ - Seq::default(), - Seq::new(col_num as f64, col_num as f64, 1.0), - ], - ) + let mut seqs = vec![ + Seq::default(), + Seq::new(col_num as f64, col_num as f64, 1.0), + ]; + for _d in 2..input.dims().ndims() { + seqs.push(Seq::default()); + } + index(input, &seqs) } /// Set `col_num`^th col in `inout` Array to a new Array `new_col` @@ -366,10 +369,13 @@ pub fn set_col(inout: &mut Array, new_col: &Array, col_num: i64) where T: HasAfEnum, { - let seqs = [ + let mut seqs = vec![ Seq::default(), Seq::new(col_num as f64, col_num as f64, 1.0), ]; + for _d in 2..inout.dims().ndims() { + seqs.push(Seq::default()); + } assign_seq(inout, &seqs, new_col) } @@ -379,10 +385,11 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - index( - input, - &[Seq::default(), Seq::new(first as f64, last as f64, step)], - ) + let mut seqs = vec![Seq::default(), Seq::new(first as f64, last as f64, step)]; + for _d in 2..input.dims().ndims() { + seqs.push(Seq::default()); + } + index(input, &seqs) } /// Set cols from `first` to `last` in `inout` Array with cols from Array `new_cols` @@ -391,7 +398,10 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - let seqs = [Seq::default(), Seq::new(first as f64, last as f64, step)]; + let mut seqs = vec![Seq::default(), Seq::new(first as f64, last as f64, step)]; + for _d in 2..inout.dims().ndims() { + seqs.push(Seq::default()); + } assign_seq(inout, &seqs, new_cols) } @@ -402,11 +412,14 @@ pub fn slice(input: &Array, slice_num: i64) -> Array where T: HasAfEnum, { - let seqs = [ + let mut seqs = vec![ Seq::default(), Seq::default(), Seq::new(slice_num as f64, slice_num as f64, 1.0), ]; + for _d in 3..input.dims().ndims() { + seqs.push(Seq::default()); + } index(input, &seqs) } @@ -417,11 +430,14 @@ pub fn set_slice(inout: &mut Array, new_slice: &Array, slice_num: i64) where T: HasAfEnum, { - let seqs = [ + let mut seqs = vec![ Seq::default(), Seq::default(), Seq::new(slice_num as f64, slice_num as f64, 1.0), ]; + for _d in 3..inout.dims().ndims() { + seqs.push(Seq::default()); + } assign_seq(inout, &seqs, new_slice) } @@ -433,11 +449,14 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - let seqs = [ + let mut seqs = vec![ Seq::default(), Seq::default(), Seq::new(first as f64, last as f64, step), ]; + for _d in 3..input.dims().ndims() { + seqs.push(Seq::default()); + } index(input, &seqs) } @@ -449,11 +468,14 @@ where T: HasAfEnum, { let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 }; - let seqs = [ + let mut seqs = vec![ Seq::default(), Seq::default(), Seq::new(first as f64, last as f64, step), ]; + for _d in 3..inout.dims().ndims() { + seqs.push(Seq::default()); + } assign_seq(inout, &seqs, new_slices) } @@ -655,7 +677,7 @@ mod tests { use super::super::device::set_device; use super::super::dim4::Dim4; use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer}; - use super::super::index::{cols, rows}; + use super::super::index::{cols, rows, set_row, set_rows}; use super::super::random::randu; use super::super::seq::Seq; @@ -868,4 +890,44 @@ mod tests { // 0.9675 0.3712 0.7896 // ANCHOR_END: get_rows } + + #[test] + fn change_row() { + set_device(0); + + let v0: Vec = vec![true, true, true, true, true, true]; + let mut a0 = Array::new(&v0, dim4!(v0.len() as u64)); + + let v1: Vec = vec![false]; + let a1 = Array::new(&v1, dim4!(v1.len() as u64)); + + set_row(&mut a0, &a1, 2); + + let mut res = vec![true; a0.elements()]; + a0.host(&mut res); + + let gold = vec![true, true, false, true, true, true]; + + assert_eq!(gold, res); + } + + #[test] + fn change_rows() { + set_device(0); + + let v0: Vec = vec![true, true, true, true, true, true]; + let mut a0 = Array::new(&v0, dim4!(v0.len() as u64)); + + let v1: Vec = vec![false, false]; + let a1 = Array::new(&v1, dim4!(v1.len() as u64)); + + set_rows(&mut a0, &a1, 2, 3); + + let mut res = vec![true; a0.elements()]; + a0.host(&mut res); + + let gold = vec![true, true, false, false, true, true]; + + assert_eq!(gold, res); + } } diff --git a/src/core/macros.rs b/src/core/macros.rs index d6b0b1596..4abba9ee7 100644 --- a/src/core/macros.rs +++ b/src/core/macros.rs @@ -190,6 +190,9 @@ macro_rules! view { $( seq_vec.push($crate::seq!($start:$end:$step)); )* + for _d in seq_vec.len()..$array_ident.dims().ndims() { + seq_vec.push($crate::seq!()); + } $crate::index(&$array_ident, &seq_vec) } }; @@ -354,7 +357,7 @@ mod tests { use super::super::array::Array; use super::super::data::constant; use super::super::device::set_device; - use super::super::index::index; + use super::super::index::{index, rows, set_rows}; use super::super::random::randu; #[test] @@ -505,4 +508,55 @@ mod tests { let _ruu32_5x5 = randu!(u32; 5, 5); let _ruu8_5x5 = randu!(u8; 5, 5); } + + #[test] + fn match_eval_macro_with_set_rows() { + set_device(0); + + let inpt = vec![true, true, true, true, true, true, true, true, true, true]; + let gold = vec![ + true, true, false, false, true, true, true, false, false, true, + ]; + + let mut orig_arr = Array::new(&inpt, dim4!(5, 2)); + let mut orig_cln = orig_arr.clone(); + + let new_vals = vec![false, false, false, false]; + let new_arr = Array::new(&new_vals, dim4!(2, 2)); + + eval!( orig_arr[2:3:1,1:1:0] = new_arr ); + let mut res1 = vec![true; orig_arr.elements()]; + orig_arr.host(&mut res1); + + set_rows(&mut orig_cln, &new_arr, 2, 3); + let mut res2 = vec![true; orig_cln.elements()]; + orig_cln.host(&mut res2); + + assert_eq!(gold, res1); + assert_eq!(res1, res2); + } + + #[test] + fn match_view_macro_with_get_rows() { + set_device(0); + + let inpt: Vec = (0..10).collect(); + let gold: Vec = vec![2, 3, 7, 8]; + + println!("input {:?}", inpt); + println!("gold {:?}", gold); + + let orig_arr = Array::new(&inpt, dim4!(5, 2)); + + let view_out = view!( orig_arr[2:3:1] ); + let mut res1 = vec![0i32; view_out.elements()]; + view_out.host(&mut res1); + + let rows_out = rows(&orig_arr, 2, 3); + let mut res2 = vec![0i32; rows_out.elements()]; + rows_out.host(&mut res2); + + assert_eq!(gold, res1); + assert_eq!(res1, res2); + } }