Skip to content

Commit

Permalink
feat: parallel DTW calculations in augurs-js (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k authored Sep 12, 2024
1 parent 22d95f3 commit 7d27a33
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 3 deletions.
1 change: 1 addition & 0 deletions crates/augurs-dtw/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ description = "Dynamic Time Warping (DTW) algorithm for Rust"
[dependencies]
augurs-core.workspace = true
rayon = { version = "1.10.0", optional = true }
tracing.workspace = true

[features]
parallel = ["dep:rayon"]
Expand Down
2 changes: 2 additions & 0 deletions crates/augurs-dtw/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use augurs_core::DistanceMatrix;

#[cfg(feature = "parallel")]
use rayon::prelude::*;
use tracing::debug;

/// A trait for defining a distance function.
///
Expand Down Expand Up @@ -477,6 +478,7 @@ impl<T: Distance + Send + Sync> Dtw<T> {
let matrix = if self.parallelize {
let n = series.len();
let mut matrix = Vec::with_capacity(n);
debug!("Calculating distance matrix in parallel");
series
.par_iter()
.map(|s| {
Expand Down
8 changes: 8 additions & 0 deletions crates/augurs-js/.cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[target.wasm32-unknown-unknown]
rustflags = ["-C", "target-feature=+atomics,+bulk-memory,+mutable-globals"]

[unstable]
build-std = ["panic_abort", "std"]

[build]
target = "wasm32-unknown-unknown"
4 changes: 3 additions & 1 deletion crates/augurs-js/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ crate-type = ["cdylib", "rlib"]

[features]
default = ["console_error_panic_hook"]
parallel = ["wasm-bindgen-rayon"]

[dependencies]
augurs-changepoint = { workspace = true }
augurs-clustering = { workspace = true }
augurs-core = { workspace = true }
augurs-dtw = { workspace = true }
augurs-dtw = { workspace = true, features = ["parallel"] }
augurs-ets = { workspace = true, features = ["mstl"] }
augurs-forecaster.workspace = true
augurs-mstl = { workspace = true }
Expand All @@ -39,3 +40,4 @@ serde-wasm-bindgen = "0.6.0"
tracing-wasm = { version = "0.2.1", optional = true }
tsify-next = { version = "0.5.3", default-features = false, features = ["js"] }
wasm-bindgen = "0.2.87"
wasm-bindgen-rayon = { version = "1.2.1", optional = true }
29 changes: 29 additions & 0 deletions crates/augurs-js/prepublish.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// This script adds the "snippets/" directory to the files array in package.json.
// Needed because of https://github.com/rustwasm/wasm-pack/issues/1206.
const fs = require('fs');
const path = require('path');

try {
const pkgPath = path.join(__dirname, "pkg/package.json");

// Check if package.json exists
if (!fs.existsSync(pkgPath)) {
console.error(`Error: File ${pkgPath} not found.`);
process.exit(1);
}

const pkg = JSON.parse(fs.readFileSync(pkgPath, 'utf-8'));

// Add snippets to the files array. If no files array exists, create one.
pkg.files = pkg.files || [];
if (!pkg.files.includes('snippets/')) {
pkg.files.push('snippets/');
fs.writeFileSync(pkgPath, JSON.stringify(pkg, null, 2));
console.log('Successfully added "snippets/" to package.json.');
} else {
console.log('"snippets/" already exists in package.json.');
}
} catch (error) {
console.error(`An error occurred: ${error.message}`);
process.exit(1);
}
8 changes: 8 additions & 0 deletions crates/augurs-js/rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Build augurs-js with the nightly toolchain and wasm32-unknown-unknown target.
# This is required for the `wasm-bindgen-rayon` dependency, which requires
# some nightly-only features (see .cargo/config.toml).
[toolchain]
channel = "nightly-2024-09-01"
components = ["rust-src"]
targets = ["wasm32-unknown-unknown"]
profile = "minimal"
10 changes: 9 additions & 1 deletion crates/augurs-js/src/dtw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ pub struct DtwOpts {
/// calculation and return this bound instead.
#[tsify(optional)]
pub upper_bound: Option<f64>,

#[tsify(optional)]
pub parallelize: Option<bool>,
}

/// A distance matrix.
Expand Down Expand Up @@ -131,6 +134,9 @@ impl Dtw {
if let Some(upper_bound) = opts.upper_bound {
dtw = dtw.with_upper_bound(upper_bound);
}
if let Some(parallelize) = opts.parallelize {
dtw = dtw.parallelize(parallelize);
}
Ok(Dtw {
inner: InnerDtw::Euclidean(dtw),
})
Expand Down Expand Up @@ -159,14 +165,16 @@ impl Dtw {
}

/// Calculate the distance between two arrays under Dynamic Time Warping.
#[wasm_bindgen]
pub fn distance(&self, a: Float64Array, b: Float64Array) -> f64 {
self.inner.distance(&a.to_vec(), &b.to_vec())
}

/// Compute the distance matrix between all pairs of series.
///
/// The series do not all have to be the same length.
pub fn distanceMatrix(&self, series: Vec<Float64Array>) -> DistanceMatrix {
#[wasm_bindgen(js_name = distanceMatrix)]
pub fn distance_matrix(&self, series: Vec<Float64Array>) -> DistanceMatrix {
let vecs: Vec<_> = series.iter().map(|x| x.to_vec()).collect();
let slices = vecs.iter().map(Vec::as_slice).collect::<Vec<_>>();
self.inner.distance_matrix(&slices)
Expand Down
43 changes: 43 additions & 0 deletions crates/augurs-js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,49 @@ use serde::Serialize;
use tsify_next::Tsify;
use wasm_bindgen::prelude::*;

/// Initialize the rayon thread pool.
///
/// This must be called once (from a Javascript context) and awaited
/// before using parallel mode of algorithms, to set up the thread pool.
///
/// # Example (JS)
///
/// ```js
/// // worker.ts
/// import init, { Dbscan, Dtw, initThreadPool} from '@bsull/augurs';
///
/// init().then(async () => {
/// console.debug('augurs initialized');
/// await initThreadPool(navigator.hardwareConcurrency * 2);
/// console.debug('augurs thread pool initialized');
/// });
///
/// export function dbscan(series: Float64Array[], epsilon: number, minClusterSize: number): number[] {
/// const distanceMatrix = Dtw.euclidean({ window: 10, parallelize: true }).distanceMatrix(series);
/// const clusterLabels = new Dbscan({ epsilon, minClusterSize }).fit(distanceMatrix);
/// return Array.from(clusterLabels);
/// }
///
/// // index.js
/// import { dbscan } from './worker';
///
/// async function runClustering(series: Float64Array[]): Promise<number[]> {
/// return dbscan(series, 0.1, 10); // await only required if using workerize-loader
/// }
///
/// // or using e.g. workerize-loader to run in a dedicated worker:
/// import worker from 'workerize-loader?ready&name=augurs!./worker';
///
/// const instance = worker()
///
/// async function runClustering(series: Float64Array[]): Promise<number[]> {
/// await instance.ready;
/// return instance.dbscan(series, 0.1, 10);
/// }
/// ```
#[cfg(feature = "parallel")]
pub use wasm_bindgen_rayon::init_thread_pool;

mod changepoints;
pub mod clustering;
mod dtw;
Expand Down
3 changes: 2 additions & 1 deletion justfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Build and publish the augurs-js package to npm with the @bsull scope.
publish-npm:
cd crates/augurs-js && \
wasm-pack build --release --scope bsull --target web && \
wasm-pack build --release --scope bsull --target web -- --features parallel && \
node prepublish && \
wasm-pack publish --access public

0 comments on commit 7d27a33

Please sign in to comment.