From 22a840a0f1a224558b36dbcbd52f3d8f1c4026da Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 25 Sep 2024 12:59:11 +0100 Subject: [PATCH] Use sgkit.distarray for PCA --- .github/workflows/cubed.yml | 2 +- sgkit/stats/pca.py | 4 +- sgkit/stats/preprocessing.py | 2 +- sgkit/stats/truncated_svd.py | 248 +++++++++++++++++++++++++++++++++++ sgkit/tests/test_pca.py | 2 +- 5 files changed, 253 insertions(+), 5 deletions(-) create mode 100644 sgkit/stats/truncated_svd.py diff --git a/.github/workflows/cubed.yml b/.github/workflows/cubed.yml index 687b8d519..0b3165fc4 100644 --- a/.github/workflows/cubed.yml +++ b/.github/workflows/cubed.yml @@ -30,4 +30,4 @@ jobs: - name: Test with pytest run: | - pytest -v sgkit/tests/test_{aggregation,association,hwe}.py -k 'test_count_call_alleles or test_gwas_linear_regression or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed \ No newline at end of file + pytest -v sgkit/tests/test_{aggregation,association,hwe,pca}.py -k 'test_count_call_alleles or test_gwas_linear_regression or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False]) or (test_pca__array_backend and tsqr)' --use-cubed \ No newline at end of file diff --git a/sgkit/stats/pca.py b/sgkit/stats/pca.py index 5a2fea566..6754b3aa2 100644 --- a/sgkit/stats/pca.py +++ b/sgkit/stats/pca.py @@ -1,20 +1,20 @@ from typing import Any, Optional, Union -import dask.array as da import numpy as np import xarray as xr -from dask_ml.decomposition import TruncatedSVD from sklearn.base import BaseEstimator from sklearn.pipeline import Pipeline from typing_extensions import Literal from xarray import DataArray, Dataset +import sgkit.distarray as da from sgkit import variables from ..typing import ArrayLike, DType, RandomStateType from ..utils import conditional_merge_datasets from .aggregation import count_call_alleles from .preprocessing import PattersonScaler +from .truncated_svd import TruncatedSVD def pca_est( diff --git a/sgkit/stats/preprocessing.py b/sgkit/stats/preprocessing.py index b7b993f31..1fb9176ab 100644 --- a/sgkit/stats/preprocessing.py +++ b/sgkit/stats/preprocessing.py @@ -1,11 +1,11 @@ from typing import Hashable, Optional -import dask.array as da import numpy as np import xarray as xr from sklearn.base import BaseEstimator, TransformerMixin from xarray import Dataset +import sgkit.distarray as da from sgkit import variables from ..typing import ArrayLike diff --git a/sgkit/stats/truncated_svd.py b/sgkit/stats/truncated_svd.py new file mode 100644 index 000000000..1dac7ae16 --- /dev/null +++ b/sgkit/stats/truncated_svd.py @@ -0,0 +1,248 @@ +from dask.utils import has_keyword +from sklearn.base import BaseEstimator, TransformerMixin + +import sgkit.distarray as da + +# Based on the implementation in Dask-ML, with minor changes to support the +# array API so it can work with both Dask and Cubed. + + +class TruncatedSVD(BaseEstimator, TransformerMixin): + def __init__( + self, + n_components=2, + algorithm="tsqr", + n_iter=5, + random_state=None, + tol=0.0, + compute=True, + ): + """Dimensionality reduction using truncated SVD (aka LSA). + + This transformer performs linear dimensionality reduction by means of + truncated singular value decomposition (SVD). Contrary to PCA, this + estimator does not center the data before computing the singular value + decomposition. + + Parameters + ---------- + n_components : int, default = 2 + Desired dimensionality of output data. + Must be less than or equal to the number of features. + The default value is useful for visualization. + + algorithm : {'tsqr', 'randomized'} + SVD solver to use. Both use the `tsqr` (for "tall-and-skinny QR") + algorithm internally. 'randomized' uses an approximate algorithm + that is faster, but not exact. See the References for more. + + n_iter : int, optional (default 0) + Number of power iterations, useful when the singular values + decay slowly. Error decreases exponentially as n_power_iter + increases. In practice, set n_power_iter <= 4. + + random_state : int, RandomState instance or None, optional + If int, random_state is the seed used by the random number + generator; + If RandomState instance, random_state is the random number + generator; + If None, the random number generator is the RandomState instance + used by `np.random`. + + tol : float, optional + Ignored. + + compute : bool + Whether or not SVD results should be computed + eagerly, by default True. + + Attributes + ---------- + components_ : array, shape (n_components, n_features) + + explained_variance_ : array, shape (n_components,) + The variance of the training samples transformed by a projection to + each component. + + explained_variance_ratio_ : array, shape (n_components,) + Percentage of variance explained by each of the selected + components. + + singular_values_ : array, shape (n_components,) + The singular values corresponding to each of the selected + components. The singular values are equal to the 2-norms of the + ``n_components`` variables in the lower-dimensional space. + + See Also + -------- + dask.array.linalg.tsqr + dask.array.linalg.svd_compressed + + References + ---------- + + Direct QR factorizations for tall-and-skinny matrices in + MapReduce architectures. + A. Benson, D. Gleich, and J. Demmel. + IEEE International Conference on Big Data, 2013. + http://arxiv.org/abs/1301.1071 + + Notes + ----- + SVD suffers from a problem called "sign indeterminacy", which means + the sign of the ``components_`` and the output from transform depend on + the algorithm and random state. To work around this, fit instances of + this class to data once, then keep the instance around to do + transformations. + + .. warning:: + + The implementation currently does not support sparse matrices. + + Examples + -------- + >>> from dask_ml.decomposition import TruncatedSVD + >>> import dask.array as da + >>> X = da.random.normal(size=(1000, 20), chunks=(100, 20)) + >>> svd = TruncatedSVD(n_components=5, n_iter=3, random_state=42) + >>> svd.fit(X) # doctest: +NORMALIZE_WHITESPACE + TruncatedSVD(algorithm='tsqr', n_components=5, n_iter=3, + random_state=42, tol=0.0) + + >>> print(svd.explained_variance_ratio_) # doctest: +ELLIPSIS + [0.06386323 0.06176776 0.05901293 0.0576399 0.05726607] + >>> print(svd.explained_variance_ratio_.sum()) # doctest: +ELLIPSIS + 0.299... + >>> print(svd.singular_values_) # doctest: +ELLIPSIS + array([35.92469517, 35.32922121, 34.53368856, 34.138..., 34.013...]) + + Note that ``transform`` returns a ``dask.Array``. + + >>> svd.transform(X) + dask.array + """ + self.algorithm = algorithm + self.n_components = n_components + self.n_iter = n_iter + self.random_state = random_state + self.tol = tol + self.compute = compute + + def fit(self, X, y=None): + """Fit truncated SVD on training data X + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data. + + y : Ignored + + Returns + ------- + self : object + Returns the transformer object. + """ + self.fit_transform(X) + return self + + def _check_array(self, X): + if self.n_components >= X.shape[1]: + raise ValueError( + "n_components must be < n_features; " + "got {} >= {}".format(self.n_components, X.shape[1]) + ) + return X + + def fit_transform(self, X, y=None): + """Fit model to X and perform dimensionality reduction on X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data. + + y : Ignored + + Returns + ------- + X_new : array, shape (n_samples, n_components) + Reduced version of X. This will always be a dense array, of the + same type as the input array. If ``X`` was a ``dask.array``, then + ``X_new`` will be a ``dask.array`` with the same chunks along the + first dimension. + """ + X = self._check_array(X) + if self.algorithm not in {"tsqr", "randomized"}: + raise ValueError( + "`algorithm` must be 'tsqr' or 'randomized', not '{}'".format( + self.algorithm + ) + ) + if self.algorithm == "tsqr": + if has_keyword(da.linalg.svd, "full_matrices"): + u, s, v = da.linalg.svd(X, full_matrices=False) + else: + u, s, v = da.linalg.svd(X) + u = u[:, : self.n_components] + s = s[: self.n_components] + v = v[: self.n_components] + else: + u, s, v = da.linalg.svd_compressed( + X, self.n_components, n_power_iter=self.n_iter, seed=self.random_state + ) + + X_transformed = u * s + explained_var = da.var(X_transformed, axis=0) + full_var = da.var(X, axis=0) + full_var = da.sum(full_var) + explained_variance_ratio = explained_var / full_var + + if self.compute: + v, explained_var, explained_variance_ratio, s = da.compute( + v, explained_var, explained_variance_ratio, s + ) + self.components_ = v + self.explained_variance_ = explained_var + self.explained_variance_ratio_ = explained_variance_ratio + self.singular_values_ = s + self.n_features_in_ = X.shape[1] + return X_transformed + + def transform(self, X, y=None): + """Perform dimensionality reduction on X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Data to be transformed. + + y : Ignored + + Returns + ------- + X_new : array, shape (n_samples, n_components) + Reduced version of X. This will always be a dense array, of the + same type as the input array. If ``X`` was a ``dask.array``, then + ``X_new`` will be a ``dask.array`` with the same chunks along the + first dimension. + """ + return X @ self.components_.T + + def inverse_transform(self, X): + """Transform X back to its original space. + + Returns an array X_original whose transform would be X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_components) + New data. + + Returns + ------- + X_original : array, shape (n_samples, n_features) + Note that this is always a dense array. + """ + # X = check_array(X) + return X @ self.components_ diff --git a/sgkit/tests/test_pca.py b/sgkit/tests/test_pca.py index ecbee543e..c132d0c74 100644 --- a/sgkit/tests/test_pca.py +++ b/sgkit/tests/test_pca.py @@ -1,12 +1,12 @@ from typing import Any, Optional import allel -import dask.array as da import numpy as np import pytest import xarray as xr from xarray import Dataset +import sgkit.distarray as da from sgkit.stats import pca from sgkit.stats.pca import count_call_alternate_alleles from sgkit.testing import simulate_genotype_call_dataset