diff --git a/Cargo.lock b/Cargo.lock index 7f46e08..d5d1e80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -955,8 +955,8 @@ dependencies = [ "faer", "faer-ext", "log", - "matrixcompare", - "matrixcompare-core", + "matrixcompare 0.1.4", + "matrixcompare 0.3.0", "nalgebra", "num-dual", "paste", @@ -982,7 +982,7 @@ dependencies = [ "gemm", "libm", "log", - "matrixcompare", + "matrixcompare 0.3.0", "matrixcompare-core", "nano-gemm", "num-complex", @@ -1582,6 +1582,16 @@ dependencies = [ "libc", ] +[[package]] +name = "matrixcompare" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8f075a9d74616240a7f32afac53bc2a04f4ec93b55b27e6e4add1eed85fffaf" +dependencies = [ + "matrixcompare-core", + "num-traits", +] + [[package]] name = "matrixcompare" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index a250c3d..f541235 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ faer = { version = "0.19.0", default-features = false, features = [ faer-ext = { version = "0.2.0", features = ["nalgebra"] } nalgebra = { version = "0.32.5", features = ["compare"] } num-dual = "0.9.1" -matrixcompare-core = { version = "0.1" } +matrixcompare = { version = "0.1" } # serialization serde = { version = "1.0.203", optional = true } diff --git a/README.md b/README.md index a6b4771..04b7d43 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,12 @@ We recommend you checkout the [docs](https://docs.rs/factrs/latest/factrs/) (WIP ```rust use factrs::prelude::*; +// Assign symbols to variable types +assign_symbols!(X: SO2); + // Make all the values let mut values = Values::new(); + let x = SO2::from_theta(1.0); let y = SO2::from_theta(2.0); values.insert(X(0), SO2::identity()); @@ -28,13 +32,16 @@ values.insert(X(1), SO2::identity()); // Make the factors & insert into graph let mut graph = Graph::new(); let res = PriorResidual::new(x.clone()); -let factor = Factor::new_base(&[X(0)], res); +let factor = FactorBuilder::new1(res, X(0)).build(); graph.add_factor(factor); let res = BetweenResidual::new(y.minus(&x)); let noise = GaussianNoise::from_scalar_sigma(0.1); let robust = Huber::default(); -let factor = Factor::new_full(&[X(0), X(1)], res, noise, robust); +let factor = FactorBuilder::new2(res, X(0), X(1)) + .noise(noise) + .robust(robust) + .build(); graph.add_factor(factor); // Optimize! diff --git a/examples/serde.rs b/examples/serde.rs index eda9b15..02d1c45 100644 --- a/examples/serde.rs +++ b/examples/serde.rs @@ -2,6 +2,7 @@ use factrs::{ containers::{Graph, Values, X}, factors::Factor, noise::GaussianNoise, + prelude::FactorBuilder, residuals::{BetweenResidual, PriorResidual}, robust::{GemanMcClure, L2}, variables::{SE2, SO2}, @@ -26,13 +27,13 @@ fn main() { let prior = PriorResidual::new(x); let bet = BetweenResidual::new(y); - let prior = Factor::new_full( - &[X(0)], - prior, - GaussianNoise::from_scalar_cov(0.1), - GemanMcClure::default(), - ); - let bet = Factor::new_full(&[X(0), X(1)], bet, GaussianNoise::from_scalar_cov(10.0), L2); + let prior = FactorBuilder::new1(prior, X(0)) + .noise(GaussianNoise::from_scalar_cov(0.1)) + .robust(GemanMcClure::default()) + .build(); + let bet = FactorBuilder::new2(bet, X(0), X(1)) + .noise(GaussianNoise::from_scalar_cov(10.0)) + .build(); let mut graph = Graph::new(); graph.add_factor(prior); graph.add_factor(bet); diff --git a/src/containers/factor.rs b/src/containers/factor.rs index 463beb7..86a0913 100644 --- a/src/containers/factor.rs +++ b/src/containers/factor.rs @@ -1,10 +1,11 @@ +use super::{Symbol, TypedSymbol}; use crate::{ - containers::{Symbol, Values}, + containers::{Key, Values}, dtype, - linalg::{AllocatorBuffer, Const, DefaultAllocator, DiffResult, DualAllocator, MatrixBlock}, + linalg::{Const, DiffResult, MatrixBlock}, linear::LinearFactor, noise::{NoiseModel, NoiseModelSafe, UnitNoise}, - residuals::{Residual, ResidualSafe}, + residuals::ResidualSafe, robust::{RobustCostSafe, L2}, }; @@ -15,7 +16,7 @@ use crate::{ /// Factors are the main building block of the factor graph. They are composed /// of four pieces: /// - Keys: The variables that the factor depends on, given by a -/// slice of [Symbols](Symbol). +/// slice of [Keys](Key). /// - Residual: The vector-valued function that computes the /// error of the factor given a set of values, from the /// [residual](crate::residuals) module. @@ -24,103 +25,31 @@ use crate::{ /// - Robust Kernel: The robust kernel weights the error of the /// factor, given by the traits in the [robust](crate::robust) module. /// -/// Constructors are available for a number of default cases including default -/// robust kernel [L2], default noise model [UnitNoise]. Keys and residual are -/// always required. +/// To construct a factor, please see the [FactorBuilder] struct. /// /// During optimization the factor is linearized around a set of values into a /// [LinearFactor]. /// /// ``` /// # use factrs::prelude::*; +/// # assign_symbols!(X: VectorVar3); /// let prior = VectorVar3::new(1.0, 2.0, 3.0); /// let residual = PriorResidual::new(prior); /// let noise = GaussianNoise::<3>::from_diag_sigmas(1e-1, 2e-1, 3e-1); /// let robust = GemanMcClure::default(); -/// let factor = Factor::new_full(&[X(0)], residual, noise, robust); +/// let factor = FactorBuilder::new1(residual, +/// X(0)).noise(noise).robust(robust).build(); /// ``` #[derive(Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Factor { - keys: Vec, + keys: Vec, residual: Box, noise: Box, robust: Box, } impl Factor { - /// Build a new factor from a set of keys and a residual. - /// - /// Keys will be compile-time checked to ensure the size is consistent with - /// the residual. Noise will be set to [UnitNoise] and robust kernel to - /// [L2]. - pub fn new_base( - keys: &[Symbol; NUM_VARS], - residual: R, - ) -> Self - where - R: 'static + Residual, DimOut = Const> + ResidualSafe, - AllocatorBuffer: Sync + Send, - DefaultAllocator: DualAllocator, - UnitNoise: NoiseModelSafe, - { - Self { - keys: keys.to_vec(), - residual: Box::new(residual), - noise: Box::new(UnitNoise::), - robust: Box::new(L2), - } - } - - /// Build a new factor from a set of keys, a residual, and a noise model. - /// - /// Keys and noise will be compile-time checked to ensure the size is - /// consistent with the residual. Robust kernel will be set to [L2]. - pub fn new_noise( - keys: &[Symbol; NUM_VARS], - residual: R, - noise: N, - ) -> Self - where - R: 'static + Residual, DimOut = Const> + ResidualSafe, - N: 'static + NoiseModel> + NoiseModelSafe, - AllocatorBuffer: Sync + Send, - DefaultAllocator: DualAllocator, - { - Self { - keys: keys.to_vec(), - residual: Box::new(residual), - noise: Box::new(noise), - robust: Box::new(L2), - } - } - - /// Build a new factor from a set of keys, a residual, a noise model, and a - /// robust kernel. - /// - /// Keys and noise will be compile-time checked to ensure the size is - /// consistent with the residual. - pub fn new_full( - keys: &[Symbol; NUM_VARS], - residual: R, - noise: N, - robust: C, - ) -> Self - where - R: 'static + Residual, DimOut = Const> + ResidualSafe, - AllocatorBuffer: Sync + Send, - DefaultAllocator: DualAllocator, - N: 'static + NoiseModel> + NoiseModelSafe, - C: 'static + RobustCostSafe, - { - Self { - keys: keys.to_vec(), - residual: Box::new(residual), - noise: Box::new(noise), - robust: Box::new(robust), - } - } - /// Compute the error of the factor given a set of values. pub fn error(&self, values: &Values) -> dtype { let r = self.residual.residual(values, &self.keys); @@ -155,7 +84,7 @@ impl Factor { .iter() .scan(0, |sum, k| { let out = Some(*sum); - *sum += values.get(k).unwrap().dim(); + *sum += values.get_raw(*k).unwrap().dim(); out }) .collect::>(); @@ -165,11 +94,104 @@ impl Factor { } /// Get the keys of the factor. - pub fn keys(&self) -> &[Symbol] { + pub fn keys(&self) -> &[Key] { &self.keys } } +/// Builder for a factor. +/// +/// If the noise model or robust kernel aren't set, they default to [UnitNoise] +/// and [L2] respectively. +pub struct FactorBuilder { + keys: Vec, + residual: Box, + noise: Option>, + robust: Option>, +} + +macro_rules! impl_new_builder { + ($($num:expr, $( ($key:ident, $key_type:ident, $var:ident) ),*);* $(;)?) => {$( + paste::paste! { + #[doc = "Create a new factor with " $num " variable connections, while verifying the key types."] + pub fn [](residual: R, $($key: $key_type),*) -> Self + where + R: crate::residuals::[]> + ResidualSafe + 'static, + $( + $key_type: TypedSymbol, + )* + { + Self { + keys: vec![$( $key.into() ),*], + residual: Box::new(residual), + noise: None, + robust: None, + } + } + + #[doc = "Create a new factor with " $num " variable connections, without verifying the key types."] + pub fn [](residual: R, $($key: $key_type),*) -> Self + where + R: crate::residuals::[]> + ResidualSafe + 'static, + $( + $key_type: Symbol, + )* + { + Self { + keys: vec![$( $key.into() ),*], + residual: Box::new(residual), + noise: None, + robust: None, + } + } + } + )*}; +} + +impl FactorBuilder { + impl_new_builder! { + 1, (key1, K1, V1); + 2, (key1, K1, V1), (key2, K2, V2); + 3, (key1, K1, V1), (key2, K2, V2), (key3, K3, V3); + 4, (key1, K1, V1), (key2, K2, V2), (key3, K3, V3), (key4, K4, V4); + 5, (key1, K1, V1), (key2, K2, V2), (key3, K3, V3), (key4, K4, V4), (key5, K5, V5); + 6, (key1, K1, V1), (key2, K2, V2), (key3, K3, V3), (key4, K4, V4), (key5, K5, V5), (key6, K6, V6); + } + + /// Add a noise model to the factor. + pub fn noise(mut self, noise: N) -> Self + where + N: 'static + NoiseModel> + NoiseModelSafe, + { + self.noise = Some(Box::new(noise)); + self + } + + /// Add a robust kernel to the factor. + pub fn robust(mut self, robust: C) -> Self + where + C: 'static + RobustCostSafe, + { + self.robust = Some(Box::new(robust)); + self + } + + /// Build the factor. + pub fn build(self) -> Factor + where + UnitNoise: NoiseModelSafe, + { + let noise = self.noise.unwrap_or_else(|| Box::new(UnitNoise::)); + let robust = self.robust.unwrap_or_else(|| Box::new(L2)); + Factor { + keys: self.keys.to_vec(), + residual: self.residual, + noise, + robust, + } + } +} + #[cfg(test)] mod tests { @@ -177,7 +199,7 @@ mod tests { use super::*; use crate::{ - containers::X, + assign_symbols, linalg::{Diff, NumericalDiff}, noise::GaussianNoise, residuals::{BetweenResidual, PriorResidual}, @@ -195,6 +217,8 @@ mod tests { #[cfg(feature = "f32")] const TOL: f32 = 1e-3; + assign_symbols!(X: VectorVar3); + #[test] fn linearize_a() { let prior = VectorVar3::new(1.0, 2.0, 3.0); @@ -204,16 +228,19 @@ mod tests { let noise = GaussianNoise::<3>::from_diag_sigmas(1e-1, 2e-1, 3e-1); let robust = GemanMcClure::default(); - let factor = Factor::new_full(&[X(0)], residual, noise, robust); + let factor = FactorBuilder::new1(residual, X(0)) + .noise(noise) + .robust(robust) + .build(); let f = |x: VectorVar3| { let mut values = Values::new(); - values.insert(X(0), x); + values.insert_unchecked(X(0), x); factor.error(&values) }; let mut values = Values::new(); - values.insert(X(0), x.clone()); + values.insert_unchecked(X(0), x.clone()); let linear = factor.linearize(&values); let grad_got = -linear.a.mat().transpose() * linear.b; @@ -234,11 +261,14 @@ mod tests { let noise = GaussianNoise::<3>::from_diag_sigmas(1e-1, 2e-1, 3e-1); let robust = GemanMcClure::default(); - let factor = Factor::new_full(&[X(0), X(1)], residual, noise, robust); + let factor = FactorBuilder::new2(residual, X(0), X(1)) + .noise(noise) + .robust(robust) + .build(); let mut values = Values::new(); - values.insert(X(0), x.clone()); - values.insert(X(1), x); + values.insert_unchecked(X(0), x.clone()); + values.insert_unchecked(X(1), x); let linear = factor.linearize(&values); diff --git a/src/containers/graph.rs b/src/containers/graph.rs index c6445e1..fab3ad2 100644 --- a/src/containers/graph.rs +++ b/src/containers/graph.rs @@ -15,7 +15,8 @@ use crate::{containers::Factor, dtype, linear::LinearGraph}; /// /// ``` /// # use factrs::prelude::*; -/// # let factor = Factor::new_base(&[X(0)], PriorResidual::new(SO2::identity())); +/// # assign_symbols!(X: SO2); +/// # let factor = FactorBuilder::new1(PriorResidual::new(SO2::identity()), X(0)).build(); /// let mut graph = Graph::new(); /// graph.add_factor(factor); /// ``` @@ -30,6 +31,12 @@ impl Graph { Self::default() } + pub fn with_capacity(capacity: usize) -> Self { + Self { + factors: Vec::with_capacity(capacity), + } + } + pub fn add_factor(&mut self, factor: Factor) { self.factors.push(factor); } @@ -63,7 +70,7 @@ impl Graph { let Idx { idx: col, dim: col_dim, - } = order.get(key).unwrap(); + } = order.get(*key).unwrap(); (0..*col_dim).for_each(|j| { indices.push((row + i, col + j)); }); diff --git a/src/containers/mod.rs b/src/containers/mod.rs index 32ac644..1fec0e2 100644 --- a/src/containers/mod.rs +++ b/src/containers/mod.rs @@ -1,7 +1,7 @@ //! Various containers for storing variables, residuals, factors, etc. mod symbol; -pub use symbol::*; +pub use symbol::{DefaultSymbol, Key, Symbol, TypedSymbol}; mod values; pub use values::Values; @@ -13,4 +13,4 @@ mod graph; pub use graph::{Graph, GraphOrder}; mod factor; -pub use factor::Factor; +pub use factor::{Factor, FactorBuilder}; diff --git a/src/containers/order.rs b/src/containers/order.rs index de39feb..1222a7c 100644 --- a/src/containers/order.rs +++ b/src/containers/order.rs @@ -2,7 +2,7 @@ use std::collections::hash_map::Iter as HashMapIter; use ahash::HashMap; -use super::{Symbol, Values}; +use super::{Key, Symbol, Values}; /// Location of a variable in a list /// @@ -21,12 +21,12 @@ pub struct Idx { /// and len of each variable #[derive(Debug, Clone)] pub struct ValuesOrder { - map: HashMap, + map: HashMap, dim: usize, } impl ValuesOrder { - pub fn new(map: HashMap) -> Self { + pub fn new(map: HashMap) -> Self { let dim = map.values().map(|idx| idx.dim).sum(); Self { map, dim } } @@ -37,22 +37,22 @@ impl ValuesOrder { let order = *idx; *idx += val.dim(); Some(( - key.clone(), + *key, Idx { idx: order, dim: val.dim(), }, )) }) - .collect::>(); + .collect::>(); let dim = map.values().map(|idx| idx.dim).sum(); Self { map, dim } } - pub fn get(&self, key: &Symbol) -> Option<&Idx> { - self.map.get(key) + pub fn get(&self, symbol: impl Symbol) -> Option<&Idx> { + self.map.get(&symbol.into()) } pub fn dim(&self) -> usize { @@ -67,7 +67,7 @@ impl ValuesOrder { self.map.is_empty() } - pub fn iter(&self) -> HashMapIter { + pub fn iter(&self) -> HashMapIter { self.map.iter() } } @@ -76,7 +76,8 @@ impl ValuesOrder { mod test { use super::*; use crate::{ - containers::{Values, X}, + containers::Values, + symbols::X, variables::{Variable, VectorVar2, VectorVar3, VectorVar6}, }; @@ -84,9 +85,9 @@ mod test { fn from_values() { // Create some form of values let mut v = Values::new(); - v.insert(X(0), VectorVar2::identity()); - v.insert(X(1), VectorVar6::identity()); - v.insert(X(2), VectorVar3::identity()); + v.insert_unchecked(X(0), VectorVar2::identity()); + v.insert_unchecked(X(1), VectorVar6::identity()); + v.insert_unchecked(X(2), VectorVar3::identity()); // Create an order let order = ValuesOrder::from_values(&v); @@ -94,8 +95,8 @@ mod test { // Verify the order assert_eq!(order.len(), 3); assert_eq!(order.dim(), 11); - assert_eq!(order.get(&X(0)).unwrap().dim, 2); - assert_eq!(order.get(&X(1)).unwrap().dim, 6); - assert_eq!(order.get(&X(2)).unwrap().dim, 3); + assert_eq!(order.get(X(0)).unwrap().dim, 2); + assert_eq!(order.get(X(1)).unwrap().dim, 6); + assert_eq!(order.get(X(2)).unwrap().dim, 3); } } diff --git a/src/containers/symbol.rs b/src/containers/symbol.rs index eb5c6af..9adfc28 100644 --- a/src/containers/symbol.rs +++ b/src/containers/symbol.rs @@ -1,5 +1,10 @@ // Similar to gtsam: https://github.com/borglab/gtsam/blob/develop/gtsam/inference/Symbol.cpp -use std::{fmt, mem::size_of}; +use std::{ + fmt::{self}, + mem::size_of, +}; + +use crate::prelude::VariableUmbrella; // Char is stored in last CHR_BITS // Value is stored in the first IDX_BITS @@ -9,124 +14,116 @@ const IDX_BITS: usize = KEY_BITS - CHR_BITS; const CHR_MASK: u64 = (char::MAX as u64) << IDX_BITS; const IDX_MASK: u64 = !CHR_MASK; +// ------------------------- Symbol Parser ------------------------- // + /// Newtype wrap around u64 /// -/// First bits contain the index, last bits contain the character. -/// Helpers exist (such as [X], [B], [L]) to create new versions. -/// /// In implementation, the u64 is exclusively used, the chr/idx aren't at all. /// If you'd like to use a custom symbol (ie with two chars for multi-robot /// experiments), simply define a new trait that creates the u64 as you desire. -#[derive(Clone, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Eq, Hash, PartialEq, Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Symbol(u64); +pub struct Key(pub u64); -impl Symbol { - pub fn new_raw(key: u64) -> Self { - Symbol(key) - } +impl Symbol for Key {} - pub fn chr(&self) -> char { - ((self.0 & CHR_MASK) >> IDX_BITS) as u8 as char - } +/// This provides a custom conversion two and from a u64 key. +pub trait Symbol: fmt::Debug + Into {} - pub fn idx(&self) -> u64 { - self.0 & IDX_MASK +pub struct DefaultSymbol { + chr: char, + idx: u64, +} + +impl DefaultSymbol { + pub fn new(chr: char, idx: u64) -> Self { + Self { chr, idx } } +} + +impl Symbol for DefaultSymbol {} - pub fn new(c: char, i: u64) -> Self { - Symbol(((c as u64) << IDX_BITS) | (i & IDX_MASK)) +impl From for DefaultSymbol { + fn from(key: Key) -> Self { + let chr = ((key.0 & CHR_MASK) >> IDX_BITS) as u8 as char; + let idx = key.0 & IDX_MASK; + Self { chr, idx } } } -impl fmt::Display for Symbol { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}{}", self.chr(), self.idx()) +impl From for Key { + fn from(sym: DefaultSymbol) -> Key { + Key((sym.chr as u64) << IDX_BITS | sym.idx & IDX_MASK) } } -impl fmt::Debug for Symbol { +impl fmt::Debug for DefaultSymbol { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}{}", self.chr(), self.idx()) + write!(f, "({}, {})", self.chr, self.idx) } } -// ------------------------- Helpers ------------------------- // -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn A(i: u64) -> Symbol { Symbol::new('a', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn B(i: u64) -> Symbol { Symbol::new('b', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn C(i: u64) -> Symbol { Symbol::new('c', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn D(i: u64) -> Symbol { Symbol::new('d', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn E(i: u64) -> Symbol { Symbol::new('e', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn F(i: u64) -> Symbol { Symbol::new('f', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn G(i: u64) -> Symbol { Symbol::new('g', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn H(i: u64) -> Symbol { Symbol::new('h', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn I(i: u64) -> Symbol { Symbol::new('i', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn J(i: u64) -> Symbol { Symbol::new('j', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn K(i: u64) -> Symbol { Symbol::new('k', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn L(i: u64) -> Symbol { Symbol::new('l', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn M(i: u64) -> Symbol { Symbol::new('m', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn N(i: u64) -> Symbol { Symbol::new('n', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn O(i: u64) -> Symbol { Symbol::new('o', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn P(i: u64) -> Symbol { Symbol::new('p', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn Q(i: u64) -> Symbol { Symbol::new('q', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn R(i: u64) -> Symbol { Symbol::new('r', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn S(i: u64) -> Symbol { Symbol::new('s', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn T(i: u64) -> Symbol { Symbol::new('t', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn U(i: u64) -> Symbol { Symbol::new('u', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn V(i: u64) -> Symbol { Symbol::new('v', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn W(i: u64) -> Symbol { Symbol::new('w', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn X(i: u64) -> Symbol { Symbol::new('x', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn Y(i: u64) -> Symbol { Symbol::new('y', i) } -#[rustfmt::skip] -#[allow(non_snake_case)] -pub fn Z(i: u64) -> Symbol { Symbol::new('z', i) } +// ------------------------- Basic Keys ------------------------- // +/* +To figure out +- How to store in Values and still be print when debugging + - Probably still want store as a Symbol for this reason + +Just want to replace the above functions with some sort of typing + +Do the reverse and assign a letter to the variable? +- Not the best in case we have multiple letters for a single variable +*/ + +pub trait TypedSymbol: Symbol {} + +/// Creates and assigns symbols to variables +/// +/// To reduce runtime errors, fact.rs symbols are tagged +/// with the type they will be used with. This macro will create a new symbol +/// and implement all the necessary traits for it to be used as a symbol. +/// ``` +/// use factrs::prelude::*; +/// assign_symbols!(X: SO2; Y: SE2); +/// ``` +#[macro_export] +macro_rules! assign_symbols { + ($($name:ident : $($var:ident),+);* $(;)?) => {$( + assign_symbols!($name); + + $( + impl $crate::containers::TypedSymbol<$var> for $name {} + )* + )*}; + + ($($name:ident),*) => { + $( + #[derive(Clone, Copy)] + pub struct $name(pub u64); + + paste::paste! { + impl From<$name> for $crate::containers::DefaultSymbol { + fn from(key: $name) -> $crate::containers::DefaultSymbol { + let chr = stringify!([<$name:lower>]).chars().next().unwrap(); + let idx = key.0; + $crate::containers::DefaultSymbol::new(chr, idx) + } + } + } + + impl From<$name> for $crate::containers::Key { + fn from(key: $name) -> $crate::containers::Key { + $crate::containers::DefaultSymbol::from(key).into() + } + } + + impl std::fmt::Debug for $name { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + $crate::containers::DefaultSymbol::from(*self).fmt(f) + } + } + + impl $crate::containers::Symbol for $name {} + )* + }; +} diff --git a/src/containers/values.rs b/src/containers/values.rs index a111abe..61e5837 100644 --- a/src/containers/values.rs +++ b/src/containers/values.rs @@ -2,28 +2,34 @@ use std::{collections::hash_map::Entry, default::Default, fmt, iter::IntoIterato use ahash::AHashMap; -use super::Symbol; -use crate::{linear::LinearValues, variables::VariableSafe}; +use super::{Key, Symbol, TypedSymbol}; +use crate::{ + linear::LinearValues, + prelude::{DefaultSymbol, VariableUmbrella}, + variables::VariableSafe, +}; // Since we won't be passing dual numbers through any of this, // we can just use dtype rather than using generics with Numeric /// Structure to hold the Variables used in the graph. /// -/// Values is essentially a thing wrapper around a Hashmap that maps [Symbol] -> +/// Values is essentially a thin wrapper around a Hashmap that maps [Key] -> /// [VariableSafe]. If you'd like to define a custom variable to be used in /// Values, it must implement [Variable](crate::variables::Variable), and then /// will implement [VariableSafe] via a blanket implementation. /// ``` /// # use factrs::prelude::*; +/// # assign_symbols!(X: SO2); /// let x = SO2::from_theta(0.1); /// let mut values = Values::new(); /// values.insert(X(0), x); /// ``` + #[derive(Clone)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Values { - values: AHashMap>, + values: AHashMap>, } impl Values { @@ -41,69 +47,102 @@ impl Values { /// Returns an [std::collections::hash_map::Entry] from the underlying /// HashMap. - pub fn entry(&mut self, key: Symbol) -> Entry> { - self.values.entry(key) + pub fn entry(&mut self, key: impl Symbol) -> Entry> { + self.values.entry(key.into()) } - pub fn insert( - &mut self, - key: Symbol, - value: impl VariableSafe, - ) -> Option> { - self.values.insert(key, Box::new(value)) + pub fn insert(&mut self, symbol: S, value: V) -> Option> + where + S: TypedSymbol, + V: VariableUmbrella, + { + self.values.insert(symbol.into(), Box::new(value)) } - /// Returns a dynamic VariableSafe. - /// - /// If the underlying value is desired, use [Values::get_cast] - pub fn get(&self, key: &Symbol) -> Option<&dyn VariableSafe> { - self.values.get(key).map(|f| f.as_ref()) + /// Unchecked verison of [Values::insert]. + pub fn insert_unchecked(&mut self, symbol: S, value: V) -> Option> + where + S: Symbol, + V: VariableUmbrella, + { + self.values.insert(symbol.into(), Box::new(value)) + } + + pub(crate) fn get_raw(&self, symbol: S) -> Option<&dyn VariableSafe> + where + S: Symbol, + { + self.values.get(&symbol.into()).map(|f| f.as_ref()) } - // TODO: This should be some kind of error - /// Casts and returns the underlying variable. + /// Returns the underlying variable. /// - /// This will return the value if variable is in the graph, and if the cast - /// is successful. Returns None otherwise. + /// This will return the value if variable is in the graph. Requires a typed + /// symbol and as such is guaranteed to return the correct type. Returns + /// None if key isn't found. /// ``` /// # use factrs::prelude::*; + /// # assign_symbols!(X: SO2); /// # let x = SO2::from_theta(0.1); /// # let mut values = Values::new(); /// # values.insert(X(0), x); - /// let x_out = values.get_cast::(&X(0)); + /// let x_out = values.get(X(0)); /// ``` - pub fn get_cast(&self, key: &Symbol) -> Option<&T> { + pub fn get(&self, symbol: S) -> Option<&V> + where + S: TypedSymbol, + V: VariableUmbrella, + { self.values - .get(key) - .and_then(|value| value.downcast_ref::()) + .get(&symbol.into()) + .and_then(|value| value.downcast_ref::()) } - // TODO: Does this still fail if one is missing? - // pub fn get_multiple<'a>(&self, keys: impl IntoIterator) -> - // Option> where - // Symbol: 'a, - // { - // keys.into_iter().map(|key| self.values.get(key)).collect() - // } + /// Returns the underlying variable, not checking the type. + pub fn get_unchecked(&self, symbol: S) -> Option<&V> + where + S: Symbol, + V: VariableUmbrella, + { + self.values + .get(&symbol.into()) + .and_then(|value| value.downcast_ref::()) + } /// Mutable version of [Values::get]. - pub fn get_mut(&mut self, key: &Symbol) -> Option<&mut dyn VariableSafe> { - self.values.get_mut(key).map(|f| f.as_mut()) + pub fn get_mut(&mut self, symbol: S) -> Option<&mut V> + where + S: TypedSymbol, + V: VariableUmbrella, + { + self.values + .get_mut(&symbol.into()) + .and_then(|value| value.downcast_mut::()) } - // TODO: This should be some kind of error - /// Mutable version of [Values::get_cast]. - pub fn get_mut_cast(&mut self, key: &Symbol) -> Option<&mut T> { + /// Mutable version of [Values::get_unchecked]. + pub fn get_unchecked_mut(&mut self, symbol: S) -> Option<&mut V> + where + S: Symbol, + V: VariableUmbrella, + { self.values - .get_mut(key) - .and_then(|value| value.downcast_mut::()) + .get_mut(&symbol.into()) + .and_then(|value| value.downcast_mut::()) } - pub fn remove(&mut self, key: &Symbol) -> Option> { - self.values.remove(key) + pub fn remove(&mut self, symbol: S) -> Option + where + S: TypedSymbol, + V: VariableUmbrella, + { + self.values + .remove(&symbol.into()) + .and_then(|value| value.downcast::().ok()) + .map(|value| *value) } - pub fn iter(&self) -> impl Iterator)> { + pub fn iter(&self) -> impl Iterator)> { self.values.iter() } @@ -112,6 +151,7 @@ impl Values { /// /// ``` /// # use factrs::prelude::*; + /// # assign_symbols!(X: SO2); /// # let mut values = Values::new(); /// # (0..10).for_each(|i| {values.insert(X(0), SO2::identity());} ); /// let mine: Vec<&SO2> = values.filter().collect(); @@ -138,18 +178,20 @@ impl Values { } } +// TODO: Find a way to make this usable on custom symbols (not just +// DefaultSymbol) impl fmt::Display for Values { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if f.alternate() { writeln!(f, "{{")?; for (key, value) in self.values.iter() { - writeln!(f, " {}: {:?},", key, value)?; + writeln!(f, " {:?}: {:?},", DefaultSymbol::from(*key), value)?; } write!(f, "}}") } else { write!(f, "{{")?; for (key, value) in self.values.iter() { - write!(f, "{}: {:?}, ", key, value)?; + write!(f, "{:?}: {:?}, ", DefaultSymbol::from(*key), value)?; } write!(f, "}}") } @@ -163,8 +205,8 @@ impl fmt::Debug for Values { } impl IntoIterator for Values { - type Item = (Symbol, Box); - type IntoIter = std::collections::hash_map::IntoIter>; + type Item = (Key, Box); + type IntoIter = std::collections::hash_map::IntoIter>; fn into_iter(self) -> Self::IntoIter { self.values.into_iter() diff --git a/src/lib.rs b/src/lib.rs index e484803..e5f5cda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,6 +38,9 @@ //! ``` //! use factrs::prelude::*; //! +//! // Assign symbols to variable types +//! assign_symbols!(X: SO2); +//! //! // Make all the values //! let mut values = Values::new(); //! @@ -50,13 +53,16 @@ //! let mut graph = Graph::new(); //! //! let res = PriorResidual::new(x.clone()); -//! let factor = Factor::new_base(&[X(0)], res); +//! let factor = FactorBuilder::new1(res, X(0)).build(); //! graph.add_factor(factor); //! //! let res = BetweenResidual::new(y.minus(&x)); //! let noise = GaussianNoise::from_scalar_sigma(0.1); //! let robust = Huber::default(); -//! let factor = Factor::new_full(&[X(0), X(1)], res, noise, robust); +//! let factor = FactorBuilder::new2(res, X(0), X(1)) +//! .noise(noise) +//! .robust(robust) +//! .build(); //! graph.add_factor(factor); //! //! // Optimize! @@ -83,6 +89,18 @@ pub mod robust; pub mod utils; pub mod variables; +/// Untagged symnbols if `unchecked` API is desired. +/// +/// We strongly recommend using [assign_symbols](crate::assign_symbols) to +/// create and tag symbols with the appropriate types. However, we provide a +/// number of pre-defined symbols if desired. Note this objects can't be tagged +/// due to the orphan rules. +pub mod symbols { + crate::assign_symbols!( + A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z + ); +} + /// Helper module to import common types /// /// This module is meant to be glob imported to make it easier to use the @@ -92,6 +110,7 @@ pub mod variables; /// ``` pub mod prelude { pub use crate::{ + assign_symbols, containers::*, noise::*, optimizers::*, diff --git a/src/linear/factor.rs b/src/linear/factor.rs index 7379985..f4150a3 100644 --- a/src/linear/factor.rs +++ b/src/linear/factor.rs @@ -1,5 +1,5 @@ use crate::{ - containers::Symbol, + containers::Key, dtype, linalg::{MatrixBlock, VectorX}, linear::LinearValues, @@ -11,12 +11,12 @@ use crate::{ /// consists of the relevant keys, a [MatrixBlock] A, and a [VectorX] b. Again, /// this *shouldn't* ever need to be used by hand. pub struct LinearFactor { - pub keys: Vec, + pub keys: Vec, pub a: MatrixBlock, pub b: VectorX, } impl LinearFactor { - pub fn new(keys: Vec, a: MatrixBlock, b: VectorX) -> Self { + pub fn new(keys: Vec, a: MatrixBlock, b: VectorX) -> Self { assert!( keys.len() == a.idx().len(), "Mismatch between keys and matrix blocks in LinearFactor::new" @@ -40,7 +40,9 @@ impl LinearFactor { .map(|(idx, key)| { self.a.mul( idx, - vector.get(key).expect("Missing key in LinearValues::error"), + vector + .get(*key) + .expect("Missing key in LinearValues::error"), ) }) .sum(); diff --git a/src/linear/graph.rs b/src/linear/graph.rs index 5c334ba..0602286 100644 --- a/src/linear/graph.rs +++ b/src/linear/graph.rs @@ -49,7 +49,7 @@ impl LinearGraph { let Idx { idx: col, dim: col_dim, - } = order.get(key).unwrap(); + } = order.get(*key).unwrap(); (0..*col_dim).for_each(|j| { indices.push((row + i, col + j)); }); @@ -121,8 +121,9 @@ mod test { use super::*; use crate::{ - containers::{Idx, X}, + containers::Idx, linalg::{MatrixBlock, MatrixX, VectorX}, + symbols::X, }; #[test] @@ -134,22 +135,26 @@ mod test { let a1 = MatrixX::from_fn(2, 2, |i, j| (i + j) as dtype); let block1 = MatrixBlock::new(a1, vec![0]); let b1 = VectorX::from_fn(2, |i, j| (i + j) as dtype); - graph.add_factor(LinearFactor::new(vec![X(1)], block1.clone(), b1.clone())); + graph.add_factor(LinearFactor::new( + vec![X(1).into()], + block1.clone(), + b1.clone(), + )); let a2 = MatrixX::from_fn(3, 5, |i, j| (i + j) as dtype); let block2 = MatrixBlock::new(a2, vec![0, 2]); let b2 = VectorX::from_fn(3, |_, _| 5.0); graph.add_factor(LinearFactor::new( - vec![X(0), X(2)], + vec![X(0).into(), X(2).into()], block2.clone(), b2.clone(), )); // Make fake ordering let mut map = HashMap::default(); - map.insert(X(0), Idx { idx: 0, dim: 2 }); - map.insert(X(1), Idx { idx: 2, dim: 2 }); - map.insert(X(2), Idx { idx: 4, dim: 3 }); + map.insert(X(0).into(), Idx { idx: 0, dim: 2 }); + map.insert(X(1).into(), Idx { idx: 2, dim: 2 }); + map.insert(X(2).into(), Idx { idx: 4, dim: 3 }); let order = ValuesOrder::new(map); // Compute the residual and jacobian diff --git a/src/linear/values.rs b/src/linear/values.rs index 5cc360a..0ca1307 100644 --- a/src/linear/values.rs +++ b/src/linear/values.rs @@ -1,7 +1,7 @@ use std::collections::hash_map::Iter as HashMapIter; use crate::{ - containers::{Idx, Symbol, Values, ValuesOrder}, + containers::{Idx, Key, Symbol, Values, ValuesOrder}, linalg::{VectorViewX, VectorX}, }; @@ -69,7 +69,7 @@ impl LinearValues { } /// Retrieve a vector from the LinearValues - pub fn get(&self, key: &Symbol) -> Option> { + pub fn get(&self, key: impl Symbol) -> Option> { let idx = self.order.get(key)?; self.get_idx(idx).into() } @@ -84,11 +84,11 @@ impl LinearValues { pub struct Iter<'a> { values: &'a LinearValues, - idx: HashMapIter<'a, Symbol, Idx>, + idx: HashMapIter<'a, Key, Idx>, } impl<'a> Iterator for Iter<'a> { - type Item = (&'a Symbol, VectorViewX<'a>); + type Item = (&'a Key, VectorViewX<'a>); fn next(&mut self) -> Option { let n = self.idx.next()?; @@ -100,17 +100,17 @@ impl<'a> Iterator for Iter<'a> { mod test { use super::*; use crate::{ - containers::X, dtype, + symbols::X, variables::{Variable, VectorVar2, VectorVar3, VectorVar6}, }; fn make_order_vector() -> (ValuesOrder, VectorX) { // Create some form of values let mut v = Values::new(); - v.insert(X(0), VectorVar2::identity()); - v.insert(X(1), VectorVar6::identity()); - v.insert(X(2), VectorVar3::identity()); + v.insert_unchecked(X(0), VectorVar2::identity()); + v.insert_unchecked(X(1), VectorVar6::identity()); + v.insert_unchecked(X(2), VectorVar3::identity()); // Create an order let order = ValuesOrder::from_values(&v); @@ -126,10 +126,10 @@ mod test { let linear_values = LinearValues::from_order_and_vector(order, vector); assert!(linear_values.len() == 3); assert!(linear_values.dim() == 11); - assert!(linear_values.get(&X(0)).unwrap().len() == 2); - assert!(linear_values.get(&X(1)).unwrap().len() == 6); - assert!(linear_values.get(&X(2)).unwrap().len() == 3); - assert!(linear_values.get(&X(3)).is_none()); + assert!(linear_values.get(X(0)).unwrap().len() == 2); + assert!(linear_values.get(X(1)).unwrap().len() == 6); + assert!(linear_values.get(X(2)).unwrap().len() == 3); + assert!(linear_values.get(X(3)).is_none()); } #[test] diff --git a/src/optimizers/mod.rs b/src/optimizers/mod.rs index 35b034f..ddba5c4 100644 --- a/src/optimizers/mod.rs +++ b/src/optimizers/mod.rs @@ -57,14 +57,17 @@ pub use levenberg_marquardt::LevenMarquardt; #[cfg(test)] pub mod test { use faer::assert_matrix_eq; + use nalgebra::{DefaultAllocator, DimNameAdd, DimNameSum, ToTypenum}; use super::*; use crate::{ - containers::{Factor, Graph, Values, X}, + containers::{Graph, Values}, dtype, - linalg::{Const, VectorX}, + linalg::{AllocatorBuffer, Const, DualAllocator, DualVector, VectorX}, noise::{NoiseModelSafe, UnitNoise}, + prelude::FactorBuilder, residuals::{BetweenResidual, PriorResidual, Residual, ResidualSafe}, + symbols::X, variables::VariableUmbrella, }; @@ -79,17 +82,17 @@ pub mod test { let p = T::exp(t.as_view()); let mut values = Values::new(); - values.insert(X(0), T::identity()); + values.insert_unchecked(X(0), T::identity()); let mut graph = Graph::new(); let res = PriorResidual::new(p.clone()); - let factor = Factor::new_base(&[X(0)], res); + let factor = FactorBuilder::new1_unchecked(res, X(0)).build(); graph.add_factor(factor); let mut opt = O::new(graph); values = opt.optimize(values).unwrap(); - let out: &T = values.get_cast(&X(0)).unwrap(); + let out: &T = values.get_unchecked(X(0)).unwrap(); assert_matrix_eq!( out.ominus(&p), VectorX::zeros(T::DIM), @@ -107,6 +110,11 @@ pub mod test { BetweenResidual: ResidualSafe + Residual, DimOut = Const, NumVars = Const<2>>, O: Optimizer + GraphOptimizer, + Const: ToTypenum, + AllocatorBuffer, Const>>: Sync + Send, + DefaultAllocator: DualAllocator, Const>>, + DualVector, Const>>: Copy, + Const: DimNameAdd>, { let t = VectorX::from_fn(T::DIM, |_, i| ((i as dtype) - (T::DIM as dtype)) / 10.0); let p1 = T::exp(t.as_view()); @@ -115,23 +123,23 @@ pub mod test { let p2 = T::exp(t.as_view()); let mut values = Values::new(); - values.insert(X(0), T::identity()); - values.insert(X(1), T::identity()); + values.insert_unchecked(X(0), T::identity()); + values.insert_unchecked(X(1), T::identity()); let mut graph = Graph::new(); let res = PriorResidual::new(p1.clone()); - let factor = Factor::new_base(&[X(0)], res); + let factor = FactorBuilder::new1_unchecked(res, X(0)).build(); graph.add_factor(factor); let diff = p2.minus(&p1); let res = BetweenResidual::new(diff); - let factor = Factor::new_base(&[X(0), X(1)], res); + let factor = FactorBuilder::new2_unchecked(res, X(0), X(1)).build(); graph.add_factor(factor); let mut opt = O::new(graph); values = opt.optimize(values).unwrap(); - let out1: &T = values.get_cast(&X(0)).unwrap(); + let out1: &T = values.get_unchecked(X(0)).unwrap(); assert_matrix_eq!( out1.ominus(&p1), VectorX::zeros(T::DIM), @@ -139,7 +147,7 @@ pub mod test { tol = 1e-6 ); - let out2: &T = values.get_cast(&X(1)).unwrap(); + let out2: &T = values.get_unchecked(X(1)).unwrap(); assert_matrix_eq!( out2.ominus(&p2), VectorX::zeros(T::DIM), diff --git a/src/residuals/between.rs b/src/residuals/between.rs index f0e9662..cb55976 100644 --- a/src/residuals/between.rs +++ b/src/residuals/between.rs @@ -3,7 +3,7 @@ use nalgebra::{DimNameAdd, DimNameSum}; use super::{Residual, Residual2}; #[allow(unused_imports)] use crate::{ - containers::{Symbol, Values}, + containers::{Key, Values}, linalg::{ AllocatorBuffer, Const, @@ -99,11 +99,11 @@ where type DimIn = ::DimIn; type NumVars = Const<2>; - fn residual(&self, values: &Values, keys: &[Symbol]) -> VectorX { + fn residual(&self, values: &Values, keys: &[Key]) -> VectorX { self.residual2_values(values, keys) } - fn residual_jacobian(&self, values: &Values, keys: &[Symbol]) -> DiffResult { + fn residual_jacobian(&self, values: &Values, keys: &[Key]) -> DiffResult { self.residual2_jacobian(values, keys) } } diff --git a/src/residuals/prior.rs b/src/residuals/prior.rs index f6824af..1ec52da 100644 --- a/src/residuals/prior.rs +++ b/src/residuals/prior.rs @@ -1,7 +1,7 @@ use super::{Residual, Residual1}; #[allow(unused_imports)] use crate::{ - containers::{Symbol, Values}, + containers::{Key, Values}, linalg::{ AllocatorBuffer, Const, @@ -88,10 +88,10 @@ where type DimIn = ::DimIn; type DimOut = ::DimOut; type NumVars = Const<1>; - fn residual(&self, values: &Values, keys: &[Symbol]) -> VectorX { + fn residual(&self, values: &Values, keys: &[Key]) -> VectorX { self.residual1_values(values, keys) } - fn residual_jacobian(&self, values: &Values, keys: &[Symbol]) -> DiffResult { + fn residual_jacobian(&self, values: &Values, keys: &[Key]) -> DiffResult { self.residual1_jacobian(values, keys) } } @@ -103,8 +103,8 @@ mod test { use super::*; use crate::{ - containers::X, linalg::{vectorx, DefaultAllocator, Diff, DualAllocator, NumericalDiff}, + symbols::X, variables::{VectorVar3, SE3, SO3}, }; @@ -129,13 +129,15 @@ mod test { let x1 = P::identity(); let mut values = Values::new(); - values.insert(X(0), x1.clone()); - let jac = prior_residual.residual1_jacobian(&values, &[X(0)]).diff; + values.insert_unchecked(X(0), x1.clone()); + let jac = prior_residual + .residual1_jacobian(&values, &[X(0).into()]) + .diff; let f = |v: P| { let mut vals = Values::new(); - vals.insert(X(0), v.clone()); - Residual1::residual1_values(&prior_residual, &vals, &[X(0)]) + vals.insert_unchecked(X(0), v.clone()); + Residual1::residual1_values(&prior_residual, &vals, &[X(0).into()]) }; let jac_n = NumericalDiff::::jacobian_1(f, &x1).diff; diff --git a/src/residuals/traits.rs b/src/residuals/traits.rs index 7639c6e..c0926d5 100644 --- a/src/residuals/traits.rs +++ b/src/residuals/traits.rs @@ -1,7 +1,7 @@ use std::fmt::{Debug, Display}; use crate::{ - containers::{Symbol, Values}, + containers::{Key, Values}, linalg::{Diff, DiffResult, DimName, MatrixX, Numeric, VectorX}, variables::{Variable, VariableUmbrella}, }; @@ -27,9 +27,9 @@ pub trait Residual: Debug + Display { Self::DimOut::USIZE } - fn residual(&self, values: &Values, keys: &[Symbol]) -> VectorX; + fn residual(&self, values: &Values, keys: &[Key]) -> VectorX; - fn residual_jacobian(&self, values: &Values, keys: &[Symbol]) -> DiffResult; + fn residual_jacobian(&self, values: &Values, keys: &[Key]) -> DiffResult; } /// The object safe version of [Residual]. @@ -42,9 +42,9 @@ pub trait ResidualSafe: Debug + Display { fn dim_out(&self) -> usize; - fn residual(&self, values: &Values, keys: &[Symbol]) -> VectorX; + fn residual(&self, values: &Values, keys: &[Key]) -> VectorX; - fn residual_jacobian(&self, values: &Values, keys: &[Symbol]) -> DiffResult; + fn residual_jacobian(&self, values: &Values, keys: &[Key]) -> DiffResult; } impl< @@ -60,11 +60,11 @@ impl< Residual::dim_out(self) } - fn residual(&self, values: &Values, keys: &[Symbol]) -> VectorX { + fn residual(&self, values: &Values, keys: &[Key]) -> VectorX { Residual::residual(self, values, keys) } - fn residual_jacobian(&self, values: &Values, keys: &[Symbol]) -> DiffResult { + fn residual_jacobian(&self, values: &Values, keys: &[Key]) -> DiffResult { Residual::residual_jacobian(self, values, keys) } @@ -105,8 +105,8 @@ macro_rules! residual_maker { /// It is generic over the dtype to allow for differentiable types. fn [](&self, $($name: Alias,)*) -> VectorX; - #[doc=concat!("Wrapper that unpacks variables and calls [", stringify!([]), "](Self::", stringify!([]), ").")] - fn [](&self, values: &Values, keys: &[Symbol]) -> VectorX + #[doc="Wrapper that unpacks and calls [" [] "](Self::" [] ")."] + fn [](&self, values: &Values, keys: &[Key]) -> VectorX where $( Self::$var: 'static, @@ -114,7 +114,7 @@ macro_rules! residual_maker { { // Unwrap everything $( - let $name: &Self::$var = values.get_cast(&keys[$idx]).unwrap_or_else(|| { + let $name: &Self::$var = values.get_unchecked(keys[$idx]).unwrap_or_else(|| { panic!("Key not found in values: {:?} with type {}", keys[$idx], std::any::type_name::()) }); )* @@ -122,8 +122,8 @@ macro_rules! residual_maker { } - #[doc=concat!("Wrapper that unpacks variables and computes jacobians using [", stringify!([]), "](Self::", stringify!([]), ").")] - fn [](&self, values: &Values, keys: &[Symbol]) -> DiffResult + #[doc="Wrapper that unpacks variables and computes jacobians using [" [] "](Self::" [] ")."] + fn [](&self, values: &Values, keys: &[Key]) -> DiffResult where $( Self::$var: 'static, @@ -131,7 +131,7 @@ macro_rules! residual_maker { { // Unwrap everything $( - let $name: &Self::$var = values.get_cast(&keys[$idx]).unwrap_or_else(|| { + let $name: &Self::$var = values.get_unchecked(keys[$idx]).unwrap_or_else(|| { panic!("Key not found in values: {:?} with type {}", keys[$idx], std::any::type_name::()) }); )* diff --git a/src/utils.rs b/src/utils.rs index 014345e..e3f979a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -5,7 +5,8 @@ use std::{ }; use crate::{ - containers::{Factor, Graph, Values, X}, + assign_symbols, + containers::{FactorBuilder, Graph, Values}, dtype, linalg::{Matrix3, Matrix6, Vector3}, noise::GaussianNoise, @@ -13,6 +14,8 @@ use crate::{ variables::*, }; +assign_symbols!(X: SE2, SE3); + /// Load a g2o file /// /// Currently supports only SE2 and SE3 pose graphs. Will autodetect which one @@ -38,7 +41,7 @@ pub fn load_g20(file: &str) -> (Graph, Values) { // Add prior on whatever the first variable is if values.len() == 1 { - let factor = Factor::new_base(&[key.clone()], PriorResidual::new(var.clone())); + let factor = FactorBuilder::new1(PriorResidual::new(var.clone()), key).build(); graph.add_factor(factor); } @@ -64,7 +67,9 @@ pub fn load_g20(file: &str) -> (Graph, Values) { let key2 = X(id_curr); let var = SE2::new(theta, x, y); let noise = GaussianNoise::from_matrix_inf(inf.as_view()); - let factor = Factor::new_noise(&[key1, key2], BetweenResidual::new(var), noise); + let factor = FactorBuilder::new2(BetweenResidual::new(var), key1, key2) + .noise(noise) + .build(); graph.add_factor(factor); } @@ -87,8 +92,9 @@ pub fn load_g20(file: &str) -> (Graph, Values) { if values.len() == 1 { let noise = GaussianNoise::<6>::from_diag_covs(1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4); - let factor = - Factor::new_noise(&[key.clone()], PriorResidual::new(var.clone()), noise); + let factor = FactorBuilder::new1(PriorResidual::new(var.clone()), key) + .noise(noise) + .build(); graph.add_factor(factor); } @@ -141,14 +147,12 @@ pub fn load_g20(file: &str) -> (Graph, Values) { let xyz = Vector3::new(x, y, z); let var = SE3::from_rot_trans(rot, xyz); - // println!("var: {:?}", var); - // println!("2499 {:?}", values.get_cast::(&X(2499)).unwrap()); - // panic!(); - let key1 = X(id_prev); let key2 = X(id_curr); let noise = GaussianNoise::from_matrix_inf(inf.as_view()); - let factor = Factor::new_noise(&[key1, key2], BetweenResidual::new(var), noise); + let factor = FactorBuilder::new2(BetweenResidual::new(var), key1, key2) + .noise(noise) + .build(); graph.add_factor(factor); } diff --git a/src/variables/macros.rs b/src/variables/macros.rs index 1ee3c38..b010b2c 100644 --- a/src/variables/macros.rs +++ b/src/variables/macros.rs @@ -1,4 +1,4 @@ -/// Variable wrapper around [matrixcompare::assert_matrix_eq] +/// Variable wrapper around [assert_matrix_eq](matrixcompare::assert_matrix_eq) /// /// This compares two variables using /// [ominus](crate::variables::Variable::ominus)