-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[package] | ||
name = "bars" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
|
||
[lib] | ||
crate-type = ["cdylib"] | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[profile.release] | ||
lto = "fat" | ||
|
||
[profile.dev] | ||
lto = "fat" | ||
|
||
[dependencies] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
#![feature(autodiff)] | ||
#![feature(slice_first_last_chunk)] | ||
#![allow(non_snake_case)] | ||
|
||
//#define BA_NCAMPARAMS 11 | ||
static BA_NCAMPARAMS: usize = 11; | ||
|
||
fn sqsum(x: &[f64]) -> f64 { | ||
x.iter().map(|&v| v * v).sum() | ||
} | ||
|
||
#[inline] | ||
fn cross(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] { | ||
[ | ||
a[1] * b[2] - a[2] * b[1], | ||
a[2] * b[0] - a[0] * b[2], | ||
a[0] * b[1] - a[1] * b[0], | ||
] | ||
} | ||
|
||
fn radial_distort(rad_params: &[f64], proj: &mut [f64]) { | ||
let rsq = sqsum(proj); | ||
let l = 1. + rad_params[0] * rsq + rad_params[1] * rsq * rsq; | ||
proj[0] = proj[0] * l; | ||
proj[1] = proj[1] * l; | ||
} | ||
|
||
fn rodrigues_rotate_point(rot: &[f64; 3], pt: &[f64; 3], rotated_pt: &mut [f64; 3]) { | ||
let sqtheta = sqsum(rot); | ||
if sqtheta != 0. { | ||
let theta = sqtheta.sqrt(); | ||
let costheta = theta.cos(); | ||
let sintheta = theta.sin(); | ||
let theta_inverse = 1. / theta; | ||
let w = rot.map(|v| v * theta_inverse); | ||
let w_cross_pt = cross(&w, &pt); | ||
let tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) * (1. - costheta); | ||
for i in 0..3 { | ||
rotated_pt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp; | ||
} | ||
} else { | ||
let rot_cross_pt = cross(&rot, &pt); | ||
for i in 0..3 { | ||
rotated_pt[i] = pt[i] + rot_cross_pt[i]; | ||
} | ||
} | ||
} | ||
|
||
fn project(cam: &[f64; 11], X: &[f64; 3], proj: &mut [f64; 2]) { | ||
let C = &cam[3..6]; | ||
let mut Xo = [0.; 3]; | ||
let mut Xcam = [0.; 3]; | ||
|
||
Xo[0] = X[0] - C[0]; | ||
Xo[1] = X[1] - C[1]; | ||
Xo[2] = X[2] - C[2]; | ||
|
||
rodrigues_rotate_point(cam.first_chunk::<3>().unwrap(), &Xo, &mut Xcam); | ||
|
||
proj[0] = Xcam[0] / Xcam[2]; | ||
proj[1] = Xcam[1] / Xcam[2]; | ||
|
||
radial_distort(&cam[9..], proj); | ||
|
||
proj[0] = proj[0] * cam[6] + cam[7]; | ||
proj[1] = proj[1] * cam[6] + cam[8]; | ||
} | ||
|
||
#[no_mangle] | ||
pub extern "C" fn dcompute_reproj_error( | ||
cam: *const [f64; 11], | ||
dcam: *mut [f64; 11], | ||
x: *const [f64; 3], | ||
dx: *mut [f64; 3], | ||
w: *const [f64; 1], | ||
wb: *mut [f64; 1], | ||
feat: *const [f64; 2], | ||
err: *mut [f64; 2], | ||
derr: *mut [f64; 2], | ||
) { | ||
rust_dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr); | ||
} | ||
|
||
#[no_mangle] | ||
pub extern "C" fn dcompute_zach_weight_error( | ||
w: *const f64, | ||
dw: *mut f64, | ||
err: *mut f64, | ||
derr: *mut f64, | ||
) { | ||
rust_dcompute_zach_weight_error(w, dw, err, derr); | ||
} | ||
|
||
#[autodiff( | ||
rust_dcompute_reproj_error, | ||
Reverse, | ||
Duplicated, | ||
Duplicated, | ||
Duplicated, | ||
Const, | ||
Duplicated | ||
)] | ||
pub fn compute_reproj_error( | ||
cam: *const [f64; 11], | ||
x: *const [f64; 3], | ||
w: *const [f64; 1], | ||
feat: *const [f64; 2], | ||
err: *mut [f64; 2], | ||
) { | ||
let cam = unsafe { &*cam }; | ||
let w = unsafe { *(*w).get_unchecked(0) }; | ||
let x = unsafe { &*x }; | ||
let feat = unsafe { &*feat }; | ||
let mut err = unsafe { &mut *err }; | ||
let mut proj = [0.; 2]; | ||
project(cam, x, &mut proj); | ||
err[0] = w * (proj[0] - feat[0]); | ||
err[1] = w * (proj[1] - feat[1]); | ||
} | ||
|
||
#[autodiff(rust_dcompute_zach_weight_error, Reverse, Duplicated, Duplicated)] | ||
pub fn compute_zach_weight_error(w: *const f64, err: *mut f64) { | ||
let w = unsafe { *w }; | ||
let mut err = unsafe { *err }; | ||
err = 1. - w * w; | ||
} | ||
|
||
// n number of cameras | ||
// m number of points | ||
// p number of observations | ||
// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2] | ||
// r1, r2, r3 are angle - axis rotation parameters(Rodrigues) | ||
// [C1 C2 C3]' is the camera center | ||
// f is the focal length in pixels | ||
// [u0 v0]' is the principal point | ||
// k1, k2 are radial distortion parameters | ||
// X: 3*m points | ||
// obs: 2*p observations (pairs cameraIdx, pointIdx) | ||
// feats: 2*p features (x,y coordinates corresponding to observations) | ||
// reproj_err: 2*p errors of observations | ||
// w_err: p weight "error" terms | ||
fn rust_ba_objective( | ||
n: usize, | ||
m: usize, | ||
p: usize, | ||
cams: &[f64], | ||
x: &[f64], | ||
w: &[f64], | ||
obs: &[i32], | ||
feats: &[f64], | ||
reproj_err: &mut [f64], | ||
w_err: &mut [f64], | ||
) { | ||
assert_eq!(cams.len(), n * 11); | ||
assert_eq!(x.len(), m * 3); | ||
assert_eq!(w.len(), p); | ||
assert_eq!(obs.len(), p * 2); | ||
assert_eq!(feats.len(), p * 2); | ||
assert_eq!(reproj_err.len(), p * 2); | ||
assert_eq!(w_err.len(), p); | ||
|
||
for i in 0..p { | ||
let cam_idx = obs[i * 2 + 0] as usize; | ||
let pt_idx = obs[i * 2 + 1] as usize; | ||
let start = cam_idx * BA_NCAMPARAMS; | ||
let cam: &[f64; 11] = unsafe { | ||
cams[start..] | ||
.get_unchecked(..11) | ||
.try_into() | ||
.unwrap_unchecked() | ||
}; | ||
let x: &[f64; 3] = unsafe { | ||
x[pt_idx * 3..] | ||
.get_unchecked(..3) | ||
.try_into() | ||
.unwrap_unchecked() | ||
}; | ||
let w: &[f64; 1] = unsafe { w[i..].get_unchecked(..1).try_into().unwrap_unchecked() }; | ||
let feat: &[f64; 2] = unsafe { | ||
feats[i * 2..] | ||
.get_unchecked(..2) | ||
.try_into() | ||
.unwrap_unchecked() | ||
}; | ||
let reproj_err: &mut [f64; 2] = unsafe { | ||
reproj_err[i * 2..] | ||
.get_unchecked_mut(..2) | ||
.try_into() | ||
.unwrap_unchecked() | ||
}; | ||
compute_reproj_error(cam, x, w, feat, reproj_err); | ||
} | ||
|
||
for i in 0..p { | ||
let w_err: &mut f64 = unsafe { w_err.get_unchecked_mut(i) }; | ||
compute_zach_weight_error(w[i..].as_ptr(), w_err as *mut f64); | ||
} | ||
} | ||
|
||
#[no_mangle] | ||
extern "C" fn ba_objective( | ||
n: usize, | ||
m: usize, | ||
p: usize, | ||
cams: *const f64, | ||
x: *const f64, | ||
w: *const f64, | ||
obs: *const i32, | ||
feats: *const f64, | ||
reproj_err: *mut f64, | ||
w_err: *mut f64, | ||
) { | ||
let cams = unsafe { std::slice::from_raw_parts(cams, n * 11) }; | ||
let x = unsafe { std::slice::from_raw_parts(x, m * 3) }; | ||
let w = unsafe { std::slice::from_raw_parts(w, p) }; | ||
let obs = unsafe { std::slice::from_raw_parts(obs, p * 2) }; | ||
let feats = unsafe { std::slice::from_raw_parts(feats, p * 2) }; | ||
let reproj_err = unsafe { std::slice::from_raw_parts_mut(reproj_err, p * 2) }; | ||
let w_err = unsafe { std::slice::from_raw_parts_mut(w_err, p) }; | ||
rust_ba_objective(n, m, p, cams, x, w, obs, feats, reproj_err, w_err); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
use bars::{dcompute_reproj_error, dcompute_zach_weight_error}; | ||
fn main() { | ||
let cam = [0.0; 11]; | ||
let mut dcam = [0.0; 11]; | ||
let x = [0.0; 3]; | ||
let mut dx = [0.0; 3]; | ||
let w = [0.0; 1]; | ||
let mut dw = [0.0; 1]; | ||
let feat = [0.0; 2]; | ||
let mut err = [0.0; 2]; | ||
let mut derr = [0.0; 2]; | ||
dcompute_reproj_error( | ||
&cam as *const [f64;11], | ||
&mut dcam as *mut [f64;11], | ||
&x as *const [f64;3], | ||
&mut dx as *mut [f64;3], | ||
&w as *const [f64;1], | ||
&mut dw as *mut [f64;1], | ||
&feat as *const [f64;2], | ||
&mut err as *mut [f64;2], | ||
&mut derr as *mut [f64;2], | ||
); | ||
|
||
let mut wb = 0.0; | ||
dcompute_zach_weight_error(&w as *const f64, &mut dw as *mut f64, &mut err as *mut f64, &mut derr as *mut f64); | ||
} |