From 1efd9434436916ecae3d5495c62540a245348065 Mon Sep 17 00:00:00 2001 From: anakinxc Date: Fri, 23 Feb 2024 15:22:03 +0800 Subject: [PATCH] repo-sync-2024-02-23T15:21:57+0800 --- bazel/repositories.bzl | 6 +- examples/python/README.md | 9 + examples/python/ml/ml_test.py | 4 +- examples/python/pir/BUILD.bazel | 65 ---- examples/python/pir/pir_client.py | 97 ----- examples/python/pir/pir_mem_server.py | 98 ------ examples/python/pir/pir_server.py | 94 ----- examples/python/pir/pir_setup.py | 74 ---- examples/python/psi/BUILD.bazel | 53 --- examples/python/psi/mem_psi.py | 104 ------ examples/python/psi/simple_psi.py | 101 ------ examples/python/psi/unbalanced_psi.py | 225 ------------ spu/BUILD.bazel | 20 +- spu/__init__.py | 3 +- spu/libpsi.cc | 70 +--- spu/pir.py | 60 ---- spu/psi.py | 15 +- spu/tests/BUILD.bazel | 59 +++- spu/tests/legacy_psi_test.py | 445 +++++++++++++++++++++++ spu/tests/pir_test.py | 125 +++++++ spu/tests/psi_test.py | 489 ++++---------------------- spu/tests/ub_psi_test.py | 136 +++++++ spu/tests/utils.py | 51 +++ spu/utils/distributed.py | 4 + spu/utils/frontend.py | 4 +- 25 files changed, 938 insertions(+), 1473 deletions(-) delete mode 100644 examples/python/pir/BUILD.bazel delete mode 100644 examples/python/pir/pir_client.py delete mode 100644 examples/python/pir/pir_mem_server.py delete mode 100644 examples/python/pir/pir_server.py delete mode 100644 examples/python/pir/pir_setup.py delete mode 100644 examples/python/psi/BUILD.bazel delete mode 100644 examples/python/psi/mem_psi.py delete mode 100644 examples/python/psi/simple_psi.py delete mode 100644 examples/python/psi/unbalanced_psi.py delete mode 100644 spu/pir.py create mode 100644 spu/tests/legacy_psi_test.py create mode 100644 spu/tests/pir_test.py create mode 100644 spu/tests/ub_psi_test.py create mode 100644 spu/tests/utils.py diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 09ace3098..ca3e6b1ae 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -51,10 +51,10 @@ def _libpsi(): http_archive, name = "psi", urls = [ - "https://github.com/secretflow/psi/archive/9225bc9626a4ed3e4a8a55b86e400aa5e50e3e93.tar.gz", + "https://github.com/secretflow/psi/archive/refs/tags/v0.2.0.dev240222.tar.gz", ], - strip_prefix = "psi-9225bc9626a4ed3e4a8a55b86e400aa5e50e3e93", - sha256 = "f90e9e9a2a931833ebdf8f08c01c09fe0c1b0f88458f0c18d45b1a548cf2c001", + strip_prefix = "psi-0.2.0.dev240222", + sha256 = "6b3aed24ae5dc4dce7acfc140093612f8753b9643a6a19177483f8462eda061f", ) def _rules_proto_grpc(): diff --git a/examples/python/README.md b/examples/python/README.md index 7011de063..c3eeeb75d 100644 --- a/examples/python/README.md +++ b/examples/python/README.md @@ -9,3 +9,12 @@ To use a specific layout configuration (i.e. change MPC protocol, change outsour > bazel run -c opt //examples/python/utils:nodectl -- -c examples/python/conf/2pc.json up Then please check the comment of each example to run. + +## Examples for PSI/PIR + +Please check tests at this moment: + +- spu/tests/legacy_psi_test.py +- spu/tests/psi_test.py +- spu/tests/ub_psi_test.py +- spu/tests/pir_test.py diff --git a/examples/python/ml/ml_test.py b/examples/python/ml/ml_test.py index f69544b2b..bdd26781a 100644 --- a/examples/python/ml/ml_test.py +++ b/examples/python/ml/ml_test.py @@ -244,7 +244,9 @@ def suite(): # should put JAX tests above suite.addTest(UnitTests('test_tf_experiment')) suite.addTest(UnitTests('test_torch_lr_experiment')) - suite.addTest(UnitTests('test_torch_resnet_experiment')) + # TODO: torch_xla's stablehlo version is not compatibale with SPU, + # reopen when torch_xla upgrade its stablehlo version + # suite.addTest(UnitTests('test_torch_resnet_experiment')) return suite diff --git a/examples/python/pir/BUILD.bazel b/examples/python/pir/BUILD.bazel deleted file mode 100644 index 79cfeae54..000000000 --- a/examples/python/pir/BUILD.bazel +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2022 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@rules_python//python:defs.bzl", "py_binary") - -package(default_visibility = ["//visibility:public"]) - -py_binary( - name = "pir_setup", - srcs = ["pir_setup.py"], - data = [ - "//examples/data", - ], - deps = [ - "//spu:api", - "//spu:pir", - ], -) - -py_binary( - name = "pir_server", - srcs = ["pir_server.py"], - data = [ - "//examples/data", - ], - deps = [ - "//spu:api", - "//spu:pir", - ], -) - -py_binary( - name = "pir_mem_server", - srcs = ["pir_mem_server.py"], - data = [ - "//examples/data", - ], - deps = [ - "//spu:api", - "//spu:pir", - ], -) - -py_binary( - name = "pir_client", - srcs = ["pir_client.py"], - data = [ - "//examples/data", - ], - deps = [ - "//spu:api", - "//spu:pir", - ], -) diff --git a/examples/python/pir/pir_client.py b/examples/python/pir/pir_client.py deleted file mode 100644 index 32a29a186..000000000 --- a/examples/python/pir/pir_client.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# To run the example, start two terminals: -# > python pir_client.py --rank 0 --in_path examples/data/pir_client_data.csv --key_columns id --out_path /tmp/pir_client_out.csv -# - -from absl import app, flags - -import spu.libspu.link as link -import spu.libspu.logging as logging -import spu.pir as pir - -flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") -flags.DEFINE_string("party_ips", "127.0.0.1:61307,127.0.0.1:61308", "party addresses") -flags.DEFINE_string("in_path", "data.csv", "data input path") -flags.DEFINE_string("key_columns", "id", "csv file filed name") -flags.DEFINE_string("out_path", "simple_psi_out.csv", "data output path") -flags.DEFINE_bool("enable_tls", False, "whether to enable tls for link") -flags.DEFINE_string("link_server_certificate", "", "link server certificate file path") -flags.DEFINE_string("link_server_private_key", "", "link server private key file path") -flags.DEFINE_string( - "link_server_ca", "", "ca file used to verify other's link server certificate" -) -flags.DEFINE_string("link_client_certificate", "", "link client certificate file path") -flags.DEFINE_string("link_client_private_key", "", "link client private key file path") -flags.DEFINE_string( - "link_client_ca", "", "ca file used to verify other's link client certificate" -) - -FLAGS = flags.FLAGS - - -def setup_link(rank): - lctx_desc = link.Desc() - lctx_desc.id = f"root" - - lctx_desc.recv_timeout_ms = 30 * 60 * 1000 - # lctx_desc.connect_retry_times = 180 - - ips = FLAGS.party_ips.split(",") - for i, ip in enumerate(ips): - lctx_desc.add_party(f"id_{i}", ip) - print(f"id_{i} = {ip}") - - # config link tls - if FLAGS.enable_tls: - # two-way authentication - lctx_desc.server_ssl_opts.cert.certificate_path = FLAGS.link_server_certificate - lctx_desc.server_ssl_opts.cert.private_key_path = FLAGS.link_server_private_key - lctx_desc.server_ssl_opts.verify.ca_file_path = FLAGS.link_server_ca - lctx_desc.server_ssl_opts.verify.verify_depth = 1 - lctx_desc.client_ssl_opts.cert.certificate_path = FLAGS.link_client_certificate - lctx_desc.client_ssl_opts.cert.private_key_path = FLAGS.link_client_private_key - lctx_desc.client_ssl_opts.verify.ca_file_path = FLAGS.link_client_ca - lctx_desc.client_ssl_opts.verify.verify_depth = 1 - - return link.create_brpc(lctx_desc, rank) - - -def main(_): - opts = logging.LogOptions() - opts.system_log_path = "./tmp/spu.log" - opts.trace_log_path = "./tmp/trace.log" - opts.enable_console_logger = True - opts.log_level = logging.LogLevel.INFO - logging.setup_logging(opts) - - key_columns = FLAGS.key_columns.split(",") - - config = pir.PirClientConfig( - pir_protocol=pir.PirProtocol.Value('KEYWORD_PIR_LABELED_PSI'), - input_path=FLAGS.in_path, - key_columns=key_columns, - output_path=FLAGS.out_path, - ) - - link_ctx = setup_link(FLAGS.rank) - report = pir.pir_client(link_ctx, config) - - print(f"data_count: {report.data_count}") - - -if __name__ == '__main__': - app.run(main) diff --git a/examples/python/pir/pir_mem_server.py b/examples/python/pir/pir_mem_server.py deleted file mode 100644 index 542fc8e5d..000000000 --- a/examples/python/pir/pir_mem_server.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# To run the example, start two terminals: -# > python pir_setup.py --in_path examples/data/pir_server_data.csv --key_columns id --label_columns label \ -# > --count_per_query 1 -max_label_length 256 \ -# > --oprf_key_path oprf_key.bin --setup_path setup_path - -import time - -from absl import app, flags - -import spu.libspu.link as link -import spu.libspu.logging as logging -import spu.pir as pir - -flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") -flags.DEFINE_string("party_ips", "127.0.0.1:61307,127.0.0.1:61308", "party addresses") - -flags.DEFINE_string("in_path", "data.csv", "data input path") -flags.DEFINE_string("key_columns", "id", "csv file key filed name") -flags.DEFINE_string("label_columns", "label", "csv file label filed name") -flags.DEFINE_integer("count_per_query", 1, "count_per_query") -flags.DEFINE_integer("max_label_length", 256, "max_label_length") -flags.DEFINE_string("setup_path", "setup_path", "data output path") - -flags.DEFINE_boolean("compressed", False, "compress seal he plaintext") -flags.DEFINE_integer("bucket_size", 1000000, "bucket size of pir query") -flags.DEFINE_integer( - "max_items_per_bin", 0, "max items per bin, i.e. Interpolate polynomial max degree" -) - -FLAGS = flags.FLAGS - - -def setup_link(rank): - lctx_desc = link.Desc() - lctx_desc.id = f"root" - - lctx_desc.recv_timeout_ms = 2 * 60 * 1000 - # lctx_desc.connect_retry_times = 180 - - ips = FLAGS.party_ips.split(",") - for i, ip in enumerate(ips): - lctx_desc.add_party(f"id_{i}", ip) - print(f"id_{i} = {ip}") - - return link.create_brpc(lctx_desc, rank) - - -def main(_): - opts = logging.LogOptions() - opts.system_log_path = "./tmp/spu.log" - opts.enable_console_logger = True - opts.log_level = logging.LogLevel.INFO - logging.setup_logging(opts) - - key_columns = FLAGS.key_columns.split(",") - label_columns = FLAGS.label_columns.split(",") - - link = setup_link(FLAGS.rank) - - start = time.time() - - config = pir.PirSetupConfig( - pir_protocol=pir.PirProtocol.Value('KEYWORD_PIR_LABELED_PSI'), - store_type=pir.KvStoreType.Value('LEVELDB_KV_STORE'), - input_path=FLAGS.in_path, - key_columns=key_columns, - label_columns=label_columns, - num_per_query=FLAGS.count_per_query, - label_max_len=FLAGS.max_label_length, - oprf_key_path="", - setup_path='::memory', - compressed=FLAGS.compressed, - bucket_size=FLAGS.bucket_size, - max_items_per_bin=FLAGS.max_items_per_bin, - ) - - report = pir.pir_memory_server(link, config) - print(f"data_count: {report.data_count}") - print(f"memory server cost time: {time.time() - start}") - - -if __name__ == '__main__': - app.run(main) diff --git a/examples/python/pir/pir_server.py b/examples/python/pir/pir_server.py deleted file mode 100644 index 585590909..000000000 --- a/examples/python/pir/pir_server.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# To run the example, start two terminals: -# > python pir_server.py --rank 1 --setup_path setup_path --oprf_key_path oprf_key.bin -# - -from absl import app, flags - -import spu.libspu.link as link -import spu.libspu.logging as logging -import spu.pir as pir - -flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") -flags.DEFINE_string("party_ips", "127.0.0.1:61307,127.0.0.1:61308", "party addresses") - -flags.DEFINE_string("oprf_key_path", "oprf_key.bin", "oprf key file path") -flags.DEFINE_string("setup_path", "setup_path", "data output path") - -flags.DEFINE_bool("enable_tls", False, "whether to enable tls for link") -flags.DEFINE_string("link_server_certificate", "", "link server certificate file path") -flags.DEFINE_string("link_server_private_key", "", "link server private key file path") -flags.DEFINE_string( - "link_server_ca", "", "ca file used to verify other's link server certificate" -) -flags.DEFINE_string("link_client_certificate", "", "link client certificate file path") -flags.DEFINE_string("link_client_private_key", "", "link client private key file path") -flags.DEFINE_string( - "link_client_ca", "", "ca file used to verify other's link client certificate" -) -FLAGS = flags.FLAGS - - -def setup_link(rank): - lctx_desc = link.Desc() - lctx_desc.id = f"root" - - lctx_desc.recv_timeout_ms = 30 * 60 * 1000 - # lctx_desc.connect_retry_times = 180 - - ips = FLAGS.party_ips.split(",") - for i, ip in enumerate(ips): - lctx_desc.add_party(f"id_{i}", ip) - print(f"id_{i} = {ip}") - - # config link tls - if FLAGS.enable_tls: - # two-way authentication - lctx_desc.server_ssl_opts.cert.certificate_path = FLAGS.link_server_certificate - lctx_desc.server_ssl_opts.cert.private_key_path = FLAGS.link_server_private_key - lctx_desc.server_ssl_opts.verify.ca_file_path = FLAGS.link_server_ca - lctx_desc.server_ssl_opts.verify.verify_depth = 1 - lctx_desc.client_ssl_opts.cert.certificate_path = FLAGS.link_client_certificate - lctx_desc.client_ssl_opts.cert.private_key_path = FLAGS.link_client_private_key - lctx_desc.client_ssl_opts.verify.ca_file_path = FLAGS.link_client_ca - lctx_desc.client_ssl_opts.verify.verify_depth = 1 - - return link.create_brpc(lctx_desc, rank) - - -def main(_): - opts = logging.LogOptions() - opts.system_log_path = "./tmp/spu.log" - opts.enable_console_logger = True - opts.log_level = logging.LogLevel.INFO - logging.setup_logging(opts) - - config = pir.PirServerConfig( - pir_protocol=pir.PirProtocol.Value('KEYWORD_PIR_LABELED_PSI'), - store_type=pir.KvStoreType.Value('LEVELDB_KV_STORE'), - oprf_key_path=FLAGS.oprf_key_path, - setup_path=FLAGS.setup_path, - ) - - link_ctx = setup_link(FLAGS.rank) - report = pir.pir_server(link_ctx, config) - - print(f"data_count: {report.data_count}") - - -if __name__ == '__main__': - app.run(main) diff --git a/examples/python/pir/pir_setup.py b/examples/python/pir/pir_setup.py deleted file mode 100644 index bed2f599e..000000000 --- a/examples/python/pir/pir_setup.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# To run the example, start two terminals: -# > python pir_setup.py --in_path examples/data/pir_server_data.csv --key_columns id --label_columns label \ -# > --count_per_query 1 -max_label_length 256 \ -# > --oprf_key_path oprf_key.bin --setup_path setup_path - -from absl import app, flags - -import spu.libspu.link as link -import spu.libspu.logging as logging -import spu.pir as pir - -flags.DEFINE_string("in_path", "data.csv", "data input path") -flags.DEFINE_string("key_columns", "id", "csv file key filed name") -flags.DEFINE_string("label_columns", "label", "csv file label filed name") -flags.DEFINE_integer("count_per_query", 1, "count_per_query") -flags.DEFINE_integer("max_label_length", 256, "max_label_length") -flags.DEFINE_string("oprf_key_path", "oprf_key.bin", "oprf key file") -flags.DEFINE_string("setup_path", "setup_path", "data output path") - -flags.DEFINE_boolean("compressed", False, "compress seal he plaintext") -flags.DEFINE_integer("bucket_size", 1000000, "bucket size of pir query") -flags.DEFINE_integer( - "max_items_per_bin", 0, "max items per bin, i.e. Interpolate polynomial max degree" -) - -FLAGS = flags.FLAGS - - -def main(_): - opts = logging.LogOptions() - opts.system_log_path = "./tmp/spu.log" - opts.enable_console_logger = True - opts.log_level = logging.LogLevel.INFO - logging.setup_logging(opts) - - key_columns = FLAGS.key_columns.split(",") - label_columns = FLAGS.label_columns.split(",") - - config = pir.PirSetupConfig( - pir_protocol=pir.PirProtocol.Value('KEYWORD_PIR_LABELED_PSI'), - store_type=pir.KvStoreType.Value('LEVELDB_KV_STORE'), - input_path=FLAGS.in_path, - key_columns=key_columns, - label_columns=label_columns, - num_per_query=FLAGS.count_per_query, - label_max_len=FLAGS.max_label_length, - oprf_key_path=FLAGS.oprf_key_path, - setup_path=FLAGS.setup_path, - compressed=FLAGS.compressed, - bucket_size=FLAGS.bucket_size, - max_items_per_bin=FLAGS.max_items_per_bin, - ) - - report = pir.pir_setup(config) - print(f"data_count: {report.data_count}") - - -if __name__ == '__main__': - app.run(main) diff --git a/examples/python/psi/BUILD.bazel b/examples/python/psi/BUILD.bazel deleted file mode 100644 index aef966524..000000000 --- a/examples/python/psi/BUILD.bazel +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2022 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@rules_python//python:defs.bzl", "py_binary") - -package(default_visibility = ["//visibility:public"]) - -py_binary( - name = "simple_psi", - srcs = ["simple_psi.py"], - data = [ - "//examples/data", - ], - deps = [ - "//spu:api", - "//spu:psi", - ], -) - -py_binary( - name = "unbalanced_psi", - srcs = ["unbalanced_psi.py"], - data = [ - "//examples/data", - ], - deps = [ - "//spu:api", - "//spu:psi", - ], -) - -py_binary( - name = "mem_psi", - srcs = ["mem_psi.py"], - data = [ - "//examples/data", - ], - deps = [ - "//spu:api", - "//spu:psi", - ], -) diff --git a/examples/python/psi/mem_psi.py b/examples/python/psi/mem_psi.py deleted file mode 100644 index e21cace4c..000000000 --- a/examples/python/psi/mem_psi.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# To run the example, start two terminals: -# > bazel run //examples/python/psi:mem_psi -- --rank 0 --protocol ECDH_PSI_2PC --in_path examples/data/psi_1.csv --field_name id --out_path /tmp/p1.out -# > bazel run //examples/python/psi:mem_psi -- --rank 1 --protocol ECDH_PSI_2PC --in_path examples/data/psi_2.csv --field_name id --out_path /tmp/p2.out - -import pandas as pd -from absl import app, flags - -import spu.libspu.link as link -import spu.libspu.logging as logging -import spu.psi as psi - -flags.DEFINE_string("protocol", "ECDH_PSI_2PC", "psi protocol, see `spu/psi/psi.proto`") -flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") -flags.DEFINE_string("party_ips", "127.0.0.1:61307,127.0.0.1:61308", "party addresses") -flags.DEFINE_string("in_path", "data.csv", "data input path") -flags.DEFINE_string("field_name", "id", "csv file filed name") -flags.DEFINE_string("out_path", "mem_psi_out.csv", "data output path") -flags.DEFINE_integer("receiver_rank", 0, "main party for psi, will get result") -flags.DEFINE_bool("enable_tls", False, "whether to enable tls for link") -flags.DEFINE_string("link_server_certificate", "", "link server certificate file path") -flags.DEFINE_string("link_server_private_key", "", "link server private key file path") -flags.DEFINE_string( - "link_server_ca", "", "ca file used to verify other's link server certificate" -) -flags.DEFINE_string("link_client_certificate", "", "link client certificate file path") -flags.DEFINE_string("link_client_private_key", "", "link client private key file path") -flags.DEFINE_string( - "link_client_ca", "", "ca file used to verify other's link client certificate" -) -FLAGS = flags.FLAGS - - -def setup_link(rank): - lctx_desc = link.Desc() - lctx_desc.id = f"root" - - lctx_desc.recv_timeout_ms = 2 * 60 * 1000 - lctx_desc.connect_retry_times = 180 - - ips = FLAGS.party_ips.split(",") - for i, ip in enumerate(ips): - lctx_desc.add_party(f"id_{i}", ip) - print(f"id_{i} = {ip}") - - # config link tls - if FLAGS.enable_tls: - # two-way authentication - lctx_desc.enable_ssl = True - lctx_desc.server_ssl_opts.cert.certificate_path = FLAGS.link_server_certificate - lctx_desc.server_ssl_opts.cert.private_key_path = FLAGS.link_server_private_key - lctx_desc.server_ssl_opts.verify.ca_file_path = FLAGS.link_server_ca - lctx_desc.server_ssl_opts.verify.verify_depth = 1 - lctx_desc.client_ssl_opts.cert.certificate_path = FLAGS.link_client_certificate - lctx_desc.client_ssl_opts.cert.private_key_path = FLAGS.link_client_private_key - lctx_desc.client_ssl_opts.verify.ca_file_path = FLAGS.link_client_ca - lctx_desc.client_ssl_opts.verify.verify_depth = 1 - - return link.create_brpc(lctx_desc, rank) - - -def main(_): - logging.setup_logging() - - # read csv - in_df = pd.read_csv(FLAGS.in_path) - in_data = in_df[FLAGS.field_name].astype(str).tolist() - - config = psi.MemoryPsiConfig( - psi_type=psi.PsiType.Value(FLAGS.protocol), - broadcast_result=False, - receiver_rank=FLAGS.receiver_rank if FLAGS.receiver_rank >= 0 else 0, - curve_type=psi.CurveType.CURVE_25519, - ) - - if FLAGS.protocol == "DP_PSI_2PC": - config.dppsi_params.bob_sub_sampling = 0.9 - config.dppsi_params.epsilon = 3 - - intersection = psi.mem_psi(setup_link(FLAGS.rank), config, in_data) - - out_df = pd.DataFrame(columns=[FLAGS.field_name]) - out_df[FLAGS.field_name] = intersection - out_df.to_csv(FLAGS.out_path, index=False) - - print(f"original_count: {len(in_data)}, intersection_count: {len(intersection)}") - - -if __name__ == '__main__': - app.run(main) diff --git a/examples/python/psi/simple_psi.py b/examples/python/psi/simple_psi.py deleted file mode 100644 index 57d00094b..000000000 --- a/examples/python/psi/simple_psi.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2021 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# To run the example, start two terminals: -# > bazel run //examples/python/psi:simple_psi -- --rank 0 --protocol ECDH_PSI_2PC --in_path examples/data/psi_1.csv --field_names id --out_path /tmp/p1.out -# > bazel run //examples/python/psi:simple_psi -- --rank 1 --protocol ECDH_PSI_2PC --in_path examples/data/psi_2.csv --field_names id --out_path /tmp/p2.out - -from absl import app, flags - -import spu.libspu.link as link -import spu.libspu.logging as logging -import spu.psi as psi - -flags.DEFINE_string("protocol", "ECDH_PSI_2PC", "psi protocol, see `spu/psi/psi.proto`") -flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") -flags.DEFINE_string("party_ips", "127.0.0.1:61307,127.0.0.1:61308", "party addresses") -flags.DEFINE_string("in_path", "data.csv", "data input path") -flags.DEFINE_string("field_names", "id", "csv file filed name") -flags.DEFINE_string("out_path", "simple_psi_out.csv", "data output path") -flags.DEFINE_integer("receiver_rank", -1, "main party for psi, will get result") -flags.DEFINE_bool("output_sort", True, "whether to sort result") -flags.DEFINE_bool("precheck_input", True, "whether to precheck input dataset") -flags.DEFINE_integer("bucket_size", 1048576, "hash bucket size") -flags.DEFINE_bool("ic_mode", False, "whether to run in interconnection mode") -FLAGS = flags.FLAGS - - -def setup_link(rank): - lctx_desc = link.Desc() - lctx_desc.id = f"root" - - lctx_desc.recv_timeout_ms = 2 * 60 * 1000 - lctx_desc.connect_retry_times = 180 - if FLAGS.ic_mode: - lctx_desc.brpc_channel_protocol = "h2:grpc" - - ips = FLAGS.party_ips.split(",") - for i, ip in enumerate(ips): - lctx_desc.add_party(f"id_{i}", ip) - print(f"id_{i} = {ip}") - - return link.create_brpc(lctx_desc, rank) - - -def progress_callback(data: psi.ProgressData): - print( - f"progress callback ---- percentage: {data.percentage}, total: {data.total}, finished: {data.finished}, running: {data.running}, description: {data.description}" - ) - - -def main(_): - opts = logging.LogOptions() - opts.system_log_path = "./tmp/spu.log" - opts.enable_console_logger = True - opts.log_level = logging.LogLevel.INFO - logging.setup_logging(opts) - - selected_fields = FLAGS.field_names.split(",") - - config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value(FLAGS.protocol), - broadcast_result=True if FLAGS.receiver_rank < 0 else False, - receiver_rank=FLAGS.receiver_rank if FLAGS.receiver_rank >= 0 else 0, - input_params=psi.InputParams( - path=FLAGS.in_path, - select_fields=selected_fields, - precheck=FLAGS.precheck_input, - ), - output_params=psi.OutputParams( - path=FLAGS.out_path, need_sort=FLAGS.output_sort - ), - bucket_size=FLAGS.bucket_size, - curve_type=psi.CurveType.CURVE_25519, - ) - - if FLAGS.protocol == "DP_PSI_2PC": - config.dppsi_params.bob_sub_sampling = 0.9 - config.dppsi_params.epsilon = 3 - - report = psi.bucket_psi( - setup_link(FLAGS.rank), config, progress_callback, 5 * 1000, FLAGS.ic_mode - ) - print( - f"original_count: {report.original_count}, intersection_count: {report.intersection_count}" - ) - - -if __name__ == '__main__': - app.run(main) diff --git a/examples/python/psi/unbalanced_psi.py b/examples/python/psi/unbalanced_psi.py deleted file mode 100644 index f278cbc8d..000000000 --- a/examples/python/psi/unbalanced_psi.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2021 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# To run the example, start two terminals: -# > bazel run //examples/python/psi:unbalanced_psi -- --rank 0 --in_path examples/data/psi_1.csv --field_names id --out_path /tmp/p1.out -# > bazel run //examples/python/psi:unbalanced_psi -- --rank 1 --in_path examples/data/psi_2.csv --field_names id --out_path /tmp/p2.out - -import time - -from absl import app, flags - -import spu.libspu.link as link -import spu.psi as psi - -flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") -flags.DEFINE_string("in_path", "data.csv", "data input path") -flags.DEFINE_string("field_names", "id", "csv file filed name") -flags.DEFINE_string("out_path", "data.out", "data output path") -flags.DEFINE_integer("receiver_rank", 0, "main party for psi, will get result") -flags.DEFINE_bool("output_sort", False, "whether to sort result") -flags.DEFINE_integer("bucket_size", 1048576, "hash bucket size") -FLAGS = flags.FLAGS - - -def setup_link(rank, port): - lctx_desc = link.Desc() - lctx_desc.id = f"desc_id" - lctx_desc.recv_timeout_ms = 3600 * 1000 # 3600 seconds - - lctx_desc.add_party(f"id_0", f"127.0.0.1:{port}") - lctx_desc.add_party(f"id_1", f"127.0.0.1:{port+10}") - - return link.create_brpc(lctx_desc, rank) - - -def main(_): - selected_fields = FLAGS.field_names.split(",") - - # one-way PSI, just one party get result - broadcast_result = False - - secret_key_path = "secret_key.bin" - with open(secret_key_path, 'wb') as f: - f.write( - bytes.fromhex( - "000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000" - ) - ) - - cache_path = "server_cache.bin" - link_ctx = setup_link(FLAGS.rank, 91827) - - # ===== gen cache phase ===== - if FLAGS.receiver_rank != FLAGS.rank: - gen_cache_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_GEN_CACHE'), - broadcast_result=False, - receiver_rank=FLAGS.receiver_rank, - input_params=psi.InputParams( - path=FLAGS.in_path, - select_fields=selected_fields, - precheck=False, - ), - output_params=psi.OutputParams(path=cache_path, need_sort=False), - bucket_size=10000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - gen_cache_config.ecdh_secret_key_path = secret_key_path - - start = time.time() - gen_cache_report = psi.bucket_psi(None, gen_cache_config) - print(f"gen cache cost time: {time.time() - start}") - print( - f"gen cache: rank: {FLAGS.rank} original_count: {gen_cache_report.original_count}" - ) - - # ===== transfer cache phase ===== - print("===== Transfer Cache Phase =====") - transfer_cache_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_TRANSFER_CACHE'), - broadcast_result=broadcast_result, - receiver_rank=FLAGS.receiver_rank, - input_params=psi.InputParams( - path=FLAGS.in_path, - select_fields=selected_fields, - precheck=False, - ), - output_params=psi.OutputParams( - path=FLAGS.out_path, need_sort=FLAGS.output_sort - ), - bucket_size=10000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if FLAGS.receiver_rank == link_ctx.rank: - transfer_cache_config.preprocess_path = 'tmp/preprocess_path_transfer_cache.csv' - transfer_cache_config.input_params.path = 'fake.csv' - else: - transfer_cache_config.input_params.path = cache_path - transfer_cache_config.ecdh_secret_key_path = secret_key_path - - start = time.time() - transfer_cache_report = psi.bucket_psi(link_ctx, transfer_cache_config) - print(f"transfer cache cost time: {time.time() - start}") - print( - f"transfer cache: rank: {FLAGS.rank} original_count: {transfer_cache_report.original_count}" - ) - print(f"intersection_count: {transfer_cache_report.intersection_count}") - - # ===== shuffle online phase ===== - print("===== shuffle online phase =====") - - server_rank = 1 - FLAGS.receiver_rank - print(f"shuffle online server_rank: {server_rank}") - - config_shuffle_online = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_SHUFFLE_ONLINE'), - broadcast_result=broadcast_result, - receiver_rank=server_rank, - input_params=psi.InputParams( - path=FLAGS.in_path, - select_fields=selected_fields, - precheck=False, - ), - output_params=psi.OutputParams( - path=FLAGS.out_path, need_sort=FLAGS.output_sort - ), - bucket_size=100000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - print(f"input path:{FLAGS.in_path}") - if server_rank == link_ctx.rank: - config_shuffle_online.preprocess_path = cache_path - config_shuffle_online.ecdh_secret_key_path = secret_key_path - else: - config_shuffle_online.preprocess_path = 'tmp/preprocess_path_transfer_cache.csv' - - start = time.time() - report_shuffle_online = psi.bucket_psi(link_ctx, config_shuffle_online) - print(f"shuffle online cost time: {time.time() - start}") - print( - f"shuffle online: rank:{FLAGS.rank} original_count: {report_shuffle_online.original_count}" - ) - print(f"intersection_count: {report_shuffle_online.intersection_count}") - - # ===== offline phase ===== - print("===== UB Offline Phase =====") - config_offline = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_OFFLINE'), - broadcast_result=broadcast_result, - receiver_rank=FLAGS.receiver_rank, - input_params=psi.InputParams( - path=FLAGS.in_path, - select_fields=selected_fields, - precheck=False, - ), - output_params=psi.OutputParams( - path=FLAGS.out_path, need_sort=FLAGS.output_sort - ), - bucket_size=10000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if FLAGS.receiver_rank == link_ctx.rank: - config_offline.preprocess_path = 'tmp/preprocess_path.csv' - config_offline.input_params.path = 'fake.csv' - else: - config_offline.ecdh_secret_key_path = secret_key_path - - start = time.time() - offline_report = psi.bucket_psi(link_ctx, config_offline) - print(f"offline cost time: {time.time() - start}") - print( - f"offline: rank: {FLAGS.rank} original_count: {offline_report.original_count}" - ) - print(f"intersection_count: {offline_report.intersection_count}") - - # ===== online phase ===== - print("===== online phase =====") - config_online = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_ONLINE'), - broadcast_result=broadcast_result, - receiver_rank=FLAGS.receiver_rank, - input_params=psi.InputParams( - path=FLAGS.in_path, - select_fields=selected_fields, - precheck=False, - ), - output_params=psi.OutputParams( - path=FLAGS.out_path, need_sort=FLAGS.output_sort - ), - bucket_size=100000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - print(f"input path:{FLAGS.in_path}") - if FLAGS.receiver_rank == link_ctx.rank: - config_online.preprocess_path = 'tmp/preprocess_path.csv' - else: - config_online.input_params.path = 'fake.csv' - config_online.ecdh_secret_key_path = secret_key_path - - start = time.time() - report_online = psi.bucket_psi(link_ctx, config_online) - print(f"online cost time: {time.time() - start}") - print(f"online: rank:{FLAGS.rank} original_count: {report_online.original_count}") - print(f"intersection_count: {report_online.intersection_count}") - - -if __name__ == '__main__': - app.run(main) diff --git a/spu/BUILD.bazel b/spu/BUILD.bazel index 737f7dd59..d7e7214f3 100644 --- a/spu/BUILD.bazel +++ b/spu/BUILD.bazel @@ -65,10 +65,8 @@ pybind_extension( deps = [ ":exported_symbols.lds", ":version_script.lds", - "@psi//psi/pir", - "@psi//psi/psi:bucket_psi", - "@psi//psi/psi:launch", - "@psi//psi/psi:memory_psi", + "@psi//psi:launch", + "@psi//psi/legacy:memory_psi", "@yacl//yacl/link", ], ) @@ -116,6 +114,7 @@ py_library( srcs = [ "psi.py", ":link_py_proto", + ":pir_py_proto", ":psi_py_proto", ":psi_v2_py_proto_fixed", ], @@ -131,25 +130,12 @@ python_proto_compile( protos = ["@psi//psi/proto:pir_proto"], ) -py_library( - name = "pir", - srcs = [ - "pir.py", - ":pir_py_proto", - ], - data = [ - ":libpsi.so", - ":libspu.so", - ], -) - py_library( name = "init", srcs = [ "__init__.py", "version.py", ":api", - ":pir", ":psi", "//spu/intrinsic:all_intrinsics", "//spu/utils:simulation", diff --git a/spu/__init__.py b/spu/__init__.py index 61c1d9cf3..b56c7ff90 100644 --- a/spu/__init__.py +++ b/spu/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. -from . import pir, psi +from . import psi from .api import Io, Runtime, check_cpu_feature, compile from .intrinsic import * from .spu_pb2 import ( # type: ignore @@ -49,7 +49,6 @@ # utils "simulation", # libs - "pir", "psi", # intrinsic ] + intrinsic.__all__ diff --git a/spu/libpsi.cc b/spu/libpsi.cc index ff30b456a..0a0d5e795 100644 --- a/spu/libpsi.cc +++ b/spu/libpsi.cc @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include "pybind11/functional.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "yacl/base/exception.h" #include "yacl/link/context.h" -#include "psi/pir/pir.h" -#include "psi/psi/bucket_psi.h" -#include "psi/psi/launch.h" -#include "psi/psi/memory_psi.h" -#include "psi/psi/utils/progress.h" +#include "psi/launch.h" +#include "psi/legacy/memory_psi.h" +#include "psi/utils/progress.h" + +#include "psi/proto/pir.pb.h" +#include "psi/proto/psi.pb.h" +#include "psi/proto/psi_v2.pb.h" namespace py = pybind11; @@ -78,7 +78,7 @@ void BindLibs(py::module& m) { "Run bucket psi. ic_mode means run in interconnection mode", NO_GIL); m.def( - "psi_v2", + "psi", [](const std::string& config_pb, const std::shared_ptr& lctx) -> py::bytes { psi::v2::PsiConfig psi_config; @@ -104,58 +104,16 @@ void BindLibs(py::module& m) { "Run UB PSI with v2 API.", NO_GIL); m.def( - "pir_setup", - [](const std::string& config_pb) -> py::bytes { - pir::PirSetupConfig config; - YACL_ENFORCE(config.ParseFromString(config_pb)); - - config.set_bucket_size(1000000); - config.set_compressed(false); - - auto r = pir::PirSetup(config); - return r.SerializeAsString(); - }, - py::arg("pir_config"), "Run pir setup."); - - m.def( - "pir_server", - [](const std::shared_ptr& lctx, - const std::string& config_pb) -> py::bytes { - pir::PirServerConfig config; - YACL_ENFORCE(config.ParseFromString(config_pb)); - - auto r = pir::PirServer(lctx, config); - return r.SerializeAsString(); - }, - py::arg("link_context"), py::arg("pir_config"), "Run pir server"); - - m.def( - "pir_memory_server", - [](const std::shared_ptr& lctx, - const std::string& config_pb) -> py::bytes { - pir::PirSetupConfig config; - YACL_ENFORCE(config.ParseFromString(config_pb)); - YACL_ENFORCE(config.setup_path() == "::memory"); - - config.set_bucket_size(1000000); - config.set_compressed(false); - - auto r = pir::PirMemoryServer(lctx, config); - return r.SerializeAsString(); - }, - py::arg("link_context"), py::arg("pir_config"), "Run pir memory server"); - - m.def( - "pir_client", - [](const std::shared_ptr& lctx, - const std::string& config_pb) -> py::bytes { - pir::PirClientConfig config; + "pir", + [](const std::string& config_pb, + const std::shared_ptr& lctx) -> py::bytes { + psi::PirConfig config; YACL_ENFORCE(config.ParseFromString(config_pb)); - auto r = pir::PirClient(lctx, config); + auto r = psi::RunPir(config, lctx); return r.SerializeAsString(); }, - py::arg("link_context"), py::arg("pir_config"), "Run pir client"); + py::arg("pir_config"), py::arg("link_context") = nullptr, "Run PIR."); } PYBIND11_MODULE(libpsi, m) { diff --git a/spu/pir.py b/spu/pir.py deleted file mode 100644 index c34e6fc1f..000000000 --- a/spu/pir.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2021 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import List - -from .libspu.link import Context # type: ignore -from . import libpsi # type: ignore -from .pir_pb2 import ( # type: ignore - KvStoreType, - PirClientConfig, - PirProtocol, - PirResultReport, - PirServerConfig, - PirSetupConfig, -) - - -def pir_setup(config: PirSetupConfig) -> List[str]: - report_str = libpsi.libs.pir_setup(config.SerializeToString()) - - report = PirResultReport() - report.ParseFromString(report_str) - return report - - -def pir_server(link: libspu.link.Context, config: PirServerConfig) -> List[str]: - report_str = libpsi.libs.pir_server(link, config.SerializeToString()) - - report = PirResultReport() - report.ParseFromString(report_str) - return report - - -def pir_memory_server(link: libspu.link.Context, config: PirSetupConfig) -> List[str]: - report_str = libpsi.libs.pir_memory_server(link, config.SerializeToString()) - - report = PirResultReport() - report.ParseFromString(report_str) - return report - - -def pir_client(link: libspu.link.Context, config: PirClientConfig) -> List[str]: - report_str = libpsi.libs.pir_client(link, config.SerializeToString()) - - report = PirResultReport() - report.ParseFromString(report_str) - return report diff --git a/spu/psi.py b/spu/psi.py index f6060be76..2796ab526 100644 --- a/spu/psi.py +++ b/spu/psi.py @@ -19,6 +19,7 @@ from . import libpsi # type: ignore from .libpsi.libs import ProgressData from .libspu.link import Context # type: ignore +from .pir_pb2 import PirConfig, PirProtocol, PirResultReport # type: ignore from .psi_pb2 import ( # type: ignore BucketPsiConfig, CurveType, @@ -103,9 +104,9 @@ def gen_cache_for_2pc_ub_psi(config: BucketPsiConfig) -> PsiResultReport: return report -def psi_v2( +def psi( config: PsiConfig, - link: Context, + link: Context = None, ) -> PsiResultReport: """ Run PSI with v2 API. @@ -114,7 +115,7 @@ def psi_v2( :param link: the transport layer :return: statistical results """ - report_str = libpsi.libs.psi_v2( + report_str = libpsi.libs.psi( config.SerializeToString(), link, ) @@ -141,3 +142,11 @@ def ub_psi( report = PsiResultReport() report.ParseFromString(report_str) return report + + +def pir(config: PirProtocol, link: Context = None) -> PirResultReport: + report_str = libpsi.libs.pir(config.SerializeToString(), link) + + report = PirResultReport() + report.ParseFromString(report_str) + return report diff --git a/spu/tests/BUILD.bazel b/spu/tests/BUILD.bazel index 92149da42..18454b36a 100644 --- a/spu/tests/BUILD.bazel +++ b/spu/tests/BUILD.bazel @@ -26,6 +26,14 @@ py_library( ], ) +py_library( + name = "utils", + srcs = ["utils.py"], + deps = [ + "//spu:api", + ], +) + py_binary( name = "np_op_status", srcs = ["np_op_status.py"], @@ -220,6 +228,23 @@ py_test( ], ) +py_test( + name = "legacy_psi_test", + srcs = ["legacy_psi_test.py"], + data = [ + "//spu/tests/data", + ], + flaky = True, + tags = [ + "exclusive-if-local", + ], + deps = [ + ":utils", + "//spu:psi", + "//spu/utils:simulation", + ], +) + py_test( name = "psi_test", srcs = ["psi_test.py"], @@ -231,8 +256,40 @@ py_test( "exclusive-if-local", ], deps = [ + ":utils", + "//spu:psi", + ], +) + +py_test( + name = "ub_psi_test", + srcs = ["ub_psi_test.py"], + data = [ + "//spu/tests/data", + ], + flaky = True, + tags = [ + "exclusive-if-local", + ], + deps = [ + ":utils", + "//spu:psi", + ], +) + +py_test( + name = "pir_test", + srcs = ["pir_test.py"], + data = [ + "//spu/tests/data", + ], + flaky = True, + tags = [ + "exclusive-if-local", + ], + deps = [ + ":utils", "//spu:psi", - "//spu/utils:simulation", ], ) diff --git a/spu/tests/legacy_psi_test.py b/spu/tests/legacy_psi_test.py new file mode 100644 index 000000000..eae7e7c86 --- /dev/null +++ b/spu/tests/legacy_psi_test.py @@ -0,0 +1,445 @@ +# Copyright 2021 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +import unittest + +import multiprocess + +import spu.libspu.link as link +import spu.psi as psi +from spu.tests.utils import get_free_port, wc_count +from spu.utils.simulation import PropagatingThread + + +class UnitTests(unittest.TestCase): + def run_psi(self, fn): + wsize = 2 + + lctx_desc = link.Desc() + for rank in range(wsize): + lctx_desc.add_party(f"id_{rank}", f"thread_{rank}") + + def wrap(rank): + lctx = link.create_mem(lctx_desc, rank) + return fn(lctx) + + jobs = [PropagatingThread(target=wrap, args=(rank,)) for rank in range(wsize)] + + [job.start() for job in jobs] + [job.join() for job in jobs] + + def run_streaming_psi(self, wsize, inputs, outputs, selected_fields, protocol): + time_stamp = time.time() + lctx_desc = link.Desc() + lctx_desc.id = str(round(time_stamp * 1000)) + + for rank in range(wsize): + port = get_free_port() + lctx_desc.add_party(f"id_{rank}", f"127.0.0.1:{port}") + + def wrap(rank, selected_fields, input_path, output_path, type): + lctx = link.create_brpc(lctx_desc, rank) + + config = psi.BucketPsiConfig( + psi_type=type, + broadcast_result=True, + input_params=psi.InputParams( + path=input_path, select_fields=selected_fields + ), + output_params=psi.OutputParams(path=output_path, need_sort=True), + curve_type=psi.CurveType.CURVE_25519, + ) + + if type == psi.PsiType.DP_PSI_2PC: + config.dppsi_params.bob_sub_sampling = 0.9 + config.dppsi_params.epsilon = 3 + + report = psi.bucket_psi(lctx, config) + + source_count = wc_count(input_path) + output_count = wc_count(output_path) + print( + f"id:{lctx.id()}, psi_type: {type}, original_count: {report.original_count}, intersection_count: {report.intersection_count}, source_count: {source_count}, output_count: {output_count}" + ) + + self.assertEqual(report.original_count, source_count - 1) + self.assertEqual(report.intersection_count, output_count - 1) + + lctx.stop_link() + + # launch with multiprocess + jobs = [ + multiprocess.Process( + target=wrap, + args=( + rank, + selected_fields, + inputs[rank], + outputs[rank], + protocol, + ), + ) + for rank in range(wsize) + ] + [job.start() for job in jobs] + for job in jobs: + job.join() + self.assertEqual(job.exitcode, 0) + + def prep_data(self): + data = [ + [f'r{idx}' for idx in range(1000) if idx % 3 == 0], + [f'r{idx}' for idx in range(1000) if idx % 7 == 0], + ] + + expected = [f'r{idx}' for idx in range(1000) if idx % 3 == 0 and idx % 7 == 0] + + return data, expected + + def test_reveal(self): + data, expected = self.prep_data() + expected.sort() + + def fn(lctx): + config = psi.MemoryPsiConfig( + psi_type=psi.PsiType.ECDH_PSI_2PC, broadcast_result=True + ) + joint = psi.mem_psi(lctx, config, data[lctx.rank]) + joint.sort() + return self.assertEqual(joint, expected) + + self.run_psi(fn) + + def test_reveal_to(self): + data, expected = self.prep_data() + expected.sort() + + reveal_to_rank = 0 + + def fn(lctx): + config = psi.MemoryPsiConfig( + psi_type=psi.PsiType.KKRT_PSI_2PC, + receiver_rank=reveal_to_rank, + broadcast_result=False, + ) + joint = psi.mem_psi(lctx, config, data[lctx.rank]) + + joint.sort() + + if lctx.rank == reveal_to_rank: + self.assertEqual(joint, expected) + else: + self.assertEqual(joint, []) + + self.run_psi(fn) + + def test_ecdh_3pc(self): + print("----------test_ecdh_3pc-------------") + + inputs = [ + "spu/tests/data/alice.csv", + "spu/tests/data/bob.csv", + "spu/tests/data/carol.csv", + ] + outputs = ["./alice-ecdh3pc.csv", "./bob-ecdh3pc.csv", "./carol-ecdh3pc.csv"] + selected_fields = ["id", "idx"] + + self.run_streaming_psi( + 3, inputs, outputs, selected_fields, psi.PsiType.ECDH_PSI_3PC + ) + + def test_kkrt_2pc(self): + print("----------test_kkrt_2pc-------------") + + inputs = ["spu/tests/data/alice.csv", "spu/tests/data/bob.csv"] + outputs = ["./alice-kkrt.csv", "./bob-kkrt.csv"] + selected_fields = ["id", "idx"] + + self.run_streaming_psi( + 2, inputs, outputs, selected_fields, psi.PsiType.KKRT_PSI_2PC + ) + + def test_ecdh_2pc(self): + print("----------test_ecdh_2pc-------------") + + inputs = ["spu/tests/data/alice.csv", "spu/tests/data/bob.csv"] + outputs = ["./alice-ecdh.csv", "./bob-ecdh.csv"] + selected_fields = ["id", "idx"] + + self.run_streaming_psi( + 2, inputs, outputs, selected_fields, psi.PsiType.ECDH_PSI_2PC + ) + + def test_dppsi_2pc(self): + print("----------test_dppsi_2pc-------------") + + inputs = ["spu/tests/data/alice.csv", "spu/tests/data/bob.csv"] + outputs = ["./alice-dppsi.csv", "./bob-dppsi.csv"] + selected_fields = ["id", "idx"] + + self.run_streaming_psi( + 2, inputs, outputs, selected_fields, psi.PsiType.DP_PSI_2PC + ) + + def test_ecdh_oprf_unbalanced(self): + print("----------test_ecdh_oprf_unbalanced-------------") + + offline_path = ["", "spu/tests/data/bob.csv"] + online_path = ["spu/tests/data/alice.csv", "spu/tests/data/bob.csv"] + outputs = ["./alice-ecdh-unbalanced.csv", "./bob-ecdh-unbalanced.csv"] + preprocess_path = ["./alice-preprocess.csv", ""] + secret_key_path = ["", "./secret_key.bin"] + selected_fields = ["id", "idx"] + + with open(secret_key_path[1], 'wb') as f: + f.write( + bytes.fromhex( + "000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000" + ) + ) + + time_stamp = time.time() + lctx_desc = link.Desc() + lctx_desc.id = str(round(time_stamp * 1000)) + + for rank in range(2): + port = get_free_port() + lctx_desc.add_party(f"id_{rank}", f"127.0.0.1:{port}") + + receiver_rank = 0 + server_rank = 1 + client_rank = 0 + # one-way PSI, just one party get result + broadcast_result = False + + precheck_input = False + server_cache_path = "server_cache.bin" + + def wrap( + rank, + offline_path, + online_path, + out_path, + preprocess_path, + ub_secret_key_path, + ): + link_ctx = link.create_brpc(lctx_desc, rank) + + if receiver_rank != link_ctx.rank: + print("===== gen cache phase =====") + print(f"{offline_path}, {server_cache_path}") + + gen_cache_config = psi.BucketPsiConfig( + psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_GEN_CACHE'), + input_params=psi.InputParams( + path=offline_path, + select_fields=selected_fields, + precheck=False, + ), + output_params=psi.OutputParams( + path=server_cache_path, need_sort=False + ), + bucket_size=1000000, + curve_type=psi.CurveType.CURVE_FOURQ, + ecdh_secret_key_path=ub_secret_key_path, + ) + + start = time.time() + + gen_cache_report = psi.gen_cache_for_2pc_ub_psi(gen_cache_config) + + server_source_count = wc_count(offline_path) + self.assertEqual( + gen_cache_report.original_count, server_source_count - 1 + ) + + print(f"offline cost time: {time.time() - start}") + print( + f"offline: rank: {rank} original_count: {gen_cache_report.original_count}" + ) + + print("===== transfer cache phase =====") + transfer_cache_config = psi.BucketPsiConfig( + psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_TRANSFER_CACHE'), + broadcast_result=broadcast_result, + receiver_rank=receiver_rank, + input_params=psi.InputParams( + path=offline_path, + select_fields=selected_fields, + precheck=precheck_input, + ), + bucket_size=1000000, + curve_type=psi.CurveType.CURVE_FOURQ, + ) + + if receiver_rank == link_ctx.rank: + transfer_cache_config.preprocess_path = preprocess_path + else: + transfer_cache_config.input_params.path = server_cache_path + + print( + f"rank:{link_ctx.rank} file:{transfer_cache_config.input_params.path}" + ) + + start = time.time() + transfer_cache_report = psi.bucket_psi(link_ctx, transfer_cache_config) + + if receiver_rank != link_ctx.rank: + server_source_count = wc_count(offline_path) + self.assertEqual( + transfer_cache_report.original_count, server_source_count - 1 + ) + + print(f"transfer cache cost time: {time.time() - start}") + print( + f"transfer cache: rank: {rank} original_count: {transfer_cache_report.original_count}" + ) + + print("===== shuffle online phase =====") + shuffle_online_config = psi.BucketPsiConfig( + psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_SHUFFLE_ONLINE'), + broadcast_result=False, + receiver_rank=server_rank, + input_params=psi.InputParams( + path=online_path, + select_fields=selected_fields, + precheck=precheck_input, + ), + output_params=psi.OutputParams(path=out_path, need_sort=False), + bucket_size=10000000, + curve_type=psi.CurveType.CURVE_FOURQ, + ) + + if client_rank == link_ctx.rank: + shuffle_online_config.preprocess_path = preprocess_path + else: + shuffle_online_config.preprocess_path = server_cache_path + shuffle_online_config.ecdh_secret_key_path = ub_secret_key_path + + print( + f"rank:{link_ctx.rank} file:{shuffle_online_config.input_params.path}" + ) + + start = time.time() + shuffle_online_report = psi.bucket_psi(link_ctx, shuffle_online_config) + + if server_rank == link_ctx.rank: + server_source_count = wc_count(offline_path) + self.assertEqual( + shuffle_online_report.original_count, server_source_count - 1 + ) + + print(f"shuffle online cost time: {time.time() - start}") + print( + f"shuffle online: rank: {rank} original_count: {shuffle_online_report.original_count}" + ) + print( + f"shuffle online: rank: {rank} intersection: {shuffle_online_report.intersection_count}" + ) + + print("===== offline phase =====") + offline_config = psi.BucketPsiConfig( + psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_OFFLINE'), + broadcast_result=broadcast_result, + receiver_rank=client_rank, + input_params=psi.InputParams( + path=offline_path, + select_fields=selected_fields, + precheck=precheck_input, + ), + output_params=psi.OutputParams(path="fake.out", need_sort=False), + bucket_size=1000000, + curve_type=psi.CurveType.CURVE_FOURQ, + ) + + if client_rank == link_ctx.rank: + offline_config.preprocess_path = preprocess_path + offline_config.input_params.path = "dummy.csv" + else: + offline_config.ecdh_secret_key_path = ub_secret_key_path + + start = time.time() + offline_report = psi.bucket_psi(link_ctx, offline_config) + + if receiver_rank != link_ctx.rank: + server_source_count = wc_count(offline_path) + self.assertEqual(offline_report.original_count, server_source_count - 1) + + print(f"offline cost time: {time.time() - start}") + print( + f"offline: rank: {rank} original_count: {offline_report.original_count}" + ) + print( + f"offline: rank: {rank} intersection_count: {offline_report.intersection_count}" + ) + + print("===== online phase =====") + online_config = psi.BucketPsiConfig( + psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_ONLINE'), + broadcast_result=broadcast_result, + receiver_rank=client_rank, + input_params=psi.InputParams( + path=online_path, + select_fields=selected_fields, + precheck=precheck_input, + ), + output_params=psi.OutputParams(path=out_path, need_sort=False), + bucket_size=300000, + curve_type=psi.CurveType.CURVE_FOURQ, + ) + + if receiver_rank == link_ctx.rank: + online_config.preprocess_path = preprocess_path + else: + online_config.ecdh_secret_key_path = ub_secret_key_path + online_config.input_params.path = "dummy.csv" + + start = time.time() + report_online = psi.bucket_psi(link_ctx, online_config) + + if receiver_rank == link_ctx.rank: + client_source_count = wc_count(online_path) + self.assertEqual(report_online.original_count, client_source_count - 1) + + print(f"online cost time: {time.time() - start}") + print(f"online: rank:{rank} original_count: {report_online.original_count}") + print(f"intersection_count: {report_online.intersection_count}") + + link_ctx.stop_link() + + # launch with multiprocess + jobs = [ + multiprocess.Process( + target=wrap, + args=( + rank, + offline_path[rank], + online_path[rank], + outputs[rank], + preprocess_path[rank], + secret_key_path[rank], + ), + ) + for rank in range(2) + ] + [job.start() for job in jobs] + for job in jobs: + job.join() + self.assertEqual(job.exitcode, 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/spu/tests/pir_test.py b/spu/tests/pir_test.py new file mode 100644 index 000000000..5bbf48a8e --- /dev/null +++ b/spu/tests/pir_test.py @@ -0,0 +1,125 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import unittest + +import multiprocess +from google.protobuf import json_format + +import spu.libspu.link as link +import spu.psi as psi +from spu.tests.utils import create_clean_folder, create_link_desc, wc_count + + +class UnitTests(unittest.TestCase): + def test_pir(self): + # setup stage + + server_setup_config = ''' + { + "mode": "MODE_SERVER_SETUP", + "pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI", + "pir_server_config": { + "input_path": "spu/tests/data/alice.csv", + "setup_path": "/tmp/spu_test_pir_pir_server_setup", + "key_columns": [ + "id" + ], + "label_columns": [ + "y" + ], + "label_max_len": 288, + "bucket_size": 1000000, + "apsi_server_config": { + "oprf_key_path": "/tmp/spu_test_pir_server_secret_key.bin", + "num_per_query": 1, + "compressed": false + } + } + } + ''' + + with open("/tmp/spu_test_pir_server_secret_key.bin", 'wb') as f: + f.write( + bytes.fromhex( + "000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000" + ) + ) + + create_clean_folder("/tmp/spu_test_pir_pir_server_setup") + + psi.pir(json_format.ParseDict(json.loads(server_setup_config), psi.PirConfig())) + + server_online_config = ''' + { + "mode": "MODE_SERVER_ONLINE", + "pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI", + "pir_server_config": { + "setup_path": "/tmp/spu_test_pir_pir_server_setup" + } + } + ''' + + client_online_config = ''' + { + "mode": "MODE_CLIENT", + "pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI", + "pir_client_config": { + "input_path": "/tmp/spu_test_pir_pir_client.csv", + "key_columns": [ + "id" + ], + "output_path": "/tmp/spu_test_pir_pir_output.csv" + } + } + ''' + + pir_client_input_content = '''id +user808 +xxx +''' + + with open("/tmp/spu_test_pir_pir_client.csv", 'w') as f: + f.write(pir_client_input_content) + + configs = [ + json_format.ParseDict(json.loads(server_online_config), psi.PirConfig()), + json_format.ParseDict(json.loads(client_online_config), psi.PirConfig()), + ] + + link_desc = create_link_desc(2) + + def wrap(rank, link_desc, configs): + link_ctx = link.create_brpc(link_desc, rank) + psi.pir(configs[rank], link_ctx) + + jobs = [ + multiprocess.Process( + target=wrap, + args=(rank, link_desc, configs), + ) + for rank in range(2) + ] + [job.start() for job in jobs] + for job in jobs: + job.join() + self.assertEqual(job.exitcode, 0) + + # including title, actual matched item cnt is 1. + self.assertEqual(wc_count("/tmp/spu_test_pir_pir_output.csv"), 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/spu/tests/psi_test.py b/spu/tests/psi_test.py index 58c1248a7..da54f284f 100644 --- a/spu/tests/psi_test.py +++ b/spu/tests/psi_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 Ant Group Co., Ltd. +# Copyright 2024 Ant Group Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,438 +12,86 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import subprocess -import time +import json import unittest -from socket import socket import multiprocess +from google.protobuf import json_format import spu.libspu.link as link import spu.psi as psi -from spu.utils.simulation import PropagatingThread - - -def get_free_port(): - with socket() as s: - s.bind(("localhost", 0)) - return s.getsockname()[1] - - -def wc_count(file_name): - out = subprocess.getoutput("wc -l %s" % file_name) - return int(out.split()[0]) +from spu.tests.utils import create_link_desc, wc_count class UnitTests(unittest.TestCase): - def run_psi(self, fn): - wsize = 2 - - lctx_desc = link.Desc() - for rank in range(wsize): - lctx_desc.add_party(f"id_{rank}", f"thread_{rank}") - - def wrap(rank): - lctx = link.create_mem(lctx_desc, rank) - return fn(lctx) - - jobs = [PropagatingThread(target=wrap, args=(rank,)) for rank in range(wsize)] - - [job.start() for job in jobs] - [job.join() for job in jobs] - - def run_streaming_psi(self, wsize, inputs, outputs, selected_fields, protocol): - time_stamp = time.time() - lctx_desc = link.Desc() - lctx_desc.id = str(round(time_stamp * 1000)) - - for rank in range(wsize): - port = get_free_port() - lctx_desc.add_party(f"id_{rank}", f"127.0.0.1:{port}") - - def wrap(rank, selected_fields, input_path, output_path, type): - lctx = link.create_brpc(lctx_desc, rank) - - config = psi.BucketPsiConfig( - psi_type=type, - broadcast_result=True, - input_params=psi.InputParams( - path=input_path, select_fields=selected_fields - ), - output_params=psi.OutputParams(path=output_path, need_sort=True), - curve_type=psi.CurveType.CURVE_25519, - ) - - if type == psi.PsiType.DP_PSI_2PC: - config.dppsi_params.bob_sub_sampling = 0.9 - config.dppsi_params.epsilon = 3 - - report = psi.bucket_psi(lctx, config) - - source_count = wc_count(input_path) - output_count = wc_count(output_path) - print( - f"id:{lctx.id()}, psi_type: {type}, original_count: {report.original_count}, intersection_count: {report.intersection_count}, source_count: {source_count}, output_count: {output_count}" - ) - - self.assertEqual(report.original_count, source_count - 1) - self.assertEqual(report.intersection_count, output_count - 1) - - lctx.stop_link() - - # launch with multiprocess - jobs = [ - multiprocess.Process( - target=wrap, - args=( - rank, - selected_fields, - inputs[rank], - outputs[rank], - protocol, - ), - ) - for rank in range(wsize) - ] - [job.start() for job in jobs] - for job in jobs: - job.join() - self.assertEqual(job.exitcode, 0) - - def prep_data(self): - data = [ - [f'r{idx}' for idx in range(1000) if idx % 3 == 0], - [f'r{idx}' for idx in range(1000) if idx % 7 == 0], - ] - - expected = [f'r{idx}' for idx in range(1000) if idx % 3 == 0 and idx % 7 == 0] - - return data, expected - - def test_reveal(self): - data, expected = self.prep_data() - expected.sort() - - def fn(lctx): - config = psi.MemoryPsiConfig( - psi_type=psi.PsiType.ECDH_PSI_2PC, broadcast_result=True - ) - joint = psi.mem_psi(lctx, config, data[lctx.rank]) - joint.sort() - return self.assertEqual(joint, expected) - - self.run_psi(fn) - - def test_reveal_to(self): - data, expected = self.prep_data() - expected.sort() - - reveal_to_rank = 0 - - def fn(lctx): - config = psi.MemoryPsiConfig( - psi_type=psi.PsiType.KKRT_PSI_2PC, - receiver_rank=reveal_to_rank, - broadcast_result=False, - ) - joint = psi.mem_psi(lctx, config, data[lctx.rank]) - - joint.sort() - - if lctx.rank == reveal_to_rank: - self.assertEqual(joint, expected) - else: - self.assertEqual(joint, []) - - self.run_psi(fn) - - def test_ecdh_3pc(self): - print("----------test_ecdh_3pc-------------") - - inputs = [ - "spu/tests/data/alice.csv", - "spu/tests/data/bob.csv", - "spu/tests/data/carol.csv", + def test_psi(self): + link_desc = create_link_desc(2) + + receiver_config_json = ''' + { + "protocol_config": { + "protocol": "PROTOCOL_ECDH", + "ecdh_config": { + "curve": "CURVE_25519" + }, + "role": "ROLE_RECEIVER", + "broadcast_result": true + }, + "input_config": { + "type": "IO_TYPE_FILE_CSV", + "path": "spu/tests/data/alice.csv" + }, + "output_config": { + "type": "IO_TYPE_FILE_CSV", + "path": "/tmp/spu_test_psi_alice_psi_ouput.csv" + }, + "keys": [ + "id" + ], + "skip_duplicates_check": true, + "disable_alignment": true + } + ''' + + sender_config_json = ''' + { + "protocol_config": { + "protocol": "PROTOCOL_ECDH", + "ecdh_config": { + "curve": "CURVE_25519" + }, + "role": "ROLE_SENDER", + "broadcast_result": true + }, + "input_config": { + "type": "IO_TYPE_FILE_CSV", + "path": "spu/tests/data/bob.csv" + }, + "output_config": { + "type": "IO_TYPE_FILE_CSV", + "path": "/tmp/spu_test_psi_bob_psi_ouput.csv" + }, + "keys": [ + "id" + ], + "skip_duplicates_check": true, + "disable_alignment": true + } + ''' + + configs = [ + json_format.ParseDict(json.loads(receiver_config_json), psi.PsiConfig()), + json_format.ParseDict(json.loads(sender_config_json), psi.PsiConfig()), ] - outputs = ["./alice-ecdh3pc.csv", "./bob-ecdh3pc.csv", "./carol-ecdh3pc.csv"] - selected_fields = ["id", "idx"] - - self.run_streaming_psi( - 3, inputs, outputs, selected_fields, psi.PsiType.ECDH_PSI_3PC - ) - - def test_kkrt_2pc(self): - print("----------test_kkrt_2pc-------------") - - inputs = ["spu/tests/data/alice.csv", "spu/tests/data/bob.csv"] - outputs = ["./alice-kkrt.csv", "./bob-kkrt.csv"] - selected_fields = ["id", "idx"] - - self.run_streaming_psi( - 2, inputs, outputs, selected_fields, psi.PsiType.KKRT_PSI_2PC - ) - - def test_ecdh_2pc(self): - print("----------test_ecdh_2pc-------------") - - inputs = ["spu/tests/data/alice.csv", "spu/tests/data/bob.csv"] - outputs = ["./alice-ecdh.csv", "./bob-ecdh.csv"] - selected_fields = ["id", "idx"] - self.run_streaming_psi( - 2, inputs, outputs, selected_fields, psi.PsiType.ECDH_PSI_2PC - ) - - def test_dppsi_2pc(self): - print("----------test_dppsi_2pc-------------") - - inputs = ["spu/tests/data/alice.csv", "spu/tests/data/bob.csv"] - outputs = ["./alice-dppsi.csv", "./bob-dppsi.csv"] - selected_fields = ["id", "idx"] - - self.run_streaming_psi( - 2, inputs, outputs, selected_fields, psi.PsiType.DP_PSI_2PC - ) - - def test_ecdh_oprf_unbalanced(self): - print("----------test_ecdh_oprf_unbalanced-------------") - - offline_path = ["", "spu/tests/data/bob.csv"] - online_path = ["spu/tests/data/alice.csv", "spu/tests/data/bob.csv"] - outputs = ["./alice-ecdh-unbalanced.csv", "./bob-ecdh-unbalanced.csv"] - preprocess_path = ["./alice-preprocess.csv", ""] - secret_key_path = ["", "./secret_key.bin"] - selected_fields = ["id", "idx"] - - with open(secret_key_path[1], 'wb') as f: - f.write( - bytes.fromhex( - "000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000" - ) - ) - - time_stamp = time.time() - lctx_desc = link.Desc() - lctx_desc.id = str(round(time_stamp * 1000)) - - for rank in range(2): - port = get_free_port() - lctx_desc.add_party(f"id_{rank}", f"127.0.0.1:{port}") - - receiver_rank = 0 - server_rank = 1 - client_rank = 0 - # one-way PSI, just one party get result - broadcast_result = False - - precheck_input = False - server_cache_path = "server_cache.bin" - - def wrap( - rank, - offline_path, - online_path, - out_path, - preprocess_path, - ub_secret_key_path, - ): - link_ctx = link.create_brpc(lctx_desc, rank) - - if receiver_rank != link_ctx.rank: - print("===== gen cache phase =====") - print(f"{offline_path}, {server_cache_path}") - - gen_cache_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_GEN_CACHE'), - input_params=psi.InputParams( - path=offline_path, - select_fields=selected_fields, - precheck=False, - ), - output_params=psi.OutputParams( - path=server_cache_path, need_sort=False - ), - bucket_size=1000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ecdh_secret_key_path=ub_secret_key_path, - ) - - start = time.time() - - gen_cache_report = psi.gen_cache_for_2pc_ub_psi(gen_cache_config) - - server_source_count = wc_count(offline_path) - self.assertEqual( - gen_cache_report.original_count, server_source_count - 1 - ) - - print(f"offline cost time: {time.time() - start}") - print( - f"offline: rank: {rank} original_count: {gen_cache_report.original_count}" - ) - - print("===== transfer cache phase =====") - transfer_cache_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_TRANSFER_CACHE'), - broadcast_result=broadcast_result, - receiver_rank=receiver_rank, - input_params=psi.InputParams( - path=offline_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - bucket_size=1000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if receiver_rank == link_ctx.rank: - transfer_cache_config.preprocess_path = preprocess_path - else: - transfer_cache_config.input_params.path = server_cache_path - - print( - f"rank:{link_ctx.rank} file:{transfer_cache_config.input_params.path}" - ) - - start = time.time() - transfer_cache_report = psi.bucket_psi(link_ctx, transfer_cache_config) - - if receiver_rank != link_ctx.rank: - server_source_count = wc_count(offline_path) - self.assertEqual( - transfer_cache_report.original_count, server_source_count - 1 - ) - - print(f"transfer cache cost time: {time.time() - start}") - print( - f"transfer cache: rank: {rank} original_count: {transfer_cache_report.original_count}" - ) - - print("===== shuffle online phase =====") - shuffle_online_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_SHUFFLE_ONLINE'), - broadcast_result=False, - receiver_rank=server_rank, - input_params=psi.InputParams( - path=online_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - output_params=psi.OutputParams(path=out_path, need_sort=False), - bucket_size=10000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if client_rank == link_ctx.rank: - shuffle_online_config.preprocess_path = preprocess_path - else: - shuffle_online_config.preprocess_path = server_cache_path - shuffle_online_config.ecdh_secret_key_path = ub_secret_key_path - - print( - f"rank:{link_ctx.rank} file:{shuffle_online_config.input_params.path}" - ) - - start = time.time() - shuffle_online_report = psi.bucket_psi(link_ctx, shuffle_online_config) - - if server_rank == link_ctx.rank: - server_source_count = wc_count(offline_path) - self.assertEqual( - shuffle_online_report.original_count, server_source_count - 1 - ) - - print(f"shuffle online cost time: {time.time() - start}") - print( - f"shuffle online: rank: {rank} original_count: {shuffle_online_report.original_count}" - ) - print( - f"shuffle online: rank: {rank} intersection: {shuffle_online_report.intersection_count}" - ) - - print("===== offline phase =====") - offline_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_OFFLINE'), - broadcast_result=broadcast_result, - receiver_rank=client_rank, - input_params=psi.InputParams( - path=offline_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - output_params=psi.OutputParams(path="fake.out", need_sort=False), - bucket_size=1000000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if client_rank == link_ctx.rank: - offline_config.preprocess_path = preprocess_path - offline_config.input_params.path = "dummy.csv" - else: - offline_config.ecdh_secret_key_path = ub_secret_key_path + def wrap(rank, link_desc, configs): + link_ctx = link.create_brpc(link_desc, rank) + psi.psi(configs[rank], link_ctx) - start = time.time() - offline_report = psi.bucket_psi(link_ctx, offline_config) - - if receiver_rank != link_ctx.rank: - server_source_count = wc_count(offline_path) - self.assertEqual(offline_report.original_count, server_source_count - 1) - - print(f"offline cost time: {time.time() - start}") - print( - f"offline: rank: {rank} original_count: {offline_report.original_count}" - ) - print( - f"offline: rank: {rank} intersection_count: {offline_report.intersection_count}" - ) - - print("===== online phase =====") - online_config = psi.BucketPsiConfig( - psi_type=psi.PsiType.Value('ECDH_OPRF_UB_PSI_2PC_ONLINE'), - broadcast_result=broadcast_result, - receiver_rank=client_rank, - input_params=psi.InputParams( - path=online_path, - select_fields=selected_fields, - precheck=precheck_input, - ), - output_params=psi.OutputParams(path=out_path, need_sort=False), - bucket_size=300000, - curve_type=psi.CurveType.CURVE_FOURQ, - ) - - if receiver_rank == link_ctx.rank: - online_config.preprocess_path = preprocess_path - else: - online_config.ecdh_secret_key_path = ub_secret_key_path - online_config.input_params.path = "dummy.csv" - - start = time.time() - report_online = psi.bucket_psi(link_ctx, online_config) - - if receiver_rank == link_ctx.rank: - client_source_count = wc_count(online_path) - self.assertEqual(report_online.original_count, client_source_count - 1) - - print(f"online cost time: {time.time() - start}") - print(f"online: rank:{rank} original_count: {report_online.original_count}") - print(f"intersection_count: {report_online.intersection_count}") - - link_ctx.stop_link() - - # launch with multiprocess jobs = [ multiprocess.Process( target=wrap, - args=( - rank, - offline_path[rank], - online_path[rank], - outputs[rank], - preprocess_path[rank], - secret_key_path[rank], - ), + args=(rank, link_desc, configs), ) for rank in range(2) ] @@ -452,6 +100,11 @@ def wrap( job.join() self.assertEqual(job.exitcode, 0) + self.assertEqual( + wc_count("/tmp/spu_test_psi_alice_psi_ouput.csv"), + wc_count("/tmp/spu_test_psi_bob_psi_ouput.csv"), + ) + if __name__ == '__main__': unittest.main() diff --git a/spu/tests/ub_psi_test.py b/spu/tests/ub_psi_test.py new file mode 100644 index 000000000..78184a98e --- /dev/null +++ b/spu/tests/ub_psi_test.py @@ -0,0 +1,136 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import unittest + +import multiprocess +from google.protobuf import json_format + +import spu.libspu.link as link +import spu.psi as psi +from spu.tests.utils import create_clean_folder, create_link_desc + + +class UnitTests(unittest.TestCase): + def test_ub_psi(self): + link_desc = create_link_desc(2) + + # offline stage + server_offline_config = ''' + { + "mode": "MODE_OFFLINE", + "role": "ROLE_SERVER", + "cache_path": "/tmp/spu_test_ub_psi_server_cache", + "input_config": { + "path": "spu/tests/data/alice.csv" + }, + "keys": [ + "id" + ], + "server_secret_key_path": "/tmp/spu_test_ub_psi_server_secret_key.key" + } + ''' + + client_offline_config = ''' + { + "mode": "MODE_OFFLINE", + "role": "ROLE_CLIENT", + "cache_path": "/tmp/spu_test_ub_psi_client_cache" + } + ''' + + with open("/tmp/spu_test_ub_psi_server_secret_key.key", 'wb') as f: + f.write( + bytes.fromhex( + "000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000" + ) + ) + + create_clean_folder("/tmp/spu_test_ub_psi_server_cache") + + configs = [ + json_format.ParseDict(json.loads(server_offline_config), psi.UbPsiConfig()), + json_format.ParseDict(json.loads(client_offline_config), psi.UbPsiConfig()), + ] + + def wrap(rank, link_desc, configs): + link_ctx = link.create_brpc(link_desc, rank) + psi.ub_psi(configs[rank], link_ctx) + + jobs = [ + multiprocess.Process( + target=wrap, + args=(rank, link_desc, configs), + ) + for rank in range(2) + ] + [job.start() for job in jobs] + for job in jobs: + job.join() + self.assertEqual(job.exitcode, 0) + + # online stage + server_online_config = ''' + { + "mode": "MODE_ONLINE", + "role": "ROLE_SERVER", + "server_secret_key_path": "/tmp/spu_test_ub_psi_server_secret_key.key", + "cache_path": "/tmp/spu_test_ub_psi_server_cache" + } + ''' + + client_online_config = ''' + { + "mode": "MODE_ONLINE", + "role": "ROLE_CLIENT", + "input_config": { + "path": "spu/tests/data/bob.csv" + }, + "output_config": { + "path": "/tmp/spu_test_ubpsi_bob_psi_ouput.csv" + }, + "keys": [ + "id" + ], + "cache_path": "/tmp/spu_test_ub_psi_client_cache" + } + ''' + + configs = [ + json_format.ParseDict(json.loads(server_online_config), psi.UbPsiConfig()), + json_format.ParseDict(json.loads(client_online_config), psi.UbPsiConfig()), + ] + + link_desc = create_link_desc(2) + + def wrap(rank, link_desc, configs): + link_ctx = link.create_brpc(link_desc, rank) + psi.ub_psi(configs[rank], link_ctx) + + jobs = [ + multiprocess.Process( + target=wrap, + args=(rank, link_desc, configs), + ) + for rank in range(2) + ] + [job.start() for job in jobs] + for job in jobs: + job.join() + self.assertEqual(job.exitcode, 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/spu/tests/utils.py b/spu/tests/utils.py new file mode 100644 index 000000000..66a54f24a --- /dev/null +++ b/spu/tests/utils.py @@ -0,0 +1,51 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import subprocess +import time +from socket import socket + +import spu.libspu.link as link + + +def get_free_port(): + with socket() as s: + s.bind(("localhost", 0)) + return s.getsockname()[1] + + +def wc_count(file_name): + out = subprocess.getoutput("wc -l %s" % file_name) + return int(out.split()[0]) + + +def create_link_desc(world_size: int): + time_stamp = time.time() + lctx_desc = link.Desc() + lctx_desc.id = str(round(time_stamp * 1000)) + + for rank in range(world_size): + port = get_free_port() + lctx_desc.add_party(f"id_{rank}", f"127.0.0.1:{port}") + + return lctx_desc + + +def create_clean_folder(path: str): + if os.path.exists(path): + shutil.rmtree(path) + + os.mkdir(path) diff --git a/spu/utils/distributed.py b/spu/utils/distributed.py index 57b372278..480cbd53d 100644 --- a/spu/utils/distributed.py +++ b/spu/utils/distributed.py @@ -532,6 +532,10 @@ def __repr__(self): def builtin_spu_init( server, name: str, my_rank: int, addrs: List[str], spu_config_str: str ): + global logger + processNameFix = {'processNameCorrected': multiprocess.current_process().name} + logger = logging.LoggerAdapter(logger, processNameFix) + if f"{name}-rt" in server._locals: logger.info(f"spu-runtime ({name}) already exist, reuse it") return diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index 68ea0179d..cd5db90f1 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -305,9 +305,11 @@ def torch_compile( args_params_flat.append(state_dict_list[state_dict_idx[loc.name]]) elif loc.type_ == VariableType.INPUT_ARG: args_params_flat.append(args_flat[loc.position]) + elif loc.type_ == VariableType.CONSTANT: + args_params_flat.append(shlo._bundle.additional_constants[loc.position]) else: raise RuntimeError( - 'Currently only torch models with named parameters and buffers are supported' + f'Currently only torch models with named parameters and buffers are supported. Type {loc.type_} is not supported.' ) input_names = [f'{id(name)}-in{idx}' for idx in range(len(args_params_flat))]