Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add benchmark for Prophet #140

Merged
merged 4 commits into from
Oct 25, 2024
Merged

chore: add benchmark for Prophet #140

merged 4 commits into from
Oct 25, 2024

Conversation

sd2k
Copy link
Collaborator

@sd2k sd2k commented Oct 24, 2024

Summary by CodeRabbit

  • New Features

    • Introduced integration tests for various functionalities of the augurs crate, including changepoint detection, clustering, dynamic time warping, exponential smoothing, forecasting, MSTL, outlier detection, and seasonal detection.
    • Added benchmarking tests for the Prophet model, measuring performance for fitting and predicting.
    • Added new benchmarks for the Dynamic Time Warping algorithm and the MSTL model on the VIC_ELEC dataset.
  • Chores

    • Updated the Cargo.toml file to include tracing-subscriber and criterion as development dependencies.
    • Modified the Prophet struct to support cloning, enhancing usability.
    • Updated the Modes struct to support cloning as well.
    • Added documentation suppression attributes to various benchmark files.
    • Expanded exclusion criteria in the test-all target of the justfile to exclude additional benchmarks.

Not sure if this will pass in CI?
Copy link
Contributor

coderabbitai bot commented Oct 24, 2024

Walkthrough

The pull request introduces updates to the augurs-prophet package by adding new development dependencies, tracing-subscriber and criterion, to the Cargo.toml file. Additionally, it enhances the integration tests for the augurs wrapper crate by implementing multiple test functions that cover various functionalities, including changepoint detection, clustering, dynamic time warping, exponential smoothing, forecasting, MSTL, outlier detection, and seasonal detection. The real-life.rs file is added for benchmarking the Prophet model, and modifications are made to several structs to enable cloning capabilities.

Changes

File Path Change Summary
crates/augurs-prophet/Cargo.toml Added tracing-subscriber.workspace = true and criterion.workspace = true under [dev-dependencies]. Added [lib] section with bench = false and new benchmark [[bench]] named real-life.
crates/augurs/tests/integration.rs Introduced multiple integration tests for various features of the augurs crate, organized by feature flags. Tests include changepoint detection, clustering, DTW, ETS, forecasting, MSTL, outlier detection, and seasonal detection.
crates/augurs-prophet/benches/real-life.rs Added benchmarking tests for Prophet model's fit and predict functions using the criterion library.
crates/augurs-prophet/src/prophet.rs Updated Prophet struct to derive Clone in addition to Debug.
crates/augurs-prophet/src/prophet/predict.rs Refactored methods in Prophet for trend prediction and uncertainty estimation, simplifying handling of sampled values.
crates/augurs-prophet/src/prophet/prep.rs Updated Modes struct to derive Clone in addition to existing traits.
crates/augurs-clustering/benches/dbscan.rs Added #![allow(missing_docs)] to suppress warnings for missing documentation.
crates/augurs-dtw/benches/dtw.rs Introduced benchmarking functions for DTW distance calculations and distance matrix calculations.
crates/augurs-ets/benches/air_passengers.rs Added #![allow(missing_docs)] to suppress warnings for missing documentation.
crates/augurs-ets/benches/air_passengers_iai.rs Added #![allow(missing_docs)] to suppress warnings for missing documentation.
crates/augurs-mstl/benches/vic_elec.rs Added #![allow(missing_docs)] to suppress warnings for missing documentation.
crates/augurs-mstl/benches/vic_elec_iai.rs Added new benchmark functions for MSTL model fitting on the VIC_ELEC dataset.
crates/augurs-seasons/benches/periodogram.rs Added #![allow(missing_docs)] to suppress warnings for missing documentation.

Possibly related PRs

  • chore: release #96: The changes in the main PR involve updates to the Cargo.toml file, including the addition of new dependencies and a new [lib] section, which aligns with the changes made in this PR that also updates the Cargo.toml for versioning and dependencies.
  • feat: add cmdstan-based optimizer for augurs-prophet #121: The main PR introduces a cmdstan optimizer for the augurs-prophet crate, which is related to the changes in this PR that enhance the cmdstan functionality and logging capabilities.
  • feat: add Prophet functionality to augurs-js #125: The main PR adds Prophet functionality to the augurs-js crate, which directly relates to the changes in this PR that introduce the augurs-prophet dependency and its features into the JavaScript bindings.
  • chore: release v0.4.1 #132: The main PR reflects a version bump for the workspace and its dependencies, which is consistent with the changes in this PR that also update the versions across multiple packages.
  • chore: release v0.4.3 #134: The main PR continues the versioning strategy by updating all packages to version 0.5.1, which is in line with the changes in this PR that also increment the versions of all dependencies.
  • chore: release v0.5.0 #137: The main PR announces the release of version 0.5.0, which is relevant as it sets the stage for the subsequent version updates seen in this PR.
  • chore: release v0.5.1 #139: The main PR details the release of version 0.5.1, which is directly related to the changes in this PR that also update the versions across all packages to 0.5.1.

🐇 In the garden, we hop and play,
With tests and crates, we pave the way.
New dependencies join our fun,
In the world of code, we’ve just begun!
So let’s cheer for changes, big and small,
Together we’ll conquer, one and all! 🌼


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (9)
crates/augurs/tests/integration.rs (9)

1-2: Enhance file-level documentation.

Consider adding more comprehensive documentation that includes:

  • Purpose and scope of these integration tests
  • Prerequisites for running the tests
  • Description of the feature flags and their dependencies
-//! Integration tests for the augurs wrapper crate.
+//! Integration tests for the augurs wrapper crate.
+//!
+//! This file contains integration tests for various features of the augurs crate:
+//! - Changepoint detection
+//! - Clustering
+//! - Dynamic Time Warping (DTW)
+//! - Exponential Smoothing (ETS)
+//! - Forecasting
+//! - MSTL (Multiple Seasonal-Trend decomposition using LOESS)
+//! - Outlier detection
+//! - Seasonal detection
+//!
+//! Each test is gated behind a feature flag and requires the corresponding feature
+//! to be enabled during testing.

Line range hint 4-24: Document test data and expected results.

The test uses synthetic data but lacks documentation explaining the data pattern and why specific indices are expected as changepoints.

 #[cfg(feature = "changepoint")]
 #[test]
 fn test_changepoint() {
     use augurs::changepoint::{ArgpcpDetector, Detector};
+    // Create synthetic data with two segments:
+    // - First segment (indices 0-33): constant value of 1.0
+    // - Second segment (indices 34-67): constant value of 2.0
+    // Expected changepoints: 0 (start) and 33 (transition point)
     let data = vec![
         1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
         1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0,
         2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
         2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
     ];
     let changepoints = ArgpcpDetector::builder().build().detect_changepoints(&data);
-    // 1 changepoint, but the start is considered a changepoint too.
+    // Verify both the start (0) and the transition point (33) are detected
     assert_eq!(changepoints, vec![0, 33]);
 }

Line range hint 26-63: Improve test structure and documentation for clustering tests.

The test uses different parameter combinations but lacks clear documentation about the test cases and expected outcomes.

Consider restructuring the test into separate test cases with clear documentation:

 #[cfg(feature = "clustering")]
-#[test]
-fn test_clustering() {
+mod clustering_tests {
     use augurs::{clustering::DbscanClusterer, DistanceMatrix};
-    let distance_matrix = vec![
+    
+    /// Creates a sample distance matrix for testing:
+    /// - Points 0 and 1 are close (distance 1.0)
+    /// - Point 2 is moderately distant (distance 2.0-3.0)
+    /// - Point 3 is far from all others (distance 3.0-4.0)
+    fn create_test_matrix() -> DistanceMatrix {
+        let matrix = vec![
         vec![0.0, 1.0, 2.0, 3.0],
         vec![1.0, 0.0, 3.0, 3.0],
         vec![2.0, 3.0, 0.0, 4.0],
         vec![3.0, 3.0, 4.0, 0.0],
-    ];
-    let distance_matrix = DistanceMatrix::try_from_square(distance_matrix).unwrap();
+        ];
+        DistanceMatrix::try_from_square(matrix).unwrap()
+    }
+
+    #[test]
+    fn test_strict_clustering() {
+        // Test with strict parameters (eps=0.5, min_points=2)
+        // Expect no clusters due to strict distance threshold
+        let distance_matrix = create_test_matrix();
+        let clusters = DbscanClusterer::new(0.5, 2).fit(&distance_matrix);
+        assert_eq!(clusters, vec![-1, -1, -1, -1], "All points should be noise");
+    }
+
+    #[test]
+    fn test_moderate_clustering() {
+        // Test with moderate parameters (eps=1.0, min_points=2)
+        // Expect points 0 and 1 to form a cluster
+        let distance_matrix = create_test_matrix();
+        let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix);
+        assert_eq!(clusters, vec![0, 0, -1, -1], "Points 0,1 should form a cluster");
+    }
     
-    let clusters = DbscanClusterer::new(0.5, 2).fit(&distance_matrix);
-    assert_eq!(clusters, vec![-1, -1, -1, -1]);
-
-    let clusters = DbscanClusterer::new(1.0, 2).fit(&distance_matrix);
-    assert_eq!(clusters, vec![0, 0, -1, -1]);
-
-    let clusters = DbscanClusterer::new(1.0, 3).fit(&distance_matrix);
-    assert_eq!(clusters, vec![-1, -1, -1, -1]);
-
-    let clusters = DbscanClusterer::new(2.0, 2).fit(&distance_matrix);
-    assert_eq!(clusters, vec![0, 0, 0, -1]);
-
-    let clusters = DbscanClusterer::new(2.0, 3).fit(&distance_matrix);
-    assert_eq!(clusters, vec![0, 0, 0, -1]);
-
-    let clusters = DbscanClusterer::new(3.0, 3).fit(&distance_matrix);
-    assert_eq!(clusters, vec![0, 0, 0, 0]);
+    // Add remaining test cases with similar documentation...
 }

Line range hint 65-95: Add error handling and improve documentation for DTW tests.

The tests cover good cases but could benefit from better documentation and error handling.

 #[cfg(feature = "dtw")]
 #[test]
 fn test_dtw() {
     use augurs::dtw::Dtw;
     use augurs_testing::assert_approx_eq;
+    // Test DTW distance between two sequences with constant difference of 3.0
     let result = Dtw::euclidean().distance(&[0.0, 1.0, 2.0], &[3.0, 4.0, 5.0]);
     assert_approx_eq!(result, 5.0990195135927845);
+
+    // Test with empty sequences
+    assert!(Dtw::euclidean().distance(&[], &[]).is_nan());
 }

 #[cfg(feature = "dtw")]
 #[test]
 fn test_dtw_distance_matrix() {
     use augurs::dtw::Dtw;
     use augurs_testing::assert_all_close;

     let dtw = Dtw::euclidean();
+    // Test with equal length sequences
     let series: &[&[f64]] = &[&[0.0, 1.0, 2.0], &[3.0, 4.0, 5.0], &[6.0, 7.0, 8.0]];
     let dists = dtw.distance_matrix(series);
     assert_eq!(dists.shape(), (3, 3));
     assert_all_close(&dists[0], &[0.0, 5.0990195135927845, 10.392304845413264]);

-    // Test with different length series.
+    // Test with varying length sequences to verify DTW's ability to handle unequal lengths
     let dtw = Dtw::euclidean();
     let series: &[&[f64]] = &[&[0.0, 1.0, 2.0], &[3.0], &[6.0, 7.0]];
     let dists = dtw.distance_matrix(series);
     assert_eq!(dists.shape(), (3, 3));
     assert_all_close(&dists[0], &[0.0, 3.7416573867739413, 9.273618495495704]);
+
+    // Test with empty sequences
+    let series: &[&[f64]] = &[&[], &[]];
+    let dists = dtw.distance_matrix(series);
+    assert_eq!(dists.shape(), (2, 2));
 }

Line range hint 97-121: Enhance error handling and test coverage for ETS model.

The test verifies basic functionality but could benefit from better error handling and edge cases.

 #[cfg(feature = "ets")]
 #[test]
 fn test_ets() {
     use augurs::{
         ets::{
             model::{ErrorComponent, SeasonalComponent, TrendComponent},
             AutoETS,
         },
         prelude::*,
     };
     use augurs_testing::{assert_within_pct, data::AIR_PASSENGERS};

+    // Test with valid data
     let auto = AutoETS::non_seasonal();
-    let fit = auto.fit(AIR_PASSENGERS).expect("fit failed");
+    let fit = auto.fit(AIR_PASSENGERS).unwrap_or_else(|e| {
+        panic!("Failed to fit model: {}", e);
+    });
+
+    // Verify model components
     assert_eq!(
         fit.model().model_type().error,
         ErrorComponent::Multiplicative
     );
     assert_eq!(fit.model().model_type().trend, TrendComponent::Additive);
     assert_eq!(fit.model().model_type().season, SeasonalComponent::None);
     assert_within_pct!(fit.model().log_likelihood(), -831.4883541595792, 0.01);
     assert_within_pct!(fit.model().aic(), 1672.9767083191584, 0.01);
+
+    // Test with edge cases
+    assert!(AutoETS::non_seasonal().fit(&[]).is_err(), "Empty data should fail");
+    assert!(
+        AutoETS::non_seasonal().fit(&[f64::NAN]).is_err(),
+        "NaN data should fail"
+    );
 }

Line range hint 123-166: Remove debug statement and improve test documentation.

The test includes a debug print statement and could benefit from better documentation of the transformation pipeline.

 #[cfg(feature = "forecaster")]
 #[test]
 fn test_forecaster() {
     use augurs::{
         forecaster::{transforms::MinMaxScaleParams, Forecaster, Transform},
         mstl::{MSTLModel, NaiveTrend},
     };
     use augurs_testing::{assert_all_close, data::AIR_PASSENGERS};
     use itertools::{Itertools, MinMaxResult};

+    // Calculate data bounds for scaling
     let MinMaxResult::MinMax(min, max) = AIR_PASSENGERS
         .iter()
         .copied()
         .minmax_by(|a, b| a.partial_cmp(b).unwrap())
     else {
         unreachable!()
     };
+
+    // Create transformation pipeline:
+    // 1. Linear interpolation for missing values
+    // 2. Min-max scaling to [0,1] range
+    // 3. Logit transform for bounded forecasting
     let transforms = vec![
         Transform::linear_interpolator(),
         Transform::min_max_scaler(MinMaxScaleParams::new(min - 1e-3, max + 1e-3)),
         Transform::logit(),
     ];
+
+    // Create and fit the forecasting model
     let model = MSTLModel::new(vec![2], NaiveTrend::new());
     let mut forecaster = Forecaster::new(model).with_transforms(transforms);
     forecaster.fit(AIR_PASSENGERS).unwrap();
+
+    // Generate and verify forecasts
     let forecasts = forecaster.predict(4, None).unwrap();
-    dbg!(&forecasts.point);
     assert_all_close(
         &forecasts.point,
         &[
             559.0587706145459,
             432.00000550710956,
             559.0587706145459,
             432.00000550710956,
         ],
     );
+
+    // Test error handling
+    assert!(forecaster.predict(0, None).is_err(), "Zero horizon should fail");
 }

Line range hint 168-238: Improve test structure and add validation for confidence intervals.

The test covers comprehensive functionality but could benefit from better structure and additional validations.

 #[cfg(feature = "mstl")]
-#[test]
-fn test_mstl() {
+mod mstl_tests {
     use augurs::{
         mstl::{stlrs, MSTLModel, NaiveTrend},
         prelude::*,
     };
     use augurs_testing::{assert_all_close, data::VIC_ELEC};

-    let mut stl_params = stlrs::params();
-    stl_params
-        .seasonal_degree(0)
-        .seasonal_jump(1)
-        .trend_degree(1)
-        .trend_jump(1)
-        .low_pass_degree(1)
-        .inner_loops(2)
-        .outer_loops(0);
-    let mut mstl_params = stlrs::MstlParams::new();
-    mstl_params.stl_params(stl_params);
-    let periods = vec![24, 24 * 7];
-    let trend_model = NaiveTrend::new();
-    let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params);
-    let fit = mstl.fit(&VIC_ELEC).unwrap();
+    fn create_test_model() -> MSTLModel {
+        let mut stl_params = stlrs::params();
+        stl_params
+            .seasonal_degree(0)
+            .seasonal_jump(1)
+            .trend_degree(1)
+            .trend_jump(1)
+            .low_pass_degree(1)
+            .inner_loops(2)
+            .outer_loops(0);
+        let mut mstl_params = stlrs::MstlParams::new();
+        mstl_params.stl_params(stl_params);
+        let periods = vec![24, 24 * 7]; // Daily and weekly seasonality
+        let trend_model = NaiveTrend::new();
+        MSTLModel::new(periods, trend_model).mstl_params(mstl_params)
+    }

+    #[test]
+    fn test_in_sample_predictions() {
+        let mstl = create_test_model();
+        let fit = mstl.fit(&VIC_ELEC).unwrap();
+        let in_sample = fit.predict_in_sample(0.95).unwrap();
+
+        // Verify predictions length
+        assert_eq!(in_sample.point.len(), VIC_ELEC.len());
+
+        // Verify first 12 values against R implementation
+        let expected_in_sample = vec![
+            f64::NAN,
+            7952.216,
+            7269.439,
+            6878.110,
+            6606.999,
+            6402.581,
+            6659.523,
+            7457.488,
+            8111.359,
+            8693.762,
+            9255.807,
+            9870.213,
+        ];
+        assert_all_close(&in_sample.point[..12], &expected_in_sample);
+    }

+    #[test]
+    fn test_out_of_sample_predictions() {
+        let mstl = create_test_model();
+        let fit = mstl.fit(&VIC_ELEC).unwrap();
+        let out_of_sample = fit.predict(10, 0.95).unwrap();
+
+        // Verify predictions
+        let expected_out_of_sample = vec![
+            8920.670, 8874.234, 8215.508, 7782.726, 7697.259,
+            8216.241, 9664.907, 10914.452, 11536.929, 11664.737,
+        ];
+        assert_all_close(&out_of_sample.point, &expected_out_of_sample);
+
+        // Verify confidence intervals
+        let ForecastIntervals { lower, upper, .. } = out_of_sample.intervals.unwrap();
+        assert_eq!(lower.len(), 10);
+        assert_eq!(upper.len(), 10);
+
+        let expected_lower = vec![
+            8700.984, 8563.551, 7835.001, 7343.354, 7206.026,
+            7678.122, 9083.672, 10293.087, 10877.871, 10970.029,
+        ];
+        let expected_upper = vec![
+            9140.356, 9184.917, 8596.016, 8222.098, 8188.491,
+            8754.359, 10246.141, 11535.818, 12195.987, 12359.445,
+        ];
+        assert_all_close(&lower, &expected_lower);
+        assert_all_close(&upper, &expected_upper);
+
+        // Verify intervals are properly ordered
+        for (l, u) in lower.iter().zip(upper.iter()) {
+            assert!(l < u, "Lower bound should be less than upper bound");
+        }
+    }
 }

Line range hint 240-297: Improve test data documentation and add edge cases.

The outlier detection tests could benefit from better documentation and additional test cases.

 #[cfg(feature = "outlier")]
+mod outlier_tests {
+    use augurs::outlier::{DbscanDetector, MADDetector, OutlierDetector};
+
+    /// Creates test data with known outliers:
+    /// - First two series are similar (normal)
+    /// - Third series has outliers in its last two points
+    fn create_test_data() -> Vec<Vec<f64>> {
+        vec![
+            vec![1.0, 2.0, 1.5, 2.3],
+            vec![1.9, 2.2, 1.2, 2.4],
+            vec![1.5, 2.1, 6.4, 8.5], // Contains outliers
+        ]
+    }
+
     #[test]
     fn test_outlier_dbscan() {
-        use augurs::outlier::{DbscanDetector, OutlierDetector};
-        let data: &[&[f64]] = &[
-            &[1.0, 2.0, 1.5, 2.3],
-            &[1.9, 2.2, 1.2, 2.4],
-            &[1.5, 2.1, 6.4, 8.5],
-        ];
+        let data = create_test_data();
+        let data_refs: Vec<&[f64]> = data.iter().map(|v| v.as_slice()).collect();
+
         let detector =
             DbscanDetector::with_sensitivity(0.5).expect("sensitivity is between 0.0 and 1.0");
-        let processed = detector.preprocess(data).unwrap();
+        let processed = detector.preprocess(&data_refs).unwrap();
         let outliers = detector.detect(&processed).unwrap();

+        // Verify outlier detection
         assert_eq!(outliers.outlying_series.len(), 1);
         assert!(outliers.outlying_series.contains(&2));
         assert!(outliers.series_results[2].is_outlier);
         assert_eq!(outliers.series_results[2].scores, vec![0.0, 0.0, 1.0, 1.0]);
         assert!(outliers.cluster_band.is_some());
+
+        // Test edge cases
+        let empty: Vec<&[f64]> = vec![];
+        assert!(detector.preprocess(&empty).is_err(), "Empty data should fail");
     }

     #[test]
     fn test_outlier_mad() {
-        use augurs::outlier::{MADDetector, OutlierDetector};
-        let data: &[&[f64]] = &[
-            &[1.0, 2.0, 1.5, 2.3],
-            &[1.9, 2.2, 1.2, 2.4],
-            &[1.5, 2.1, 6.4, 8.5],
-        ];
+        let data = create_test_data();
+        let data_refs: Vec<&[f64]> = data.iter().map(|v| v.as_slice()).collect();
+
         let detector = MADDetector::with_sensitivity(0.5).unwrap();
-        let processed = detector.preprocess(data).unwrap();
+        let processed = detector.preprocess(&data_refs).unwrap();
         let outliers = detector.detect(&processed).unwrap();

+        // Verify outlier detection
         assert_eq!(outliers.outlying_series.len(), 1);
         assert!(outliers.outlying_series.contains(&2));
         assert!(outliers.series_results[2].is_outlier);
         assert_eq!(
             outliers.series_results[2].scores,
             vec![
                 0.6835259767082061,
                 0.057793242408848366,
                 5.028012089569781,
                 7.4553282707414
             ]
         );
         assert!(outliers.cluster_band.is_some());
+
+        // Test with invalid sensitivity
+        assert!(MADDetector::with_sensitivity(-1.0).is_err());
+        assert!(MADDetector::with_sensitivity(1.5).is_err());
     }
 }

Line range hint 299-332: Enhance seasonal detection test with better documentation and assertions.

The test uses synthetic data but lacks documentation about the data pattern and could include more comprehensive assertions.

 #[cfg(feature = "seasons")]
 #[test]
 fn test_seasonal() {
     use augurs::seasons::{Detector, PeriodogramDetector};

+    // Create synthetic data with a period of 4:
+    // - Pattern repeats every 4 points
+    // - Contains 8 complete cycles
     #[rustfmt::skip]
-        let y = &[
-            0.1, 0.3, 0.8, 0.5,
-            0.1, 0.31, 0.79, 0.48,
-            0.09, 0.29, 0.81, 0.49,
-            0.11, 0.28, 0.78, 0.53,
-            0.1, 0.3, 0.8, 0.5,
-            0.1, 0.31, 0.79, 0.48,
-            0.09, 0.29, 0.81, 0.49,
-            0.11, 0.28, 0.78, 0.53,
-        ];
+    let y = &[
+        0.1, 0.3, 0.8, 0.5,   // Cycle 1
+        0.1, 0.31, 0.79, 0.48, // Cycle 2
+        0.09, 0.29, 0.81, 0.49, // Cycle 3
+        0.11, 0.28, 0.78, 0.53, // Cycle 4
+        0.1, 0.3, 0.8, 0.5,    // Cycle 5
+        0.1, 0.31, 0.79, 0.48, // Cycle 6
+        0.09, 0.29, 0.81, 0.49, // Cycle 7
+        0.11, 0.28, 0.78, 0.53, // Cycle 8
+    ];
+
+    // Detect seasonal periods
     let periods = PeriodogramDetector::default().detect(y);
-    assert_eq!(periods[0], 4);
+
+    // Verify detected periods
+    assert!(!periods.is_empty(), "Should detect at least one period");
+    assert_eq!(periods[0], 4, "Primary period should be 4");
+
+    // Test edge cases
+    let empty: Vec<f64> = vec![];
+    assert!(PeriodogramDetector::default().detect(&empty).is_empty());
+
+    // Test with non-seasonal data
+    let random = vec![1.0, 2.0, 3.0, 4.0];
+    assert!(PeriodogramDetector::default().detect(&random).is_empty());
 }
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 7e84a6b and 22a5e83.

📒 Files selected for processing (2)
  • crates/augurs-prophet/Cargo.toml (1 hunks)
  • crates/augurs/tests/integration.rs (1 hunks)
🔇 Additional comments (1)
crates/augurs-prophet/Cargo.toml (1)

32-32: LGTM! The addition of tracing-subscriber is well-placed.

The addition of tracing-subscriber as a development dependency is appropriate for supporting the new integration tests, and correctly uses workspace-level version management.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (2)
crates/augurs-prophet/Cargo.toml (2)

31-31: Consider limiting tracing-subscriber features

While the dev-dependencies are appropriate, consider specifying only the required features for tracing-subscriber instead of using default-features = true. This can help reduce compilation time during development.

-tracing-subscriber = { workspace = true, default-features = true }
+tracing-subscriber = { workspace = true, default-features = false, features = ["fmt"] }

Also applies to: 33-33


53-56: Consider splitting benchmarks into a separate PR

While the benchmarks are valuable, they seem tangential to the main PR objective of adding integration tests. Consider:

  1. Splitting the benchmark addition into a separate PR to maintain focused changes
  2. This would simplify CI configuration and reduce the risk of CI failures
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 22a5e83 and 4c52507.

📒 Files selected for processing (5)
  • crates/augurs-prophet/Cargo.toml (2 hunks)
  • crates/augurs-prophet/benches/real-life.rs (1 hunks)
  • crates/augurs-prophet/src/prophet.rs (1 hunks)
  • crates/augurs-prophet/src/prophet/predict.rs (5 hunks)
  • crates/augurs-prophet/src/prophet/prep.rs (1 hunks)
🔇 Additional comments (9)
crates/augurs-prophet/Cargo.toml (1)

50-51: LGTM: Appropriate library configuration

Disabling the default benchmark harness is correct since we're using criterion for custom benchmarks.

crates/augurs-prophet/src/prophet.rs (1)

21-21: Consider performance implications of cloning.

While adding Clone is necessary for integration testing, be mindful that cloning a trained Prophet model with large datasets could be memory-intensive due to deep copying of all internal buffers (training data, parameters, etc.).

Let's check the size of the struct's fields:

Consider implementing a more memory-efficient approach if cloning becomes a bottleneck:

  1. Use reference counting (Arc) for sharing large immutable data
  2. Implement a custom Clone that only copies necessary fields
  3. Add benchmarks to measure cloning overhead
crates/augurs-prophet/src/prophet/prep.rs (1)

31-31: LGTM: Adding Clone trait is appropriate.

The addition of the Clone trait to the Modes struct is well-justified as both its fields (HashSet<ComponentName>) already implement Clone. This change aligns with similar modifications in the codebase and supports the integration testing objectives.

crates/augurs-prophet/benches/real-life.rs (3)

1-8: LGTM: Appropriate imports for benchmarking Prophet

The imports cover all necessary components for benchmarking Prophet's performance, including criterion for benchmarking, Prophet types, and testing utilities.


10-35: LGTM: Well-structured benchmark for model fitting

The benchmark is well-implemented with:

  • Appropriate Prophet options for testing
  • Batched benchmarking for accurate measurements
  • Fixed seed (100) for reproducible results
  • Proper cloning of model and training data between iterations

2110-2111: LGTM: Proper criterion benchmark setup

The benchmark registration using criterion_group! and criterion_main! follows best practices for criterion benchmarks.

crates/augurs-prophet/src/prophet/predict.rs (3)

412-417: Efficient use of mutable buffers to reduce allocations

Introducing mutable buffers yhat and trend to reuse in each iteration reduces memory allocations and can improve performance in the sample_posterior_predictive method.


456-460: Clear buffers before reuse in sample_model

Clearing yhat_tmp and trend_tmp at the start of sample_model ensures that old data does not persist between iterations. This practice is essential for maintaining data integrity when reusing buffers.


484-487: Verify alignment and lengths in yhat computation

Ensure that the iterators trend_tmp, xb_a, xb_m, and noise are of the same length and correctly aligned. Misalignment could lead to incorrect calculations or runtime panics due to out-of-bounds access.

Consider adding assertions to confirm that all vectors have the same length:

assert_eq!(trend_tmp.len(), xb_a.len());
assert_eq!(xb_a.len(), xb_m.len());
assert_eq!(xb_m.len(), noise.len());

This verification ensures that the izip! macro will not panic and that each element corresponds correctly during the computation.

crates/augurs-prophet/Cargo.toml Show resolved Hide resolved
crates/augurs-prophet/src/prophet.rs Show resolved Hide resolved
crates/augurs-prophet/benches/real-life.rs Show resolved Hide resolved
crates/augurs-prophet/benches/real-life.rs Show resolved Hide resolved
crates/augurs-prophet/benches/real-life.rs Show resolved Hide resolved
crates/augurs-prophet/src/prophet/predict.rs Show resolved Hide resolved
@sd2k sd2k changed the title chore: add integration test for Prophet chore: add benchmark for Prophet Oct 25, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (5)
crates/augurs-ets/benches/air_passengers_iai.rs (2)

Line range hint 13-15: Consider documenting the AutoETS parameters.

While the implementation is correct, it would be helpful to document what the parameters 1 and "ZZN" represent in the AutoETS::new(1, "ZZN") call. This would make the benchmark more maintainable and easier to understand.

 fn auto_fit() {
+    // 1: seasonal period
+    // "ZZN": ZZN specification (error/trend/seasonal components)
     AutoETS::new(1, "ZZN").unwrap().fit(black_box(AP)).unwrap();
 }

Line range hint 28-45: Consider alternative approaches for prediction benchmarking.

While iai doesn't support benchmark setup, you could consider using Criterion.rs for the prediction benchmarks, which does support setup through its Bencher API. This would allow you to measure prediction performance while keeping the iai benchmarks for the fitting operations.

Example implementation with Criterion:

use criterion::{criterion_group, criterion_main, Criterion};

fn predict_benchmark(c: &mut Criterion) {
    let model = Unfit::new(ModelType {
        error: ErrorComponent::Additive,
        trend: TrendComponent::Additive,
        season: None,
    })
    .damped(true)
    .fit(AP)
    .unwrap();

    c.bench_function("predict", |b| {
        b.iter(|| model.predict(24, 0.95))
    });
}
crates/augurs-ets/benches/air_passengers.rs (1)

Line range hint 66-70: Consider optimizing the profiler configuration.

The current profiler configuration uses a high sampling rate (10000) which might be excessive for these benchmarks. Consider:

  1. Reducing the sampling rate for quicker benchmark runs
  2. Adding explicit warmup and measurement time configurations for more stable results
 criterion_group! {
     name = benches;
-    config = Criterion::default().with_profiler(PProfProfiler::new(10000, Output::Protobuf));
+    config = Criterion::default()
+        .with_profiler(PProfProfiler::new(1000, Output::Protobuf))
+        .warm_up_time(std::time::Duration::from_secs(1))
+        .measurement_time(std::time::Duration::from_secs(5));
     targets = auto_fit, fit, forecast,
 }
crates/augurs-dtw/benches/dtw.rs (2)

Line range hint 8-22: Consider improving error handling and using a CSV parser.

The current implementation has several potential issues:

  1. Uses unwrap() which could panic on malformed data
  2. Manual CSV parsing is error-prone
  3. Pre-allocates vectors without knowing valid data count

Consider these improvements:

-fn examples() -> Vec<Vec<f64>> {
+fn examples() -> Result<Vec<Vec<f64>>, Box<dyn std::error::Error>> {
     let raw = include_str!("../data/series.csv");
-    let n_columns = raw.lines().next().unwrap().split(',').count();
-    let n_rows = raw.lines().count();
-    let mut examples = vec![Vec::with_capacity(n_rows); n_columns];
-    for line in raw.lines() {
-        for (i, value) in line.split(',').enumerate() {
-            let value: f64 = value.parse().unwrap();
-            if !value.is_nan() {
-                examples[i].push(value);
-            }
-        }
-    }
-    examples
+    let mut rdr = csv::Reader::from_reader(raw.as_bytes());
+    let mut examples: Vec<Vec<f64>> = Vec::new();
+    
+    for result in rdr.records() {
+        let record = result?;
+        for (i, value) in record.iter().enumerate() {
+            if i >= examples.len() {
+                examples.push(Vec::new());
+            }
+            if let Ok(value) = value.parse::<f64>() {
+                if !value.is_nan() {
+                    examples[i].push(value);
+                }
+            }
+        }
+    }
+    Ok(examples)
 }

This would require adding the csv crate to your dependencies:

[dev-dependencies]
csv = "1.2"

Line range hint 23-43: Add documentation for benchmark parameters.

While the benchmark is well-structured, it would benefit from documentation explaining:

  • The significance of the chosen window sizes
  • The expected impact on performance
  • The characteristics of the input data being used

Add documentation like this:

 fn distance_euclidean(c: &mut Criterion) {
     let mut group = c.benchmark_group("distance_euclidean");
     let examples = examples();
     let (s, t) = (&examples[0], &examples[1]);
+    // Window sizes chosen to demonstrate performance characteristics:
+    // - None: unrestricted DTW
+    // - Small windows (2, 5): tight constraints, fastest performance
+    // - Medium windows (10, 20): balanced accuracy/performance
+    // - Large window (50): similar to unrestricted for this data
     let windows = [None, Some(2), Some(5), Some(10), Some(20), Some(50)];
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 4c52507 and c260027.

📒 Files selected for processing (7)
  • crates/augurs-clustering/benches/dbscan.rs (1 hunks)
  • crates/augurs-dtw/benches/dtw.rs (1 hunks)
  • crates/augurs-ets/benches/air_passengers.rs (1 hunks)
  • crates/augurs-ets/benches/air_passengers_iai.rs (1 hunks)
  • crates/augurs-mstl/benches/vic_elec.rs (1 hunks)
  • crates/augurs-mstl/benches/vic_elec_iai.rs (1 hunks)
  • crates/augurs-seasons/benches/periodogram.rs (1 hunks)
✅ Files skipped from review due to trivial changes (4)
  • crates/augurs-clustering/benches/dbscan.rs
  • crates/augurs-mstl/benches/vic_elec.rs
  • crates/augurs-mstl/benches/vic_elec_iai.rs
  • crates/augurs-seasons/benches/periodogram.rs
🔇 Additional comments (7)
crates/augurs-ets/benches/air_passengers_iai.rs (2)

Line range hint 1-11: LGTM! Imports and attributes are well-organized.

The use of #![allow(missing_docs)] is appropriate for a benchmark file, and all necessary imports are present.


Line range hint 17-26: LGTM! Model configuration is well-structured.

The explicit model configuration with damped trend provides good coverage for benchmarking different ETS model variants.

crates/augurs-ets/benches/air_passengers.rs (2)

1-2: LGTM! Suppressing doc warnings is acceptable for benchmark files.

The #![allow(missing_docs)] attribute is appropriately used here since benchmark files typically don't require extensive documentation.


Line range hint 13-63: Consider enhancing benchmark coverage and methodology.

The benchmarks could be improved in several ways:

  1. The forecast benchmark uses a fixed horizon of 24. Consider parameterizing this to measure performance across different forecast horizons.
  2. The SmallInput batch size might not be optimal for measuring the fitting performance accurately.
  3. Memory allocations during model fitting aren't being measured separately.

Let's verify the dataset size to determine if SmallInput is appropriate:

Consider these improvements:

 fn forecast(c: &mut Criterion) {
     let model = Unfit::new(ModelType {
         error: ErrorComponent::Additive,
         trend: TrendComponent::Additive,
         season: None,
     })
     .damped(true)
     .fit(AP)
     .unwrap();
     let mut group = c.benchmark_group("forecast");
-    group.bench_function("air_passengers", |b| {
-        b.iter(|| {
-            model.predict(24, 0.95).unwrap();
-        })
-    });
+    for horizon in [12, 24, 48] {
+        group.bench_function(format!("air_passengers_h{horizon}"), |b| {
+            b.iter(|| {
+                model.predict(horizon, 0.95).unwrap();
+            })
+        });
+    }
 }
✅ Verification successful

Based on the output, I can now generate the final response since I have enough information about the dataset size. The AIR_PASSENGERS dataset contains approximately 144 data points (12 years of monthly data from the visible portion), which is relatively small.

The SmallInput batch size is appropriate, but forecast benchmarks need enhancement

The current benchmark implementation is mostly sound:

  • SmallInput batch size is appropriate given the dataset size (~144 points)
  • However, the forecast benchmark could be improved by:
    • Testing multiple forecast horizons (e.g., 12, 24, 48 periods)
    • Including confidence interval computation overhead in measurements
    • Adding separate benchmarks for point forecasts vs interval forecasts
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Find and analyze the Air Passengers dataset size
rg -A 5 "AIR_PASSENGERS" --type rust

Length of output: 5190

crates/augurs-dtw/benches/dtw.rs (3)

Line range hint 1-7: LGTM: Appropriate imports and setup for benchmarking.

The file correctly sets up the necessary imports for criterion benchmarking and DTW functionality.


Line range hint 72-73: LGTM: Proper criterion benchmark setup.

The benchmark configuration is correctly set up using criterion's standard macros.


Line range hint 44-71: Verify the choice of window sizes for matrix benchmarks.

The matrix benchmark uses a more limited set of window sizes (2, 10) compared to the single distance benchmark. This might not provide a complete performance profile.

Let's check if there are any constraints that influenced this decision:

@sd2k sd2k enabled auto-merge (squash) October 25, 2024 11:13
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (1)
justfile (1)

38-38: Consider documenting how to run Prophet benchmarks locally.

While it's appropriate to exclude Prophet benchmarks from CI due to STAN dependencies, it would be helpful to add documentation (either in the README or as a comment) explaining how developers can run these benchmarks locally when they have STAN installed.

Example comment to add above the test-all target:

+# To run Prophet benchmarks locally (requires STAN installation):
+# cargo bench --package augurs-prophet --bench real-life
+
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between c260027 and 02c7b76.

📒 Files selected for processing (1)
  • justfile (1 hunks)
🔇 Additional comments (1)
justfile (1)

29-30: LGTM! Clear explanation of benchmark exclusions.

The comment clearly explains why Prophet benchmarks are excluded from the test-all target.

@sd2k sd2k merged commit 4f65cab into main Oct 25, 2024
21 checks passed
@sd2k sd2k deleted the add-prophet-integration-test branch October 25, 2024 11:18
This was referenced Oct 25, 2024
@sd2k sd2k mentioned this pull request Oct 25, 2024
@sd2k sd2k mentioned this pull request Nov 8, 2024
@sd2k sd2k mentioned this pull request Nov 9, 2024
@coderabbitai coderabbitai bot mentioned this pull request Nov 9, 2024
@sd2k sd2k mentioned this pull request Nov 21, 2024
@sd2k sd2k mentioned this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant