From 290f1273fad93ebbfae3b1f2b6b597101d94710c Mon Sep 17 00:00:00 2001 From: carrotflakes Date: Mon, 11 Mar 2024 05:17:19 +0900 Subject: [PATCH] Add `flip` tensor operator --- burn-book/src/building-blocks/tensor.md | 1 + crates/burn-autodiff/src/ops/bool_tensor.rs | 4 + crates/burn-autodiff/src/ops/int_tensor.rs | 4 + crates/burn-autodiff/src/ops/tensor.rs | 56 +++++++ crates/burn-autodiff/src/tests/flip.rs | 28 ++++ crates/burn-autodiff/src/tests/mod.rs | 2 + crates/burn-candle/src/lib.rs | 1 + crates/burn-candle/src/ops/base.rs | 21 ++- crates/burn-candle/src/ops/bool_tensor.rs | 11 +- crates/burn-candle/src/ops/int_tensor.rs | 8 +- crates/burn-candle/src/ops/tensor.rs | 11 +- crates/burn-fusion/src/ops/boolean.rs | 41 ++++- crates/burn-fusion/src/ops/float.rs | 49 +++++- crates/burn-fusion/src/ops/int.rs | 40 ++++- crates/burn-fusion/src/stream/context.rs | 27 +-- crates/burn-fusion/src/stream/operation.rs | 21 ++- crates/burn-jit/src/kernel/index/flip.rs | 157 ++++++++++++++++++ crates/burn-jit/src/kernel/index/mod.rs | 2 + crates/burn-jit/src/ops/bool_ops.rs | 7 + crates/burn-jit/src/ops/float_ops.rs | 7 + crates/burn-jit/src/ops/int_ops.rs | 4 + crates/burn-ndarray/src/ops/base.rs | 29 ++++ crates/burn-ndarray/src/ops/bool_tensor.rs | 7 + crates/burn-ndarray/src/ops/int_tensor.rs | 7 + crates/burn-ndarray/src/ops/tensor.rs | 7 + crates/burn-tch/src/ops/base.rs | 6 + crates/burn-tch/src/ops/bool_tensor.rs | 4 + crates/burn-tch/src/ops/int_tensor.rs | 7 + crates/burn-tch/src/ops/tensor.rs | 7 + crates/burn-tensor/src/tensor/api/base.rs | 51 ++++++ crates/burn-tensor/src/tensor/api/check.rs | 28 ++++ .../burn-tensor/src/tensor/ops/bool_tensor.rs | 10 ++ .../burn-tensor/src/tensor/ops/int_tensor.rs | 10 ++ crates/burn-tensor/src/tensor/ops/tensor.rs | 10 ++ crates/burn-tensor/src/tests/mod.rs | 1 + crates/burn-tensor/src/tests/ops/flip.rs | 105 ++++++++++++ crates/burn-tensor/src/tests/ops/mod.rs | 1 + 37 files changed, 758 insertions(+), 34 deletions(-) create mode 100644 crates/burn-autodiff/src/tests/flip.rs create mode 100644 crates/burn-jit/src/kernel/index/flip.rs create mode 100644 crates/burn-tensor/src/tests/ops/flip.rs diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 8db1b18c0d..3cefc6ab00 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -149,6 +149,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. | `tensor.dims()` | `tensor.size()` | | `tensor.equal(other)` | `x == y` | | `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | +| `tensor.flip(axes)` | `tensor.flip(axes)` | | `tensor.into_data()` | N/A | | `tensor.into_primitive()` | N/A | | `tensor.into_scalar()` | `tensor.item()` | diff --git a/crates/burn-autodiff/src/ops/bool_tensor.rs b/crates/burn-autodiff/src/ops/bool_tensor.rs index 4c7d336c17..1d509c6b17 100644 --- a/crates/burn-autodiff/src/ops/bool_tensor.rs +++ b/crates/burn-autodiff/src/ops/bool_tensor.rs @@ -117,6 +117,10 @@ impl BoolTensorOps for Autodiff { B::bool_permute(tensor, axes) } + fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor { + B::bool_flip(tensor, axes) + } + fn bool_argwhere(tensor: BoolTensor) -> IntTensor { B::bool_argwhere(tensor) } diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index 940dbb1b96..a1ff3327ac 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -353,6 +353,10 @@ impl IntTensorOps for Autodiff { B::int_permute(tensor, axes) } + fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { + B::int_flip(tensor, axes) + } + fn int_sign(tensor: IntTensor) -> IntTensor { B::int_sign(tensor) } diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index 16303c1ef6..685c119208 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -738,6 +738,62 @@ impl FloatTensorOps for Autodiff } } + fn float_flip( + tensor: FloatTensor, + axes: &[usize], + ) -> FloatTensor { + #[derive(Debug)] + struct FlipDim; + + #[derive(new, Debug)] + struct RetroFlipDims { + input_id: NodeID, + axes: Vec, + _backend: PhantomData, + } + + impl RetroForward for RetroFlipDims { + fn forward(&self, states: &mut BackwardStates, out_node: NodeID) { + let input = states.get_state::>(&self.input_id); + let out = B::float_flip(input, &self.axes); + states.save(out_node, out) + } + } + + impl Backward for FlipDim { + type State = Vec; + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + _checkpointer: &mut Checkpointer, + ) { + let axes = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + B::float_flip(grad, &axes) + }); + } + } + + match FlipDim + .prepare::([tensor.node.clone()], [tensor.graph.clone()]) + .memory_bound() + .retro_forward(RetroFlipDims::::new( + tensor.node.id.clone(), + axes.to_vec(), + )) + .parents([&tensor]) + .stateful() + { + OpsKind::Tracked(prep) => { + prep.finish(axes.to_vec(), B::float_flip(tensor.primitive, axes)) + } + OpsKind::UnTracked(prep) => prep.finish(B::float_flip(tensor.primitive, axes)), + } + } + fn float_reshape( tensor: FloatTensor, shape: Shape, diff --git a/crates/burn-autodiff/src/tests/flip.rs b/crates/burn-autodiff/src/tests/flip.rs new file mode 100644 index 0000000000..40ec29c765 --- /dev/null +++ b/crates/burn-autodiff/src/tests/flip.rs @@ -0,0 +1,28 @@ +#[burn_tensor_testgen::testgen(ad_flip)] +mod tests { + use super::*; + use burn_tensor::Data; + + #[test] + fn should_diff_flip() { + let data_1: Data = Data::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2 + let data_2: Data = Data::from([[[3.0, 2.0, 7.0], [3.0, 3.2, 1.0]]]); // 1x2x3 + + let device = Default::default(); + let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad(); + let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); + + let tensor_3 = tensor_2.clone().flip([1, 2]); + let tensor_4 = tensor_1.clone().matmul(tensor_3); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + assert_eq!(grad_1.to_data(), Data::from([[[7.2, 12.0], [7.2, 12.0]]])); // 1x2x2 + assert_eq!( + grad_2.to_data(), + Data::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]) // 1x2x3 + ); + } +} diff --git a/crates/burn-autodiff/src/tests/mod.rs b/crates/burn-autodiff/src/tests/mod.rs index b1b9f4b09e..1e7585139a 100644 --- a/crates/burn-autodiff/src/tests/mod.rs +++ b/crates/burn-autodiff/src/tests/mod.rs @@ -21,6 +21,7 @@ mod cross_entropy; mod div; mod erf; mod exp; +mod flip; mod gather_scatter; mod gelu; mod gradients; @@ -114,6 +115,7 @@ macro_rules! testgen_all { burn_autodiff::testgen_ad_sigmoid!(); burn_autodiff::testgen_ad_transpose!(); burn_autodiff::testgen_ad_permute!(); + burn_autodiff::testgen_ad_flip!(); burn_autodiff::testgen_ad_nonzero!(); burn_autodiff::testgen_ad_sign!(); }; diff --git a/crates/burn-candle/src/lib.rs b/crates/burn-candle/src/lib.rs index 2dbf458f68..48879ef7bc 100644 --- a/crates/burn-candle/src/lib.rs +++ b/crates/burn-candle/src/lib.rs @@ -81,6 +81,7 @@ mod tests { burn_tensor::testgen_mul!(); burn_tensor::testgen_neg!(); burn_tensor::testgen_permute!(); + burn_tensor::testgen_flip!(); burn_tensor::testgen_argwhere_nonzero!(); burn_tensor::testgen_sign!(); diff --git a/crates/burn-candle/src/ops/base.rs b/crates/burn-candle/src/ops/base.rs index 71c9adc400..5a58c33628 100644 --- a/crates/burn-candle/src/ops/base.rs +++ b/crates/burn-candle/src/ops/base.rs @@ -23,7 +23,6 @@ pub fn from_data( ) -> CandleTensor { CandleTensor::from_data(data, *device) } - pub fn into_data(tensor: CandleTensor) -> Data { Data::new( tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(), @@ -60,6 +59,26 @@ pub fn permute( CandleTensor::new(tensor.tensor.permute(axes).unwrap()) } +pub fn flip( + tensor: CandleTensor, + axes: &[usize], +) -> CandleTensor { + // FIXME: Replace with an appropriate method when Candle provides one. + let mut tensor = tensor.tensor; + for &axis in axes { + let indexes = candle_core::Tensor::arange_step( + tensor.dim(axis).unwrap() as i64 - 1, + -1, + -1, + tensor.device(), + ) + .unwrap(); + tensor = tensor.index_select(&indexes, axis).unwrap(); + } + + CandleTensor::new(tensor) +} + pub fn reshape( tensor: CandleTensor, shape: Shape, diff --git a/crates/burn-candle/src/ops/bool_tensor.rs b/crates/burn-candle/src/ops/bool_tensor.rs index 994d630f97..6106856ae0 100644 --- a/crates/burn-candle/src/ops/bool_tensor.rs +++ b/crates/burn-candle/src/ops/bool_tensor.rs @@ -8,8 +8,6 @@ use crate::{ Candle, CandleTensor, }; -use super::base::permute; - impl BoolTensorOps for Candle { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { super::base::empty(shape, device) @@ -133,6 +131,13 @@ impl BoolTensorOps for Candle< tensor: BoolTensor, axes: [usize; D], ) -> BoolTensor { - permute(tensor, axes) + super::base::permute(tensor, axes) + } + + fn bool_flip( + tensor: BoolTensor, + axes: &[usize], + ) -> BoolTensor { + super::base::flip(tensor, axes) } } diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 946cc52fdb..62c9da1f56 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -8,8 +8,6 @@ use crate::{ Candle, CandleTensor, }; -use super::base::permute; - impl IntTensorOps for Candle { fn int_empty(shape: Shape, device: &Device) -> IntTensor { super::base::empty(shape, device) @@ -425,7 +423,11 @@ impl IntTensorOps for Candle, axes: [usize; D], ) -> IntTensor { - permute(tensor, axes) + super::base::permute(tensor, axes) + } + + fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { + super::base::flip(tensor, axes) } // TODO add sign operator once Candle supports it: diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index 0c9105531c..3693bcdd97 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -11,8 +11,6 @@ use crate::{ Candle, CandleTensor, }; -use super::base::permute; - impl FloatTensorOps for Candle { fn float_from_data( data: Data, @@ -522,7 +520,14 @@ impl FloatTensorOps for Candle tensor: FloatTensor, axes: [usize; D], ) -> FloatTensor { - permute(tensor, axes) + super::base::permute(tensor, axes) + } + + fn float_flip( + tensor: FloatTensor, + axes: &[usize], + ) -> FloatTensor { + super::base::flip(tensor, axes) } // TODO add sign operator once Candle supports it: diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index d6e418d2fc..b80ac5d6b7 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -4,9 +4,9 @@ use crate::{ ops::binary::binary_ops_shape, stream::{ BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, - CatOperationDescription, Operation, OperationDescription, PermuteOperationDescription, - ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId, - SwapDimsDescription, UnaryOperationDescription, + CatOperationDescription, FlipOperationDescription, Operation, OperationDescription, + PermuteOperationDescription, ReshapeDescription, SliceAssignOperationDescription, + SliceOperationDescription, StreamId, SwapDimsDescription, UnaryOperationDescription, }, Fusion, FusionBackend, }; @@ -466,4 +466,39 @@ impl BoolTensorOps for Fusion { out } + + fn bool_flip( + tensor: BoolTensor, + axes: &[usize], + ) -> BoolTensor { + #[derive(new)] + struct FlipOps { + desc: FlipOperationDescription, + } + + impl Operation for FlipOps { + fn execute(self: Box, handles: &mut crate::HandleContainer) { + let input = handles.get_bool_tensor::(&self.desc.input); + let output = B::bool_flip(input, self.desc.axes.as_slice()); + handles.register_bool_tensor(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + let desc = FlipOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + axes: axes.to_vec(), + }; + + out.client.register( + vec![stream], + OperationDescription::BaseBool(BaseOperationDescription::Flip(desc.clone())), + FlipOps::::new(desc), + ); + + out + } } diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 129205efa6..d04fc18c71 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -6,13 +6,13 @@ use crate::{ scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, stream::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, - ClampOperationDescription, FloatOperationDescription, GatherOperationDescription, - MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, - Operation, OperationDescription, PermuteOperationDescription, RandomOperationDescription, - ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription, - UnaryOperationDescription, + ClampOperationDescription, FlipOperationDescription, FloatOperationDescription, + GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription, + NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription, + RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription, + ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, + SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + StreamId, SwapDimsDescription, UnaryOperationDescription, }, unary_float_ops, Fusion, FusionBackend, TensorDescription, }; @@ -1846,4 +1846,39 @@ impl FloatTensorOps for Fusion { out } + + fn float_flip( + tensor: FloatTensor, + axes: &[usize], + ) -> FloatTensor { + #[derive(new)] + struct FlipOps { + desc: FlipOperationDescription, + } + + impl Operation for FlipOps { + fn execute(self: Box, handles: &mut crate::HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); + let output = B::float_flip(input, &self.desc.axes); + handles.register_float_tensor(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + let desc = FlipOperationDescription { + input: tensor.into_description(), + axes: axes.to_vec(), + out: out.to_description_out(), + }; + + out.client.register( + vec![stream], + OperationDescription::BaseInt(BaseOperationDescription::Flip(desc.clone())), + FlipOps::::new(desc), + ); + + out + } } diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 5d2515ba6f..d2b324a73f 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -6,9 +6,9 @@ use crate::{ scalar_int_cmp_ops, scalar_int_ops, stream::{ self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, - ClampOperationDescription, GatherOperationDescription, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, Operation, - OperationDescription, PermuteOperationDescription, RandomOperationDescription, + ClampOperationDescription, FlipOperationDescription, GatherOperationDescription, + MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, + Operation, OperationDescription, PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription, @@ -1552,4 +1552,38 @@ impl IntTensorOps for Fusion { out } + + fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { + #[derive(new)] + struct FlipDimsOps { + desc: FlipOperationDescription, + } + + impl Operation for FlipDimsOps { + fn execute(self: Box, handles: &mut crate::HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); + let axes = &self.desc.axes; + let output = B::int_flip(input, axes); + handles.register_int_tensor(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + + let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); + + let desc = FlipOperationDescription { + input: tensor.into_description(), + axes: axes.to_vec(), + out: out.to_description_out(), + }; + + out.client.register( + vec![stream], + OperationDescription::BaseInt(BaseOperationDescription::Flip(desc.clone())), + FlipDimsOps::::new(desc), + ); + + out + } } diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index ac27d8fc93..ced3be8ab9 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -4,16 +4,16 @@ use super::{ AvgPool2dBackwardDescription, AvgPool2dDescription, BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, ClampOperationDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, ConvTranspose2dDescription, - EmbeddingBackwardDescription, EmbeddingDescription, FloatOperationDescription, - GatherOperationDescription, IntOperationDescription, InterpolateBackwardDescription, - InterpolateDescription, MaskFillOperationDescription, MaskWhereOperationDescription, - MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription, - MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, - ModuleOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription, - SelectAssignOperationDescription, SelectOperationDescription, SliceOperationDescription, - SwapDimsDescription, UnaryOperationDescription, + EmbeddingBackwardDescription, EmbeddingDescription, FlipOperationDescription, + FloatOperationDescription, GatherOperationDescription, IntOperationDescription, + InterpolateBackwardDescription, InterpolateDescription, MaskFillOperationDescription, + MaskWhereOperationDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, + MaxPool1dWithIndicesDescription, MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription, + MaxPool2dWithIndicesDescription, ModuleOperationDescription, NumericOperationDescription, + OperationDescription, PermuteOperationDescription, RandomOperationDescription, + ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription, + ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, + SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, }; use crate::{FusionBackend, HandleContainer, TensorDescription, TensorId}; use burn_tensor::{Element, ElementConversion}; @@ -798,6 +798,13 @@ impl BaseOperationDescription { axes: desc.axes.clone(), }) } + BaseOperationDescription::Flip(desc) => { + BaseOperationDescription::Flip(FlipOperationDescription { + input: desc.input.to_relative(converter), + out: desc.out.to_relative(converter), + axes: desc.axes.clone(), + }) + } BaseOperationDescription::Slice(desc) => { BaseOperationDescription::Slice(SliceOperationDescription { tensor: desc.tensor.to_relative(converter), diff --git a/crates/burn-fusion/src/stream/operation.rs b/crates/burn-fusion/src/stream/operation.rs index a0b46c0091..1cd9e44717 100644 --- a/crates/burn-fusion/src/stream/operation.rs +++ b/crates/burn-fusion/src/stream/operation.rs @@ -155,6 +155,12 @@ pub enum BaseOperationDescription { /// Bool => [permute](burn_tensor::ops::BoolTensorOps::bool_permute). Permute(PermuteOperationDescription), + /// Operation corresponding to: + /// Float => [flip](burn_tensor::ops::FloatTensorOps::float_flip). + /// Int => [flip](burn_tensor::ops::IntTensorOps::int_flip). + /// Bool => [flip](burn_tensor::ops::BoolTensorOps::bool_flip). + Flip(FlipOperationDescription), + /// Operation corresponding to: /// /// Float => [slice](burn_tensor::ops::FloatTensorOps::float_slice). @@ -456,6 +462,17 @@ pub struct PermuteOperationDescription { pub axes: Vec, } +/// Flip operation description. +#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] +pub struct FlipOperationDescription { + /// Input tensor description. + pub input: TensorDescription, + /// Output tensor description. + pub out: TensorDescription, + /// The dimensions to flip. + pub axes: Vec, +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct RandomOperationDescription { @@ -1034,10 +1051,12 @@ impl BaseOperationDescription { BaseOperationDescription::SwapDims(desc) => { vec![&desc.input, &desc.out] } - BaseOperationDescription::Permute(desc) => { vec![&desc.input, &desc.out] } + BaseOperationDescription::Flip(desc) => { + vec![&desc.input, &desc.out] + } BaseOperationDescription::Slice(desc) => { vec![&desc.tensor, &desc.out] } diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs new file mode 100644 index 0000000000..9fe98ef4db --- /dev/null +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -0,0 +1,157 @@ +use crate::{ + codegen::{ + dialect::gpu::{gpu, Elem, Scope, Variable, Visibility}, + execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, + InputInfo, OutputInfo, WorkgroupLaunch, + }, + element::JitElement, + kernel::{DynamicKernelSource, SourceTemplate}, + ops::numeric::empty_device, + tensor::JitTensor, + Runtime, RuntimeInt, +}; +use burn_tensor::ElementConversion; +use std::marker::PhantomData; + +#[derive(new)] +struct FlipEagerKernel { + rank: usize, + _runtime: PhantomData, + _elem: PhantomData, +} + +pub struct FlipComputeShader { + input: Variable, + output: Variable, + rank: usize, +} + +impl FlipComputeShader { + pub fn expand(self, scope: &mut Scope) { + let input = self.input; + let output = self.output; + let id = Variable::Id; + + let offset_input = scope.zero(Elem::UInt); + let offset_local = scope.create_local(Elem::UInt); + + let stride = scope.create_local(Elem::UInt); + let shape = scope.create_local(Elem::UInt); + let flip = scope.create_local(Elem::UInt); + let flip_bool = scope.create_local(Elem::Bool); + + for i in 0..self.rank { + gpu!(scope, stride = stride(input, i)); + gpu!(scope, shape = shape(output, i)); + gpu!( + scope, + flip = cast(Variable::GlobalScalar(i as u16, Elem::UInt)) + ); + gpu!(scope, flip_bool = flip == 1u32); + + gpu!(scope, offset_local = id / stride); + gpu!(scope, offset_local = offset_local % shape); + + gpu!(scope, if(flip_bool).then(|scope| { + gpu!(scope, offset_local = shape - offset_local); + gpu!(scope, offset_local = offset_local - 1u32); + })); + gpu!(scope, offset_local = offset_local * stride); + + gpu!(scope, offset_input += offset_local); + } + + let result = scope.create_local(input.item()); + gpu!(scope, result = input[offset_input]); + gpu!(scope, output[id] = result); + } +} + +impl DynamicKernelSource for FlipEagerKernel { + fn source(&self) -> crate::kernel::SourceTemplate { + let mut scope = Scope::root(); + let item = E::gpu_elem().into(); + + let input = Variable::GlobalInputArray(0, item); + let output = Variable::GlobalOutputArray(0, item); + + scope.write_global_custom(output); + + FlipComputeShader { + input, + output, + rank: self.rank, + } + .expand(&mut scope); + + let input = InputInfo::Array { + item, + visibility: Visibility::Read, + }; + let flip_dims = InputInfo::Scalar { + elem: Elem::UInt, + size: self.rank, + }; + let output = OutputInfo::Array { item }; + + let info = CompilationInfo { + inputs: vec![input, flip_dims], + outputs: vec![output], + scope, + }; + + let settings = CompilationSettings::default(); + let shader = Compilation::new(info).compile(settings); + let shader = ::compile(shader); + SourceTemplate::new(shader.to_string()) + } + + fn id(&self) -> String { + format!("{:?}-rank={:?}", core::any::TypeId::of::(), self.rank) + } +} + +pub(crate) fn flip( + tensor: JitTensor, + indices: &[usize], +) -> JitTensor { + let output = empty_device( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + flip_on_output(tensor, output, indices) +} + +pub(crate) fn flip_on_output( + tensor: JitTensor, + output: JitTensor, + indices: &[usize], +) -> JitTensor { + let mut scalars = Vec::with_capacity(D); + + for i in 0..D { + scalars.push((indices.contains(&i) as u32).elem()); + } + + let kernel = FlipEagerKernel::new(D); + + execute_dynamic::, RuntimeInt>( + &[EagerHandle::new( + &tensor.handle, + &tensor.strides, + &tensor.shape.dims, + )], + &[EagerHandle::new( + &output.handle, + &output.strides, + &output.shape.dims, + )], + Some(&scalars), + kernel, + WorkgroupLaunch::Output { pos: 0 }, + tensor.client, + ); + + output +} diff --git a/crates/burn-jit/src/kernel/index/mod.rs b/crates/burn-jit/src/kernel/index/mod.rs index cebd07c1e0..b9f7946716 100644 --- a/crates/burn-jit/src/kernel/index/mod.rs +++ b/crates/burn-jit/src/kernel/index/mod.rs @@ -1,3 +1,4 @@ +mod flip; mod gather; mod repeat; mod scatter; @@ -6,6 +7,7 @@ mod select_assign; mod slice; mod slice_assign; +pub use flip::*; pub use repeat::*; pub use select::*; pub use select_assign::*; diff --git a/crates/burn-jit/src/ops/bool_ops.rs b/crates/burn-jit/src/ops/bool_ops.rs index 73f47d4947..bba905f401 100644 --- a/crates/burn-jit/src/ops/bool_ops.rs +++ b/crates/burn-jit/src/ops/bool_ops.rs @@ -120,4 +120,11 @@ impl BoolTensorOps for JitBackend { ) -> BoolTensor { permute(tensor, axes) } + + fn bool_flip( + tensor: BoolTensor, + axes: &[usize], + ) -> BoolTensor { + kernel::flip(tensor, axes) + } } diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 5ebc62ec35..a8bba869a4 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -516,4 +516,11 @@ impl FloatTensorOps for JitBackend { ) -> FloatTensor { permute(tensor, axes) } + + fn float_flip( + tensor: FloatTensor, + axes: &[usize], + ) -> FloatTensor { + kernel::flip(tensor, axes) + } } diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 11fdfd362e..ff043577f9 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -348,4 +348,8 @@ impl IntTensorOps for JitBackend { ) -> IntTensor { permute(tensor, axes) } + + fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { + kernel::flip(tensor, axes) + } } diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index d8a105b3f1..c3dbe77fb3 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -4,6 +4,7 @@ use burn_tensor::ElementConversion; use core::{marker::PhantomData, ops::Range}; use ndarray::s; use ndarray::Array2; +use ndarray::SliceInfo; use ndarray::Zip; use num_traits::Signed; @@ -111,6 +112,34 @@ where NdArrayTensor::new(array) } + + pub fn flip( + tensor: NdArrayTensor, + axes: &[usize], + ) -> NdArrayTensor { + let slice_items: Vec<_> = (0..D) + .map(|i| { + if axes.contains(&i) { + SliceInfoElem::Slice { + start: 0, + end: None, + step: -1, + } + } else { + SliceInfoElem::Slice { + start: 0, + end: None, + step: 1, + } + } + }) + .collect(); + let slice_info = + SliceInfo::, IxDyn, IxDyn>::try_from(slice_items).unwrap(); + let array = tensor.array.slice(slice_info).into_owned().into_shared(); + + NdArrayTensor::new(array) + } } impl NdArrayMathOps diff --git a/crates/burn-ndarray/src/ops/bool_tensor.rs b/crates/burn-ndarray/src/ops/bool_tensor.rs index 3b8c418aca..a760e48dc7 100644 --- a/crates/burn-ndarray/src/ops/bool_tensor.rs +++ b/crates/burn-ndarray/src/ops/bool_tensor.rs @@ -137,4 +137,11 @@ impl BoolTensorOps for NdArray { let array = tensor.array.permuted_axes(axes.into_dimension()); NdArrayTensor { array } } + + fn bool_flip( + tensor: burn_tensor::ops::BoolTensor, + axes: &[usize], + ) -> burn_tensor::ops::BoolTensor { + NdArrayOps::flip(tensor, axes) + } } diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index bd1315a1e8..d464371a7d 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -446,6 +446,13 @@ impl IntTensorOps for NdArray { NdArrayTensor { array } } + fn int_flip( + tensor: burn_tensor::ops::IntTensor, + axes: &[usize], + ) -> burn_tensor::ops::IntTensor { + NdArrayOps::flip(tensor, axes) + } + fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::sign_op(tensor) } diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index a15b0e5ce7..30457db558 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -494,6 +494,13 @@ impl FloatTensorOps for NdArray { NdArrayTensor { array } } + fn float_flip( + tensor: burn_tensor::ops::FloatTensor, + axes: &[usize], + ) -> burn_tensor::ops::FloatTensor { + NdArrayOps::flip(tensor, axes) + } + fn float_sign(tensor: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::sign_op(tensor) } diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index ca75a51d9b..568c985228 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -446,6 +446,12 @@ impl TchOps { TchTensor::new(tensor) } + pub fn flip(tensor: TchTensor, axes: &[usize]) -> TchTensor { + let dims = axes.iter().map(|x| *x as i64).collect::>(); + let tensor = tensor.tensor.flip(dims); + TchTensor::new(tensor) + } + pub fn narrow( tensor: TchTensor, dim: usize, diff --git a/crates/burn-tch/src/ops/bool_tensor.rs b/crates/burn-tch/src/ops/bool_tensor.rs index df8d7d2ebf..ddf3ac4ab4 100644 --- a/crates/burn-tch/src/ops/bool_tensor.rs +++ b/crates/burn-tch/src/ops/bool_tensor.rs @@ -139,6 +139,10 @@ impl BoolTensorOps for LibTorch { TchOps::permute(tensor, axes) } + fn bool_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor { + TchOps::flip(tensor, axes) + } + fn bool_argwhere( tensor: as Backend>::BoolTensorPrimitive, ) -> TchTensor { diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index 900b2d7cf6..f0564aa3c4 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -474,6 +474,13 @@ impl IntTensorOps for LibTorch { TchOps::permute(tensor, axes) } + fn int_flip( + tensor: burn_tensor::ops::IntTensor, + axes: &[usize], + ) -> burn_tensor::ops::IntTensor { + TchOps::flip(tensor, axes) + } + fn int_sign( tensor: as Backend>::IntTensorPrimitive, ) -> as Backend>::IntTensorPrimitive { diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index 87089c4bb0..c373cc232a 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -483,6 +483,13 @@ impl FloatTensorOps for LibTorch { TchOps::permute(tensor, axes) } + fn float_flip( + tensor: burn_tensor::ops::FloatTensor, + axes: &[usize], + ) -> burn_tensor::ops::FloatTensor { + TchOps::flip(tensor, axes) + } + fn float_sign( tensor: as Backend>::FloatTensorPrimitive, ) -> as Backend>::FloatTensorPrimitive { diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index c38b226e8c..5b10f71550 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -170,6 +170,33 @@ where Tensor::new(K::permute(self.primitive, transformed_axes)) } + /// Reverse the order of elements in the tensor along the given dimensions. + /// + /// # Arguments + /// + /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions. + /// The values can be negative, in which case they are used as an offset from the end. + /// + /// # Returns + /// + /// The tensor with the axes flipped. + pub fn flip(self, axes: [isize; N]) -> Tensor { + // Convert the axes to usize and handle negative values without using vector + let mut transformed_axes: [usize; N] = [0; N]; + for (i, &x) in axes.iter().enumerate() { + transformed_axes[i] = if x < 0 { + (D as isize + x) as usize + } else { + x as usize + }; + } + + // Check if the axes are valid + check!(TensorCheck::flip(D, &transformed_axes)); + + Tensor::new(K::flip(self.primitive, &transformed_axes)) + } + /// Flatten the tensor along a given range of dimensions. /// /// This function collapses the specified range of dimensions into a single dimension, @@ -1130,6 +1157,18 @@ pub trait BasicOps: TensorKind { /// The tensor with the dimensions permuted. fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive; + /// Flips the tensor along the given axes. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to flip. + /// * `axes` - The axes to flip the tensor along. + /// + /// # Returns + /// + /// The tensor with the axes flipped. + fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive; + /// Select tensor elements corresponding for the given ranges. /// /// # Arguments @@ -1558,6 +1597,10 @@ impl BasicOps for Float { fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { B::float_permute(tensor, axes) } + + fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { + B::float_flip(tensor, axes) + } } impl BasicOps for Int { @@ -1672,6 +1715,10 @@ impl BasicOps for Int { fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { B::int_permute(tensor, axes) } + + fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { + B::int_flip(tensor, axes) + } } impl BasicOps for Bool { @@ -1786,6 +1833,10 @@ impl BasicOps for Bool { fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive { B::bool_permute(tensor, axes) } + + fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive { + B::bool_flip(tensor, axes) + } } /// Trait used for reshape arguments. diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index e719bbab44..f31b280ef9 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -349,6 +349,34 @@ impl TensorCheck { check } + pub(crate) fn flip(rank: usize, axes: &[usize]) -> Self { + let check = Self::Ok; + + // Check if the axes are within the tensor dimensions + if let Some(axis) = axes.iter().find(|&x| *x >= rank) { + return check.register( + "flip", + TensorError::new("The axes must be smaller than the tensor dimension.").details( + format!("The '{axis}' axis is greater than {rank} dimensions."), + ), + ); + } + + // Check if the axes are unique + let mut dedup = axes.to_vec(); + dedup.sort_unstable(); + dedup.dedup(); + if dedup.len() != axes.len() { + return check.register( + "flip", + TensorError::new("The axes must be unique.") + .details(format!("The axes '{axes:?}' are not unique.")), + ); + } + + check + } + pub(crate) fn matmul( lhs: &Tensor, rhs: &Tensor, diff --git a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs index 0fca91c5ca..69896b3c90 100644 --- a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs @@ -291,6 +291,16 @@ pub trait BoolTensorOps { fn bool_permute(tensor: BoolTensor, axes: [usize; D]) -> BoolTensor; + /// Reverse the order of elements in a tensor along the given axes. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to reverse. + /// * `axes` - The axes to reverse. + /// + /// The tensor with the elements reversed. + fn bool_flip(tensor: BoolTensor, axes: &[usize]) -> BoolTensor; + /// Returns a new tensor with the given dimension narrowed to the given range. /// /// # Arguments diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index f8ee790f04..55ae6b735e 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -985,6 +985,16 @@ pub trait IntTensorOps { /// The tensor with the dimensions permuted. fn int_permute(tensor: IntTensor, axes: [usize; D]) -> IntTensor; + /// Reverse the order of elements in a tensor along the given axes. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to reverse. + /// * `axes` - The axes to reverse. + /// + /// The tensor with the elements reversed. + fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor; + /// Returns a new tensor with the given dimension narrowed to the given range. /// /// # Arguments diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index d02126e2c9..dc022be5db 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -447,6 +447,16 @@ pub trait FloatTensorOps { axes: [usize; D], ) -> FloatTensor; + /// Reverse the order of elements in a tensor along the given axes. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to reverse. + /// * `axes` - The axes to reverse. + /// + /// The tensor with the elements reversed. + fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor; + /// Reshapes a tensor. /// /// # Arguments diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 619c212588..5eedccc7cc 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -86,6 +86,7 @@ macro_rules! testgen_all { burn_tensor::testgen_any!(); burn_tensor::testgen_all_op!(); burn_tensor::testgen_permute!(); + burn_tensor::testgen_flip!(); burn_tensor::testgen_bool!(); burn_tensor::testgen_argwhere_nonzero!(); burn_tensor::testgen_sign!(); diff --git a/crates/burn-tensor/src/tests/ops/flip.rs b/crates/burn-tensor/src/tests/ops/flip.rs new file mode 100644 index 0000000000..4c88166c49 --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/flip.rs @@ -0,0 +1,105 @@ +#[burn_tensor_testgen::testgen(flip)] +mod tests { + use super::*; + use burn_tensor::backend::Backend; + use burn_tensor::{Data, Device, Int, Shape, Tensor}; + + #[test] + fn normal_int() { + let device = Default::default(); + let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + + let flipped = tensor.clone().flip([0, 2]); + + // from pytorch: + // import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)) + let data_expected = Data::from([ + [[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]], + [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]], + ]); + + assert_eq!(data_expected, flipped.into_data()); + + // Test with no flip + let flipped = tensor.clone().flip([]); + assert_eq!(tensor.into_data(), flipped.into_data()); + } + + #[test] + fn normal_float() { + let device = Default::default(); + let tensor = Tensor::::arange(0..24, &device) + .reshape([2, 3, 4]) + .float(); + + let flipped = tensor.clone().flip([0, 2]); + + // from pytorch: + // import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)).float() + let data_expected = Data::from([ + [ + [15., 14., 13., 12.], + [19., 18., 17., 16.], + [23., 22., 21., 20.], + ], + [[3., 2., 1., 0.], [7., 6., 5., 4.], [11., 10., 9., 8.]], + ]); + + assert_eq!(data_expected, flipped.into_data()); + + // Test with no flip + let flipped = tensor.clone().flip([]); + assert_eq!(tensor.into_data(), flipped.into_data()); + } + + #[test] + fn normal_bool() { + let device = Default::default(); + let tensor = Tensor::::arange(0..24, &device) + .reshape([2, 3, 4]) + .greater_elem(10); + + let flipped = tensor.clone().flip([0, 2]); + + // from pytorch: + // import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)).gt(10) + let data_expected = Data::from([ + [ + [true, true, true, true], + [true, true, true, true], + [true, true, true, true], + ], + [ + [false, false, false, false], + [false, false, false, false], + [true, false, false, false], + ], + ]); + + assert_eq!(data_expected, flipped.into_data()); + + // Test with no flip + let flipped = tensor.clone().flip([]); + assert_eq!(tensor.into_data(), flipped.into_data()); + } + + #[test] + #[should_panic] + fn edge_duplicated_axes() { + let device = Default::default(); + let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + + // Test with a duplicated axis + let _ = tensor.clone().flip([0, 0, 1]); + } + + #[test] + #[should_panic] + fn edge_out_of_bound_axis() { + let device = Default::default(); + let tensor = Tensor::::arange(0..24, &device).reshape([2, 3, 4]); + + // Test with an out of bound axis + let _ = tensor.clone().flip([3, 0, 1]); + } +} diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index bac8521e27..bb4c4c8ccd 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -19,6 +19,7 @@ mod div; mod erf; mod exp; mod flatten; +mod flip; mod full; mod gather_scatter; mod init;