Skip to content

Fix negative indexing in row/col/slice functions #258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/helloworld.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ fn main() {
);
println!("Revision: {}", get_revision());

let num_rows: u64 = 5;
let num_cols: u64 = 3;
let num_rows: i64 = 5;
let num_cols: i64 = 3;
let values: [f32; 3] = [1.0, 2.0, 3.0];
let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));

af_print!("Indices ", indices);

let dims = Dim4::new(&[num_rows, num_cols, 1, 1]);
let dims = Dim4::new(&[num_rows as u64, num_cols as u64, 1, 1]);

let mut a = randu::<f32>(dims);
af_print!("Create a 5-by-3 float matrix on the GPU", a);
Expand Down
90 changes: 72 additions & 18 deletions src/core/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ where
/// print(&a);
/// print(&row(&a, 4));
/// ```
pub fn row<T>(input: &Array<T>, row_num: u64) -> Array<T>
pub fn row<T>(input: &Array<T>, row_num: i64) -> Array<T>
where
T: HasAfEnum,
{
Expand All @@ -301,7 +301,7 @@ where
}

/// Set `row_num`^th row in `inout` Array to a new Array `new_row`
pub fn set_row<T>(inout: &mut Array<T>, new_row: &Array<T>, row_num: u64)
pub fn set_row<T>(inout: &mut Array<T>, new_row: &Array<T>, row_num: i64)
where
T: HasAfEnum,
{
Expand All @@ -313,22 +313,24 @@ where
}

/// Get an Array with all rows from `first` to `last` in the `input` Array
pub fn rows<T>(input: &Array<T>, first: u64, last: u64) -> Array<T>
pub fn rows<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
index(
input,
&[Seq::new(first as f64, last as f64, 1.0), Seq::default()],
&[Seq::new(first as f64, last as f64, step), Seq::default()],
)
}

/// Set rows from `first` to `last` in `inout` Array with rows from Array `new_rows`
pub fn set_rows<T>(inout: &mut Array<T>, new_rows: &Array<T>, first: u64, last: u64)
pub fn set_rows<T>(inout: &mut Array<T>, new_rows: &Array<T>, first: i64, last: i64)
where
T: HasAfEnum,
{
let seqs = [Seq::new(first as f64, last as f64, 1.0), Seq::default()];
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [Seq::new(first as f64, last as f64, step), Seq::default()];
assign_seq(inout, &seqs, new_rows)
}

Expand All @@ -344,7 +346,7 @@ where
/// println!("Grab last col of the random matrix");
/// print(&col(&a, 4));
/// ```
pub fn col<T>(input: &Array<T>, col_num: u64) -> Array<T>
pub fn col<T>(input: &Array<T>, col_num: i64) -> Array<T>
where
T: HasAfEnum,
{
Expand All @@ -358,7 +360,7 @@ where
}

/// Set `col_num`^th col in `inout` Array to a new Array `new_col`
pub fn set_col<T>(inout: &mut Array<T>, new_col: &Array<T>, col_num: u64)
pub fn set_col<T>(inout: &mut Array<T>, new_col: &Array<T>, col_num: i64)
where
T: HasAfEnum,
{
Expand All @@ -370,29 +372,31 @@ where
}

/// Get all cols from `first` to `last` in the `input` Array
pub fn cols<T>(input: &Array<T>, first: u64, last: u64) -> Array<T>
pub fn cols<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
index(
input,
&[Seq::default(), Seq::new(first as f64, last as f64, 1.0)],
&[Seq::default(), Seq::new(first as f64, last as f64, step)],
)
}

/// Set cols from `first` to `last` in `inout` Array with cols from Array `new_cols`
pub fn set_cols<T>(inout: &mut Array<T>, new_cols: &Array<T>, first: u64, last: u64)
pub fn set_cols<T>(inout: &mut Array<T>, new_cols: &Array<T>, first: i64, last: i64)
where
T: HasAfEnum,
{
let seqs = [Seq::default(), Seq::new(first as f64, last as f64, 1.0)];
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [Seq::default(), Seq::new(first as f64, last as f64, step)];
assign_seq(inout, &seqs, new_cols)
}

/// Get `slice_num`^th slice from `input` Array
///
/// Slices indicate that the indexing is along 3rd dimension
pub fn slice<T>(input: &Array<T>, slice_num: u64) -> Array<T>
pub fn slice<T>(input: &Array<T>, slice_num: i64) -> Array<T>
where
T: HasAfEnum,
{
Expand All @@ -407,7 +411,7 @@ where
/// Set slice `slice_num` in `inout` Array to a new Array `new_slice`
///
/// Slices indicate that the indexing is along 3rd dimension
pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: u64)
pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: i64)
where
T: HasAfEnum,
{
Expand All @@ -422,29 +426,31 @@ where
/// Get slices from `first` to `last` in `input` Array
///
/// Slices indicate that the indexing is along 3rd dimension
pub fn slices<T>(input: &Array<T>, first: u64, last: u64) -> Array<T>
pub fn slices<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [
Seq::default(),
Seq::default(),
Seq::new(first as f64, last as f64, 1.0),
Seq::new(first as f64, last as f64, step),
];
index(input, &seqs)
}

/// Set `first` to `last` slices of `inout` Array to a new Array `new_slices`
///
/// Slices indicate that the indexing is along 3rd dimension
pub fn set_slices<T>(inout: &mut Array<T>, new_slices: &Array<T>, first: u64, last: u64)
pub fn set_slices<T>(inout: &mut Array<T>, new_slices: &Array<T>, first: i64, last: i64)
where
T: HasAfEnum,
{
let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
let seqs = [
Seq::default(),
Seq::default(),
Seq::new(first as f64, last as f64, 1.0),
Seq::new(first as f64, last as f64, step),
];
assign_seq(inout, &seqs, new_slices)
}
Expand Down Expand Up @@ -644,6 +650,7 @@ mod tests {
use super::super::data::constant;
use super::super::dim4::Dim4;
use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer};
use super::super::index::{cols, rows};
use super::super::random::randu;
use super::super::seq::Seq;

Expand Down Expand Up @@ -800,4 +807,51 @@ mod tests {
// 0.7896
// ANCHOR_END: setrow
}

#[test]
fn get_row() {
// ANCHOR: get_row
let a = randu::<f32>(dim4!(5, 5));
// [5 5 1 1]
// 0.6010 0.5497 0.1583 0.3636 0.6755
// 0.0278 0.2864 0.3712 0.4165 0.6105
// 0.9806 0.3410 0.3543 0.5814 0.5232
// 0.2126 0.7509 0.6450 0.8962 0.5567
// 0.0655 0.4105 0.9675 0.3712 0.7896
let _r = row(&a, -1);
// [1 5 1 1]
// 0.0655 0.4105 0.9675 0.3712 0.7896
let _c = col(&a, -1);
// [5 1 1 1]
// 0.6755
// 0.6105
// 0.5232
// 0.5567
// 0.7896
// ANCHOR_END: get_row
}

#[test]
fn get_rows() {
// ANCHOR: get_rows
let a = randu::<f32>(dim4!(5, 5));
// [5 5 1 1]
// 0.6010 0.5497 0.1583 0.3636 0.6755
// 0.0278 0.2864 0.3712 0.4165 0.6105
// 0.9806 0.3410 0.3543 0.5814 0.5232
// 0.2126 0.7509 0.6450 0.8962 0.5567
// 0.0655 0.4105 0.9675 0.3712 0.7896
let _r = rows(&a, -1, -2);
// [2 5 1 1]
// 0.2126 0.7509 0.6450 0.8962 0.5567
// 0.0655 0.4105 0.9675 0.3712 0.7896
let _c = cols(&a, -1, -3);
// [5 3 1 1]
// 0.1583 0.3636 0.6755
// 0.3712 0.4165 0.6105
// 0.3543 0.5814 0.5232
// 0.6450 0.8962 0.5567
// 0.9675 0.3712 0.7896
// ANCHOR_END: get_rows
}
}
33 changes: 28 additions & 5 deletions tutorials-book/src/indexing.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@
Indexing in ArrayFire is a powerful but easy to abuse feature. This feature allows you to reference
or copy subsections of a larger array and perform operations on only a subset of elements.

This chapter is split into the following sections:
- [Index an Array using Seq Objects](#using-seq-objects)
- [Create a view of an existing Array](#create-a-view-of-an-existing-array)
- [Modify a sub region of an existing Array](#modify-a-sub-region-of-an-existing-array)
- [Using Array and Seq combination](#using-array-and-seq-combination)
- [Create a view of an existing Array](#create-a-view-of-an-existing-array)
- [Modify a sub region of an existing Array](#modify-a-sub-region-of-an-existing-array)
- [Extract or Set rows/columns of an Array](#extract-or-set-rowscolumns-of-an-array)
- [Negative Indices](#negative-indices)

[Indexer][1] structure is the key element used in Rust wrapper of ArrayFire for creating references
to existing Arrays. Given below are few of such functions and their corresponding use cases. Use
[Indexer::new][2] to create an Indexer object and set either a `Seq` object or `Array` as indexing
object for a given dimension.
to existing Arrays. The above sections illustrate how it can be used in conjunction with `Seq`
and/or `Array`. Apart from that, each section also showcases a macro based equivalent
code(if one exists) that is more terse in syntax but offers the same functionality.

## Using Seq objects

Expand Down Expand Up @@ -74,7 +84,7 @@ We will use [assign\_gen][13] function to do it.
{{#include ../../src/core/macros.rs:macro_seq_array_assign}}
```

## Extract or Set rows/coloumns of an Array
## Extract or Set rows/columns of an Array

Extract a specific set of rows/coloumns from an existing Array.

Expand All @@ -88,8 +98,21 @@ Similarly, [set\_row][7] & [set\_rows][9] can be used to change the values in a
rows using another Array. [set\_col][8] & [set\_cols][10] has same functionality, except that it is
for coloumns.

## Negative Indices

Negative indices can also be used to refer elements from the end of a given axis. Negative value for
a row/column/slice will fetch corresponding row/column/slice in reverse order. Given below are some
examples that showcase getting row(s)/col(s) from an existing Array.

```rust,noplaypen
{{#include ../../src/core/index.rs:get_row}}
```

```rust,noplaypen
{{#include ../../src/core/index.rs:get_rows}}
```

[1]: http://arrayfire.org/arrayfire-rust/arrayfire/struct.Indexer.html
[2]: http://arrayfire.org/arrayfire-rust/arrayfire/struct.Indexer.html#method.new
[3]: http://arrayfire.org/arrayfire-rust/arrayfire/fn.index.html
[4]: http://arrayfire.org/arrayfire-rust/arrayfire/fn.assign_seq.html
[5]: http://arrayfire.org/arrayfire-rust/arrayfire/fn.rows.html
Expand Down