Skip to content

Commit

Permalink
Merge pull request #8 from valentingol/gpudtype
Browse files Browse the repository at this point in the history
🐛 Solve GPU and different dtypes issues
  • Loading branch information
valentingol authored Jun 24, 2024
2 parents 9bff944 + 225be4b commit ca605df
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 6 deletions.
25 changes: 21 additions & 4 deletions src/torch_pca/pca_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def fit(self, inputs: Tensor, *, determinist: bool = True) -> "PCA":
PCA
The PCA model fitted on the input data.
"""
# Auto-cast to float32 because float16 is not supported
if inputs.dtype == torch.float16:
inputs = inputs.to(torch.float32)

if self.svd_solver_ == "auto":
self.svd_solver_ = choose_svd_solver(
inputs=inputs,
Expand All @@ -184,14 +188,16 @@ def fit(self, inputs: Tensor, *, determinist: bool = True) -> "PCA":
eigenvals[eigenvals < 0.0] = 0.0
# Inverted indices
idx = range(eigenvals.size(0) - 1, -1, -1)
idx = torch.LongTensor(idx)
idx = torch.LongTensor(idx).to(eigenvals.device)
explained_variance = eigenvals.index_select(0, idx)
total_var = torch.sum(explained_variance)
# Compute equivalent variables to full SVD output
vh_mat = eigenvecs.T.index_select(0, idx)
coefs = torch.sqrt(explained_variance * (self.n_samples_ - 1))
u_mat = None
elif self.svd_solver_ == "randomized":
if self.n_components_ is None:
self.n_components_ = min(inputs.shape[-2:])
if (
not isinstance(self.n_components_, int)
or int(self.n_components_) != self.n_components_
Expand Down Expand Up @@ -267,11 +273,22 @@ def transform(self, inputs: Tensor, center: str = "fit") -> Tensor:
"""
self._check_fitted("transform")
assert self.components_ is not None # for mypy
transformed = inputs @ self.components_.T
assert self.mean_ is not None # for mypy
components = (
self.components_.to(torch.float16)
if inputs.dtype == torch.float16
else self.components_
)
mean = (
self.mean_.to(torch.float16)
if inputs.dtype == torch.float16
else self.mean_
)
transformed = inputs @ components.T
if center == "fit":
transformed -= self.mean_ @ self.components_.T
transformed -= mean @ components.T
elif center == "input":
transformed -= inputs.mean(dim=-2, keepdim=True) @ self.components_.T
transformed -= inputs.mean(dim=-2, keepdim=True) @ components.T
elif center != "none":
raise ValueError(
"Unknown centering, `center` argument should be "
Expand Down
4 changes: 3 additions & 1 deletion src/torch_pca/random_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def lu_normalizer(inputs: Tensor) -> Tuple[Tensor, Tensor]:

if random_state is not None:
torch.manual_seed(random_state)
proj_mat = torch.randn(inputs.shape[-1], size, device=inputs.device)
proj_mat = torch.randn(
inputs.shape[-1], size, device=inputs.device, dtype=inputs.dtype
)
if power_iteration_normalizer == "auto":
power_iteration_normalizer = "none" if n_iter <= 2 else "QR"
qr_normalizer = torch.linalg.qr
Expand Down
2 changes: 1 addition & 1 deletion src/torch_pca/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def svd_flip(u_mat: Optional[Tensor], vh_mat: Tensor) -> Tuple[Tensor, Tensor]:
Adjusted V^H matrix.
"""
max_abs_v_rows = torch.argmax(torch.abs(vh_mat), dim=1)
shift = torch.arange(vh_mat.shape[0])
shift = torch.arange(vh_mat.shape[0]).to(vh_mat.device)
indices = max_abs_v_rows + shift * vh_mat.shape[1]
flat_vh = torch.reshape(vh_mat, (-1,))
signs = torch.sign(torch.take_along_dim(flat_vh, indices, dim=0))
Expand Down
20 changes: 20 additions & 0 deletions tests/test_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Test PCA with GPU and different dtypes."""

# Copyright (c) 2024 Valentin Goldité. All Rights Reserved.
import pytest_check as check
import torch

from torch_pca import PCA


def test_gpu() -> None:
"""Test with GPU and different dtypes."""
inputs = torch.load("tests/input_data.pt").to("cuda:0")
for dtype in [torch.float32, torch.float16, torch.float64]:
inputs = inputs.to(dtype)
out1 = PCA(svd_solver="full").fit_transform(inputs)
out2 = PCA(svd_solver="covariance_eigh").fit_transform(inputs)
out3 = PCA(svd_solver="randomized", random_state=0).fit_transform(inputs)
for out in [out1, out2, out3]:
check.equal(str(out.device), "cuda:0")
check.equal(out.dtype, dtype)

0 comments on commit ca605df

Please sign in to comment.