Skip to content

Commit

Permalink
Merge pull request #15 from rpl-cmu/dtype-rename
Browse files Browse the repository at this point in the history
Dtype rename
  • Loading branch information
contagon authored Nov 12, 2024
2 parents 00a6c00 + 9f11f78 commit f88873e
Show file tree
Hide file tree
Showing 19 changed files with 418 additions and 490 deletions.
8 changes: 4 additions & 4 deletions src/linalg/dual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ impl<

// TODO: Expand on this instead of including in Variable??
pub trait DualConvert {
type Alias<D: Numeric>;
fn dual_convert<D: Numeric>(other: &Self::Alias<dtype>) -> Self::Alias<D>;
type Alias<T: Numeric>;
fn dual_convert<T: Numeric>(other: &Self::Alias<dtype>) -> Self::Alias<T>;
}

impl<const R: usize, const C: usize, T: Numeric> DualConvert for Matrix<R, C, T> {
type Alias<D: Numeric> = Matrix<R, C, D>;
fn dual_convert<D: Numeric>(other: &Self::Alias<dtype>) -> Self::Alias<D> {
type Alias<TT: Numeric> = Matrix<R, C, TT>;
fn dual_convert<TT: Numeric>(other: &Self::Alias<dtype>) -> Self::Alias<TT> {
other.map(|x| x.into())
}
}
12 changes: 5 additions & 7 deletions src/linalg/forward_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ use paste::paste;

use super::{
dual::{DualAllocator, DualVector},
AllocatorBuffer,
Diff,
MatrixDim,
AllocatorBuffer, Diff, MatrixDim,
};
use crate::{
dtype,
Expand All @@ -29,7 +27,7 @@ use crate::{
/// variables::SO2,
/// };
///
/// fn f<D: Numeric>(x: SO2<D>, y: SO2<D>) -> VectorX<D> {
/// fn f<T: Numeric>(x: SO2<T>, y: SO2<T>) -> VectorX<T> {
/// x.ominus(&y)
/// }
///
Expand All @@ -48,12 +46,12 @@ macro_rules! forward_maker {
($num:expr, $( ($name:ident: $var:ident) ),*) => {
paste! {
#[allow(unused_assignments)]
fn [<jacobian_ $num>]<$( $var: Variable<Alias<dtype> = $var>, )* F: Fn($($var::Alias<Self::D>,)*) -> VectorX<Self::D>>
fn [<jacobian_ $num>]<$( $var: Variable<Alias<dtype> = $var>, )* F: Fn($($var::Alias<Self::T>,)*) -> VectorX<Self::T>>
(f: F, $($name: &$var,)*) -> DiffResult<VectorX, MatrixX>{
// Prepare variables
let mut curr_dim = 0;
$(
let $name: $var::Alias<Self::D> = $var::dual($name, curr_dim);
let $name: $var::Alias<Self::T> = $var::dual($name, curr_dim);
curr_dim += $name.dim();
)*

Expand Down Expand Up @@ -85,7 +83,7 @@ where
DefaultAllocator: DualAllocator<N>,
DualVector<N>: Copy,
{
type D = DualVector<N>;
type T = DualVector<N>;

forward_maker!(1, (v1: V1));
forward_maker!(2, (v1: V1), (v2: V2));
Expand Down
8 changes: 4 additions & 4 deletions src/linalg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ pub struct DiffResult<V, G> {
macro_rules! fn_maker {
(grad, $num:expr, $( ($name:ident: $var:ident) ),*) => {
paste! {
fn [<gradient_ $num>]<$( $var: Variable<Alias<dtype> = $var>, )* F: Fn($($var::Alias<Self::D>,)*) -> Self::D>
fn [<gradient_ $num>]<$( $var: Variable<Alias<dtype> = $var>, )* F: Fn($($var::Alias<Self::T>,)*) -> Self::T>
(f: F, $($name: &$var,)*) -> DiffResult<dtype, VectorX>{
let f_wrapped = |$($name: $var::Alias<Self::D>,)*| vectorx![f($($name.clone(),)*)];
let f_wrapped = |$($name: $var::Alias<Self::T>,)*| vectorx![f($($name.clone(),)*)];
let DiffResult { value, diff } = Self::[<jacobian_ $num>](f_wrapped, $($name,)*);
let diff = VectorX::from_iterator(diff.len(), diff.iter().cloned());
DiffResult { value: value[0], diff }
Expand All @@ -85,7 +85,7 @@ macro_rules! fn_maker {

(jac, $num:expr, $( ($name:ident: $var:ident) ),*) => {
paste! {
fn [<jacobian_ $num>]<$( $var: Variable<Alias<$crate::dtype>=$var>, )* F: Fn($($var::Alias<Self::D>,)*) -> VectorX<Self::D>>
fn [<jacobian_ $num>]<$( $var: Variable<Alias<$crate::dtype>=$var>, )* F: Fn($($var::Alias<Self::T>,)*) -> VectorX<Self::T>>
(f: F, $($name: &$var,)*) -> DiffResult<VectorX, MatrixX>;
}
};
Expand All @@ -101,7 +101,7 @@ macro_rules! fn_maker {
/// recommend [ForwardProp] which functions using dual numbers.
pub trait Diff {
/// The dtype of the variables
type D: Numeric;
type T: Numeric;

fn_maker!(grad, 1, (v1: V1));
fn_maker!(grad, 2, (v1: V1), (v2: V2));
Expand Down
142 changes: 71 additions & 71 deletions src/linalg/nalgebra_wrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,100 +13,100 @@ pub type AllocatorBuffer<N> = <DefaultAllocator as Allocator<N>>::Buffer<dtype>;

// ------------------------- Vector/Matrix Aliases ------------------------- //
// Vectors
pub type Vector<const N: usize, D = dtype> = na::SVector<D, N>;
pub type VectorX<D = dtype> = na::DVector<D>;
pub type Vector1<D = dtype> = na::SVector<D, 1>;
pub type Vector2<D = dtype> = na::SVector<D, 2>;
pub type Vector3<D = dtype> = na::SVector<D, 3>;
pub type Vector4<D = dtype> = na::SVector<D, 4>;
pub type Vector5<D = dtype> = na::SVector<D, 5>;
pub type Vector6<D = dtype> = na::SVector<D, 6>;
pub type Vector<const N: usize, T = dtype> = na::SVector<T, N>;
pub type VectorX<T = dtype> = na::DVector<T>;
pub type Vector1<T = dtype> = na::SVector<T, 1>;
pub type Vector2<T = dtype> = na::SVector<T, 2>;
pub type Vector3<T = dtype> = na::SVector<T, 3>;
pub type Vector4<T = dtype> = na::SVector<T, 4>;
pub type Vector5<T = dtype> = na::SVector<T, 5>;
pub type Vector6<T = dtype> = na::SVector<T, 6>;

// Matrices
// square
pub type MatrixX<D = dtype> = na::DMatrix<D>;
pub type Matrix1<D = dtype> = na::Matrix1<D>;
pub type Matrix2<D = dtype> = na::Matrix2<D>;
pub type Matrix3<D = dtype> = na::Matrix3<D>;
pub type Matrix4<D = dtype> = na::Matrix4<D>;
pub type Matrix5<D = dtype> = na::Matrix5<D>;
pub type Matrix6<D = dtype> = na::Matrix6<D>;
pub type MatrixX<T = dtype> = na::DMatrix<T>;
pub type Matrix1<T = dtype> = na::Matrix1<T>;
pub type Matrix2<T = dtype> = na::Matrix2<T>;
pub type Matrix3<T = dtype> = na::Matrix3<T>;
pub type Matrix4<T = dtype> = na::Matrix4<T>;
pub type Matrix5<T = dtype> = na::Matrix5<T>;
pub type Matrix6<T = dtype> = na::Matrix6<T>;

// row
pub type Matrix1xX<D = dtype> = na::Matrix1xX<D>;
pub type Matrix1x2<D = dtype> = na::Matrix1x2<D>;
pub type Matrix1x3<D = dtype> = na::Matrix1x3<D>;
pub type Matrix1x4<D = dtype> = na::Matrix1x4<D>;
pub type Matrix1x5<D = dtype> = na::Matrix1x5<D>;
pub type Matrix1x6<D = dtype> = na::Matrix1x6<D>;
pub type Matrix1xX<T = dtype> = na::Matrix1xX<T>;
pub type Matrix1x2<T = dtype> = na::Matrix1x2<T>;
pub type Matrix1x3<T = dtype> = na::Matrix1x3<T>;
pub type Matrix1x4<T = dtype> = na::Matrix1x4<T>;
pub type Matrix1x5<T = dtype> = na::Matrix1x5<T>;
pub type Matrix1x6<T = dtype> = na::Matrix1x6<T>;

// two rows
pub type Matrix2xX<D = dtype> = na::Matrix2xX<D>;
pub type Matrix2x3<D = dtype> = na::Matrix2x3<D>;
pub type Matrix2x4<D = dtype> = na::Matrix2x4<D>;
pub type Matrix2x5<D = dtype> = na::Matrix2x5<D>;
pub type Matrix2x6<D = dtype> = na::Matrix2x6<D>;
pub type Matrix2xX<T = dtype> = na::Matrix2xX<T>;
pub type Matrix2x3<T = dtype> = na::Matrix2x3<T>;
pub type Matrix2x4<T = dtype> = na::Matrix2x4<T>;
pub type Matrix2x5<T = dtype> = na::Matrix2x5<T>;
pub type Matrix2x6<T = dtype> = na::Matrix2x6<T>;

// three rows
pub type Matrix3xX<D = dtype> = na::Matrix3xX<D>;
pub type Matrix3x2<D = dtype> = na::Matrix3x2<D>;
pub type Matrix3x4<D = dtype> = na::Matrix3x4<D>;
pub type Matrix3x5<D = dtype> = na::Matrix3x5<D>;
pub type Matrix3x6<D = dtype> = na::Matrix3x6<D>;
pub type Matrix3xX<T = dtype> = na::Matrix3xX<T>;
pub type Matrix3x2<T = dtype> = na::Matrix3x2<T>;
pub type Matrix3x4<T = dtype> = na::Matrix3x4<T>;
pub type Matrix3x5<T = dtype> = na::Matrix3x5<T>;
pub type Matrix3x6<T = dtype> = na::Matrix3x6<T>;

// four rows
pub type Matrix4xX<D = dtype> = na::Matrix4xX<D>;
pub type Matrix4x2<D = dtype> = na::Matrix4x2<D>;
pub type Matrix4x3<D = dtype> = na::Matrix4x3<D>;
pub type Matrix4x5<D = dtype> = na::Matrix4x5<D>;
pub type Matrix4x6<D = dtype> = na::Matrix4x6<D>;
pub type Matrix4xX<T = dtype> = na::Matrix4xX<T>;
pub type Matrix4x2<T = dtype> = na::Matrix4x2<T>;
pub type Matrix4x3<T = dtype> = na::Matrix4x3<T>;
pub type Matrix4x5<T = dtype> = na::Matrix4x5<T>;
pub type Matrix4x6<T = dtype> = na::Matrix4x6<T>;

// five rows
pub type Matrix5xX<D = dtype> = na::Matrix5xX<D>;
pub type Matrix5x2<D = dtype> = na::Matrix5x2<D>;
pub type Matrix5x3<D = dtype> = na::Matrix5x3<D>;
pub type Matrix5x4<D = dtype> = na::Matrix5x4<D>;
pub type Matrix5x6<D = dtype> = na::Matrix5x6<D>;
pub type Matrix5xX<T = dtype> = na::Matrix5xX<T>;
pub type Matrix5x2<T = dtype> = na::Matrix5x2<T>;
pub type Matrix5x3<T = dtype> = na::Matrix5x3<T>;
pub type Matrix5x4<T = dtype> = na::Matrix5x4<T>;
pub type Matrix5x6<T = dtype> = na::Matrix5x6<T>;

// six rows
pub type Matrix6xX<D = dtype> = na::Matrix6xX<D>;
pub type Matrix6x2<D = dtype> = na::Matrix6x2<D>;
pub type Matrix6x3<D = dtype> = na::Matrix6x3<D>;
pub type Matrix6x4<D = dtype> = na::Matrix6x4<D>;
pub type Matrix6x5<D = dtype> = na::Matrix6x5<D>;
pub type Matrix6xX<T = dtype> = na::Matrix6xX<T>;
pub type Matrix6x2<T = dtype> = na::Matrix6x2<T>;
pub type Matrix6x3<T = dtype> = na::Matrix6x3<T>;
pub type Matrix6x4<T = dtype> = na::Matrix6x4<T>;
pub type Matrix6x5<T = dtype> = na::Matrix6x5<T>;

// dynamic rows
pub type MatrixXx2<D = dtype> = na::MatrixXx2<D>;
pub type MatrixXx3<D = dtype> = na::MatrixXx3<D>;
pub type MatrixXx4<D = dtype> = na::MatrixXx4<D>;
pub type MatrixXx5<D = dtype> = na::MatrixXx5<D>;
pub type MatrixXx6<D = dtype> = na::MatrixXx6<D>;
pub type MatrixXxN<const N: usize, D = dtype> =
na::Matrix<D, Dyn, Const<N>, na::VecStorage<D, Dyn, Const<N>>>;
pub type MatrixXx2<T = dtype> = na::MatrixXx2<T>;
pub type MatrixXx3<T = dtype> = na::MatrixXx3<T>;
pub type MatrixXx4<T = dtype> = na::MatrixXx4<T>;
pub type MatrixXx5<T = dtype> = na::MatrixXx5<T>;
pub type MatrixXx6<T = dtype> = na::MatrixXx6<T>;
pub type MatrixXxN<const N: usize, T = dtype> =
na::Matrix<T, Dyn, Const<N>, na::VecStorage<T, Dyn, Const<N>>>;

// Views - aka references of matrices
pub type MatrixViewX<'a, D = dtype> = na::MatrixView<'a, D, Dyn, Dyn>;
pub type MatrixViewX<'a, T = dtype> = na::MatrixView<'a, T, Dyn, Dyn>;

pub type Matrix<const R: usize, const C: usize = 1, D = dtype> = na::Matrix<
D,
pub type Matrix<const R: usize, const C: usize = 1, T = dtype> = na::Matrix<
T,
Const<R>,
Const<C>,
<na::DefaultAllocator as Allocator<Const<R>, Const<C>>>::Buffer<D>,
<na::DefaultAllocator as Allocator<Const<R>, Const<C>>>::Buffer<T>,
>;
pub type MatrixView<'a, const R: usize, const C: usize = 1, D = dtype> =
na::MatrixView<'a, D, Const<R>, Const<C>>;
pub type MatrixView<'a, const R: usize, const C: usize = 1, T = dtype> =
na::MatrixView<'a, T, Const<R>, Const<C>>;

pub type VectorView<'a, const N: usize, D = dtype> = na::VectorView<'a, D, Const<N>>;
pub type VectorViewX<'a, D = dtype> = na::VectorView<'a, D, Dyn>;
pub type VectorView1<'a, D = dtype> = na::VectorView<'a, D, Const<1>>;
pub type VectorView2<'a, D = dtype> = na::VectorView<'a, D, Const<2>>;
pub type VectorView3<'a, D = dtype> = na::VectorView<'a, D, Const<3>>;
pub type VectorView4<'a, D = dtype> = na::VectorView<'a, D, Const<4>>;
pub type VectorView5<'a, D = dtype> = na::VectorView<'a, D, Const<5>>;
pub type VectorView6<'a, D = dtype> = na::VectorView<'a, D, Const<6>>;
pub type VectorView<'a, const N: usize, T = dtype> = na::VectorView<'a, T, Const<N>>;
pub type VectorViewX<'a, T = dtype> = na::VectorView<'a, T, Dyn>;
pub type VectorView1<'a, T = dtype> = na::VectorView<'a, T, Const<1>>;
pub type VectorView2<'a, T = dtype> = na::VectorView<'a, T, Const<2>>;
pub type VectorView3<'a, T = dtype> = na::VectorView<'a, T, Const<3>>;
pub type VectorView4<'a, T = dtype> = na::VectorView<'a, T, Const<4>>;
pub type VectorView5<'a, T = dtype> = na::VectorView<'a, T, Const<5>>;
pub type VectorView6<'a, T = dtype> = na::VectorView<'a, T, Const<6>>;

// Generic, taking in sizes with Const
pub type VectorDim<N, D = dtype> = OVector<D, N>;
pub type MatrixDim<R, C = Const<1>, D = dtype> =
na::Matrix<D, R, C, <na::DefaultAllocator as Allocator<R, C>>::Buffer<D>>;
pub type MatrixViewDim<'a, R, C = Const<1>, D = dtype> = na::MatrixView<'a, D, R, C>;
pub type VectorDim<N, T = dtype> = OVector<T, N>;
pub type MatrixDim<R, C = Const<1>, T = dtype> =
na::Matrix<T, R, C, <na::DefaultAllocator as Allocator<R, C>>::Buffer<T>>;
pub type MatrixViewDim<'a, R, C = Const<1>, T = dtype> = na::MatrixView<'a, T, R, C>;
2 changes: 1 addition & 1 deletion src/linalg/numerical_diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ macro_rules! numerical_maker {
}

impl<const PWR: i32> Diff for NumericalDiff<PWR> {
type D = dtype;
type T = dtype;

numerical_maker!(1, (0, v1, V1));
numerical_maker!(2, (0, v1, V1), (1, v2, V2));
Expand Down
4 changes: 2 additions & 2 deletions src/residuals/between.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ where
type DimOut = P::Dim;
type DimIn = DimNameSum<P::Dim, P::Dim>;

fn residual2<D: Numeric>(&self, v1: P::Alias<D>, v2: P::Alias<D>) -> VectorX<D> {
let delta = P::dual_convert::<D>(&self.delta);
fn residual2<T: Numeric>(&self, v1: P::Alias<T>, v2: P::Alias<T>) -> VectorX<T> {
let delta = P::dual_convert::<T>(&self.delta);
v1.compose(&delta).ominus(&v2)
}
}
Expand Down
Loading

0 comments on commit f88873e

Please sign in to comment.