Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ica bench #261

Merged
merged 16 commits into from
Nov 5, 2022
5 changes: 5 additions & 0 deletions algorithms/linfa-ica/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ linfa = { version = "0.6.0", path = "../.." }
[dev-dependencies]
ndarray-npy = { version = "0.8", default-features = false }
paste = "1.0"
criterion = "0.4.0"

[[bench]]
name = "fast_ica"
harness = false
68 changes: 68 additions & 0 deletions algorithms/linfa-ica/benches/fast_ica.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use linfa::{dataset::DatasetBase, traits::Fit};
use linfa_ica::fast_ica::{FastIca, GFunc};
use ndarray::{array, concatenate};
use ndarray::{Array, Array2, Axis};
use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
use rand_xoshiro::Xoshiro256Plus;

fn perform_ica(size: usize, gfunc: GFunc) {
let sources_mixed = create_data(size);

let ica = FastIca::params().gfunc(gfunc).random_state(10);

let ica = ica.fit(&DatasetBase::from(sources_mixed.view()));
}

fn create_data(nsamples: usize) -> Array2<f64> {
// Creating a sine wave signal
let source1 = Array::linspace(0., 8., nsamples).mapv(|x| (2f64 * x).sin());

// Creating a sawtooth signal
let source2 = Array::linspace(0., 8., nsamples).mapv(|x| {
let tmp = (4f64 * x).sin();
if tmp > 0. {
return 1.;
}
-1.
});

// Column concatenating both the signals
let mut sources_original = concatenate![
Axis(1),
source1.insert_axis(Axis(1)),
source2.insert_axis(Axis(1))
];

// Adding noise to the signals
let mut rng = Xoshiro256Plus::seed_from_u64(42);
sources_original +=
&Array::random_using((nsamples, 2), Uniform::new(0.0, 1.0), &mut rng).mapv(|x| x * 0.2);

// Mixing the two signals
let mixing = array![[1., 1.], [0.5, 2.]];
let sources_mixed = sources_original.dot(&mixing.t());

sources_mixed
}

fn bench(c: &mut Criterion) {
for (gfunc, name) in [
(GFunc::Cube, "GFunc_Cube"),
(GFunc::Logcosh(1.0), "GFunc_Logcosh"),
(GFunc::Exp, "Exp"),
] {
let mut group = c.benchmark_group("Fast ICA");
let sizes: [usize; 3] = [1_000, 10_000, 100_000];
for size in sizes {
let input = (size, gfunc);
group.bench_with_input(BenchmarkId::new(name, size), &input, |b, (size, gfunc)| {
b.iter(|| perform_ica(*size, *gfunc));
});
}
group.finish();
}
}

criterion_group!(benches, bench);
criterion_main!(benches);
2 changes: 1 addition & 1 deletion algorithms/linfa-ica/src/fast_ica.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ impl<F: Float> PredictInplace<Array2<F>, Array2<F>> for FastIca<F> {
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq, Copy)]
pub enum GFunc {
Logcosh(f64),
Exp,
Expand Down
2 changes: 1 addition & 1 deletion src/metrics_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ mod tests {

// randomly sample ground truth
let ground_truth = (0..1000)
.map(|_| rng.sample(&range) == 1)
.map(|_| rng.sample(range) == 1)
.collect::<Vec<_>>();

// ROC Area-Under-Curve should be approximately 0.5
Expand Down