Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[clustering] Derive {Des,S}erialize for all public items #324

Merged
merged 1 commit into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions algorithms/linfa-clustering/src/appx_dbscan/cells_grid/cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@ use linfa::Float;
use linfa_nn::distance::{Distance, L2Dist};
use ndarray::{Array1, ArrayView1, ArrayView2, ArrayViewMut1};
use partitions::PartitionVec;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// A point in a D dimensional euclidean space that memorizes its
/// status: 'core' or 'non core'
pub struct StatusPoint {
Expand All @@ -16,10 +23,7 @@ pub struct StatusPoint {

impl StatusPoint {
pub fn new(point_index: usize) -> StatusPoint {
StatusPoint {
point_index,
is_core: false,
}
StatusPoint { point_index, is_core: false }
}

pub fn is_core(&self) -> bool {
Expand All @@ -32,6 +36,11 @@ impl StatusPoint {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// Informations regarding the cell used in various stages of the approximate DBSCAN
/// algorithm if it is a core cell
pub struct CoreCellInfo<F: Float> {
Expand All @@ -42,6 +51,11 @@ pub struct CoreCellInfo<F: Float> {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// A cell from a grid that partitions the D dimensional euclidean space.
pub struct Cell<F: Float> {
/// The index of the intervals of the D dimensional axes where this cell lies
Expand Down
7 changes: 7 additions & 0 deletions algorithms/linfa-clustering/src/appx_dbscan/cells_grid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use linfa::Float;
use linfa_nn::{distance::L2Dist, NearestNeighbour};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use partitions::PartitionVec;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};

use cell::{Cell, StatusPoint};

Expand All @@ -16,6 +18,11 @@ pub type CellVector<F> = PartitionVec<Cell<F>>;
pub type CellTable = HashMap<Array1<i64>, usize>;

#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct CellsGrid<F: Float> {
table: CellTable,
cells: CellVector<F>,
Expand Down
12 changes: 12 additions & 0 deletions algorithms/linfa-clustering/src/appx_dbscan/counting_tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,28 @@ use crate::appx_dbscan::AppxDbscanValidParams;
use linfa::Float;
use linfa_nn::distance::{Distance, L2Dist};
use ndarray::{Array1, Array2, ArrayView1, Axis};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub enum IntersectionType {
FullyCovered,
Disjoint,
Intersecting,
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// Tree structure that divides the space in nested cells to perform approximate range counting
/// Each member of this structure is a node in the tree
pub struct TreeStructure<F: Float> {
Expand Down
7 changes: 7 additions & 0 deletions algorithms/linfa-clustering/src/dbscan/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@ use linfa_nn::{
CommonNearestNeighbour, NearestNeighbour, NearestNeighbourIndex,
};
use ndarray::{Array1, ArrayBase, Data, Ix2};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use std::collections::VecDeque;

use linfa::Float;
use linfa::{traits::Transformer, DatasetBase};

#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// DBSCAN (Density-based Spatial Clustering of Applications with Noise)
/// clusters together points which are close together with enough neighbors
/// labelled points which are sparsely neighbored as noise. As points may be
Expand Down
5 changes: 5 additions & 0 deletions algorithms/linfa-clustering/src/k_means/hyperparams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ pub struct KMeansValidParams<F: Float, R: Rng, D: Distance<F>> {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
/// An helper struct used to construct a set of [valid hyperparameters](KMeansParams) for
/// the [K-means algorithm](crate::KMeans) (using the builder pattern).
pub struct KMeansParams<F: Float, R: Rng, D: Distance<F>>(KMeansValidParams<F, R, D>);
Expand Down
10 changes: 10 additions & 0 deletions algorithms/linfa-clustering/src/optics/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ pub struct Optics;
/// This struct represents a data point in the dataset with it's associated distances obtained from
/// the OPTICS analysis
#[derive(Debug, Clone)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct Sample<F> {
/// Index of the observation in the dataset
index: usize,
Expand Down Expand Up @@ -103,6 +108,11 @@ impl<F: Float> Ord for Sample<F> {
/// that of the dataset instead ordering based on the clustering structure worked out during
/// analysis.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct OpticsAnalysis<F: Float> {
/// A list of the samples in the dataset sorted and with their reachability and core distances
/// computed.
Expand Down
5 changes: 5 additions & 0 deletions algorithms/linfa-clustering/src/optics/hyperparams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ impl<F: Float, D, N> OpticsValidParams<F, D, N> {
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct OpticsParams<F, D, N>(OpticsValidParams<F, D, N>);

impl<F: Float, D, N> OpticsParams<F, D, N> {
Expand Down
Loading