diff --git a/libspu/device/pphlo/pphlo_intrinsic_executor.cc b/libspu/device/pphlo/pphlo_intrinsic_executor.cc index 7bfafc3b..75c10a93 100644 --- a/libspu/device/pphlo/pphlo_intrinsic_executor.cc +++ b/libspu/device/pphlo/pphlo_intrinsic_executor.cc @@ -67,7 +67,15 @@ std::vector intrinsic_dispatcher(SPUContext* ctx, return {inputs[0]}; } - // DO-NOT-EDIT: Add_DISPATCH_CODE + + if (name == "permute") { + SPU_ENFORCE(inputs.size() == 2); + absl::Span input_span(inputs.data(), inputs.size()); + return kernel::hal::internal::apply_inv_perm(ctx, input_span, inputs[1]); + + } + + // Default: Identity function if (name == "example") { diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc index b7c34ff2..d0f0a3aa 100644 --- a/libspu/kernel/hal/permute.cc +++ b/libspu/kernel/hal/permute.cc @@ -678,6 +678,7 @@ spu::Value _apply_inv_perm_ss(SPUContext *ctx, const spu::Value &x, return std::move(ret[0]); } + // Ref: https://eprint.iacr.org/2019/695.pdf // Algorithm 5: Composition of two share-vector permutations // @@ -949,6 +950,8 @@ spu::Value gen_inv_perm(SPUContext *ctx, absl::Span inputs, return inv_perm; } + + std::vector apply_inv_perm(SPUContext *ctx, absl::Span inputs, const spu::Value &perm) { @@ -978,6 +981,8 @@ std::vector apply_inv_perm(SPUContext *ctx, } } + + // Secure Radix Sort // Ref: // https://eprint.iacr.org/2019/695.pdf diff --git a/libspu/kernel/hal/permute.h b/libspu/kernel/hal/permute.h index 9025bbec..d34de520 100644 --- a/libspu/kernel/hal/permute.h +++ b/libspu/kernel/hal/permute.h @@ -79,4 +79,12 @@ std::vector topk_1d(SPUContext *ctx, const spu::Value &input, const SimpleCompFn &scalar_cmp, const TopKConfig &config); -} // namespace spu::kernel::hal \ No newline at end of file +} // namespace spu::kernel::hal + + +namespace spu::kernel::hal::internal{ +std::vector apply_inv_perm(SPUContext *ctx, + absl::Span inputs, + const spu::Value &perm); + +} \ No newline at end of file diff --git a/sml/decomposition/BUILD.bazel b/sml/manifold/BUILD.bazel similarity index 64% rename from sml/decomposition/BUILD.bazel rename to sml/manifold/BUILD.bazel index c80a6751..b3b1b42c 100644 --- a/sml/decomposition/BUILD.bazel +++ b/sml/manifold/BUILD.bazel @@ -1,4 +1,4 @@ -# Copyright 2023 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,12 +17,31 @@ load("@rules_python//python:defs.bzl", "py_library") package(default_visibility = ["//visibility:public"]) py_library( - name = "pca", - srcs = ["pca.py"], - deps = ["//sml/utils:extmath"], + name = "jacobi", + srcs = ["jacobi.py"], ) py_library( - name = "nmf", - srcs = ["nmf.py"], + name = "dijkstra", + srcs = ["dijkstra.py"], +) + +py_library( + name = "MDS", + srcs = ["MDS.py"], +) + +py_library( + name = "kneighbors", + srcs = ["kneighbors.py"], +) + +py_library( + name = "floyd", + srcs = ["floyd.py"], +) + +py_library( + name = "SE", + srcs = ["SE.py"], ) diff --git a/sml/manifold/MDS.py b/sml/manifold/MDS.py new file mode 100644 index 00000000..ba4352f2 --- /dev/null +++ b/sml/manifold/MDS.py @@ -0,0 +1,46 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax +import jax.numpy as jnp + +from sml.manifold.jacobi import Jacobi + +def mds(D, num_samples, n_components): + D_2 = jnp.square(D) + B = jnp.zeros((num_samples, num_samples)) + B = -0.5 * D_2 + # 按行求和(使用英语的注释) + dist_2_i = jnp.sum(B, axis=1) + dist_2_i = dist_2_i / num_samples + # 按列求和 + dist_2_j = dist_2_i.T + # 全部求和 + dist_2 = jnp.sum(dist_2_i) + dist_2 = dist_2 / (num_samples) + for i in range(num_samples): + for j in range(num_samples): + B = B.at[i, j].set(B[i][j] - dist_2_i[i] - dist_2_j[j] + dist_2) + + values, vectors = Jacobi(B, num_samples) + values = jnp.diag(values) + values = jnp.array(values) + values = jnp.expand_dims(values, axis=1).repeat(vectors.shape[1], axis=1) + values,vectors=jax.lax.sort_key_val(values.T,vectors.T) + vectors=vectors[:,num_samples - n_components:num_samples] + values=values[0,num_samples - n_components:num_samples] + values = jnp.sqrt(jnp.diag(values)) + + ans = jnp.dot(vectors, values) + + return B, ans, values, vectors diff --git a/sml/manifold/SE.py b/sml/manifold/SE.py new file mode 100644 index 00000000..0076f843 --- /dev/null +++ b/sml/manifold/SE.py @@ -0,0 +1,46 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax +import jax.numpy as jnp +import numpy as np +import spu.intrinsic as si +from sml.manifold.jacobi import Jacobi + +def se(X, num_samples, D, n_components): + X, Q = Jacobi(X, num_samples) + X = jnp.diag(X) + X = jnp.array(X) + # perm = jnp.argsort(X) + X2 = jnp.expand_dims(X, axis=1).repeat(Q.shape[1], axis=1) + X3,ans=jax.lax.sort_key_val(X2.T,Q.T) + ans=ans[:,1:n_components + 1] + D = jnp.diag(D) + ans = ans.T * jnp.reciprocal(jnp.sqrt(D)) + return ans.T + + +def normalization( + adjacency, # 邻接矩阵 + norm_laplacian=True, # 如果为 True,使用对称归一化拉普拉斯矩阵;如果为 False,使用非归一化的拉普拉斯矩阵。 +): + D = jnp.sum(adjacency, axis=1) + D = jnp.diag(D) + + L = D - adjacency + D2 = jnp.diag(jnp.reciprocal(jnp.sqrt(jnp.diag(D)))) + if norm_laplacian == True: + # 归一化 + L = jnp.dot(D2, L) + L = jnp.dot(L, D2) + return D, L \ No newline at end of file diff --git a/sml/manifold/dijkstra.py b/sml/manifold/dijkstra.py new file mode 100644 index 00000000..8de43c14 --- /dev/null +++ b/sml/manifold/dijkstra.py @@ -0,0 +1,110 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax +import jax.numpy as jnp + + +def set_value(x, index, value, n): + # 将数组x的index索引处的值修改为value,其中index是秘密共享的 + perm = jnp.zeros(n, dtype=jnp.int16) + perm_2 = jnp.zeros(n, dtype=jnp.int16) + for i in range(n): + perm = perm.at[i].set(i) + perm_2 = perm_2.at[i].set(index) + flag = jnp.equal(perm, perm_2) + set_x = jnp.select([flag], [value], x) + + return set_x + + +def get_value_1(x, index, n): + # 获得x[index]索引处的值,其中index是秘密共享的 + perm = jnp.zeros(n, dtype=jnp.int16) + perm_2 = jnp.zeros(n, dtype=jnp.int16) + for i in range(n): + perm = perm.at[i].set(i) + perm_2 = perm_2.at[i].set(index) + flag = jnp.equal(perm, perm_2) + return jnp.sum(flag * x) + + +def get_value_2(x, index_1, index_2, n): + # 获得x[index_1][index_2]索引处的值,其中index_2是明文,index_1是秘密共享的 + # 初始化行索引 + perm_1 = jnp.zeros((n, n), dtype=jnp.int16) + perm_2_row = jnp.zeros((n, n), dtype=jnp.int16) + + for i in range(n): + for j in range(n): + perm_1 = perm_1.at[i, j].set(i) + perm_2_row = perm_2_row.at[i, j].set(index_1) + + # 行匹配 + flag_row = jnp.equal(perm_1, perm_2_row) + + # 使用明文 index_2 直接提取列的值 + flag = flag_row[:, index_2] + + # 返回匹配索引处的值 + return jnp.sum(flag * x[:, index_2]) + + +def mpc_dijkstra(adj_matrix, num_samples, start, dist_inf): + # adj_matrix:要求最短路径的邻接矩阵 + # num_samples:邻接矩阵的大小 + # start:要计算所有点到点start的最短路径 + # dis_inf:所有点到点start的初始最短路径,设置为inf + + # 用inf值初始化 + + sinf = dist_inf[0] + distances = dist_inf + + # 使用 Dijkstra 算法计算从起始点到其他点的最短路径 + distances = distances.at[start].set(0) + # visited = [False] * num_samples + visited = jnp.zeros(num_samples, dtype=bool) # 初始化为 False 的数组 + visited = jnp.array(visited) + + for i in range(num_samples): + # 找到当前未访问的最近节点 + + min_distance = sinf + min_index = -1 + for v in range(num_samples): + flag = (visited[v] == 0) * (distances[v] < min_distance) + min_distance = min_distance + flag * (distances[v] - min_distance) + min_index = min_index + flag * (v - min_index) + # min_distance = jax.lax.cond(flag, lambda _: distances[v], lambda _: min_distance) + # min_index = jax.lax.cond(flag, lambda _: v, lambda _: min_index) + + # 标记为已访问 + # jax.lax.dynamic_update_slice(visited, 1, (min_index,)) + # visited[min_index] = True + visited = set_value(visited, min_index, True, num_samples) + + # 更新邻接节点的距离 + temp_dis = get_value_1(distances, min_index, num_samples) + + for v in range(num_samples): + temp_adj = get_value_2(adj_matrix, min_index, v, num_samples) + dist_new = temp_dis + temp_adj + distances = distances.at[v].set( + distances[v] + + (temp_adj != 0) + * (visited[v] == 0) + * (dist_new < distances[v]) + * (dist_new - distances[v]) + ) + return distances diff --git a/sml/manifold/emulations/BUILD.bazel b/sml/manifold/emulations/BUILD.bazel new file mode 100644 index 00000000..7d768abf --- /dev/null +++ b/sml/manifold/emulations/BUILD.bazel @@ -0,0 +1,61 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "Isomap_emul", + srcs = ["Isomap_emul.py"], + deps = [ + "//sml/utils:emulation", + "//sml/manifold:jacobi", + "//sml/manifold:dijkstra", + "//sml/manifold:MDS", + "//sml/manifold:kneighbors", + "//sml/manifold:floyd", + "//sml/manifold:SE", + ], +) + +py_binary( + name = "se_emul", + srcs = ["se_emul.py"], + deps = [ + "//sml/utils:emulation", + "//sml/manifold:jacobi", + "//sml/manifold:dijkstra", + "//sml/manifold:MDS", + "//sml/manifold:kneighbors", + "//sml/manifold:floyd", + "//sml/manifold:SE", + ], +) + +py_binary( + name = "test_emul", + srcs = ["test_emul.py"], + deps = [ + "//sml/utils:emulation", + "//sml/manifold:jacobi", + "//sml/manifold:dijkstra", + "//sml/manifold:MDS", + "//sml/manifold:kneighbors", + "//sml/manifold:floyd", + "//sml/manifold:SE", + ], +) + + diff --git a/sml/manifold/emulations/Isomap_emul.py b/sml/manifold/emulations/Isomap_emul.py new file mode 100644 index 00000000..2f1819ef --- /dev/null +++ b/sml/manifold/emulations/Isomap_emul.py @@ -0,0 +1,162 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax +import jax.numpy as jnp +import numpy as np +from sklearn.manifold import Isomap +from sklearn.neighbors import kneighbors_graph + +import sml.utils.emulation as emulation +from sml.manifold.dijkstra import mpc_dijkstra +from sml.manifold.kneighbors import mpc_kneighbors_graph +from sml.manifold.MDS import mds +from sml.manifold.floyd import floyd_opt + +def emul_cpz(mode: emulation.Mode.MULTIPROCESS): + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 + ) + emulator.up() + + def mpc_isomap_dijkstra( + sX, + mpc_dist_inf, + mpc_shortest_paths, + num_samples, + num_features, + k, + num_components, + ): + Knn = mpc_kneighbors_graph(sX, num_samples, num_features, k) + + # for i in range(num_samples): + # distances = mpc_dijkstra(Knn, num_samples, i, mpc_dist_inf) + # mpc_shortest_paths = mpc_shortest_paths.at[i].set(distances) + def compute_distances_for_sample(i, Knn, num_samples, mpc_dist_inf): + return mpc_dijkstra(Knn, num_samples, i, mpc_dist_inf) + + # 使用 vmap 来并行化计算每个样本的最短路径 + compute_distances = jax.vmap( + compute_distances_for_sample, in_axes=(0, None, None, None) + ) + + # 并行化执行所有样本的最短路径计算 + indices = jnp.arange(num_samples) # 样本索引 + mpc_shortest_paths = compute_distances( + indices, Knn, num_samples, mpc_dist_inf + ) + B, ans, values, vectors = mds( + mpc_shortest_paths, num_samples, num_components + ) + return Knn, mpc_shortest_paths, B, ans, values, vectors + + def mpc_isomap_floyd( + sX, + mpc_dist_inf, + mpc_shortest_paths, + num_samples, + num_features, + k, + num_components, + ): + Knn = mpc_kneighbors_graph(sX, num_samples, num_features, k) + mpc_shortest_paths=floyd_opt(Knn) + B, ans, values, vectors = mds( + mpc_shortest_paths, num_samples, num_components + ) + return Knn, mpc_shortest_paths, B, ans, values, vectors + + + # 设置样本数量和维度 + num_samples = 6 + num_features = 3 + k = 3 + num_components = 2 + X = jnp.array( + [ + [0.122, 0.114, 0.64], + [0.136, 0.204, 0.25], + [0.11, 0.145, 0.24], + [0.16, 0.81, 0.91], + [0.209, 0.122, 0.76], + [0.148, 0.119, 0.15], + ] + ) + dist_inf = jnp.full(num_samples, np.inf) + shortest_paths = jnp.zeros((num_samples, num_samples)) + + sX, mpc_dist_inf, mpc_shortest_paths = emulator.seal( + X, dist_inf, shortest_paths + ) + # Knn, mpc_shortest_paths, B, ans, values, vectors = emulator.run( + # mpc_isomap_dijkstra, static_argnums=(3, 4, 5, 6) + # )( + # sX, + # mpc_dist_inf, + # mpc_shortest_paths, + # num_samples, + # num_features, + # k, + # num_components, + # ) + + # print('shortest_paths: \n',mpc_shortest_paths) + # print('Knn: \n',Knn) + # print('B: \n', B) + # print('ans: \n', ans) + # print('values: \n', values) + # print('vectors: \n', vectors) + Knn, mpc_shortest_paths, B, ans, values, vectors = emulator.run( + mpc_isomap_floyd, static_argnums=(3, 4, 5, 6) + )( + sX, + mpc_dist_inf, + mpc_shortest_paths, + num_samples, + num_features, + k, + num_components, + ) + print('ans: \n', ans) + # sklearn test + affinity_matrix = kneighbors_graph( + X, n_neighbors=k, mode="distance", include_self=False + ) + # print('affinity_matrix1: \n',affinity_matrix.toarray()) + # 使矩阵对称 + affinity_matrix = 0.5 * (affinity_matrix + affinity_matrix.T) + # print('affinity_matrix2: \n',affinity_matrix.toarray()) + + # dist_matrix = shortest_path(affinity_matrix, method="D", directed=False) + # print('dist_matrix: \n',dist_matrix) + + affinity_matrix = affinity_matrix.toarray() + for i in range(1, num_samples): + for j in range(i): + if affinity_matrix[i][j] == 0: + affinity_matrix[i][j] = 10000 + affinity_matrix[j][i] = 10000 + + embedding = Isomap(n_components=num_components, metric='precomputed') + X_transformed = embedding.fit_transform(affinity_matrix) + print('X_transformed: \n', X_transformed) + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_cpz(emulation.Mode.MULTIPROCESS) diff --git a/sml/manifold/emulations/se_emul.py b/sml/manifold/emulations/se_emul.py new file mode 100644 index 00000000..e48d3179 --- /dev/null +++ b/sml/manifold/emulations/se_emul.py @@ -0,0 +1,100 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +import jax +import jax.numpy as jnp +import numpy as np +from sklearn.manifold import spectral_embedding +from sklearn.neighbors import kneighbors_graph + +import sml.utils.emulation as emulation +from sml.manifold.SE import normalization, se +from sml.manifold.kneighbors import mpc_kneighbors_graph +import spu.intrinsic as si + +def emul_cpz(mode: emulation.Mode.MULTIPROCESS): + try: + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 + ) + emulator.up() + + def SE(sX, num_samples, num_features, k, num_components): + Knn = mpc_kneighbors_graph(sX, num_samples, num_features, k) + D, L = normalization(Knn) + ans = se(L, num_samples, D, num_components) + return ans + + # 设置样本数量和维度 + num_samples = 6 + num_features = 3 + k = 3 + num_components = 2 + seed = int(time.time()) + key = jax.random.PRNGKey(seed) + X = jax.random.uniform(key, shape=(num_samples, num_features), minval=0.0, maxval=1.0) + # X = np.array( + # [ + # [0.122, 0.114, 0.64], + # [0.136, 0.204, 0.25], + # [0.11, 0.145, 0.24], + # [0.16, 0.81, 0.91], + # [0.209, 0.122, 0.76], + # [0.148, 0.119, 0.15], + # ] + # ) + #m_ans=SE(X, num_samples, num_features, k, num_components) + + sX = emulator.seal(X) + ans = emulator.run( + SE, + static_argnums=( + 1, + 2, + 3, + 4, + ), + )(sX, num_samples, num_features, k, num_components) + + print('ans: \n', ans) + + # for i in range(num_samples): + # print(f"\n验证第 {i+1} 个特征值和特征向量:") + # print("A @ v =\n", L @ Q[i, :]) + # print("λ * v =\n", X2[i][i] * Q[i, :]) + + # sklearn test + affinity_matrix = kneighbors_graph( + X, n_neighbors=3, mode="distance", include_self=False + ) + # print('affinity_matrix1: \n',affinity_matrix.toarray()) + # 使矩阵对称 + affinity_matrix = 0.5 * (affinity_matrix + affinity_matrix.T) + # print('affinity_matrix2: \n',affinity_matrix.toarray()) + embedding = spectral_embedding( + affinity_matrix, n_components=num_components, random_state=None + ) + print('embedding: \n', embedding) + + # max_abs_diff = jnp.max(jnp.abs(jnp.abs(embedding) - jnp.abs(ans.T))) + # print(max_abs_diff) + + # m_max_abs_diff = jnp.max(jnp.abs(jnp.abs(embedding) - jnp.abs(m_ans.T))) + # print(m_max_abs_diff) + finally: + emulator.down() + + +if __name__ == "__main__": + emul_cpz(emulation.Mode.MULTIPROCESS) diff --git a/sml/manifold/emulations/test_emul.py b/sml/manifold/emulations/test_emul.py new file mode 100644 index 00000000..6d2ccc44 --- /dev/null +++ b/sml/manifold/emulations/test_emul.py @@ -0,0 +1,134 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax +import jax.numpy as jnp +import numpy as np + +import sml.utils.emulation as emulation +from sml.manifold.dijkstra import mpc_dijkstra +from sml.manifold.floyd import floyd +from sml.manifold.floyd import floyd_opt + + +def emul_cpz(mode: emulation.Mode.MULTIPROCESS): + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 + ) + emulator.up() + + def dijkstra_all_pairs( + Knn, + mpc_dist_inf, + num_samples, + ): + + def compute_distances_for_sample(i, Knn, num_samples, mpc_dist_inf): + return mpc_dijkstra(Knn, num_samples, i, mpc_dist_inf) + + # 使用 vmap 来并行化计算每个样本的最短路径 + compute_distances = jax.vmap( + compute_distances_for_sample, in_axes=(0, None, None, None) + ) + + # 并行化执行所有样本的最短路径计算 + indices = jnp.arange(num_samples) # 样本索引 + mpc_shortest_paths = compute_distances( + indices, Knn, num_samples, mpc_dist_inf + ) + + return mpc_shortest_paths + + # 设置样本数量和维度 + num_samples = 20 + dist_inf = jnp.full(num_samples, np.inf) + + # 初始化邻接矩阵 + Knn = np.random.rand(num_samples, num_samples) + # Knn = (Knn + Knn.T) * 100 + Knn = (Knn + Knn.T) / 2 + Knn[Knn == 0] = np.inf + np.fill_diagonal(Knn, 0) + # print("\nadjacency_matrix:\n") + # for row in Knn: + # print(row) + mpc_dist_inf = emulator.seal(dist_inf) + + # dijkstra_all_pairs + Knn=emulator.seal(Knn) + shortest_paths_dijkstra= emulator.run( + dijkstra_all_pairs, static_argnums=(2,) + )( + Knn, + mpc_dist_inf, + num_samples + ) + + # floyd_opt + # shortest_paths_floyd= emulator.run(floyd)(Knn) + shortest_paths_opt_floyd= emulator.run(floyd_opt)(Knn) + # are_equal = np.array_equal(shortest_paths_dijkstra, shortest_paths_floyd) + # if are_equal: + # print("计算结果相同。") + # else: + # print("计算结果不同!") + + print("\nshortest_paths_dijkstra:\n") + for row in shortest_paths_dijkstra: + print(row) + + # print("\nshortest_paths_floyd:\n") + # for row in shortest_paths_floyd: + # print(row) + + print("\nshortest_paths_opt_floyd:\n") + for row in shortest_paths_opt_floyd: + print(row) + + + + + + + + # # sklearn test + # affinity_matrix = kneighbors_graph( + # X, n_neighbors=k, mode="distance", include_self=False + # ) + # # print('affinity_matrix1: \n',affinity_matrix.toarray()) + # # 使矩阵对称 + # affinity_matrix = 0.5 * (affinity_matrix + affinity_matrix.T) + # # print('affinity_matrix2: \n',affinity_matrix.toarray()) + + # # dist_matrix = shortest_path(affinity_matrix, method="D", directed=False) + # # print('dist_matrix: \n',dist_matrix) + + # affinity_matrix = affinity_matrix.toarray() + # for i in range(1, num_samples): + # for j in range(i): + # if affinity_matrix[i][j] == 0: + # affinity_matrix[i][j] = 10000 + # affinity_matrix[j][i] = 10000 + + # embedding = Isomap(n_components=num_components, metric='precomputed') + # X_transformed = embedding.fit_transform(affinity_matrix) + # print('X_transformed: \n', X_transformed) + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_cpz(emulation.Mode.MULTIPROCESS) diff --git a/sml/manifold/floyd.py b/sml/manifold/floyd.py new file mode 100644 index 00000000..05e157ce --- /dev/null +++ b/sml/manifold/floyd.py @@ -0,0 +1,86 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax.numpy as jnp +import numpy + +def floyd( + dist +): + dist=(dist==0)*jnp.inf+dist + dist=jnp.where(jnp.eye(dist.shape[0]),0,dist) + n = len(dist) + for k in range(n): + for i in range(n): + for j in range(n): + dist = dist.at[i, j].set(jnp.minimum(dist[i, j], dist[i, k] + dist[k, j])) + + return dist + +def floyd_opt( + dist +): + dist=(dist==0)*jnp.inf+dist + dist=jnp.where(jnp.eye(dist.shape[0]),0,dist) + n = len(dist) + + for k in range(n): + # # 打包计算batch_2。 + # dist_kk = jnp.full(n-1, dist[k][k]) + + # batch_1 = jnp.delete(dist[k], k) + # batch_1 = jnp.minimum(batch_1, batch_1 + dist_kk) + + # dist = dist.at[k].set(jnp.insert(batch_1, k, dist[k][k])) # 把更新的值放回原位置 + # dist = dist.at[:, k].set(dist[k]) + + # 打包计算batch_3 + batch_2 = dist + batch_2 = jnp.delete(batch_2, k, axis=0) + col_k_without_dkk = batch_2[:, k] + batch_2 = jnp.delete(batch_2, k, axis=1) + dist_ik = jnp.zeros_like(batch_2) + dist_kj = jnp.zeros_like(batch_2) + + for i in range(n-1): + if(i < k): + dist_ik = dist_ik.at[i].set(jnp.full(n-1, dist[i][k])) + else: + dist_ik = dist_ik.at[i].set(jnp.full(n-1, dist[i+1][k])) + + for j in range(n-1): + if(j < k): + dist_kj = dist_kj.at[:, j].set(jnp.full(n-1, dist[k][j])) + else: + dist_kj = dist_kj.at[:, j].set(jnp.full(n-1, dist[k][j+1])) + # 能替换成这个吗?哪个效率更高? + # dist_kj = dist_ik. + + # 把上三角拿出来算 + indices = numpy.triu_indices(batch_2.shape[0], k=1) + batch_2_upper_triangle = batch_2[indices] + dist_ik_upper_triangle = dist_ik[indices] + dist_kj_upper_triangle = dist_kj[indices] + + batch_2_upper_triangle = jnp.minimum(batch_2_upper_triangle, dist_ik_upper_triangle + dist_kj_upper_triangle) + + # 把上三角放回去 + batch_2 = jnp.zeros_like(batch_2) + batch_2 = batch_2.at[indices].set(batch_2_upper_triangle) + batch_2 += batch_2.T + + batch_2 = jnp.insert(batch_2, k, col_k_without_dkk, axis=1) # 把更新的值放回原位置 + batch_2 = jnp.insert(batch_2, k, dist[k], axis=0) + dist = batch_2 + + return dist diff --git a/sml/manifold/jacobi.py b/sml/manifold/jacobi.py new file mode 100644 index 00000000..05f183e9 --- /dev/null +++ b/sml/manifold/jacobi.py @@ -0,0 +1,186 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax +import jax.numpy as jnp +import numpy as np +import spu.intrinsic as si +# def SelectElement(X, num_samples): +# num_zero = 0 +# for i in range(num_samples): +# for j in range(num_samples): +# num_zero += jnp.equal(X[i][j], 0) +# X = X.flatten() +# X_ij = jnp.zeros((num_samples * num_samples, 2)) +# for i in range(num_samples): +# for j in range(num_samples): +# X_ij = X_ij.at[i * num_samples + j, 0].set(i) +# X_ij = X_ij.at[i * num_samples + j, 1].set(j) +# perm = jnp.argsort(X) +# X_ij = jnp.take(X_ij, perm, axis=0) +# return X_ij, num_zero + + +# def get_ij(X_ij, i, num_samples, num_zero): +# return ( +# X_ij[num_samples * (num_samples - 1) - i - num_zero - 1][0], +# X_ij[num_samples * (num_samples - 1) - i - 1 - num_zero][1], +# ) + + +def set_value_2d(x, index_1, index_2, value, n): + # 将x[index_1][index_2]处的值设置为value,其中index_1和index_2是秘密共享的 + + # 创建两维度的索引矩阵 + row_indices = jnp.zeros((n, n), dtype=jnp.int16) + col_indices = jnp.zeros((n, n), dtype=jnp.int16) + for i in range(n): + for j in range(n): + row_indices = row_indices.at[i, j].set(i) # 设置行索引 + col_indices = col_indices.at[i, j].set(j) # 设置列索引 + + # 比较行和列索引是否等于传入的秘密共享索引 + flag_1 = jnp.equal(row_indices, index_1) + flag_2 = jnp.equal(col_indices, index_2) + + # 同时满足两个索引条件 + flag = jnp.logical_and(flag_1, flag_2) + + # 根据标志位设置数组值 + set_x = jnp.select([flag], [value], x) + + return set_x + + +# def Rotation_Matrix(X, k, l, n): +# #根据选择的X[k][l]计算旋转矩阵J +# J = jnp.eye(n) +# tar_elements = X[k][l] +# tar_diff = X[k][k] - X[l][l] +# # cos_2theta=jnp.abs(tar_diff)*jnp.reciprocal(jnp.sqrt(4*tar_elements*tar_elements+tar_diff*tar_diff)) +# cos_2theta = jnp.reciprocal( +# jnp.sqrt( +# 1 +# + 4 +# * tar_elements +# * tar_elements +# * jnp.reciprocal(tar_diff) +# * jnp.reciprocal(tar_diff) +# ) +# ) +# cos2 = 0.5 + 0.5 * cos_2theta +# sin2 = 0.5 - 0.5 * cos_2theta +# flag_zero = jnp.equal(tar_elements, 0) +# cos = jnp.sqrt(cos2) * (1 - flag_zero) + flag_zero +# sin = ( +# (jnp.where(jnp.logical_and(tar_elements == 0, tar_diff == 0), 0, 1)) +# * jnp.sqrt(sin2) +# * ((jnp.greater(tar_elements * tar_diff, 0)) * 2 - 1) +# ) + +# J = set_value_2d(J, k, k, cos, n) +# J = set_value_2d(J, l, l, cos, n) +# J = set_value_2d(J, k, l, -sin, n) +# J = set_value_2d(J, l, k, sin, n) +# return J + + +def compute_elements(X, k, l, n): + tar_elements = X[k][l] + tar_diff = X[k][k] - X[l][l] + + cos_2theta = jnp.reciprocal( + jnp.sqrt( + 1 + + 4*jnp.square(tar_elements*jnp.reciprocal(tar_diff)) + ) + ) + cos2 = 0.5 + 0.5 * cos_2theta + sin2 = 0.5 - 0.5 * cos_2theta + flag_zero = jnp.equal(tar_elements, 0) + cos = jnp.sqrt(cos2) * (1 - flag_zero) + flag_zero + sin = ( + (jnp.where(jnp.logical_and(tar_elements == 0, tar_diff == 0), 0, 1)) + * jnp.sqrt(sin2) + * ((jnp.greater(tar_elements * tar_diff, 0)) * 2 - 1) + ) + + return cos, sin + + +def update_J(J, k, l, cos, sin, n): + J = set_value_2d(J, k, k, cos, n) + J = set_value_2d(J, l, l, cos, n) + J = set_value_2d(J, k, l, -sin, n) + J = set_value_2d(J, l, k, sin, n) + # J=J.at[k,k].set(cos) + # J=J.at[l,l].set(cos) + # J=J.at[k,l].set(-sin) + # J=J.at[l,k].set(sin) + + return J + + +def Rotation_Matrix(X, k, l, n, m,k_0,l_0): + # 根据选择的X[k][l]计算旋转矩阵J + J = jnp.eye(n) + k_values = jnp.array(k) # 确保 k 和 l 是 JAX 数组 + l_values = jnp.array(l) + + # 使用 vmap 进行并行化 + cos_values, sin_values = jax.vmap(compute_elements, in_axes=(None, 0, 0, None))( + X, k_values, l_values, n + ) + + # 更新 J + for i in range(len(k_values)): + t_k=k_0-i + t_l=l_0+i + J=J.at[t_k,t_k].set(cos_values[i]) + J=J.at[t_l,t_l].set(cos_values[i]) + J=J.at[t_k,t_l].set(-sin_values[i]) + J=J.at[t_l,t_k].set(sin_values[i]) + # should not be here x=Value<1x1xSF32,s=0,0>, to=Pub2k + # J = update_J(J, t_k, t_l, cos_values[i], sin_values[i], n) + + return J + + +def Jacobi(X, num_samples): + Q = jnp.eye(num_samples) + k = 0 + while k < 5: + for i in range(1, 2 * num_samples - 2): + if i < num_samples: + l_0 = i + r_0 = 0 + else: + l_0 = num_samples - 1 + r_0 = i - l_0 + + n = (l_0 - r_0 - 1) // 2 + 1 + l = jnp.zeros(n, dtype=jnp.int16) + r = jnp.zeros(n, dtype=jnp.int16) + # 选取索引各不相同的位置 + for j in range(0, n): + l = l.at[j].set(l_0 - j) + r = r.at[j].set(r_0 + j) + # 计算旋转矩阵 + J = Rotation_Matrix(X, l, r, num_samples, n,l_0,r_0) + # 用旋转矩阵更新X和Q + X = jnp.dot(J.T, jnp.dot(X, J)) + Q = jnp.dot(J.T, Q) + k = k + 1 + + return X, Q + diff --git a/sml/manifold/kneighbors.py b/sml/manifold/kneighbors.py new file mode 100644 index 00000000..a1345df9 --- /dev/null +++ b/sml/manifold/kneighbors.py @@ -0,0 +1,70 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax +import jax.numpy as jnp +import spu.intrinsic as si + +def mpc_kneighbors_graph( + X, # 要计算最近邻的输入样本 + num_samples, # 样本数量 + num_features, # 样本维度 + n_neighbors, # 定义最近邻的个数,不包括样本自己 + *, + mode="distance", + metric="minkowski", # 距离定义为样本之间的欧几里得距离 + p=2, +): + + # 计算每两个samples之间的欧几里得距离的平方 + X_expanded = jnp.expand_dims(X, axis=1) - jnp.expand_dims(X, axis=0) + X_expanded = jnp.square(X_expanded) + Dis = jnp.sum(X_expanded, axis=-1) + + # 对Dis的每一行进行排序,首先计算置换,之后将置换应用于Dis + Indix_Dis = jnp.argsort(Dis, axis=1) + + Knn=jnp.zeros((num_samples, num_samples)) + for i in range(num_samples): + temp_pi= jnp.arange(num_samples) + per_dis=si.permute(Dis[i], si.permute(temp_pi,Indix_Dis[i])) + for j in range(num_samples): + Knn=Knn.at[i,j].set(per_dis[j]) + + # 对之前求的最近邻的欧几里得距离的平方求平方根,非最近邻的距离设置为0 + Knn2 = jnp.zeros((num_samples, num_samples)) + def update_knn_row(i, Knn_row, n_neighbors): + def update_element(j, Knn_value): + return jnp.where(j <= n_neighbors, jnp.sqrt(Knn_value), 0) + + # Vectorize the inner loop over `j` + Knn_row_updated = jax.vmap(update_element, in_axes=(0, 0))( + jnp.arange(Knn_row.shape[0]), Knn_row + ) + return Knn_row_updated + + # Vectorize the outer loop over `i` + Knn2 = jax.vmap(lambda i, Knn_row: update_knn_row(i, Knn_row, n_neighbors))( + jnp.arange(num_samples), Knn + ) + + # 对Dis进行逆置换,恢复之前的顺序 + Knn3 = jnp.zeros((num_samples, num_samples)) + for i in range(num_samples): + per_dis=si.permute(Knn2[i], Indix_Dis[i]) + for j in range(num_samples): + Knn3=Knn3.at[i,j].set(per_dis[j]) + + # 使最近邻矩阵对称 + Knn4 = 0.5 * (Knn3 + Knn3.T) + return Knn4 diff --git a/spu/intrinsic/BUILD.bazel b/spu/intrinsic/BUILD.bazel index 691e3fa4..5928ed5e 100644 --- a/spu/intrinsic/BUILD.bazel +++ b/spu/intrinsic/BUILD.bazel @@ -25,6 +25,7 @@ py_library( deps = [ ":example", ":example_binary", + ":permute", # DO-NOT-EDIT:ADD_IMPORT ], ) @@ -48,3 +49,15 @@ py_library( "//visibility:private", ], ) + +py_library( + name = "permute", + srcs = [ + "permute_impl.py", + ], + visibility = [ + "//visibility:private", + ], +) + + \ No newline at end of file diff --git a/spu/intrinsic/__init__.py b/spu/intrinsic/__init__.py index 3559fa8a..5d454d7f 100644 --- a/spu/intrinsic/__init__.py +++ b/spu/intrinsic/__init__.py @@ -15,10 +15,12 @@ from .example_binary_impl import example_binary from .example_impl import example +from .permute_impl import permute # DO-NOT-EDIT:ADD_IMPORT __all__ = [ # "example", # "example_binary", - # DO-NOT-EDIT:EOL + "permute", + # DO-NOT-EDIT:EOL ] diff --git a/spu/intrinsic/permute_impl.py b/spu/intrinsic/permute_impl.py new file mode 100644 index 00000000..8fa31298 --- /dev/null +++ b/spu/intrinsic/permute_impl.py @@ -0,0 +1,83 @@ +__all__ = ["permute"] + +from functools import partial + +from jax import core, dtypes +from jax.core import ShapedArray +from jax.interpreters import ad, batching, mlir, xla + +# from jax.lib import xla_client +from jaxlib.hlo_helpers import custom_call + + +# Public facing interface +def permute(in1, in2): + # Add necessary preprocessing code + return _permute_prim.bind(in1, in2) + + +# ********************************* +# * SUPPORT FOR JIT COMPILATION * +# ********************************* + + +# For JIT compilation we need a function to evaluate the shape and dtype of the +# outputs of our op for some given inputs +def _permute_abstract(in1, in2): + shape = in1.shape + dtype = dtypes.canonicalize_dtype(in1.dtype) + return ShapedArray(shape, dtype) + + +# We also need a lowering rule to provide an MLIR "lowering" of out primitive. +def _permute_lowering(ctx, in1, in2): + # The inputs and outputs all have the same shape and memory layout + # so let's predefine this specification + dtype = mlir.ir.RankedTensorType(in1.type) + shape = dtype.shape + result_type = mlir.ir.RankedTensorType.get(shape, dtype.element_type) + + return custom_call( + "permute", + # Output types + result_types=[result_type], + # The inputs: + operands=[in1, in2], + ).results + + +# ********************************** +# * SUPPORT FOR FORWARD AUTODIFF * +# ********************************** + + +def _permute_jvp(args, tangents): + raise NotImplementedError() + + +# ************************************ +# * SUPPORT FOR BATCHING WITH VMAP * +# ************************************ + + +# Our op already supports arbitrary dimensions so the batching rule is quite +# simple. The jax.lax.linalg module includes some example of more complicated +# batching rules if you need such a thing. +def _permute_batch(args, axes): + raise NotImplementedError() + + +# ********************************************* +# * BOILERPLATE TO REGISTER THE OP WITH JAX * +# ********************************************* +_permute_prim = core.Primitive("permute") +# Change this to True if there are more than 1 output +_permute_prim.multiple_results = False +_permute_prim.def_impl(partial(xla.apply_primitive, _permute_prim)) +_permute_prim.def_abstract_eval(_permute_abstract) + +mlir.register_lowering(_permute_prim, _permute_lowering) + +# Connect the JVP and batching rules +ad.primitive_jvps[_permute_prim] = _permute_jvp +batching.primitive_batchers[_permute_prim] = _permute_batch \ No newline at end of file