Skip to content

Fix logic to sequence based indexing in row/col/slice functions #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 93 additions & 31 deletions src/core/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,11 @@ pub fn row<T>(input: &Array<T>, row_num: i64) -> Array<T>
where
T: HasAfEnum,
{
index(
input,
&[
Seq::new(row_num as f64, row_num as f64, 1.0),
Seq::default(),
],
)
let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)];
for _d in 1..input.dims().ndims() {
seqs.push(Seq::default());
}
index(input, &seqs)
}

/// Set `row_num`^th row in `inout` Array to a new Array `new_row`
Expand All @@ -308,7 +306,7 @@ where
T: HasAfEnum,
{
let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)];
if inout.dims().ndims() > 1 {
for _d in 1..inout.dims().ndims() {
seqs.push(Seq::default());
}
assign_seq(inout, &seqs, new_row)
Expand All @@ -320,10 +318,11 @@ where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
index(
input,
&[Seq::new(first as f64, last as f64, step), Seq::default()],
)
let mut seqs = vec![Seq::new(first as f64, last as f64, step)];
for _d in 1..input.dims().ndims() {
seqs.push(Seq::default());
}
index(input, &seqs)
}

/// Set rows from `first` to `last` in `inout` Array with rows from Array `new_rows`
Expand All @@ -332,7 +331,10 @@ where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [Seq::new(first as f64, last as f64, step), Seq::default()];
let mut seqs = vec![Seq::new(first as f64, last as f64, step)];
for _d in 1..inout.dims().ndims() {
seqs.push(Seq::default());
}
assign_seq(inout, &seqs, new_rows)
}

Expand All @@ -352,24 +354,28 @@ pub fn col<T>(input: &Array<T>, col_num: i64) -> Array<T>
where
T: HasAfEnum,
{
index(
input,
&[
Seq::default(),
Seq::new(col_num as f64, col_num as f64, 1.0),
],
)
let mut seqs = vec![
Seq::default(),
Seq::new(col_num as f64, col_num as f64, 1.0),
];
for _d in 2..input.dims().ndims() {
seqs.push(Seq::default());
}
index(input, &seqs)
}

/// Set `col_num`^th col in `inout` Array to a new Array `new_col`
pub fn set_col<T>(inout: &mut Array<T>, new_col: &Array<T>, col_num: i64)
where
T: HasAfEnum,
{
let seqs = [
let mut seqs = vec![
Seq::default(),
Seq::new(col_num as f64, col_num as f64, 1.0),
];
for _d in 2..inout.dims().ndims() {
seqs.push(Seq::default());
}
assign_seq(inout, &seqs, new_col)
}

Expand All @@ -379,10 +385,11 @@ where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
index(
input,
&[Seq::default(), Seq::new(first as f64, last as f64, step)],
)
let mut seqs = vec![Seq::default(), Seq::new(first as f64, last as f64, step)];
for _d in 2..input.dims().ndims() {
seqs.push(Seq::default());
}
index(input, &seqs)
}

/// Set cols from `first` to `last` in `inout` Array with cols from Array `new_cols`
Expand All @@ -391,7 +398,10 @@ where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [Seq::default(), Seq::new(first as f64, last as f64, step)];
let mut seqs = vec![Seq::default(), Seq::new(first as f64, last as f64, step)];
for _d in 2..inout.dims().ndims() {
seqs.push(Seq::default());
}
assign_seq(inout, &seqs, new_cols)
}

Expand All @@ -402,11 +412,14 @@ pub fn slice<T>(input: &Array<T>, slice_num: i64) -> Array<T>
where
T: HasAfEnum,
{
let seqs = [
let mut seqs = vec![
Seq::default(),
Seq::default(),
Seq::new(slice_num as f64, slice_num as f64, 1.0),
];
for _d in 3..input.dims().ndims() {
seqs.push(Seq::default());
}
index(input, &seqs)
}

Expand All @@ -417,11 +430,14 @@ pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: i64)
where
T: HasAfEnum,
{
let seqs = [
let mut seqs = vec![
Seq::default(),
Seq::default(),
Seq::new(slice_num as f64, slice_num as f64, 1.0),
];
for _d in 3..inout.dims().ndims() {
seqs.push(Seq::default());
}
assign_seq(inout, &seqs, new_slice)
}

Expand All @@ -433,11 +449,14 @@ where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [
let mut seqs = vec![
Seq::default(),
Seq::default(),
Seq::new(first as f64, last as f64, step),
];
for _d in 3..input.dims().ndims() {
seqs.push(Seq::default());
}
index(input, &seqs)
}

Expand All @@ -449,11 +468,14 @@ where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [
let mut seqs = vec![
Seq::default(),
Seq::default(),
Seq::new(first as f64, last as f64, step),
];
for _d in 3..inout.dims().ndims() {
seqs.push(Seq::default());
}
assign_seq(inout, &seqs, new_slices)
}

Expand Down Expand Up @@ -655,7 +677,7 @@ mod tests {
use super::super::device::set_device;
use super::super::dim4::Dim4;
use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer};
use super::super::index::{cols, rows};
use super::super::index::{cols, rows, set_row, set_rows};
use super::super::random::randu;
use super::super::seq::Seq;

Expand Down Expand Up @@ -868,4 +890,44 @@ mod tests {
// 0.9675 0.3712 0.7896
// ANCHOR_END: get_rows
}

#[test]
fn change_row() {
set_device(0);

let v0: Vec<bool> = vec![true, true, true, true, true, true];
let mut a0 = Array::new(&v0, dim4!(v0.len() as u64));

let v1: Vec<bool> = vec![false];
let a1 = Array::new(&v1, dim4!(v1.len() as u64));

set_row(&mut a0, &a1, 2);

let mut res = vec![true; a0.elements()];
a0.host(&mut res);

let gold = vec![true, true, false, true, true, true];

assert_eq!(gold, res);
}

#[test]
fn change_rows() {
set_device(0);

let v0: Vec<bool> = vec![true, true, true, true, true, true];
let mut a0 = Array::new(&v0, dim4!(v0.len() as u64));

let v1: Vec<bool> = vec![false, false];
let a1 = Array::new(&v1, dim4!(v1.len() as u64));

set_rows(&mut a0, &a1, 2, 3);

let mut res = vec![true; a0.elements()];
a0.host(&mut res);

let gold = vec![true, true, false, false, true, true];

assert_eq!(gold, res);
}
}
56 changes: 55 additions & 1 deletion src/core/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ macro_rules! view {
$(
seq_vec.push($crate::seq!($start:$end:$step));
)*
for _d in seq_vec.len()..$array_ident.dims().ndims() {
seq_vec.push($crate::seq!());
}
$crate::index(&$array_ident, &seq_vec)
}
};
Expand Down Expand Up @@ -354,7 +357,7 @@ mod tests {
use super::super::array::Array;
use super::super::data::constant;
use super::super::device::set_device;
use super::super::index::index;
use super::super::index::{index, rows, set_rows};
use super::super::random::randu;

#[test]
Expand Down Expand Up @@ -505,4 +508,55 @@ mod tests {
let _ruu32_5x5 = randu!(u32; 5, 5);
let _ruu8_5x5 = randu!(u8; 5, 5);
}

#[test]
fn match_eval_macro_with_set_rows() {
set_device(0);

let inpt = vec![true, true, true, true, true, true, true, true, true, true];
let gold = vec![
true, true, false, false, true, true, true, false, false, true,
];

let mut orig_arr = Array::new(&inpt, dim4!(5, 2));
let mut orig_cln = orig_arr.clone();

let new_vals = vec![false, false, false, false];
let new_arr = Array::new(&new_vals, dim4!(2, 2));

eval!( orig_arr[2:3:1,1:1:0] = new_arr );
let mut res1 = vec![true; orig_arr.elements()];
orig_arr.host(&mut res1);

set_rows(&mut orig_cln, &new_arr, 2, 3);
let mut res2 = vec![true; orig_cln.elements()];
orig_cln.host(&mut res2);

assert_eq!(gold, res1);
assert_eq!(res1, res2);
}

#[test]
fn match_view_macro_with_get_rows() {
set_device(0);

let inpt: Vec<i32> = (0..10).collect();
let gold: Vec<i32> = vec![2, 3, 7, 8];

println!("input {:?}", inpt);
println!("gold {:?}", gold);

let orig_arr = Array::new(&inpt, dim4!(5, 2));

let view_out = view!( orig_arr[2:3:1] );
let mut res1 = vec![0i32; view_out.elements()];
view_out.host(&mut res1);

let rows_out = rows(&orig_arr, 2, 3);
let mut res2 = vec![0i32; rows_out.elements()];
rows_out.host(&mut res2);

assert_eq!(gold, res1);
assert_eq!(res1, res2);
}
}