diff --git a/Cargo.lock b/Cargo.lock
index 7f46e08..d5d1e80 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -955,8 +955,8 @@ dependencies = [
"faer",
"faer-ext",
"log",
- "matrixcompare",
- "matrixcompare-core",
+ "matrixcompare 0.1.4",
+ "matrixcompare 0.3.0",
"nalgebra",
"num-dual",
"paste",
@@ -982,7 +982,7 @@ dependencies = [
"gemm",
"libm",
"log",
- "matrixcompare",
+ "matrixcompare 0.3.0",
"matrixcompare-core",
"nano-gemm",
"num-complex",
@@ -1582,6 +1582,16 @@ dependencies = [
"libc",
]
+[[package]]
+name = "matrixcompare"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f8f075a9d74616240a7f32afac53bc2a04f4ec93b55b27e6e4add1eed85fffaf"
+dependencies = [
+ "matrixcompare-core",
+ "num-traits",
+]
+
[[package]]
name = "matrixcompare"
version = "0.3.0"
diff --git a/Cargo.toml b/Cargo.toml
index a250c3d..f541235 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -23,7 +23,7 @@ faer = { version = "0.19.0", default-features = false, features = [
faer-ext = { version = "0.2.0", features = ["nalgebra"] }
nalgebra = { version = "0.32.5", features = ["compare"] }
num-dual = "0.9.1"
-matrixcompare-core = { version = "0.1" }
+matrixcompare = { version = "0.1" }
# serialization
serde = { version = "1.0.203", optional = true }
diff --git a/README.md b/README.md
index a6b4771..04b7d43 100644
--- a/README.md
+++ b/README.md
@@ -18,8 +18,12 @@ We recommend you checkout the [docs](https://docs.rs/factrs/latest/factrs/) (WIP
```rust
use factrs::prelude::*;
+// Assign symbols to variable types
+assign_symbols!(X: SO2);
+
// Make all the values
let mut values = Values::new();
+
let x = SO2::from_theta(1.0);
let y = SO2::from_theta(2.0);
values.insert(X(0), SO2::identity());
@@ -28,13 +32,16 @@ values.insert(X(1), SO2::identity());
// Make the factors & insert into graph
let mut graph = Graph::new();
let res = PriorResidual::new(x.clone());
-let factor = Factor::new_base(&[X(0)], res);
+let factor = FactorBuilder::new1(res, X(0)).build();
graph.add_factor(factor);
let res = BetweenResidual::new(y.minus(&x));
let noise = GaussianNoise::from_scalar_sigma(0.1);
let robust = Huber::default();
-let factor = Factor::new_full(&[X(0), X(1)], res, noise, robust);
+let factor = FactorBuilder::new2(res, X(0), X(1))
+ .noise(noise)
+ .robust(robust)
+ .build();
graph.add_factor(factor);
// Optimize!
diff --git a/examples/serde.rs b/examples/serde.rs
index eda9b15..02d1c45 100644
--- a/examples/serde.rs
+++ b/examples/serde.rs
@@ -2,6 +2,7 @@ use factrs::{
containers::{Graph, Values, X},
factors::Factor,
noise::GaussianNoise,
+ prelude::FactorBuilder,
residuals::{BetweenResidual, PriorResidual},
robust::{GemanMcClure, L2},
variables::{SE2, SO2},
@@ -26,13 +27,13 @@ fn main() {
let prior = PriorResidual::new(x);
let bet = BetweenResidual::new(y);
- let prior = Factor::new_full(
- &[X(0)],
- prior,
- GaussianNoise::from_scalar_cov(0.1),
- GemanMcClure::default(),
- );
- let bet = Factor::new_full(&[X(0), X(1)], bet, GaussianNoise::from_scalar_cov(10.0), L2);
+ let prior = FactorBuilder::new1(prior, X(0))
+ .noise(GaussianNoise::from_scalar_cov(0.1))
+ .robust(GemanMcClure::default())
+ .build();
+ let bet = FactorBuilder::new2(bet, X(0), X(1))
+ .noise(GaussianNoise::from_scalar_cov(10.0))
+ .build();
let mut graph = Graph::new();
graph.add_factor(prior);
graph.add_factor(bet);
diff --git a/src/containers/factor.rs b/src/containers/factor.rs
index 463beb7..86a0913 100644
--- a/src/containers/factor.rs
+++ b/src/containers/factor.rs
@@ -1,10 +1,11 @@
+use super::{Symbol, TypedSymbol};
use crate::{
- containers::{Symbol, Values},
+ containers::{Key, Values},
dtype,
- linalg::{AllocatorBuffer, Const, DefaultAllocator, DiffResult, DualAllocator, MatrixBlock},
+ linalg::{Const, DiffResult, MatrixBlock},
linear::LinearFactor,
noise::{NoiseModel, NoiseModelSafe, UnitNoise},
- residuals::{Residual, ResidualSafe},
+ residuals::ResidualSafe,
robust::{RobustCostSafe, L2},
};
@@ -15,7 +16,7 @@ use crate::{
/// Factors are the main building block of the factor graph. They are composed
/// of four pieces:
/// - Keys: The variables that the factor depends on, given by a
-/// slice of [Symbols](Symbol).
+/// slice of [Keys](Key).
/// - Residual: The vector-valued function that computes the
/// error of the factor given a set of values, from the
/// [residual](crate::residuals) module.
@@ -24,103 +25,31 @@ use crate::{
/// - Robust Kernel: The robust kernel weights the error of the
/// factor, given by the traits in the [robust](crate::robust) module.
///
-/// Constructors are available for a number of default cases including default
-/// robust kernel [L2], default noise model [UnitNoise]. Keys and residual are
-/// always required.
+/// To construct a factor, please see the [FactorBuilder] struct.
///
/// During optimization the factor is linearized around a set of values into a
/// [LinearFactor].
///
/// ```
/// # use factrs::prelude::*;
+/// # assign_symbols!(X: VectorVar3);
/// let prior = VectorVar3::new(1.0, 2.0, 3.0);
/// let residual = PriorResidual::new(prior);
/// let noise = GaussianNoise::<3>::from_diag_sigmas(1e-1, 2e-1, 3e-1);
/// let robust = GemanMcClure::default();
-/// let factor = Factor::new_full(&[X(0)], residual, noise, robust);
+/// let factor = FactorBuilder::new1(residual,
+/// X(0)).noise(noise).robust(robust).build();
/// ```
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Factor {
- keys: Vec,
+ keys: Vec,
residual: Box,
noise: Box,
robust: Box,
}
impl Factor {
- /// Build a new factor from a set of keys and a residual.
- ///
- /// Keys will be compile-time checked to ensure the size is consistent with
- /// the residual. Noise will be set to [UnitNoise] and robust kernel to
- /// [L2].
- pub fn new_base(
- keys: &[Symbol; NUM_VARS],
- residual: R,
- ) -> Self
- where
- R: 'static + Residual, DimOut = Const> + ResidualSafe,
- AllocatorBuffer: Sync + Send,
- DefaultAllocator: DualAllocator,
- UnitNoise: NoiseModelSafe,
- {
- Self {
- keys: keys.to_vec(),
- residual: Box::new(residual),
- noise: Box::new(UnitNoise::),
- robust: Box::new(L2),
- }
- }
-
- /// Build a new factor from a set of keys, a residual, and a noise model.
- ///
- /// Keys and noise will be compile-time checked to ensure the size is
- /// consistent with the residual. Robust kernel will be set to [L2].
- pub fn new_noise(
- keys: &[Symbol; NUM_VARS],
- residual: R,
- noise: N,
- ) -> Self
- where
- R: 'static + Residual, DimOut = Const> + ResidualSafe,
- N: 'static + NoiseModel> + NoiseModelSafe,
- AllocatorBuffer: Sync + Send,
- DefaultAllocator: DualAllocator,
- {
- Self {
- keys: keys.to_vec(),
- residual: Box::new(residual),
- noise: Box::new(noise),
- robust: Box::new(L2),
- }
- }
-
- /// Build a new factor from a set of keys, a residual, a noise model, and a
- /// robust kernel.
- ///
- /// Keys and noise will be compile-time checked to ensure the size is
- /// consistent with the residual.
- pub fn new_full(
- keys: &[Symbol; NUM_VARS],
- residual: R,
- noise: N,
- robust: C,
- ) -> Self
- where
- R: 'static + Residual, DimOut = Const> + ResidualSafe,
- AllocatorBuffer: Sync + Send,
- DefaultAllocator: DualAllocator,
- N: 'static + NoiseModel> + NoiseModelSafe,
- C: 'static + RobustCostSafe,
- {
- Self {
- keys: keys.to_vec(),
- residual: Box::new(residual),
- noise: Box::new(noise),
- robust: Box::new(robust),
- }
- }
-
/// Compute the error of the factor given a set of values.
pub fn error(&self, values: &Values) -> dtype {
let r = self.residual.residual(values, &self.keys);
@@ -155,7 +84,7 @@ impl Factor {
.iter()
.scan(0, |sum, k| {
let out = Some(*sum);
- *sum += values.get(k).unwrap().dim();
+ *sum += values.get_raw(*k).unwrap().dim();
out
})
.collect::>();
@@ -165,11 +94,104 @@ impl Factor {
}
/// Get the keys of the factor.
- pub fn keys(&self) -> &[Symbol] {
+ pub fn keys(&self) -> &[Key] {
&self.keys
}
}
+/// Builder for a factor.
+///
+/// If the noise model or robust kernel aren't set, they default to [UnitNoise]
+/// and [L2] respectively.
+pub struct FactorBuilder {
+ keys: Vec,
+ residual: Box,
+ noise: Option>,
+ robust: Option>,
+}
+
+macro_rules! impl_new_builder {
+ ($($num:expr, $( ($key:ident, $key_type:ident, $var:ident) ),*);* $(;)?) => {$(
+ paste::paste! {
+ #[doc = "Create a new factor with " $num " variable connections, while verifying the key types."]
+ pub fn [](residual: R, $($key: $key_type),*) -> Self
+ where
+ R: crate::residuals::[]> + ResidualSafe + 'static,
+ $(
+ $key_type: TypedSymbol,
+ )*
+ {
+ Self {
+ keys: vec![$( $key.into() ),*],
+ residual: Box::new(residual),
+ noise: None,
+ robust: None,
+ }
+ }
+
+ #[doc = "Create a new factor with " $num " variable connections, without verifying the key types."]
+ pub fn [](residual: R, $($key: $key_type),*) -> Self
+ where
+ R: crate::residuals::[]> + ResidualSafe + 'static,
+ $(
+ $key_type: Symbol,
+ )*
+ {
+ Self {
+ keys: vec![$( $key.into() ),*],
+ residual: Box::new(residual),
+ noise: None,
+ robust: None,
+ }
+ }
+ }
+ )*};
+}
+
+impl FactorBuilder {
+ impl_new_builder! {
+ 1, (key1, K1, V1);
+ 2, (key1, K1, V1), (key2, K2, V2);
+ 3, (key1, K1, V1), (key2, K2, V2), (key3, K3, V3);
+ 4, (key1, K1, V1), (key2, K2, V2), (key3, K3, V3), (key4, K4, V4);
+ 5, (key1, K1, V1), (key2, K2, V2), (key3, K3, V3), (key4, K4, V4), (key5, K5, V5);
+ 6, (key1, K1, V1), (key2, K2, V2), (key3, K3, V3), (key4, K4, V4), (key5, K5, V5), (key6, K6, V6);
+ }
+
+ /// Add a noise model to the factor.
+ pub fn noise(mut self, noise: N) -> Self
+ where
+ N: 'static + NoiseModel> + NoiseModelSafe,
+ {
+ self.noise = Some(Box::new(noise));
+ self
+ }
+
+ /// Add a robust kernel to the factor.
+ pub fn robust(mut self, robust: C) -> Self
+ where
+ C: 'static + RobustCostSafe,
+ {
+ self.robust = Some(Box::new(robust));
+ self
+ }
+
+ /// Build the factor.
+ pub fn build(self) -> Factor
+ where
+ UnitNoise: NoiseModelSafe,
+ {
+ let noise = self.noise.unwrap_or_else(|| Box::new(UnitNoise::));
+ let robust = self.robust.unwrap_or_else(|| Box::new(L2));
+ Factor {
+ keys: self.keys.to_vec(),
+ residual: self.residual,
+ noise,
+ robust,
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
@@ -177,7 +199,7 @@ mod tests {
use super::*;
use crate::{
- containers::X,
+ assign_symbols,
linalg::{Diff, NumericalDiff},
noise::GaussianNoise,
residuals::{BetweenResidual, PriorResidual},
@@ -195,6 +217,8 @@ mod tests {
#[cfg(feature = "f32")]
const TOL: f32 = 1e-3;
+ assign_symbols!(X: VectorVar3);
+
#[test]
fn linearize_a() {
let prior = VectorVar3::new(1.0, 2.0, 3.0);
@@ -204,16 +228,19 @@ mod tests {
let noise = GaussianNoise::<3>::from_diag_sigmas(1e-1, 2e-1, 3e-1);
let robust = GemanMcClure::default();
- let factor = Factor::new_full(&[X(0)], residual, noise, robust);
+ let factor = FactorBuilder::new1(residual, X(0))
+ .noise(noise)
+ .robust(robust)
+ .build();
let f = |x: VectorVar3| {
let mut values = Values::new();
- values.insert(X(0), x);
+ values.insert_unchecked(X(0), x);
factor.error(&values)
};
let mut values = Values::new();
- values.insert(X(0), x.clone());
+ values.insert_unchecked(X(0), x.clone());
let linear = factor.linearize(&values);
let grad_got = -linear.a.mat().transpose() * linear.b;
@@ -234,11 +261,14 @@ mod tests {
let noise = GaussianNoise::<3>::from_diag_sigmas(1e-1, 2e-1, 3e-1);
let robust = GemanMcClure::default();
- let factor = Factor::new_full(&[X(0), X(1)], residual, noise, robust);
+ let factor = FactorBuilder::new2(residual, X(0), X(1))
+ .noise(noise)
+ .robust(robust)
+ .build();
let mut values = Values::new();
- values.insert(X(0), x.clone());
- values.insert(X(1), x);
+ values.insert_unchecked(X(0), x.clone());
+ values.insert_unchecked(X(1), x);
let linear = factor.linearize(&values);
diff --git a/src/containers/graph.rs b/src/containers/graph.rs
index c6445e1..fab3ad2 100644
--- a/src/containers/graph.rs
+++ b/src/containers/graph.rs
@@ -15,7 +15,8 @@ use crate::{containers::Factor, dtype, linear::LinearGraph};
///
/// ```
/// # use factrs::prelude::*;
-/// # let factor = Factor::new_base(&[X(0)], PriorResidual::new(SO2::identity()));
+/// # assign_symbols!(X: SO2);
+/// # let factor = FactorBuilder::new1(PriorResidual::new(SO2::identity()), X(0)).build();
/// let mut graph = Graph::new();
/// graph.add_factor(factor);
/// ```
@@ -30,6 +31,12 @@ impl Graph {
Self::default()
}
+ pub fn with_capacity(capacity: usize) -> Self {
+ Self {
+ factors: Vec::with_capacity(capacity),
+ }
+ }
+
pub fn add_factor(&mut self, factor: Factor) {
self.factors.push(factor);
}
@@ -63,7 +70,7 @@ impl Graph {
let Idx {
idx: col,
dim: col_dim,
- } = order.get(key).unwrap();
+ } = order.get(*key).unwrap();
(0..*col_dim).for_each(|j| {
indices.push((row + i, col + j));
});
diff --git a/src/containers/mod.rs b/src/containers/mod.rs
index 32ac644..1fec0e2 100644
--- a/src/containers/mod.rs
+++ b/src/containers/mod.rs
@@ -1,7 +1,7 @@
//! Various containers for storing variables, residuals, factors, etc.
mod symbol;
-pub use symbol::*;
+pub use symbol::{DefaultSymbol, Key, Symbol, TypedSymbol};
mod values;
pub use values::Values;
@@ -13,4 +13,4 @@ mod graph;
pub use graph::{Graph, GraphOrder};
mod factor;
-pub use factor::Factor;
+pub use factor::{Factor, FactorBuilder};
diff --git a/src/containers/order.rs b/src/containers/order.rs
index de39feb..1222a7c 100644
--- a/src/containers/order.rs
+++ b/src/containers/order.rs
@@ -2,7 +2,7 @@ use std::collections::hash_map::Iter as HashMapIter;
use ahash::HashMap;
-use super::{Symbol, Values};
+use super::{Key, Symbol, Values};
/// Location of a variable in a list
///
@@ -21,12 +21,12 @@ pub struct Idx {
/// and len of each variable
#[derive(Debug, Clone)]
pub struct ValuesOrder {
- map: HashMap,
+ map: HashMap,
dim: usize,
}
impl ValuesOrder {
- pub fn new(map: HashMap) -> Self {
+ pub fn new(map: HashMap) -> Self {
let dim = map.values().map(|idx| idx.dim).sum();
Self { map, dim }
}
@@ -37,22 +37,22 @@ impl ValuesOrder {
let order = *idx;
*idx += val.dim();
Some((
- key.clone(),
+ *key,
Idx {
idx: order,
dim: val.dim(),
},
))
})
- .collect::>();
+ .collect::>();
let dim = map.values().map(|idx| idx.dim).sum();
Self { map, dim }
}
- pub fn get(&self, key: &Symbol) -> Option<&Idx> {
- self.map.get(key)
+ pub fn get(&self, symbol: impl Symbol) -> Option<&Idx> {
+ self.map.get(&symbol.into())
}
pub fn dim(&self) -> usize {
@@ -67,7 +67,7 @@ impl ValuesOrder {
self.map.is_empty()
}
- pub fn iter(&self) -> HashMapIter {
+ pub fn iter(&self) -> HashMapIter {
self.map.iter()
}
}
@@ -76,7 +76,8 @@ impl ValuesOrder {
mod test {
use super::*;
use crate::{
- containers::{Values, X},
+ containers::Values,
+ symbols::X,
variables::{Variable, VectorVar2, VectorVar3, VectorVar6},
};
@@ -84,9 +85,9 @@ mod test {
fn from_values() {
// Create some form of values
let mut v = Values::new();
- v.insert(X(0), VectorVar2::identity());
- v.insert(X(1), VectorVar6::identity());
- v.insert(X(2), VectorVar3::identity());
+ v.insert_unchecked(X(0), VectorVar2::identity());
+ v.insert_unchecked(X(1), VectorVar6::identity());
+ v.insert_unchecked(X(2), VectorVar3::identity());
// Create an order
let order = ValuesOrder::from_values(&v);
@@ -94,8 +95,8 @@ mod test {
// Verify the order
assert_eq!(order.len(), 3);
assert_eq!(order.dim(), 11);
- assert_eq!(order.get(&X(0)).unwrap().dim, 2);
- assert_eq!(order.get(&X(1)).unwrap().dim, 6);
- assert_eq!(order.get(&X(2)).unwrap().dim, 3);
+ assert_eq!(order.get(X(0)).unwrap().dim, 2);
+ assert_eq!(order.get(X(1)).unwrap().dim, 6);
+ assert_eq!(order.get(X(2)).unwrap().dim, 3);
}
}
diff --git a/src/containers/symbol.rs b/src/containers/symbol.rs
index eb5c6af..9adfc28 100644
--- a/src/containers/symbol.rs
+++ b/src/containers/symbol.rs
@@ -1,5 +1,10 @@
// Similar to gtsam: https://github.com/borglab/gtsam/blob/develop/gtsam/inference/Symbol.cpp
-use std::{fmt, mem::size_of};
+use std::{
+ fmt::{self},
+ mem::size_of,
+};
+
+use crate::prelude::VariableUmbrella;
// Char is stored in last CHR_BITS
// Value is stored in the first IDX_BITS
@@ -9,124 +14,116 @@ const IDX_BITS: usize = KEY_BITS - CHR_BITS;
const CHR_MASK: u64 = (char::MAX as u64) << IDX_BITS;
const IDX_MASK: u64 = !CHR_MASK;
+// ------------------------- Symbol Parser ------------------------- //
+
/// Newtype wrap around u64
///
-/// First bits contain the index, last bits contain the character.
-/// Helpers exist (such as [X], [B], [L]) to create new versions.
-///
/// In implementation, the u64 is exclusively used, the chr/idx aren't at all.
/// If you'd like to use a custom symbol (ie with two chars for multi-robot
/// experiments), simply define a new trait that creates the u64 as you desire.
-#[derive(Clone, Eq, Hash, PartialEq)]
+#[derive(Clone, Copy, Eq, Hash, PartialEq, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
-pub struct Symbol(u64);
+pub struct Key(pub u64);
-impl Symbol {
- pub fn new_raw(key: u64) -> Self {
- Symbol(key)
- }
+impl Symbol for Key {}
- pub fn chr(&self) -> char {
- ((self.0 & CHR_MASK) >> IDX_BITS) as u8 as char
- }
+/// This provides a custom conversion two and from a u64 key.
+pub trait Symbol: fmt::Debug + Into {}
- pub fn idx(&self) -> u64 {
- self.0 & IDX_MASK
+pub struct DefaultSymbol {
+ chr: char,
+ idx: u64,
+}
+
+impl DefaultSymbol {
+ pub fn new(chr: char, idx: u64) -> Self {
+ Self { chr, idx }
}
+}
+
+impl Symbol for DefaultSymbol {}
- pub fn new(c: char, i: u64) -> Self {
- Symbol(((c as u64) << IDX_BITS) | (i & IDX_MASK))
+impl From for DefaultSymbol {
+ fn from(key: Key) -> Self {
+ let chr = ((key.0 & CHR_MASK) >> IDX_BITS) as u8 as char;
+ let idx = key.0 & IDX_MASK;
+ Self { chr, idx }
}
}
-impl fmt::Display for Symbol {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(f, "{}{}", self.chr(), self.idx())
+impl From for Key {
+ fn from(sym: DefaultSymbol) -> Key {
+ Key((sym.chr as u64) << IDX_BITS | sym.idx & IDX_MASK)
}
}
-impl fmt::Debug for Symbol {
+impl fmt::Debug for DefaultSymbol {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(f, "{}{}", self.chr(), self.idx())
+ write!(f, "({}, {})", self.chr, self.idx)
}
}
-// ------------------------- Helpers ------------------------- //
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn A(i: u64) -> Symbol { Symbol::new('a', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn B(i: u64) -> Symbol { Symbol::new('b', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn C(i: u64) -> Symbol { Symbol::new('c', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn D(i: u64) -> Symbol { Symbol::new('d', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn E(i: u64) -> Symbol { Symbol::new('e', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn F(i: u64) -> Symbol { Symbol::new('f', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn G(i: u64) -> Symbol { Symbol::new('g', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn H(i: u64) -> Symbol { Symbol::new('h', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn I(i: u64) -> Symbol { Symbol::new('i', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn J(i: u64) -> Symbol { Symbol::new('j', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn K(i: u64) -> Symbol { Symbol::new('k', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn L(i: u64) -> Symbol { Symbol::new('l', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn M(i: u64) -> Symbol { Symbol::new('m', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn N(i: u64) -> Symbol { Symbol::new('n', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn O(i: u64) -> Symbol { Symbol::new('o', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn P(i: u64) -> Symbol { Symbol::new('p', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn Q(i: u64) -> Symbol { Symbol::new('q', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn R(i: u64) -> Symbol { Symbol::new('r', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn S(i: u64) -> Symbol { Symbol::new('s', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn T(i: u64) -> Symbol { Symbol::new('t', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn U(i: u64) -> Symbol { Symbol::new('u', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn V(i: u64) -> Symbol { Symbol::new('v', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn W(i: u64) -> Symbol { Symbol::new('w', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn X(i: u64) -> Symbol { Symbol::new('x', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn Y(i: u64) -> Symbol { Symbol::new('y', i) }
-#[rustfmt::skip]
-#[allow(non_snake_case)]
-pub fn Z(i: u64) -> Symbol { Symbol::new('z', i) }
+// ------------------------- Basic Keys ------------------------- //
+/*
+To figure out
+- How to store in Values and still be print when debugging
+ - Probably still want store as a Symbol for this reason
+
+Just want to replace the above functions with some sort of typing
+
+Do the reverse and assign a letter to the variable?
+- Not the best in case we have multiple letters for a single variable
+*/
+
+pub trait TypedSymbol: Symbol {}
+
+/// Creates and assigns symbols to variables
+///
+/// To reduce runtime errors, fact.rs symbols are tagged
+/// with the type they will be used with. This macro will create a new symbol
+/// and implement all the necessary traits for it to be used as a symbol.
+/// ```
+/// use factrs::prelude::*;
+/// assign_symbols!(X: SO2; Y: SE2);
+/// ```
+#[macro_export]
+macro_rules! assign_symbols {
+ ($($name:ident : $($var:ident),+);* $(;)?) => {$(
+ assign_symbols!($name);
+
+ $(
+ impl $crate::containers::TypedSymbol<$var> for $name {}
+ )*
+ )*};
+
+ ($($name:ident),*) => {
+ $(
+ #[derive(Clone, Copy)]
+ pub struct $name(pub u64);
+
+ paste::paste! {
+ impl From<$name> for $crate::containers::DefaultSymbol {
+ fn from(key: $name) -> $crate::containers::DefaultSymbol {
+ let chr = stringify!([<$name:lower>]).chars().next().unwrap();
+ let idx = key.0;
+ $crate::containers::DefaultSymbol::new(chr, idx)
+ }
+ }
+ }
+
+ impl From<$name> for $crate::containers::Key {
+ fn from(key: $name) -> $crate::containers::Key {
+ $crate::containers::DefaultSymbol::from(key).into()
+ }
+ }
+
+ impl std::fmt::Debug for $name {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ $crate::containers::DefaultSymbol::from(*self).fmt(f)
+ }
+ }
+
+ impl $crate::containers::Symbol for $name {}
+ )*
+ };
+}
diff --git a/src/containers/values.rs b/src/containers/values.rs
index a111abe..61e5837 100644
--- a/src/containers/values.rs
+++ b/src/containers/values.rs
@@ -2,28 +2,34 @@ use std::{collections::hash_map::Entry, default::Default, fmt, iter::IntoIterato
use ahash::AHashMap;
-use super::Symbol;
-use crate::{linear::LinearValues, variables::VariableSafe};
+use super::{Key, Symbol, TypedSymbol};
+use crate::{
+ linear::LinearValues,
+ prelude::{DefaultSymbol, VariableUmbrella},
+ variables::VariableSafe,
+};
// Since we won't be passing dual numbers through any of this,
// we can just use dtype rather than using generics with Numeric
/// Structure to hold the Variables used in the graph.
///
-/// Values is essentially a thing wrapper around a Hashmap that maps [Symbol] ->
+/// Values is essentially a thin wrapper around a Hashmap that maps [Key] ->
/// [VariableSafe]. If you'd like to define a custom variable to be used in
/// Values, it must implement [Variable](crate::variables::Variable), and then
/// will implement [VariableSafe] via a blanket implementation.
/// ```
/// # use factrs::prelude::*;
+/// # assign_symbols!(X: SO2);
/// let x = SO2::from_theta(0.1);
/// let mut values = Values::new();
/// values.insert(X(0), x);
/// ```
+
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Values {
- values: AHashMap>,
+ values: AHashMap>,
}
impl Values {
@@ -41,69 +47,102 @@ impl Values {
/// Returns an [std::collections::hash_map::Entry] from the underlying
/// HashMap.
- pub fn entry(&mut self, key: Symbol) -> Entry> {
- self.values.entry(key)
+ pub fn entry(&mut self, key: impl Symbol) -> Entry> {
+ self.values.entry(key.into())
}
- pub fn insert(
- &mut self,
- key: Symbol,
- value: impl VariableSafe,
- ) -> Option> {
- self.values.insert(key, Box::new(value))
+ pub fn insert(&mut self, symbol: S, value: V) -> Option>
+ where
+ S: TypedSymbol,
+ V: VariableUmbrella,
+ {
+ self.values.insert(symbol.into(), Box::new(value))
}
- /// Returns a dynamic VariableSafe.
- ///
- /// If the underlying value is desired, use [Values::get_cast]
- pub fn get(&self, key: &Symbol) -> Option<&dyn VariableSafe> {
- self.values.get(key).map(|f| f.as_ref())
+ /// Unchecked verison of [Values::insert].
+ pub fn insert_unchecked(&mut self, symbol: S, value: V) -> Option>
+ where
+ S: Symbol,
+ V: VariableUmbrella,
+ {
+ self.values.insert(symbol.into(), Box::new(value))
+ }
+
+ pub(crate) fn get_raw(&self, symbol: S) -> Option<&dyn VariableSafe>
+ where
+ S: Symbol,
+ {
+ self.values.get(&symbol.into()).map(|f| f.as_ref())
}
- // TODO: This should be some kind of error
- /// Casts and returns the underlying variable.
+ /// Returns the underlying variable.
///
- /// This will return the value if variable is in the graph, and if the cast
- /// is successful. Returns None otherwise.
+ /// This will return the value if variable is in the graph. Requires a typed
+ /// symbol and as such is guaranteed to return the correct type. Returns
+ /// None if key isn't found.
/// ```
/// # use factrs::prelude::*;
+ /// # assign_symbols!(X: SO2);
/// # let x = SO2::from_theta(0.1);
/// # let mut values = Values::new();
/// # values.insert(X(0), x);
- /// let x_out = values.get_cast::(&X(0));
+ /// let x_out = values.get(X(0));
/// ```
- pub fn get_cast(&self, key: &Symbol) -> Option<&T> {
+ pub fn get(&self, symbol: S) -> Option<&V>
+ where
+ S: TypedSymbol,
+ V: VariableUmbrella,
+ {
self.values
- .get(key)
- .and_then(|value| value.downcast_ref::())
+ .get(&symbol.into())
+ .and_then(|value| value.downcast_ref::())
}
- // TODO: Does this still fail if one is missing?
- // pub fn get_multiple<'a>(&self, keys: impl IntoIterator- ) ->
- // Option> where
- // Symbol: 'a,
- // {
- // keys.into_iter().map(|key| self.values.get(key)).collect()
- // }
+ /// Returns the underlying variable, not checking the type.
+ pub fn get_unchecked
(&self, symbol: S) -> Option<&V>
+ where
+ S: Symbol,
+ V: VariableUmbrella,
+ {
+ self.values
+ .get(&symbol.into())
+ .and_then(|value| value.downcast_ref::())
+ }
/// Mutable version of [Values::get].
- pub fn get_mut(&mut self, key: &Symbol) -> Option<&mut dyn VariableSafe> {
- self.values.get_mut(key).map(|f| f.as_mut())
+ pub fn get_mut(&mut self, symbol: S) -> Option<&mut V>
+ where
+ S: TypedSymbol,
+ V: VariableUmbrella,
+ {
+ self.values
+ .get_mut(&symbol.into())
+ .and_then(|value| value.downcast_mut::())
}
- // TODO: This should be some kind of error
- /// Mutable version of [Values::get_cast].
- pub fn get_mut_cast(&mut self, key: &Symbol) -> Option<&mut T> {
+ /// Mutable version of [Values::get_unchecked].
+ pub fn get_unchecked_mut(&mut self, symbol: S) -> Option<&mut V>
+ where
+ S: Symbol,
+ V: VariableUmbrella,
+ {
self.values
- .get_mut(key)
- .and_then(|value| value.downcast_mut::())
+ .get_mut(&symbol.into())
+ .and_then(|value| value.downcast_mut::())
}
- pub fn remove(&mut self, key: &Symbol) -> Option> {
- self.values.remove(key)
+ pub fn remove(&mut self, symbol: S) -> Option
+ where
+ S: TypedSymbol,
+ V: VariableUmbrella,
+ {
+ self.values
+ .remove(&symbol.into())
+ .and_then(|value| value.downcast::().ok())
+ .map(|value| *value)
}
- pub fn iter(&self) -> impl Iterator- )> {
+ pub fn iter(&self) -> impl Iterator
- )> {
self.values.iter()
}
@@ -112,6 +151,7 @@ impl Values {
///
/// ```
/// # use factrs::prelude::*;
+ /// # assign_symbols!(X: SO2);
/// # let mut values = Values::new();
/// # (0..10).for_each(|i| {values.insert(X(0), SO2::identity());} );
/// let mine: Vec<&SO2> = values.filter().collect();
@@ -138,18 +178,20 @@ impl Values {
}
}
+// TODO: Find a way to make this usable on custom symbols (not just
+// DefaultSymbol)
impl fmt::Display for Values {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if f.alternate() {
writeln!(f, "{{")?;
for (key, value) in self.values.iter() {
- writeln!(f, " {}: {:?},", key, value)?;
+ writeln!(f, " {:?}: {:?},", DefaultSymbol::from(*key), value)?;
}
write!(f, "}}")
} else {
write!(f, "{{")?;
for (key, value) in self.values.iter() {
- write!(f, "{}: {:?}, ", key, value)?;
+ write!(f, "{:?}: {:?}, ", DefaultSymbol::from(*key), value)?;
}
write!(f, "}}")
}
@@ -163,8 +205,8 @@ impl fmt::Debug for Values {
}
impl IntoIterator for Values {
- type Item = (Symbol, Box);
- type IntoIter = std::collections::hash_map::IntoIter>;
+ type Item = (Key, Box);
+ type IntoIter = std::collections::hash_map::IntoIter>;
fn into_iter(self) -> Self::IntoIter {
self.values.into_iter()
diff --git a/src/lib.rs b/src/lib.rs
index e484803..e5f5cda 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -38,6 +38,9 @@
//! ```
//! use factrs::prelude::*;
//!
+//! // Assign symbols to variable types
+//! assign_symbols!(X: SO2);
+//!
//! // Make all the values
//! let mut values = Values::new();
//!
@@ -50,13 +53,16 @@
//! let mut graph = Graph::new();
//!
//! let res = PriorResidual::new(x.clone());
-//! let factor = Factor::new_base(&[X(0)], res);
+//! let factor = FactorBuilder::new1(res, X(0)).build();
//! graph.add_factor(factor);
//!
//! let res = BetweenResidual::new(y.minus(&x));
//! let noise = GaussianNoise::from_scalar_sigma(0.1);
//! let robust = Huber::default();
-//! let factor = Factor::new_full(&[X(0), X(1)], res, noise, robust);
+//! let factor = FactorBuilder::new2(res, X(0), X(1))
+//! .noise(noise)
+//! .robust(robust)
+//! .build();
//! graph.add_factor(factor);
//!
//! // Optimize!
@@ -83,6 +89,18 @@ pub mod robust;
pub mod utils;
pub mod variables;
+/// Untagged symnbols if `unchecked` API is desired.
+///
+/// We strongly recommend using [assign_symbols](crate::assign_symbols) to
+/// create and tag symbols with the appropriate types. However, we provide a
+/// number of pre-defined symbols if desired. Note this objects can't be tagged
+/// due to the orphan rules.
+pub mod symbols {
+ crate::assign_symbols!(
+ A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z
+ );
+}
+
/// Helper module to import common types
///
/// This module is meant to be glob imported to make it easier to use the
@@ -92,6 +110,7 @@ pub mod variables;
/// ```
pub mod prelude {
pub use crate::{
+ assign_symbols,
containers::*,
noise::*,
optimizers::*,
diff --git a/src/linear/factor.rs b/src/linear/factor.rs
index 7379985..f4150a3 100644
--- a/src/linear/factor.rs
+++ b/src/linear/factor.rs
@@ -1,5 +1,5 @@
use crate::{
- containers::Symbol,
+ containers::Key,
dtype,
linalg::{MatrixBlock, VectorX},
linear::LinearValues,
@@ -11,12 +11,12 @@ use crate::{
/// consists of the relevant keys, a [MatrixBlock] A, and a [VectorX] b. Again,
/// this *shouldn't* ever need to be used by hand.
pub struct LinearFactor {
- pub keys: Vec,
+ pub keys: Vec,
pub a: MatrixBlock,
pub b: VectorX,
}
impl LinearFactor {
- pub fn new(keys: Vec, a: MatrixBlock, b: VectorX) -> Self {
+ pub fn new(keys: Vec, a: MatrixBlock, b: VectorX) -> Self {
assert!(
keys.len() == a.idx().len(),
"Mismatch between keys and matrix blocks in LinearFactor::new"
@@ -40,7 +40,9 @@ impl LinearFactor {
.map(|(idx, key)| {
self.a.mul(
idx,
- vector.get(key).expect("Missing key in LinearValues::error"),
+ vector
+ .get(*key)
+ .expect("Missing key in LinearValues::error"),
)
})
.sum();
diff --git a/src/linear/graph.rs b/src/linear/graph.rs
index 5c334ba..0602286 100644
--- a/src/linear/graph.rs
+++ b/src/linear/graph.rs
@@ -49,7 +49,7 @@ impl LinearGraph {
let Idx {
idx: col,
dim: col_dim,
- } = order.get(key).unwrap();
+ } = order.get(*key).unwrap();
(0..*col_dim).for_each(|j| {
indices.push((row + i, col + j));
});
@@ -121,8 +121,9 @@ mod test {
use super::*;
use crate::{
- containers::{Idx, X},
+ containers::Idx,
linalg::{MatrixBlock, MatrixX, VectorX},
+ symbols::X,
};
#[test]
@@ -134,22 +135,26 @@ mod test {
let a1 = MatrixX::from_fn(2, 2, |i, j| (i + j) as dtype);
let block1 = MatrixBlock::new(a1, vec![0]);
let b1 = VectorX::from_fn(2, |i, j| (i + j) as dtype);
- graph.add_factor(LinearFactor::new(vec![X(1)], block1.clone(), b1.clone()));
+ graph.add_factor(LinearFactor::new(
+ vec![X(1).into()],
+ block1.clone(),
+ b1.clone(),
+ ));
let a2 = MatrixX::from_fn(3, 5, |i, j| (i + j) as dtype);
let block2 = MatrixBlock::new(a2, vec![0, 2]);
let b2 = VectorX::from_fn(3, |_, _| 5.0);
graph.add_factor(LinearFactor::new(
- vec![X(0), X(2)],
+ vec![X(0).into(), X(2).into()],
block2.clone(),
b2.clone(),
));
// Make fake ordering
let mut map = HashMap::default();
- map.insert(X(0), Idx { idx: 0, dim: 2 });
- map.insert(X(1), Idx { idx: 2, dim: 2 });
- map.insert(X(2), Idx { idx: 4, dim: 3 });
+ map.insert(X(0).into(), Idx { idx: 0, dim: 2 });
+ map.insert(X(1).into(), Idx { idx: 2, dim: 2 });
+ map.insert(X(2).into(), Idx { idx: 4, dim: 3 });
let order = ValuesOrder::new(map);
// Compute the residual and jacobian
diff --git a/src/linear/values.rs b/src/linear/values.rs
index 5cc360a..0ca1307 100644
--- a/src/linear/values.rs
+++ b/src/linear/values.rs
@@ -1,7 +1,7 @@
use std::collections::hash_map::Iter as HashMapIter;
use crate::{
- containers::{Idx, Symbol, Values, ValuesOrder},
+ containers::{Idx, Key, Symbol, Values, ValuesOrder},
linalg::{VectorViewX, VectorX},
};
@@ -69,7 +69,7 @@ impl LinearValues {
}
/// Retrieve a vector from the LinearValues
- pub fn get(&self, key: &Symbol) -> Option> {
+ pub fn get(&self, key: impl Symbol) -> Option> {
let idx = self.order.get(key)?;
self.get_idx(idx).into()
}
@@ -84,11 +84,11 @@ impl LinearValues {
pub struct Iter<'a> {
values: &'a LinearValues,
- idx: HashMapIter<'a, Symbol, Idx>,
+ idx: HashMapIter<'a, Key, Idx>,
}
impl<'a> Iterator for Iter<'a> {
- type Item = (&'a Symbol, VectorViewX<'a>);
+ type Item = (&'a Key, VectorViewX<'a>);
fn next(&mut self) -> Option {
let n = self.idx.next()?;
@@ -100,17 +100,17 @@ impl<'a> Iterator for Iter<'a> {
mod test {
use super::*;
use crate::{
- containers::X,
dtype,
+ symbols::X,
variables::{Variable, VectorVar2, VectorVar3, VectorVar6},
};
fn make_order_vector() -> (ValuesOrder, VectorX) {
// Create some form of values
let mut v = Values::new();
- v.insert(X(0), VectorVar2::identity());
- v.insert(X(1), VectorVar6::identity());
- v.insert(X(2), VectorVar3::identity());
+ v.insert_unchecked(X(0), VectorVar2::identity());
+ v.insert_unchecked(X(1), VectorVar6::identity());
+ v.insert_unchecked(X(2), VectorVar3::identity());
// Create an order
let order = ValuesOrder::from_values(&v);
@@ -126,10 +126,10 @@ mod test {
let linear_values = LinearValues::from_order_and_vector(order, vector);
assert!(linear_values.len() == 3);
assert!(linear_values.dim() == 11);
- assert!(linear_values.get(&X(0)).unwrap().len() == 2);
- assert!(linear_values.get(&X(1)).unwrap().len() == 6);
- assert!(linear_values.get(&X(2)).unwrap().len() == 3);
- assert!(linear_values.get(&X(3)).is_none());
+ assert!(linear_values.get(X(0)).unwrap().len() == 2);
+ assert!(linear_values.get(X(1)).unwrap().len() == 6);
+ assert!(linear_values.get(X(2)).unwrap().len() == 3);
+ assert!(linear_values.get(X(3)).is_none());
}
#[test]
diff --git a/src/optimizers/mod.rs b/src/optimizers/mod.rs
index 35b034f..ddba5c4 100644
--- a/src/optimizers/mod.rs
+++ b/src/optimizers/mod.rs
@@ -57,14 +57,17 @@ pub use levenberg_marquardt::LevenMarquardt;
#[cfg(test)]
pub mod test {
use faer::assert_matrix_eq;
+ use nalgebra::{DefaultAllocator, DimNameAdd, DimNameSum, ToTypenum};
use super::*;
use crate::{
- containers::{Factor, Graph, Values, X},
+ containers::{Graph, Values},
dtype,
- linalg::{Const, VectorX},
+ linalg::{AllocatorBuffer, Const, DualAllocator, DualVector, VectorX},
noise::{NoiseModelSafe, UnitNoise},
+ prelude::FactorBuilder,
residuals::{BetweenResidual, PriorResidual, Residual, ResidualSafe},
+ symbols::X,
variables::VariableUmbrella,
};
@@ -79,17 +82,17 @@ pub mod test {
let p = T::exp(t.as_view());
let mut values = Values::new();
- values.insert(X(0), T::identity());
+ values.insert_unchecked(X(0), T::identity());
let mut graph = Graph::new();
let res = PriorResidual::new(p.clone());
- let factor = Factor::new_base(&[X(0)], res);
+ let factor = FactorBuilder::new1_unchecked(res, X(0)).build();
graph.add_factor(factor);
let mut opt = O::new(graph);
values = opt.optimize(values).unwrap();
- let out: &T = values.get_cast(&X(0)).unwrap();
+ let out: &T = values.get_unchecked(X(0)).unwrap();
assert_matrix_eq!(
out.ominus(&p),
VectorX::zeros(T::DIM),
@@ -107,6 +110,11 @@ pub mod test {
BetweenResidual: ResidualSafe
+ Residual, DimOut = Const, NumVars = Const<2>>,
O: Optimizer + GraphOptimizer,
+ Const: ToTypenum,
+ AllocatorBuffer, Const>>: Sync + Send,
+ DefaultAllocator: DualAllocator, Const>>,
+ DualVector, Const>>: Copy,
+ Const: DimNameAdd>,
{
let t = VectorX::from_fn(T::DIM, |_, i| ((i as dtype) - (T::DIM as dtype)) / 10.0);
let p1 = T::exp(t.as_view());
@@ -115,23 +123,23 @@ pub mod test {
let p2 = T::exp(t.as_view());
let mut values = Values::new();
- values.insert(X(0), T::identity());
- values.insert(X(1), T::identity());
+ values.insert_unchecked(X(0), T::identity());
+ values.insert_unchecked(X(1), T::identity());
let mut graph = Graph::new();
let res = PriorResidual::new(p1.clone());
- let factor = Factor::new_base(&[X(0)], res);
+ let factor = FactorBuilder::new1_unchecked(res, X(0)).build();
graph.add_factor(factor);
let diff = p2.minus(&p1);
let res = BetweenResidual::new(diff);
- let factor = Factor::new_base(&[X(0), X(1)], res);
+ let factor = FactorBuilder::new2_unchecked(res, X(0), X(1)).build();
graph.add_factor(factor);
let mut opt = O::new(graph);
values = opt.optimize(values).unwrap();
- let out1: &T = values.get_cast(&X(0)).unwrap();
+ let out1: &T = values.get_unchecked(X(0)).unwrap();
assert_matrix_eq!(
out1.ominus(&p1),
VectorX::zeros(T::DIM),
@@ -139,7 +147,7 @@ pub mod test {
tol = 1e-6
);
- let out2: &T = values.get_cast(&X(1)).unwrap();
+ let out2: &T = values.get_unchecked(X(1)).unwrap();
assert_matrix_eq!(
out2.ominus(&p2),
VectorX::zeros(T::DIM),
diff --git a/src/residuals/between.rs b/src/residuals/between.rs
index f0e9662..cb55976 100644
--- a/src/residuals/between.rs
+++ b/src/residuals/between.rs
@@ -3,7 +3,7 @@ use nalgebra::{DimNameAdd, DimNameSum};
use super::{Residual, Residual2};
#[allow(unused_imports)]
use crate::{
- containers::{Symbol, Values},
+ containers::{Key, Values},
linalg::{
AllocatorBuffer,
Const,
@@ -99,11 +99,11 @@ where
type DimIn = ::DimIn;
type NumVars = Const<2>;
- fn residual(&self, values: &Values, keys: &[Symbol]) -> VectorX {
+ fn residual(&self, values: &Values, keys: &[Key]) -> VectorX {
self.residual2_values(values, keys)
}
- fn residual_jacobian(&self, values: &Values, keys: &[Symbol]) -> DiffResult {
+ fn residual_jacobian(&self, values: &Values, keys: &[Key]) -> DiffResult {
self.residual2_jacobian(values, keys)
}
}
diff --git a/src/residuals/prior.rs b/src/residuals/prior.rs
index f6824af..1ec52da 100644
--- a/src/residuals/prior.rs
+++ b/src/residuals/prior.rs
@@ -1,7 +1,7 @@
use super::{Residual, Residual1};
#[allow(unused_imports)]
use crate::{
- containers::{Symbol, Values},
+ containers::{Key, Values},
linalg::{
AllocatorBuffer,
Const,
@@ -88,10 +88,10 @@ where
type DimIn = ::DimIn;
type DimOut = ::DimOut;
type NumVars = Const<1>;
- fn residual(&self, values: &Values, keys: &[Symbol]) -> VectorX {
+ fn residual(&self, values: &Values, keys: &[Key]) -> VectorX {
self.residual1_values(values, keys)
}
- fn residual_jacobian(&self, values: &Values, keys: &[Symbol]) -> DiffResult {
+ fn residual_jacobian(&self, values: &Values, keys: &[Key]) -> DiffResult {
self.residual1_jacobian(values, keys)
}
}
@@ -103,8 +103,8 @@ mod test {
use super::*;
use crate::{
- containers::X,
linalg::{vectorx, DefaultAllocator, Diff, DualAllocator, NumericalDiff},
+ symbols::X,
variables::{VectorVar3, SE3, SO3},
};
@@ -129,13 +129,15 @@ mod test {
let x1 = P::identity();
let mut values = Values::new();
- values.insert(X(0), x1.clone());
- let jac = prior_residual.residual1_jacobian(&values, &[X(0)]).diff;
+ values.insert_unchecked(X(0), x1.clone());
+ let jac = prior_residual
+ .residual1_jacobian(&values, &[X(0).into()])
+ .diff;
let f = |v: P| {
let mut vals = Values::new();
- vals.insert(X(0), v.clone());
- Residual1::residual1_values(&prior_residual, &vals, &[X(0)])
+ vals.insert_unchecked(X(0), v.clone());
+ Residual1::residual1_values(&prior_residual, &vals, &[X(0).into()])
};
let jac_n = NumericalDiff::::jacobian_1(f, &x1).diff;
diff --git a/src/residuals/traits.rs b/src/residuals/traits.rs
index 7639c6e..c0926d5 100644
--- a/src/residuals/traits.rs
+++ b/src/residuals/traits.rs
@@ -1,7 +1,7 @@
use std::fmt::{Debug, Display};
use crate::{
- containers::{Symbol, Values},
+ containers::{Key, Values},
linalg::{Diff, DiffResult, DimName, MatrixX, Numeric, VectorX},
variables::{Variable, VariableUmbrella},
};
@@ -27,9 +27,9 @@ pub trait Residual: Debug + Display {
Self::DimOut::USIZE
}
- fn residual(&self, values: &Values, keys: &[Symbol]) -> VectorX;
+ fn residual(&self, values: &Values, keys: &[Key]) -> VectorX;
- fn residual_jacobian(&self, values: &Values, keys: &[Symbol]) -> DiffResult;
+ fn residual_jacobian(&self, values: &Values, keys: &[Key]) -> DiffResult;
}
/// The object safe version of [Residual].
@@ -42,9 +42,9 @@ pub trait ResidualSafe: Debug + Display {
fn dim_out(&self) -> usize;
- fn residual(&self, values: &Values, keys: &[Symbol]) -> VectorX;
+ fn residual(&self, values: &Values, keys: &[Key]) -> VectorX;
- fn residual_jacobian(&self, values: &Values, keys: &[Symbol]) -> DiffResult;
+ fn residual_jacobian(&self, values: &Values, keys: &[Key]) -> DiffResult;
}
impl<
@@ -60,11 +60,11 @@ impl<
Residual::dim_out(self)
}
- fn residual(&self, values: &Values, keys: &[Symbol]) -> VectorX {
+ fn residual(&self, values: &Values, keys: &[Key]) -> VectorX {
Residual::residual(self, values, keys)
}
- fn residual_jacobian(&self, values: &Values, keys: &[Symbol]) -> DiffResult {
+ fn residual_jacobian(&self, values: &Values, keys: &[Key]) -> DiffResult {
Residual::residual_jacobian(self, values, keys)
}
@@ -105,8 +105,8 @@ macro_rules! residual_maker {
/// It is generic over the dtype to allow for differentiable types.
fn [](&self, $($name: Alias,)*) -> VectorX;
- #[doc=concat!("Wrapper that unpacks variables and calls [", stringify!([]), "](Self::", stringify!([]), ").")]
- fn [](&self, values: &Values, keys: &[Symbol]) -> VectorX
+ #[doc="Wrapper that unpacks and calls [" [] "](Self::" [] ")."]
+ fn [](&self, values: &Values, keys: &[Key]) -> VectorX
where
$(
Self::$var: 'static,
@@ -114,7 +114,7 @@ macro_rules! residual_maker {
{
// Unwrap everything
$(
- let $name: &Self::$var = values.get_cast(&keys[$idx]).unwrap_or_else(|| {
+ let $name: &Self::$var = values.get_unchecked(keys[$idx]).unwrap_or_else(|| {
panic!("Key not found in values: {:?} with type {}", keys[$idx], std::any::type_name::())
});
)*
@@ -122,8 +122,8 @@ macro_rules! residual_maker {
}
- #[doc=concat!("Wrapper that unpacks variables and computes jacobians using [", stringify!([]), "](Self::", stringify!([]), ").")]
- fn [](&self, values: &Values, keys: &[Symbol]) -> DiffResult
+ #[doc="Wrapper that unpacks variables and computes jacobians using [" [] "](Self::" [] ")."]
+ fn [](&self, values: &Values, keys: &[Key]) -> DiffResult
where
$(
Self::$var: 'static,
@@ -131,7 +131,7 @@ macro_rules! residual_maker {
{
// Unwrap everything
$(
- let $name: &Self::$var = values.get_cast(&keys[$idx]).unwrap_or_else(|| {
+ let $name: &Self::$var = values.get_unchecked(keys[$idx]).unwrap_or_else(|| {
panic!("Key not found in values: {:?} with type {}", keys[$idx], std::any::type_name::())
});
)*
diff --git a/src/utils.rs b/src/utils.rs
index 014345e..e3f979a 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -5,7 +5,8 @@ use std::{
};
use crate::{
- containers::{Factor, Graph, Values, X},
+ assign_symbols,
+ containers::{FactorBuilder, Graph, Values},
dtype,
linalg::{Matrix3, Matrix6, Vector3},
noise::GaussianNoise,
@@ -13,6 +14,8 @@ use crate::{
variables::*,
};
+assign_symbols!(X: SE2, SE3);
+
/// Load a g2o file
///
/// Currently supports only SE2 and SE3 pose graphs. Will autodetect which one
@@ -38,7 +41,7 @@ pub fn load_g20(file: &str) -> (Graph, Values) {
// Add prior on whatever the first variable is
if values.len() == 1 {
- let factor = Factor::new_base(&[key.clone()], PriorResidual::new(var.clone()));
+ let factor = FactorBuilder::new1(PriorResidual::new(var.clone()), key).build();
graph.add_factor(factor);
}
@@ -64,7 +67,9 @@ pub fn load_g20(file: &str) -> (Graph, Values) {
let key2 = X(id_curr);
let var = SE2::new(theta, x, y);
let noise = GaussianNoise::from_matrix_inf(inf.as_view());
- let factor = Factor::new_noise(&[key1, key2], BetweenResidual::new(var), noise);
+ let factor = FactorBuilder::new2(BetweenResidual::new(var), key1, key2)
+ .noise(noise)
+ .build();
graph.add_factor(factor);
}
@@ -87,8 +92,9 @@ pub fn load_g20(file: &str) -> (Graph, Values) {
if values.len() == 1 {
let noise =
GaussianNoise::<6>::from_diag_covs(1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4);
- let factor =
- Factor::new_noise(&[key.clone()], PriorResidual::new(var.clone()), noise);
+ let factor = FactorBuilder::new1(PriorResidual::new(var.clone()), key)
+ .noise(noise)
+ .build();
graph.add_factor(factor);
}
@@ -141,14 +147,12 @@ pub fn load_g20(file: &str) -> (Graph, Values) {
let xyz = Vector3::new(x, y, z);
let var = SE3::from_rot_trans(rot, xyz);
- // println!("var: {:?}", var);
- // println!("2499 {:?}", values.get_cast::(&X(2499)).unwrap());
- // panic!();
-
let key1 = X(id_prev);
let key2 = X(id_curr);
let noise = GaussianNoise::from_matrix_inf(inf.as_view());
- let factor = Factor::new_noise(&[key1, key2], BetweenResidual::new(var), noise);
+ let factor = FactorBuilder::new2(BetweenResidual::new(var), key1, key2)
+ .noise(noise)
+ .build();
graph.add_factor(factor);
}
diff --git a/src/variables/macros.rs b/src/variables/macros.rs
index 1ee3c38..b010b2c 100644
--- a/src/variables/macros.rs
+++ b/src/variables/macros.rs
@@ -1,4 +1,4 @@
-/// Variable wrapper around [matrixcompare::assert_matrix_eq]
+/// Variable wrapper around [assert_matrix_eq](matrixcompare::assert_matrix_eq)
///
/// This compares two variables using
/// [ominus](crate::variables::Variable::ominus)