diff --git a/bazel/patches/seal.patch b/bazel/patches/seal.patch deleted file mode 100644 index 20467095..00000000 --- a/bazel/patches/seal.patch +++ /dev/null @@ -1,262 +0,0 @@ -diff --git a/native/src/seal/serializable.h b/native/src/seal/serializable.h -index a940190..e490b30 100644 ---- a/native/src/seal/serializable.h -+++ b/native/src/seal/serializable.h -@@ -135,6 +135,9 @@ namespace seal - return obj_.save(out, size, compr_mode); - } - -+ const T& obj() const { return obj_; } -+ -+ T& obj() { return obj_; } - private: - Serializable(T &&obj) : obj_(std::move(obj)) - {} - -diff --git a/native/src/seal/context.cpp b/native/src/seal/context.cpp -index 887a1312..932d9774 100644 ---- a/native/src/seal/context.cpp -+++ b/native/src/seal/context.cpp -@@ -477,7 +477,8 @@ namespace seal - // more than one modulus in coeff_modulus. This is equivalent to expanding - // the chain by one step. Otherwise, we set first_parms_id_ to equal - // key_parms_id_. -- if (!context_data_map_.at(key_parms_id_)->qualifiers_.parameters_set() || parms.coeff_modulus().size() == 1) -+ if (!context_data_map_.at(key_parms_id_)->qualifiers_.parameters_set() || parms.coeff_modulus().size() == 1 || -+ !parms.use_special_prime()) - { - first_parms_id_ = key_parms_id_; - } -diff --git a/native/src/seal/encryptionparams.cpp b/native/src/seal/encryptionparams.cpp -index 31e07441..c34d0a45 100644 ---- a/native/src/seal/encryptionparams.cpp -+++ b/native/src/seal/encryptionparams.cpp -@@ -23,8 +23,10 @@ namespace seal - uint64_t poly_modulus_degree64 = static_cast(poly_modulus_degree_); - uint64_t coeff_modulus_size64 = static_cast(coeff_modulus_.size()); - uint8_t scheme = static_cast(scheme_); -+ uint8_t use_special_prime = static_cast(use_special_prime); - - stream.write(reinterpret_cast(&scheme), sizeof(uint8_t)); -+ stream.write(reinterpret_cast(&use_special_prime), sizeof(uint8_t)); - stream.write(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); - stream.write(reinterpret_cast(&coeff_modulus_size64), sizeof(uint64_t)); - for (const auto &mod : coeff_modulus_) -@@ -63,6 +65,10 @@ namespace seal - // This constructor will throw if scheme is invalid - EncryptionParameters parms(scheme); - -+ uint8_t use_special_prime; -+ stream.read(reinterpret_cast(&use_special_prime), sizeof(uint8_t)); -+ parms.set_use_special_prime(use_special_prime); -+ - // Read the poly_modulus_degree - uint64_t poly_modulus_degree64 = 0; - stream.read(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); -@@ -128,7 +134,8 @@ namespace seal - size_t total_uint64_count = add_safe( - size_t(1), // scheme - size_t(1), // poly_modulus_degree -- coeff_modulus_size, plain_modulus_.uint64_count()); -+ size_t(1), // use_special_prime -+ coeff_modulus_size); - - auto param_data(allocate_uint(total_uint64_count, pool_)); - uint64_t *param_data_ptr = param_data.get(); -@@ -139,13 +146,15 @@ namespace seal - // Write the poly_modulus_degree. Note that it will always be positive. - *param_data_ptr++ = static_cast(poly_modulus_degree_); - -+ *param_data_ptr++ = static_cast(use_special_prime_); - for (const auto &mod : coeff_modulus_) - { - *param_data_ptr++ = mod.value(); - } - -- set_uint(plain_modulus_.data(), plain_modulus_.uint64_count(), param_data_ptr); -- param_data_ptr += plain_modulus_.uint64_count(); -+ // NOTE(juhou): we skip the plain modulus for parms_id -+ // set_uint(plain_modulus_.data(), plain_modulus_.uint64_count(), param_data_ptr); -+ // param_data_ptr += plain_modulus_.uint64_count(); - - HashFunction::hash(param_data.get(), total_uint64_count, parms_id_); - -diff --git a/native/src/seal/encryptionparams.h b/native/src/seal/encryptionparams.h -index 9e1fbe48..eb71c4ac 100644 ---- a/native/src/seal/encryptionparams.h -+++ b/native/src/seal/encryptionparams.h -@@ -266,6 +266,11 @@ namespace seal - random_generator_ = std::move(random_generator); - } - -+ inline void set_use_special_prime(bool flag) -+ { -+ use_special_prime_ = flag; -+ } -+ - /** - Returns the encryption scheme type. - */ -@@ -274,6 +279,11 @@ namespace seal - return scheme_; - } - -+ bool use_special_prime() const noexcept -+ { -+ return use_special_prime_; -+ } -+ - /** - Returns the degree of the polynomial modulus parameter. - */ -@@ -501,6 +511,8 @@ namespace seal - - Modulus plain_modulus_{}; - -+ bool use_special_prime_ = true; -+ - parms_id_type parms_id_ = parms_id_zero; - }; - } // namespace seal - -diff --git a/native/src/seal/evaluator.cpp b/native/src/seal/evaluator.cpp -index dabd3bab..afaa71dc 100644 ---- a/native/src/seal/evaluator.cpp -+++ b/native/src/seal/evaluator.cpp -@@ -2382,6 +2382,7 @@ namespace seal - size_t encrypted_size = encrypted.size(); - // Use key_context_data where permutation tables exist since previous runs. - auto galois_tool = context_.key_context_data()->galois_tool(); -+ bool is_ntt_form = encrypted.is_ntt_form(); - - // Size check - if (!product_fits_in(coeff_count, coeff_modulus_size)) -@@ -2412,7 +2413,7 @@ namespace seal - // DO NOT CHANGE EXECUTION ORDER OF FOLLOWING SECTION - // BEGIN: Apply Galois for each ciphertext - // Execution order is sensitive, since apply_galois is not inplace! -- if (parms.scheme() == scheme_type::bfv) -+ if (not is_ntt_form) - { - // !!! DO NOT CHANGE EXECUTION ORDER!!! - -@@ -2426,7 +2427,7 @@ namespace seal - // Next transform encrypted.data(1) - galois_tool->apply_galois(encrypted_iter[1], coeff_modulus_size, galois_elt, coeff_modulus, temp); - } -- else if (parms.scheme() == scheme_type::ckks || parms.scheme() == scheme_type::bgv) -+ else - { - // !!! DO NOT CHANGE EXECUTION ORDER!!! - -@@ -2440,10 +2441,6 @@ namespace seal - // Next transform encrypted.data(1) - galois_tool->apply_galois_ntt(encrypted_iter[1], coeff_modulus_size, galois_elt, temp); - } -- else -- { -- throw logic_error("scheme not implemented"); -- } - - // Wipe encrypted.data(1) - set_zero_poly(coeff_count, coeff_modulus_size, encrypted.data(1)); -@@ -2530,6 +2527,7 @@ namespace seal - auto &key_context_data = *context_.key_context_data(); - auto &key_parms = key_context_data.parms(); - auto scheme = parms.scheme(); -+ bool is_ntt_form = encrypted.is_ntt_form(); - - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted)) -@@ -2559,14 +2557,6 @@ namespace seal - { - throw invalid_argument("pool is uninitialized"); - } -- if (scheme == scheme_type::bfv && encrypted.is_ntt_form()) -- { -- throw invalid_argument("BFV encrypted cannot be in NTT form"); -- } -- if (scheme == scheme_type::ckks && !encrypted.is_ntt_form()) -- { -- throw invalid_argument("CKKS encrypted must be in NTT form"); -- } - if (scheme == scheme_type::bgv && !encrypted.is_ntt_form()) - { - throw invalid_argument("BGV encrypted must be in NTT form"); -@@ -2605,7 +2595,7 @@ namespace seal - set_uint(target_iter, decomp_modulus_size * coeff_count, t_target); - - // In CKKS or BGV, t_target is in NTT form; switch back to normal form -- if (scheme == scheme_type::ckks || scheme == scheme_type::bgv) -+ if (is_ntt_form) - { - inverse_ntt_negacyclic_harvey(t_target, decomp_modulus_size, key_ntt_tables); - } -@@ -2632,7 +2622,7 @@ namespace seal - ConstCoeffIter t_operand; - - // RNS-NTT form exists in input -- if ((scheme == scheme_type::ckks || scheme == scheme_type::bgv) && (I == J)) -+ if (is_ntt_form && (I == J)) - { - t_operand = target_iter[J]; - } -@@ -2789,7 +2779,7 @@ namespace seal - SEAL_ITERATE(t_ntt, coeff_count, [fix](auto &K) { K += fix; }); - - uint64_t qi_lazy = qi << 1; // some multiples of qi -- if (scheme == scheme_type::ckks) -+ if (is_ntt_form) - { - // This ntt_negacyclic_harvey_lazy results in [0, 4*qi). - ntt_negacyclic_harvey_lazy(t_ntt, get<2>(J)); -@@ -2802,7 +2792,7 @@ namespace seal - qi_lazy = qi << 2; - #endif - } -- else if (scheme == scheme_type::bfv) -+ else - { - inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<2>(J)); - } - -diff --git a/CMakeLists.txt b/CMakeLists.txt -index 1a7a2bfd..bc4ad9d9 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -223,7 +223,7 @@ if(SEAL_USE_INTEL_HEXL) - message(STATUS "Intel HEXL: download ...") - seal_fetch_thirdparty_content(ExternalIntelHEXL) - else() -- find_package(HEXL 1.2.4) -+ find_package(HEXL 1.2.5) - if (NOT TARGET HEXL::hexl) - message(FATAL_ERROR "Intel HEXL: not found") - endif() - -diff --git a/native/src/seal/evaluator.h b/native/src/seal/evaluator.h -index 33bc3c7d..8a00ebea 100644 ---- a/native/src/seal/evaluator.h -+++ b/native/src/seal/evaluator.h -@@ -1199,6 +1199,10 @@ namespace seal - */ - struct EvaluatorPrivateHelper; - -+ void switch_key_inplace( -+ Ciphertext &encrypted, util::ConstRNSIter target_iter, const KSwitchKeys &kswitch_keys, -+ std::size_t key_index, MemoryPoolHandle pool = MemoryManager::GetPool()) const; -+ - private: - Evaluator(const Evaluator ©) = delete; - -@@ -1257,10 +1261,6 @@ namespace seal - apply_galois_inplace(encrypted, galois_tool->get_elt_from_step(0), galois_keys, std::move(pool)); - } - -- void switch_key_inplace( -- Ciphertext &encrypted, util::ConstRNSIter target_iter, const KSwitchKeys &kswitch_keys, -- std::size_t key_index, MemoryPoolHandle pool = MemoryManager::GetPool()) const; -- - void multiply_plain_normal(Ciphertext &encrypted, const Plaintext &plain, MemoryPoolHandle pool) const; - - void multiply_plain_ntt(Ciphertext &encrypted_ntt, const Plaintext &plain_ntt) const; diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 1318e460..7223f19e 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -29,7 +29,6 @@ def spu_deps(): _com_github_emptoolkit_emp_tool() _com_github_emptoolkit_emp_ot() _com_github_facebook_zstd() - _com_github_microsoft_seal() _com_github_eigenteam_eigen() _com_github_nvidia_cutlass() _yacl() @@ -40,10 +39,10 @@ def _yacl(): http_archive, name = "yacl", urls = [ - "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3_nightly_20240722.tar.gz", + "https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b4_nightly_20240731.tar.gz", ], - strip_prefix = "yacl-0.4.5b3_nightly_20240722", - sha256 = "ccca599e6ded6089c5afbb87c8f5e09383195af256caacd50089f0c7443e8604", + strip_prefix = "yacl-0.4.5b4_nightly_20240731", + sha256 = "952715bd56f6d9386984e9963426a1399bd2bd3702cf3efede9c82591cfab99b", ) def _libpsi(): @@ -51,10 +50,10 @@ def _libpsi(): http_archive, name = "psi", urls = [ - "https://github.com/secretflow/psi/archive/refs/tags/v0.4.1.dev240722.tar.gz", + "https://github.com/secretflow/psi/archive/refs/tags/v0.4.0.dev240801.tar.gz", ], - strip_prefix = "psi-0.4.1.dev240722", - sha256 = "878cd8af2c7b9850944a27adf91f21dd4937d09d38e8365baad3b5165db8b39a", + strip_prefix = "psi-0.4.0.dev240801", + sha256 = "541ad74de0cd9e6bffe348c3bc97e659fccb1f1811e612f9d8e6b1debdd7c2a0", ) def _rules_proto_grpc(): @@ -225,21 +224,6 @@ def _com_github_emptoolkit_emp_ot(): build_file = "@spulib//bazel:emp-ot.BUILD", ) -def _com_github_microsoft_seal(): - maybe( - http_archive, - name = "com_github_microsoft_seal", - sha256 = "acc2a1a127a85d1e1ffcca3ffd148f736e665df6d6b072df0e42fff64795a13c", - strip_prefix = "SEAL-4.1.2", - type = "tar.gz", - patch_args = ["-p1"], - patches = ["@spulib//bazel:patches/seal.patch"], - urls = [ - "https://github.com/microsoft/SEAL/archive/refs/tags/v4.1.2.tar.gz", - ], - build_file = "@spulib//bazel:seal.BUILD", - ) - def _com_github_eigenteam_eigen(): EIGEN_COMMIT = "66e8f38891841bf88ee976a316c0c78a52f0cee5" EIGEN_SHA256 = "01fcd68409c038bbcfd16394274c2bf71e2bb6dda89a2319e23fc59a2da17210" diff --git a/experimental/squirrel/README.md b/experimental/squirrel/README.md index affce442..6d2eb815 100644 --- a/experimental/squirrel/README.md +++ b/experimental/squirrel/README.md @@ -37,7 +37,8 @@ Code under this folder is purely for research demonstration and it's **NOT desig ```sh bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=85 --rank1_nfeatures=85 --standalone=true --train=BinaryClassification_Aps_Test_60000_171.csv --test=BinaryClassification_Aps_Test_16000_171.csv --rank=1 --has_label=1 --lr=1.0 --subsample=0.8 ``` -* Run on distributed dataset, e.g., using the `breast_cancer` dataset from the SPU repo. + +* Run on distributed dataset, e.g., using the `breast_cancer` dataset from the SPU repo. * On one terminal ```sh @@ -49,4 +50,3 @@ Code under this folder is purely for research demonstration and it's **NOT desig ```sh bazel-bin/experimental/squirrel/squirrel_demo_main --rank0_nfeatures=15 --rank1_nfeatures=15 --standalone=false --train=examples/data/breast_cancer_b.csv --rank=1 --has_label=1 --lr=1.0 --subsample=0.8 ``` - diff --git a/experimental/squirrel/bin_matvec_prot_test.cc b/experimental/squirrel/bin_matvec_prot_test.cc index 9ef145ad..62c07ead 100644 --- a/experimental/squirrel/bin_matvec_prot_test.cc +++ b/experimental/squirrel/bin_matvec_prot_test.cc @@ -205,7 +205,7 @@ TEST_P(BinMatVecProtTest, EmptyMat) { }); NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using sT = std::make_signed::type; NdArrayView _vec(vec); auto expected = BinAccumuate(_vec, mat); @@ -255,7 +255,7 @@ TEST_P(BinMatVecProtTest, WithEmptyIndicator) { }); NdArrayRef reveal = ring_add(out_shr[0], out_shr[1]); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, [&]() { using sT = std::make_signed::type; NdArrayView _vec(vec); NdArrayView got(reveal); diff --git a/sml/linear_model/tests/glm_test.py b/sml/linear_model/tests/glm_test.py index a9e7c734..f42b09b6 100644 --- a/sml/linear_model/tests/glm_test.py +++ b/sml/linear_model/tests/glm_test.py @@ -106,7 +106,7 @@ def accuracy_test(model, std_model, y, coef, num=5): assert norm_diff < 1e-2 -def proc_test(proc): +def proc_test(proc, x, y): """ Test if the results of the specified fitting algorithm are correct. @@ -121,8 +121,8 @@ def proc_test(proc): """ # Run the simulation and get the results - sim_res = spsim.sim_jax(sim, proc)() - res = proc() + sim_res = spsim.sim_jax(sim, proc)(x, y) + res = proc(x, y) # Calculate the difference between simulation and actual results norm_diff = jnp.linalg.norm(sim_res - res) @@ -130,10 +130,10 @@ def proc_test(proc): print(proc.__name__, "-norm_diff:", "%.5f" % norm_diff) # Assert that the difference is within the tolerance - assert norm_diff < 1e-4 + assert norm_diff < 5e-1 -def proc_ncSolver(): +def proc_ncSolver(X, y): """ Fit Generalized Linear Regression model using Newton-Cholesky algorithm and return the model coefficients. @@ -163,7 +163,7 @@ def proc_lbfgsSolver(): return model.coef_ -def proc_Poisson(): +def proc_Poisson(X, round_exp_y): """ Fit Generalized Linear Regression model using PoissonRegressor and return the model coefficients. @@ -178,7 +178,7 @@ def proc_Poisson(): return model.coef_ -def proc_Gamma(): +def proc_Gamma(X, exp_y): """ Fit Generalized Linear Regression model using GammaRegressor and return the model coefficients. @@ -193,7 +193,7 @@ def proc_Gamma(): return model.coef_ -def proc_Tweedie(): +def proc_Tweedie(X, exp_y): """ Fit Generalized Linear Regression model using TweedieRegressor and return the model coefficients. @@ -239,22 +239,22 @@ def test_Tweedie_accuracy(self, power=1.5): def test_ncSolver_encrypted(self): # Test if the results of the Newton-Cholesky solver are correct after encryption - proc_test(proc_ncSolver) + proc_test(proc_ncSolver, X, y) print('test_ncSolver_encrypted: OK') def test_Poisson_encrypted(self): # Test if the results of the PoissonRegressor model are correct after encryption - proc_test(proc_Poisson) + proc_test(proc_Poisson, X, round_exp_y) print('test_Poisson_encrypted: OK') def test_gamma_encrypted(self): # Test if the results of the GammaRegressor model are correct after encryption - proc_test(proc_Gamma) + proc_test(proc_Gamma, X, exp_y) print('test_gamma_encrypted: OK') def test_Tweedie_encrypted(self): # Test if the results of the TweedieRegressor model are correct after encryption - proc_test(proc_Tweedie) + proc_test(proc_Tweedie, X, exp_y) print('test_Tweedie_encrypted: OK') diff --git a/spu/libpsi.cc b/spu/libpsi.cc index fec0f233..0beb06eb 100644 --- a/spu/libpsi.cc +++ b/spu/libpsi.cc @@ -104,16 +104,30 @@ void BindLibs(py::module& m) { "Run UB PSI with v2 API.", NO_GIL); m.def( - "pir", + "apsi_send", [](const std::string& config_pb, const std::shared_ptr& lctx) -> py::bytes { - psi::PirConfig config; + psi::ApsiSenderConfig config; YACL_ENFORCE(config.ParseFromString(config_pb)); auto r = psi::RunPir(config, lctx); return r.SerializeAsString(); }, - py::arg("pir_config"), py::arg("link_context") = nullptr, "Run PIR."); + py::arg("pir_config"), py::arg("link_context") = nullptr, + "Run APSI sender operations."); + + m.def( + "apsi_receive", + [](const std::string& config_pb, + const std::shared_ptr& lctx) -> py::bytes { + psi::ApsiReceiverConfig config; + YACL_ENFORCE(config.ParseFromString(config_pb)); + + auto r = psi::RunPir(config, lctx); + return r.SerializeAsString(); + }, + py::arg("pir_config"), py::arg("link_context") = nullptr, + "Run APSI receiver operations."); } PYBIND11_MODULE(libpsi, m) { diff --git a/spu/psi.py b/spu/psi.py index 2796ab52..581bd3d3 100644 --- a/spu/psi.py +++ b/spu/psi.py @@ -19,7 +19,11 @@ from . import libpsi # type: ignore from .libpsi.libs import ProgressData from .libspu.link import Context # type: ignore -from .pir_pb2 import PirConfig, PirProtocol, PirResultReport # type: ignore +from .pir_pb2 import ( # type: ignore + ApsiReceiverConfig, + ApsiSenderConfig, + PirResultReport, +) from .psi_pb2 import ( # type: ignore BucketPsiConfig, CurveType, @@ -144,8 +148,16 @@ def ub_psi( return report -def pir(config: PirProtocol, link: Context = None) -> PirResultReport: - report_str = libpsi.libs.pir(config.SerializeToString(), link) +def apsi_send(config: ApsiSenderConfig, link: Context = None) -> PirResultReport: + report_str = libpsi.libs.apsi_send(config.SerializeToString(), link) + + report = PirResultReport() + report.ParseFromString(report_str) + return report + + +def apsi_receive(config: ApsiReceiverConfig, link: Context = None) -> PirResultReport: + report_str = libpsi.libs.apsi_receive(config.SerializeToString(), link) report = PirResultReport() report.ParseFromString(report_str) diff --git a/spu/tests/data/100K-1-16.json b/spu/tests/data/100K-1-16.json new file mode 100644 index 00000000..6ebf1819 --- /dev/null +++ b/spu/tests/data/100K-1-16.json @@ -0,0 +1,19 @@ +{ + "table_params": { + "hash_func_count": 1, + "table_size": 409, + "max_items_per_bin": 42 + }, + "item_params": { + "felts_per_item": 5 + }, + "query_params": { + "ps_low_degree": 0, + "query_powers": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42 ] + }, + "seal_params": { + "plain_modulus": 65537, + "poly_modulus_degree": 2048, + "coeff_modulus_bits": [ 48 ] + } +} diff --git a/spu/tests/data/BUILD.bazel b/spu/tests/data/BUILD.bazel index 1eadb581..c5e9d338 100644 --- a/spu/tests/data/BUILD.bazel +++ b/spu/tests/data/BUILD.bazel @@ -17,8 +17,12 @@ package(default_visibility = ["//visibility:public"]) filegroup( name = "data", data = [ + "100K-1-16.json", "alice.csv", "bob.csv", "carol.csv", + "db.csv", + "ground_truth.csv", + "query.csv", ], ) diff --git a/spu/tests/data/db.csv b/spu/tests/data/db.csv new file mode 100644 index 00000000..207e86e7 --- /dev/null +++ b/spu/tests/data/db.csv @@ -0,0 +1,100 @@ +aPYaKgvvcESwAtfghnRUIAYYIZeCsGeaWbAAEUCXQzNrxGRPVOcACqMBJmdfiveq,LdQNbKmBMhlpCctB +eXeLfYoSlEntpRaKuddYhtaImMOdEhNTIolxElSrlPYMhgZoWxccpUTOjFciaRcD,wcsaUIrVHmxDxnXr +gOUgSoWKKYaLpmWiuTyuVNostvTJUHxBZjYWJZukOzqICDlmKdavyERZimgOkaHn,FgdZUDbsfINsODTe +BrImQzkUpelejHVFeatNDJTqdTuxCJmHKunkODDPsUtsgEfCkXyLNEAZXniJZgvf,PkxxKRwEEEwrcLbR +RhoATcXotsJDbAICIkIbAYdDUXqreTDbZVPoSyziyQJvBVVaZhYqWmHVQsqNbdIP,OTFagQCudOpRtKqT +bibVQsoLIFjXrSFIjPFOTWclPDhakarsWICsYNvbMtlNwBIKqnaWmvCPcFPesBVy,YysQfuPwTiXayxlt +CQZcsLKmugDFfsYAOihJENLpBwCYwoLuATHrMiIChXkycvhejgjjUCoxKDpTDMuZ,OIGhXryCxSXCZKBn +aizidqjurwmpHHqioJclEQvccvnWQKmZTMBNSGgmfxxqHilZwFPdjHlZElQaayJN,SjCyivgzSTLpXmnz +eNGJWxNcevvcthSXdXLseTORLpDVdVCAdnbQtcovUVAislbievgctCGMidsncXyL,rFJEAlFBMUMZNPrc +OQArDeXFlhxHBPERmlYuwZHQAgCQtRorzCkDenmMdnbFdGKJggqmhApWMrbtqDZw,saeWaHguQALRalnZ +AZKVktpAYSLSZTpjylrifvdItAuhXvzvDhtWnbjMlshYFTooHXtwNfCsmiuWgXLa,sntRGrWUapWvBwAM +xNTfGmPFpdczqsPvlLFMSzXLdqGkyupfwXepkfJHECoAIDZaRwqoIullNzMfWHLK,jVerCzWaKQoYYBeP +hKFRJSHErdbPbPFRZpEkigOemyqRgRDzWHMhiVDlHDtTAjKHZcQCogTvFlbwVVbA,ftLFWLGNrSkavyEX +tGYkoEeaQinZmEsPtFyhyJxXGiJreOCTHTxPfSWFFQwqChLKADXjeWiTsVnLyFTb,vKUHgdGvKkJKSAcJ +KSyQGEXCvzziOdoaIwQYynSjQlyPOkJYroZAxblcGRHOqhjWuByOVIUYfuacruUQ,NnsQMUvXYZOhswVL +WsQkdjbaExqvNkQlbMWqwxtChiWRBsmylFanbONarkbmmWAnKPvCdPLlhQtvFaXP,phEafGnFQIcoYCKG +rESDWnSxXeTvwvTDxSsSmwPswujcWVlrvlgVAuGUYjvyXFjCSGenqEZhfFHgZXIB,sHxPvGwZWDbJgCYn +UKgGYBiPUdYdfuQZrpKEtprGxofPaLVqhdiuSUzKrSDzXuCgYuOQJlyhsNHyLrCx,sdNdnkjuRNLyPhaB +eEXLhpsjwQQRFVXxPgqQnFMCQDAMOMxwoAeeubrJXwxKaiWgnilOwjzoEZuRUJLV,WXVgFQQFBEwRomjA +QVGpdCqJIpZStCSUOEKyEoOJJurZSWmQZDCnIrANHGJYhpbfxAhsPvrVZnVrhQKn,bUIqpQPVvtiPhItz +UOmpkfibQXxQlYJQzZgdfoIHckIWHWiVBcoLaLSQnlpnIBQZcrnCEXCfHTDwsFDX,OdckqFMcdfnVnBoB +TDzDtfzzMugZZxhNGhmwYsMrOCFvCUWmmOUOLNGAYMRMnZVGuOMSXZZgaTufrqXK,DAToiYOldpgNGOqn +zmRKIEQFtIjCYUXaFgyAvZZEDIukHAwYlzUwxbttWndcAGFEoRzGyAUuLsKnbfZi,rfwdiePuXVtvKgat +vmtpIcBkPJyFRqKWIYHWcecKdgCoUShJwkhYvjHZPdhwmcdBGwQDDVynyOwSZcYj,PDNsnMKRZubVpMRT +EjQMHoawzxMREpZaJFKJBNsnKdzQTWeGmAMkhsuSfEzoDpQfdUWUeTFKvKClRNPz,rauheCdowFiOAMFk +mjjTWkjovIsCsMuZfdtXIKVZEcwuspLRUtCVPKpMdkkaGQtUUmFrXaZHaDuPKvsa,wCscOftxAHuBnsSW +MjvTMaePwIVFpEbspToomYGFAOmpGuKlmgJvIOhtVoHNgWaHReuMELUapHWAaZjL,MeizqMvAktGZLkCH +EwJZCBgPDuKRnTTZwuJRKfkznXpHGdbfMOZTnVjixKGciMLkdLSzWBXkBhMGzwSS,RhpjUzFsJtSSXund +HSaBXSGcBxYSUIXYnlFnYrdTclIehDdMhKqIRJuAYebfViJttknfMmCqbyYOJAXE,NIZPwgQebsKBehaN +fLOyQDLsIUaWZUjwzsrxlGHlGTYNWVZyTEWJZenWqZiMqHEpLWAvGojmOQvteOqS,XGTuKgtLshqQUtfr +wvKVPbYksmYXTsRqvJETrjXJethrvgmBLIwMQhJBCMTfLGOFKHwxrrBcGQqdMjZe,ICnEQAovJhrWaIiY +WpbdxcjMqKMkLdSlUBowCTWDGVtRJLiEDQytMenWEIkWFWLKByiEhrvIpCncUQDS,xEGgXHFKHYDlGdYF +AVSBpQelmdheyUZPdmRrhrEqHmKowFAIDNjxzphVCoLgSypBfHNtVuDgoIVqCoLF,ZLFAfNImdiEwcupl +XtndzLknyWTeElGFXjZfbrHGqYzqHcTzEtXquKkpckuwhkQPcCkmXIhfCnLYCrVG,UHbNwRmAJFMalnbt +lvLHBgacmSdZJqpzrezjTYTfWIBFUDIaMcGwErtmnAwgjDXwmIHxMDqYTrJvjUyq,NoWnhJQWDJVCtExB +SKenZwjdFvsAiARRmpBzTAXGWtByjJcIniiAhovlsAHLZXQJCmDyRJxKevZBttDa,mHAfCtCExIyoRuWG +vKgaRAkJVOSESTgmEVXpQIVXDADJiHmFFaAwxtjwUyFVrQouyJcZeDwhMUZPROkA,CiuvQKMCsVGLkoBM +cdxMDaSfvmlcpSfFqvfzgNIyUcmkDEyVswXcCKCJfYyAqrSCGWBGIEQlBxKTWSCj,LTIsYYovMEnpawzQ +bnfGLKRBqwOCQBNWSRAbEWVLqyUAlzrYNiJRWUZnmGnXtjMFBQhLHBhVJBygIrGU,vLYmsxCGzRXOMiqB +bxidiEYIOjbFDbHRqnaYXZuQcZJbNxynsmjNPCpzujEKzATaBeTrUchoylhvqLjx,AXlyMrMWiWGMqoIs +eSuMvTbZgMRDrEIxJwgFYdpNWkQzEzrsyyybeaJlUPEhEZZBWpPwQFqImIGnLFar,BTfGxXWpxvvPzcDS +uONlUmHMFQiPkCfdPrqqUDleaUBHKnxuQbFJAqfMqzoSpqPzawdOIvtVQMSWfCRv,uFDJTeJjZNFjIxKT +IEWyZnggNMulCyYklMdZaMYiIsqQNtbzbcpMHBUfPeOKaoSCMeezBSqcQwVXNJho,zpDlpaygXXZyslHj +ItlLEykbpzIeTErDcbaxxfzntBAYcPHVcLFneOzhNhxYYgwsbKZEmHHuHTnPnhSW,nKPAjZgLwcVTMNGi +uTQQIgVItVwPUSygrxoUwrLuFbAbahqbnixUuRIKRnJConAViRHYsRerKyEieFYI,fhPHVglAfXtHjNae +etMRJJVQSxIMgvzdSoCPsVpGJcKtjpXtMqtzXgaGQDryplTJifNvOFYGWChHLOUo,WJrEYWYGMOqveCgd +BYHllipDRYZVvMQYYhIRzLHabgftPTSnFbUCRmZejFUoeLLQoZtrZPJrTjVqWfNO,EztmadvCnpbgQtBl +ZEkjAbQOpxQrbtEDVlDhKgChCNsxTxSQtUXARrEeVJQrzPuVPkYHuuoMXjVCyeCk,kcgMJGMwDDDiFDSP +WdhxSOyjKNLItJZiXZtkKBdIcjGLuLbZHPmSfJCzlvqBxnrobjDTPxsFRXhEInhh,OPuAWqKpmpZHbvIY +wSGlgcmRasbLYVClIhhCppYezjZrWIhhaiASQcDrCDxdsGJIJjNmTWwbsFuKlbAt,VToAPASiTJHUGIEk +CFzAfFrDFyhndtOJalJiSiufNlCjWcwxwQjnjRgbqFlaAlIzXgrwVJmwISEAKHxx,jpylGyqZUinEIiVo +BwluSThaftzdTOrrzjBfqmHdbXqLUDEPPqmduYZESLtaSQAWLOKeuECRKPDEumJA,vgocckgnbQZCgpSq +OUguorluoGFpgVuXujAFBOIkBsIAaNCXgcywsWjutvEcrJrrDBRAHgcKwnfNLpXr,vwsRUvSJpupXvGta +TXQWeuibBvqDmmTaLAPZNsEHosjhcLBsixvomJaiAPmLmBDemETNOZMwrwREVRir,EIgmXjWGNYpcfdOs +BbQiYWirQnKUzBachIeJZgJWvmeZkpBUphzYEGrGcGpxUwWvSSnSYpQzBRsOFuzC,oUBiVRJNarkuUdRV +dugMhlZPXHEnWTZjaOFaxmmZgIHdVBmzUsIfUUPZkdVzDvCLRZyTBPrHhDVrYAVk,OdxERipabMlyoyEa +aKPvrmNwyclBMcPMEAgGItoshSJVrSonWVrcHMWiBXxqTpdHjuGKrRLTaHhQyunk,TmBLACtwsObAkgoz +LSndhzUuoIbprGCzRDfZryKqdcqwLpLWYfHoOHgBJDSkZRoYMQxmIoVCUdBSxHsZ,qvhvMIqFvXELhWqK +lInwjXRJTuZgrvcbDEvFHgPFGpqlnSuIJWtNRJYizWEfZJbZtfLexmEQMnGLxNlW,MURWsZhFTeLoqeAd +DOlyDBVEZFruhsBwHZgrnWTXckcVJcVzrwniSnJFEYUNiFgkIyukstlbdrluVhag,vUuhogLghMFjNjyK +zgDxGOAAYMGtAOoMviwtSQDLEpHluVuqFsqisVvoKCLfnMdPVTKgKCrchKrAmlmz,YaojDSYxfntenXUp +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv,eweOmajnCQLZOgBp +nJJyhHkDkQRttGsYjkMBBuGeQuPPDHQEQQdnGQmMbOdRFifRDpZVUqdfqeskxngR,nyikwrUYIuWbTawJ +ZcwBhgoFFiWyDeZbeMpliSvhbfsAkccEQQhYreLTGdfVHuNLpmCsduhkIlRKMNkx,nWeXVNOUmsoakAMC +BCJYdYFkkLRWUxhnnwpJGbXJchPvVCtbcaMdkArTcNLdRmopwncgdgOLGhJkZOnC,JWAVSCWqKUcmfoqk +NkxNbLyBjixoCzClTdwshhuZcFRjJJdDdWgCfiQIttZWQWqouBkYyMGpampLdUAr,nkuaUoDVcCUBUoOv +WgwvZoDsjspDAWHEflQMWzlbqnssWiBElmABhLmhgDPqFbNmAHSnzQrbAqSVAmWS,IQIlCVVcjVcTdshJ +yXelVXMUEuAtfNgzPrhjvYOpiAVEMZuqPfsQEUQoshjSIekxxzkFxdftfqFzfzpa,AyTbbGVAXKCYUUln +sAtWqpSPxPSkDtmIJKfNvlKjgStnYMOmrLsQnzmIFAusETPPzLDTjKcBASKWNRAJ,bDQFllamogAjBEPU +RAOMDMMZkezCBWxQDWLjvHLkbpvFyrbUbDDEekWciXejYwKifSVumcsocUmMkmpa,MnwKAECsMVGVLIZM +LMizeIoxMxHCKwikjqOSSPbuiqXWDAmbTLMBOXpyorUmpunjWFTNLVqvHNcCNHrN,zDzyAkfeTYZzaxhG +IzLqciYsaKtWrsrjOQldeIvEqavoIZEYnupJZLizJVeOhoLqtLQFaoNRdvZMWSQH,klLegifptLAxnhha +VkyQsnlXcIGjGhcSJUcZKeQyiDUbIcgIWSbHaEsbfEydSTHqRlxImGdYGEurZczg,jyzyCGlvuBdKwyIX +dkSRcFpHPXWjNIHrCpWlOPaIkVqjtyPRhlJeYMksjieDxYhiUcGhbuvamVlrMDDx,FATEzwNXGerINvHD +KravfnzVbNOhLhstPcaLLVWpbqYzXckQGbuALlEiXqqhFfUyThFZFhzSLhjldPMB,sowcmsLQTFiKpNXy +ITZWxxjqnizlxuRWMlPQLUnBopyDtOxNfcaoFDbRKetIpVxKLSRoJOauSCcDwWUP,NmqaomxmRQGqiKCV +guhzFIXdUEACMMHObJptkrZqJgbclDcdRxCPYSvuAdITaKgHfJaKLNHFzdRmpHni,DiONbLbxHfUHhoTU +AnZjaIBmnsEBplkEmHBstdggPnYmhblyQQttVqYzxxNtOXwlNQetkvCOySSXRUpw,FCdUtLDyvAqszerb +wkCcCXGKcZJEDwTkzOoNRkMbxHdNciQlVruGSKcJrHpokspcZIVfupcTxapISupH,pzEvnKzLQbzNDSQN +ppieExXmNHqBXVgLFhjlHHHhHSAddipMCmPXhXDfHZVTtNhqcMMVauyjKOFGBHPe,tXpiHGkKGTzzMluO +lnVknQNrrYyqFbEKYPxsQWNPKLpsEVmUGtbbWMWDThMuScSByeZRwuusLYzKPbHE,HzxgCCtiIFYvgwWO +oYSnlwjpsWaNzYunBnhNLwiICrmAEFiZRczbdHYpQgwSrrMQCixgtjfCGOptTkmd,IYmbhaQueIKQvcBc +saWvQlIiiYqAPmcEDGsVXNAIJNNGTyZKhrMMKYHXJQnniGVuIClgwvAEXeIPGeFN,epReKmWNANFpINhn +PZIEJMwirPArOGCfJJAfdwGydRDBGGQojUzWFJtVoJZTFAFYwaDOuLFruvRjolHq,yCMGUjViZoTPMTtg +bMCOdGAYjXaDSPyZyGegyuRnnwYSySrRzbLbvtgBjfFXfCMPIVIGFTagRyBpiKLa,zRTXXYnUXmkIMDoP +JQWdUrNElPWswpQvnVqCmboMEjMhebKISRcmznakzemGxBjUughzOVbctPzmVTLW,CdcnGSQhRbdsoQrg +OdAvNbDdCQpTrAbJWrUrVprpgVIXwvvSStooIVwzUfDIThtvdBHldyUFFkvabfyj,ueAqmPurXOjNtvWr +zjIkKNEzyUFiubvlxYWNXdjoIIEwZavalnqwSCgDgcZUldjZOkzhKXuRciwSTNJg,MWwWigLZKqLgZkLp +WTaCVgYrnoyEoShtBDUmrRHeRSYIAjvUpZnVAUTxTyaIGzvQIdwcPafAnkIbplSq,mdIEdIajBbeAyPCk +tNdAhhdqVzLdGfPoctgRkehzEOIRvjEwDpmAQrMjbWtfRQGjeUiVJNafrhVKFieX,IxHvvUzKrlMpWhpR +vhgPWqsRnvDRMFIYHppovDbKlWPzEFwbBXSihpYbwCYpkeXFIXbIYdWSLfcHpnWX,qhwFSRGRKwcPlLJs +ARxEUJokZaGDgXHGxwPiSqqvNSmoowUxRDDkozqbvcUvQuPtdNaeaKOKykMIUkmR,XqmCMQPKPtAPzBZd +RwztiezZCzbSLKzIyYqfEMjDTcLpASCiGWoaseuxBWpvSVutmtdEgdZornGkHrQf,YcRFZNgodJFPNoop +YTcHYrADMhlKAnvdGBdQBXWBqcftxkNpFceODelYVRXwFOZTHdXkVGAfJTzZcyhD,tBGtrQaLFgACGOEE +fHFCvDLRGGhYZWSnxaIqKTgvNbCPLzyvOnpHyAhrKEAsApdPgkxAptCTtgYAnmEq,vxGOPFzvJOVBEblg +zckpuLjSVdhSFnhTqPfDoHdJdjpfZBDdlzGbYgzVbKgDMJQDBGCHZSJBdtzlvHro,TeeGbXAcEbwzglGf +muAQTPuNCQTZurKTDlYzTQgvlWNyRXOlKizgsnGSrKdYWCSBlQtOvIyEWVthaYhO,ZnYBDVQYoJOoTMlS +UQswwuiprHWAbguGNZgOAdFrgEIdsDRImrqXXTmbqppVgnJrjjiOdZaNUpIQGcTR,VwugWpNMzEKHAFqo +GDRPaAUIAymOEEksSqccGOqpUYvGUyvBKjfRqKSTAyNadpaMYnMYboPOrEEfXVWf,noDbJmsjYCgqHsBu +cVjSBnCUnKfKXwETABIPvavwLXMGSLSpoVylUSCRlRCzpDvDVjfNAIrSiRWNHJZS,OszhlCboIvNdCTYH diff --git a/spu/tests/data/ground_truth.csv b/spu/tests/data/ground_truth.csv new file mode 100644 index 00000000..fa97d0fa --- /dev/null +++ b/spu/tests/data/ground_truth.csv @@ -0,0 +1 @@ +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv,eweOmajnCQLZOgBp diff --git a/spu/tests/data/query.csv b/spu/tests/data/query.csv new file mode 100644 index 00000000..79dbd224 --- /dev/null +++ b/spu/tests/data/query.csv @@ -0,0 +1 @@ +JpnLfyRVbNPLIfbvPGuakcXCvxtoElcbACKRUfMSiKUemqyOVmvLspaZEPUtqJxv diff --git a/spu/tests/pir_test.py b/spu/tests/pir_test.py index 45574044..c75d6bbe 100644 --- a/spu/tests/pir_test.py +++ b/spu/tests/pir_test.py @@ -13,122 +13,88 @@ # limitations under the License. import json +import tempfile import unittest -from tempfile import TemporaryDirectory import multiprocess -from google.protobuf import json_format - import spu.libspu.link as link import spu.psi as psi -from spu.tests.utils import create_link_desc, wc_count +from google.protobuf import json_format +from spu.tests.utils import create_link_desc class UnitTests(unittest.TestCase): - def setUp(self) -> None: - self.tempdir_ = TemporaryDirectory() - return super().setUp() - - def tearDown(self) -> None: - self.tempdir_.cleanup() - return super().tearDown() def test_pir(self): - # setup stage - server_setup_config = f''' - {{ - "mode": "MODE_SERVER_SETUP", - "pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI", - "pir_server_config": {{ - "input_path": "spu/tests/data/alice.csv", - "setup_path": "{self.tempdir_.name}/spu_test_pir_pir_server_setup", - "key_columns": [ - "id" - ], - "label_columns": [ - "y" - ], - "label_max_len": 288, - "bucket_size": 1000000, - "apsi_server_config": {{ - "oprf_key_path": "{self.tempdir_.name}/spu_test_pir_server_secret_key.bin", - "num_per_query": 1, - "compressed": false - }} + with tempfile.TemporaryDirectory() as temp_dir: + # setup stage + sender_setup_config_json = f''' + {{ + "db_file": "spu/tests/data/db.csv", + "params_file": "spu/tests/data/100K-1-16.json", + "sdb_out_file": "{temp_dir}/sdb", + "save_db_only": true }} - }} - ''' - - with open( - f"{self.tempdir_.name}/spu_test_pir_server_secret_key.bin", 'wb' - ) as f: - f.write( - bytes.fromhex( - "000102030405060708090a0b0c0d0e0ff0e0d0c0b0a090807060504030201000" + ''' + + psi.apsi_send( + json_format.ParseDict( + json.loads(sender_setup_config_json), psi.ApsiSenderConfig() ) ) - psi.pir(json_format.ParseDict(json.loads(server_setup_config), psi.PirConfig())) - - server_online_config = f''' - {{ - "mode": "MODE_SERVER_ONLINE", - "pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI", - "pir_server_config": {{ - "setup_path": "{self.tempdir_.name}/spu_test_pir_pir_server_setup" + sender_online_config_json = f''' + {{ + "db_file": "{temp_dir}/sdb" }} - }} - ''' - - client_online_config = f''' - {{ - "mode": "MODE_CLIENT", - "pir_protocol": "PIR_PROTOCOL_KEYWORD_PIR_APSI", - "pir_client_config": {{ - "input_path": "{self.tempdir_.name}/spu_test_pir_pir_client.csv", - "key_columns": [ - "id" - ], - "output_path": "{self.tempdir_.name}/spu_test_pir_pir_output.csv" + ''' + + receiver_online_config_json = f''' + {{ + "query_file": "spu/tests/data/query.csv", + "output_file": "{temp_dir}/result.csv", + "params_file": "spu/tests/data/100K-1-16.json" }} - }} - ''' + ''' - pir_client_input_content = '''id -user808 -xxx -''' + sender_online_config = json_format.ParseDict( + json.loads(sender_online_config_json), psi.ApsiSenderConfig() + ) + + receiver_online_config = json_format.ParseDict( + json.loads(receiver_online_config_json), psi.ApsiReceiverConfig() + ) - with open(f"{self.tempdir_.name}/spu_test_pir_pir_client.csv", 'w') as f: - f.write(pir_client_input_content) + link_desc = create_link_desc(2) - configs = [ - json_format.ParseDict(json.loads(server_online_config), psi.PirConfig()), - json_format.ParseDict(json.loads(client_online_config), psi.PirConfig()), - ] + def sender_wrap(rank, link_desc, config): + link_ctx = link.create_brpc(link_desc, rank) + psi.apsi_send(config, link_ctx) - link_desc = create_link_desc(2) + def receiver_wrap(rank, link_desc, config): + link_ctx = link.create_brpc(link_desc, rank) + psi.apsi_receive(config, link_ctx) - def wrap(rank, link_desc, configs): - link_ctx = link.create_brpc(link_desc, rank) - psi.pir(configs[rank], link_ctx) + jobs = [ + multiprocess.Process( + target=sender_wrap, args=(0, link_desc, sender_online_config) + ), + multiprocess.Process( + target=receiver_wrap, args=(1, link_desc, receiver_online_config) + ), + ] - jobs = [ - multiprocess.Process( - target=wrap, - args=(rank, link_desc, configs), - ) - for rank in range(2) - ] - [job.start() for job in jobs] - for job in jobs: - job.join() - self.assertEqual(job.exitcode, 0) - - # including title, actual matched item cnt is 1. - self.assertEqual( - wc_count(f"{self.tempdir_.name}/spu_test_pir_pir_output.csv"), 2 - ) + [job.start() for job in jobs] + for job in jobs: + job.join() + self.assertEqual(job.exitcode, 0) + + import pandas as pd + + df1 = pd.read_csv(f'{temp_dir}/result.csv') + df2 = pd.read_csv('spu/tests/data/ground_truth.csv') + + self.assertTrue(df1.equals(df2)) if __name__ == '__main__':