-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ec2109a
commit b77ba69
Showing
6 changed files
with
119 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
use light_river::mondrian_forest::mondrian_forest::MondrianForestRegressor; | ||
|
||
use light_river::common::{RegTarget, Regressor}; | ||
use light_river::datasets::machine_degradation::MachineDegradation; | ||
use light_river::stream::iter_csv::IterCsv; | ||
use ndarray::Array1; | ||
use num::ToPrimitive; | ||
|
||
use std::fs::File; | ||
use std::time::Instant; | ||
|
||
/// Get list of features of the dataset. | ||
/// | ||
/// e.g. features: ["H.e", "UD.t.i", "H.i", ...] | ||
fn get_features(transactions: IterCsv<f32, File>) -> Vec<String> { | ||
let sample = transactions.into_iter().next(); | ||
let observation = sample.unwrap().unwrap().get_observation(); | ||
let mut out: Vec<String> = observation.iter().map(|(k, _)| k.clone()).collect(); | ||
out.sort(); | ||
out | ||
} | ||
|
||
fn get_dataset_size(transactions: IterCsv<f32, File>) -> usize { | ||
let mut length = 0; | ||
for _ in transactions { | ||
length += 1; | ||
} | ||
length | ||
} | ||
|
||
fn main() { | ||
let n_trees: usize = 10; | ||
|
||
let transactions_f = MachineDegradation::load_data(); | ||
let features = get_features(transactions_f); | ||
|
||
println!("Features: {:?}", features); | ||
|
||
let mut mf: MondrianForestRegressor<f32> = | ||
MondrianForestRegressor::new(n_trees, features.len()); | ||
let mut err_total = 0.0; | ||
|
||
let transactions_l = MachineDegradation::load_data(); | ||
let dataset_size = get_dataset_size(transactions_l); | ||
|
||
let now = Instant::now(); | ||
|
||
let transactions = MachineDegradation::load_data(); | ||
for (idx, transaction) in transactions.enumerate() { | ||
let data = transaction.unwrap(); | ||
|
||
let x = data.get_observation(); | ||
let x = Array1::<f32>::from_vec(features.iter().map(|k| x[k]).collect()); | ||
|
||
let y = data.get_y().unwrap(); | ||
let y = data.to_regression_target("pCut::Motor_Torque").unwrap(); | ||
|
||
// println!("=M=1 idx={idx}, x={x}, y={y}"); | ||
|
||
// Skip first sample since tree has still no node | ||
if idx != 0 { | ||
let pred = mf.predict_one(&x, &y); | ||
let err = (pred - y).powi(2); | ||
err_total += err; | ||
// println!("idx={idx}, x={x}, y={y}, pred: {pred}, err: {err}"); | ||
} | ||
|
||
mf.learn_one(&x, &y); | ||
} | ||
|
||
let elapsed_time = now.elapsed(); | ||
println!("Took {}ms", elapsed_time.as_millis()); | ||
|
||
println!( | ||
"MSE: {} / {} = {}", | ||
err_total, | ||
dataset_size - 1, | ||
err_total / (dataset_size - 1).to_f32().unwrap() | ||
); | ||
|
||
let forest_size = mf.get_forest_size(); | ||
println!("Forest tree sizes: {:?}", forest_size); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
use crate::datasets::utils; | ||
use crate::stream::data_stream::Target; | ||
use crate::stream::iter_csv::IterCsv; | ||
use std::{fs::File, path::Path}; | ||
|
||
/// One Year Industrial Component Degradation | ||
/// | ||
/// References | ||
/// ---------- | ||
/// [^1]: [One Year Industrial Component Degradation](https://www.kaggle.com/datasets/inIT-OWL/one-year-industrial-component-degradation) | ||
pub struct MachineDegradation; | ||
impl MachineDegradation { | ||
pub fn load_data() -> IterCsv<f32, File> { | ||
// let url = "https://www.kaggle.com/datasets/inIT-OWL/one-year-industrial-component-degradation/download/fA53OHmuZ0enYASBqytj%2Fversions%2FvXObUJmxGJQSUSC2Wyc7%2Ffiles%2F01-04T184148_000_mode1.csv?datasetVersionNumber=1"; | ||
let file_name = "one-year-industrial-component-degradation.csv"; | ||
|
||
if !Path::new(file_name).exists() { | ||
panic!("Dataset not downloaded. Download it in file '{file_name}'"); | ||
} | ||
|
||
let file = File::open(file_name).unwrap(); | ||
let y_cols = Some(Target::Name("pCut::Motor_Torque".to_string())); | ||
IterCsv::<f32, File>::new(file, y_cols).unwrap() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
pub mod credit_card; | ||
pub mod keystroke; | ||
pub mod machine_degradation; | ||
pub mod synthetic; | ||
pub mod synthetic_regression; | ||
pub mod utils; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters