diff --git a/README.md b/README.md index 4d43acf..5606b91 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ Principal Component Anlaysis (PCA) in PyTorch. The intention is to provide a simple and easy to use implementation of PCA in PyTorch, the most similar to the `sklearn`'s PCA as possible (in terms of API and, of course, output). +Plus, this implementation is **fully differentiable and faster** (thanks to GPU parallelization)! [![Release](https://img.shields.io/github/v/tag/valentingol/torch_pca?label=Pypi&logo=pypi&logoColor=yellow)](https://pypi.org/project/torch_pca/) ![PythonVersion](https://img.shields.io/badge/python-3.8%20%7E%203.11-informational) @@ -58,6 +59,50 @@ pca_model = PCA(n_components=None, svd_solver='full') More details and features in the [API documentation](https://torch-pca.readthedocs.io/en/latest/api.html#torch_pca.pca_main.PCA). +## Gradient backward pass + +Use the pytorch framework allows the automatic differentiation of the PCA! + +The PCA transform method is always differentiable so it is always possible to +compute gradient like that: + +```python +pca = PCA() +for ep in range(n_epochs): + optimizer.zero_grad() + out = neural_net(inputs) + with torch.no_grad(): + pca.fit(out) + out = pca.transform(out) + loss = loss_fn(out, targets) + loss.backward() +``` + +If you want to compute the gradient over the full PCA model (including the +fitted `pca.n_components`), you can do it by using the "full" SVD solver +and removing the part of the `fit` method that enforce the deterministic +output by passing `determinist=False` in `fit` or `fit_transform` method. +This part sort the components using the singular values and change their sign +accordingly so it is not differentiable by nature but may be not necessary if +you don't care about the determinism of the output: + +```python +pca = PCA(svd_solver="full") +for ep in range(n_epochs): + optimizer.zero_grad() + out = neural_net(inputs) + out = pca.fit_transform(out, determinist=False) + loss = loss_fn(out, targets) + loss.backward() +``` + +## Comparison of execution time with sklearn's PCA + +As we can see below the PyTorch PCA is faster than sklearn's PCA, in all the +configs tested with the parameter by default (for each PCA model): + +![include](docs/_static/comparison.png) + ## Implemented features - [x] `fit`, `transform`, `fit_transform` methods. @@ -65,20 +110,19 @@ More details and features in the [API documentation](https://torch-pca.readthedo `singular_values_`, `components_`, `mean_`, `noise_variance_`, ... - [x] Full SVD solver - [x] SVD by covariance matrix solver +- [x] Randomized SVD solver - [x] (absent from sklearn) Decide how to center the input data in `transform` method (default is like sklearn's PCA) - [x] Find number of components with explained variance proportion - [x] Automatically find number of components with MLE - [x] `inverse_transform` method +- [x] Whitening option +- [x] `get_covariance` method +- [x] `get_precision` method and `score`/`score_samples` methods ## To be implemented -- [ ] Whitening option -- [ ] Randomized SVD solver -- [ ] ARPACK solver -- [ ] `get_covariance` method -- [ ] `get_precision` method and `score` method -- [ ] Support sparse matrices +- [ ] Support sparse matrices with ARPACK solver ## Contributing diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 0000000..b325fa5 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,63 @@ +"""Comparison between sklearn and torch PCA models.""" + +# Copyright (c) 2024 Valentin Goldité. All Rights Reserved. + +from time import time + +# NOTE: requires matplotlib (not in requirements(-dev).txt) +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.decomposition import PCA as PCA_sklearn + +from torch_pca import PCA + + +def main() -> None: + """Measure and compare the time of execution of the PCA.""" + configs = [(75, 75), (100, 2000), (10_000, 500)] + torch_times, sklearn_times = [], [] + for config in configs: + inputs = torch.randn(*config) + t0 = time() + PCA(n_components=50).fit_transform(inputs) + torch_times.append(round(time() - t0, 4)) + t0 = time() + PCA_sklearn(n_components=50).fit_transform(inputs) + sklearn_times.append(round(time() - t0, 4)) + ticks = np.arange(len(configs)) + labels = [f"n_samples={config[0]}, n_features={config[1]}" for config in configs] + width = 0.35 + fig, ax = plt.subplots() + rects1 = ax.bar(ticks - width / 2, torch_times, width, label="Pytorch PCA") + rects2 = ax.bar(ticks + width / 2, sklearn_times, width, label="Sklearn PCA") + ax.set_ylabel("Time of execution (s)") + ax.set_title("Comparison of execution time between Pytorch and Sklearn PCA.") + ax.set_xticks(ticks) + ax.set_xticklabels(labels) + ax.legend() + autolabel(rects1, ax) + autolabel(rects2, ax) + fig.tight_layout() + plt.show() + + +def autolabel(rects: list, ax: plt.Axes) -> None: + """Attach a text label above each bar in *rects*, displaying its height. + + From https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/barchart.html + """ + for rect in rects: + height = rect.get_height() + ax.annotate( + str(height), + xy=(rect.get_x() + rect.get_width() / 2, height), + xytext=(0, 3), # 3 points vertical offset + textcoords="offset points", + ha="center", + va="bottom", + ) + + +if __name__ == "__main__": + main() diff --git a/docs/_static/comparison.png b/docs/_static/comparison.png new file mode 100644 index 0000000..8b7c379 Binary files /dev/null and b/docs/_static/comparison.png differ diff --git a/docs/comparison.md b/docs/comparison.md new file mode 100644 index 0000000..924aae2 --- /dev/null +++ b/docs/comparison.md @@ -0,0 +1,6 @@ +# Comparison of execution time with sklearn's PCA + +As we can see below the PyTorch PCA is faster than sklearn's PCA, in all the +configs tested with the parameter by default (for each PCA model): + +![include](https://raw.githubusercontent.com/valentingol/torch_pca/main/docs/_static/comparison.png) diff --git a/docs/grad.md b/docs/grad.md new file mode 100644 index 0000000..4536a47 --- /dev/null +++ b/docs/grad.md @@ -0,0 +1,36 @@ +# Gradient backward pass + +Use the pytorch framework allows the automatic differentiation of the PCA! + +The PCA transform method is always differentiable so it is always possible to +compute gradient like that: + +```python +pca = PCA() +for ep in range(n_epochs): + optimizer.zero_grad() + out = neural_net(inputs) + with torch.no_grad(): + pca.fit(out) + out = pca.transform(out) + loss = loss_fn(out, targets) + loss.backward() +``` + +If you want to compute the gradient over the full PCA model (including the +fitted `pca.n_components`), you can do it by using the "full" SVD solver +and removing the part of the `fit` method that enforce the deterministic +output by passing `determinist=False` in `fit` or `fit_transform` method. +This part sort the components using the singular values and change their sign +accordingly so it is not differentiable by nature but may be not necessary if +you don't care about the determinism of the output: + +```python +pca = PCA(svd_solver="full") +for ep in range(n_epochs): + optimizer.zero_grad() + out = neural_net(inputs) + out = pca.fit_transform(out, determinist=False) + loss = loss_fn(out, targets) + loss.backward() +``` diff --git a/docs/index.rst b/docs/index.rst index 3bc63d5..b22b4dd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -4,7 +4,8 @@ Pytorch PCA Principal Component Anlaysis (PCA) in PyTorch. The intention is to provide a simple and easy to use implementation of PCA in PyTorch, the most similar to the ``sklearn``\ ’s PCA as possible (in terms of API -and, of course, output). +and, of course, output). Plus, this implementation is **fully differentiable and faster** +(thanks to GPU parallelization)! |Release| |PythonVersion| |PytorchVersion| @@ -12,7 +13,7 @@ and, of course, output). |Ruff_logo| |Black_logo| -|Ruff| |Flake8| |Pydocstyle| |MyPy| |PyLint| +|Ruff| |Flake8| |MyPy| |PyLint| |Tests| |Coverage| |Documentation Status| @@ -41,8 +42,6 @@ Documentation: https://torch-pca.readthedocs.io/en/latest/ :target: https://github.com/valentingol/Dinosor/actions/workflows/ruff.yaml .. |Flake8| image:: https://github.com/valentingol/torch_pca/actions/workflows/flake.yaml/badge.svg :target: https://github.com/valentingol/Dinosor/actions/workflows/flake.yaml -.. |Pydocstyle| image:: https://github.com/valentingol/torch_pca/actions/workflows/pydocstyle.yaml/badge.svg - :target: https://github.com/valentingol/Dinosor/actions/workflows/pydocstyle.yaml .. |MyPy| image:: https://github.com/valentingol/torch_pca/actions/workflows/mypy.yaml/badge.svg :target: https://github.com/valentingol/Dinosor/actions/workflows/mypy.yaml .. |PyLint| image:: https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/valentingol/8fb4f3f78584e085dd7b0cca7e046d1f/raw/torch_pca_pylint.json @@ -60,6 +59,7 @@ Documentation: https://torch-pca.readthedocs.io/en/latest/ installation howto + grad api contributing.md license.md diff --git a/src/torch_pca/pca_main.py b/src/torch_pca/pca_main.py index f274f7f..8d819bd 100644 --- a/src/torch_pca/pca_main.py +++ b/src/torch_pca/pca_main.py @@ -1,13 +1,16 @@ """Main module for PCA.""" # Copyright (c) 2024 Valentin Goldité. All Rights Reserved. -from typing import Optional +# Inspired from https://github.com/scikit-learn (BSD-3-Clause License) +# Copyright (c) Scikit-learn developers. All Rights Reserved. +from math import log +from typing import Optional, Tuple, Union import torch from torch import Tensor from torch_pca.ncompo import NComponentsType, find_ncomponents -from torch_pca.svd import choose_svd_solver, svd_flip +from torch_pca.svd import choose_svd_solver, randomized_svd, svd_flip class PCA: @@ -38,14 +41,44 @@ class PCA: * 'covariance_eigh': Compute the covariance matrix and take the eigenvalues decomposition with torch.linalg.eigh. Most efficient for small n_features and large n_samples. + * 'randomized': Compute the randomized SVD by the method of Halko et al. By default, svd_solver='auto'. + + whiten : bool, optional + If True, the components_ vectors are divided by sqrt(n_samples - 1) + and scaled by the singular values to ensure uncorrelated outputs + with unit component-wise variances. + By default, False. + + iterated_power: int | str, optional + Integer or 'auto'. Number of iterations for the power method + computed by randomized SVD. Must be >= 0. + Ignored if svd_solver!='randomized'. By default, 'auto'. + n_oversamples : int, optional + Additional number of random vectors to sample the + range of input data in randomized solver to ensure proper + conditioning. + Ignored if svd_solver!='randomized'. By default, 10. + power_iteration_normalizer : str, optional + One of 'auto', 'QR', 'LU', 'none'. + Power iteration normalizer for randomized SVD solver. + Ignored if svd_solver!='randomized'. By default, 'auto. + random_state : int | None, optional + Seed of randomized SVD solver. + Ignored if svd_solver!='randomized'. By default, None. """ def __init__( self, n_components: NComponentsType = None, + *, + whiten: bool = False, svd_solver: str = "auto", + iterated_power: Union[str, int] = "auto", + n_oversamples: int = 10, + power_iteration_normalizer: str = "auto", + random_state: Optional[int] = None, ): #: Principal axes in feature space. self.components_: Optional[Tensor] = None @@ -65,39 +98,66 @@ def __init__( self.noise_variance_: Optional[Tensor] = None #: Singular values corresponding to each of the selected components. self.singular_values_: Optional[Tensor] = None + #: Whether the data is whitened or not. + self.whiten: bool = whiten #: Solver to use for the PCA computation. self.svd_solver_: str = svd_solver + # Randomized SVD parameters + self.n_oversamples = n_oversamples + self.iterated_power = iterated_power + self.power_iteration_normalizer = power_iteration_normalizer + self.random_state = random_state - if self.svd_solver_ not in ["auto", "full", "covariance_eigh"]: + if self.svd_solver_ not in ["auto", "full", "covariance_eigh", "randomized"]: raise ValueError( "Unknown SVD solver. `svd_solver` should be one of " - "'auto', 'full', 'covariance_eigh'." + "'auto', 'full', 'covariance_eigh', 'randomized'." ) - def fit_transform(self, inputs: Tensor) -> Tensor: + def fit_transform(self, inputs: Tensor, *, determinist: bool = True) -> Tensor: """Fit the PCA model and apply the dimensionality reduction. Parameters ---------- inputs : Tensor Input data of shape (n_samples, n_features). + determinist : bool, optional + If True, the SVD solver is deterministic but the gradient + cannot be computed through the PCA fit (the PCA transform is + always differentiable though). + If False, the SVD can be non-deterministic but the + gradient can be computed through the PCA fit. + By default, determinist=True. + Returns ------- transformed : Tensor Transformed data. """ - self.fit(inputs) + self.fit(inputs, determinist=determinist) transformed = self.transform(inputs) return transformed - def fit(self, inputs: Tensor) -> "PCA": + def fit(self, inputs: Tensor, *, determinist: bool = True) -> "PCA": """Fit the PCA model and return it. Parameters ---------- inputs : Tensor Input data of shape (n_samples, n_features). + determinist : bool, optional + If True, the SVD solver is deterministic but the gradient + cannot be computed through the PCA fit (the PCA transform is + always differentiable though). + If False, the SVD can be non-deterministic but the + gradient can be computed through the PCA fit. + By default, determinist=True. + + Returns + ------- + PCA + The PCA model fitted on the input data. """ if self.svd_solver_ == "auto": self.svd_solver_ = choose_svd_solver( @@ -113,6 +173,7 @@ def fit(self, inputs: Tensor) -> "PCA": full_matrices=False, ) explained_variance = coefs**2 / (inputs.shape[-2] - 1) + total_var = torch.sum(explained_variance) elif self.svd_solver_ == "covariance_eigh": covariance = inputs.T @ inputs delta = self.n_samples_ * torch.transpose(self.mean_, -2, -1) * self.mean_ @@ -125,12 +186,34 @@ def fit(self, inputs: Tensor) -> "PCA": idx = range(eigenvals.size(0) - 1, -1, -1) idx = torch.LongTensor(idx) explained_variance = eigenvals.index_select(0, idx) + total_var = torch.sum(explained_variance) # Compute equivalent variables to full SVD output vh_mat = eigenvecs.T.index_select(0, idx) coefs = torch.sqrt(explained_variance * (self.n_samples_ - 1)) u_mat = None - _, vh_mat = svd_flip(u_mat, vh_mat) - total_var = torch.sum(explained_variance) + elif self.svd_solver_ == "randomized": + if ( + not isinstance(self.n_components_, int) + or int(self.n_components_) != self.n_components_ + ): + raise ValueError( + "Randomized SVD only supports integer number of components." + f"Found '{self.n_components_}'." + ) + inputs_centered = inputs - self.mean_ + u_mat, coefs, vh_mat = randomized_svd( + inputs=inputs_centered, + n_components=self.n_components_, + n_oversamples=self.n_oversamples, + n_iter=self.iterated_power, + power_iteration_normalizer=self.power_iteration_normalizer, + random_state=self.random_state, + ) + explained_variance = coefs**2 / (inputs.shape[-2] - 1) + total_var = torch.sum(inputs_centered**2) / (self.n_samples_ - 1) + + if determinist: + _, vh_mat = svd_flip(u_mat, vh_mat) explained_variance_ratio = explained_variance / total_var self.n_components_ = find_ncomponents( n_components=self.n_components_, @@ -152,6 +235,14 @@ def fit(self, inputs: Tensor) -> "PCA": ) return self + def _check_fitted(self, method_name: str) -> None: + """Check if the PCA model is fitted.""" + if self.components_ is None: + raise ValueError( + f"PCA not fitted when calling {method_name}. " + "Please call `fit` or `fit_transform` first." + ) + def transform(self, inputs: Tensor, center: str = "fit") -> Tensor: """Apply dimensionality reduction to X. @@ -174,11 +265,8 @@ def transform(self, inputs: Tensor, center: str = "fit") -> Tensor: transformed : Tensor Transformed data of shape (n_samples, n_components). """ - if self.components_ is None: - raise ValueError( - "PCA not fitted when calling transform. " - "Please call `fit` or `fit_transform` first." - ) + self._check_fitted("transform") + assert self.components_ is not None # for mypy transformed = inputs @ self.components_.T if center == "fit": transformed -= self.mean_ @ self.components_.T @@ -189,6 +277,11 @@ def transform(self, inputs: Tensor, center: str = "fit") -> Tensor: "Unknown centering, `center` argument should be " "one of 'fit', 'input' or 'none'." ) + + if self.whiten: + scale = torch.sqrt(self.explained_variance_) + scale[scale < 1e-8] = 1e-8 + transformed /= scale return transformed def inverse_transform(self, inputs: Tensor) -> Tensor: @@ -206,10 +299,83 @@ def inverse_transform(self, inputs: Tensor) -> Tensor: where n_features is the number of features in the input data before applying transform. """ - if self.components_ is None: - raise ValueError( - "PCA not fitted when calling inverse_transform. " - "Please call `fit` or `fit_transform` first." - ) + self._check_fitted("inverse_transform") + assert self.components_ is not None # for mypy de_transformed = inputs @ self.components_ + self.mean_ return de_transformed + + def get_covariance(self) -> Tensor: + """Compute data covariance with the generative model.""" + self._check_fitted("get_covariance") + assert self.components_ is not None # for mypy + components, exp_variance_diff = self.get_exp_variance_diff() + covariance = (components.T * exp_variance_diff) @ components + covariance += self.noise_variance_ * torch.eye(components.shape[-1]) + return covariance + + def get_exp_variance_diff(self) -> Tuple[Tensor, Tensor]: + """Get explained variance difference (from noise).""" + assert self.noise_variance_ is not None # for mypy + components = self.components_ + explained_variance = self.explained_variance_ + if self.whiten: + components = components * torch.sqrt(explained_variance)[:, None] + exp_variance_diff = explained_variance - self.noise_variance_ + exp_variance_diff = torch.where( + exp_variance_diff > 0, + exp_variance_diff, + torch.tensor(0.0), + ) + return components, exp_variance_diff + + def get_precision(self) -> Tensor: + """Compute data precision matrix with the generative model. + + It is the inverse the covariance matrix but the method is more + efficient than computing it directly. + """ + self._check_fitted("get_precision") + assert self.noise_variance_ is not None # for mypy + assert self.components_ is not None # for mypy + n_features = self.components_.shape[-1] + if self.n_components_ == 0: + return torch.eye(n_features) / self.noise_variance_ + if self.noise_variance_ == 0.0: + return torch.linalg.inv(self.get_covariance()) + components, exp_variance_diff = self.get_exp_variance_diff() + precision = components @ components.T / self.noise_variance_ + precision += (1.0 / exp_variance_diff) * torch.eye(precision.shape[0]) + precision = components.T @ torch.linalg.inv(precision) @ components + precision /= -(self.noise_variance_**2) + precision += (1.0 / self.noise_variance_) * torch.eye(precision.shape[0]) + return precision + + def score_samples(self, inputs: Tensor) -> Tensor: + """Compute score of each sample based on log-likelihood. + + Returns + ------- + log_likelihood : Tensor + Log-likelihood of each sample under the current model, + of shape (n_samples,) + """ + centered_inputs = inputs - self.mean_ + n_features = centered_inputs.shape[-1] + precision = self.get_precision() + log_likelihood = -0.5 * ( + n_features * log(2 * torch.pi) + - torch.linalg.slogdet(precision)[1] + + torch.sum((centered_inputs @ precision) * centered_inputs, dim=-1) + ) + return log_likelihood + + def score(self, inputs: Tensor) -> Tensor: + """Return the average score (log-likelihood) of all samples.""" + return self.score_samples(inputs).mean() + + @property + def _n_features_out(self) -> int: + """Number of transformed output features.""" + self._check_fitted("_n_features_out") + assert self.components_ is not None # for mypy + return self.components_.shape[0] diff --git a/src/torch_pca/random_svd.py b/src/torch_pca/random_svd.py new file mode 100644 index 0000000..c7d88ae --- /dev/null +++ b/src/torch_pca/random_svd.py @@ -0,0 +1,51 @@ +"""Functions related to randomized SVD.""" + +# Copyright (c) 2024 Valentin Goldité. All Rights Reserved. +from typing import Optional, Tuple + +import torch +from torch import Tensor + + +def randomized_range_finder( + inputs: Tensor, + *, + size: int, + n_iter: int, + power_iteration_normalizer: str, + random_state: Optional[int], +) -> Tensor: + """Compute an orthonormal matrix whose range approximates the range of inputs. + + Returns + ------- + proj_mat : Tensor + Orthonormal matrix whose range approximates the range of inputs. + """ + + def no_normalizer(inputs: Tensor) -> Tuple[Tensor, Tensor]: + """Disable normalizer.""" + return inputs, None + + def lu_normalizer(inputs: Tensor) -> Tuple[Tensor, Tensor]: + """LU normalizer.""" + p_mat, l_mat, u_mat = torch.linalg.lu(inputs) + return p_mat @ l_mat, u_mat + + if random_state is not None: + torch.manual_seed(random_state) + proj_mat = torch.randn(inputs.shape[-1], size, device=inputs.device) + if power_iteration_normalizer == "auto": + power_iteration_normalizer = "none" if n_iter <= 2 else "QR" + qr_normalizer = torch.linalg.qr + if power_iteration_normalizer == "QR": + normalizer = qr_normalizer + elif power_iteration_normalizer == "LU": + normalizer = lu_normalizer + else: + normalizer = no_normalizer + for _ in range(n_iter): + proj_mat, _ = normalizer(inputs @ proj_mat) + proj_mat, _ = normalizer(inputs.T @ proj_mat) + proj_mat, _ = qr_normalizer(inputs @ proj_mat, mode="reduced") + return proj_mat diff --git a/src/torch_pca/svd.py b/src/torch_pca/svd.py index 5e0fdce..f5ef81a 100644 --- a/src/torch_pca/svd.py +++ b/src/torch_pca/svd.py @@ -1,12 +1,15 @@ """Functions related to SVD.""" # Copyright (c) 2024 Valentin Goldité. All Rights Reserved. -from typing import Optional, Tuple +# Inspired from https://github.com/scikit-learn (BSD-3-Clause License) +# Copyright (c) Scikit-learn developers. All Rights Reserved. +from typing import Optional, Tuple, Union import torch from torch import Tensor from torch_pca.ncompo import NComponentsType +from torch_pca.random_svd import randomized_range_finder def choose_svd_solver(inputs: Tensor, n_components: NComponentsType) -> str: @@ -15,15 +18,73 @@ def choose_svd_solver(inputs: Tensor, n_components: NComponentsType) -> str: return "covariance_eigh" if max(inputs.shape[-2:]) <= 500 or n_components == "mle": return "full" - # NOTE: The randomized solver is not implemented yet. - # if ( - # isinstance(n_components, float) - # and 1 <= n_components < 0.8 * min(inputs.shape) - # ): - # return "randomized" + if isinstance(n_components, int) and 1 <= n_components < 0.8 * min(inputs.shape): + return "randomized" return "full" +def randomized_svd( + inputs: Tensor, + n_components: int, + n_oversamples: int, + n_iter: Union[str, int], + power_iteration_normalizer: str, + random_state: Optional[int], +) -> Tuple[Tensor, Tensor, Tensor]: + """Randomized SVD using Halko, et al. (2009) method. + + Returns + ------- + u_mat : Tensor + Left singular vectors. + coefs : Tensor + Singular values. + vh_mat : Tensor + Right singular vectors. + + References + ---------- + .. [1] :arxiv:`"Finding structure with randomness: + Stochastic algorithms for constructing approximate matrix decompositions" + <0909.4061>` + Halko, et al. (2009) + + .. [2] A randomized algorithm for the decomposition of matrices + Per-Gunnar Martinsson, Vladimir Rokhlin and Mark Tygert + + .. [3] An implementation of a randomized algorithm for principal component + analysis A. Szlam et al. 2014 + """ + n_random = n_components + n_oversamples + n_samples, n_features = inputs.shape + if n_iter == "auto": + # Compromise found by Sklearn + n_iter = 7 if n_components < 0.1 * min(inputs.shape) else 4 + if isinstance(n_iter, str): + raise ValueError( + f"`iterated_power` must be an integer or 'auto'. Found '{n_iter}'." + ) + if n_samples < n_features: + inputs = inputs.T + proj_mat = randomized_range_finder( + inputs, + size=n_random, + n_iter=n_iter, + power_iteration_normalizer=power_iteration_normalizer, + random_state=random_state, + ) + pseudo_inputs = proj_mat.T @ inputs + u_mat, coefs, vh_mat = torch.linalg.svd(pseudo_inputs, full_matrices=False) + u_mat = proj_mat @ u_mat + if n_samples < n_features: + return ( + vh_mat[:n_components, :].T, + coefs[:n_components], + u_mat[:, :n_components].T, + ) + return u_mat[:, :n_components], coefs[:n_components], vh_mat[:n_components, :] + + def svd_flip(u_mat: Optional[Tensor], vh_mat: Tensor) -> Tuple[Tensor, Tensor]: """Sign correction to ensure deterministic output from SVD. diff --git a/tests/test_fit_transform.py b/tests/test_fit_transform.py index e695b07..825715d 100644 --- a/tests/test_fit_transform.py +++ b/tests/test_fit_transform.py @@ -12,40 +12,48 @@ def test_fullsvd() -> None: """Test basic full SVD.""" input_1 = torch.load("tests/input_data.pt").to(torch.float32) + 2.0 - torch_model = PCA( - n_components=2, - svd_solver="full", - ).fit(input_1) - sklearn_model = PCA_sklearn( - n_components=2, - svd_solver="full", - whiten=False, - ).fit(input_1) - check.is_true( - torch.allclose( - torch.tensor(sklearn_model.components_, dtype=torch.float32), - torch_model.components_, - rtol=1e-5, - atol=1e-5, + for whiten in [True, False]: + torch_model = PCA( + n_components=2, + svd_solver="full", + whiten=whiten, + ).fit(input_1) + sklearn_model = PCA_sklearn( + n_components=2, + svd_solver="full", + whiten=whiten, + ).fit(input_1) + check.is_true( + torch.allclose( + torch.tensor(sklearn_model.components_, dtype=torch.float32), + torch_model.components_, + rtol=1e-5, + atol=1e-5, + ) ) - ) - torch_outs = torch_model.transform(input_1) - sklearn_outs = torch.tensor(sklearn_model.transform(input_1), dtype=torch.float32) - check.is_true(torch.allclose(torch_outs, sklearn_outs, rtol=1e-5, atol=1e-5)) + torch_outs = torch_model.transform(input_1) + sklearn_outs = torch.tensor( + sklearn_model.transform(input_1), dtype=torch.float32 + ) + check.is_true(torch.allclose(torch_outs, sklearn_outs, rtol=1e-5, atol=1e-5)) - check.is_true( - torch.allclose( - torch_model.fit_transform(input_1), - torch_outs, - rtol=1e-5, - atol=1e-5, + check.is_true( + torch.allclose( + torch_model.fit_transform(input_1), + torch_outs, + rtol=1e-5, + atol=1e-5, + ) ) - ) # New data input_2 = torch.load("tests/input_data2.pt").to(torch.float32) - 1.0 torch_outs = torch_model.transform(input_2) sklearn_outs = torch.tensor(sklearn_model.transform(input_2), dtype=torch.float32) check.is_true(torch.allclose(torch_outs, sklearn_outs, rtol=1e-5, atol=1e-5)) + check.equal( + torch_model._n_features_out, # pylint: disable=protected-access + sklearn_model._n_features_out, # pylint: disable=protected-access + ) # Fail if not fitted model = PCA(n_components=2) @@ -104,6 +112,50 @@ def test_covariance_eigh() -> None: ) +def test_randomized() -> None: + """Test randomized SVD.""" + inputs = torch.load("tests/input_data.pt").to(torch.float32) + 2.0 + for inputs_ in [inputs, inputs.T]: + for method in ["auto", "QR", "LU", "none"]: + torch_model = PCA( + n_components=2, + svd_solver="randomized", + random_state=0, + power_iteration_normalizer=method, + ).fit(inputs_) + sklearn_model = PCA_sklearn( + n_components=2, + svd_solver="randomized", + random_state=0, + power_iteration_normalizer=method, + whiten=False, + ).fit(inputs_) + for attr_name in ["components_", "explained_variance_", "singular_values_"]: + attr_torch = getattr(torch_model, attr_name) + attr_sklearn = getattr(sklearn_model, attr_name) + check.is_true( + torch.allclose( + attr_torch, + torch.tensor(attr_sklearn, dtype=torch.float32), + rtol=1e-5, + atol=1e-5, + ), + f"Failed for method {method}, attribute {attr_name}", + ) + with pytest.raises( + ValueError, + match="Randomized SVD only supports integer number of components.*", + ): + PCA(n_components=0.5, svd_solver="randomized").fit(inputs) + with pytest.raises( + ValueError, + match="`iterated_power` must be an integer or 'auto'.*", + ): + PCA(n_components=2, svd_solver="randomized", iterated_power="UNKNOWN").fit( + inputs + ) + + def test_fail_svd() -> None: """Test unknown SVD solver.""" input_1 = torch.randn(200, 10) @@ -121,4 +173,6 @@ def test_auto_svd() -> None: check.equal(model.svd_solver_, "full") input_3 = torch.randn(10, 1200) model = PCA(n_components=2).fit(input_3) + check.equal(model.svd_solver_, "randomized") + model = PCA(n_components=1000).fit(input_3) check.equal(model.svd_solver_, "full") diff --git a/tests/test_get_cov_prec.py b/tests/test_get_cov_prec.py new file mode 100644 index 0000000..5820395 --- /dev/null +++ b/tests/test_get_cov_prec.py @@ -0,0 +1,70 @@ +"""Tests around covariance, precision and score methods.""" + +# Copyright (c) 2024 Valentin Goldité. All Rights Reserved. +import pytest_check as check +import torch +from sklearn.decomposition import PCA as PCA_sklearn + +from torch_pca import PCA + + +def test_get_covariance() -> None: + """Test get_covariance method.""" + inputs = torch.load("tests/input_data.pt").to(torch.float32) + torch_model = PCA(n_components=2).fit(inputs) + sklearn_model = PCA_sklearn(n_components=2).fit(inputs) + check.is_true( + torch.allclose( + torch.tensor(sklearn_model.get_covariance(), dtype=torch.float32), + torch_model.get_covariance(), + rtol=1e-5, + atol=1e-5, + ) + ) + + +def test_precision() -> None: + """Test get_precision method.""" + inputs = torch.load("tests/input_data3.pt").to(torch.float32) + torch_model = PCA(n_components=2).fit(inputs) + sklearn_model = PCA_sklearn(n_components=2).fit(inputs) + print( + torch_model.get_precision(), + ) + print( + torch.tensor(sklearn_model.get_precision(), dtype=torch.float32), + ) + check.is_true( + torch.allclose( + torch.tensor(sklearn_model.get_precision(), dtype=torch.float32), + torch_model.get_precision(), + rtol=1e-5, + atol=1e-5, + ) + ) + + +def test_scores() -> None: + """Test score-related methods.""" + inputs = torch.load("tests/input_data3.pt").to(torch.float32) + for whiten in [True, False]: + torch_model = PCA(n_components=2, whiten=whiten).fit(inputs) + sklearn_model = PCA_sklearn(n_components=2, whiten=whiten).fit(inputs) + check.is_true( + torch.allclose( + torch.tensor(sklearn_model.score_samples(inputs), dtype=torch.float32), + torch_model.score_samples(inputs), + rtol=1e-5, + atol=1e-5, + ), + f"Wrong with whiten={whiten}", + ) + check.is_true( + torch.allclose( + torch.tensor(sklearn_model.score(inputs), dtype=torch.float32), + torch_model.score(inputs), + rtol=1e-5, + atol=1e-5, + ), + f"Wrong with whiten={whiten}", + ) diff --git a/tests/test_grad.py b/tests/test_grad.py new file mode 100644 index 0000000..d80b396 --- /dev/null +++ b/tests/test_grad.py @@ -0,0 +1,51 @@ +"""Test gradient backward through PCA.""" + +# Copyright (c) 2024 Valentin Goldité. All Rights Reserved. +import torch +from torch import nn + +from torch_pca import PCA + + +def test_grad_transform() -> None: + """Test backward in transform.""" + X = torch.randn(100, 32) + y = X[:, :10].sum(dim=1) / 10.0 + model1 = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 10)) + model2 = nn.Sequential(nn.Linear(5, 1)) + pca = PCA(n_components=5, svd_solver="full") + optimizer = torch.optim.Adam( + list(model1.parameters()) + list(model2.parameters()), lr=0.01 + ) + criterion = nn.MSELoss() + for _ in range(10): + optimizer.zero_grad() + out = model1(X) + with torch.no_grad(): + pca.fit(out, determinist=True) + out_pca = pca.transform(out) + y_pred = model2(out_pca) + loss = criterion(y_pred, y.view(-1, 1)) + loss.backward() + optimizer.step() + + +def test_grad_fit_transform() -> None: + """Test backward in fit_transform.""" + X = torch.randn(100, 32) + y = X[:, :10].sum(dim=1) / 10.0 + model1 = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 10)) + model2 = nn.Sequential(nn.Linear(5, 1)) + pca = PCA(n_components=5, svd_solver="full") + optimizer = torch.optim.Adam( + list(model1.parameters()) + list(model2.parameters()), lr=0.01 + ) + criterion = nn.MSELoss() + for _ in range(10): + optimizer.zero_grad() + out = model1(X) + out_pca = pca.fit_transform(out, determinist=False) + y_pred = model2(out_pca) + loss = criterion(y_pred, y.view(-1, 1)) + loss.backward() + optimizer.step()