Skip to content

Commit

Permalink
Moves Pearson R calculation to Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
cuducos committed May 9, 2024
1 parent 0e9ccea commit 477a8d0
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 8 deletions.
28 changes: 28 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@ use walkdir::WalkDir;

use pyo3::{pyfunction, pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};

#[pyfunction]
fn pearson_r(x: Vec<u32>, y: Vec<u32>) -> PyResult<f64> {
let n = x.len() as f64;
let sum_x = x.iter().sum::<u32>() as f64;
let sum_y = y.iter().sum::<u32>() as f64;
let sum_x_sq = x.iter().map(|&i| i.pow(2)).sum::<u32>();
let sum_y_sq = y.iter().map(|&i| i.pow(2)).sum::<u32>();
let p_sum = x.iter().zip(y).map(|(i, j)| i * j).sum::<u32>();
let num = (p_sum as f64) - ((sum_x * sum_y) / n);
let multiplier_x: f64 = (sum_x_sq as f64) - ((sum_x.powi(2)) / n);
let multiplier_y: f64 = (sum_y_sq as f64) - ((sum_y.powi(2)) / n);
let den = (multiplier_x * multiplier_y).sqrt();
if den != 0.0 {
Ok(num / den)
} else {
Ok(0.0)
}
}

#[pyfunction]
fn latest_changed_at(dir: String) -> PyResult<String> {
let mut latest = SystemTime::UNIX_EPOCH;
Expand All @@ -26,6 +45,7 @@ fn latest_changed_at(dir: String) -> PyResult<String> {
#[pymodule]
fn crates(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(latest_changed_at, m)?)?;
m.add_function(wrap_pyfunction!(pearson_r, m)?)?;
Ok(())
}

Expand All @@ -44,4 +64,12 @@ mod tests {
let got = latest_changed_at(tmp.path().to_string_lossy().to_string());
assert_eq!(got.unwrap(), expected.format("%Y-%m-%d").to_string());
}

#[test]
fn test_pearson_r() {
let expected = 0.7904333328627509;
let x = vec![2, 2, 3, 1, 0, 2, 2, 1, 1, 1, 1, 2];
let y = vec![2, 2, 3, 1, 0, 2, 1, 1, 1, 2, 1, 1];
assert_eq!(pearson_r(x, y).unwrap(), expected);
}
}
8 changes: 4 additions & 4 deletions whiskyton/helpers/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __init__(
text_line_height=11,
):
"""
:param reference: whiskyton.models.Whisky or tastes (list of str)
:param comparison: whiskyton.models.Whisky or tastes (list of str)
:param reference: whiskyton.models.Whisky or tastes (tuple of integers)
:param comparison: whiskyton.models.Whisky or tastes (tuple of integers)
:param width: (int) width of the SVG chart
:param height: (int) width of the SVG chart
:param sides: (int) number of sides the grid (polygon)
Expand All @@ -29,9 +29,9 @@ def __init__(

# set whisky data
if not isinstance(reference, (list, tuple)) and reference is not None:
reference = reference.get_tastes()
reference = tuple(str(taste) for taste in reference.get_tastes())
if not isinstance(comparison, (list, tuple)) and comparison is not None:
comparison = comparison.get_tastes()
comparison = tuple(str(taste) for taste in comparison.get_tastes())

self.reference = reference
self.comparison = comparison
Expand Down
7 changes: 4 additions & 3 deletions whiskyton/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from re import compile

from crates import pearson_r
from flask import current_app
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import DeclarativeBase
Expand Down Expand Up @@ -40,10 +41,10 @@ def __repr__(self):
def get_tastes(self):
"""
Return a list of tastes of the whisky.
:return: (list of strings) tastes of the whisky
:return: (tuple of integers) tastes of the whisky
"""
tastes = current_app.config["TASTES"]
return [str(getattr(self, taste, None)) for taste in tastes]
return tuple(getattr(self, taste) for taste in tastes)

def get_slug(self):
"""
Expand All @@ -63,7 +64,7 @@ def get_correlation(self, comparison):
return {
"reference": self.id,
"whisky": comparison.id,
"r": self.__pearson_r(self.get_tastes(), comparison.get_tastes()),
"r": pearson_r(self.get_tastes(), comparison.get_tastes()),
}

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion whiskyton/tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_slug(self):
def test_get_tastes(self):
with self.app.app_context():
whisky = self.get_whisky(2)
tastes = ["1", "1", "1", "1", "1", "3", "2", "1", "0", "2", "0", "2"]
tastes = (1, 1, 1, 1, 1, 3, 2, 1, 0, 2, 0, 2)
self.assertEqual(whisky.get_tastes(), tastes)

# test methods from Chart (whiskyton/helpers/charts.py)
Expand Down

0 comments on commit 477a8d0

Please sign in to comment.