Skip to content

Commit

Permalink
fix!: clean up JS APIs (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k authored Oct 18, 2024
1 parent b9846f2 commit 7fe33a8
Show file tree
Hide file tree
Showing 25 changed files with 2,507 additions and 185 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/js.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: augurs-js

on:
push:
branches: [ "main" ]
pull_request:

env:
CARGO_TERM_COLOR: always

jobs:
test:
name: JS tests
runs-on: ubuntu-latest
steps:
- name: Checkout sources
uses: actions/checkout@v4

- uses: dtolnay/rust-toolchain@master
with:
toolchain: nightly-2024-09-01
targets: wasm32-unknown-unknown
- uses: taiki-e/install-action@v2
with:
tool: just,wasm-pack

- name: Build augurs-js
run: just build-augurs-js

- uses: actions/setup-node@v4
- name: Install dependencies
run: npm ci
working-directory: crates/augurs-js/testpkg
- name: Run typecheck
run: npm run typecheck
working-directory: crates/augurs-js/testpkg
- name: Run tests
run: npm run test:ci
working-directory: crates/augurs-js/testpkg
5 changes: 4 additions & 1 deletion crates/augurs-js/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ doctest = false
test = false

[features]
default = ["logging"]
logging = ["wasm-tracing"]
parallel = ["wasm-bindgen-rayon"]

[dependencies]
Expand All @@ -38,10 +40,11 @@ js-sys = "0.3.64"
serde.workspace = true
serde-wasm-bindgen = "0.6.0"
tracing.workspace = true
tracing-wasm = { version = "0.2.1", optional = true }
tracing-subscriber = { workspace = true, features = ["registry"], default-features = false }
tsify-next = { version = "0.5.3", default-features = false, features = ["js"] }
wasm-bindgen = "=0.2.93"
wasm-bindgen-rayon = { version = "1.2.1", optional = true }
wasm-tracing = { version = "0.2.1", optional = true }

[package.metadata.wasm-pack.profile.release]
# previously had just ['-O4']
Expand Down
31 changes: 26 additions & 5 deletions crates/augurs-js/src/changepoints.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::num::NonZeroUsize;

use js_sys::Float64Array;
use serde::{Deserialize, Serialize};
use tsify_next::Tsify;
use wasm_bindgen::prelude::*;
Expand All @@ -9,6 +8,8 @@ use augurs_changepoint::{
dist, ArgpcpDetector, BocpdDetector, DefaultArgpcpDetector, Detector, NormalGammaDetector,
};

use crate::VecF64;

#[derive(Debug)]
enum EitherDetector {
NormalGamma(NormalGammaDetector),
Expand All @@ -24,6 +25,18 @@ impl EitherDetector {
}
}

/// The type of changepoint detector to use.
#[derive(Debug, Clone, Copy, Deserialize, Tsify)]
#[serde(rename_all = "kebab-case")]
#[tsify(from_wasm_abi)]
pub enum ChangepointDetectorType {
/// A Bayesian Online Changepoint Detector with a Normal Gamma prior.
NormalGamma,
/// An autoregressive Gaussian Process changepoint detector,
/// with the default kernel and parameters.
DefaultArgpcp,
}

/// A changepoint detector.
#[derive(Debug)]
#[wasm_bindgen]
Expand All @@ -35,6 +48,14 @@ const DEFAULT_HAZARD_LAMBDA: f64 = 250.0;

#[wasm_bindgen]
impl ChangepointDetector {
#[wasm_bindgen(constructor)]
pub fn new(detectorType: ChangepointDetectorType) -> Result<ChangepointDetector, JsValue> {
match detectorType {
ChangepointDetectorType::NormalGamma => Self::normal_gamma(None),
ChangepointDetectorType::DefaultArgpcp => Self::default_argpcp(None),
}
}

/// Create a new Bayesian Online changepoint detector with a Normal Gamma prior.
#[wasm_bindgen(js_name = "normalGamma")]
pub fn normal_gamma(
Expand Down Expand Up @@ -73,10 +94,10 @@ impl ChangepointDetector {

/// Detect changepoints in the given time series.
#[wasm_bindgen(js_name = "detectChangepoints")]
pub fn detect_changepoints(&mut self, y: Float64Array) -> Changepoints {
Changepoints {
indices: self.detector.detect_changepoints(&y.to_vec()),
}
pub fn detect_changepoints(&mut self, y: VecF64) -> Result<Changepoints, JsError> {
Ok(Changepoints {
indices: self.detector.detect_changepoints(&y.convert()?),
})
}
}

Expand Down
80 changes: 68 additions & 12 deletions crates/augurs-js/src/dtw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use wasm_bindgen::prelude::*;

use augurs_dtw::{Euclidean, Manhattan};

use crate::{VecF64, VecVecF64};

enum InnerDtw {
Euclidean(augurs_dtw::Dtw<Euclidean>),
Manhattan(augurs_dtw::Dtw<Manhattan>),
Expand Down Expand Up @@ -89,11 +91,45 @@ pub struct DistanceMatrix {
}

impl DistanceMatrix {
/// Get the inner distance matrix.
pub fn inner(&self) -> &augurs_core::DistanceMatrix {
&self.inner
}
}

#[wasm_bindgen]
impl DistanceMatrix {
/// Create a new `DistanceMatrix` from a raw distance matrix.
#[wasm_bindgen(constructor)]
pub fn new(distanceMatrix: VecVecF64) -> Result<DistanceMatrix, JsError> {
Ok(Self {
inner: augurs_core::DistanceMatrix::try_from_square(distanceMatrix.convert()?)?,
})
}

/// Get the shape of the distance matrix.
#[wasm_bindgen(js_name = shape)]
pub fn shape(&self) -> Vec<usize> {
let (m, n) = self.inner.shape();
vec![m, n]
}

/// Get the distance matrix as an array of arrays.
#[wasm_bindgen(js_name = toArray)]
pub fn to_array(&self) -> Vec<Float64Array> {
self.inner
.clone()
.into_inner()
.into_iter()
.map(|x| {
let arr = Float64Array::new_with_length(x.len() as u32);
arr.copy_from(&x);
arr
})
.collect()
}
}

impl From<augurs_core::DistanceMatrix> for DistanceMatrix {
fn from(inner: augurs_core::DistanceMatrix) -> Self {
Self { inner }
Expand All @@ -106,6 +142,17 @@ impl From<DistanceMatrix> for augurs_core::DistanceMatrix {
}
}

/// The distance function to use for Dynamic Time Warping.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Tsify)]
#[serde(rename_all = "lowercase")]
#[tsify(from_wasm_abi)]
pub enum DistanceFunction {
/// Euclidean distance.
Euclidean,
/// Manhattan distance.
Manhattan,
}

/// Dynamic Time Warping.
///
/// The `window` parameter can be used to specify the Sakoe-Chiba band size.
Expand All @@ -119,9 +166,18 @@ pub struct Dtw {

#[wasm_bindgen]
impl Dtw {
/// Create a new `Dtw` instance.
#[wasm_bindgen(constructor)]
pub fn new(distanceFunction: DistanceFunction, opts: Option<DtwOptions>) -> Self {
match distanceFunction {
DistanceFunction::Euclidean => Self::euclidean(opts),
DistanceFunction::Manhattan => Self::manhattan(opts),
}
}

/// Create a new `Dtw` instance using the Euclidean distance.
#[wasm_bindgen]
pub fn euclidean(opts: Option<DtwOptions>) -> Result<Dtw, JsValue> {
pub fn euclidean(opts: Option<DtwOptions>) -> Dtw {
let opts = opts.unwrap_or_default();
let mut dtw = augurs_dtw::Dtw::euclidean();
if let Some(window) = opts.window {
Expand All @@ -140,14 +196,14 @@ impl Dtw {
if let Some(parallelize) = opts.parallelize {
dtw = dtw.parallelize(parallelize);
}
Ok(Dtw {
Dtw {
inner: InnerDtw::Euclidean(dtw),
})
}
}

/// Create a new `Dtw` instance using the Euclidean distance.
/// Create a new `Dtw` instance using the Manhattan distance.
#[wasm_bindgen]
pub fn manhattan(opts: Option<DtwOptions>) -> Result<Dtw, JsValue> {
pub fn manhattan(opts: Option<DtwOptions>) -> Dtw {
let opts = opts.unwrap_or_default();
let mut dtw = augurs_dtw::Dtw::manhattan();
if let Some(window) = opts.window {
Expand All @@ -162,24 +218,24 @@ impl Dtw {
if let Some(upper_bound) = opts.upper_bound {
dtw = dtw.with_upper_bound(upper_bound);
}
Ok(Dtw {
Dtw {
inner: InnerDtw::Manhattan(dtw),
})
}
}

/// Calculate the distance between two arrays under Dynamic Time Warping.
#[wasm_bindgen]
pub fn distance(&self, a: Float64Array, b: Float64Array) -> f64 {
self.inner.distance(&a.to_vec(), &b.to_vec())
pub fn distance(&self, a: VecF64, b: VecF64) -> Result<f64, JsError> {
Ok(self.inner.distance(&a.convert()?, &b.convert()?))
}

/// Compute the distance matrix between all pairs of series.
///
/// The series do not all have to be the same length.
#[wasm_bindgen(js_name = distanceMatrix)]
pub fn distance_matrix(&self, series: Vec<Float64Array>) -> DistanceMatrix {
let vecs: Vec<_> = series.iter().map(|x| x.to_vec()).collect();
pub fn distance_matrix(&self, series: VecVecF64) -> Result<DistanceMatrix, JsError> {
let vecs = series.convert()?;
let slices = vecs.iter().map(Vec::as_slice).collect::<Vec<_>>();
self.inner.distance_matrix(&slices)
Ok(self.inner.distance_matrix(&slices))
}
}
19 changes: 8 additions & 11 deletions crates/augurs-js/src/ets.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
//! JavaScript bindings for the AutoETS model.
use js_sys::Float64Array;
use wasm_bindgen::prelude::*;

use augurs_core::prelude::*;

use crate::Forecast;
use crate::{Forecast, VecF64};

/// Automatic ETS model selection.
#[derive(Debug)]
Expand All @@ -24,9 +23,8 @@ impl AutoETS {
///
/// If the `spec` string is invalid, this function returns an error.
#[wasm_bindgen(constructor)]
pub fn new(seasonLength: usize, spec: String) -> Result<AutoETS, JsValue> {
let inner =
augurs_ets::AutoETS::new(seasonLength, spec.as_str()).map_err(|e| e.to_string())?;
pub fn new(seasonLength: usize, spec: String) -> Result<AutoETS, JsError> {
let inner = augurs_ets::AutoETS::new(seasonLength, spec.as_str())?;
Ok(Self {
inner,
fitted: None,
Expand All @@ -43,8 +41,8 @@ impl AutoETS {
/// If no model can be found, or if any parameters are invalid, this function
/// returns an error.
#[wasm_bindgen]
pub fn fit(&mut self, y: Float64Array) -> Result<(), JsValue> {
self.fitted = Some(self.inner.fit(&y.to_vec()).map_err(|e| e.to_string())?);
pub fn fit(&mut self, y: VecF64) -> Result<(), JsError> {
self.fitted = Some(self.inner.fit(&y.convert()?)?);
Ok(())
}

Expand All @@ -57,13 +55,12 @@ impl AutoETS {
///
/// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]).
#[wasm_bindgen]
pub fn predict(&self, horizon: usize, level: Option<f64>) -> Result<Forecast, JsValue> {
pub fn predict(&self, horizon: usize, level: Option<f64>) -> Result<Forecast, JsError> {
Ok(self
.fitted
.as_ref()
.map(|x| x.predict(horizon, level))
.ok_or("model not fit yet")?
.map(Into::into)
.map_err(|e| e.to_string())?)
.ok_or(JsError::new("model not fit yet"))?
.map(Into::into)?)
}
}
57 changes: 55 additions & 2 deletions crates/augurs-js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ mod changepoints;
pub mod clustering;
mod dtw;
pub mod ets;
#[cfg(feature = "logging")]
pub mod logging;
pub mod mstl;
mod outlier;
mod prophet;
Expand All @@ -74,8 +76,6 @@ pub mod seasons;
#[wasm_bindgen(start)]
pub fn custom_init() {
console_error_panic_hook::set_once();
#[cfg(feature = "tracing-wasm")]
tracing_wasm::try_set_as_global_default().ok();
}

// Wrapper types for the core types, so we can derive `Tsify` for them.
Expand Down Expand Up @@ -122,3 +122,56 @@ impl From<augurs_core::Forecast> for Forecast {
}
}
}

// These custom types are needed to have the correct TypeScript types generated
// for functions which accept either `number[]` or typed arrays when called
// from Javascript.
// They should always be preferred over using `Vec<T>` directly in functions
// exported to Javascript, even if it is a bit of hassle to convert them.
// They can be converted using:
//
// let y = y.convert()?;
#[wasm_bindgen]
extern "C" {
/// Custom type for `Vec<u32>`.
#[wasm_bindgen(typescript_type = "number[] | Uint32Array")]
#[derive(Debug)]
pub type VecU32;

/// Custom type for `Vec<usize>`.
#[wasm_bindgen(typescript_type = "number[] | Uint32Array")]
#[derive(Debug)]
pub type VecUsize;

/// Custom type for `Vec<f64>`.
#[wasm_bindgen(typescript_type = "number[] | Float64Array")]
#[derive(Debug)]
pub type VecF64;

/// Custom type for `Vec<Vec<f64>>`.
#[wasm_bindgen(typescript_type = "number[][] | Float64Array[]")]
#[derive(Debug)]
pub type VecVecF64;
}

impl VecUsize {
fn convert(self) -> Result<Vec<usize>, JsError> {
serde_wasm_bindgen::from_value(self.into())
.map_err(|_| JsError::new("TypeError: expected array of integers or Uint32Array"))
}
}

impl VecF64 {
fn convert(self) -> Result<Vec<f64>, JsError> {
serde_wasm_bindgen::from_value(self.into())
.map_err(|_| JsError::new("TypeError: expected array of numbers or Float64Array"))
}
}

impl VecVecF64 {
fn convert(self) -> Result<Vec<Vec<f64>>, JsError> {
serde_wasm_bindgen::from_value(self.into()).map_err(|_| {
JsError::new("TypeError: expected array of number arrays or array of Float64Array")
})
}
}
Loading

0 comments on commit 7fe33a8

Please sign in to comment.