Skip to content

Commit

Permalink
problem_jacobi
Browse files Browse the repository at this point in the history
  • Loading branch information
cpz2024 committed Nov 12, 2024
1 parent a3e1e09 commit c18d773
Show file tree
Hide file tree
Showing 17 changed files with 1,148 additions and 9 deletions.
10 changes: 9 additions & 1 deletion libspu/device/pphlo/pphlo_intrinsic_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,15 @@ std::vector<Value> intrinsic_dispatcher(SPUContext* ctx,

return {inputs[0]};
}
// DO-NOT-EDIT: Add_DISPATCH_CODE

if (name == "permute") {
SPU_ENFORCE(inputs.size() == 2);
absl::Span<const spu::Value> input_span(inputs.data(), inputs.size());
return kernel::hal::internal::apply_inv_perm(ctx, input_span, inputs[1]);

}



// Default: Identity function
if (name == "example") {
Expand Down
5 changes: 5 additions & 0 deletions libspu/kernel/hal/permute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ spu::Value _apply_inv_perm_ss(SPUContext *ctx, const spu::Value &x,
return std::move(ret[0]);
}


// Ref: https://eprint.iacr.org/2019/695.pdf
// Algorithm 5: Composition of two share-vector permutations
//
Expand Down Expand Up @@ -949,6 +950,8 @@ spu::Value gen_inv_perm(SPUContext *ctx, absl::Span<spu::Value const> inputs,
return inv_perm;
}



std::vector<spu::Value> apply_inv_perm(SPUContext *ctx,
absl::Span<spu::Value const> inputs,
const spu::Value &perm) {
Expand Down Expand Up @@ -978,6 +981,8 @@ std::vector<spu::Value> apply_inv_perm(SPUContext *ctx,
}
}



// Secure Radix Sort
// Ref:
// https://eprint.iacr.org/2019/695.pdf
Expand Down
10 changes: 9 additions & 1 deletion libspu/kernel/hal/permute.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,12 @@ std::vector<Value> topk_1d(SPUContext *ctx, const spu::Value &input,
const SimpleCompFn &scalar_cmp,
const TopKConfig &config);

} // namespace spu::kernel::hal
} // namespace spu::kernel::hal


namespace spu::kernel::hal::internal{
std::vector<spu::Value> apply_inv_perm(SPUContext *ctx,
absl::Span<spu::Value const> inputs,
const spu::Value &perm);

}
31 changes: 25 additions & 6 deletions sml/decomposition/BUILD.bazel → sml/manifold/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Ant Group Co., Ltd.
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,12 +17,31 @@ load("@rules_python//python:defs.bzl", "py_library")
package(default_visibility = ["//visibility:public"])

py_library(
name = "pca",
srcs = ["pca.py"],
deps = ["//sml/utils:extmath"],
name = "jacobi",
srcs = ["jacobi.py"],
)

py_library(
name = "nmf",
srcs = ["nmf.py"],
name = "dijkstra",
srcs = ["dijkstra.py"],
)

py_library(
name = "MDS",
srcs = ["MDS.py"],
)

py_library(
name = "kneighbors",
srcs = ["kneighbors.py"],
)

py_library(
name = "floyd",
srcs = ["floyd.py"],
)

py_library(
name = "SE",
srcs = ["SE.py"],
)
46 changes: 46 additions & 0 deletions sml/manifold/MDS.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp

from sml.manifold.jacobi import Jacobi

def mds(D, num_samples, n_components):
D_2 = jnp.square(D)
B = jnp.zeros((num_samples, num_samples))
B = -0.5 * D_2
# 按行求和(使用英语的注释)
dist_2_i = jnp.sum(B, axis=1)
dist_2_i = dist_2_i / num_samples
# 按列求和
dist_2_j = dist_2_i.T
# 全部求和
dist_2 = jnp.sum(dist_2_i)
dist_2 = dist_2 / (num_samples)
for i in range(num_samples):
for j in range(num_samples):
B = B.at[i, j].set(B[i][j] - dist_2_i[i] - dist_2_j[j] + dist_2)

values, vectors = Jacobi(B, num_samples)
values = jnp.diag(values)
values = jnp.array(values)
values = jnp.expand_dims(values, axis=1).repeat(vectors.shape[1], axis=1)
values,vectors=jax.lax.sort_key_val(values.T,vectors.T)
vectors=vectors[:,num_samples - n_components:num_samples]
values=values[0,num_samples - n_components:num_samples]
values = jnp.sqrt(jnp.diag(values))

ans = jnp.dot(vectors, values)

return B, ans, values, vectors
46 changes: 46 additions & 0 deletions sml/manifold/SE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp
import numpy as np
import spu.intrinsic as si
from sml.manifold.jacobi import Jacobi

def se(X, num_samples, D, n_components):
X, Q = Jacobi(X, num_samples)
X = jnp.diag(X)
X = jnp.array(X)
# perm = jnp.argsort(X)
X2 = jnp.expand_dims(X, axis=1).repeat(Q.shape[1], axis=1)
X3,ans=jax.lax.sort_key_val(X2.T,Q.T)
ans=ans[:,1:n_components + 1]
D = jnp.diag(D)
ans = ans.T * jnp.reciprocal(jnp.sqrt(D))
return ans.T


def normalization(
adjacency, # 邻接矩阵
norm_laplacian=True, # 如果为 True,使用对称归一化拉普拉斯矩阵;如果为 False,使用非归一化的拉普拉斯矩阵。
):
D = jnp.sum(adjacency, axis=1)
D = jnp.diag(D)

L = D - adjacency
D2 = jnp.diag(jnp.reciprocal(jnp.sqrt(jnp.diag(D))))
if norm_laplacian == True:
# 归一化
L = jnp.dot(D2, L)
L = jnp.dot(L, D2)
return D, L
110 changes: 110 additions & 0 deletions sml/manifold/dijkstra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp


def set_value(x, index, value, n):
# 将数组x的index索引处的值修改为value,其中index是秘密共享的
perm = jnp.zeros(n, dtype=jnp.int16)
perm_2 = jnp.zeros(n, dtype=jnp.int16)
for i in range(n):
perm = perm.at[i].set(i)
perm_2 = perm_2.at[i].set(index)
flag = jnp.equal(perm, perm_2)
set_x = jnp.select([flag], [value], x)

return set_x


def get_value_1(x, index, n):
# 获得x[index]索引处的值,其中index是秘密共享的
perm = jnp.zeros(n, dtype=jnp.int16)
perm_2 = jnp.zeros(n, dtype=jnp.int16)
for i in range(n):
perm = perm.at[i].set(i)
perm_2 = perm_2.at[i].set(index)
flag = jnp.equal(perm, perm_2)
return jnp.sum(flag * x)


def get_value_2(x, index_1, index_2, n):
# 获得x[index_1][index_2]索引处的值,其中index_2是明文,index_1是秘密共享的
# 初始化行索引
perm_1 = jnp.zeros((n, n), dtype=jnp.int16)
perm_2_row = jnp.zeros((n, n), dtype=jnp.int16)

for i in range(n):
for j in range(n):
perm_1 = perm_1.at[i, j].set(i)
perm_2_row = perm_2_row.at[i, j].set(index_1)

# 行匹配
flag_row = jnp.equal(perm_1, perm_2_row)

# 使用明文 index_2 直接提取列的值
flag = flag_row[:, index_2]

# 返回匹配索引处的值
return jnp.sum(flag * x[:, index_2])


def mpc_dijkstra(adj_matrix, num_samples, start, dist_inf):
# adj_matrix:要求最短路径的邻接矩阵
# num_samples:邻接矩阵的大小
# start:要计算所有点到点start的最短路径
# dis_inf:所有点到点start的初始最短路径,设置为inf

# 用inf值初始化

sinf = dist_inf[0]
distances = dist_inf

# 使用 Dijkstra 算法计算从起始点到其他点的最短路径
distances = distances.at[start].set(0)
# visited = [False] * num_samples
visited = jnp.zeros(num_samples, dtype=bool) # 初始化为 False 的数组
visited = jnp.array(visited)

for i in range(num_samples):
# 找到当前未访问的最近节点

min_distance = sinf
min_index = -1
for v in range(num_samples):
flag = (visited[v] == 0) * (distances[v] < min_distance)
min_distance = min_distance + flag * (distances[v] - min_distance)
min_index = min_index + flag * (v - min_index)
# min_distance = jax.lax.cond(flag, lambda _: distances[v], lambda _: min_distance)
# min_index = jax.lax.cond(flag, lambda _: v, lambda _: min_index)

# 标记为已访问
# jax.lax.dynamic_update_slice(visited, 1, (min_index,))
# visited[min_index] = True
visited = set_value(visited, min_index, True, num_samples)

# 更新邻接节点的距离
temp_dis = get_value_1(distances, min_index, num_samples)

for v in range(num_samples):
temp_adj = get_value_2(adj_matrix, min_index, v, num_samples)
dist_new = temp_dis + temp_adj
distances = distances.at[v].set(
distances[v]
+ (temp_adj != 0)
* (visited[v] == 0)
* (dist_new < distances[v])
* (dist_new - distances[v])
)
return distances
61 changes: 61 additions & 0 deletions sml/manifold/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_binary")

package(default_visibility = ["//visibility:public"])

py_binary(
name = "Isomap_emul",
srcs = ["Isomap_emul.py"],
deps = [
"//sml/utils:emulation",
"//sml/manifold:jacobi",
"//sml/manifold:dijkstra",
"//sml/manifold:MDS",
"//sml/manifold:kneighbors",
"//sml/manifold:floyd",
"//sml/manifold:SE",
],
)

py_binary(
name = "se_emul",
srcs = ["se_emul.py"],
deps = [
"//sml/utils:emulation",
"//sml/manifold:jacobi",
"//sml/manifold:dijkstra",
"//sml/manifold:MDS",
"//sml/manifold:kneighbors",
"//sml/manifold:floyd",
"//sml/manifold:SE",
],
)

py_binary(
name = "test_emul",
srcs = ["test_emul.py"],
deps = [
"//sml/utils:emulation",
"//sml/manifold:jacobi",
"//sml/manifold:dijkstra",
"//sml/manifold:MDS",
"//sml/manifold:kneighbors",
"//sml/manifold:floyd",
"//sml/manifold:SE",
],
)


Loading

0 comments on commit c18d773

Please sign in to comment.