Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: upgrade pyo3 to 0.23 #188

Merged
merged 6 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/pyaugurs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ augurs-ets = { workspace = true, features = ["mstl"] }
augurs-forecaster.workspace = true
augurs-mstl.workspace = true
augurs-seasons.workspace = true
numpy = "0.21.0"
pyo3 = { version = "0.21.0", features = ["extension-module"] }
numpy = "0.23.0"
pyo3 = { version = "0.23.3", features = ["extension-module"] }
tracing = { version = "0.1.37", features = ["log"] }

[lints]
Expand Down
5 changes: 1 addition & 4 deletions crates/pyaugurs/augurs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,12 @@ class Forecast:
def lower(self) -> npt.NDArray[np.float64] | None: ...
def upper(self) -> npt.NDArray[np.float64] | None: ...

class PyTrendModel:
def __init__(self, trend_model: TrendModel) -> None: ...

class MSTL:
@classmethod
def ets(cls, periods: Sequence[int]) -> "MSTL": ...
@classmethod
def custom_trend(
cls, periods: Sequence[int], trend_model: PyTrendModel
cls, periods: Sequence[int], trend_model: TrendModel
) -> "MSTL": ...
def fit(self, y: npt.NDArray[np.float64]) -> None: ...
def predict(self, horizon: int, level: float | None) -> Forecast: ...
Expand Down
6 changes: 1 addition & 5 deletions crates/pyaugurs/src/clustering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ impl Dbscan {
distance_matrix: InputDistanceMatrix<'_>,
) -> PyResult<Py<PyArray1<isize>>> {
let distance_matrix = distance_matrix.try_into()?;
Ok(self
.inner
.fit(&distance_matrix)
.into_pyarray_bound(py)
.into())
Ok(self.inner.fit(&distance_matrix).into_pyarray(py).into())
}
}
2 changes: 1 addition & 1 deletion crates/pyaugurs/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl DistanceMatrix {
*elem = *val;
}
}
arr.into_pyarray_bound(py).into()
arr.into_pyarray(py).into()
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/pyaugurs/src/dtw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ impl Dtw {
}

#[new]
#[pyo3(signature = (window=None, distance_fn=None, max_distance=None, lower_bound=None, upper_bound=None))]
fn new(
window: Option<usize>,
distance_fn: Option<&str>,
Expand Down
2 changes: 2 additions & 0 deletions crates/pyaugurs/src/ets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ impl AutoETS {
/// # Errors
///
/// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]).
#[pyo3(signature = (horizon, level=None))]
pub fn predict(&self, horizon: usize, level: Option<f64>) -> PyResult<Forecast> {
self.fitted
.as_ref()
Expand All @@ -80,6 +81,7 @@ impl AutoETS {
/// # Errors
///
/// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]).
#[pyo3(signature = (level=None))]
pub fn predict_in_sample(&self, level: Option<f64>) -> PyResult<Forecast> {
self.fitted
.as_ref()
Expand Down
8 changes: 4 additions & 4 deletions crates/pyaugurs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ impl From<Forecast> for augurs_core::Forecast {
#[pymethods]
impl Forecast {
#[new]
#[pyo3(signature = (point, level=None, lower=None, upper=None))]
fn new(
py: Python<'_>,
point: Py<PyArray1<f64>>,
Expand Down Expand Up @@ -80,23 +81,23 @@ impl Forecast {
// We could also use `into_pyarray` to construct the
// numpy arrays in the Rust heap; let's see which ends up being
// faster and more convenient.
self.inner.point.to_pyarray_bound(py).into()
self.inner.point.to_pyarray(py).into()
}

/// Get the lower prediction interval.
fn lower(&self, py: Python<'_>) -> Option<Py<PyArray1<f64>>> {
self.inner
.intervals
.as_ref()
.map(|x| x.lower.to_pyarray_bound(py).into())
.map(|x| x.lower.to_pyarray(py).into())
}

/// Get the upper prediction interval.
fn upper(&self, py: Python<'_>) -> Option<Py<PyArray1<f64>>> {
self.inner
.intervals
.as_ref()
.map(|x| x.upper.to_pyarray_bound(py).into())
.map(|x| x.upper.to_pyarray(py).into())
}
}

Expand All @@ -106,7 +107,6 @@ fn augurs(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
// pyo3_log::init();
m.add_class::<ets::AutoETS>()?;
m.add_class::<mstl::MSTL>()?;
m.add_class::<trend::PyTrendModel>()?;
m.add_class::<Forecast>()?;
m.add_class::<clustering::Dbscan>()?;
m.add_class::<distance::DistanceMatrix>()?;
Expand Down
24 changes: 20 additions & 4 deletions crates/pyaugurs/src/mstl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,30 @@ impl MSTL {
}
}

/// Create a new MSTL model with the given periods using provided trend model.
/// Create a new MSTL model with the given periods using the custom Python trend model.
///
/// The custom trend model must implement the following methods:
///
/// - `fit(self, y: np.ndarray) -> None`
/// - `predict(self, horizon: int, level: float | None = None) -> augurs.Forecast`
/// - `predict_in_sample(self, level: float | None = None) -> augurs.Forecast`
#[classmethod]
pub fn custom_trend(
_cls: &Bound<'_, PyType>,
periods: Vec<usize>,
trend_model: PyTrendModel,
trend_model: Py<PyAny>,
) -> Self {
let trend_model_name = trend_model.name().to_string();
let trend_model_name = Python::with_gil(|py| {
let trend_model = trend_model.bind(py).get_type();
trend_model
.name()
.map_or_else(|_| "unknown Python class".into(), |s| s.to_string())
});
Self {
forecaster: Forecaster::new(MSTLModel::new(periods, Box::new(trend_model))),
forecaster: Forecaster::new(MSTLModel::new(
periods,
Box::new(PyTrendModel::new(trend_model)),
)),
trend_model_name,
fit: false,
}
Expand All @@ -72,6 +86,7 @@ impl MSTL {
/// intervals at the given level.
///
/// If provided, `level` must be a float between 0 and 1.
#[pyo3(signature = (horizon, level=None))]
pub fn predict(&self, horizon: usize, level: Option<f64>) -> PyResult<Forecast> {
self.forecaster
.predict(horizon, level)
Expand All @@ -83,6 +98,7 @@ impl MSTL {
/// intervals at the given level.
///
/// If provided, `level` must be a float between 0 and 1.
#[pyo3(signature = (level=None))]
pub fn predict_in_sample(&self, level: Option<f64>) -> PyResult<Forecast> {
self.forecaster
.predict_in_sample(level)
Expand Down
7 changes: 2 additions & 5 deletions crates/pyaugurs/src/seasons.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use augurs_seasons::{Detector, PeriodogramDetector};
/// The default is 0.9.
/// :return: an array of season lengths.
#[pyfunction]
#[pyo3(signature = (y, min_period=None, max_period=None, threshold=None))]
pub fn seasonalities(
py: Python<'_>,
y: PyReadonlyArray1<'_, f64>,
Expand All @@ -35,9 +36,5 @@ pub fn seasonalities(
builder = builder.threshold(threshold);
}

Ok(builder
.build()
.detect(y.as_slice()?)
.to_pyarray_bound(py)
.into())
Ok(builder.build().detect(y.as_slice()?).to_pyarray(py).into())
}
26 changes: 12 additions & 14 deletions crates/pyaugurs/src/trend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
//! - `fit(self, y: np.ndarray) -> None`
//! - `predict(self, horizon: int, level: float | None = None) -> augurs.Forecast`
//! - `predict_in_sample(self, level: float | None = None) -> augurs.Forecast`
use numpy::ToPyArray;
use pyo3::{exceptions::PyException, prelude::*};
use pyo3::{exceptions::PyException, prelude::*, types::PyAnyMethods};

use augurs_mstl::{FittedTrendModel, TrendModel};

Expand All @@ -28,8 +29,8 @@ use crate::Forecast;
/// - `predict(self, horizon: int, level: float | None = None) -> augurs.Forecast`
/// - `predict_in_sample(self, level: float | None = None) -> augurs.Forecast`
#[pyclass(name = "TrendModel")]
#[derive(Clone, Debug)]
pub struct PyTrendModel {
#[derive(Debug)]
pub(crate) struct PyTrendModel {
model: Py<PyAny>,
}

Expand All @@ -44,7 +45,7 @@ impl PyTrendModel {
/// The returned PyTrendModel can be used in MSTL models using the
/// `custom_trend` method of the MSTL class.
#[new]
pub fn new(model: Py<PyAny>) -> Self {
pub(crate) fn new(model: Py<PyAny>) -> Self {
Self { model }
}
}
Expand All @@ -56,7 +57,7 @@ impl TrendModel for PyTrendModel {
.bind(py)
.get_type()
.name()
.map(|s| s.into_owned().into())
.map(|s| s.to_string().into())
})
.unwrap_or_else(|_| "unknown Python class".into())
}
Expand All @@ -68,21 +69,18 @@ impl TrendModel for PyTrendModel {
Box<dyn FittedTrendModel + Sync + Send>,
Box<dyn std::error::Error + Send + Sync + 'static>,
> {
// TODO - `fitted` should be a `PyFittedTrendModel`
// which should implement `Fit` and `FittedTrendModel`
Python::with_gil(|py| {
let np = y.to_pyarray_bound(py);
self.model.call_method1(py, "fit", (np,))
let model = Python::with_gil(|py| {
let np = y.to_pyarray(py);
self.model.call_method1(py, "fit", (np,))?;
Ok::<_, PyErr>(self.model.clone_ref(py))
})?;
Ok(Box::new(PyFittedTrendModel {
model: self.model.clone(),
}) as _)
Ok(Box::new(PyFittedTrendModel { model }) as _)
}
}

/// A wrapper for a Python trend model that has been fitted to data.
#[derive(Debug)]
pub struct PyFittedTrendModel {
pub(crate) struct PyFittedTrendModel {
model: Py<PyAny>,
}

Expand Down
Loading