Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ica bench #261

Merged
merged 16 commits into from
Nov 5, 2022
5 changes: 5 additions & 0 deletions algorithms/linfa-ica/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ linfa = { version = "0.6.0", path = "../.." }
[dev-dependencies]
ndarray-npy = { version = "0.8", default-features = false }
paste = "1.0"
criterion = "0.4"

[[bench]]
name = "fast_ica"
harness = false
60 changes: 60 additions & 0 deletions algorithms/linfa-ica/benches/fast_ica.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use linfa::{dataset::DatasetBase, traits::Fit};
use linfa_ica::fast_ica::{FastIca, GFunc};
use ndarray::{array, concatenate};
use ndarray::{Array, Array2, Axis};
use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
use rand_xoshiro::Xoshiro256Plus;

fn perform_ica(size: usize) -> () {
oojo12 marked this conversation as resolved.
Show resolved Hide resolved
let sources_mixed = create_data(size);

let ica = FastIca::params().gfunc(GFunc::Logcosh(1.0));
YuhanLiin marked this conversation as resolved.
Show resolved Hide resolved

let ica = ica.fit(&DatasetBase::from(sources_mixed.view()));
}

fn create_data(nsamples: usize) -> Array2<f64> {
// Creating a sine wave signal
let source1 = Array::linspace(0., 8., nsamples).mapv(|x| (2f64 * x).sin());

// Creating a sawtooth signal
let source2 = Array::linspace(0., 8., nsamples).mapv(|x| {
let tmp = (4f64 * x).sin();
if tmp > 0. {
return 1.;
}
-1.
});

// Column concatenating both the signals
let mut sources_original = concatenate![
Axis(1),
source1.insert_axis(Axis(1)),
source2.insert_axis(Axis(1))
];

// Adding noise to the signals
let mut rng = Xoshiro256Plus::seed_from_u64(42);
sources_original +=
&Array::random_using((nsamples, 2), Uniform::new(0.0, 1.0), &mut rng).mapv(|x| x * 0.2);

// Mixing the two signals
let mixing = array![[1., 1.], [0.5, 2.]];
let sources_mixed = sources_original.dot(&mixing.t());

sources_mixed
}

fn bench(c: &mut Criterion) {
let mut group = c.benchmark_group("fast_ica_bench");
for size in [1_000, 10_000, 100_000].iter() {
YuhanLiin marked this conversation as resolved.
Show resolved Hide resolved
group.bench_with_input(BenchmarkId::new("fast-ica-{}", size), size, |b, &size| {
b.iter(|| perform_ica(size));
});
}
group.finish();
}

criterion_group!(benches, bench);
criterion_main!(benches);