From cadf65c5f35b90095de8caf236ee9b899119acf3 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 6 Feb 2024 18:29:01 -0500 Subject: [PATCH] stateful has an implicit sequence of 1 --- dfdx/src/nn/layers/mamba_minimal.rs | 86 ++++++++++++----------------- 1 file changed, 35 insertions(+), 51 deletions(-) diff --git a/dfdx/src/nn/layers/mamba_minimal.rs b/dfdx/src/nn/layers/mamba_minimal.rs index 2b60b2b1..59e81fa3 100644 --- a/dfdx/src/nn/layers/mamba_minimal.rs +++ b/dfdx/src/nn/layers/mamba_minimal.rs @@ -810,7 +810,7 @@ pub mod stateful { T: Tape, > Module<( - Tensor<(Batch, C1, DModel), E, D, T>, + Tensor<(Batch, DModel), E, D, T>, MambaStateCache, )> for MambaBlock where @@ -842,7 +842,7 @@ pub mod stateful { ): dfdx_core::tensor_ops::TryConcatShapeAlong, Output = (Batch, DInner, DConv)>, { type Output = ( - Tensor<(Batch, C1, DModel), E, D, T>, + Tensor<(Batch, DModel), E, D, T>, MambaStateCache, ); @@ -850,27 +850,26 @@ pub mod stateful { fn try_forward( &self, x: ( - Tensor<(Batch, C1, DModel), E, D, T>, + Tensor<(Batch, DModel), E, D, T>, MambaStateCache, ), ) -> Result { let (x, mut cache) = x; - // let (batch, _d_model) = *x.shape(); let (batch, d_inner, d_conv) = *cache.conv_state.shape(); // layer 1 (in_proj) let (xs, res): ( - Tensor<(Batch, C1, DInner), _, _, _>, - Tensor<(Batch, C1, DInner), _, _, _>, + Tensor<(Batch, DInner), _, _, _>, + Tensor<(Batch, DInner), _, _, _>, ) = { // projects the input DModel into 2*DInner - let xs_and_res: Tensor<(Batch, C1, >::Output), _, _, _> = + let xs_and_res: Tensor<(Batch, >::Output), _, _, _> = self.in_proj.try_forward(x)?; // splits xs_and_res into (xs, res) let (xs, res, _tape) = - xs_and_res.try_split_tensor_along(Axis::<2>, d_inner, d_inner)?; + xs_and_res.try_split_tensor_along(Axis::<1>, d_inner, d_inner)?; (xs, res) }; @@ -893,12 +892,11 @@ pub mod stateful { )?; // then concat with the xs as the last column (by the right side) let xs: Tensor<(Batch, DInner, C1), _, _, _> = - xs.try_permute::<_, Axes3<0, 2, 1>>()?; - // let xs = xs.try_reshape_like(&(batch, d_inner, Const::<1>))?; + xs.try_reshape_like(&(batch, d_inner, Const::<1>))?; (conv_state, xs).try_concat_tensor_along(Axis::<2>)? }; - let xs: Tensor<(Batch, C1, DInner), E, _, _> = { + let xs: Tensor<(Batch, DInner), E, _, _> = { let conv1d = self .conv1d .weight @@ -913,9 +911,7 @@ pub mod stateful { let xs = self.conv1d_bias.try_forward(xs)?; // activation - let xs = xs.try_silu()?; - - xs.try_reshape_like(&(batch, Const::<1>, d_inner))? + xs.try_silu()? }; let (ss, cache_ssm_state) = ss_step::( @@ -929,7 +925,7 @@ pub mod stateful { )?; let ys = ss.try_mul(res.try_silu()?)?; - let y: Tensor<(Batch, C1, DModel), _, _, _> = self.out_proj.try_forward(ys)?; + let y: Tensor<(Batch, DModel), _, _, _> = self.out_proj.try_forward(ys)?; cache.ssm_state = cache_ssm_state; @@ -957,13 +953,13 @@ pub mod stateful { // a: Tensor<(DInner, DState), E, D, T>, d: Tensor<(DInner,), E, D, T>, - u: Tensor<(Batch, C1, DInner), E, D, T>, + u: Tensor<(Batch, DInner), E, D, T>, x_proj: &MatMul>::Output>>::Output, E, D>, dt_proj: &Linear, ssm_state_cache: Tensor<(Batch, DInner, DState), E, D, T>, ) -> Result< ( - Tensor<(Batch, C1, DInner), E, D, T>, + Tensor<(Batch, DInner), E, D, T>, Tensor<(Batch, DInner, DState), E, D, T>, ), dfdx::tensor::Error, @@ -987,25 +983,25 @@ pub mod stateful { // this is input independent (see Section 3.5.2 "Interpretation of A" form the Mamba paper for why A isn't selective) let a: Tensor<(DInner, DState), _, _, _> = a.try_exp()?.try_negate()?; - // (Batch, 1, DtRank + DState * 2) - let x_dbl: Tensor<(Batch, C1, _), _, _, _> = x_proj.try_forward(u.retaped::())?; + // (Batch, DtRank + DState * 2) + let x_dbl: Tensor<(Batch, _), _, _, _> = x_proj.try_forward(u.retaped::())?; // ∆ (part 1/2) // ∆ is input-dependent - let (delta, x_dbl_tail, _tape): (Tensor<(Batch, C1, DtRank), _, _, _>, _, _) = - x_dbl.try_split_tensor_along(Axis::<2>, dt_rank, d_state * Const::<2>)?; + let (delta, x_dbl_tail, _tape): (Tensor<(Batch, DtRank), _, _, _>, _, _) = + x_dbl.try_split_tensor_along(Axis::<1>, dt_rank, d_state * Const::<2>)?; // B and C // B and C are input-dependent let (b, c, _tape): ( - Tensor<(Batch, C1, DState), _, _, _>, - Tensor<(Batch, C1, DState), _, _, _>, + Tensor<(Batch, DState), _, _, _>, + Tensor<(Batch, DState), _, _, _>, _, - ) = x_dbl_tail.try_split_tensor_along(Axis::<2>, d_state, d_state)?; + ) = x_dbl_tail.try_split_tensor_along(Axis::<1>, d_state, d_state)?; // ∆ (part 2/2) // ∆ is input-dependent - let delta: Tensor<(Batch, C1, DInner), _, _, _> = { + let delta: Tensor<(Batch, DInner), _, _, _> = { // note: don't add dt_proj bias let delta = delta.try_matmul( dt_proj @@ -1021,22 +1017,14 @@ pub mod stateful { dt_proj .bias .retaped::() - .try_broadcast_like(&(batch, Const::<1>, d_inner))?, + .try_broadcast_like(&(batch, d_inner))?, )? .try_exp()? .try_add(one)?) .try_ln()? }; - selective_scan_step::( - delta.try_permute::<_, Axes3<0, 2, 1>>()?, - a, - b, - c.try_permute::<_, Axes3<1, 0, 2>>()?, - d, - u, - ssm_state_cache, - ) + selective_scan_step::(delta, a, b, c, d, u, ssm_state_cache) } // Selective Scan. @@ -1057,16 +1045,16 @@ pub mod stateful { D: Device, T: Tape, >( - delta: Tensor<(Batch, DInner, C1), E, D, T>, + delta: Tensor<(Batch, DInner), E, D, T>, a: Tensor<(DInner, DState), E, D, T>, - b: Tensor<(Batch, C1, DState), E, D, T>, - c: Tensor<(C1, Batch, DState), E, D, T>, + b: Tensor<(Batch, DState), E, D, T>, + c: Tensor<(Batch, DState), E, D, T>, d: Tensor<(DInner,), E, D, T>, - u: Tensor<(Batch, C1, DInner), E, D, T>, + u: Tensor<(Batch, DInner), E, D, T>, mut ssm_state_cache: Tensor<(Batch, DInner, DState), E, D, T>, ) -> Result< ( - Tensor<(Batch, C1, DInner), E, D, T>, + Tensor<(Batch, DInner), E, D, T>, Tensor<(Batch, DInner, DState), E, D, T>, ), dfdx::tensor::Error, @@ -1078,15 +1066,15 @@ pub mod stateful { // - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: // "A is the more important term and the performance doesn't change much with the simplification on B" let (delta_a, delta_bu): ( - Tensor<(Batch, DInner, C1, DState), _, _, _>, - Tensor<(Batch, DInner, C1, DState), _, _, _>, + Tensor<(Batch, DInner, DState), _, _, _>, + Tensor<(Batch, DInner, DState), _, _, _>, ) = { - let target_shape = (batch, d_inner, Const::<1>, d_state); + let target_shape = (batch, d_inner, d_state); let delta_broadcasted = delta.try_broadcast_like(&target_shape)?; let a = a.try_broadcast_like(&target_shape)?; - let delta_a: Tensor<(Batch, DInner, C1, DState), _, _, _> = + let delta_a: Tensor<(Batch, DInner, DState), _, _, _> = delta_broadcasted.retaped::().try_mul(a)?.try_exp()?; let b = b.try_broadcast_like(&target_shape)?; @@ -1106,13 +1094,9 @@ pub mod stateful { let y = ssm_state_cache .retaped::() - .try_matmul(c.try_permute::<_, Axes3<1, 2, 0>>()?)?; - let du = d - .try_broadcast_like(&(batch, Const::<1>, d_inner))? - .try_mul(u)?; - let y = y - .try_reshape_like(&(batch, Const::<1>, d_inner))? - .try_add(du)?; + .try_matmul(c.try_reshape_like(&(batch, d_state, Const::<1>))?)?; + let du = d.try_broadcast_like(&(batch, d_inner))?.try_mul(u)?; + let y = y.try_reshape_like(&(batch, d_inner))?.try_add(du)?; Ok((y, ssm_state_cache)) }