diff --git a/src/data/mod.rs b/src/data/mod.rs index e917a757f..67e9b3d49 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -8,6 +8,7 @@ use crate::defines::AfError; use crate::dim4::Dim4; use crate::error::HANDLE_ERROR; use crate::util::{AfArray, DimT, HasAfEnum, Intl, MutAfArray, Uintl}; +use std::option::Option; use std::vec::Vec; #[allow(dead_code)] @@ -468,46 +469,114 @@ where temp.into() } -macro_rules! data_func_def { - ($doc_str: expr, $fn_name:ident, $ffi_name: ident) => { - #[doc=$doc_str] - /// - ///# Parameters - /// - /// - `input` is the input Array - /// - `dims` is the target(output) dimensions - /// - ///# Return Values - /// - /// An Array with modified data. - #[allow(unused_mut)] - pub fn $fn_name(input: &Array, dims: Dim4) -> Array - where - T: HasAfEnum, - { - let mut temp: i64 = 0; - unsafe { - let err_val = $ffi_name( - &mut temp as MutAfArray, - input.get() as AfArray, - dims[0] as c_uint, - dims[1] as c_uint, - dims[2] as c_uint, - dims[3] as c_uint, - ); - HANDLE_ERROR(AfError::from(err_val)); +/// Tile the input array along specified dimension +/// +/// Tile essentially creates copies of data along each dimension. +/// The number of copies created is provided by the user on per +/// axis basis using [Dim4](./struct.dim4.html) +/// +///# Parameters +/// +/// - `input` is the input Array +/// - `dims` is the target(output) dimensions +/// +///# Return Values +/// +/// Tiled input array as per the tiling dimensions provided +#[allow(unused_mut)] +pub fn tile(input: &Array, dims: Dim4) -> Array +where + T: HasAfEnum, +{ + let mut temp: i64 = 0; + unsafe { + let err_val = af_tile( + &mut temp as MutAfArray, + input.get() as AfArray, + dims[0] as c_uint, + dims[1] as c_uint, + dims[2] as c_uint, + dims[3] as c_uint, + ); + HANDLE_ERROR(AfError::from(err_val)); + } + temp.into() +} + +/// Reorder the array in specified order +/// +/// The default order of axes in ArrayFire is axis with smallest distance +/// between adjacent elements towards an axis with highest distance between +/// adjacent elements. +/// +///# Parameters +/// +/// - `input` is the input Array +/// - `new_axis0` is the new first axis for output +/// - `new_axis1` is the new second axis for output +/// - `next_axes` is the new axes order for output +/// +///# Return Values +/// +/// Array with data reordered as per the new axes order +pub fn reorder_v2( + input: &Array, + new_axis0: u64, + new_axis1: u64, + next_axes: Option>, +) -> Array +where + T: HasAfEnum, +{ + let mut new_axes = vec![new_axis0, new_axis1]; + match next_axes { + Some(v) => { + for axis in v { + new_axes.push(axis); } - temp.into() + } + None => { + new_axes.push(2); + new_axes.push(3); } }; + + let mut temp: i64 = 0; + unsafe { + let err_val = af_reorder( + &mut temp as MutAfArray, + input.get() as AfArray, + new_axes[0] as c_uint, + new_axes[1] as c_uint, + new_axes[2] as c_uint, + new_axes[3] as c_uint, + ); + HANDLE_ERROR(AfError::from(err_val)); + } + temp.into() } -data_func_def!( - "Tile the input array along specified dimension", - tile, - af_tile -); -data_func_def!("Reorder the array in specified order", reorder, af_reorder); +/// Reorder the array in specified order +/// +/// The default order of axes in ArrayFire is axis with smallest distance +/// between adjacent elements towards an axis with highest distance between +/// adjacent elements. +/// +///# Parameters +/// +/// - `input` is the input Array +/// - `dims` is the target(output) dimensions +/// +///# Return Values +/// +/// Array with data reordered as per the new axes order +#[deprecated(since = "3.6.3", note = "Please use new reorder API")] +pub fn reorder(input: &Array, dims: Dim4) -> Array +where + T: HasAfEnum, +{ + reorder_v2(input, dims[0], dims[1], Some(vec![dims[2], dims[3]])) +} ///"Circular shift of values along specified dimension ///