-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
使用 SPU 实现主成分分析基础功能 #213
Comments
hacker-jerry give it to me. |
您好,我使用jax实现了一个 pca 的类原型。
算法可以通过
直接调用。也可以通过pdd 的方式进行模拟
但是,我在使用spsim 进行模拟的时候,发生报错
请问应该如何修改? |
By the way, 用jax.jit也可以编译通过,但是还是无法使用spsim😭 |
Hi @hacker-jerry,这个看上去是因为 SPU并未支持JAX所有的算子(比如 |
感谢提供复现代码,我们研究一下 Thanks |
能麻烦提供一下本地使用的package版本么?(主要是spu和jax) |
spu 0.3.3b0 |
Thanks,我发现你使用ppd和spsim的方式不太一致: |
谢谢,请问spsim这里的报错是什么原因呢?应该怎样修改? |
首先,我理解fit方法应该是需要能在密态下执行的,所以本质上的原因应该是SPU暂时没支持svd算子。
|
好的,谢谢! |
您好,我使用jacobi的方法实现了eigh算子,重构后的代码如下: import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
class PCA:
def __init__(self, n_components=None, tol=1e-8, max_iters=100):
self.n_components = n_components
self.tol = tol
self.max_iters = max_iters
self.components_ = None
self.explained_variance_ = None
self.mean_ = None
def fit_transform(self, X):
self.mean_ = jnp.mean(X, axis=0)
X_centered = X - self.mean_
cov_matrix = jnp.cov(X_centered, rowvar=False)
eigenvalues, eigenvectors = jacobi_eigh(cov_matrix, self.tol, self.max_iters)
idx = jnp.argsort(eigenvalues)[::-1]
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]
if self.n_components is None:
self.n_components = X.shape[1]
self.components_ = eigenvectors[:, :self.n_components]
self.explained_variance_ = eigenvalues[:self.n_components]
X_transformed = jnp.dot(X_centered, self.components_)
return X_transformed, self.explained_variance_
def jacobi_eigh(A, tol, max_iters):
n = A.shape[0]
Q = jnp.eye(n)
def body_fn(i, vals):
A, Q = vals
p, q = jnp.unravel_index(jnp.argmax(jnp.abs(A - jnp.diag(jnp.diag(A)))), A.shape)
phi = 0.5 * jnp.arctan(2 * A[p, q] / (A[q, q] - A[p, p]))
rotation = jnp.eye(n)
rotation = rotation.at[[p, q], [p, q]].set(jnp.cos(phi))
rotation = rotation.at[q, p].set(jnp.sin(phi))
rotation = rotation.at[p, q].set(-jnp.sin(phi))
A_prime = rotation.T @ A @ rotation
Q_prime = Q @ rotation
A = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: A, lambda _: A_prime, None)
Q = jax.lax.cond(jnp.abs(A[p, q]) < tol, lambda _: Q, lambda _: Q_prime, None)
return A, Q
A, Q = jax.lax.fori_loop(0, max_iters, body_fn, (A, Q))
return jnp.diag(A), Q 该函数可以通过jit编译后调用 pca = PCA(n_components=2)
pca_fit_transform = jit(pca.fit_transform, static_argnums=1)
# Prepare some data
X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 使用编译后的fit_transform函数进行拟合和转换
X_transformed, explained_variance = pca_fit_transform(X)
print(explained_variance)
print(X_transformed) 但是使用spism模拟时,再次发生报错 import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
import spu
sim_aby = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
def fit_transform(X, n_components=None):
pca_fit_transform = jit(PCA(n_components=n_components).fit_transform, static_argnums=1)
X_transformed, explained_variance = pca_fit_transform(X)
return X_transformed, explained_variance
result = spsim.sim_jax(sim_aby, fit_transform, static_argnums=(1,))(X,2) 报错信息如下:
请问是什么原因? |
hello,不能跑的原因主要是eigh的实现里用到了三角函数,spu当前没有实现,所以报错了; PLUS,你eigh的实现应该也有问题,我运行了你的eigh def test_eigh():
X = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
X_centered = X - jnp.mean(X, axis=0)
cov_matrix = jnp.cov(X_centered, rowvar=False)
eigenvalues, eigenvectors = jacobi_eigh(cov_matrix, 1e-8, 1000)
# print(eigenvalues)
# print(eigenvectors)
print(cov_matrix @ eigenvectors)
print(eigenvalues * eigenvectors)
print()
eigenvalues, eigenvectors = eigh(cov_matrix)
# print(eigenvalues)
# print(eigenvectors)
print(cov_matrix @ eigenvectors)
print(eigenvalues * eigenvectors) [[2.2865321e+02 5.4445304e-06 2.5582359e+02]
[2.2865321e+02 5.4445304e-06 2.5582359e+02]
[2.2865321e+02 5.4445304e-06 2.5582359e+02]]
[[1.4377324e+02 1.2175019e-21 2.0135764e+02]
[2.4747901e-06 1.2621775e-29 3.4660027e-06]
[1.6085759e+02 9.3170446e-22 2.2528442e+02]]
[[ 4.1723251e-07 -2.9802322e-07 1.5588457e+01]
[ 4.1723251e-07 -2.9802322e-07 1.5588457e+01]
[ 4.1723251e-07 -2.9802322e-07 1.5588457e+01]]
[[-4.0569081e-08 -1.6165853e-06 1.5588456e+01]
[-1.1870292e-07 1.1621918e-06 1.5588454e+01]
[ 1.5927199e-07 4.5439356e-07 1.5588456e+01]] |
jacobi需要计算旋转变换,当前spu无法支持;我这边提供两种思路,你可以尝试一下:
仅供参考~ |
谢谢您的建议,基于此,我重新实现了一下,代码如下: import jax
import jax.numpy as jnp
from jax import random
class PCA:
def __init__(self, n_components):
self.n_components = n_components
self.mean = None
self.components = None
self.variances = None
def fit(self, X):
self.mean = jnp.mean(X, axis=0)
X = X - self.mean
cov_matrix = jnp.cov(X, rowvar=False)
L = jnp.linalg.cholesky(cov_matrix)
q, r = jnp.linalg.qr(L)
eigvals = jnp.diag(r)
idx = jnp.argsort(eigvals)[::-1][:self.n_components]
self.components = q[:, idx]
self.variances = eigvals[idx]
def transform(self, X):
X = X - self.mean
return jnp.dot(X, self.components)
def fit_and_transform(X, n_components):
pca = PCA(n_components)
pca.fit(X)
return pca.transform(X)
X = random.randint(random.PRNGKey(0), (10,3), 0, 10)
fit_and_transform_jit = jit(fit_and_transform, static_argnums=1)
X_transformed = fit_and_transform_jit(X, 2)
print(X_transformed) 您看是否符合要求? |
Sorry, 这应该是spsim的bug,实际上无论cholesky还是qr应该都无法真实的执行,执行到那两个函数的时候似乎python进程会被直接关闭,我们后续应该会修复这个bug。(所以我很好奇,您运行spsim真的能得到PCA transform后的矩阵么?) 所以,你也需要自己手动实现cholesky分解或qr分解。 最后,麻烦您后面提交pr的时候,用注释的方式标记一下之前的实现中,因为spu不支持算子而无法运行的实现方式,后续我们增加这些算子以后可以重新考察这些实现~ 感谢! |
Thanks, 我本地运行会一直卡住,我需要check一下原因. 另外,麻烦您运行一下下面这个代码,看是否会raise除0错误. def test_run_eigh():
X = jnp.array(np.random.rand(6, 3))
cov_matrix = jnp.cov(X, rowvar=False)
sim_aby = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)
print(cov_matrix)
print(jnp.linalg.det(cov_matrix))
print(jnp.linalg.cholesky(cov_matrix))
print(spsim.sim_jax(sim_aby, jnp.linalg.cholesky)(cov_matrix))
print(1 / 0) |
我是用的m1 mac |
我本地测试了一下,应该是jax版本问题,0.4.8是ok的,我本地是0.4.13才会报错...这个问题得 @anakinxc 看看 建议您实现以后可以:
感谢! |
你好 想问下cholesky分解和qr分解是已经在SPU中支持了还是使用spsim模拟的bug? |
hi,麻烦升级 spu 到最新的或者把 jax 和 jaxlib 降级到 0.4.12 |
那看来就是jax最新版本不太适配了,,那您就先用现在的版本先开发吧 |
ok,我测试了一下,和sklearn的 PCA 效果是差不多的。请问如果要提交pr,需要在哪几个文件中进行修改? |
感谢快速响应,可以先参考一下这个kmeans的PR; BTW:麻烦请在spsim模拟测试的那个文件中同时提交一下和明文sklearn的结果对比(可以写在不同的unittest里) Thanks! |
Already solved this issue @Candicepan . |
任务介绍
详细要求
能力要求
操作说明
The text was updated successfully, but these errors were encountered: