diff --git a/ACKNOWLEDGEMENTS.md b/ACKNOWLEDGEMENTS.md index 3f9fb8b..4a8b44f 100644 --- a/ACKNOWLEDGEMENTS.md +++ b/ACKNOWLEDGEMENTS.md @@ -3,6 +3,14 @@ listed below. We are extremely grateful to the authors of these projects! +## ff + +[ff: Traits and utilities for working with finite fields](https://github.com/zkcrypto/ff) dual-licensed +under [MIT][ff-mit] and [Apache 2.0][ff-apache] licenses. + +[ff-mit]: https://github.com/zkcrypto/ff/blob/b853db2c05a5901a8199012f80f5ee3784f52549/LICENSE-MIT +[ff-apache]: https://github.com/zkcrypto/ff/blob/b853db2c05a5901a8199012f80f5ee3784f52549/LICENSE-APACHE + ## bellman [bellman: zk-SNARK library](https://github.com/zkcrypto/bellman) dual-licensed under diff --git a/crates/boojum/.github/ISSUE_TEMPLATE/bug_report.md b/crates/boojum/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 2d3e38a..0000000 --- a/crates/boojum/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,39 +0,0 @@ ---- -name: Bug report -about: Use this template for reporting issues -title: '' -labels: bug -assignees: '' ---- - -### 🐛 Bug Report - -#### 📝 Description - -Provide a clear and concise description of the bug. - -#### 🔄 Reproduction Steps - -Steps to reproduce the behaviour - -#### 🤔 Expected Behavior - -Describe what you expected to happen. - -#### 😯 Current Behavior - -Describe what actually happened. - -#### 🖥️ Environment - -Any relevant environment details. - -#### 📋 Additional Context - -Add any other context about the problem here. If applicable, add screenshots to help explain. - -#### 📎 Log Output - -``` -Paste any relevant log output here. -``` diff --git a/crates/boojum/.github/ISSUE_TEMPLATE/feature_request.md b/crates/boojum/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index d921e06..0000000 --- a/crates/boojum/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -name: Feature request -about: Use this template for requesting features -title: '' -labels: feat -assignees: '' ---- - -### 🌟 Feature Request - -#### 📝 Description - -Provide a clear and concise description of the feature you'd like to see. - -#### 🤔 Rationale - -Explain why this feature is important and how it benefits the project. - -#### 📋 Additional Context - -Add any other context or information about the feature request here. diff --git a/crates/boojum/.github/pull_request_template.md b/crates/boojum/.github/pull_request_template.md deleted file mode 100644 index 8ce206c..0000000 --- a/crates/boojum/.github/pull_request_template.md +++ /dev/null @@ -1,20 +0,0 @@ -# What ❔ - - - - - -## Why ❔ - - - - -## Checklist - - - - -- [ ] PR title corresponds to the body of PR (we generate changelog entries from PRs). -- [ ] Tests for the changes have been added / updated. -- [ ] Documentation comments have been added / updated. -- [ ] Code has been formatted via `zk fmt` and `zk lint`. diff --git a/crates/boojum/.github/workflows/cargo-license.yaml b/crates/boojum/.github/workflows/cargo-license.yaml deleted file mode 100644 index 189b471..0000000 --- a/crates/boojum/.github/workflows/cargo-license.yaml +++ /dev/null @@ -1,8 +0,0 @@ -name: Cargo license check -on: pull_request -jobs: - cargo-deny: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: EmbarkStudios/cargo-deny-action@v1 diff --git a/crates/boojum/.github/workflows/ci.yaml b/crates/boojum/.github/workflows/ci.yaml deleted file mode 100644 index a984dd9..0000000 --- a/crates/boojum/.github/workflows/ci.yaml +++ /dev/null @@ -1,30 +0,0 @@ -name: "Rust CI" -on: - pull_request: - -jobs: - build: - name: cargo build and test - strategy: - matrix: - # Needs big runners to run tests - # Only macos-13-xlarge is Apple Silicon, as per: - # https://docs.github.com/en/actions/using-github-hosted-runners/about-larger-runners/about-larger-runners#about-macos-larger-runners - os: [ubuntu-22.04-github-hosted-16core, macos-13-xlarge] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - uses: actions-rust-lang/setup-rust-toolchain@v1 - - run: cargo build --verbose - - run: cargo test --verbose --all - - formatting: - name: cargo fmt - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions-rust-lang/setup-rust-toolchain@v1 - with: - components: rustfmt - - name: Rustfmt Check - uses: actions-rust-lang/rustfmt@v1 diff --git a/crates/boojum/.github/workflows/secrets_scanner.yaml b/crates/boojum/.github/workflows/secrets_scanner.yaml deleted file mode 100644 index 54054cf..0000000 --- a/crates/boojum/.github/workflows/secrets_scanner.yaml +++ /dev/null @@ -1,17 +0,0 @@ -name: Leaked Secrets Scan -on: [pull_request] -jobs: - TruffleHog: - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@ac593985615ec2ede58e132d2e21d2b1cbd6127c # v3 - with: - fetch-depth: 0 - - name: TruffleHog OSS - uses: trufflesecurity/trufflehog@0c66d30c1f4075cee1aada2e1ab46dabb1b0071a - with: - path: ./ - base: ${{ github.event.repository.default_branch }} - head: HEAD - extra_args: --debug --only-verified diff --git a/crates/ff/.gitignore b/crates/ff/.gitignore new file mode 100644 index 0000000..4308d82 --- /dev/null +++ b/crates/ff/.gitignore @@ -0,0 +1,3 @@ +target/ +**/*.rs.bk +Cargo.lock diff --git a/crates/ff/Cargo.toml b/crates/ff/Cargo.toml new file mode 100644 index 0000000..8f59b50 --- /dev/null +++ b/crates/ff/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "ff_ce" +version = "0.14.3" +authors = ["Sean Bowe ", + "Alex Gluchowski ", + "Alex Vlasov "] +description = "Library for building and interfacing with finite fields" +documentation = "https://docs.rs/ff/" +homepage = "https://github.com/matter-labs/ff" +license = "MIT/Apache-2.0" +repository = "https://github.com/matter-labs/ff" +edition = "2018" +exclude = [ + "ff_derive_const", + "tester", + "asm_tester" +] + +[dependencies] +byteorder = "1" +rand = "0.4" +ff_derive_ce = { version = "0.11", optional = true } +# ff_derive_ce = { path = "ff_derive", optional = true } +hex = {version = "0.4"} +serde = "1" + +[features] +default = [] +derive = ["ff_derive_ce"] +asm_derive = ["derive", "ff_derive_ce/asm"] diff --git a/crates/ff/LICENSE-APACHE b/crates/ff/LICENSE-APACHE new file mode 100644 index 0000000..1e5006d --- /dev/null +++ b/crates/ff/LICENSE-APACHE @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + diff --git a/crates/ff/LICENSE-MIT b/crates/ff/LICENSE-MIT new file mode 100644 index 0000000..ed3a13f --- /dev/null +++ b/crates/ff/LICENSE-MIT @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017 Sean Bowe + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/crates/ff/README.md b/crates/ff/README.md new file mode 100644 index 0000000..ec8ce95 --- /dev/null +++ b/crates/ff/README.md @@ -0,0 +1,64 @@ +# "FF community edition" + +This library is community maintained fork of the original `ff` library by Sean Bowe. Name of the library is changed to allow publishing to the `crates.io` + +## Original ff + +`ff` is a finite field library written in pure Rust, with no `unsafe{}` code. + +## Disclaimers + +* This library does not provide constant-time guarantees. + +## Usage + +Add the `ff_ce` crate to your `Cargo.toml`: + +```toml +[dependencies] +ff_ce = "0.6" +``` + +The `ff_ce` crate contains `Field`, `PrimeField`, `PrimeFieldRepr` and `SqrtField` traits. See the **[documentation](https://docs.rs/ff/0.4.0/ff/)** for more. + +### #![derive(PrimeField)] + +If you need an implementation of a prime field, this library also provides a procedural macro that will expand into an efficient implementation of a prime field when supplied with the modulus. `PrimeFieldGenerator` must be an element of Fp of p-1 order, that is also quadratic nonresidue. + +First, enable the `derive` crate feature: + +```toml +[dependencies] +ff = { ..., features = ["derive"] } +``` + +And then use the macro like so: + +```rust +extern crate rand; +#[macro_use] +extern crate ff_ce; + +#[derive(PrimeField)] +#[PrimeFieldModulus = "52435875175126190479447740508185965837690552500527637822603658699938581184513"] +#[PrimeFieldGenerator = "7"] +struct Fp(FpRepr); +``` + +And that's it! `Fp` now implements `Field` and `PrimeField`. `Fp` will also implement `SqrtField` if supported. The library implements `FpRepr` itself and derives `PrimeFieldRepr` for it. + +## License + +Licensed under either of + + * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) + * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) + +at your option. + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally +submitted for inclusion in the work by you, as defined in the Apache-2.0 +license, shall be dual licensed as above, without any additional terms or +conditions. diff --git a/crates/ff/asm_tester/Cargo.toml b/crates/ff/asm_tester/Cargo.toml new file mode 100644 index 0000000..6381988 --- /dev/null +++ b/crates/ff/asm_tester/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "asm_tester" +version = "0.1.0" +authors = ["Alex Vlasov "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ff = {package = "ff_ce", path = "../", features = ["asm_derive"]} +rand = "0.4" +serde = "1" + +[dev-dependencies] +criterion = "0.3" + +[[bench]] +name = "multiplication" +harness = false \ No newline at end of file diff --git a/crates/ff/asm_tester/bench_with_features.sh b/crates/ff/asm_tester/bench_with_features.sh new file mode 100755 index 0000000..c2b4e75 --- /dev/null +++ b/crates/ff/asm_tester/bench_with_features.sh @@ -0,0 +1 @@ +RUSTFLAGS="-C target-cpu=native -C target_feature=+bmi2,+adx" cargo +nightly bench \ No newline at end of file diff --git a/crates/ff/asm_tester/benches/multiplication.rs b/crates/ff/asm_tester/benches/multiplication.rs new file mode 100644 index 0000000..bd26005 --- /dev/null +++ b/crates/ff/asm_tester/benches/multiplication.rs @@ -0,0 +1,192 @@ +extern crate ff; +extern crate rand; +extern crate asm_tester; + +use self::ff::*; +use asm_tester::test_large_field::{Fr, FrAsm}; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +fn mul_assing_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fr = rng.gen(); + let b: Fr = rng.gen(); + + c.bench_function("Mont mul assign 256", |bencher| bencher.iter(|| black_box(a).mul_assign(&black_box(b)))); +} + +fn mul_assing_asm_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: FrAsm = rng.gen(); + let b: FrAsm = rng.gen(); + + c.bench_function("Mont mul assign 256 ASM", |bencher| bencher.iter(|| black_box(a).mul_assign(&black_box(b)))); +} + +fn add_assing_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fr = rng.gen(); + let b: Fr = rng.gen(); + + c.bench_function("Add assign 256", |bencher| bencher.iter(|| black_box(a).add_assign(&black_box(b)))); +} + +fn add_assing_asm_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: FrAsm = rng.gen(); + let b: FrAsm = rng.gen(); + + c.bench_function("Add assign 256 ASM", |bencher| bencher.iter(|| black_box(a).add_assign(&black_box(b)))); +} + +fn sub_assing_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fr = rng.gen(); + let b: Fr = rng.gen(); + + c.bench_function("Sub assign 256", |bencher| bencher.iter(|| black_box(a).sub_assign(&black_box(b)))); +} + +fn sub_assing_asm_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: FrAsm = rng.gen(); + let b: FrAsm = rng.gen(); + + c.bench_function("Sub assign 256 ASM", |bencher| bencher.iter(|| black_box(a).sub_assign(&black_box(b)))); +} + +fn double_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fr = rng.gen(); + + c.bench_function("Double 256", |bencher| bencher.iter(|| black_box(a).double())); +} + +fn double_asm_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: FrAsm = rng.gen(); + + c.bench_function("Double 256 ASM", |bencher| bencher.iter(|| black_box(a).double())); +} + +fn square_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fr = rng.gen(); + + c.bench_function("Mont square 256", |bencher| bencher.iter(|| black_box(a).square())); +} + +fn square_asm_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: FrAsm = rng.gen(); + + c.bench_function("Mont square 256 ASM", |bencher| bencher.iter(|| black_box(a).square())); +} + +fn mul_assing_vector_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let mut a = [Fr::zero(); 1024]; + let mut b = [Fr::zero(); 1024]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + + c.bench_function("Mont mul assign vector 256", |bencher| bencher.iter(|| + { + let mut a = black_box(a); + let b = black_box(b); + for (a, b) in a.iter_mut().zip(b.iter()) { + a.mul_assign(b); + } + } + )); +} + +fn mul_assing_asm_vector_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let mut a = [FrAsm::zero(); 1024]; + let mut b = [FrAsm::zero(); 1024]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + + c.bench_function("Mont mul assign vector 256 ASM", |bencher| bencher.iter(|| + { + let mut a = black_box(a); + let b = black_box(b); + for (a, b) in a.iter_mut().zip(b.iter()) { + a.mul_assign(b); + } + } + )); +} + +fn square_vector_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let mut a = [Fr::zero(); 1024]; + for a in a.iter_mut() { + *a = rng.gen(); + } + + c.bench_function("Mont square vector 256", |bencher| bencher.iter(|| + { + let mut a = black_box(a); + for a in a.iter_mut() { + a.square(); + } + } + )); +} + +fn square_asm_vector_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let mut a = [FrAsm::zero(); 1024]; + for a in a.iter_mut() { + *a = rng.gen(); + } + + c.bench_function("Mont square vector 256 ASM", |bencher| bencher.iter(|| + { + let mut a = black_box(a); + for a in a.iter_mut() { + a.square(); + } + } + )); +} +criterion_group!( + name = advanced; + config = Criterion::default().warm_up_time(std::time::Duration::from_secs(5)); + targets = mul_assing_benchmark, mul_assing_asm_benchmark, add_assing_benchmark, add_assing_asm_benchmark, sub_assing_benchmark, sub_assing_asm_benchmark, double_benchmark, double_asm_benchmark, square_benchmark, square_asm_benchmark, mul_assing_vector_benchmark, mul_assing_asm_vector_benchmark, square_vector_benchmark, square_asm_vector_benchmark +); +criterion_main!(advanced); diff --git a/crates/ff/asm_tester/src/lib.rs b/crates/ff/asm_tester/src/lib.rs new file mode 100644 index 0000000..677fb44 --- /dev/null +++ b/crates/ff/asm_tester/src/lib.rs @@ -0,0 +1,5 @@ +extern crate ff; +extern crate rand; +extern crate serde; + +pub mod test_large_field; \ No newline at end of file diff --git a/crates/ff/asm_tester/src/test_large_field.rs b/crates/ff/asm_tester/src/test_large_field.rs new file mode 100644 index 0000000..9291dd8 --- /dev/null +++ b/crates/ff/asm_tester/src/test_large_field.rs @@ -0,0 +1,140 @@ +mod normal { + use ff::*; + + #[derive(PrimeField)] + #[PrimeFieldModulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583"] + #[PrimeFieldGenerator = "2"] + pub struct Fr(FrRepr); +} + +mod asm { + use ff::*; + + #[derive(PrimeFieldAsm)] + #[PrimeFieldModulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583"] + #[PrimeFieldGenerator = "2"] + #[UseADX = "true"] + pub struct FrAsm(FrReprAsm); +} + +pub use self::normal::Fr; +pub use self::asm::FrAsm; + +#[cfg(test)] +mod test { + use super::*; + + use rand::*; + use ff::*; + + #[test] + fn check_mul_asm() { + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + for i in 0..10000 { + let a: Fr = rng.gen(); + let b: Fr = rng.gen(); + + let a_asm = unsafe { std::mem::transmute::<_, FrAsm>(a) }; + let b_asm = unsafe { std::mem::transmute::<_, FrAsm>(b) }; + + let mut c = a; + c.mul_assign(&b); + + let mut c_asm = a_asm; + c_asm.mul_assign(&b_asm); + + let c_back = unsafe { std::mem::transmute::<_, Fr>(c_asm) }; + + assert_eq!(c, c_back, "failed at iteration {}: a = {:?}, b = {:?}", i, a, b); + } + } + + #[test] + fn check_add_asm() { + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + for i in 0..10000 { + let a: Fr = rng.gen(); + let b: Fr = rng.gen(); + + let a_asm = unsafe { std::mem::transmute::<_, FrAsm>(a) }; + let b_asm = unsafe { std::mem::transmute::<_, FrAsm>(b) }; + + let mut c = a; + c.add_assign(&b); + + let mut c_asm = a_asm; + c_asm.add_assign(&b_asm); + + let c_back = unsafe { std::mem::transmute::<_, Fr>(c_asm) }; + + assert_eq!(c, c_back, "failed at iteration {}: a = {:?}, b = {:?}", i, a, b); + } + } + + #[test] + fn check_sub_asm() { + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + for i in 0..10000 { + let a: Fr = rng.gen(); + let b: Fr = rng.gen(); + + let a_asm = unsafe { std::mem::transmute::<_, FrAsm>(a) }; + let b_asm = unsafe { std::mem::transmute::<_, FrAsm>(b) }; + + let mut c = a; + c.sub_assign(&b); + + let mut c_asm = a_asm; + c_asm.sub_assign(&b_asm); + + let c_back = unsafe { std::mem::transmute::<_, Fr>(c_asm) }; + + assert_eq!(c, c_back, "failed at iteration {}: a = {:?}, b = {:?}", i, a, b); + } + } + + #[test] + fn check_double_asm() { + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + for i in 0..10000 { + let a: Fr = rng.gen(); + + let a_asm = unsafe { std::mem::transmute::<_, FrAsm>(a) }; + + let mut c = a; + c.double(); + + let mut c_asm = a_asm; + c_asm.double(); + + let c_back = unsafe { std::mem::transmute::<_, Fr>(c_asm) }; + + assert_eq!(c, c_back, "failed at iteration {}: a = {:?}", i, a); + } + } + + #[test] + fn check_square_asm() { + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + for i in 0..10000 { + let a: Fr = rng.gen(); + + let a_asm = unsafe { std::mem::transmute::<_, FrAsm>(a) }; + + let mut c = a; + c.square(); + + let mut c_asm = a_asm; + c_asm.square(); + + let c_back = unsafe { std::mem::transmute::<_, Fr>(c_asm) }; + + assert_eq!(c, c_back, "failed at iteration {}: a = {:?}", i, a); + } + } +} \ No newline at end of file diff --git a/crates/ff/asm_tester/test_with_features.sh b/crates/ff/asm_tester/test_with_features.sh new file mode 100755 index 0000000..f51abdb --- /dev/null +++ b/crates/ff/asm_tester/test_with_features.sh @@ -0,0 +1 @@ +RUSTFLAGS="-C target-cpu=native -C target_feature=+bmi2,+adx -Z macro-backtrace" cargo +nightly test \ No newline at end of file diff --git a/crates/ff/asm_tester/tmp.rs b/crates/ff/asm_tester/tmp.rs new file mode 100644 index 0000000..71bb4c3 --- /dev/null +++ b/crates/ff/asm_tester/tmp.rs @@ -0,0 +1,7 @@ +mod test_large_field { + use ff::*; + #[PrimeFieldModulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583"] + #[PrimeFieldGenerator = "2"] + #[UseADX = "true"] + struct FrAsm(FrReprAsm); +} diff --git a/crates/ff/ff_derive/Cargo.toml b/crates/ff/ff_derive/Cargo.toml new file mode 100644 index 0000000..c5026d7 --- /dev/null +++ b/crates/ff/ff_derive/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "ff_derive_ce" +version = "0.11.2" +authors = ["Sean Bowe ", + "Alex Gluchowski ", + "Alex Vlasov "] +description = "Procedural macro library used to build custom prime field implementations" +documentation = "https://docs.rs/ff/" +homepage = "https://github.com/matter-labs/ff" +license = "MIT/Apache-2.0" +repository = "https://github.com/matter-labs/ff" +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +num-bigint = "0.4" +num-traits = "0.2" +num-integer = "0.1" +proc-macro2 = "1" +quote = "1" +syn = "1" +serde = {version = "1", features = ["derive"]} + +[features] +default = [] +asm = [] +serde = [] diff --git a/crates/ff/ff_derive/src/asm/asm_derive.rs b/crates/ff/ff_derive/src/asm/asm_derive.rs new file mode 100644 index 0000000..57154c6 --- /dev/null +++ b/crates/ff/ff_derive/src/asm/asm_derive.rs @@ -0,0 +1,948 @@ +use num_bigint::BigUint; +use num_integer::Integer; +use num_traits::{One, ToPrimitive}; +use quote::TokenStreamExt; +use std::str::FromStr; + +use crate::utils::*; +use super::super::{fetch_wrapped_ident, fetch_attr, get_temp, get_temp_with_literal}; +use crate::asm::impls_4::*; + +const MODULUS_PREFIX: &str = "MODULUS_"; +const MODULUS_NEGATED_PREFIX: &str = "MODULUS_NEG_"; + +// #[proc_macro_derive(PrimeFieldAsm, attributes(PrimeFieldModulus, PrimeFieldGenerator, UseADX))] +pub fn prime_field_asm_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + // Parse the type definition + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + // The struct we're deriving for is a wrapper around a "Repr" type we must construct. + let repr_ident = fetch_wrapped_ident(&ast.data) + .expect("PrimeField derive only operates over tuple structs of a single item"); + + // We're given the modulus p of the prime field + let modulus: BigUint = fetch_attr("PrimeFieldModulus", &ast.attrs) + .expect("Please supply a PrimeFieldModulus attribute") + .parse() + .expect("PrimeFieldModulus should be a number"); + + // We may be provided with a generator of p - 1 order. It is required that this generator be quadratic + // nonresidue. + let generator: BigUint = fetch_attr("PrimeFieldGenerator", &ast.attrs) + .expect("Please supply a PrimeFieldGenerator attribute") + .parse() + .expect("PrimeFieldGenerator should be a number"); + + // User may opt-in for feature to generate CIOS based multiplication operation + let use_adx: Option = fetch_attr("UseADX", &ast.attrs) + .map(|el| el.parse().expect("UseADX should be `true` or `false`")); + + + assert!(use_adx.unwrap(), "For now only ADX backend is used"); + + // The arithmetic in this library only works if the modulus*2 is smaller than the backing + // representation. Compute the number of limbs we need. + let mut limbs = 1; + { + let mod2 = (&modulus) << 1; // modulus * 2 + let mut cur = BigUint::one() << 64; // always 64-bit limbs for now + while cur < mod2 { + limbs += 1; + cur = cur << 64; + } + } + + assert_eq!(limbs, 4, "can only derive for 4 limb fitting modulus"); + + let modulus_limbs = biguint_to_real_u64_vec(modulus.clone(), limbs); + let top_limb = modulus_limbs.last().unwrap().clone().to_u64().unwrap(); + let can_use_optimistic_cios_mul = { + let mut can_use = true; + if top_limb == 0 { + can_use = false; + } + + if top_limb > (std::u64::MAX / 2) - 1 { + can_use = false; + } + can_use + }; + + let can_use_optimistic_cios_sqr = { + let mut can_use = true; + if top_limb == 0 { + can_use = false; + } + + if top_limb > (std::u64::MAX / 4) - 1 { + assert!(!can_use, "can not use optimistic CIOS for this modulus"); + can_use = false; + } + can_use + }; + + assert!(can_use_optimistic_cios_mul, "Can only derive for moduluses that fit in 255 bits - epsilon"); + assert!(can_use_optimistic_cios_sqr, "Can only derive for moduluses that fit in 254 bits - epsilon"); + + let random_id = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .subsec_nanos(); + + let mut gen = proc_macro2::TokenStream::new(); + + let (constants_impl, mont_inv, sqrt_impl) = prime_field_constants_with_inv_and_sqrt( + &ast.ident, + &repr_ident, + modulus, + limbs, + generator, + random_id + ); + + gen.extend(constants_impl); + gen.extend(prime_field_repr_impl(&repr_ident, limbs)); + gen.extend(prime_field_impl(&ast.ident, &repr_ident, mont_inv, limbs, random_id)); + gen.extend(sqrt_impl); + + // Return the generated impl + gen.into() +} + +// Implement PrimeFieldRepr for the wrapped ident `repr` with `limbs` limbs. +fn prime_field_repr_impl(repr: &syn::Ident, limbs: usize) -> proc_macro2::TokenStream { + quote! { + + #[derive(Copy, Clone, PartialEq, Eq, Default, ::serde::Serialize, ::serde::Deserialize)] + pub struct #repr( + pub [u64; #limbs] + ); + + impl ::std::fmt::Debug for #repr + { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "0x")?; + for i in self.0.iter().rev() { + write!(f, "{:016x}", *i)?; + } + + Ok(()) + } + } + + impl ::rand::Rand for #repr { + #[inline(always)] + fn rand(rng: &mut R) -> Self { + #repr(rng.gen()) + } + } + + impl ::std::fmt::Display for #repr { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "0x")?; + for i in self.0.iter().rev() { + write!(f, "{:016x}", *i)?; + } + + Ok(()) + } + } + + impl std::hash::Hash for #repr { + fn hash(&self, state: &mut H) { + for limb in self.0.iter() { + limb.hash(state); + } + } + } + + impl AsRef<[u64]> for #repr { + #[inline(always)] + fn as_ref(&self) -> &[u64] { + &self.0 + } + } + + impl AsMut<[u64]> for #repr { + #[inline(always)] + fn as_mut(&mut self) -> &mut [u64] { + &mut self.0 + } + } + + impl From for #repr { + #[inline(always)] + fn from(val: u64) -> #repr { + use std::default::Default; + + let mut repr = Self::default(); + repr.0[0] = val; + repr + } + } + + impl Ord for #repr { + #[inline(always)] + fn cmp(&self, other: &#repr) -> ::std::cmp::Ordering { + for (a, b) in self.0.iter().rev().zip(other.0.iter().rev()) { + if a < b { + return ::std::cmp::Ordering::Less + } else if a > b { + return ::std::cmp::Ordering::Greater + } + } + + ::std::cmp::Ordering::Equal + } + } + + impl PartialOrd for #repr { + #[inline(always)] + fn partial_cmp(&self, other: &#repr) -> Option<::std::cmp::Ordering> { + Some(self.cmp(other)) + } + } + + impl crate::ff::PrimeFieldRepr for #repr { + #[inline(always)] + fn is_odd(&self) -> bool { + self.0[0] & 1 == 1 + } + + #[inline(always)] + fn is_even(&self) -> bool { + !self.is_odd() + } + + #[inline(always)] + fn is_zero(&self) -> bool { + self.0.iter().all(|&e| e == 0) + } + + #[inline(always)] + fn shr(&mut self, mut n: u32) { + if n as usize >= 64 * #limbs { + *self = Self::from(0); + return; + } + + while n >= 64 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + ::std::mem::swap(&mut t, i); + } + n -= 64; + } + + if n > 0 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + let t2 = *i << (64 - n); + *i >>= n; + *i |= t; + t = t2; + } + } + } + + #[inline(always)] + fn div2(&mut self) { + let mut t = 0; + for i in self.0.iter_mut().rev() { + let t2 = *i << 63; + *i >>= 1; + *i |= t; + t = t2; + } + } + + #[inline(always)] + fn mul2(&mut self) { + let mut last = 0; + for i in &mut self.0 { + let tmp = *i >> 63; + *i <<= 1; + *i |= last; + last = tmp; + } + } + + #[inline(always)] + fn shl(&mut self, mut n: u32) { + if n as usize >= 64 * #limbs { + *self = Self::from(0); + return; + } + + while n >= 64 { + let mut t = 0; + for i in &mut self.0 { + ::std::mem::swap(&mut t, i); + } + n -= 64; + } + + if n > 0 { + let mut t = 0; + for i in &mut self.0 { + let t2 = *i >> (64 - n); + *i <<= n; + *i |= t; + t = t2; + } + } + } + + #[inline(always)] + fn num_bits(&self) -> u32 { + let mut ret = (#limbs as u32) * 64; + for i in self.0.iter().rev() { + let leading = i.leading_zeros(); + ret -= leading; + if leading != 64 { + break; + } + } + + ret + } + + #[inline(always)] + fn add_nocarry(&mut self, other: &#repr) { + let mut carry = 0; + + for (a, b) in self.0.iter_mut().zip(other.0.iter()) { + *a = crate::ff::adc(*a, *b, &mut carry); + } + } + + #[inline(always)] + fn sub_noborrow(&mut self, other: &#repr) { + let mut borrow = 0; + + for (a, b) in self.0.iter_mut().zip(other.0.iter()) { + *a = crate::ff::sbb(*a, *b, &mut borrow); + } + } + } + } +} + +fn prime_field_constants_with_inv_and_sqrt( + name: &syn::Ident, + repr: &syn::Ident, + modulus: BigUint, + limbs: usize, + generator: BigUint, + random_id: u32 +) -> (proc_macro2::TokenStream, u64, proc_macro2::TokenStream) { + let modulus_num_bits = biguint_num_bits(modulus.clone()); + + // The number of bits we should "shave" from a randomly sampled reputation, i.e., + // if our modulus is 381 bits and our representation is 384 bits, we should shave + // 3 bits from the beginning of a randomly sampled 384 bit representation to + // reduce the cost of rejection sampling. + let repr_shave_bits = (64 * limbs as u32) - biguint_num_bits(modulus.clone()); + + // Compute R = 2**(64 * limbs) mod m + let r = (BigUint::one() << (limbs * 64)) % &modulus; + + // modulus - 1 = 2^s * t + let mut s: u32 = 0; + let mut t = &modulus - BigUint::from_str("1").unwrap(); + while t.is_even() { + t = t >> 1; + s += 1; + } + + // Compute 2^s root of unity given the generator + let root_of_unity = biguint_to_u64_vec( + (generator.clone().modpow(&t, &modulus) * &r) % &modulus, + limbs, + ); + let generator = biguint_to_u64_vec((generator.clone() * &r) % &modulus, limbs); + + let mod_minus_1_over_2 = + biguint_to_u64_vec((&modulus - BigUint::from_str("1").unwrap()) >> 1, limbs); + let legendre_impl = quote!{ + fn legendre(&self) -> crate::ff::LegendreSymbol { + // s = self^((modulus - 1) // 2) + let s = self.pow(#mod_minus_1_over_2); + if s == Self::zero() { + crate::ff::LegendreSymbol::Zero + } else if s == Self::one() { + crate::ff::LegendreSymbol::QuadraticResidue + } else { + crate::ff::LegendreSymbol::QuadraticNonResidue + } + } + }; + + let sqrt_impl = + if (&modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() { + let mod_minus_3_over_4 = + biguint_to_u64_vec((&modulus - BigUint::from_str("3").unwrap()) >> 2, limbs); + + // Compute -R as (m - r) + let rneg = biguint_to_u64_vec(&modulus - &r, limbs); + + quote!{ + impl crate::ff::SqrtField for #name { + #legendre_impl + + fn sqrt(&self) -> Option { + // Shank's algorithm for q mod 4 = 3 + // https://eprint.iacr.org/2012/685.pdf (page 9, algorithm 2) + + let mut a1 = self.pow(#mod_minus_3_over_4); + + let mut a0 = a1; + a0.square(); + a0.mul_assign(self); + + if a0.0 == #repr(#rneg) { + None + } else { + a1.mul_assign(self); + Some(a1) + } + } + } + } + } else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() { + let t_plus_1_over_2 = biguint_to_u64_vec((&t + BigUint::one()) >> 1, limbs); + let t = biguint_to_u64_vec(t.clone(), limbs); + + quote!{ + impl crate::ff::SqrtField for #name { + #legendre_impl + + fn sqrt(&self) -> Option { + // Tonelli-Shank's algorithm for q mod 16 = 1 + // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5) + + match self.legendre() { + crate::ff::LegendreSymbol::Zero => Some(*self), + crate::ff::LegendreSymbol::QuadraticNonResidue => None, + crate::ff::LegendreSymbol::QuadraticResidue => { + let mut c = #name(ROOT_OF_UNITY); + let mut r = self.pow(#t_plus_1_over_2); + let mut t = self.pow(#t); + let mut m = S; + + while t != Self::one() { + let mut i = 1; + { + let mut t2i = t; + t2i.square(); + loop { + if t2i == Self::one() { + break; + } + t2i.square(); + i += 1; + } + } + + for _ in 0..(m - i - 1) { + c.square(); + } + r.mul_assign(&c); + c.square(); + t.mul_assign(&c); + m = i; + } + + Some(r) + } + } + } + } + } + } else { + quote!{} + }; + + // Compute R^2 mod m + let r2 = biguint_to_u64_vec((&r * &r) % &modulus, limbs); + + let r = biguint_to_u64_vec(r, limbs); + + // 2^k - modulus + let modulus_negated = (BigUint::one() << (64 * limbs)) - &modulus; + + let modulus = biguint_to_real_u64_vec(modulus, limbs); + let modulus_negated = biguint_to_real_u64_vec(modulus_negated, limbs); + + // Compute -m^-1 mod 2**64 by exponentiating by totient(2**64) - 1 + let mut inv = 1u64; + for _ in 0..63 { + inv = inv.wrapping_mul(inv); + inv = inv.wrapping_mul(modulus[0]); + } + inv = inv.wrapping_neg(); + + let mut constants_gen = quote! { + /// This is the modulus m of the prime field + const MODULUS: #repr = #repr([#(#modulus,)*]); + + /// The number of bits needed to represent the modulus. + const MODULUS_BITS: u32 = #modulus_num_bits; + + /// The number of bits that must be shaved from the beginning of + /// the representation when randomly sampling. + const REPR_SHAVE_BITS: u32 = #repr_shave_bits; + + /// 2^{limbs*64} mod m + const R: #repr = #repr(#r); + + /// 2^{limbs*64*2} mod m + const R2: #repr = #repr(#r2); + + /// -(m^{-1} mod m) mod m + const INV: u64 = #inv; + + /// Multiplicative generator of `MODULUS` - 1 order, also quadratic + /// nonresidue. + const GENERATOR: #repr = #repr(#generator); + + /// 2^s * t = MODULUS - 1 with t odd + const S: u32 = #s; + + /// 2^s root of unity computed by GENERATOR^t + const ROOT_OF_UNITY: #repr = #repr(#root_of_unity); + }; + + for i in 0..4 { + let m = get_temp_with_literal(&format!("{}{}_", MODULUS_PREFIX, random_id), i); + let n = get_temp_with_literal(&format!("{}{}_", MODULUS_NEGATED_PREFIX, random_id), i); + let value = modulus[i]; + let limb_neg = modulus_negated[i]; + + constants_gen.extend( + quote!{ + #[no_mangle] + static #m: u64 = #value; + #[no_mangle] + static #n: u64 = #limb_neg; + } + ); + } + + (constants_gen, inv, sqrt_impl) +} + +/// Implement PrimeField for the derived type. +fn prime_field_impl( + name: &syn::Ident, + repr: &syn::Ident, + mont_inv: u64, + limbs: usize, + random_id: u32, +) -> proc_macro2::TokenStream { + // The parameter list for the mont_reduce() internal method. + // r0: u64, mut r1: u64, mut r2: u64, ... + let mut mont_paramlist = proc_macro2::TokenStream::new(); + mont_paramlist.append_separated( + (0..(limbs * 2)).map(|i| (i, get_temp(i))).map(|(i, x)| { + if i != 0 { + quote!{mut #x: u64} + } else { + quote!{#x: u64} + } + }), + proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), + ); + + // Implement montgomery reduction for some number of limbs + fn mont_impl(limbs: usize) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + for i in 0..limbs { + { + let temp = get_temp(i); + gen.extend(quote!{ + let k = #temp.wrapping_mul(INV); + let mut carry = 0; + crate::ff::mac_with_carry(#temp, k, MODULUS.0[0], &mut carry); + }); + } + + for j in 1..limbs { + let temp = get_temp(i + j); + gen.extend(quote!{ + #temp = crate::ff::mac_with_carry(#temp, k, MODULUS.0[#j], &mut carry); + }); + } + + let temp = get_temp(i + limbs); + + if i == 0 { + gen.extend(quote!{ + #temp = crate::ff::adc(#temp, 0, &mut carry); + }); + } else { + gen.extend(quote!{ + #temp = crate::ff::adc(#temp, carry2, &mut carry); + }); + } + + if i != (limbs - 1) { + gen.extend(quote!{ + let carry2 = carry; + }); + } + } + + for i in 0..limbs { + let temp = get_temp(limbs + i); + + gen.extend(quote!{ + (self.0).0[#i] = #temp; + }); + } + + gen + } + + let top_limb_index = limbs - 1; + + let montgomery_impl = mont_impl(limbs); + + // (self.0).0[0], (self.0).0[1], ..., 0, 0, 0, 0, ... + let mut into_repr_params = proc_macro2::TokenStream::new(); + into_repr_params.append_separated( + (0..limbs) + .map(|i| quote!{ (self.0).0[#i] }) + .chain((0..limbs).map(|_| quote!{0})), + proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), + ); + + let modulus_random_prefix = format!("{}{}_", MODULUS_PREFIX, random_id); + let modulus_neg_random_prefix = format!("{}{}_", MODULUS_NEGATED_PREFIX, random_id); + + let mul_asm_impl = mul_impl(mont_inv, &modulus_random_prefix); + let sqr_asm_impl = sqr_impl(mont_inv, &modulus_random_prefix); + // let add_asm_impl = add_impl(MODULUS_PREFIX); + let add_asm_impl = add_impl(&modulus_neg_random_prefix); + let sub_asm_impl = sub_impl(&modulus_random_prefix); + // let sub_asm_impl = sub_impl(MODULUS_NEGATED_PREFIX); + // let double_asm_impl = double_impl(MODULUS_PREFIX); + let double_asm_impl = double_impl(&modulus_neg_random_prefix); + + quote!{ + impl ::std::marker::Copy for #name { } + + impl ::std::clone::Clone for #name { + fn clone(&self) -> #name { + *self + } + } + + impl ::std::cmp::PartialEq for #name { + fn eq(&self, other: &#name) -> bool { + self.0 == other.0 + } + } + + impl ::std::cmp::Eq for #name { } + + impl ::std::fmt::Debug for #name + { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}({:?})", stringify!(#name), self.into_repr()) + } + } + + /// Elements are ordered lexicographically. + impl Ord for #name { + #[inline(always)] + fn cmp(&self, other: &#name) -> ::std::cmp::Ordering { + self.into_repr().cmp(&other.into_repr()) + } + } + + impl PartialOrd for #name { + #[inline(always)] + fn partial_cmp(&self, other: &#name) -> Option<::std::cmp::Ordering> { + Some(self.cmp(other)) + } + } + + impl ::std::fmt::Display for #name { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}({})", stringify!(#name), self.into_repr()) + } + } + + impl ::rand::Rand for #name { + /// Computes a uniformly random element using rejection sampling. + fn rand(rng: &mut R) -> Self { + loop { + let mut tmp = #name(#repr::rand(rng)); + + // Mask away the unused bits at the beginning. + tmp.0.as_mut()[#top_limb_index] &= 0xffffffffffffffff >> REPR_SHAVE_BITS; + + if tmp.is_valid() { + return tmp + } + } + } + } + + impl From<#name> for #repr { + fn from(e: #name) -> #repr { + e.into_repr() + } + } + + impl crate::ff::PrimeField for #name { + type Repr = #repr; + + fn from_repr(r: #repr) -> Result<#name, crate::ff::PrimeFieldDecodingError> { + let mut r = #name(r); + if r.is_valid() { + r.mul_assign(&#name(R2)); + + Ok(r) + } else { + Err(crate::ff::PrimeFieldDecodingError::NotInField(format!("{}", r.0))) + } + } + + fn from_raw_repr(r: #repr) -> Result { + let mut r = #name(r); + if r.is_valid() { + Ok(r) + } else { + Err(crate::ff::PrimeFieldDecodingError::NotInField(format!("{}", r.0))) + } + } + + fn into_repr(&self) -> #repr { + let mut r = *self; + r.mont_reduce( + #into_repr_params + ); + + r.0 + } + + fn into_raw_repr(&self) -> #repr { + let r = *self; + + r.0 + } + + fn char() -> #repr { + MODULUS + } + + const NUM_BITS: u32 = MODULUS_BITS; + + const CAPACITY: u32 = Self::NUM_BITS - 1; + + fn multiplicative_generator() -> Self { + #name(GENERATOR) + } + + const S: u32 = S; + + fn root_of_unity() -> Self { + #name(ROOT_OF_UNITY) + } + + } + + impl crate::ff::Field for #name { + #[inline] + fn zero() -> Self { + #name(#repr::from(0)) + } + + #[inline] + fn one() -> Self { + #name(R) + } + + #[inline] + fn is_zero(&self) -> bool { + self.0.is_zero() + } + + #[inline] + fn add_assign(&mut self, other: &#name) { + (self.0).0 = #name::add_asm_adx_with_reduction(&(self.0).0, &(other.0).0); + } + + #[inline] + fn double(&mut self) { + (self.0).0 = Self::double_asm_adx_with_reduction(&(self.0).0); + } + + #[inline] + fn sub_assign(&mut self, other: &#name) { + (self.0).0 = Self::sub_asm_adx_with_reduction(&(self.0).0, &(other.0).0); + } + + #[inline] + fn negate(&mut self) { + if !self.is_zero() { + let mut tmp = MODULUS; + tmp.sub_noborrow(&self.0); + self.0 = tmp; + } + } + + fn inverse(&self) -> Option { + if self.is_zero() { + None + } else { + // Guajardo Kumar Paar Pelzl + // Efficient Software-Implementation of Finite Fields with Applications to Cryptography + // Algorithm 16 (BEA for Inversion in Fp) + + let one = #repr::from(1); + + let mut u = self.0; + let mut v = MODULUS; + let mut b = #name(R2); // Avoids unnecessary reduction step. + let mut c = Self::zero(); + + while u != one && v != one { + while u.is_even() { + u.div2(); + + if b.0.is_even() { + b.0.div2(); + } else { + b.0.add_nocarry(&MODULUS); + b.0.div2(); + } + } + + while v.is_even() { + v.div2(); + + if c.0.is_even() { + c.0.div2(); + } else { + c.0.add_nocarry(&MODULUS); + c.0.div2(); + } + } + + if v < u { + u.sub_noborrow(&v); + b.sub_assign(&c); + } else { + v.sub_noborrow(&u); + c.sub_assign(&b); + } + } + + if u == one { + Some(b) + } else { + Some(c) + } + } + } + + #[inline(always)] + fn frobenius_map(&mut self, _: usize) { + // This has no effect in a prime field. + } + + #[inline] + fn mul_assign(&mut self, other: &#name) + { + (self.0).0 = Self::mont_mul_asm_adx_with_reduction(&(self.0).0, &(other.0).0); + } + + #[inline] + fn square(&mut self) + { + (self.0).0 = Self::mont_sqr_asm_adx_with_reduction(&(self.0).0); + } + } + + impl std::default::Default for #name { + fn default() -> Self { + Self::zero() + } + } + + impl std::hash::Hash for #name { + fn hash(&self, state: &mut H) { + for limb in self.0.as_ref().iter() { + limb.hash(state); + } + } + } + + impl #name { + /// Determines if the element is really in the field. This is only used + /// internally. + #[inline(always)] + fn is_valid(&self) -> bool { + self.0 < MODULUS + } + + /// Subtracts the modulus from this element if this element is not in the + /// field. Only used interally. + #[inline(always)] + fn reduce(&mut self) { + if !self.is_valid() { + self.0.sub_noborrow(&MODULUS); + } + } + + #[inline(always)] + fn mont_reduce( + &mut self, + #mont_paramlist + ) + { + // The Montgomery reduction here is based on Algorithm 14.32 in + // Handbook of Applied Cryptography + // . + + #montgomery_impl + + self.reduce(); + } + + #mul_asm_impl + + #sqr_asm_impl + + #add_asm_impl + + #sub_asm_impl + + #double_asm_impl + } + + impl ::serde::Serialize for #name { + fn serialize(&self, serializer: S) -> Result + where S: ::serde::Serializer + { + let repr = self.into_repr(); + repr.serialize(serializer) + } + } + + impl<'de> ::serde::Deserialize<'de> for #name { + fn deserialize(deserializer: D) -> Result + where D: ::serde::Deserializer<'de> + { + let repr = #repr::deserialize(deserializer)?; + let new = Self::from_repr(repr).expect("serialized representation is expected to be valid"); + + Ok(new) + } + } + } +} diff --git a/crates/ff/ff_derive/src/asm/impls_4.rs b/crates/ff/ff_derive/src/asm/impls_4.rs new file mode 100644 index 0000000..9a0da6a --- /dev/null +++ b/crates/ff/ff_derive/src/asm/impls_4.rs @@ -0,0 +1,741 @@ +use super::super::{get_temp_with_literal}; + +pub(crate) fn mul_impl(mont_inv: u64, modulus_static_prefix: &str) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + let m0 = get_temp_with_literal(modulus_static_prefix, 0); + let m1 = get_temp_with_literal(modulus_static_prefix, 1); + let m2 = get_temp_with_literal(modulus_static_prefix, 2); + let m3 = get_temp_with_literal(modulus_static_prefix, 3); + + gen.extend(quote!{ + #[allow(clippy::too_many_lines)] + #[inline(always)] + #[cfg(target_arch = "x86_64")] + // #[cfg(all(target_arch = "x86_64", target_feature = "adx"))] + fn mont_mul_asm_adx_with_reduction(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + // this is CIOS multiplication when top bit for top word of modulus is not set + + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + unsafe { + core::arch::asm!( + // round 0 + "mov rdx, qword ptr [{a_ptr} + 0]", + "xor r8d, r8d", + "mulx r14, r13, qword ptr [{b_ptr} + 0]", + "mulx r9, r8, qword ptr [{b_ptr} + 8]", + "mulx r10, r15, qword ptr [{b_ptr} + 16]", + "mulx r12, rdi, qword ptr [{b_ptr} + 24]", + "mov rdx, r13", + "mov r11, {inv}", + "mulx r11, rdx, r11", + "adcx r14, r8", + "adox r10, rdi", + "adcx r15, r9", + "mov r11, 0", + "adox r12, r11", + "adcx r10, r11", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r13, r8", + "adcx r14, rdi", + "adox r14, r9", + "adcx r15, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", + "adox r15, r8", + "adcx r10, rdi", + "adox r10, r9", + "adcx r12, r11", + "mov r9, 0", + "adox r12, r9", + + // round 1 + "mov rdx, qword ptr [{a_ptr} + 8]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r14, r8", + "adox r15, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r15, rdi", + "adox r10, r11", + "mulx r13, rdi, qword ptr [{b_ptr} + 24]", + "adcx r10, r8", + "adox r12, rdi", + "adcx r12, r9", + "mov rdi, 0", + "adox r13, rdi", + "adcx r13, rdi", + "mov rdx, r14", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r14, r8", + "adcx r15, rdi", + "adox r15, r9", + "adcx r10, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", + "adox r10, r8", + "adcx r12, r9", + "adox r12, rdi", + "adcx r13, r11", + "mov rdi, 0", + "adox r13, rdi", + + // round 2 + "mov rdx, qword ptr [{a_ptr} + 16]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r15, r8", + "adox r10, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r10, rdi", + "adox r12, r11", + "mulx r14, rdi, qword ptr [{b_ptr} + 24]", + "adcx r12, r8", + "adox r13, r9", + "adcx r13, rdi", + "mov r9, 0", + "adox r14, r9", + "adcx r14, r9", + "mov rdx, r15", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r15, r8", + "adcx r10, r9", + "adox r10, rdi", + "adcx r12, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", + "adox r12, r8", + "adcx r13, r9", + "adox r13, rdi", + "adcx r14, r11", + "mov rdi, 0", + "adox r14, rdi", + + // round 3 + "mov rdx, qword ptr [{a_ptr} + 24]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r10, r8", + "adox r12, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r12, rdi", + "adox r13, r11", + "mulx r15, rdi, qword ptr [{b_ptr} + 24]", + "adcx r13, r8", + "adox r14, r9", + "adcx r14, rdi", + "mov r9, 0", + "adox r15, r9", + "adcx r15, r9", + "mov rdx, r10", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r10, r8", + "adcx r12, r9", + "adox r12, rdi", + "adcx r13, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx rdx, rdi, qword ptr [rip + {q3_ptr}]", + "adox r13, r8", + "adcx r14, r9", + "adox r14, rdi", + "adcx r15, rdx", + "mov rdi, 0", + "adox r15, rdi", + // reduction. We use sub/sbb + + "mov r8, r12", + "mov rdx, qword ptr [rip + {q0_ptr}]", + "sub r8, rdx", + "mov r9, r13", + "mov rdx, qword ptr [rip + {q1_ptr}]", + "sbb r9, rdx", + "mov r10, r14", + "mov rdx, qword ptr [rip + {q2_ptr}]", + "sbb r10, rdx", + "mov r11, r15", + "mov rdx, qword ptr [rip + {q3_ptr}]", + "sbb r11, rdx", + + // if CF == 1 then original result was ok (reduction wa not necessary) + // so if not carry (CMOVNQ) then we copy + "cmovnc r12, r8", + "cmovnc r13, r9", + "cmovnc r14, r10", + "cmovnc r15, r11", + // end of reduction + q0_ptr = sym #m0, + q1_ptr = sym #m1, + q2_ptr = sym #m2, + q3_ptr = sym #m3, + inv = const #mont_inv, + a_ptr = in(reg) a.as_ptr(), + b_ptr = in(reg) b.as_ptr(), + out("rdx") _, + out("rdi") _, + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, readonly, nostack) + ); + } + + [r0, r1, r2, r3] + } + + }); + + gen +} + +pub(crate) fn sqr_impl(mont_inv: u64, modulus_static_prefix: &str) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + let m0 = get_temp_with_literal(modulus_static_prefix, 0); + let m1 = get_temp_with_literal(modulus_static_prefix, 1); + let m2 = get_temp_with_literal(modulus_static_prefix, 2); + let m3 = get_temp_with_literal(modulus_static_prefix, 3); + + gen.extend(quote!{ + #[allow(clippy::too_many_lines)] + #[inline(always)] + #[cfg(target_arch = "x86_64")] + // #[cfg(all(target_arch = "x86_64", target_feature = "adx"))] + fn mont_sqr_asm_adx_with_reduction(a: &[u64; 4]) -> [u64; 4] { + // this is CIOS multiplication when top bit for top word of modulus is not set + + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + unsafe { + core::arch::asm!( + // round 0 + "mov rdx, qword ptr [{a_ptr} + 0]", + "xor r8d, r8d", + "mulx r10, r9, qword ptr [{a_ptr} + 8]", + "mulx r15, r8, qword ptr [{a_ptr} + 16]", + "mulx r12, r11, qword ptr [{a_ptr} + 24]", + "adox r10, r8", + "adcx r11, r15", + "mov rdx, qword ptr [{a_ptr} + 8]", + "mulx r15, r8, qword ptr [{a_ptr} + 16]", + "mulx rcx, rdi, qword ptr [{a_ptr} + 24]", + "mov rdx, qword ptr [{a_ptr} + 16]", + "mulx r14, r13, qword ptr [{a_ptr} + 24]", + "adox r11, r8", + "adcx r12, rdi", + "adox r12, r15", + "adcx r13, rcx", + "mov r8, 0", + "adox r13, r8", + "adcx r14, r8", + + // double + "adox r9, r9", + "adcx r12, r12", + "adox r10, r10", + "adcx r13, r13", + "adox r11, r11", + "adcx r14, r14", + + // square contributions + "mov rdx, qword ptr [{a_ptr} + 0]", + "mulx rcx, r8, rdx", + "mov rdx, qword ptr [{a_ptr} + 16]", + "mulx rdi, rdx, rdx", + "adox r12, rdx", + "adcx r9, rcx", + "adox r13, rdi", + "mov rdx, qword ptr [{a_ptr} + 24]", + "mulx r15, rcx, rdx", + "mov rdx, qword ptr [{a_ptr} + 8]", + "mulx rdx, rdi, rdx", + "adcx r10, rdi", + "adox r14, rcx", + "mov rdi, 0", + "adcx r11, rdx", + "adox r15, rdi", + + // reduction round 0 + "mov rdx, r8", + "mov rdi, {inv}", + "mulx rdi, rdx, rdi", + "mulx rcx, rdi, qword ptr [rip + {q0_ptr}]", + "adox r8, rdi", + "mulx rdi, r8, qword ptr [rip + {q3_ptr}]", + "adcx r12, rdi", + "adox r9, rcx", + "mov rdi, 0", + "adcx r13, rdi", + "mulx rcx, rdi, qword ptr [rip + {q1_ptr}]", + "adox r10, rcx", + "adcx r9, rdi", + "adox r11, r8", + "mulx rcx, rdi, qword ptr [rip + {q2_ptr}]", + "adcx r10, rdi", + "adcx r11, rcx", + + // reduction round 1 + "mov rdx, r9", + "mov rdi, {inv}", + "mulx rdi, rdx, rdi", + "mulx rcx, rdi, qword ptr [rip + {q2_ptr}]", + "adox r12, rcx", + "mulx rcx, r8, qword ptr [rip + {q3_ptr}]", + "adcx r12, r8", + "adox r13, rcx", + "mov r8, 0", + "adcx r13, r8", + "adox r14, r8", + "mulx rcx, r8, qword ptr [rip + {q0_ptr}]", + "adcx r9, r8", + "adox r10, rcx", + "mulx rcx, r8, qword ptr [rip + {q1_ptr}]", + "adcx r10, r8", + "adox r11, rcx", + "adcx r11, rdi", + + // reduction round 2 + "mov rdx, r10", + "mov rdi, {inv}", + "mulx rdi, rdx, rdi", + "mulx rcx, rdi, qword ptr [rip + {q1_ptr}]", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "adox r12, rcx", + "adcx r12, r8", + "adox r13, r9", + "mulx r9, r8, qword ptr [rip + {q3_ptr}]", + "adcx r13, r8", + "adox r14, r9", + "mov r8, 0", + "adcx r14, r8", + "adox r15, r8", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "adcx r10, r8", + "adox r11, r9", + "mov r8, 0", + "adcx r11, rdi", + "adox r12, r8", + + // reduction round 3 + "mov rdx, r11", + "mov rdi, {inv}", + "mulx rdi, rdx, rdi", + "mulx rcx, rdi, qword ptr [rip + {q0_ptr}]", + "mulx r9, r8, qword ptr [rip + {q1_ptr}]", + "adox r11, rdi", + "adcx r12, r8", + "adox r12, rcx", + "adcx r13, r9", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, r10, qword ptr [rip + {q3_ptr}]", + "adox r13, r8", + "adcx r14, r10", + "mov r8, 0", + "adox r14, r9", + "adcx r15, r11", + "adox r15, r8", + + // reduction. We use sub/sbb + "mov r8, r12", + "mov rdx, qword ptr [rip + {q0_ptr}]", + "sub r8, rdx", + "mov r9, r13", + "mov rdx, qword ptr [rip + {q1_ptr}]", + "sbb r9, rdx", + "mov r10, r14", + "mov rdx, qword ptr [rip + {q2_ptr}]", + "sbb r10, rdx", + "mov r11, r15", + "mov rdx, qword ptr [rip + {q3_ptr}]", + "sbb r11, rdx", + + // if CF == 1 then original result was ok (reduction wa not necessary) + // so if not carry (CMOVNQ) then we copy + "cmovnc r12, r8", + "cmovnc r13, r9", + "cmovnc r14, r10", + "cmovnc r15, r11", + // end of reduction + q0_ptr = sym #m0, + q1_ptr = sym #m1, + q2_ptr = sym #m2, + q3_ptr = sym #m3, + inv = const #mont_inv, + a_ptr = in(reg) a.as_ptr(), + out("rcx") _, + out("rdx") _, + out("rdi") _, + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, readonly, nostack) + ); + } + + [r0, r1, r2, r3] + } + + }); + + gen +} + +pub(crate) fn add_impl(modulus_static_prefix: &str) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + let m0 = get_temp_with_literal(modulus_static_prefix, 0); + let m1 = get_temp_with_literal(modulus_static_prefix, 1); + let m2 = get_temp_with_literal(modulus_static_prefix, 2); + let m3 = get_temp_with_literal(modulus_static_prefix, 3); + + gen.extend(quote!{ + #[allow(clippy::too_many_lines)] + #[inline(always)] + #[cfg(target_arch = "x86_64")] + // #[cfg(all(target_arch = "x86_64", target_feature = "adx"))] + fn add_asm_adx_with_reduction(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + unsafe { + core::arch::asm!( + // we sum (a+b) using addition chain with OF + // and sum (a+b) - p using addition chain with CF + // if (a+b) does not overflow the modulus + // then sum (a+b) will produce CF + "xor r12d, r12d", + "mov r12, qword ptr [{a_ptr} + 0]", + "mov r13, qword ptr [{a_ptr} + 8]", + "mov r14, qword ptr [{a_ptr} + 16]", + "mov r15, qword ptr [{a_ptr} + 24]", + "adox r12, qword ptr [{b_ptr} + 0]", + "mov r8, r12", + "adcx r8, qword ptr [rip + {q0_ptr}]", + "adox r13, qword ptr [{b_ptr} + 8]", + "mov r9, r13", + "adcx r9, qword ptr [rip + {q1_ptr}]", + "adox r14, qword ptr [{b_ptr} + 16]", + "mov r10, r14", + "adcx r10, qword ptr [rip + {q2_ptr}]", + "adox r15, qword ptr [{b_ptr} + 24]", + "mov r11, r15", + "adcx r11, qword ptr [rip + {q3_ptr}]", + + // if CF = 0 then take value (a+b) from [r12, .., r15] + // otherwise take (a+b) - p + + "cmovc r12, r8", + "cmovc r13, r9", + "cmovc r14, r10", + "cmovc r15, r11", + + q0_ptr = sym #m0, + q1_ptr = sym #m1, + q2_ptr = sym #m2, + q3_ptr = sym #m3, + // end of reduction + a_ptr = in(reg) a.as_ptr(), + b_ptr = in(reg) b.as_ptr(), + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, readonly, nostack) + ); + } + + // unsafe { + // core::arch::asm!( + // "xor r12d, r12d", + // "mov r12, qword ptr [{a_ptr} + 0]", + // "mov r13, qword ptr [{a_ptr} + 8]", + // "mov r14, qword ptr [{a_ptr} + 16]", + // "mov r15, qword ptr [{a_ptr} + 24]", + // "add r12, qword ptr [{b_ptr} + 0]", + // "adc r13, qword ptr [{b_ptr} + 8]", + // "adc r14, qword ptr [{b_ptr} + 16]", + // "adc r15, qword ptr [{b_ptr} + 24]", + + // "mov r8, r12", + // "mov rdx, qword ptr [rip + {q0_ptr}]", + // "sub r8, rdx", + // "mov r9, r13", + // "mov rdx, qword ptr [rip + {q1_ptr}]", + // "sbb r9, rdx", + // "mov r10, r14", + // "mov rdx, qword ptr [rip + {q2_ptr}]", + // "sbb r10, rdx", + // "mov r11, r15", + // "mov rdx, qword ptr [rip + {q3_ptr}]", + // "sbb r11, rdx", + + // "cmovnc r12, r8", + // "cmovnc r13, r9", + // "cmovnc r14, r10", + // "cmovnc r15, r11", + + // q0_ptr = sym #m0, + // q1_ptr = sym #m1, + // q2_ptr = sym #m2, + // q3_ptr = sym #m3, + // // end of reduction + // a_ptr = in(reg) a.as_ptr(), + // b_ptr = in(reg) b.as_ptr(), + // out("rdx") _, + // out("r8") _, + // out("r9") _, + // out("r10") _, + // out("r11") _, + // out("r12") r0, + // out("r13") r1, + // out("r14") r2, + // out("r15") r3, + // options(pure, readonly, nostack) + // ); + // } + + [r0, r1, r2, r3] + } + }); + + gen +} + +pub(crate) fn double_impl(modulus_static_prefix: &str) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + let m0 = get_temp_with_literal(modulus_static_prefix, 0); + let m1 = get_temp_with_literal(modulus_static_prefix, 1); + let m2 = get_temp_with_literal(modulus_static_prefix, 2); + let m3 = get_temp_with_literal(modulus_static_prefix, 3); + + gen.extend(quote!{ + #[allow(clippy::too_many_lines)] + #[inline(always)] + #[cfg(target_arch = "x86_64")] + // #[cfg(all(target_arch = "x86_64", target_feature = "adx"))] + fn double_asm_adx_with_reduction(a: &[u64; 4]) -> [u64; 4] { + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + unsafe { + core::arch::asm!( + // we sum (a+b) using addition chain with OF + // and sum (a+b) - p using addition chain with CF + // if (a+b) does not overflow the modulus + // then sum (a+b) will produce CF + "xor r12d, r12d", + "mov r12, qword ptr [{a_ptr} + 0]", + "adox r12, r12", + "mov r13, qword ptr [{a_ptr} + 8]", + "adox r13, r13", + "mov r14, qword ptr [{a_ptr} + 16]", + "adox r14, r14", + "mov r15, qword ptr [{a_ptr} + 24]", + "adox r15, r15", + + "mov r8, r12", + "adcx r8, qword ptr [rip + {q0_ptr}]", + "mov r9, r13", + "adcx r9, qword ptr [rip + {q1_ptr}]", + "mov r10, r14", + "adcx r10, qword ptr [rip + {q2_ptr}]", + "mov r11, r15", + "adcx r11, qword ptr [rip + {q3_ptr}]", + + // if CF = 0 then take value (a+b) from [r12, .., r15] + // otherwise take (a+b) - p + + "cmovc r12, r8", + "cmovc r13, r9", + "cmovc r14, r10", + "cmovc r15, r11", + + q0_ptr = sym #m0, + q1_ptr = sym #m1, + q2_ptr = sym #m2, + q3_ptr = sym #m3, + // end of reduction + a_ptr = in(reg) a.as_ptr(), + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, readonly, nostack) + ); + } + + // unsafe { + // core::arch::asm!( + // "xor r12d, r12d", + // "mov r12, qword ptr [{a_ptr} + 0]", + // "mov r13, qword ptr [{a_ptr} + 8]", + // "mov r14, qword ptr [{a_ptr} + 16]", + // "mov r15, qword ptr [{a_ptr} + 24]", + // "add r12, r12", + // "adc r13, r13", + // "adc r14, r14", + // "adc r15, r15", + + // "mov r8, r12", + // "mov rdx, qword ptr [rip + {q0_ptr}]", + // "sub r8, rdx", + // "mov r9, r13", + // "mov rdx, qword ptr [rip + {q1_ptr}]", + // "sbb r9, rdx", + // "mov r10, r14", + // "mov rdx, qword ptr [rip + {q2_ptr}]", + // "sbb r10, rdx", + // "mov r11, r15", + // "mov rdx, qword ptr [rip + {q3_ptr}]", + // "sbb r11, rdx", + + // "cmovnc r12, r8", + // "cmovnc r13, r9", + // "cmovnc r14, r10", + // "cmovnc r15, r11", + + // q0_ptr = sym #m0, + // q1_ptr = sym #m1, + // q2_ptr = sym #m2, + // q3_ptr = sym #m3, + // // end of reduction + // a_ptr = in(reg) a.as_ptr(), + // out("rdx") _, + // out("r8") _, + // out("r9") _, + // out("r10") _, + // out("r11") _, + // out("r12") r0, + // out("r13") r1, + // out("r14") r2, + // out("r15") r3, + // options(pure, readonly, nostack) + // ); + // } + + [r0, r1, r2, r3] + } + }); + + gen +} + +pub(crate) fn sub_impl(modulus_static_prefix: &str) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + let m0 = get_temp_with_literal(modulus_static_prefix, 0); + let m1 = get_temp_with_literal(modulus_static_prefix, 1); + let m2 = get_temp_with_literal(modulus_static_prefix, 2); + let m3 = get_temp_with_literal(modulus_static_prefix, 3); + + gen.extend(quote!{ + #[allow(clippy::too_many_lines)] + #[inline(always)] + #[cfg(target_arch = "x86_64")] + // #[cfg(all(target_arch = "x86_64", target_feature = "adx"))] + fn sub_asm_adx_with_reduction(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + unsafe { + core::arch::asm!( + "xor r12d, r12d", + "mov r12, qword ptr [{a_ptr} + 0]", + "sub r12, qword ptr [{b_ptr} + 0]", + "mov r13, qword ptr [{a_ptr} + 8]", + "sbb r13, qword ptr [{b_ptr} + 8]", + "mov r14, qword ptr [{a_ptr} + 16]", + "sbb r14, qword ptr [{b_ptr} + 16]", + "mov r15, qword ptr [{a_ptr} + 24]", + "sbb r15, qword ptr [{b_ptr} + 24]", + + // duplicate (a-b) into [r8, r9, r10, r11] + + // now make [r12, .., r15] + modulus; + // if (a-b) did underflow then 2^256 + (a-b) < modulus, + // so below we will get an overflow + + "mov r8, r12", + "add r12, qword ptr [rip + {q0_ptr}]", + // "adox r12, qword ptr [rip + {q0_ptr}]", + "mov r9, r13", + "adc r13, qword ptr [rip + {q1_ptr}]", + // "adox r13, qword ptr [rip + {q1_ptr}]", + "mov r10, r14", + "adc r14, qword ptr [rip + {q2_ptr}]", + // "adox r14, qword ptr [rip + {q2_ptr}]", + "mov r11, r15", + "adc r15, qword ptr [rip + {q3_ptr}]", + // "adox r15, qword ptr [rip + {q3_ptr}]", + + "cmovnc r12, r8", + "cmovnc r13, r9", + "cmovnc r14, r10", + "cmovnc r15, r11", + + q0_ptr = sym #m0, + q1_ptr = sym #m1, + q2_ptr = sym #m2, + q3_ptr = sym #m3, + // end of reduction + a_ptr = in(reg) a.as_ptr(), + b_ptr = in(reg) b.as_ptr(), + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, readonly, nostack) + ); + } + + [r0, r1, r2, r3] + } + }); + + gen +} diff --git a/crates/ff/ff_derive/src/asm/mod.rs b/crates/ff/ff_derive/src/asm/mod.rs new file mode 100644 index 0000000..e1049f5 --- /dev/null +++ b/crates/ff/ff_derive/src/asm/mod.rs @@ -0,0 +1,4 @@ +mod asm_derive; +mod impls_4; + +pub(crate) use self::asm_derive::prime_field_asm_impl; \ No newline at end of file diff --git a/crates/ff/ff_derive/src/lib.rs b/crates/ff/ff_derive/src/lib.rs new file mode 100644 index 0000000..7bcfcfa --- /dev/null +++ b/crates/ff/ff_derive/src/lib.rs @@ -0,0 +1,1404 @@ +#![recursion_limit = "1024"] + +extern crate proc_macro; +extern crate proc_macro2; +extern crate syn; +#[macro_use] +extern crate quote; + +extern crate num_bigint; +extern crate num_integer; +extern crate num_traits; + +use num_bigint::BigUint; +use num_integer::Integer; +use num_traits::{One, ToPrimitive, Zero}; +use quote::TokenStreamExt; +use std::str::FromStr; + +mod utils; +use utils::*; + +#[cfg(feature = "asm")] +mod asm; + +#[cfg(feature = "asm")] +#[proc_macro_derive(PrimeFieldAsm, attributes(PrimeFieldModulus, PrimeFieldGenerator, UseADX))] +pub fn prime_field_asm(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + self::asm::prime_field_asm_impl(input) +} + +#[proc_macro_derive(PrimeField, attributes(PrimeFieldModulus, PrimeFieldGenerator, OptimisticCIOSMultiplication, OptimisticCIOSSquaring))] +pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + // Parse the type definition + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + // The struct we're deriving for is a wrapper around a "Repr" type we must construct. + let repr_ident = fetch_wrapped_ident(&ast.data) + .expect("PrimeField derive only operates over tuple structs of a single item"); + + // We're given the modulus p of the prime field + let modulus: BigUint = fetch_attr("PrimeFieldModulus", &ast.attrs) + .expect("Please supply a PrimeFieldModulus attribute") + .parse() + .expect("PrimeFieldModulus should be a number"); + + // We may be provided with a generator of p - 1 order. It is required that this generator be quadratic + // nonresidue. + let generator: BigUint = fetch_attr("PrimeFieldGenerator", &ast.attrs) + .expect("Please supply a PrimeFieldGenerator attribute") + .parse() + .expect("PrimeFieldGenerator should be a number"); + + // User may opt-in for feature to generate CIOS based multiplication operation + let opt_in_cios_mul: Option = fetch_attr("OptimisticCIOSMultiplication", &ast.attrs) + .map(|el| el.parse().expect("OptimisticCIOSMultiplication should be `true` or `false`")); + + // User may opt-in for feature to generate CIOS based squaring operation + let opt_in_cios_square: Option = fetch_attr("OptimisticCIOSSquaring", &ast.attrs) + .map(|el| el.parse().expect("OptimisticCIOSSquaring should be `true` or `false`")); + + // The arithmetic in this library only works if the modulus*2 is smaller than the backing + // representation. Compute the number of limbs we need. + let mut limbs = 1; + { + let mod2 = (&modulus) << 1; // modulus * 2 + let mut cur = BigUint::one() << 64; // always 64-bit limbs for now + while cur < mod2 { + limbs += 1; + cur = cur << 64; + } + } + + let modulus_limbs = biguint_to_real_u64_vec(modulus.clone(), limbs); + let top_limb = modulus_limbs.last().unwrap().clone().to_u64().unwrap(); + let can_use_optimistic_cios_mul = { + let mut can_use = if let Some(cios) = opt_in_cios_mul { + cios + } else { + false + }; + if top_limb == 0 { + can_use = false; + } + + if top_limb > (std::u64::MAX / 2) - 1 { + can_use = false; + } + can_use + }; + + let can_use_optimistic_cios_sqr = { + let mut can_use = if let Some(cios) = opt_in_cios_square { + cios + } else { + false + }; + if top_limb == 0 { + can_use = false; + } + + if top_limb > (std::u64::MAX / 4) - 1 { + assert!(!can_use, "can not use optimistic CIOS for this modulus"); + can_use = false; + } + can_use + }; + + let mut gen = proc_macro2::TokenStream::new(); + + let (constants_impl, sqrt_impl) = prime_field_constants_and_sqrt( + &ast.ident, + &repr_ident, + modulus, + limbs, + generator, + ); + + gen.extend(constants_impl); + gen.extend(prime_field_repr_impl(&repr_ident, limbs)); + gen.extend(prime_field_impl(&ast.ident, &repr_ident, can_use_optimistic_cios_mul, can_use_optimistic_cios_sqr, limbs)); + gen.extend(sqrt_impl); + + // Return the generated impl + gen.into() +} + +/// Fetches the ident being wrapped by the type we're deriving. +fn fetch_wrapped_ident(body: &syn::Data) -> Option { + match body { + &syn::Data::Struct(ref variant_data) => match variant_data.fields { + syn::Fields::Unnamed(ref fields) => { + if fields.unnamed.len() == 1 { + match fields.unnamed[0].ty { + syn::Type::Path(ref path) => { + if path.path.segments.len() == 1 { + return Some(path.path.segments[0].ident.clone()); + } + } + _ => {} + } + } + } + _ => {} + }, + _ => {} + }; + + None +} + +/// Fetch an attribute string from the derived struct. +fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option { + for attr in attrs { + if let Ok(meta) = attr.parse_meta() { + match meta { + syn::Meta::NameValue(nv) => { + if nv.path.is_ident(name) { + match nv.lit { + syn::Lit::Str(ref s) => return Some(s.value()), + _ => { + panic!("attribute {} should be a string", name); + } + } + } + } + _ => { + panic!("attribute {} should be a string", name); + } + } + } + } + + None +} + +// Implement PrimeFieldRepr for the wrapped ident `repr` with `limbs` limbs. +fn prime_field_repr_impl(repr: &syn::Ident, limbs: usize) -> proc_macro2::TokenStream { + quote! { + + #[derive(Copy, Clone, PartialEq, Eq, Default, ::serde::Serialize, ::serde::Deserialize)] + pub struct #repr( + pub [u64; #limbs] + ); + + impl ::std::fmt::Debug for #repr + { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "0x")?; + for i in self.0.iter().rev() { + write!(f, "{:016x}", *i)?; + } + + Ok(()) + } + } + + impl ::rand::Rand for #repr { + #[inline(always)] + fn rand(rng: &mut R) -> Self { + #repr(rng.gen()) + } + } + + impl ::std::fmt::Display for #repr { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "0x")?; + for i in self.0.iter().rev() { + write!(f, "{:016x}", *i)?; + } + + Ok(()) + } + } + + impl std::hash::Hash for #repr { + fn hash(&self, state: &mut H) { + for limb in self.0.iter() { + limb.hash(state); + } + } + } + + impl AsRef<[u64]> for #repr { + #[inline(always)] + fn as_ref(&self) -> &[u64] { + &self.0 + } + } + + impl AsMut<[u64]> for #repr { + #[inline(always)] + fn as_mut(&mut self) -> &mut [u64] { + &mut self.0 + } + } + + impl From for #repr { + #[inline(always)] + fn from(val: u64) -> #repr { + use std::default::Default; + + let mut repr = Self::default(); + repr.0[0] = val; + repr + } + } + + impl Ord for #repr { + #[inline(always)] + fn cmp(&self, other: &#repr) -> ::std::cmp::Ordering { + for (a, b) in self.0.iter().rev().zip(other.0.iter().rev()) { + if a < b { + return ::std::cmp::Ordering::Less + } else if a > b { + return ::std::cmp::Ordering::Greater + } + } + + ::std::cmp::Ordering::Equal + } + } + + impl PartialOrd for #repr { + #[inline(always)] + fn partial_cmp(&self, other: &#repr) -> Option<::std::cmp::Ordering> { + Some(self.cmp(other)) + } + } + + impl crate::ff::PrimeFieldRepr for #repr { + #[inline(always)] + fn is_odd(&self) -> bool { + self.0[0] & 1 == 1 + } + + #[inline(always)] + fn is_even(&self) -> bool { + !self.is_odd() + } + + #[inline(always)] + fn is_zero(&self) -> bool { + self.0.iter().all(|&e| e == 0) + } + + #[inline(always)] + fn shr(&mut self, mut n: u32) { + if n as usize >= 64 * #limbs { + *self = Self::from(0); + return; + } + + while n >= 64 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + ::std::mem::swap(&mut t, i); + } + n -= 64; + } + + if n > 0 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + let t2 = *i << (64 - n); + *i >>= n; + *i |= t; + t = t2; + } + } + } + + #[inline(always)] + fn div2(&mut self) { + let mut t = 0; + for i in self.0.iter_mut().rev() { + let t2 = *i << 63; + *i >>= 1; + *i |= t; + t = t2; + } + } + + #[inline(always)] + fn mul2(&mut self) { + let mut last = 0; + for i in &mut self.0 { + let tmp = *i >> 63; + *i <<= 1; + *i |= last; + last = tmp; + } + } + + #[inline(always)] + fn shl(&mut self, mut n: u32) { + if n as usize >= 64 * #limbs { + *self = Self::from(0); + return; + } + + while n >= 64 { + let mut t = 0; + for i in &mut self.0 { + ::std::mem::swap(&mut t, i); + } + n -= 64; + } + + if n > 0 { + let mut t = 0; + for i in &mut self.0 { + let t2 = *i >> (64 - n); + *i <<= n; + *i |= t; + t = t2; + } + } + } + + #[inline(always)] + fn num_bits(&self) -> u32 { + let mut ret = (#limbs as u32) * 64; + for i in self.0.iter().rev() { + let leading = i.leading_zeros(); + ret -= leading; + if leading != 64 { + break; + } + } + + ret + } + + #[inline(always)] + fn add_nocarry(&mut self, other: &#repr) { + let mut carry = 0; + + for (a, b) in self.0.iter_mut().zip(other.0.iter()) { + *a = crate::ff::adc(*a, *b, &mut carry); + } + } + + #[inline(always)] + fn sub_noborrow(&mut self, other: &#repr) { + let mut borrow = 0; + + for (a, b) in self.0.iter_mut().zip(other.0.iter()) { + *a = crate::ff::sbb(*a, *b, &mut borrow); + } + } + } + } +} + +fn prime_field_constants_and_sqrt( + name: &syn::Ident, + repr: &syn::Ident, + modulus: BigUint, + limbs: usize, + generator: BigUint, +) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + let modulus_num_bits = biguint_num_bits(modulus.clone()); + + // The number of bits we should "shave" from a randomly sampled reputation, i.e., + // if our modulus is 381 bits and our representation is 384 bits, we should shave + // 3 bits from the beginning of a randomly sampled 384 bit representation to + // reduce the cost of rejection sampling. + let repr_shave_bits = (64 * limbs as u32) - biguint_num_bits(modulus.clone()); + let repr_shave_mask = if repr_shave_bits == 64 { + 0u64 + } else { + 0xffffffffffffffffu64 >> repr_shave_bits + }; + + // Compute R = 2**(64 * limbs) mod m + let r = (BigUint::one() << (limbs * 64)) % &modulus; + + // modulus - 1 = 2^s * t + let mut s: u32 = 0; + let mut t = &modulus - BigUint::from_str("1").unwrap(); + while t.is_even() { + t = t >> 1; + s += 1; + } + + // Compute 2^s root of unity given the generator + let root_of_unity = biguint_to_u64_vec( + (generator.clone().modpow(&t, &modulus) * &r) % &modulus, + limbs, + ); + let generator = biguint_to_u64_vec((generator.clone() * &r) % &modulus, limbs); + + let mod_minus_1_over_2 = + biguint_to_u64_vec((&modulus - BigUint::from_str("1").unwrap()) >> 1, limbs); + let legendre_impl = quote!{ + fn legendre(&self) -> crate::ff::LegendreSymbol { + // s = self^((modulus - 1) // 2) + let s = self.pow(#mod_minus_1_over_2); + if s == Self::zero() { + crate::ff::LegendreSymbol::Zero + } else if s == Self::one() { + crate::ff::LegendreSymbol::QuadraticResidue + } else { + crate::ff::LegendreSymbol::QuadraticNonResidue + } + } + }; + + let sqrt_impl = + if (&modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() { + let mod_minus_3_over_4 = + biguint_to_u64_vec((&modulus - BigUint::from_str("3").unwrap()) >> 2, limbs); + + // Compute -R as (m - r) + let rneg = biguint_to_u64_vec(&modulus - &r, limbs); + + quote!{ + impl crate::ff::SqrtField for #name { + #legendre_impl + + fn sqrt(&self) -> Option { + // Shank's algorithm for q mod 4 = 3 + // https://eprint.iacr.org/2012/685.pdf (page 9, algorithm 2) + + let mut a1 = self.pow(#mod_minus_3_over_4); + + let mut a0 = a1; + a0.square(); + a0.mul_assign(self); + + if a0.0 == #repr(#rneg) { + None + } else { + a1.mul_assign(self); + Some(a1) + } + } + } + } + } else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() { + let t_plus_1_over_2 = biguint_to_u64_vec((&t + BigUint::one()) >> 1, limbs); + let t = biguint_to_u64_vec(t.clone(), limbs); + + quote!{ + impl crate::ff::SqrtField for #name { + #legendre_impl + + fn sqrt(&self) -> Option { + // Tonelli-Shank's algorithm for q mod 16 = 1 + // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5) + + match self.legendre() { + crate::ff::LegendreSymbol::Zero => Some(*self), + crate::ff::LegendreSymbol::QuadraticNonResidue => None, + crate::ff::LegendreSymbol::QuadraticResidue => { + let mut c = #name(ROOT_OF_UNITY); + let mut r = self.pow(#t_plus_1_over_2); + let mut t = self.pow(#t); + let mut m = S; + + while t != Self::one() { + let mut i = 1; + { + let mut t2i = t; + t2i.square(); + loop { + if t2i == Self::one() { + break; + } + t2i.square(); + i += 1; + } + } + + for _ in 0..(m - i - 1) { + c.square(); + } + r.mul_assign(&c); + c.square(); + t.mul_assign(&c); + m = i; + } + + Some(r) + } + } + } + } + } + } else { + quote!{} + }; + + // Compute R^2 mod m + let r2 = biguint_to_u64_vec((&r * &r) % &modulus, limbs); + + let r = biguint_to_u64_vec(r, limbs); + let modulus = biguint_to_real_u64_vec(modulus, limbs); + + // Compute -m^-1 mod 2**64 by exponentiating by totient(2**64) - 1 + let mut inv = 1u64; + for _ in 0..63 { + inv = inv.wrapping_mul(inv); + inv = inv.wrapping_mul(modulus[0]); + } + inv = inv.wrapping_neg(); + + (quote! { + /// This is the modulus m of the prime field + const MODULUS: #repr = #repr([#(#modulus,)*]); + + /// The number of bits needed to represent the modulus. + const MODULUS_BITS: u32 = #modulus_num_bits; + + /// The number of bits that must be shaved from the beginning of + /// the representation when randomly sampling. + const REPR_SHAVE_BITS: u32 = #repr_shave_bits; + + /// Precalculated mask to shave bits from the top limb in random sampling + const TOP_LIMB_SHAVE_MASK: u64 = #repr_shave_mask; + + /// 2^{limbs*64} mod m + const R: #repr = #repr(#r); + + /// 2^{limbs*64*2} mod m + const R2: #repr = #repr(#r2); + + /// -(m^{-1} mod m) mod m + const INV: u64 = #inv; + + /// Multiplicative generator of `MODULUS` - 1 order, also quadratic + /// nonresidue. + const GENERATOR: #repr = #repr(#generator); + + /// 2^s * t = MODULUS - 1 with t odd + const S: u32 = #s; + + /// 2^s root of unity computed by GENERATOR^t + const ROOT_OF_UNITY: #repr = #repr(#root_of_unity); + }, sqrt_impl) +} + +// Returns r{n} as an ident. +fn get_temp(n: usize) -> syn::Ident { + syn::Ident::new(&format!("r{}", n), proc_macro2::Span::call_site()) +} + +fn get_temp_with_literal(literal: &str, n: usize) -> syn::Ident { + syn::Ident::new(&format!("{}{}", literal, n), proc_macro2::Span::call_site()) +} + +/// Implement PrimeField for the derived type. +fn prime_field_impl( + name: &syn::Ident, + repr: &syn::Ident, + can_use_cios_mul: bool, + can_use_cios_sqr: bool, + limbs: usize, +) -> proc_macro2::TokenStream { + + // The parameter list for the mont_reduce() internal method. + // r0: u64, mut r1: u64, mut r2: u64, ... + let mut mont_paramlist = proc_macro2::TokenStream::new(); + mont_paramlist.append_separated( + (0..(limbs * 2)).map(|i| (i, get_temp(i))).map(|(i, x)| { + if i != 0 { + quote!{mut #x: u64} + } else { + quote!{#x: u64} + } + }), + proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), + ); + + // Implement montgomery reduction for some number of limbs + fn mont_impl(limbs: usize) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + for i in 0..limbs { + { + let temp = get_temp(i); + gen.extend(quote!{ + let k = #temp.wrapping_mul(INV); + let mut carry = 0; + crate::ff::mac_with_carry(#temp, k, MODULUS.0[0], &mut carry); + }); + } + + for j in 1..limbs { + let temp = get_temp(i + j); + gen.extend(quote!{ + #temp = crate::ff::mac_with_carry(#temp, k, MODULUS.0[#j], &mut carry); + }); + } + + let temp = get_temp(i + limbs); + + if i == 0 { + gen.extend(quote!{ + #temp = crate::ff::adc(#temp, 0, &mut carry); + }); + } else { + gen.extend(quote!{ + #temp = crate::ff::adc(#temp, carry2, &mut carry); + }); + } + + if i != (limbs - 1) { + gen.extend(quote!{ + let carry2 = carry; + }); + } + } + + for i in 0..limbs { + let temp = get_temp(limbs + i); + + gen.extend(quote!{ + (self.0).0[#i] = #temp; + }); + } + + gen + } + + fn sqr_impl(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + for i in 0..(limbs - 1) { + gen.extend(quote!{ + let mut carry = 0; + }); + + for j in (i + 1)..limbs { + let temp = get_temp(i + j); + if i == 0 { + gen.extend(quote!{ + let #temp = crate::ff::mac_with_carry(0, (#a.0).0[#i], (#a.0).0[#j], &mut carry); + }); + } else { + gen.extend(quote!{ + let #temp = crate::ff::mac_with_carry(#temp, (#a.0).0[#i], (#a.0).0[#j], &mut carry); + }); + } + } + + let temp = get_temp(i + limbs); + + gen.extend(quote!{ + let #temp = carry; + }); + } + + if limbs != 1 { + for i in 1..(limbs * 2) { + let temp0 = get_temp(limbs * 2 - i); + let temp1 = get_temp(limbs * 2 - i - 1); + + if i == 1 { + gen.extend(quote!{ + let #temp0 = #temp1 >> 63; + }); + } else if i == (limbs * 2 - 1) { + gen.extend(quote!{ + let #temp0 = #temp0 << 1; + }); + } else { + gen.extend(quote!{ + let #temp0 = (#temp0 << 1) | (#temp1 >> 63); + }); + } + } + } else { + gen.extend(quote!{ + let r1 = 0; + }); + } + + gen.extend(quote!{ + let mut carry = 0; + }); + + for i in 0..limbs { + let temp0 = get_temp(i * 2); + let temp1 = get_temp(i * 2 + 1); + if i == 0 { + gen.extend(quote!{ + let #temp0 = crate::ff::mac_with_carry(0, (#a.0).0[#i], (#a.0).0[#i], &mut carry); + }); + } else { + gen.extend(quote!{ + let #temp0 = crate::ff::mac_with_carry(#temp0, (#a.0).0[#i], (#a.0).0[#i], &mut carry); + }); + } + + gen.extend(quote!{ + let #temp1 = crate::ff::adc(#temp1, 0, &mut carry); + }); + } + + let mut mont_calling = proc_macro2::TokenStream::new(); + mont_calling.append_separated( + (0..(limbs * 2)).map(|i| get_temp(i)), + proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), + ); + + gen.extend(quote!{ + self.mont_reduce(#mont_calling); + }); + + gen + } + + fn mul_impl( + a: proc_macro2::TokenStream, + b: proc_macro2::TokenStream, + limbs: usize, + ) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + for i in 0..limbs { + gen.extend(quote!{ + let mut carry = 0; + }); + + for j in 0..limbs { + let temp = get_temp(i + j); + + if i == 0 { + gen.extend(quote!{ + let #temp = crate::ff::mac_with_carry(0, (#a.0).0[#i], (#b.0).0[#j], &mut carry); + }); + } else { + gen.extend(quote!{ + let #temp = crate::ff::mac_with_carry(#temp, (#a.0).0[#i], (#b.0).0[#j], &mut carry); + }); + } + } + + let temp = get_temp(i + limbs); + + gen.extend(quote!{ + let #temp = carry; + }); + } + + let mut mont_calling = proc_macro2::TokenStream::new(); + mont_calling.append_separated( + (0..(limbs * 2)).map(|i| get_temp(i)), + proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), + ); + + gen.extend(quote!{ + self.mont_reduce(#mont_calling); + }); + + gen + } + + fn optimistic_cios_mul_impl( + a: proc_macro2::TokenStream, + b: proc_macro2::TokenStream, + name: &syn::Ident, + repr: &syn::Ident, + limbs: usize, + ) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + let mut other_limbs_set = proc_macro2::TokenStream::new(); + other_limbs_set.append_separated( + (0..limbs).map(|i| get_temp_with_literal("b", i)), + proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), + ); + + gen.extend(quote!{ + let [#other_limbs_set] = (#b.0).0; + }); + + for i in 0..limbs { + gen.extend(quote!{ + let a = (#a.0).0[#i]; + }); + + let temp = get_temp(0); + + let b = get_temp_with_literal("b", 0); + + if i == 0 { + gen.extend(quote!{ + let (#temp, carry) = crate::ff::full_width_mul(a, #b); + }); + } else { + gen.extend(quote!{ + let (#temp, carry) = crate::ff::mac_by_value(#temp, a, #b); + }); + } + gen.extend(quote!{ + let m = r0.wrapping_mul(INV); + let red_carry = crate::ff::mac_by_value_return_carry_only(#temp, m, MODULUS.0[0]); + }); + + for j in 1..limbs { + let temp = get_temp(j); + + let b = get_temp_with_literal("b", j); + + if i == 0 { + gen.extend(quote!{ + let (#temp, carry) = crate::ff::mac_by_value(carry, a, #b); + }); + } else { + gen.extend(quote!{ + let (#temp, carry) = crate::ff::mac_with_carry_by_value(#temp, a, #b, carry); + }); + } + + let temp_prev = get_temp(j-1); + + gen.extend(quote!{ + let (#temp_prev, red_carry) = crate::ff::mac_with_carry_by_value(#temp, m, MODULUS.0[#j], red_carry); + }); + } + + let temp = get_temp(limbs-1); + gen.extend(quote!{ + let #temp = red_carry + carry; + }); + } + + let mut limbs_set = proc_macro2::TokenStream::new(); + limbs_set.append_separated( + (0..limbs).map(|i| get_temp(i)), + proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), + ); + + gen.extend(quote!{ + *self = #name(#repr([#limbs_set])); + self.reduce(); + }); + + gen + } + + fn optimistic_cios_sqr_impl( + a: proc_macro2::TokenStream, + name: &syn::Ident, + repr: &syn::Ident, + limbs: usize, + ) -> proc_macro2::TokenStream { + let mut gen = proc_macro2::TokenStream::new(); + + let mut this_limbs_set = proc_macro2::TokenStream::new(); + this_limbs_set.append_separated( + (0..limbs).map(|i| get_temp_with_literal("a", i)), + proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), + ); + + gen.extend(quote!{ + let [#this_limbs_set] = (#a.0).0; + }); + + for i in 0..limbs { + for red_idx in 0..i { + if red_idx == 0 { + let temp = get_temp(0); + + gen.extend(quote!{ + let m = r0.wrapping_mul(INV); + let red_carry = crate::ff::mac_by_value_return_carry_only(#temp, m, MODULUS.0[0]); + }); + } else { + let temp = get_temp(red_idx); + let temp_prev = get_temp(red_idx-1); + gen.extend(quote!{ + let (#temp_prev, red_carry) = crate::ff::mac_with_carry_by_value(#temp, m, MODULUS.0[#red_idx], red_carry); + }); + + } + } + let a = get_temp_with_literal("a", i); + + // single square step + if i == 0 { + // for a first pass just square and reduce + let temp = get_temp(0); + + gen.extend(quote!{ + let (#temp, carry) = crate::ff::full_width_mul(#a, #a); + let m = r0.wrapping_mul(INV); + let red_carry = crate::ff::mac_by_value_return_carry_only(#temp, m, MODULUS.0[0]); + }); + } else { + // for next passes square, add previous value and reduce + let temp = get_temp(i); + let temp_prev = get_temp(i-1); + gen.extend(quote!{ + let (#temp, carry) = crate::ff::mac_by_value(#temp, #a, #a); + + }); + + if i == limbs - 1 { + gen.extend(quote!{ + let (#temp_prev, #temp) = crate::ff::mac_with_low_and_high_carry_by_value( + red_carry, m, MODULUS.0[#i], #temp, carry + ); + }); + } else { + gen.extend(quote!{ + let (#temp_prev, red_carry) = crate::ff::mac_with_carry_by_value(#temp, m, MODULUS.0[#i], red_carry); + }); + } + } + + // continue with propagation and reduction + for j in (i+1)..limbs { + let b = get_temp_with_literal("a", j); + + let temp = get_temp(j); + + if i == 0 { + if j == limbs - 1 { + let temp_prev = get_temp(j-1); + + gen.extend(quote!{ + let (#temp, carry) = crate::ff::mul_double_add_low_and_high_carry_by_value_ignore_superhi( + #a, #b, carry, superhi + ); + + let (#temp_prev, #temp) = crate::ff::mac_with_low_and_high_carry_by_value( + red_carry, m, MODULUS.0[#j], #temp, carry + ); + }); + } else { + if j == i+1 { + gen.extend(quote!{ + let (#temp, carry, superhi) = crate::ff::mul_double_add_by_value( + carry, #a, #b, + ); + }); + } else { + gen.extend(quote!{ + let (#temp, carry, superhi) = crate::ff::mul_double_add_low_and_high_carry_by_value( + #a, #b, carry, superhi + ); + }); + } + + let temp_prev = get_temp(j-1); + + gen.extend(quote!{ + let (#temp_prev, red_carry) = crate::ff::mac_with_carry_by_value(#temp, m, MODULUS.0[#j], red_carry); + }); + } + } else { + if j == limbs - 1 { + let temp_prev = get_temp(j-1); + + if j == i+1 { + gen.extend(quote!{ + let (#temp, carry) = crate::ff::mul_double_add_add_carry_by_value_ignore_superhi( + #temp, #a, #b, carry + ); + }); + } else { + gen.extend(quote!{ + let (#temp, carry) = crate::ff::mul_double_add_add_low_and_high_carry_by_value_ignore_superhi( + #temp, #a, #b, carry, superhi + ); + }); + + } + + gen.extend(quote!{ + let (#temp_prev, #temp) = crate::ff::mac_with_low_and_high_carry_by_value( + red_carry, m, MODULUS.0[#j], #temp, carry + ); + }); + } else { + if j == i+1 { + gen.extend(quote!{ + let (#temp, carry, superhi) = crate::ff::mul_double_add_add_carry_by_value( + #temp, #a, #b, carry + ); + }); + } else { + gen.extend(quote!{ + let (#temp, carry, superhi) = crate::ff::mul_double_add_add_low_and_high_carry_by_value_ignore_superhi( + #temp, #a, #b, carry, superhi + ); + }); + } + let temp_prev = get_temp(j-1); + + gen.extend(quote!{ + let (#temp_prev, red_carry) = mac_with_carry_by_value(#temp, m, MODULUS.0[#j], red_carry); + }); + } + } + + } + + // let temp = get_temp(limbs-1); + + // gen.extend(quote!{ + // let #temp = red_carry + carry; + // }); + } + + let mut limbs_set = proc_macro2::TokenStream::new(); + limbs_set.append_separated( + (0..limbs).map(|i| get_temp(i)), + proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), + ); + + gen.extend(quote!{ + *self = #name(#repr([#limbs_set])); + self.reduce(); + }); + + gen + } + let multiply_impl = if can_use_cios_mul { + optimistic_cios_mul_impl(quote!{self}, quote!{other}, name, repr, limbs) + } else { + mul_impl(quote!{self}, quote!{other}, limbs) + }; + let squaring_impl = if can_use_cios_sqr { + optimistic_cios_sqr_impl(quote!{self}, name, repr, limbs) + } else { + sqr_impl(quote!{self}, limbs) + }; + + let top_limb_index = limbs - 1; + + let montgomery_impl = mont_impl(limbs); + + // (self.0).0[0], (self.0).0[1], ..., 0, 0, 0, 0, ... + let mut into_repr_params = proc_macro2::TokenStream::new(); + into_repr_params.append_separated( + (0..limbs) + .map(|i| quote!{ (self.0).0[#i] }) + .chain((0..limbs).map(|_| quote!{0})), + proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), + ); + + quote!{ + impl ::std::marker::Copy for #name { } + + impl ::std::clone::Clone for #name { + fn clone(&self) -> #name { + *self + } + } + + impl ::std::cmp::PartialEq for #name { + fn eq(&self, other: &#name) -> bool { + self.0 == other.0 + } + } + + impl ::std::cmp::Eq for #name { } + + impl ::std::fmt::Debug for #name + { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}({:?})", stringify!(#name), self.into_repr()) + } + } + + /// Elements are ordered lexicographically. + impl Ord for #name { + #[inline(always)] + fn cmp(&self, other: &#name) -> ::std::cmp::Ordering { + self.into_repr().cmp(&other.into_repr()) + } + } + + impl PartialOrd for #name { + #[inline(always)] + fn partial_cmp(&self, other: &#name) -> Option<::std::cmp::Ordering> { + Some(self.cmp(other)) + } + } + + impl ::std::fmt::Display for #name { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}({})", stringify!(#name), self.into_repr()) + } + } + + impl ::rand::Rand for #name { + /// Computes a uniformly random element using rejection sampling. + fn rand(rng: &mut R) -> Self { + loop { + let mut tmp = #name(#repr::rand(rng)); + + // Mask away the unused bits at the beginning. + tmp.0.as_mut()[#top_limb_index] &= TOP_LIMB_SHAVE_MASK; + + if tmp.is_valid() { + return tmp + } + } + } + } + + impl From<#name> for #repr { + fn from(e: #name) -> #repr { + e.into_repr() + } + } + + impl crate::ff::PrimeField for #name { + type Repr = #repr; + + fn from_repr(r: #repr) -> Result<#name, crate::ff::PrimeFieldDecodingError> { + let mut r = #name(r); + if r.is_valid() { + r.mul_assign(&#name(R2)); + + Ok(r) + } else { + Err(crate::ff::PrimeFieldDecodingError::NotInField(format!("{}", r.0))) + } + } + + fn from_raw_repr(r: #repr) -> Result { + let mut r = #name(r); + if r.is_valid() { + Ok(r) + } else { + Err(crate::ff::PrimeFieldDecodingError::NotInField(format!("{}", r.0))) + } + } + + fn into_repr(&self) -> #repr { + let mut r = *self; + r.mont_reduce( + #into_repr_params + ); + + r.0 + } + + fn into_raw_repr(&self) -> #repr { + let r = *self; + + r.0 + } + + fn char() -> #repr { + MODULUS + } + + const NUM_BITS: u32 = MODULUS_BITS; + + const CAPACITY: u32 = Self::NUM_BITS - 1; + + fn multiplicative_generator() -> Self { + #name(GENERATOR) + } + + const S: u32 = S; + + fn root_of_unity() -> Self { + #name(ROOT_OF_UNITY) + } + + } + + impl crate::ff::Field for #name { + #[inline] + fn zero() -> Self { + #name(#repr::from(0)) + } + + #[inline] + fn one() -> Self { + #name(R) + } + + #[inline] + fn is_zero(&self) -> bool { + self.0.is_zero() + } + + #[inline] + fn add_assign(&mut self, other: &#name) { + // This cannot exceed the backing capacity. + self.0.add_nocarry(&other.0); + + // However, it may need to be reduced. + self.reduce(); + } + + #[inline] + fn double(&mut self) { + // This cannot exceed the backing capacity. + self.0.mul2(); + + // However, it may need to be reduced. + self.reduce(); + } + + #[inline] + fn sub_assign(&mut self, other: &#name) { + // If `other` is larger than `self`, we'll need to add the modulus to self first. + if other.0 > self.0 { + self.0.add_nocarry(&MODULUS); + } + + self.0.sub_noborrow(&other.0); + } + + #[inline] + fn negate(&mut self) { + if !self.is_zero() { + let mut tmp = MODULUS; + tmp.sub_noborrow(&self.0); + self.0 = tmp; + } + } + + fn inverse(&self) -> Option { + if self.is_zero() { + None + } else { + // Guajardo Kumar Paar Pelzl + // Efficient Software-Implementation of Finite Fields with Applications to Cryptography + // Algorithm 16 (BEA for Inversion in Fp) + + let one = #repr::from(1); + + let mut u = self.0; + let mut v = MODULUS; + let mut b = #name(R2); // Avoids unnecessary reduction step. + let mut c = Self::zero(); + + while u != one && v != one { + while u.is_even() { + u.div2(); + + if b.0.is_even() { + b.0.div2(); + } else { + b.0.add_nocarry(&MODULUS); + b.0.div2(); + } + } + + while v.is_even() { + v.div2(); + + if c.0.is_even() { + c.0.div2(); + } else { + c.0.add_nocarry(&MODULUS); + c.0.div2(); + } + } + + if v < u { + u.sub_noborrow(&v); + b.sub_assign(&c); + } else { + v.sub_noborrow(&u); + c.sub_assign(&b); + } + } + + if u == one { + Some(b) + } else { + Some(c) + } + } + } + + #[inline(always)] + fn frobenius_map(&mut self, _: usize) { + // This has no effect in a prime field. + } + + #[inline] + fn mul_assign(&mut self, other: &#name) + { + #multiply_impl + } + + #[inline] + fn square(&mut self) + { + #squaring_impl + } + } + + impl std::default::Default for #name { + fn default() -> Self { + Self::zero() + } + } + + impl std::hash::Hash for #name { + fn hash(&self, state: &mut H) { + for limb in self.0.as_ref().iter() { + limb.hash(state); + } + } + } + + impl #name { + /// Determines if the element is really in the field. This is only used + /// internally. + #[inline(always)] + fn is_valid(&self) -> bool { + self.0 < MODULUS + } + + /// Subtracts the modulus from this element if this element is not in the + /// field. Only used interally. + #[inline(always)] + fn reduce(&mut self) { + if !self.is_valid() { + self.0.sub_noborrow(&MODULUS); + } + } + + #[inline(always)] + fn mont_reduce( + &mut self, + #mont_paramlist + ) + { + // The Montgomery reduction here is based on Algorithm 14.32 in + // Handbook of Applied Cryptography + // . + + #montgomery_impl + + self.reduce(); + } + } + + impl ::serde::Serialize for #name { + fn serialize(&self, serializer: S) -> Result + where S: ::serde::Serializer + { + let repr = self.into_repr(); + repr.serialize(serializer) + } + } + + impl<'de> ::serde::Deserialize<'de> for #name { + fn deserialize(deserializer: D) -> Result + where D: ::serde::Deserializer<'de> + { + let repr = #repr::deserialize(deserializer)?; + let new = Self::from_repr(repr).expect("serialized representation is expected to be valid"); + + Ok(new) + } + } + } +} diff --git a/crates/ff/ff_derive/src/utils.rs b/crates/ff/ff_derive/src/utils.rs new file mode 100644 index 0000000..e6d91df --- /dev/null +++ b/crates/ff/ff_derive/src/utils.rs @@ -0,0 +1,38 @@ +use super::*; + +/// Convert BigUint into a vector of 64-bit limbs. +pub(crate) fn biguint_to_real_u64_vec(mut v: BigUint, limbs: usize) -> Vec { + let m = BigUint::one() << 64; + let mut ret = vec![]; + + while v > BigUint::zero() { + let rem: BigUint = &v % &m; + ret.push(rem.to_u64().unwrap()); + v = v >> 64; + } + + while ret.len() < limbs { + ret.push(0); + } + + assert!(ret.len() == limbs); + + ret +} + +/// Convert BigUint into a tokenized vector of 64-bit limbs. +pub(crate) fn biguint_to_u64_vec(v: BigUint, limbs: usize) -> proc_macro2::TokenStream { + let ret = biguint_to_real_u64_vec(v, limbs); + quote!([#(#ret,)*]) +} + +pub(crate) fn biguint_num_bits(mut v: BigUint) -> u32 { + let mut bits = 0; + + while v != BigUint::zero() { + v = v >> 1; + bits += 1; + } + + bits +} \ No newline at end of file diff --git a/crates/ff/ff_derive_const/Cargo.toml b/crates/ff/ff_derive_const/Cargo.toml new file mode 100644 index 0000000..9e23da1 --- /dev/null +++ b/crates/ff/ff_derive_const/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "ff_derive_const_ce" +version = "0.1.0" +authors = ["Alex Vlasov "] +description = "Procedural macro library used to build custom prime field implementations using const generics" +documentation = "https://docs.rs/ff/" +homepage = "https://github.com/matter-labs/ff" +license = "MIT/Apache-2.0" +repository = "https://github.com/matter-labs/ff" +edition = "2018" + +[lib] +#proc-macro = true + +[dependencies] +num-bigint = "0.2" +num-traits = "0.2" +num-integer = "0.1" +proc-macro2 = "0.4" +quote = "0.6" +syn = "0.14" +serde = "1.0.80" +hex = "0.3.2" +ff = { package = "ff_ce", version = "0.6" } +rand = "0.4" +crunchy = "0.2" + +[features] +default = [] +derive_serde = [] diff --git a/crates/ff/ff_derive_const/src/const_field_element.rs b/crates/ff/ff_derive_const/src/const_field_element.rs new file mode 100644 index 0000000..4a62084 --- /dev/null +++ b/crates/ff/ff_derive_const/src/const_field_element.rs @@ -0,0 +1,350 @@ +use crate::const_repr::BigintRepresentation; +use crate::const_repr::FullMultiplication; + +use ff::*; + +// #[macro_use] +// use crunchy::*; + +pub struct PrimeFieldElement< + P, + const N: usize +>(pub BigintRepresentation<{N}>, std::marker::PhantomData

) + where BigintRepresentation<{N}>: FullMultiplication, + P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32; + +pub trait FieldParameters: Sized + Copy + Send + Sync + 'static { + const NUM_BITS: u32; + const CAPACITY: u32; + const REPR_SHAVE_BITS: u32; + const S: u32; + const MULTIPLICATIVE_GENERATOR: BigintRepresentation<{N}>; + const ROOT_OF_UNITY: BigintRepresentation<{N}>; + const MODULUS: BigintRepresentation<{N}>; + const R: BigintRepresentation<{N}>; + const R2: BigintRepresentation<{N}>; + const INV: u64; +} + +impl PrimeFieldElement + where P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32 +{ + #[inline(always)] + fn is_valid(&self) -> bool { + self.0 < P::MODULUS + } + + #[inline(always)] + fn reduce(&mut self) { + if !self.is_valid() { + self.0.sub_noborrow(&P::MODULUS); + } + } + + #[inline(always)] + // fn mont_reduce(&mut self, mut mul_res: BigintRepresentation<{N*2}>) { + fn mont_reduce(&mut self, mut mul_res:< BigintRepresentation<{N}> as FullMultiplication >::MulResult) { + let mut carry2 = 0u64; + let mut carry = 0u64; + for j in 0..N { + let k = mul_res.0[j].wrapping_mul(P::INV); + for i in 0..N { + mul_res.0[i + j] = ::ff::mac_with_carry(mul_res.0[i + j], k, P::MODULUS.0[i], &mut carry); + } + mul_res.0[N + j] = ::ff::adc(mul_res.0[{N} + j], carry2, &mut carry); + carry2 = carry; + carry = 0u64; + } + + for j in 0..N { + (self.0).0[j] = (mul_res.0)[N + j]; + } + + self.reduce(); + } +} + +impl Copy for PrimeFieldElement + where P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32 {} + +impl Clone for PrimeFieldElement + where P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32 { + fn clone(&self) -> Self { + *self + } +} + +impl std::cmp::PartialEq for PrimeFieldElement + where P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32 { + fn eq(&self, other: &Self) -> bool { + self.0.cmp(&other.0) == std::cmp::Ordering::Equal + } +} + +impl std::cmp::Eq for PrimeFieldElement + where P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32 {} + +impl std::fmt::Debug for PrimeFieldElement + where P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32 +{ + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Fp({})", self.into_repr()) + } +} + +impl ::rand::Rand for PrimeFieldElement + where P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32 { + + #[inline(always)] + fn rand(rng: &mut R) -> Self { + let s = BigintRepresentation::<{N}>::rand(rng); + // TODO: shave + Self(s, std::marker::PhantomData) + } +} + +impl std::fmt::Display for PrimeFieldElement + where P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32 { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Fp({})", self.into_repr()) + } +} + +impl Field for PrimeFieldElement + where P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32 +{ + #[inline(always)] + fn zero() -> Self { + Self(BigintRepresentation::<{N}>::default(), std::marker::PhantomData) + } + + #[inline(always)] + fn one() -> Self { + Self(P::R, std::marker::PhantomData) + } + + #[inline(always)] + fn is_zero(&self) -> bool { + self.0.is_zero() + } + + #[inline(always)] + fn add_assign(&mut self, other: &Self) { + self.0.add_nocarry(&other.0); + self.reduce(); + } + + #[inline(always)] + fn double(&mut self) { + self.0.mul2(); + self.reduce(); + } + + #[inline(always)] + fn sub_assign(&mut self, other: &Self) { + + if other.0 > self.0 { + self.0.add_nocarry(&P::MODULUS); + } + + self.0.sub_noborrow(&other.0); + } + + #[inline(always)] + fn negate(&mut self) { + if !self.is_zero() { + let mut tmp = P::MODULUS; + tmp.sub_noborrow(&self.0); + self.0 = tmp; + } + } + + fn inverse(&self) -> Option { + None + // if self.is_zero() { + // None + // } else { + // // Guajardo Kumar Paar Pelzl + // // Efficient Software-Implementation of Finite Fields with Applications to Cryptography + // // Algorithm 16 (BEA for Inversion in Fp) + + // let one = #repr::from(1); + + // let mut u = self.0; + // let mut v = MODULUS; + // let mut b = #name(R2); // Avoids unnecessary reduction step. + // let mut c = Self::zero(); + + // while u != one && v != one { + // while u.is_even() { + // u.div2(); + + // if b.0.is_even() { + // b.0.div2(); + // } else { + // b.0.add_nocarry(&MODULUS); + // b.0.div2(); + // } + // } + + // while v.is_even() { + // v.div2(); + + // if c.0.is_even() { + // c.0.div2(); + // } else { + // c.0.add_nocarry(&MODULUS); + // c.0.div2(); + // } + // } + + // if v < u { + // u.sub_noborrow(&v); + // b.sub_assign(&c); + // } else { + // v.sub_noborrow(&u); + // c.sub_assign(&b); + // } + // } + + // if u == one { + // Some(b) + // } else { + // Some(c) + // } + // } + } + + #[inline(always)] + fn frobenius_map(&mut self, _: usize) { + // This has no effect in a prime field. + } + + #[inline(always)] + fn mul_assign(&mut self, other: &Self) { + // let mut interm = BigintRepresentation::<{N*2}>::default(); + let mut interm = < BigintRepresentation<{N}> as FullMultiplication >::MulResult::default(); + let mut carry = 0u64; + for j in 0..N { + let this_limb = (self.0).0[j]; + for i in 0..N { + interm.0[i + j] = ::ff::mac_with_carry(interm.0[i + j], this_limb, (other.0).0[i], &mut carry); + } + interm.0[N + j] = carry; + carry = 0u64; + } + + self.mont_reduce(interm); + } + + #[inline(always)] + fn square(&mut self) { + // let mut interm = BigintRepresentation::<{N*2}>::default(); + let mut interm = < BigintRepresentation<{N}> as FullMultiplication >::MulResult::default(); + let mut carry = 0u64; + + for j in 0..N { + let this_limb = (self.0).0[j]; + for i in (j+1)..N { + interm.0[i + j] = ::ff::mac_with_carry(interm.0[i + j], this_limb, (self.0).0[i], &mut carry); + } + interm.0[N + j] = carry; + carry = 0u64; + } + + interm.0[2*N - 1] = interm.0[2*N - 2] >> 63; + + for j in (2..=(2*N - 2)).rev() { + interm.0[j] = (interm.0[j] << 1) | (interm.0[j-1] >> 63); + } + + interm.0[1] = interm.0[1] << 1; + for j in 0..N { + let this_limb = (self.0).0[j]; + let idx = 2*j; + interm.0[idx] = ::ff::mac_with_carry(interm.0[idx], this_limb, this_limb, &mut carry); + interm.0[idx+1] = ::ff::adc(interm.0[idx+1], 0u64, &mut carry); + } + + self.mont_reduce(interm); + } +} + +impl From> for BigintRepresentation<{N}> + where P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32 +{ + fn from(e: PrimeFieldElement) -> Self { + e.into_repr() + } +} + +impl PrimeField for PrimeFieldElement + where P: FieldParameters<{N}>, + [u64; N]: std::array::LengthAtMost32 +{ + const NUM_BITS: u32 = P::NUM_BITS; + const CAPACITY: u32 = P::CAPACITY; + const S: u32 = P::S; + + type Repr = BigintRepresentation<{N}>; + + fn from_repr(repr: Self::Repr) -> Result { + let mut r = Self(repr, std::marker::PhantomData); + if r.is_valid() { + r.mul_assign(&Self(P::R2, std::marker::PhantomData)); + + Ok(r) + } else { + Err(PrimeFieldDecodingError::NotInField(format!("{}", r.0))) + } + } + + // fn from_raw_repr(repr: Self::Repr) -> Result { + // let r = Self(repr, std::marker::PhantomData); + // if r.is_valid() { + // Ok(r) + // } else { + // Err(PrimeFieldDecodingError::NotInField(format!("{}", r.0))) + // } + // } + + fn into_repr(&self) -> Self::Repr { + let mut r = *self; + // let mut interm = BigintRepresentation::<{N*2}>::default(); + let mut interm = < BigintRepresentation<{N}> as FullMultiplication >::MulResult::default(); + for j in 0..N { + interm.0[j] = (self.0).0[j]; + } + r.mont_reduce(interm); + + r.0 + } + + // fn into_raw_repr(&self) -> Self::Repr { + // self.0 + // } + + fn char() -> Self::Repr { + P::MODULUS + } + + fn multiplicative_generator() -> Self { + Self(P::MULTIPLICATIVE_GENERATOR, std::marker::PhantomData) + } + + fn root_of_unity() -> Self { + Self(P::ROOT_OF_UNITY, std::marker::PhantomData) + } + +} \ No newline at end of file diff --git a/crates/ff/ff_derive_const/src/const_repr.rs b/crates/ff/ff_derive_const/src/const_repr.rs new file mode 100644 index 0000000..86d7f44 --- /dev/null +++ b/crates/ff/ff_derive_const/src/const_repr.rs @@ -0,0 +1,251 @@ +use ff; +use rand; + +pub struct BigintRepresentation< + const N: usize +>(pub [u64; N]); + +impl Copy for BigintRepresentation<{N}> {} +impl Clone for BigintRepresentation<{N}> { + fn clone(&self) -> Self { + *self + } +} + +impl std::cmp::PartialEq for BigintRepresentation<{N}> { + fn eq(&self, other: &Self) -> bool { + self.cmp(&other) == std::cmp::Ordering::Equal + } +} + +impl std::cmp::Eq for BigintRepresentation<{N}> {} + +impl std::fmt::Debug for BigintRepresentation<{N}> +{ + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "0x")?; + for i in self.0.iter().rev() { + write!(f, "{:016x}", *i)?; + } + + Ok(()) + } +} + +impl ::rand::Rand for BigintRepresentation<{N}> { + #[inline(always)] + fn rand(rng: &mut R) -> Self { + let mut s = Self::default(); + for el in s.0.iter_mut() { + *el = rng.gen(); + } + + s + } +} + +impl std::fmt::Display for BigintRepresentation<{N}> { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "0x")?; + for i in self.0.iter().rev() { + write!(f, "{:016x}", *i)?; + } + + Ok(()) + } +} + +impl std::default::Default for BigintRepresentation<{N}> { + #[inline(always)] + fn default() -> Self { + use std::mem::{MaybeUninit}; + let mut s: Self = unsafe {MaybeUninit::uninit().assume_init() }; + for el in s.0.iter_mut() { + *el = 0u64; + } + + s + // let repr: [u64; {N}] = [0u64; {N}]; + // BigintRepresentation::<{N}>(repr) + } +} + +impl AsRef<[u64]> for BigintRepresentation<{N}> { + #[inline(always)] + fn as_ref(&self) -> &[u64] { + &self.0 + } +} + +impl AsMut<[u64]> for BigintRepresentation<{N}> { + #[inline(always)] + fn as_mut(&mut self) -> &mut [u64] { + &mut self.0 + } +} + +impl From for BigintRepresentation<{N}> { + #[inline(always)] + fn from(val: u64) -> Self { + let mut repr = Self::default(); + repr.0[0] = val; + repr + } +} + +impl Ord for BigintRepresentation<{N}> { + #[inline(always)] + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + for (a, b) in self.0.iter().rev().zip(other.0.iter().rev()) { + if a < b { + return std::cmp::Ordering::Less + } else if a > b { + return std::cmp::Ordering::Greater + } + } + + std::cmp::Ordering::Equal + } +} + +impl PartialOrd for BigintRepresentation<{N}> { + #[inline(always)] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl ::ff::PrimeFieldRepr for BigintRepresentation<{N}> where [u64; N]: std::array::LengthAtMost32 { + #[inline(always)] + fn is_odd(&self) -> bool { + self.0[0] & 1 == 1 + } + + #[inline(always)] + fn is_even(&self) -> bool { + !self.is_odd() + } + + #[inline(always)] + fn is_zero(&self) -> bool { + self.0.iter().all(|&e| e == 0) + } + + #[inline(always)] + fn shr(&mut self, mut n: u32) { + if n as usize >= 64 * N { + *self = Self::from(0); + return; + } + + while n >= 64 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + std::mem::swap(&mut t, i); + } + n -= 64; + } + + if n > 0 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + let t2 = *i << (64 - n); + *i >>= n; + *i |= t; + t = t2; + } + } + } + + #[inline(always)] + fn div2(&mut self) { + let mut t = 0; + for i in self.0.iter_mut().rev() { + let t2 = *i << 63; + *i >>= 1; + *i |= t; + t = t2; + } + } + + #[inline(always)] + fn mul2(&mut self) { + let mut last = 0; + for i in &mut self.0 { + let tmp = *i >> 63; + *i <<= 1; + *i |= last; + last = tmp; + } + } + + #[inline(always)] + fn shl(&mut self, mut n: u32) { + if n as usize >= 64 * N { + *self = Self::from(0); + return; + } + + while n >= 64 { + let mut t = 0; + for i in &mut self.0 { + ::std::mem::swap(&mut t, i); + } + n -= 64; + } + + if n > 0 { + let mut t = 0; + for i in &mut self.0 { + let t2 = *i >> (64 - n); + *i <<= n; + *i |= t; + t = t2; + } + } + } + + #[inline(always)] + fn num_bits(&self) -> u32 { + let mut ret = (N as u32) * 64; + for i in self.0.iter().rev() { + let leading = i.leading_zeros(); + ret -= leading; + if leading != 64 { + break; + } + } + + ret + } + + #[inline(always)] + fn add_nocarry(&mut self, other: &Self) { + let mut carry = 0; + + for (a, b) in self.0.iter_mut().zip(other.0.iter()) { + *a = ::ff::adc(*a, *b, &mut carry); + } + } + + #[inline(always)] + fn sub_noborrow(&mut self, other: &Self) { + let mut borrow = 0; + + for (a, b) in self.0.iter_mut().zip(other.0.iter()) { + *a = ::ff::sbb(*a, *b, &mut borrow); + } + } +} + +pub trait FullMultiplication { + type Multiplicand; + type MulResult; +} + +impl FullMultiplication for BigintRepresentation<{N}> +{ + type Multiplicand = BigintRepresentation<{N}>; + type MulResult = BigintRepresentation<{N*2}>; +} + diff --git a/crates/ff/ff_derive_const/src/lib.rs b/crates/ff/ff_derive_const/src/lib.rs new file mode 100644 index 0000000..9c3835b --- /dev/null +++ b/crates/ff/ff_derive_const/src/lib.rs @@ -0,0 +1,87 @@ + +#![feature(const_generics)] +#![feature(const_generic_impls_guard)] + +#![recursion_limit = "1024"] + +extern crate ff; +extern crate rand; + +pub mod const_repr; +pub mod const_field_element; +// mod alt; + +#[cfg(test)] +mod tests { + use super::const_repr::*; + use super::const_field_element::*; + + const MODULUS: BigintRepresentation::<4> = BigintRepresentation::<4>([ + 0x677297dc392126f1, + 0xab3eedb83920ee0a, + 0x370a08b6d0302b0b, + 0x060c89ce5c263405, + ]); + + const MODULUS_BITS: u32 = 251; + const REPR_SHAVE_BITS: u32 = 5; + const R: BigintRepresentation::<4> = BigintRepresentation::<4>([ + 0x073315dea08f9c76, + 0xe7acffc6a098f24b, + 0xf85a9201d818f015, + 0x01f16424e1bb7724, + ]); + + const R2: BigintRepresentation::<4> = BigintRepresentation::<4>([ + 0x35e44abee7ecb21e, + 0x74646cacf5f84ec4, + 0xe472df203faa158f, + 0x0445b524f1ba50a8, + ]); + + const INV: u64 = 0x532ce5aebc48f5ef; + + const GENERATOR: BigintRepresentation::<4> = BigintRepresentation::<4>([ + 0x6380695df1aaf958, + 0xff3d22fdf1ecc3f8, + 0x5c65ec9f484e3a81, + 0x0180a96573d3d9f8, + ]); + + const S: u32 = 4; + + const ROOT_OF_UNITY: BigintRepresentation::<4> = BigintRepresentation::<4>([ + 0xa13885692e7afcb0, + 0xb789766cd18573ca, + 0xd5468c0174efc3b9, + 0x03534b612b0b6f7a, + ]); + + #[derive(Copy, Clone)] + struct FsParams; + + impl FieldParameters<4> for FsParams { + const NUM_BITS: u32 = MODULUS_BITS; + const CAPACITY: u32 = 250; + const REPR_SHAVE_BITS: u32 = REPR_SHAVE_BITS; + const S: u32 = S; + const MULTIPLICATIVE_GENERATOR: BigintRepresentation::<4> = GENERATOR; + const ROOT_OF_UNITY: BigintRepresentation::<4> = ROOT_OF_UNITY; + const MODULUS: BigintRepresentation::<4> = MODULUS; + const R: BigintRepresentation::<4> = R; + const R2: BigintRepresentation::<4> = R2; + const INV: u64 = INV; + } + + type Fs = PrimeFieldElement; + + #[test] + fn make_naive() { + use crate::ff::PrimeField; + + let repr = BigintRepresentation::<4>::from(3u64); + + let fe = Fs::from_repr(repr).unwrap(); + println!("{:?}", fe); + } +} \ No newline at end of file diff --git a/crates/ff/src/lib.rs b/crates/ff/src/lib.rs new file mode 100644 index 0000000..aa0b8ba --- /dev/null +++ b/crates/ff/src/lib.rs @@ -0,0 +1,610 @@ +#![allow(unused_imports)] + +extern crate byteorder; +extern crate hex as hex_ext; +extern crate rand; +extern crate serde; +pub mod hex { + pub use hex_ext::*; +} + +#[cfg(feature = "derive")] +#[macro_use] +extern crate ff_derive_ce; + +#[cfg(feature = "derive")] +pub use ff_derive_ce::*; + +use std::error::Error; +use std::fmt; +use std::hash; +use std::io::{self, Read, Write}; + +/// This trait represents an element of a field. +pub trait Field: Sized + Eq + Copy + Clone + Send + Sync + fmt::Debug + fmt::Display + 'static + rand::Rand + hash::Hash + Default + serde::Serialize + serde::de::DeserializeOwned { + /// Returns the zero element of the field, the additive identity. + fn zero() -> Self; + + /// Returns the one element of the field, the multiplicative identity. + fn one() -> Self; + + /// Returns true iff this element is zero. + fn is_zero(&self) -> bool; + + /// Squares this element. + fn square(&mut self); + + /// Doubles this element. + fn double(&mut self); + + /// Negates this element. + fn negate(&mut self); + + /// Adds another element to this element. + fn add_assign(&mut self, other: &Self); + + /// Subtracts another element from this element. + fn sub_assign(&mut self, other: &Self); + + /// Multiplies another element by this element. + fn mul_assign(&mut self, other: &Self); + + /// Computes the multiplicative inverse of this element, if nonzero. + fn inverse(&self) -> Option; + + /// Exponentiates this element by a power of the base prime modulus via + /// the Frobenius automorphism. + fn frobenius_map(&mut self, power: usize); + + /// Exponentiates this element by a number represented with `u64` limbs, + /// least significant digit first. + fn pow>(&self, exp: S) -> Self { + let mut res = Self::one(); + + let mut found_one = false; + + for i in BitIterator::new(exp) { + if found_one { + res.square(); + } else { + found_one = i; + } + + if i { + res.mul_assign(self); + } + } + + res + } +} + +/// This trait represents an element of a field that has a square root operation described for it. +pub trait SqrtField: Field { + /// Returns the Legendre symbol of the field element. + fn legendre(&self) -> LegendreSymbol; + + /// Returns the square root of the field element, if it is + /// quadratic residue. + fn sqrt(&self) -> Option; +} + +/// This trait represents a wrapper around a biginteger which can encode any element of a particular +/// prime field. It is a smart wrapper around a sequence of `u64` limbs, least-significant digit +/// first. +pub trait PrimeFieldRepr: + Sized + + Copy + + Clone + + Eq + + Ord + + Send + + Sync + + Default + + fmt::Debug + + fmt::Display + + 'static + + rand::Rand + + AsRef<[u64]> + + AsMut<[u64]> + + From + + hash::Hash + + serde::Serialize + + serde::de::DeserializeOwned +{ + /// Subtract another represetation from this one. + fn sub_noborrow(&mut self, other: &Self); + + /// Add another representation to this one. + fn add_nocarry(&mut self, other: &Self); + + /// Compute the number of bits needed to encode this number. Always a + /// multiple of 64. + fn num_bits(&self) -> u32; + + /// Returns true iff this number is zero. + fn is_zero(&self) -> bool; + + /// Returns true iff this number is odd. + fn is_odd(&self) -> bool; + + /// Returns true iff this number is even. + fn is_even(&self) -> bool; + + /// Performs a rightwise bitshift of this number, effectively dividing + /// it by 2. + fn div2(&mut self); + + /// Performs a rightwise bitshift of this number by some amount. + fn shr(&mut self, amt: u32); + + /// Performs a leftwise bitshift of this number, effectively multiplying + /// it by 2. Overflow is ignored. + fn mul2(&mut self); + + /// Performs a leftwise bitshift of this number by some amount. + fn shl(&mut self, amt: u32); + + /// Writes this `PrimeFieldRepr` as a big endian integer. + fn write_be(&self, mut writer: W) -> io::Result<()> { + use byteorder::{BigEndian, WriteBytesExt}; + + for digit in self.as_ref().iter().rev() { + writer.write_u64::(*digit)?; + } + + Ok(()) + } + + /// Reads a big endian integer into this representation. + fn read_be(&mut self, mut reader: R) -> io::Result<()> { + use byteorder::{BigEndian, ReadBytesExt}; + + for digit in self.as_mut().iter_mut().rev() { + *digit = reader.read_u64::()?; + } + + Ok(()) + } + + /// Writes this `PrimeFieldRepr` as a little endian integer. + fn write_le(&self, mut writer: W) -> io::Result<()> { + use byteorder::{LittleEndian, WriteBytesExt}; + + for digit in self.as_ref().iter() { + writer.write_u64::(*digit)?; + } + + Ok(()) + } + + /// Reads a little endian integer into this representation. + fn read_le(&mut self, mut reader: R) -> io::Result<()> { + use byteorder::{LittleEndian, ReadBytesExt}; + + for digit in self.as_mut().iter_mut() { + *digit = reader.read_u64::()?; + } + + Ok(()) + } +} + +#[derive(Debug, PartialEq)] +pub enum LegendreSymbol { + Zero = 0, + QuadraticResidue = 1, + QuadraticNonResidue = -1, +} + +/// An error that may occur when trying to interpret a `PrimeFieldRepr` as a +/// `PrimeField` element. +#[derive(Debug)] +pub enum PrimeFieldDecodingError { + /// The encoded value is not in the field + NotInField(String), +} + +impl Error for PrimeFieldDecodingError { + fn description(&self) -> &str { + match *self { + PrimeFieldDecodingError::NotInField(..) => "not an element of the field", + } + } +} + +impl fmt::Display for PrimeFieldDecodingError { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + match *self { + PrimeFieldDecodingError::NotInField(ref repr) => { + write!(f, "{} is not an element of the field", repr) + } + } + } +} + +/// This represents an element of a prime field. +pub trait PrimeField: Field { + /// The prime field can be converted back and forth into this biginteger + /// representation. + type Repr: PrimeFieldRepr + From; + + /// Interpret a string of numbers as a (congruent) prime field element. + /// Does not accept unnecessary leading zeroes or a blank string. + fn from_str(s: &str) -> Option { + if s.is_empty() { + return None; + } + + if s == "0" { + return Some(Self::zero()); + } + + let mut res = Self::zero(); + + let ten = Self::from_repr(Self::Repr::from(10)).unwrap(); + + let mut first_digit = true; + + for c in s.chars() { + match c.to_digit(10) { + Some(c) => { + if first_digit { + if c == 0 { + return None; + } + + first_digit = false; + } + + res.mul_assign(&ten); + res.add_assign(&Self::from_repr(Self::Repr::from(u64::from(c))).unwrap()); + } + None => { + return None; + } + } + } + + Some(res) + } + + /// Convert this prime field element into a biginteger representation. + fn from_repr(repr: Self::Repr) -> Result; + + /// Creates an element from raw representation in Montgommery form. + fn from_raw_repr(repr: Self::Repr) -> Result; + + /// Convert a biginteger representation into a prime field element, if + /// the number is an element of the field. + fn into_repr(&self) -> Self::Repr; + + /// Expose Montgommery represendation. + fn into_raw_repr(&self) -> Self::Repr; + + /// Returns the field characteristic; the modulus. + fn char() -> Self::Repr; + + /// How many bits are needed to represent an element of this field. + const NUM_BITS: u32; + + /// How many bits of information can be reliably stored in the field element. + const CAPACITY: u32; + + /// Returns the multiplicative generator of `char()` - 1 order. This element + /// must also be quadratic nonresidue. + fn multiplicative_generator() -> Self; + + /// 2^s * t = `char()` - 1 with t odd. + const S: u32; + + /// Returns the 2^s root of unity computed by exponentiating the `multiplicative_generator()` + /// by t. + fn root_of_unity() -> Self; +} + +/// An "engine" is a collection of types (fields, elliptic curve groups, etc.) +/// with well-defined relationships. Specific relationships (for example, a +/// pairing-friendly curve) can be defined in a subtrait. +pub trait ScalarEngine: Sized + 'static + Clone + Copy + Send + Sync + fmt::Debug { + /// This is the scalar field of the engine's groups. + type Fr: PrimeField + SqrtField; +} + +#[derive(Debug)] +pub struct BitIterator { + t: E, + n: usize, +} + +impl> BitIterator { + pub fn new(t: E) -> Self { + let n = t.as_ref().len() * 64; + + BitIterator { t, n } + } +} + +impl> Iterator for BitIterator { + type Item = bool; + + fn next(&mut self) -> Option { + if self.n == 0 { + None + } else { + self.n -= 1; + let part = self.n / 64; + let bit = self.n - (64 * part); + + Some(self.t.as_ref()[part] & (1 << bit) > 0) + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.n, Some(self.n)) + } +} + +impl> ExactSizeIterator for BitIterator { + fn len(&self) -> usize { + self.n + } +} + +#[test] +fn test_bit_iterator() { + let mut a = BitIterator::new([0xa953d79b83f6ab59, 0x6dea2059e200bd39]); + let expected = "01101101111010100010000001011001111000100000000010111101001110011010100101010011110101111001101110000011111101101010101101011001"; + + for e in expected.chars() { + assert!(a.next().unwrap() == (e == '1')); + } + + assert!(a.next().is_none()); + + let expected = "1010010101111110101010000101101011101000011101110101001000011001100100100011011010001011011011010001011011101100110100111011010010110001000011110100110001100110011101101000101100011100100100100100001010011101010111110011101011000011101000111011011101011001"; + + let mut a = BitIterator::new([0x429d5f3ac3a3b759, 0xb10f4c66768b1c92, 0x92368b6d16ecd3b4, 0xa57ea85ae8775219]); + + for e in expected.chars() { + assert!(a.next().unwrap() == (e == '1')); + } + + assert!(a.next().is_none()); +} + +#[test] +fn test_bit_iterator_length() { + let a = BitIterator::new([0xa953d79b83f6ab59, 0x6dea2059e200bd39]); + let trusted_len = a.len(); + let (lower, some_upper) = a.size_hint(); + let upper = some_upper.unwrap(); + assert_eq!(trusted_len, 128); + assert_eq!(lower, 128); + assert_eq!(upper, 128); + + let mut i = 0; + for _ in a { + i += 1; + } + + assert_eq!(trusted_len, i); +} + +pub use self::arith_impl::*; + +mod arith_impl { + /// Calculate a - b - borrow, returning the result and modifying + /// the borrow value. + #[inline(always)] + pub fn sbb(a: u64, b: u64, borrow: &mut u64) -> u64 { + use std::num::Wrapping; + + let tmp = (1u128 << 64).wrapping_add(u128::from(a)).wrapping_sub(u128::from(b)).wrapping_sub(u128::from(*borrow)); + + *borrow = if tmp >> 64 == 0 { 1 } else { 0 }; + + tmp as u64 + } + + /// Calculate a + b + carry, returning the sum and modifying the + /// carry value. + #[inline(always)] + pub fn adc(a: u64, b: u64, carry: &mut u64) -> u64 { + use std::num::Wrapping; + + let tmp = u128::from(a).wrapping_add(u128::from(b)).wrapping_add(u128::from(*carry)); + + *carry = (tmp >> 64) as u64; + + tmp as u64 + } + + /// Calculate a + (b * c) + carry, returning the least significant digit + /// and setting carry to the most significant digit. + #[inline(always)] + pub fn mac_with_carry(a: u64, b: u64, c: u64, carry: &mut u64) -> u64 { + use std::num::Wrapping; + + let tmp = (u128::from(a)).wrapping_add(u128::from(b).wrapping_mul(u128::from(c))).wrapping_add(u128::from(*carry)); + + *carry = (tmp >> 64) as u64; + + tmp as u64 + } + + #[inline(always)] + pub fn full_width_mul(a: u64, b: u64) -> (u64, u64) { + let tmp = (a as u128) * (b as u128); + + return (tmp as u64, (tmp >> 64) as u64); + } + + #[inline(always)] + pub fn mac_by_value(a: u64, b: u64, c: u64) -> (u64, u64) { + let tmp = ((b as u128) * (c as u128)) + (a as u128); + + (tmp as u64, (tmp >> 64) as u64) + } + + #[inline(always)] + pub fn mac_by_value_return_carry_only(a: u64, b: u64, c: u64) -> u64 { + let tmp = ((b as u128) * (c as u128)) + (a as u128); + + (tmp >> 64) as u64 + } + + #[inline(always)] + pub fn mac_with_carry_by_value(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) { + let tmp = ((b as u128) * (c as u128)) + (a as u128) + (carry as u128); + + (tmp as u64, (tmp >> 64) as u64) + } + + #[inline(always)] + pub fn mul_double_add_by_value(a: u64, b: u64, c: u64) -> (u64, u64, u64) { + // multiply + let tmp = (b as u128) * (c as u128); + // doulbe + let lo = tmp as u64; + let hi = (tmp >> 64) as u64; + let superhi = hi >> 63; + let hi = hi << 1 | lo >> 63; + let lo = lo << 1; + // add + let tmp = (lo as u128) + ((hi as u128) << 64) + (a as u128); + + (tmp as u64, (tmp >> 64) as u64, superhi) + } + + #[inline(always)] + pub fn mul_double_add_add_carry_by_value(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64, u64) { + // multiply + let tmp = (b as u128) * (c as u128); + // doulbe + let lo = tmp as u64; + let hi = (tmp >> 64) as u64; + let superhi = hi >> 63; + let hi = hi << 1 | lo >> 63; + let lo = lo << 1; + // add + let tmp = (lo as u128) + ((hi as u128) << 64) + (a as u128) + (carry as u128); + + (tmp as u64, (tmp >> 64) as u64, superhi) + } + + #[inline(always)] + pub fn mul_double_add_add_carry_by_value_ignore_superhi(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) { + // multiply + let tmp = (b as u128) * (c as u128); + // doulbe + let lo = tmp as u64; + let hi = (tmp >> 64) as u64; + let hi = hi << 1 | lo >> 63; + let lo = lo << 1; + // add + let tmp = (lo as u128) + ((hi as u128) << 64) + (a as u128) + (carry as u128); + + (tmp as u64, (tmp >> 64) as u64) + } + + #[inline(always)] + pub fn mul_double_add_low_and_high_carry_by_value(b: u64, c: u64, lo_carry: u64, hi_carry: u64) -> (u64, u64, u64) { + // multiply + let tmp = (b as u128) * (c as u128); + // doulbe + let lo = tmp as u64; + let hi = (tmp >> 64) as u64; + let superhi = hi >> 63; + let hi = hi << 1 | lo >> 63; + let lo = lo << 1; + // add + let tmp = (lo as u128) + ((hi as u128) << 64) + (lo_carry as u128) + ((hi_carry as u128) << 64); + + (tmp as u64, (tmp >> 64) as u64, superhi) + } + + #[inline(always)] + pub fn mul_double_add_low_and_high_carry_by_value_ignore_superhi(b: u64, c: u64, lo_carry: u64, hi_carry: u64) -> (u64, u64) { + // multiply + let tmp = (b as u128) * (c as u128); + // doulbe + let lo = tmp as u64; + let hi = (tmp >> 64) as u64; + let hi = hi << 1 | lo >> 63; + let lo = lo << 1; + // add + let tmp = (lo as u128) + ((hi as u128) << 64) + (lo_carry as u128) + ((hi_carry as u128) << 64); + + (tmp as u64, (tmp >> 64) as u64) + } + + #[inline(always)] + pub fn mul_double_add_add_low_and_high_carry_by_value(a: u64, b: u64, c: u64, lo_carry: u64, hi_carry: u64) -> (u64, u64, u64) { + // multiply + let tmp = (b as u128) * (c as u128); + // doulbe + let lo = tmp as u64; + let hi = (tmp >> 64) as u64; + let superhi = hi >> 63; + let hi = hi << 1 | lo >> 63; + let lo = lo << 1; + // add + let tmp = (lo as u128) + ((hi as u128) << 64) + (a as u128) + (lo_carry as u128) + ((hi_carry as u128) << 64); + + (tmp as u64, (tmp >> 64) as u64, superhi) + } + + #[inline(always)] + pub fn mul_double_add_add_low_and_high_carry_by_value_ignore_superhi(a: u64, b: u64, c: u64, lo_carry: u64, hi_carry: u64) -> (u64, u64) { + // multiply + let tmp = (b as u128) * (c as u128); + // doulbe + let lo = tmp as u64; + let hi = (tmp >> 64) as u64; + let hi = hi << 1 | lo >> 63; + let lo = lo << 1; + // add + let tmp = (lo as u128) + ((hi as u128) << 64) + (a as u128) + (lo_carry as u128) + ((hi_carry as u128) << 64); + + (tmp as u64, (tmp >> 64) as u64) + } + + #[inline(always)] + pub fn mac_with_low_and_high_carry_by_value(a: u64, b: u64, c: u64, lo_carry: u64, hi_carry: u64) -> (u64, u64) { + let tmp = ((b as u128) * (c as u128)) + (a as u128) + (lo_carry as u128) + ((hi_carry as u128) << 64); + + (tmp as u64, (tmp >> 64) as u64) + } +} + +pub use to_hex::{from_hex, to_hex}; + +mod to_hex { + use super::{hex_ext, PrimeField, PrimeFieldRepr}; + + pub fn to_hex(el: &F) -> String { + let repr = el.into_repr(); + let required_length = repr.as_ref().len() * 8; + let mut buf: Vec = Vec::with_capacity(required_length); + repr.write_be(&mut buf).unwrap(); + + hex_ext::encode(&buf) + } + + pub fn from_hex(value: &str) -> Result { + let value = if value.starts_with("0x") { &value[2..] } else { value }; + if value.len() % 2 != 0 { + return Err(format!("hex length must be even for full byte encoding: {}", value)); + } + let mut buf = hex_ext::decode(&value).map_err(|_| format!("could not decode hex: {}", value))?; + let mut repr = F::Repr::default(); + let required_length = repr.as_ref().len() * 8; + buf.reverse(); + buf.resize(required_length, 0); + + repr.read_le(&buf[..]).map_err(|e| format!("could not read {}: {}", value, &e))?; + + F::from_repr(repr).map_err(|e| format!("could not convert into prime field: {}: {}", value, &e)) + } +} diff --git a/crates/ff/src/tests.rs b/crates/ff/src/tests.rs new file mode 100644 index 0000000..ba0f5fd --- /dev/null +++ b/crates/ff/src/tests.rs @@ -0,0 +1,153 @@ +extern crate rand; + +extern crate test as rust_test; +extern crate num_bigint; +extern crate num_traits; +extern crate num_integer; + +#[cfg(feature = "derive")] +mod benches { + use crate::Field; + use crate::*; + use super::rust_test::Bencher; + use rand::{Rng, XorShiftRng, SeedableRng}; + use super::num_bigint::BigUint; + use super::num_traits::identities::{Zero, One}; + use super::num_traits::{ToPrimitive, Num}; + use super::num_integer::Integer; + + #[derive(PrimeField)] + #[PrimeFieldModulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583"] + #[PrimeFieldGenerator = "2"] + struct Fr(pub FrRepr); + + #[bench] + fn bench_arith(bencher: &mut Bencher) { + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let mut a: Fr = rng.gen(); + let b: Fr = rng.gen(); + + bencher.iter(|| { + for _ in 0..100 { + a.mul_assign(&b); + } + }); + + } + + #[bench] + fn bench_cios(bencher: &mut Bencher) { + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let mut a: Fr = rng.gen(); + let b: Fr = rng.gen(); + + fn modular_inverse(el: &BigUint, modulus: &BigUint) -> BigUint { + let mut a = el.clone(); + let mut new = BigUint::one(); + let mut old = BigUint::zero(); + let mut q = modulus.clone(); + let mut r = BigUint::zero(); + let mut h = BigUint::zero(); + let mut positive = false; + while !a.is_zero() { + let (q_new, r_new) = q.div_mod_floor(&a); + r = r_new; + q = q_new; + h = (q * new.clone() + old.clone()) % modulus; + old = new; + new = h; + q = a; + a = r; + positive = !positive; + } + if positive { + return old; + } else { + return modulus - old; + } + } + + let mut modulus = BigUint::from_str_radix("21888242871839275222246405745257275088696311157297823662689037894645226208583", 10).unwrap(); + let mut mont_r = (BigUint::one() << 256) % modulus.clone(); + let r_inv = modular_inverse(&mont_r, &modulus); + + let (_, rem) = (mont_r.clone() * r_inv.clone()).div_mod_floor(&modulus); + assert!(rem == BigUint::one()); + + let modulus_512 = (modulus.clone() << 256); + + let r_r_inv = (r_inv.clone() << 256); + let subtracted = r_r_inv - BigUint::one(); + + let (mont_k, rem) = subtracted.div_mod_floor(&modulus_512); + let mut mont_k = mont_k % (BigUint::one() << 256); + + let mut mont_k_fixed = [0u64; 4]; + for i in 0..4 { + let limb = mont_k.clone() % (BigUint::one() << 64); + mont_k_fixed[i] = limb.to_u64().unwrap(); + mont_k = mont_k.clone() >> 64; + } + + let mut modulus_fixed = [0u64; 4]; + for i in 0..4 { + let limb = modulus.clone() % (BigUint::one() << 64); + modulus_fixed[i] = limb.to_u64().unwrap(); + modulus = modulus.clone() >> 64; + } + + let mont_k = mont_k_fixed; + let modulus = modulus_fixed; + + + fn cios_mul(a: &[u64; 4], b: &[u64; 4], mont_k: &[u64; 4], modulus: &[u64; 4]) -> [u64; 4] { + let mut t = [0u64; 8]; + let limbs = b.len(); + for i in 0..limbs { + let mut carry = 0u64; + let limb = b[i]; + for j in 0..limbs { + t[j] = crate::arith_impl::mac_with_carry(t[j], a[j], limb, &mut carry); + } + t[limbs] = crate::arith_impl::adc(t[limbs], 0, &mut carry); + t[limbs+1] = carry; + let mut m = crate::arith_impl::mac_with_carry(0u64, t[0], mont_k[0], &mut 0u64); + let mut m = m as u64; + let mut carry = 0; + crate::arith_impl::mac_with_carry(t[0], m, modulus[0], &mut carry); + for j in 1..limbs { + t[j-1] = crate::arith_impl::mac_with_carry(t[j], m, modulus[j], &mut carry); + } + t[limbs-1] = crate::arith_impl::adc(t[limbs], 0, &mut carry); + t[limbs] = t[limbs+1] + carry; + } + + let mut u = t; + let mut borrow = 0; + for i in 0..limbs { + t[i] = crate::arith_impl::sbb(u[i], modulus[i], &mut borrow); + } + t[limbs] = crate::arith_impl::sbb(u[limbs], 0, &mut borrow); + if borrow == 0 { + return [t[0], t[1], t[2], t[3]]; + } else { + return [u[0], u[1], u[2], u[3]]; + } + } + + let a_repr = (a.0).0; + let b_repr = (b.0).0; + + bencher.iter(|| { + for _ in 0..100 { + cios_mul(&a_repr, &b_repr, &mont_k, &modulus); + } + }); + + } +} + + + + + diff --git a/crates/ff/tester/Cargo.toml b/crates/ff/tester/Cargo.toml new file mode 100644 index 0000000..bd33e77 --- /dev/null +++ b/crates/ff/tester/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "ff_ce_tester" +version = "0.1.0" +authors = ["Alex Vlasov "] +description = "Stub for testing ff_ce" +documentation = "https://docs.rs/ff/" +homepage = "https://github.com/matter-labs/ff" +license = "MIT/Apache-2.0" +repository = "https://github.com/matter-labs/ff" +edition = "2018" + +[dependencies] +ff = {package = "ff_ce", path = "../", features = ["derive"]} +rand = "0.4" + +[dev-dependencies] +criterion = "0.3" + +[[bench]] +name = "multiplication" +harness = false \ No newline at end of file diff --git a/crates/ff/tester/bench_with_features.sh b/crates/ff/tester/bench_with_features.sh new file mode 100755 index 0000000..c2b4e75 --- /dev/null +++ b/crates/ff/tester/bench_with_features.sh @@ -0,0 +1 @@ +RUSTFLAGS="-C target-cpu=native -C target_feature=+bmi2,+adx" cargo +nightly bench \ No newline at end of file diff --git a/crates/ff/tester/benches/multiplication.rs b/crates/ff/tester/benches/multiplication.rs new file mode 100644 index 0000000..33259e0 --- /dev/null +++ b/crates/ff/tester/benches/multiplication.rs @@ -0,0 +1,352 @@ +extern crate ff; +extern crate rand; +extern crate ff_ce_tester; + +use self::ff::*; + +mod fr { + use crate::ff::*; + + #[derive(PrimeField)] + #[PrimeFieldModulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583"] + #[PrimeFieldGenerator = "2"] + pub struct Fr(FrRepr); +} + +mod frcios{ + use crate::ff::*; + + #[derive(PrimeField)] + #[PrimeFieldModulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583"] + #[PrimeFieldGenerator = "2"] + #[OptimisticCIOSMultiplication = "true"] + #[OptimisticCIOSSquaring = "true"] + pub struct FrCios(FrCiosRepr); +} + +use self::fr::Fr; +use self::frcios::FrCios; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +// #[inline(always)] +// fn mul(a: F, b: &F) -> F { +// let mut c = a; +// c.mul_assign(b); + +// c +// } + +// fn multiplication_benchmark(c: &mut Criterion) { +// use rand::{Rng, XorShiftRng, SeedableRng}; + +// let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); +// let a: Fr = rng.gen(); +// let b: Fr = rng.gen(); + +// c.bench_function("Mont mul 256", |bencher| bencher.iter(|| mul(black_box(a), &black_box(b)))); +// } + +fn mul_assing_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fr = rng.gen(); + let b: Fr = rng.gen(); + + c.bench_function("Mont mul assign 256", |bencher| bencher.iter(|| black_box(a).mul_assign(&black_box(b)))); +} + +fn mul_assing_cios_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: FrCios = rng.gen(); + let b: FrCios = rng.gen(); + + c.bench_function("Mont mul assign 256 CIOS derive", |bencher| bencher.iter(|| black_box(a).mul_assign(&black_box(b)))); +} + +fn square_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fr = rng.gen(); + + c.bench_function("Mont square assign 256", |bencher| bencher.iter(|| black_box(a).square())); +} + +fn square_cios_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: FrCios = rng.gen(); + + c.bench_function("Mont square assign 256 CIOS derive", |bencher| bencher.iter(|| black_box(a).square())); +} +// fn mul_assing_custom_benchmark(c: &mut Criterion) { +// use rand::{Rng, XorShiftRng, SeedableRng}; + +// use self::ff_ce_tester::mul_variant0::Fs; +// // use self::mul_variant0::Fs; + +// let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); +// let a: Fs = rng.gen(); +// let b: Fs = rng.gen(); + +// c.bench_function("Mont mul assign 256 custom", |bencher| bencher.iter(|| black_box(a).mul_assign(&black_box(b)))); +// } + +fn mul_assing_rps_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + use ff_ce_tester::mul_variant0::Fs; + // use self::mul_variant0::Fs; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fs = rng.gen(); + let b: Fs = rng.gen(); + + c.bench_function("Mont mul assign 256 RPS", |bencher| bencher.iter(|| black_box(a).rps_mul_assign(&black_box(b)))); +} + +// fn mul_assing_optimistic_cios_benchmark(c: &mut Criterion) { +// use rand::{Rng, XorShiftRng, SeedableRng}; + +// use self::ff_ce_tester::mul_variant0::Fs; +// // use self::mul_variant0::Fs; + +// let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); +// let a: Fs = rng.gen(); +// let b: Fs = rng.gen(); + +// c.bench_function("Mont mul assign 256 optimistic CIOS", |bencher| bencher.iter(|| black_box(a).optimistic_cios_mul_assign(&black_box(b)))); +// } + +fn mul_assing_optimistic_cios_by_value_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + use self::ff_ce_tester::mul_variant0::Fs; + // use self::mul_variant0::Fs; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fs = rng.gen(); + let b: Fs = rng.gen(); + + c.bench_function("Mont mul assign 256 optimistic CIOS by value", |bencher| bencher.iter(|| black_box(a).optimistic_cios_by_value(black_box(b)))); +} + +// fn mul_assing_optimistic_cios_by_value_with_partial_red_benchmark(c: &mut Criterion) { +// use rand::{Rng, XorShiftRng, SeedableRng}; + +// use self::ff_ce_tester::mul_variant0::Fs; +// // use self::mul_variant0::Fs; + +// let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); +// let a: Fs = rng.gen(); +// let b: Fs = rng.gen(); + +// c.bench_function("Mont mul assign 256 optimistic CIOS by value with_partial_red", |bencher| bencher.iter(|| black_box(a).optimistic_cios_by_value_with_partial_red(black_box(b)))); +// } + +// fn mulx_mul_assing_benchmark(c: &mut Criterion) { +// use rand::{Rng, XorShiftRng, SeedableRng}; + +// use self::ff_ce_tester::mul_variant0::Fs; +// // use self::mul_variant0::Fs; + +// let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); +// let a: Fs = rng.gen(); +// let b: Fs = rng.gen(); + +// c.bench_function("Mont mul assign 256 with MULX latency", |bencher| bencher.iter(|| black_box(a).mulx_latency_mul_assign(&black_box(b)))); +// } + +fn llvm_asm_mul_assing_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + use self::ff_ce_tester::mul_variant0::{Fs, mont_mul_asm}; + // use self::mul_variant0::Fs; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fs = rng.gen(); + let b: Fs = rng.gen(); + + let a = a.into_raw_repr().0; + let b = b.into_raw_repr().0; + + c.bench_function("Mont mul assign 256 with LLVM assembly", |bencher| bencher.iter(|| mont_mul_asm( + black_box(&a), + black_box(&b) + ) + )); +} + +fn new_asm_mul_assing_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + use self::ff_ce_tester::assembly_4::*; + use ff_ce_tester::mul_variant0::Fs; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fs = rng.gen(); + let b: Fs = rng.gen(); + + let a = a.into_raw_repr().0; + let b = b.into_raw_repr().0; + + c.bench_function("Mont mul assign 256 with new assembly", |bencher| bencher.iter(|| mont_mul_asm( + black_box(&a), + black_box(&b) + ) + )); +} + +fn adx_asm_mul_assing_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + use self::ff_ce_tester::assembly_4::*; + use ff_ce_tester::mul_variant0::Fs; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fs = rng.gen(); + let b: Fs = rng.gen(); + + let a = a.into_raw_repr().0; + let b = b.into_raw_repr().0; + + c.bench_function("Mont mul assign 256 with new assembly with ADX", |bencher| bencher.iter(|| mont_mul_asm_adx( + black_box(&a), + black_box(&b) + ) + )); +} + +fn asm_mul_assing_with_register_abi_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + use self::ff_ce_tester::assembly_4::*; + use ff_ce_tester::mul_variant0::Fs; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fs = rng.gen(); + let b: Fs = rng.gen(); + + let [a0, a1, a2, a3] = a.into_raw_repr().0; + let b = b.into_raw_repr().0; + + c.bench_function("Mont mul assign 256 with new assembly and ABI through registers", |bencher| bencher.iter(|| mont_mul_asm_through_registers( + black_box(a0), + black_box(a1), + black_box(a2), + black_box(a3), + black_box(&b) + ) + )); +} + +fn proth_adx_asm_mul_assing_benchmark(c: &mut Criterion) { + use rand::{Rng, XorShiftRng, SeedableRng}; + + use self::ff_ce_tester::assembly_4::*; + use ff_ce_tester::mul_variant0::Fs; + + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + let a: Fs = rng.gen(); + let b: Fs = rng.gen(); + + let a = a.into_raw_repr().0; + let b = b.into_raw_repr().0; + + c.bench_function("Mont mul assign for Proth prime with new assembly with ADX", |bencher| bencher.iter(|| mont_mul_asm_adx_for_proth_prime( + black_box(&a), + black_box(&b) + ) + )); +} + +// fn mul_assing_optimistic_cios_with_different_semantics_benchmark(c: &mut Criterion) { +// use rand::{Rng, XorShiftRng, SeedableRng}; + +// use self::ff_ce_tester::mul_variant0::Fs; +// // use self::mul_variant0::Fs; + +// let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); +// let a: Fs = rng.gen(); +// let b: Fs = rng.gen(); + +// c.bench_function("Mont mul assign 256 optimistic CIOS different semantics", |bencher| bencher.iter(|| black_box(a).optimistic_cios_mul_assign_with_different_semantics(&black_box(b)))); +// } + +// fn mul_assing_vector_benchmark(c: &mut Criterion) { +// use rand::{Rng, XorShiftRng, SeedableRng}; + +// let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); +// let mut a = [Fr::zero(); 1024]; +// let mut b = [Fr::zero(); 1024]; +// for (a, b) in a.iter_mut().zip(b.iter_mut()) { +// *a = rng.gen(); +// *b = rng.gen(); +// } + +// c.bench_function("Mont mul assign vector 256", |bencher| bencher.iter(|| +// { +// let mut a = black_box(a); +// let b = black_box(b); +// for (a, b) in a.iter_mut().zip(b.iter()) { +// a.mul_assign(b); +// } +// } +// )); +// } + +// fn mul_assing_custom_vector_benchmark(c: &mut Criterion) { +// use rand::{Rng, XorShiftRng, SeedableRng}; + +// use self::ff_ce_tester::mul_variant0::Fs; + +// let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); +// let mut a = [Fs::zero(); 1024]; +// let mut b = [Fs::zero(); 1024]; +// for (a, b) in a.iter_mut().zip(b.iter_mut()) { +// *a = rng.gen(); +// *b = rng.gen(); +// } + +// c.bench_function("Mont mul assign custom vector 256", |bencher| bencher.iter(|| +// { +// let mut a = black_box(a); +// let b = black_box(b); +// for (a, b) in a.iter_mut().zip(b.iter()) { +// a.mul_assign(b); +// } +// } +// )); +// } + +// criterion_group!(benches, mul_assing_benchmark, mul_assing_rps_benchmark, mul_assing_optimistic_cios_benchmark); +// criterion_group!(benches, mul_assing_vector_benchmark, mul_assing_custom_vector_benchmark); +// criterion_main!(benches); + +// criterion_group!( +// name = advanced; +// config = Criterion::default().warm_up_time(std::time::Duration::from_secs(10)); +// targets = mul_assing_benchmark, mul_assing_rps_benchmark, mul_assing_optimistic_cios_benchmark, mul_assing_optimistic_cios_with_different_semantics_benchmark +// ); + +// criterion_group!( +// name = advanced; +// config = Criterion::default().warm_up_time(std::time::Duration::from_secs(10)); +// targets = mul_assing_optimistic_cios_benchmark, mul_assing_optimistic_cios_by_value_benchmark, asm_mul_assing_benchmark +// ); +// criterion_group!( +// name = advanced; +// config = Criterion::default().warm_up_time(std::time::Duration::from_secs(5)); +// targets = mul_assing_benchmark, mul_assing_cios_benchmark, mul_assing_optimistic_cios_by_value_benchmark, llvm_asm_mul_assing_benchmark, new_asm_mul_assing_benchmark, adx_asm_mul_assing_benchmark, asm_mul_assing_with_register_abi_benchmark, proth_adx_asm_mul_assing_benchmark +// ); +criterion_group!( + name = advanced; + config = Criterion::default().warm_up_time(std::time::Duration::from_secs(5)); + targets = llvm_asm_mul_assing_benchmark, new_asm_mul_assing_benchmark, adx_asm_mul_assing_benchmark, proth_adx_asm_mul_assing_benchmark +); +criterion_main!(advanced); diff --git a/crates/ff/tester/check_with_features.sh b/crates/ff/tester/check_with_features.sh new file mode 100755 index 0000000..fbd4a28 --- /dev/null +++ b/crates/ff/tester/check_with_features.sh @@ -0,0 +1 @@ +RUSTFLAGS="-C target-cpu=native -C target_feature=+bmi2,+adx -Z macro-backtrace" cargo +nightly check \ No newline at end of file diff --git a/crates/ff/tester/src/adx_4/mod.rs b/crates/ff/tester/src/adx_4/mod.rs new file mode 100644 index 0000000..7b9fc3e --- /dev/null +++ b/crates/ff/tester/src/adx_4/mod.rs @@ -0,0 +1,700 @@ +static ZERO: u64 = 0; +static MODULUS_0: u64 = 0xd0970e5ed6f72cb7; +static MODULUS_1: u64 = 0xa6682093ccc81082; +static MODULUS_2: u64 = 0x6673b0101343b00; +static MODULUS_3: u64 = 0xe7db4ea6533afa9; +static INV: u64 = 0x1ba3a358ef788ef9; + +static MODULUS_0_INV: u64 = MODULUS_0.wrapping_neg(); +static MODULUS_1_INV: u64 = MODULUS_1.wrapping_neg(); +static MODULUS_2_INV: u64 = MODULUS_2.wrapping_neg(); +static MODULUS_3_INV: u64 = MODULUS_3.wrapping_neg(); + +#[allow(dead_code)] +#[allow(clippy::too_many_lines)] +#[inline(always)] +#[cfg(all(target_arch = "x86_64", target_feature = "adx"))] +pub fn mont_mul_asm_adx(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + // this is CIOS multiplication when top bit for top word of modulus is not set + + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + // mulx dest_hi, dest_lo, src1 + // use notation of order (hi, lo) + + // | | b3 | b2 | b1 | b0 | + // | | | | | a0 | + // |---- |---- |---- |---- |---- | + // | | | | r14 | r13 | + // | | | r9 | r8 | | + // | | r10 | r15 | | | + // | r12 | rdi | | | | + // |---- |---- |---- |---- |---- | + // | | | | | | // rdx = m, r11 = garbage + // | | | CF | r14 | | + // | OF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | | CF | r15 | | | + // | r12 | | | | | + // | CF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | r12 | r10 | r15 | r14 | r13 | + + unsafe { + asm!( + // round 0 + "mov rdx, qword ptr [{a_ptr} + 0]", + "xor r8d, r8d", + "mulx r14, r13, qword ptr [{b_ptr} + 0]", // (r14, r13) = a[0] * b[0] + "mulx r9, r8, qword ptr [{b_ptr} + 8]", // (r9, r8) = a[0] * b[1] + "mulx r10, r15, qword ptr [{b_ptr} + 16]", // (r10, r15) = a[0] * b[2] + "mulx r12, rdi, qword ptr [{b_ptr} + 24]", // (r12, rdi) = a[0] * b[3] + // by this moment MULX for a[0] * b[0] is complete (latency = 4) + "mov rdx, r13", // rdx = r13 = (a[0] * b[0]).l0 + "mov r11, {inv}", + "mulx r11, rdx, r11", // (r11, rdx) = (a[0] * b[0]).lo * k, so rdx = m (we overwrite rdx cause (a[0] * b[0]).lo is not needed for anything else) + // "mulx r11, rdx, qword ptr [rip + {inv_ptr}]", // (r11, rdx) = (a[0] * b[0]).lo * k, so rdx = m (we overwrite rdx cause (a[0] * b[0]).lo is not needed for anything else) + "adcx r14, r8", // r14 = r14 + r8 = (a[0] * b[0]).hi + (a[0] * b[1]).lo, carry flag is set in CF register (CF = carry into 2nd word), 1st word calculation + "adox r10, rdi", // r10 = r10 + rdi = (a[0] * b[2]).hi + (a[0] * b[3]).lo, carry flag is set in OF register (OF = carry into 4th word), 3rd word calculation + "adcx r15, r9", // r15 = r15 + r9 + CF = (a[0] * b[1]).hi + (a[0] * b[2]).lo + CF, 2nd word continuation + "mov r11, 0", + "adox r12, r11", // r12 = r12 + OF = 4th word + "adcx r10, r11", // r10 = r10 + CF, 3rd word continuation + // "adox r12, qword ptr [rip + {zero_ptr}]", // r12 = r12 + OF = 4th word + // "adcx r10, qword ptr [rip + {zero_ptr}]", // r10 = r10 + CF, 3rd word continuation + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", // (r9, r8) = m * q0 + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", // (r11, rdi) = m * q1 + "adox r13, r8", // r13 = t[0] + (m * q0).lo, set OF + "adcx r14, rdi", // r14 = t[1] + (m * q1).lo, set CF + "adox r14, r9", // r14 = t[1] + (m * q0).hi + OF, set OF + "adcx r15, r11", // r15 = t[2] + (m * q1).hi + CF, set CF + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", // (r9, r8) = m * q2 + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", // (r11, rdi) = m * q3 + "adox r15, r8", // r15 = t[2] + (m * q2).lo + OF, set OF + "adcx r10, rdi", // r10 = t[3] + (m * q3).lo + CF, set CF + "adox r10, r9", // r10 = t[3] + (m * q2).hi + OF, set OF + "adcx r12, r11", // r12 = t[4] + (m * q3).hi + CF, set CF + "mov r9, 0", + "adox r12, r9", // r12 = r12 + OF + // "adox r12, qword ptr [rip + {zero_ptr}]", // r12 = r12 + OF + + // round 1 + "mov rdx, qword ptr [{a_ptr} + 8]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r14, r8", + "adox r15, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r15, rdi", + "adox r10, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r13, rdi, qword ptr [{b_ptr} + 24]", + "adcx r10, r8", + "adox r12, rdi", + "adcx r12, r9", + "mov rdi, 0", + "adox r13, rdi", + "adcx r13, rdi", + // "adox r13, qword ptr [rip + {zero_ptr}]", + // "adcx r13, qword ptr [rip + {zero_ptr}]", + "mov rdx, r14", + "mov r8, {inv}", + "mulx r8, rdx, r8", + // "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r14, r8", + "adcx r15, rdi", + "adox r15, r9", + "adcx r10, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", + "adox r10, r8", + "adcx r12, r9", + "adox r12, rdi", + "adcx r13, r11", + "mov rdi, 0", + "adox r13, rdi", + // "adox r13, qword ptr [rip + {zero_ptr}]", + + // round 2 + "mov rdx, qword ptr [{a_ptr} + 16]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r15, r8", + "adox r10, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r10, rdi", + "adox r12, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r14, rdi, qword ptr [{b_ptr} + 24]", + "adcx r12, r8", + "adox r13, r9", + "adcx r13, rdi", + "mov r9, 0", + "adox r14, r9", + "adcx r14, r9", + // "adox r14, qword ptr [rip + {zero_ptr}]", + // "adcx r14, qword ptr [rip + {zero_ptr}]", + "mov rdx, r15", + "mov r8, {inv}", + "mulx r8, rdx, r8", + // "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r15, r8", + "adcx r10, r9", + "adox r10, rdi", + "adcx r12, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", + "adox r12, r8", + "adcx r13, r9", + "adox r13, rdi", + "adcx r14, r11", + "mov rdi, 0", + "adox r14, rdi", + // "adox r14, qword ptr [rip + {zero_ptr}]", + + // round 3 + "mov rdx, qword ptr [{a_ptr} + 24]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r10, r8", + "adox r12, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r12, rdi", + "adox r13, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r15, rdi, qword ptr [{b_ptr} + 24]", + "adcx r13, r8", + "adox r14, r9", + "adcx r14, rdi", + "mov r9, 0", + "adox r15, r9", + "adcx r15, r9", + // "adox r15, qword ptr [rip + {zero_ptr}]", + // "adcx r15, qword ptr [rip + {zero_ptr}]", + "mov rdx, r10", + "mov r8, {inv}", + "mulx r8, rdx, r8", + // "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r10, r8", + "adcx r12, r9", + "adox r12, rdi", + "adcx r13, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx rdx, rdi, qword ptr [rip + {q3_ptr}]", + "adox r13, r8", + "adcx r14, r9", + "adox r14, rdi", + "adcx r15, rdx", + "mov rdi, 0", + "adox r15, rdi", + // "adox r15, qword ptr [rip + {zero_ptr}]", + + // "mov [{out_ptr} + 0], r12", + // "mov [{out_ptr} + 8], r13", + // "mov [{out_ptr} + 16], r14", + // "mov [{out_ptr} + 24], r15", + + // zero_ptr = sym ZERO, + // inv_ptr = sym INV, + q0_ptr = sym MODULUS_0, + q1_ptr = sym MODULUS_1, + q2_ptr = sym MODULUS_2, + q3_ptr = sym MODULUS_3, + inv = const 0x1ba3a358ef788ef9u64, + // out_ptr = in(reg) result.as_mut_ptr(), + a_ptr = in(reg) a.as_ptr(), + b_ptr = in(reg) b.as_ptr(), + out("rdx") _, + out("rdi") _, + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, readonly, nostack) + ); + } + + [r0, r1, r2, r3] +} + + +#[allow(dead_code)] +#[allow(clippy::too_many_lines)] +#[inline(always)] +#[cfg(all(target_arch = "x86_64", target_feature = "adx"))] +pub fn mont_mul_asm_adx_with_reduction(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + // this is CIOS multiplication when top bit for top word of modulus is not set + + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + // mulx dest_hi, dest_lo, src1 + // use notation of order (hi, lo) + + // | | b3 | b2 | b1 | b0 | + // | | | | | a0 | + // |---- |---- |---- |---- |---- | + // | | | | r14 | r13 | + // | | | r9 | r8 | | + // | | r10 | r15 | | | + // | r12 | rdi | | | | + // |---- |---- |---- |---- |---- | + // | | | | | | // rdx = m, r11 = garbage + // | | | CF | r14 | | + // | OF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | | CF | r15 | | | + // | r12 | | | | | + // | CF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | r12 | r10 | r15 | r14 | r13 | + + unsafe { + asm!( + // round 0 + "mov rdx, qword ptr [{a_ptr} + 0]", + "xor r8d, r8d", + "mulx r14, r13, qword ptr [{b_ptr} + 0]", // (r14, r13) = a[0] * b[0] + "mulx r9, r8, qword ptr [{b_ptr} + 8]", // (r9, r8) = a[0] * b[1] + "mulx r10, r15, qword ptr [{b_ptr} + 16]", // (r10, r15) = a[0] * b[2] + "mulx r12, rdi, qword ptr [{b_ptr} + 24]", // (r12, rdi) = a[0] * b[3] + // by this moment MULX for a[0] * b[0] is complete (latency = 4) + "mov rdx, r13", // rdx = r13 = (a[0] * b[0]).l0 + "mov r11, {inv}", + "mulx r11, rdx, r11", // (r11, rdx) = (a[0] * b[0]).lo * k, so rdx = m (we overwrite rdx cause (a[0] * b[0]).lo is not needed for anything else) + "adcx r14, r8", // r14 = r14 + r8 = (a[0] * b[0]).hi + (a[0] * b[1]).lo, carry flag is set in CF register (CF = carry into 2nd word), 1st word calculation + "adox r10, rdi", // r10 = r10 + rdi = (a[0] * b[2]).hi + (a[0] * b[3]).lo, carry flag is set in OF register (OF = carry into 4th word), 3rd word calculation + "adcx r15, r9", // r15 = r15 + r9 + CF = (a[0] * b[1]).hi + (a[0] * b[2]).lo + CF, 2nd word continuation + "mov r11, 0", + "adox r12, r11", // r12 = r12 + OF = 4th word + "adcx r10, r11", // r10 = r10 + CF, 3rd word continuation + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", // (r9, r8) = m * q0 + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", // (r11, rdi) = m * q1 + "adox r13, r8", // r13 = t[0] + (m * q0).lo, set OF + "adcx r14, rdi", // r14 = t[1] + (m * q1).lo, set CF + "adox r14, r9", // r14 = t[1] + (m * q0).hi + OF, set OF + "adcx r15, r11", // r15 = t[2] + (m * q1).hi + CF, set CF + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", // (r9, r8) = m * q2 + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", // (r11, rdi) = m * q3 + "adox r15, r8", // r15 = t[2] + (m * q2).lo + OF, set OF + "adcx r10, rdi", // r10 = t[3] + (m * q3).lo + CF, set CF + "adox r10, r9", // r10 = t[3] + (m * q2).hi + OF, set OF + "adcx r12, r11", // r12 = t[4] + (m * q3).hi + CF, set CF + "mov r9, 0", + "adox r12, r9", // r12 = r12 + OF + + // round 1 + "mov rdx, qword ptr [{a_ptr} + 8]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r14, r8", + "adox r15, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r15, rdi", + "adox r10, r11", + "mulx r13, rdi, qword ptr [{b_ptr} + 24]", + "adcx r10, r8", + "adox r12, rdi", + "adcx r12, r9", + "mov rdi, 0", + "adox r13, rdi", + "adcx r13, rdi", + "mov rdx, r14", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r14, r8", + "adcx r15, rdi", + "adox r15, r9", + "adcx r10, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", + "adox r10, r8", + "adcx r12, r9", + "adox r12, rdi", + "adcx r13, r11", + "mov rdi, 0", + "adox r13, rdi", + + // round 2 + "mov rdx, qword ptr [{a_ptr} + 16]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r15, r8", + "adox r10, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r10, rdi", + "adox r12, r11", + "mulx r14, rdi, qword ptr [{b_ptr} + 24]", + "adcx r12, r8", + "adox r13, r9", + "adcx r13, rdi", + "mov r9, 0", + "adox r14, r9", + "adcx r14, r9", + "mov rdx, r15", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r15, r8", + "adcx r10, r9", + "adox r10, rdi", + "adcx r12, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", + "adox r12, r8", + "adcx r13, r9", + "adox r13, rdi", + "adcx r14, r11", + "mov rdi, 0", + "adox r14, rdi", + + // round 3 + "mov rdx, qword ptr [{a_ptr} + 24]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r10, r8", + "adox r12, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r12, rdi", + "adox r13, r11", + "mulx r15, rdi, qword ptr [{b_ptr} + 24]", + "adcx r13, r8", + "adox r14, r9", + "adcx r14, rdi", + "mov r9, 0", + "adox r15, r9", + "adcx r15, r9", + "mov rdx, r10", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r10, r8", + "adcx r12, r9", + "adox r12, rdi", + "adcx r13, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx rdx, rdi, qword ptr [rip + {q3_ptr}]", + "adox r13, r8", + "adcx r14, r9", + "adox r14, rdi", + "adcx r15, rdx", + "mov rdi, 0", + "adox r15, rdi", + // reduction. We use add/adc cause it's more efficiently encoded + "mov r8, r12", + "mov rdx, {q0_neg}", + "add r12, rdx", + "mov r9, r13", + "mov rdx, {q1_neg}", + "adc r13, rdx", + "mov r10, r14", + "mov rdx, {q2_neg}", + "adc r14, rdx", + "mov r11, r15", + "mov rdx, {q3_neg}", + "adc r15, rdx", + + // "add r12, {q0_neg}", + // "adc r13, {q1_neg}", + // "adc r14, {q2_neg}", + // "adc r15, {q3_neg}", + // overflow flag is 1 => no reduction was necessary + "cmovnc r12, r9", + "cmovnc r13, r10", + "cmovnc r14, r11", + "cmovnc r15, r12", + q0_neg = const 1991615062597996281u64, + q1_neg = const 0x1ba3a358ef788ef9u64, + q2_neg = const 0x1ba3a358ef788ef9u64, + q3_neg = const 0x1ba3a358ef788ef9u64, + // end of reduction + q0_ptr = sym MODULUS_0, + q1_ptr = sym MODULUS_1, + q2_ptr = sym MODULUS_2, + q3_ptr = sym MODULUS_3, + inv = const 0x1ba3a358ef788ef9u64, + a_ptr = in(reg) a.as_ptr(), + b_ptr = in(reg) b.as_ptr(), + out("rdx") _, + out("rdi") _, + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, readonly, nostack) + ); + } + + [r0, r1, r2, r3] +} + +#[allow(dead_code)] +#[allow(clippy::too_many_lines)] +#[inline(always)] +#[cfg(all(target_arch = "x86_64", target_feature = "adx"))] +pub fn add_asm_adx(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + unsafe { + asm!( + "xor r12d, r12d", + "mov r12, qword ptr [{a_ptr} + 0]", + "mov r13, qword ptr [{a_ptr} + 8]", + "mov r14, qword ptr [{a_ptr} + 16]", + "mov r15, qword ptr [{a_ptr} + 24]", + "add r12, qword ptr [{b_ptr} + 0]", + "adc r13, qword ptr [{b_ptr} + 8]", + "adc r14, qword ptr [{b_ptr} + 16]", + "adc r15, qword ptr [{b_ptr} + 24]", + + // "mov r8, qword ptr [{b_ptr} + 0]", + // "mov r9, qword ptr [{b_ptr} + 8]", + // "mov r10, qword ptr [{b_ptr} + 16]", + // "mov r11, qword ptr [{b_ptr} + 24]", + + // reduction. We use add/adc cause it's more efficiently encoded + "mov r8, r12", + "mov rdx, {q0_neg}", + "add r12, rdx", + "mov r9, r13", + "mov rdx, {q1_neg}", + "adc r13, rdx", + "mov r10, r14", + "mov rdx, {q2_neg}", + "adc r14, rdx", + "mov r11, r15", + "mov rdx, {q3_neg}", + "adc r15, rdx", + + // overflow flag is 1 => no reduction was necessary + "cmovnc r12, r9", + "cmovnc r13, r10", + "cmovnc r14, r11", + "cmovnc r15, r12", + q0_neg = const 1991615062597996281u64, + q1_neg = const 0x1ba3a358ef788ef9u64, + q2_neg = const 0x1ba3a358ef788ef9u64, + q3_neg = const 0x1ba3a358ef788ef9u64, + // end of reduction + a_ptr = in(reg) a.as_ptr(), + b_ptr = in(reg) b.as_ptr(), + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, readonly, nostack) + ); + } + + [r0, r1, r2, r3] +} + +#[allow(dead_code)] +#[allow(clippy::too_many_lines)] +#[inline(always)] +#[cfg(all(target_arch = "x86_64", target_feature = "adx"))] +pub fn mont_mul_asm_adx_for_proth_prime(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + // this is CIOS multiplication when top bit for top work of modulus is not set + + // mulx dest_hi, dest_lo, src1 + // use notation of order (hi, lo) + + // | | b3 | b2 | b1 | b0 | + // | | | | | a0 | + // |---- |---- |---- |---- |---- | + // | | | | r14 | r13 | + // | | | r9 | r8 | | + // | | r10 | r15 | | | + // | r12 | rdi | | | | + // |---- |---- |---- |---- |---- | + // | | | | | | // rdx = m, r11 = garbage + // | | | CF | r14 | | + // | OF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | | CF | r15 | | | + // | r12 | | | | | + // | CF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | r12 | r10 | r15 | r14 | r13 | + + unsafe { + asm!( + // round 0 + "mov rdx, qword ptr [{a_ptr} + 0]", + "xor r8d, r8d", + "mulx r14, r13, qword ptr [{b_ptr} + 0]", // (r14, r13) = a[0] * b[0] + "mulx r9, r8, qword ptr [{b_ptr} + 8]", // (r9, r8) = a[0] * b[1] + "mulx r10, r15, qword ptr [{b_ptr} + 16]", // (r10, r15) = a[0] * b[2] + "mulx r12, rdi, qword ptr [{b_ptr} + 24]", // (r12, rdi) = a[0] * b[3] + // by this moment MULX for a[0] * b[0] is complete (latency = 4) + "mov rdx, r13", // rdx = r13 = (a[0] * b[0]).l0 + "mov r11, {inv}", + "mulx r11, rdx, r11", // (r11, rdx) = (a[0] * b[0]).lo * k, so rdx = m (we overwrite rdx cause (a[0] * b[0]).lo is not needed for anything else) + "adcx r14, r8", // r14 = r14 + r8 = (a[0] * b[0]).hi + (a[0] * b[1]).lo, carry flag is set in CF register (CF = carry into 2nd word), 1st word calculation + "adox r10, rdi", // r10 = r10 + rdi = (a[0] * b[2]).hi + (a[0] * b[3]).lo, carry flag is set in OF register (OF = carry into 4th word), 3rd word calculation + "adcx r15, r9", // r15 = r15 + r9 + CF = (a[0] * b[1]).hi + (a[0] * b[2]).lo + CF, 2nd word continuation + "mov r11, 0", + "adox r12, r11", // r12 = r12 + OF = 4th word + "adcx r10, r11", // r10 = r10 + CF, 3rd word continuation + "adox r13, rdx", // r13 = t[0] + (m * q0).lo, set OF + "adcx r14, r11", // r14 = t[1] + (m * q1).lo, set CF + "adox r14, r11", // r14 = t[1] + (m * q0).hi + OF, set OF + "adcx r15, r11", // r15 = t[2] + (m * q1).hi + CF, set CF + "mov r8, {q_3}", + "mulx r9, rdi, r8", // (r11, rdi) = m * q3 + "adox r15, r11", // r15 = t[2] + (m * q2).lo + OF, set OF + "adcx r10, rdi", // r10 = t[3] + (m * q3).lo + CF, set CF + "adox r10, r11", // r10 = t[3] + (m * q2).hi + OF, set OF + "adcx r12, r9", // r12 = t[4] + (m * q3).hi + CF, set CF + "adox r12, r11", // r12 = r12 + OF + + // round 1 + "mov rdx, qword ptr [{a_ptr} + 8]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r14, r8", + "adox r15, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r15, rdi", + "adox r10, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r13, rdi, qword ptr [{b_ptr} + 24]", + "adcx r10, r8", + "adox r12, rdi", + "adcx r12, r9", + "mov r11, 0", + "adox r13, r11", + "adcx r13, r11", + "mov rdx, r14", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "adox r14, rdx", + "adcx r15, r11", + "adox r15, r11", + "adcx r10, r11", + "mov r8, {q_3}", + "mulx r9, rdi, r8", + "adox r10, r11", + "adcx r12, r11", + "adox r12, rdi", + "adcx r13, r9", + "adox r13, r11", + + // round 2 + "mov rdx, qword ptr [{a_ptr} + 16]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r15, r8", + "adox r10, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r10, rdi", + "adox r12, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r14, rdi, qword ptr [{b_ptr} + 24]", + "adcx r12, r8", + "adox r13, r9", + "adcx r13, rdi", + "mov r11, 0", + "adox r14, r11", + "adcx r14, r11", + "mov rdx, r15", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "adox r15, rdx", + "adcx r10, r11", + "adox r10, r11", + "adcx r12, r11", + "mov r8, {q_3}", + "mulx r9, rdi, r8", + "adox r12, r11", + "adcx r13, r11", + "adox r13, rdi", + "adcx r14, r9", + "adox r14, r11", + + // round 3 + "mov rdx, qword ptr [{a_ptr} + 24]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r10, r8", + "adox r12, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r12, rdi", + "adox r13, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r15, rdi, qword ptr [{b_ptr} + 24]", + "adcx r13, r8", + "adox r14, r9", + "adcx r14, rdi", + "mov r11, 0", + "adox r15, r11", + "adcx r15, r11", + "mov rdx, r10", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "adox r10, rdx", + "adcx r12, r11", + "adox r12, r11", + "adcx r13, r11", + "mov r8, {q_3}", + "mulx r9, rdi, r8", + "adox r13, r11", + "adcx r14, r11", + "adox r14, rdi", + "adcx r15, r9", + "adox r15, r11", + q_3 = const 0xe7db4ea6533afa9u64, + inv = const 0xffffffffffffffffu64, + a_ptr = in(reg) a.as_ptr(), + b_ptr = in(reg) b.as_ptr(), + out("rdx") _, + out("rdi") _, + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, nomem, nostack) + ); + } + + [r0, r1, r2, r3] +} \ No newline at end of file diff --git a/crates/ff/tester/src/assembly_4.rs b/crates/ff/tester/src/assembly_4.rs new file mode 100644 index 0000000..790186f --- /dev/null +++ b/crates/ff/tester/src/assembly_4.rs @@ -0,0 +1,1040 @@ +const MODULUS: [u64; 4] = [0xd0970e5ed6f72cb7, 0xa6682093ccc81082, 0x6673b0101343b00, 0xe7db4ea6533afa9]; +const INV: u64 = 0x1ba3a358ef788ef9; + +#[allow(dead_code)] +#[allow(clippy::too_many_lines)] +#[inline(always)] +#[cfg(all(target_arch = "x86_64", target_feature = "adx"))] +pub fn mont_mul_asm_adx(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + use core::mem::MaybeUninit; + + static ZERO: u64 = 0; + static MODULUS_0: u64 = 0xd0970e5ed6f72cb7; + static MODULUS_1: u64 = 0xa6682093ccc81082; + static MODULUS_2: u64 = 0x6673b0101343b00; + static MODULUS_3: u64 = 0xe7db4ea6533afa9; + static INV: u64 = 0x1ba3a358ef788ef9; + + static MODULUS_0_INV: u64 = MODULUS_0.wrapping_neg(); + static MODULUS_1_INV: u64 = MODULUS_1.wrapping_neg(); + static MODULUS_2_INV: u64 = MODULUS_2.wrapping_neg(); + static MODULUS_3_INV: u64 = MODULUS_3.wrapping_neg(); + + // this is CIOS multiplication when top bit for top word of modulus is not set + + // let mut result = MaybeUninit::<[u64; 4]>::uninit(); + + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + // mulx dest_hi, dest_lo, src1 + // use notation of order (hi, lo) + + // | | b3 | b2 | b1 | b0 | + // | | | | | a0 | + // |---- |---- |---- |---- |---- | + // | | | | r14 | r13 | + // | | | r9 | r8 | | + // | | r10 | r15 | | | + // | r12 | rdi | | | | + // |---- |---- |---- |---- |---- | + // | | | | | | // rdx = m, r11 = garbage + // | | | CF | r14 | | + // | OF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | | CF | r15 | | | + // | r12 | | | | | + // | CF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | r12 | r10 | r15 | r14 | r13 | + + unsafe { + asm!( + // round 0 + "mov rdx, qword ptr [{a_ptr} + 0]", + "xor r8d, r8d", + "mulx r14, r13, qword ptr [{b_ptr} + 0]", // (r14, r13) = a[0] * b[0] + "mulx r9, r8, qword ptr [{b_ptr} + 8]", // (r9, r8) = a[0] * b[1] + "mulx r10, r15, qword ptr [{b_ptr} + 16]", // (r10, r15) = a[0] * b[2] + "mulx r12, rdi, qword ptr [{b_ptr} + 24]", // (r12, rdi) = a[0] * b[3] + // by this moment MULX for a[0] * b[0] is complete (latency = 4) + "mov rdx, r13", // rdx = r13 = (a[0] * b[0]).l0 + "mov r11, {inv}", + "mulx r11, rdx, r11", // (r11, rdx) = (a[0] * b[0]).lo * k, so rdx = m (we overwrite rdx cause (a[0] * b[0]).lo is not needed for anything else) + // "mulx r11, rdx, qword ptr [rip + {inv_ptr}]", // (r11, rdx) = (a[0] * b[0]).lo * k, so rdx = m (we overwrite rdx cause (a[0] * b[0]).lo is not needed for anything else) + "adcx r14, r8", // r14 = r14 + r8 = (a[0] * b[0]).hi + (a[0] * b[1]).lo, carry flag is set in CF register (CF = carry into 2nd word), 1st word calculation + "adox r10, rdi", // r10 = r10 + rdi = (a[0] * b[2]).hi + (a[0] * b[3]).lo, carry flag is set in OF register (OF = carry into 4th word), 3rd word calculation + "adcx r15, r9", // r15 = r15 + r9 + CF = (a[0] * b[1]).hi + (a[0] * b[2]).lo + CF, 2nd word continuation + "mov r11, 0", + "adox r12, r11", // r12 = r12 + OF = 4th word + "adcx r10, r11", // r10 = r10 + CF, 3rd word continuation + // "adox r12, qword ptr [rip + {zero_ptr}]", // r12 = r12 + OF = 4th word + // "adcx r10, qword ptr [rip + {zero_ptr}]", // r10 = r10 + CF, 3rd word continuation + // "mov r8, {q0}", + // "mulx r9, r8, r8", // (r9, r8) = m * q0 + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", // (r9, r8) = m * q0 + // "mov rdi, {q1}", + // "mulx r11, rdi, rdi", // (r11, rdi) = m * q1 + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", // (r11, rdi) = m * q1 + "adox r13, r8", // r13 = t[0] + (m * q0).lo, set OF + "adcx r14, rdi", // r14 = t[1] + (m * q1).lo, set CF + "adox r14, r9", // r14 = t[1] + (m * q0).hi + OF, set OF + "adcx r15, r11", // r15 = t[2] + (m * q1).hi + CF, set CF + // "mov r8, {q2}", + // "mulx r9, r8, r8", // (r9, r8) = m * q2 + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", // (r9, r8) = m * q2 + // "mov rdi, {q3}", + // "mulx r11, rdi, rdi", // (r11, rdi) = m * q3 + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", // (r11, rdi) = m * q3 + "adox r15, r8", // r15 = t[2] + (m * q2).lo + OF, set OF + "adcx r10, rdi", // r10 = t[3] + (m * q3).lo + CF, set CF + "adox r10, r9", // r10 = t[3] + (m * q2).hi + OF, set OF + "adcx r12, r11", // r12 = t[4] + (m * q3).hi + CF, set CF + "mov r9, 0", + "adox r12, r9", // r12 = r12 + OF + // "adox r12, qword ptr [rip + {zero_ptr}]", // r12 = r12 + OF + + // round 1 + "mov rdx, qword ptr [{a_ptr} + 8]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r14, r8", + "adox r15, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r15, rdi", + "adox r10, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r13, rdi, qword ptr [{b_ptr} + 24]", + "adcx r10, r8", + "adox r12, rdi", + "adcx r12, r9", + "mov rdi, 0", + "adox r13, rdi", + "adcx r13, rdi", + // "adox r13, qword ptr [rip + {zero_ptr}]", + // "adcx r13, qword ptr [rip + {zero_ptr}]", + "mov rdx, r14", + "mov r8, {inv}", + "mulx r8, rdx, r8", + // "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + // "mov r8, {q0}", + // "mulx r9, r8, r8", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + // "mov rdi, {q1}", + // "mulx r11, rdi, rdi", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r14, r8", + "adcx r15, rdi", + "adox r15, r9", + "adcx r10, r11", + // "mov r8, {q2}", + // "mulx r9, r8, r8", + // "mov rdi, {q3}", + // "mulx r11, rdi, rdi", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", + "adox r10, r8", + "adcx r12, r9", + "adox r12, rdi", + "adcx r13, r11", + "mov rdi, 0", + "adox r13, rdi", + // "adox r13, qword ptr [rip + {zero_ptr}]", + + // round 2 + "mov rdx, qword ptr [{a_ptr} + 16]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r15, r8", + "adox r10, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r10, rdi", + "adox r12, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r14, rdi, qword ptr [{b_ptr} + 24]", + "adcx r12, r8", + "adox r13, r9", + "adcx r13, rdi", + "mov r9, 0", + "adox r14, r9", + "adcx r14, r9", + // "adox r14, qword ptr [rip + {zero_ptr}]", + // "adcx r14, qword ptr [rip + {zero_ptr}]", + "mov rdx, r15", + "mov r8, {inv}", + "mulx r8, rdx, r8", + // "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + // "mov r8, {q0}", + // "mulx r9, r8, r8", + // "mov rdi, {q1}", + // "mulx r11, rdi, rdi", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r15, r8", + "adcx r10, r9", + "adox r10, rdi", + "adcx r12, r11", + // "mov r8, {q2}", + // "mulx r9, r8, r8", + // "mov rdi, {q3}", + // "mulx r11, rdi, rdi", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", + "adox r12, r8", + "adcx r13, r9", + "adox r13, rdi", + "adcx r14, r11", + "mov rdi, 0", + "adox r14, rdi", + // "adox r14, qword ptr [rip + {zero_ptr}]", + + // round 3 + "mov rdx, qword ptr [{a_ptr} + 24]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r10, r8", + "adox r12, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r12, rdi", + "adox r13, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r15, rdi, qword ptr [{b_ptr} + 24]", + "adcx r13, r8", + "adox r14, r9", + "adcx r14, rdi", + "mov r9, 0", + "adox r15, r9", + "adcx r15, r9", + // "adox r15, qword ptr [rip + {zero_ptr}]", + // "adcx r15, qword ptr [rip + {zero_ptr}]", + "mov rdx, r10", + "mov r8, {inv}", + "mulx r8, rdx, r8", + // "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + // "mov r8, {q0}", + // "mulx r9, r8, r8", + // "mov rdi, {q1}", + // "mulx r11, rdi, rdi", + + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r10, r8", + "adcx r12, r9", + "adox r12, rdi", + "adcx r13, r11", + // "mov r8, {q2}", + // "mulx r9, r8, r8", + // "mov rdi, {q3}", + // "mulx rdx, rdi, rdi", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx rdx, rdi, qword ptr [rip + {q3_ptr}]", + "adox r13, r8", + "adcx r14, r9", + "adox r14, rdi", + "adcx r15, rdx", + "mov rdi, 0", + "adox r15, rdi", + // "adox r15, qword ptr [rip + {zero_ptr}]", + + // "mov [{out_ptr} + 0], r12", + // "mov [{out_ptr} + 8], r13", + // "mov [{out_ptr} + 16], r14", + // "mov [{out_ptr} + 24], r15", + + // zero_ptr = sym ZERO, + // inv_ptr = sym INV, + // q0 = const 0xd0970e5ed6f72cb7u64, + // q1 = const 0xa6682093ccc81082u64, + // q2 = const 0x6673b0101343b00u64, + // q3 = const 0xe7db4ea6533afa9u64, + + q0_ptr = sym MODULUS_0, + q1_ptr = sym MODULUS_1, + q2_ptr = sym MODULUS_2, + q3_ptr = sym MODULUS_3, + inv = const 0x1ba3a358ef788ef9u64, + // out_ptr = in(reg) result.as_mut_ptr(), + a_ptr = in(reg) a.as_ptr(), + b_ptr = in(reg) b.as_ptr(), + out("rdx") _, + out("rdi") _, + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, readonly, nostack) + ); + } + + [r0, r1, r2, r3] + + // let result = unsafe { result.assume_init() }; + + // result +} + + +#[allow(dead_code)] +#[allow(clippy::too_many_lines)] +#[inline(always)] +#[cfg(all(target_arch = "x86_64", target_feature = "adx"))] +pub fn mont_mul_asm_adx_with_reduction(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + use core::mem::MaybeUninit; + + static ZERO: u64 = 0; + static MODULUS_0: u64 = 0xd0970e5ed6f72cb7; + static MODULUS_1: u64 = 0xa6682093ccc81082; + static MODULUS_2: u64 = 0x6673b0101343b00; + static MODULUS_3: u64 = 0xe7db4ea6533afa9; + static INV: u64 = 0x1ba3a358ef788ef9; + + static MODULUS_0_INV: u64 = MODULUS_0.wrapping_neg(); + static MODULUS_1_INV: u64 = MODULUS_1.wrapping_neg(); + static MODULUS_2_INV: u64 = MODULUS_2.wrapping_neg(); + static MODULUS_3_INV: u64 = MODULUS_3.wrapping_neg(); + + // this is CIOS multiplication when top bit for top word of modulus is not set + + // let mut result = MaybeUninit::<[u64; 4]>::uninit(); + + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + // mulx dest_hi, dest_lo, src1 + // use notation of order (hi, lo) + + // | | b3 | b2 | b1 | b0 | + // | | | | | a0 | + // |---- |---- |---- |---- |---- | + // | | | | r14 | r13 | + // | | | r9 | r8 | | + // | | r10 | r15 | | | + // | r12 | rdi | | | | + // |---- |---- |---- |---- |---- | + // | | | | | | // rdx = m, r11 = garbage + // | | | CF | r14 | | + // | OF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | | CF | r15 | | | + // | r12 | | | | | + // | CF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | r12 | r10 | r15 | r14 | r13 | + + unsafe { + asm!( + // round 0 + "mov rdx, qword ptr [{a_ptr} + 0]", + "xor r8d, r8d", + "mulx r14, r13, qword ptr [{b_ptr} + 0]", // (r14, r13) = a[0] * b[0] + "mulx r9, r8, qword ptr [{b_ptr} + 8]", // (r9, r8) = a[0] * b[1] + "mulx r10, r15, qword ptr [{b_ptr} + 16]", // (r10, r15) = a[0] * b[2] + "mulx r12, rdi, qword ptr [{b_ptr} + 24]", // (r12, rdi) = a[0] * b[3] + // by this moment MULX for a[0] * b[0] is complete (latency = 4) + "mov rdx, r13", // rdx = r13 = (a[0] * b[0]).l0 + "mov r11, {inv}", + "mulx r11, rdx, r11", // (r11, rdx) = (a[0] * b[0]).lo * k, so rdx = m (we overwrite rdx cause (a[0] * b[0]).lo is not needed for anything else) + "adcx r14, r8", // r14 = r14 + r8 = (a[0] * b[0]).hi + (a[0] * b[1]).lo, carry flag is set in CF register (CF = carry into 2nd word), 1st word calculation + "adox r10, rdi", // r10 = r10 + rdi = (a[0] * b[2]).hi + (a[0] * b[3]).lo, carry flag is set in OF register (OF = carry into 4th word), 3rd word calculation + "adcx r15, r9", // r15 = r15 + r9 + CF = (a[0] * b[1]).hi + (a[0] * b[2]).lo + CF, 2nd word continuation + "mov r11, 0", + "adox r12, r11", // r12 = r12 + OF = 4th word + "adcx r10, r11", // r10 = r10 + CF, 3rd word continuation + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", // (r9, r8) = m * q0 + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", // (r11, rdi) = m * q1 + "adox r13, r8", // r13 = t[0] + (m * q0).lo, set OF + "adcx r14, rdi", // r14 = t[1] + (m * q1).lo, set CF + "adox r14, r9", // r14 = t[1] + (m * q0).hi + OF, set OF + "adcx r15, r11", // r15 = t[2] + (m * q1).hi + CF, set CF + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", // (r9, r8) = m * q2 + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", // (r11, rdi) = m * q3 + "adox r15, r8", // r15 = t[2] + (m * q2).lo + OF, set OF + "adcx r10, rdi", // r10 = t[3] + (m * q3).lo + CF, set CF + "adox r10, r9", // r10 = t[3] + (m * q2).hi + OF, set OF + "adcx r12, r11", // r12 = t[4] + (m * q3).hi + CF, set CF + "mov r9, 0", + "adox r12, r9", // r12 = r12 + OF + + // round 1 + "mov rdx, qword ptr [{a_ptr} + 8]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r14, r8", + "adox r15, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r15, rdi", + "adox r10, r11", + "mulx r13, rdi, qword ptr [{b_ptr} + 24]", + "adcx r10, r8", + "adox r12, rdi", + "adcx r12, r9", + "mov rdi, 0", + "adox r13, rdi", + "adcx r13, rdi", + "mov rdx, r14", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r14, r8", + "adcx r15, rdi", + "adox r15, r9", + "adcx r10, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", + "adox r10, r8", + "adcx r12, r9", + "adox r12, rdi", + "adcx r13, r11", + "mov rdi, 0", + "adox r13, rdi", + + // round 2 + "mov rdx, qword ptr [{a_ptr} + 16]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r15, r8", + "adox r10, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r10, rdi", + "adox r12, r11", + "mulx r14, rdi, qword ptr [{b_ptr} + 24]", + "adcx r12, r8", + "adox r13, r9", + "adcx r13, rdi", + "mov r9, 0", + "adox r14, r9", + "adcx r14, r9", + "mov rdx, r15", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r15, r8", + "adcx r10, r9", + "adox r10, rdi", + "adcx r12, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q3_ptr}]", + "adox r12, r8", + "adcx r13, r9", + "adox r13, rdi", + "adcx r14, r11", + "mov rdi, 0", + "adox r14, rdi", + + // round 3 + "mov rdx, qword ptr [{a_ptr} + 24]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r10, r8", + "adox r12, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r12, rdi", + "adox r13, r11", + "mulx r15, rdi, qword ptr [{b_ptr} + 24]", + "adcx r13, r8", + "adox r14, r9", + "adcx r14, rdi", + "mov r9, 0", + "adox r15, r9", + "adcx r15, r9", + "mov rdx, r10", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "mulx r9, r8, qword ptr [rip + {q0_ptr}]", + "mulx r11, rdi, qword ptr [rip + {q1_ptr}]", + "adox r10, r8", + "adcx r12, r9", + "adox r12, rdi", + "adcx r13, r11", + "mulx r9, r8, qword ptr [rip + {q2_ptr}]", + "mulx rdx, rdi, qword ptr [rip + {q3_ptr}]", + "adox r13, r8", + "adcx r14, r9", + "adox r14, rdi", + "adcx r15, rdx", + "mov rdi, 0", + "adox r15, rdi", + // reduction. We use sub/sbb + + "mov r8, r12", + "mov rdx, {q0_neg}", + "sub r8, rdx", + "mov r9, r13", + "mov rdx, {q1_neg}", + "sbb r9, rdx", + "mov r10, r14", + "mov rdx, {q2_neg}", + "sbb r10, rdx", + "mov r11, r15", + "mov rdx, {q3_neg}", + "sbb r11, rdx", + + // if CF == 1 then original result was ok (reduction wa not necessary) + // so if not carry (CMOVNQ) then we copy + "cmovnc r12, r8", + "cmovnc r13, r9", + "cmovnc r14, r10", + "cmovnc r15, r11", + q0_neg = const 0xd0970e5ed6f72cb7u64, + q1_neg = const 0xa6682093ccc81082u64, + q2_neg = const 0x6673b0101343b00u64, + q3_neg = const 0xe7db4ea6533afa9u64, + // end of reduction + q0_ptr = sym MODULUS_0, + q1_ptr = sym MODULUS_1, + q2_ptr = sym MODULUS_2, + q3_ptr = sym MODULUS_3, + inv = const 0x1ba3a358ef788ef9u64, + a_ptr = in(reg) a.as_ptr(), + b_ptr = in(reg) b.as_ptr(), + out("rdx") _, + out("rdi") _, + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, readonly, nostack) + ); + } + + [r0, r1, r2, r3] +} + +// // assumes that +// macro_rules! branchless_reduce_by_one_modulus { +// ($q0_neg: literal, $q1_neg: literal, $q2_neg: literal, $q3_neg: literal) => { + +// }; +// } + + +#[allow(dead_code)] +#[allow(clippy::too_many_lines)] +#[inline(always)] +#[cfg(target_arch = "x86_64")] +pub fn mont_mul_asm(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + use core::mem::MaybeUninit; + + static INV: u64 = 0x1ba3a358ef788ef9; + static MODULUS: [u64; 4] = [0xd0970e5ed6f72cb7, 0xa6682093ccc81082, 0x6673b0101343b00, 0xe7db4ea6533afa9]; + + // this is CIOS multiplication when top bit for top work of modulus is not set + let mut result = MaybeUninit::<[u64; 4]>::uninit(); + // mulx dest_hi, dest_lo, src1 + // use notation of order (hi, lo) + + unsafe { + asm!( + // round 0 + "mov rdx, qword ptr [{a_ptr} + 0]", + "xor r8d, r8d", + "mulx r14, r13, qword ptr [{b_ptr} + 0]", // (r14, r13) = a[0] * b[0] + "mulx r9, r8, qword ptr [{b_ptr} + 8]", // (r9, r8) = a[0] * b[1] + "mulx r10, r15, qword ptr [{b_ptr} + 16]", // (r10, r15) = a[0] * b[2] + "mulx r12, rdi, qword ptr [{b_ptr} + 24]", // (r12, rdi) = a[0] * b[3] + "mov rdx, r13", // rdx = r13 = (a[0] * b[0]).l0, r[0] in r13, + "mulx r11, rdx, qword ptr [rip + {inv_ptr}]", // (r11, rdx) = (a[0] * b[0]).lo * k, so rdx = m (we overwrite rdx cause (a[0] * b[0]).lo is not needed for anything else) + "add r14, r8", // t[1], set CF + "adc r15, r9", // t[2] + CF, set CF + "adc r10, rdi", // t[3] + CF, set CF + "adc r12, 0", // t[4] + CF + "mulx r9, r8, qword ptr [rip + {q_ptr} + 0]", // (r9, r8) = m * q0 + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 8]", // (r11, rdi) = m * q1 + "add r13, r8", // r[0] + "adc r14, rdi", // r[1] + "adc r15, r11", // r[2] + "adc r10, 0", // r[3] + "add r14, r9", // continue r[1] + "mulx r9, r8, qword ptr [rip + {q_ptr} + 16]", // (r9, r8) = m * q2 + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 24]", // (r11, rdi) = m * q3 + "adc r15, r8", // continue r[2] + "adc r10, rdi", // continue r[3] + "adc r12, r11", // r[4] + "add r10, r9", // finish r[3] + "adc r12, 0", // finish r[4] + + // round 1 + "mov rdx, qword ptr [{a_ptr} + 8]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "add r14, r8", + "adc r15, rdi", + "adc r10, r11", + "adc r12, 0", + "add r15, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r13, rdi, qword ptr [{b_ptr} + 24]", + "adc r10, r8", + "adc r12, rdi", + "adc r13, 0", + "add r12, r9", + "adc r13, 0", + "mov rdx, r14", + "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 0]", + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 8]", + "add r14, r8", + "adc r15, rdi", + "adc r10, r11", + "adc r12, 0", + "add r15, r9", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 16]", + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 24]", + "adc r10, r8", + "adc r12, r9", + "adc r13, r11", + "add r12, rdi", + "adc r13, 0", + + // round 2 + "mov rdx, qword ptr [{a_ptr} + 16]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "add r15, r8", + "adc r10, r9", + "adc r12, r11", + "adc r13, 0", + "add r10, rdi", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r14, rdi, qword ptr [{b_ptr} + 24]", + "adc r12, r8", + "adc r13, r9", + "adc r14, 0", + "add r13, rdi", + "adc r14, 0", + "mov rdx, r15", + "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 0]", + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 8]", + "add r15, r8", + "adc r10, r9", + "adc r12, r11", + "adc r13, 0", + "add r10, rdi", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 16]", + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 24]", + "adc r12, r8", + "adc r13, r9", + "adc r14, r11", + "add r13, rdi", + "adc r14, 0", + + // round 3 + "mov rdx, qword ptr [{a_ptr} + 24]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "add r10, r8", + "adc r12, r9", + "adc r13, r11", + "adc r14, 0", + "add r12, rdi", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r15, rdi, qword ptr [{b_ptr} + 24]", + "adc r13, r8", + "adc r14, r9", + "adc r15, 0", + "add r14, rdi", + "adc r15, 0", + "mov rdx, r10", + "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 0]", + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 8]", + "add r10, r8", + "adc r12, r9", + "adc r13, r11", + "adc r14, 0", + "add r12, rdi", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 16]", + "mulx rdx, rdi, qword ptr [rip + {q_ptr} + 24]", + "adc r13, r8", + "adc r14, r9", + "adc r15, rdx", + "add r14, rdi", + "adc r15, 0", + "mov [{out_ptr} + 0], r12", + "mov [{out_ptr} + 8], r13", + "mov [{out_ptr} + 16], r14", + "mov [{out_ptr} + 24], r15", + q_ptr = sym MODULUS, + inv_ptr = sym INV, + out_ptr = in(reg) result.as_mut_ptr(), + a_ptr = in(reg) a.as_ptr(), + b_ptr = in(reg) b.as_ptr(), + out("rdx") _, + out("rdi") _, + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") _, + out("r13") _, + out("r14") _, + out("r15") _, + ); + } + + let result = unsafe { result.assume_init() }; + + result +} + + +#[allow(dead_code)] +#[allow(clippy::too_many_lines)] +#[inline(always)] +#[cfg(target_arch = "x86_64")] +pub fn mont_mul_asm_through_registers(mut a0: u64, mut a1: u64, mut a2: u64, mut a3: u64, b: &[u64; 4]) -> (u64, u64, u64, u64) { + static INV: u64 = 0x1ba3a358ef788ef9; + static MODULUS: [u64; 4] = [0xd0970e5ed6f72cb7, 0xa6682093ccc81082, 0x6673b0101343b00, 0xe7db4ea6533afa9]; + + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + unsafe { + asm!( + // round 0 + "mov rdx, r12", // move a0 to rdx + "xor r12d, r12d", // clear flags before we begin + // a0 is in rdx + "mulx r14, r13, qword ptr [{b_ptr} + 0]", // (r14, r13) = a[0] * b[0] + "mulx r9, r8, qword ptr [{b_ptr} + 8]", // (r9, r8) = a[0] * b[1] + "mulx r10, r15, qword ptr [{b_ptr} + 16]", // (r10, r15) = a[0] * b[2] + "mulx r12, rdi, qword ptr [{b_ptr} + 24]", // (r12, rdi) = a[0] * b[3] + "mov rdx, r13", // rdx = r13 = (a[0] * b[0]).l0, r[0] in r13, + "mulx r11, rdx, qword ptr [rip + {inv_ptr}]", // (r11, rdx) = (a[0] * b[0]).lo * k, so rdx = m (we overwrite rdx cause (a[0] * b[0]).lo is not needed for anything else) + "add r14, r8", // t[1], set CF + "adc r15, r9", // t[2] + CF, set CF + "adc r10, rdi", // t[3] + CF, set CF + "adc r12, 0", // t[4] + CF + "mulx r9, r8, qword ptr [rip + {q_ptr} + 0]", // (r9, r8) = m * q0 + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 8]", // (r11, rdi) = m * q1 + "add r13, r8", // r[0] + "adc r14, rdi", // r[1] + "adc r15, r11", // r[2] + "adc r10, 0", // r[3] + "add r14, r9", // continue r[1] + "mulx r9, r8, qword ptr [rip + {q_ptr} + 16]", // (r9, r8) = m * q2 + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 24]", // (r11, rdi) = m * q3 + "adc r15, r8", // continue r[2] + "adc r10, rdi", // continue r[3] + "adc r12, r11", // r[4] + "add r10, r9", // finish r[3] + "adc r12, 0", // finish r[4] + + // round 1 + "mov rdx, {a_1}", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "add r14, r8", + "adc r15, rdi", + "adc r10, r11", + "adc r12, 0", + "add r15, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r13, rdi, qword ptr [{b_ptr} + 24]", + "adc r10, r8", + "adc r12, rdi", + "adc r13, 0", + "add r12, r9", + "adc r13, 0", + "mov rdx, r14", + "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 0]", + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 8]", + "add r14, r8", + "adc r15, rdi", + "adc r10, r11", + "adc r12, 0", + "add r15, r9", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 16]", + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 24]", + "adc r10, r8", + "adc r12, r9", + "adc r13, r11", + "add r12, rdi", + "adc r13, 0", + + // round 2 + "mov rdx, {a_2}", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "add r15, r8", + "adc r10, r9", + "adc r12, r11", + "adc r13, 0", + "add r10, rdi", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r14, rdi, qword ptr [{b_ptr} + 24]", + "adc r12, r8", + "adc r13, r9", + "adc r14, 0", + "add r13, rdi", + "adc r14, 0", + "mov rdx, r15", + "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 0]", + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 8]", + "add r15, r8", + "adc r10, r9", + "adc r12, r11", + "adc r13, 0", + "add r10, rdi", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 16]", + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 24]", + "adc r12, r8", + "adc r13, r9", + "adc r14, r11", + "add r13, rdi", + "adc r14, 0", + + // round 3 + "mov rdx, {a_3}", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "add r10, r8", + "adc r12, r9", + "adc r13, r11", + "adc r14, 0", + "add r12, rdi", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r15, rdi, qword ptr [{b_ptr} + 24]", + "adc r13, r8", + "adc r14, r9", + "adc r15, 0", + "add r14, rdi", + "adc r15, 0", + "mov rdx, r10", + "mulx r8, rdx, qword ptr [rip + {inv_ptr}]", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 0]", + "mulx r11, rdi, qword ptr [rip + {q_ptr} + 8]", + "add r10, r8", + "adc r12, r9", + "adc r13, r11", + "adc r14, 0", + "add r12, rdi", + "mulx r9, r8, qword ptr [rip + {q_ptr} + 16]", + "mulx rdx, rdi, qword ptr [rip + {q_ptr} + 24]", + "adc r13, r8", + "adc r14, r9", + "adc r15, rdx", + "add r14, rdi", + "adc r15, 0", + q_ptr = sym MODULUS, + inv_ptr = sym INV, + b_ptr = in(reg) b.as_ptr(), + a_1 = in(reg) a1, + a_2 = in(reg) a2, + a_3 = in(reg) a2, + inout("r12") a0 => r0, + out("rdx") _, + out("rdi") _, + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, readonly, nostack) + ); + } + + (r0, r1, r2, r3) +} + + +#[allow(dead_code)] +#[allow(clippy::too_many_lines)] +#[inline(always)] +#[cfg(all(target_arch = "x86_64", target_feature = "adx"))] +pub fn mont_mul_asm_adx_for_proth_prime(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + let mut r0: u64; + let mut r1: u64; + let mut r2: u64; + let mut r3: u64; + + // this is CIOS multiplication when top bit for top work of modulus is not set + + // mulx dest_hi, dest_lo, src1 + // use notation of order (hi, lo) + + // | | b3 | b2 | b1 | b0 | + // | | | | | a0 | + // |---- |---- |---- |---- |---- | + // | | | | r14 | r13 | + // | | | r9 | r8 | | + // | | r10 | r15 | | | + // | r12 | rdi | | | | + // |---- |---- |---- |---- |---- | + // | | | | | | // rdx = m, r11 = garbage + // | | | CF | r14 | | + // | OF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | | CF | r15 | | | + // | r12 | | | | | + // | CF | r10 | | | | + // |---- |---- |---- |---- |---- | + // | r12 | r10 | r15 | r14 | r13 | + + unsafe { + asm!( + // round 0 + "mov rdx, qword ptr [{a_ptr} + 0]", + "xor r8d, r8d", + "mulx r14, r13, qword ptr [{b_ptr} + 0]", // (r14, r13) = a[0] * b[0] + "mulx r9, r8, qword ptr [{b_ptr} + 8]", // (r9, r8) = a[0] * b[1] + "mulx r10, r15, qword ptr [{b_ptr} + 16]", // (r10, r15) = a[0] * b[2] + "mulx r12, rdi, qword ptr [{b_ptr} + 24]", // (r12, rdi) = a[0] * b[3] + // by this moment MULX for a[0] * b[0] is complete (latency = 4) + "mov rdx, r13", // rdx = r13 = (a[0] * b[0]).l0 + "mov r11, {inv}", + "mulx r11, rdx, r11", // (r11, rdx) = (a[0] * b[0]).lo * k, so rdx = m (we overwrite rdx cause (a[0] * b[0]).lo is not needed for anything else) + "adcx r14, r8", // r14 = r14 + r8 = (a[0] * b[0]).hi + (a[0] * b[1]).lo, carry flag is set in CF register (CF = carry into 2nd word), 1st word calculation + "adox r10, rdi", // r10 = r10 + rdi = (a[0] * b[2]).hi + (a[0] * b[3]).lo, carry flag is set in OF register (OF = carry into 4th word), 3rd word calculation + "adcx r15, r9", // r15 = r15 + r9 + CF = (a[0] * b[1]).hi + (a[0] * b[2]).lo + CF, 2nd word continuation + "mov r11, 0", + "adox r12, r11", // r12 = r12 + OF = 4th word + "adcx r10, r11", // r10 = r10 + CF, 3rd word continuation + "adox r13, rdx", // r13 = t[0] + (m * q0).lo, set OF + "adcx r14, r11", // r14 = t[1] + (m * q1).lo, set CF + "adox r14, r11", // r14 = t[1] + (m * q0).hi + OF, set OF + "adcx r15, r11", // r15 = t[2] + (m * q1).hi + CF, set CF + "mov r8, {q_3}", + "mulx r9, rdi, r8", // (r11, rdi) = m * q3 + "adox r15, r11", // r15 = t[2] + (m * q2).lo + OF, set OF + "adcx r10, rdi", // r10 = t[3] + (m * q3).lo + CF, set CF + "adox r10, r11", // r10 = t[3] + (m * q2).hi + OF, set OF + "adcx r12, r9", // r12 = t[4] + (m * q3).hi + CF, set CF + "adox r12, r11", // r12 = r12 + OF + + // round 1 + "mov rdx, qword ptr [{a_ptr} + 8]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r14, r8", + "adox r15, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r15, rdi", + "adox r10, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r13, rdi, qword ptr [{b_ptr} + 24]", + "adcx r10, r8", + "adox r12, rdi", + "adcx r12, r9", + "mov r11, 0", + "adox r13, r11", + "adcx r13, r11", + "mov rdx, r14", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "adox r14, rdx", + "adcx r15, r11", + "adox r15, r11", + "adcx r10, r11", + "mov r8, {q_3}", + "mulx r9, rdi, r8", + "adox r10, r11", + "adcx r12, r11", + "adox r12, rdi", + "adcx r13, r9", + "adox r13, r11", + + // round 2 + "mov rdx, qword ptr [{a_ptr} + 16]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r15, r8", + "adox r10, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r10, rdi", + "adox r12, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r14, rdi, qword ptr [{b_ptr} + 24]", + "adcx r12, r8", + "adox r13, r9", + "adcx r13, rdi", + "mov r11, 0", + "adox r14, r11", + "adcx r14, r11", + "mov rdx, r15", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "adox r15, rdx", + "adcx r10, r11", + "adox r10, r11", + "adcx r12, r11", + "mov r8, {q_3}", + "mulx r9, rdi, r8", + "adox r12, r11", + "adcx r13, r11", + "adox r13, rdi", + "adcx r14, r9", + "adox r14, r11", + + // round 3 + "mov rdx, qword ptr [{a_ptr} + 24]", + "mulx r9, r8, qword ptr [{b_ptr} + 0]", + "mulx r11, rdi, qword ptr [{b_ptr} + 8]", + "adcx r10, r8", + "adox r12, r9", + "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "adcx r12, rdi", + "adox r13, r11", + // "mulx r9, r8, qword ptr [{b_ptr} + 16]", + "mulx r15, rdi, qword ptr [{b_ptr} + 24]", + "adcx r13, r8", + "adox r14, r9", + "adcx r14, rdi", + "mov r11, 0", + "adox r15, r11", + "adcx r15, r11", + "mov rdx, r10", + "mov r8, {inv}", + "mulx r8, rdx, r8", + "adox r10, rdx", + "adcx r12, r11", + "adox r12, r11", + "adcx r13, r11", + "mov r8, {q_3}", + "mulx r9, rdi, r8", + "adox r13, r11", + "adcx r14, r11", + "adox r14, rdi", + "adcx r15, r9", + "adox r15, r11", + q_3 = const 0xe7db4ea6533afa9u64, + inv = const 0xffffffffffffffffu64, + a_ptr = in(reg) a.as_ptr(), + b_ptr = in(reg) b.as_ptr(), + out("rdx") _, + out("rdi") _, + out("r8") _, + out("r9") _, + out("r10") _, + out("r11") _, + out("r12") r0, + out("r13") r1, + out("r14") r2, + out("r15") r3, + options(pure, nomem, nostack) + ); + } + + [r0, r1, r2, r3] +} \ No newline at end of file diff --git a/crates/ff/tester/src/check_assembly_4.rs b/crates/ff/tester/src/check_assembly_4.rs new file mode 100644 index 0000000..4249337 --- /dev/null +++ b/crates/ff/tester/src/check_assembly_4.rs @@ -0,0 +1,36 @@ +#[cfg(test)] +mod test { + use super::super::mul_variant0::Fs; + // use super::super::mul_variant0::mont_mul_asm; + use super::super::assembly_4::*; + + use rand::*; + use ff::*; + + #[test] + fn check_mul_asm() { + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + for i in 0..10000 { + let a: Fs = rng.gen(); + let b: Fs = rng.gen(); + + let a_asm = unsafe { std::mem::transmute::<_, [u64; 4]>(a) }; + let b_asm = unsafe { std::mem::transmute::<_, [u64; 4]>(b) }; + + let mut c = a; + c.mul_assign(&b); + + // let c_asm = mont_mul_asm(&a_asm, &b_asm); + // let c_asm = mont_mul_asm_adx(&a_asm, &b_asm); + let c_asm = mont_mul_asm_adx_with_reduction(&a_asm, &b_asm); + + let mut c_back = unsafe { std::mem::transmute::<_, Fs>(c_asm) }; + // if !c_back.is_valid() { + // c_back.reduce(); + // } + + assert_eq!(c, c_back, "failed at iteration {}: a = {:?}, b = {:?}", i, a, b); + } + } +} \ No newline at end of file diff --git a/crates/ff/tester/src/check_cios.rs b/crates/ff/tester/src/check_cios.rs new file mode 100644 index 0000000..f2a0f2b --- /dev/null +++ b/crates/ff/tester/src/check_cios.rs @@ -0,0 +1,52 @@ +#[cfg(test)] +mod test { + use super::super::test_large_field::Fr as Fr; + use super::super::test_large_cios_field::Fr as FrCios; + + use rand::*; + use ff::*; + + #[test] + fn check_mul() { + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + for _ in 0..10000 { + let a: Fr = rng.gen(); + let b: Fr = rng.gen(); + + let a_cios = unsafe { std::mem::transmute::<_, FrCios>(a) }; + let b_cios = unsafe { std::mem::transmute::<_, FrCios>(b) }; + + let mut c = a; + c.mul_assign(&b); + + let mut c_cios = a_cios; + c_cios.mul_assign(&b_cios); + + let c_back = unsafe { std::mem::transmute::<_, Fr>(c_cios) }; + + assert_eq!(c, c_back); + } + } + + #[test] + fn check_sqr() { + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + for _ in 0..10000 { + let a: Fr = rng.gen(); + + let a_cios = unsafe { std::mem::transmute::<_, FrCios>(a) }; + + let mut c = a; + c.square(); + + let mut c_cios = a_cios; + c_cios.square(); + + let c_back = unsafe { std::mem::transmute::<_, Fr>(c_cios) }; + + assert_eq!(c, c_back); + } + } +} \ No newline at end of file diff --git a/crates/ff/tester/src/lib.rs b/crates/ff/tester/src/lib.rs new file mode 100644 index 0000000..c71f1f7 --- /dev/null +++ b/crates/ff/tester/src/lib.rs @@ -0,0 +1,15 @@ +#![feature(llvm_asm)] +#![feature(asm)] + +extern crate ff; +extern crate rand; + +mod test_short_field; +mod test_large_field; +mod test_large_cios_field; +mod check_cios; +mod check_assembly_4; +pub mod mul_variant0; +pub mod assembly_4; + +pub mod adx_4; \ No newline at end of file diff --git a/crates/ff/tester/src/mul_variant0.rs b/crates/ff/tester/src/mul_variant0.rs new file mode 100644 index 0000000..4de600f --- /dev/null +++ b/crates/ff/tester/src/mul_variant0.rs @@ -0,0 +1,1597 @@ +use ff::{adc, sbb, mac_with_carry}; +use ff::{Field, PrimeField, SqrtField, PrimeFieldRepr, PrimeFieldDecodingError, LegendreSymbol}; +use ff::LegendreSymbol::*; + +// s = 6554484396890773809930967563523245729705921265872317281365359162392183254199 +const MODULUS: FsRepr = FsRepr([0xd0970e5ed6f72cb7, 0xa6682093ccc81082, 0x6673b0101343b00, 0xe7db4ea6533afa9]); + +// The number of bits needed to represent the modulus. +const MODULUS_BITS: u32 = 252; + +// The number of bits that must be shaved from the beginning of +// the representation when randomly sampling. +const REPR_SHAVE_BITS: u32 = 4; + +// R = 2**256 % s +const R: FsRepr = FsRepr([0x25f80bb3b99607d9, 0xf315d62f66b6e750, 0x932514eeeb8814f4, 0x9a6fc6f479155c6]); + +// R2 = R^2 % s +const R2: FsRepr = FsRepr([0x67719aa495e57731, 0x51b0cef09ce3fc26, 0x69dab7fac026e9a5, 0x4f6547b8d127688]); + +// INV = -(s^{-1} mod 2^64) mod s +const INV: u64 = 0x1ba3a358ef788ef9; + +// GENERATOR = 6 (multiplicative generator of r-1 order, that is also quadratic nonresidue) +const GENERATOR: FsRepr = FsRepr([0x720b1b19d49ea8f1, 0xbf4aa36101f13a58, 0x5fa8cc968193ccbb, 0xe70cbdc7dccf3ac]); + +// 2^S * t = MODULUS - 1 with t odd +const S: u32 = 1; + +// 2^S root of unity computed by GENERATOR^t +const ROOT_OF_UNITY: FsRepr = FsRepr([0xaa9f02ab1d6124de, 0xb3524a6466112932, 0x7342261215ac260b, 0x4d6b87b1da259e2]); + +// -((2**256) mod s) mod s +const NEGATIVE_ONE: Fs = Fs(FsRepr([0xaa9f02ab1d6124de, 0xb3524a6466112932, 0x7342261215ac260b, 0x4d6b87b1da259e2])); + +/// This is the underlying representation of an element of `Fs`. +#[derive(Copy, Clone, PartialEq, Eq, Default, Debug, Hash)] +pub struct FsRepr(pub [u64; 4]); + +impl ::rand::Rand for FsRepr { + #[inline(always)] + fn rand(rng: &mut R) -> Self { + FsRepr(rng.gen()) + } +} + +impl ::std::fmt::Display for FsRepr +{ + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "0x")?; + for i in self.0.iter().rev() { + write!(f, "{:016x}", *i)?; + } + + Ok(()) + } +} + +impl AsRef<[u64]> for FsRepr { + #[inline(always)] + fn as_ref(&self) -> &[u64] { + &self.0 + } +} + +impl AsMut<[u64]> for FsRepr { + #[inline(always)] + fn as_mut(&mut self) -> &mut [u64] { + &mut self.0 + } +} + +impl From for FsRepr { + #[inline(always)] + fn from(val: u64) -> FsRepr { + let mut repr = Self::default(); + repr.0[0] = val; + repr + } +} + +impl Ord for FsRepr { + #[inline(always)] + fn cmp(&self, other: &FsRepr) -> ::std::cmp::Ordering { + for (a, b) in self.0.iter().rev().zip(other.0.iter().rev()) { + if a < b { + return ::std::cmp::Ordering::Less + } else if a > b { + return ::std::cmp::Ordering::Greater + } + } + + ::std::cmp::Ordering::Equal + } +} + +impl PartialOrd for FsRepr { + #[inline(always)] + fn partial_cmp(&self, other: &FsRepr) -> Option<::std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +impl PrimeFieldRepr for FsRepr { + #[inline(always)] + fn is_odd(&self) -> bool { + self.0[0] & 1 == 1 + } + + #[inline(always)] + fn is_even(&self) -> bool { + !self.is_odd() + } + + #[inline(always)] + fn is_zero(&self) -> bool { + self.0.iter().all(|&e| e == 0) + } + + #[inline(always)] + fn shr(&mut self, mut n: u32) { + if n >= 64 * 4 { + *self = Self::from(0); + return; + } + + while n >= 64 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + ::std::mem::swap(&mut t, i); + } + n -= 64; + } + + if n > 0 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + let t2 = *i << (64 - n); + *i >>= n; + *i |= t; + t = t2; + } + } + } + + #[inline(always)] + fn div2(&mut self) { + let mut t = 0; + for i in self.0.iter_mut().rev() { + let t2 = *i << 63; + *i >>= 1; + *i |= t; + t = t2; + } + } + + #[inline(always)] + fn mul2(&mut self) { + let mut last = 0; + for i in &mut self.0 { + let tmp = *i >> 63; + *i <<= 1; + *i |= last; + last = tmp; + } + } + + #[inline(always)] + fn shl(&mut self, mut n: u32) { + if n >= 64 * 4 { + *self = Self::from(0); + return; + } + + while n >= 64 { + let mut t = 0; + for i in &mut self.0 { + ::std::mem::swap(&mut t, i); + } + n -= 64; + } + + if n > 0 { + let mut t = 0; + for i in &mut self.0 { + let t2 = *i >> (64 - n); + *i <<= n; + *i |= t; + t = t2; + } + } + } + + #[inline(always)] + fn num_bits(&self) -> u32 { + let mut ret = (4 as u32) * 64; + for i in self.0.iter().rev() { + let leading = i.leading_zeros(); + ret -= leading; + if leading != 64 { + break; + } + } + + ret + } + + #[inline(always)] + fn add_nocarry(&mut self, other: &FsRepr) { + let mut carry = 0; + + for (a, b) in self.0.iter_mut().zip(other.0.iter()) { + *a = adc(*a, *b, &mut carry); + } + } + + #[inline(always)] + fn sub_noborrow(&mut self, other: &FsRepr) { + let mut borrow = 0; + + for (a, b) in self.0.iter_mut().zip(other.0.iter()) { + *a = sbb(*a, *b, &mut borrow); + } + } +} + +/// This is an element of the scalar field of the Jubjub curve. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)] +pub struct Fs(FsRepr); + +impl ::std::fmt::Display for Fs +{ + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "Fs({})", self.into_repr()) + } +} + +impl ::rand::Rand for Fs { + fn rand(rng: &mut R) -> Self { + loop { + let mut tmp = Fs(FsRepr::rand(rng)); + + // Mask away the unused bits at the beginning. + tmp.0.as_mut()[3] &= 0xffffffffffffffff >> REPR_SHAVE_BITS; + + if tmp.is_valid() { + return tmp + } + } + } +} + +impl From for FsRepr { + fn from(e: Fs) -> FsRepr { + e.into_repr() + } +} + +impl PrimeField for Fs { + type Repr = FsRepr; + + fn from_repr(r: FsRepr) -> Result { + let mut r = Fs(r); + if r.is_valid() { + r.mul_assign(&Fs(R2)); + + Ok(r) + } else { + Err(PrimeFieldDecodingError::NotInField(format!("{}", r.0))) + } + } + + fn from_raw_repr(r: FsRepr) -> Result { + let r = Fs(r); + if r.is_valid() { + Ok(r) + } else { + Err(PrimeFieldDecodingError::NotInField(format!("{}", r.0))) + } + } + + fn into_repr(&self) -> FsRepr { + let mut r = *self; + r.mont_reduce((self.0).0[0], (self.0).0[1], + (self.0).0[2], (self.0).0[3], + 0, 0, 0, 0); + r.0 + } + + fn into_raw_repr(&self) -> FsRepr { + let r = *self; + r.0 + } + + fn char() -> FsRepr { + MODULUS + } + + const NUM_BITS: u32 = MODULUS_BITS; + + const CAPACITY: u32 = Self::NUM_BITS - 1; + + fn multiplicative_generator() -> Self { + Fs(GENERATOR) + } + + const S: u32 = S; + + fn root_of_unity() -> Self { + Fs(ROOT_OF_UNITY) + } +} + +impl Field for Fs { + #[inline] + fn zero() -> Self { + Fs(FsRepr::from(0)) + } + + #[inline] + fn one() -> Self { + Fs(R) + } + + #[inline] + fn is_zero(&self) -> bool { + self.0.is_zero() + } + + #[inline] + fn add_assign(&mut self, other: &Fs) { + // This cannot exceed the backing capacity. + self.0.add_nocarry(&other.0); + + // However, it may need to be reduced. + self.reduce(); + } + + #[inline] + fn double(&mut self) { + // This cannot exceed the backing capacity. + self.0.mul2(); + + // However, it may need to be reduced. + self.reduce(); + } + + #[inline] + fn sub_assign(&mut self, other: &Fs) { + // If `other` is larger than `self`, we'll need to add the modulus to self first. + if other.0 > self.0 { + self.0.add_nocarry(&MODULUS); + } + + self.0.sub_noborrow(&other.0); + } + + #[inline] + fn negate(&mut self) { + if !self.is_zero() { + let mut tmp = MODULUS; + tmp.sub_noborrow(&self.0); + self.0 = tmp; + } + } + + fn inverse(&self) -> Option { + if self.is_zero() { + None + } else { + // Guajardo Kumar Paar Pelzl + // Efficient Software-Implementation of Finite Fields with Applications to Cryptography + // Algorithm 16 (BEA for Inversion in Fp) + + let one = FsRepr::from(1); + + let mut u = self.0; + let mut v = MODULUS; + let mut b = Fs(R2); // Avoids unnecessary reduction step. + let mut c = Self::zero(); + + while u != one && v != one { + while u.is_even() { + u.div2(); + + if b.0.is_even() { + b.0.div2(); + } else { + b.0.add_nocarry(&MODULUS); + b.0.div2(); + } + } + + while v.is_even() { + v.div2(); + + if c.0.is_even() { + c.0.div2(); + } else { + c.0.add_nocarry(&MODULUS); + c.0.div2(); + } + } + + if v < u { + u.sub_noborrow(&v); + b.sub_assign(&c); + } else { + v.sub_noborrow(&u); + c.sub_assign(&b); + } + } + + if u == one { + Some(b) + } else { + Some(c) + } + } + } + + #[inline(always)] + fn frobenius_map(&mut self, _: usize) { + // This has no effect in a prime field. + } + + #[inline] + fn mul_assign(&mut self, other: &Fs) + { + let mut carry = 0; + let r0 = mac_with_carry(0, (self.0).0[0], (other.0).0[0], &mut carry); + let r1 = mac_with_carry(0, (self.0).0[0], (other.0).0[1], &mut carry); + let r2 = mac_with_carry(0, (self.0).0[0], (other.0).0[2], &mut carry); + let r3 = mac_with_carry(0, (self.0).0[0], (other.0).0[3], &mut carry); + let r4 = carry; + let mut carry = 0; + let r1 = mac_with_carry(r1, (self.0).0[1], (other.0).0[0], &mut carry); + let r2 = mac_with_carry(r2, (self.0).0[1], (other.0).0[1], &mut carry); + let r3 = mac_with_carry(r3, (self.0).0[1], (other.0).0[2], &mut carry); + let r4 = mac_with_carry(r4, (self.0).0[1], (other.0).0[3], &mut carry); + let r5 = carry; + let mut carry = 0; + let r2 = mac_with_carry(r2, (self.0).0[2], (other.0).0[0], &mut carry); + let r3 = mac_with_carry(r3, (self.0).0[2], (other.0).0[1], &mut carry); + let r4 = mac_with_carry(r4, (self.0).0[2], (other.0).0[2], &mut carry); + let r5 = mac_with_carry(r5, (self.0).0[2], (other.0).0[3], &mut carry); + let r6 = carry; + let mut carry = 0; + let r3 = mac_with_carry(r3, (self.0).0[3], (other.0).0[0], &mut carry); + let r4 = mac_with_carry(r4, (self.0).0[3], (other.0).0[1], &mut carry); + let r5 = mac_with_carry(r5, (self.0).0[3], (other.0).0[2], &mut carry); + let r6 = mac_with_carry(r6, (self.0).0[3], (other.0).0[3], &mut carry); + let r7 = carry; + self.mont_reduce(r0, r1, r2, r3, r4, r5, r6, r7); + } + + #[inline] + fn square(&mut self) + { + let mut carry = 0; + let r1 = mac_with_carry(0, (self.0).0[0], (self.0).0[1], &mut carry); + let r2 = mac_with_carry(0, (self.0).0[0], (self.0).0[2], &mut carry); + let r3 = mac_with_carry(0, (self.0).0[0], (self.0).0[3], &mut carry); + let r4 = carry; + let mut carry = 0; + let r3 = mac_with_carry(r3, (self.0).0[1], (self.0).0[2], &mut carry); + let r4 = mac_with_carry(r4, (self.0).0[1], (self.0).0[3], &mut carry); + let r5 = carry; + let mut carry = 0; + let r5 = mac_with_carry(r5, (self.0).0[2], (self.0).0[3], &mut carry); + let r6 = carry; + + let r7 = r6 >> 63; + let r6 = (r6 << 1) | (r5 >> 63); + let r5 = (r5 << 1) | (r4 >> 63); + let r4 = (r4 << 1) | (r3 >> 63); + let r3 = (r3 << 1) | (r2 >> 63); + let r2 = (r2 << 1) | (r1 >> 63); + let r1 = r1 << 1; + + let mut carry = 0; + let r0 = mac_with_carry(0, (self.0).0[0], (self.0).0[0], &mut carry); + let r1 = adc(r1, 0, &mut carry); + let r2 = mac_with_carry(r2, (self.0).0[1], (self.0).0[1], &mut carry); + let r3 = adc(r3, 0, &mut carry); + let r4 = mac_with_carry(r4, (self.0).0[2], (self.0).0[2], &mut carry); + let r5 = adc(r5, 0, &mut carry); + let r6 = mac_with_carry(r6, (self.0).0[3], (self.0).0[3], &mut carry); + let r7 = adc(r7, 0, &mut carry); + self.mont_reduce(r0, r1, r2, r3, r4, r5, r6, r7); + } +} + +impl Fs { + /// Determines if the element is really in the field. This is only used + /// internally. + #[inline(always)] + pub fn is_valid(&self) -> bool { + self.0 < MODULUS + } + + /// Subtracts the modulus from this element if this element is not in the + /// field. Only used internally. + #[inline(always)] + pub fn reduce(&mut self) { + if !self.is_valid() { + self.0.sub_noborrow(&MODULUS); + } + } + + #[inline(always)] + fn mont_reduce( + &mut self, + r0: u64, + mut r1: u64, + mut r2: u64, + mut r3: u64, + mut r4: u64, + mut r5: u64, + mut r6: u64, + mut r7: u64 + ) + { + // The Montgomery reduction here is based on Algorithm 14.32 in + // Handbook of Applied Cryptography + // . + + let k = r0.wrapping_mul(INV); + let mut carry = 0; + mac_with_carry(r0, k, MODULUS.0[0], &mut carry); + r1 = mac_with_carry(r1, k, MODULUS.0[1], &mut carry); + r2 = mac_with_carry(r2, k, MODULUS.0[2], &mut carry); + r3 = mac_with_carry(r3, k, MODULUS.0[3], &mut carry); + r4 = adc(r4, 0, &mut carry); + let carry2 = carry; + let k = r1.wrapping_mul(INV); + let mut carry = 0; + mac_with_carry(r1, k, MODULUS.0[0], &mut carry); + r2 = mac_with_carry(r2, k, MODULUS.0[1], &mut carry); + r3 = mac_with_carry(r3, k, MODULUS.0[2], &mut carry); + r4 = mac_with_carry(r4, k, MODULUS.0[3], &mut carry); + r5 = adc(r5, carry2, &mut carry); + let carry2 = carry; + let k = r2.wrapping_mul(INV); + let mut carry = 0; + mac_with_carry(r2, k, MODULUS.0[0], &mut carry); + r3 = mac_with_carry(r3, k, MODULUS.0[1], &mut carry); + r4 = mac_with_carry(r4, k, MODULUS.0[2], &mut carry); + r5 = mac_with_carry(r5, k, MODULUS.0[3], &mut carry); + r6 = adc(r6, carry2, &mut carry); + let carry2 = carry; + let k = r3.wrapping_mul(INV); + let mut carry = 0; + mac_with_carry(r3, k, MODULUS.0[0], &mut carry); + r4 = mac_with_carry(r4, k, MODULUS.0[1], &mut carry); + r5 = mac_with_carry(r5, k, MODULUS.0[2], &mut carry); + r6 = mac_with_carry(r6, k, MODULUS.0[3], &mut carry); + r7 = adc(r7, carry2, &mut carry); + (self.0).0[0] = r4; + (self.0).0[1] = r5; + (self.0).0[2] = r6; + (self.0).0[3] = r7; + self.reduce(); + } + + fn branchless_reduction(&mut self) { + let mut borrow = 0; + let mut tmp = *self; + + for (a, &b) in (tmp.0).0.iter_mut().zip(MODULUS.0.iter()) { + *a = sbb(*a, b, &mut borrow); + } + + for (target, reduced) in (self.0).0.iter_mut().zip((tmp.0).0.iter()) { + // if there was a borrow then + // - it's equal to 2^64 - 1 + // - we take original value + // otherwise + // - bitflip borrow + // take reduced value + + *target = (*target & borrow) | (*reduced & (!borrow)); + } + } + + // fn mul_bits>(&self, bits: BitIterator) -> Self { + // let mut res = Self::zero(); + // for bit in bits { + // res.double(); + + // if bit { + // res.add_assign(self) + // } + // } + // res + // } + + pub fn rps_mul_assign(&mut self, other: &Fs) { + let [b0, b1, b2, b3] = (other.0).0; + + let a0 = (self.0).0[0]; + let a1 = (self.0).0[1]; + + // two temporary registers + let (t1, t0) = full_width_mul(a1, b1); + // make product a0 * b1, with low part going into z1 (which is empty) + // and high part is summed with t0 and propagated to t1 + let (z1, t0, t1) = mul_and_add_existing_high(a0, b1, t0, t1); + // make product a0 * b1, with low part beign added with z1, and then propagated + // and high part is summed with t0 and propagated to t1 + let (z1, t0, t1) = mul_and_add_existing(a1, b0, z1, t0, t1); + // make product a0 * b0, and propagate everything + let (z0, z1, z2, z3, c0) = mul_and_full_block_propagate(a0, b0, 0, z1, 0, 0, t0, t1); + + // round 2 + + let (t1, t0) = full_width_mul(a1, b3); + let (z3, t0, t1) = mul_and_add_existing(a0, b3, z3, t0, t1); + let (z3, t0, t1) = mul_and_add_existing(a1, b2, z3, t0, t1); + // we place c0 instead of empty (yet) z4 + let (z2, z3, z4, z5, c1) = mul_and_full_block_propagate(a0, b2, z2, z3, c0, 0, t0, t1); + + // round 3 + + drop(a0); + drop(a1); + + let a2 = (self.0).0[2]; + let a3 = (self.0).0[3]; + + let (t1, t0) = full_width_mul(a3, b1); + let (z3, t0, t1) = mul_and_add_existing(a2, b1, z3, t0, t1); + let (z3, t0, t1) = mul_and_add_existing(a3, b0, z3, t0, t1); + let (z2, z3, z4, z5, c1) = mul_and_full_block_propagate_into_existing_carry_catcher(a0, b2, z2, z3, z4, z5, t0, t1, c1); + + // round 4 + + let (t1, t0) = full_width_mul(a3, b3); + let (z5, t0, t1) = mul_and_add_existing(a3, b2, z5, t0, t1); + let (z5, t0, t1) = mul_and_add_existing(a2, b3, z5, t0, t1); + let (z4, z5, z6, z7) = mul_and_full_block_propagate_without_carry_catch(a2, b2, z4, z5, c1, 0, t0, t1); + + self.mont_reduce(z0, z1, z2, z3, z4, z5, z6, z7); + } + + pub fn optimistic_cios_mul_assign(&mut self, other: &Fs) { + let mut m; + + let [q0, q1, q2, q3] = MODULUS.0; + let [a0, a1, a2, a3] = (self.0).0; + + // round 0 + let b0 = (other.0).0[0]; + let (r0, carry) = full_width_mul(a0, b0); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = wrapping_mac_by_value(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_by_value(carry, a1, b0); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_by_value(carry, a2, b0); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_by_value(carry, a3, b0); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b0); + + // round 1 + let b1 = (other.0).0[1]; + let (r0, carry) = mac_by_value(r0, a0, b1); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = wrapping_mac_by_value(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, carry, a1, b1); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, carry, a2, b1); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, carry, a3, b1); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b1); + + // round 2 + let b2 = (other.0).0[2]; + let (r0, carry) = mac_by_value(r0, a0, b2); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = wrapping_mac_by_value(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, carry, a1, b2); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, carry, a2, b2); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, carry, a3, b2); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b2); + + // round 3 + let b3 = (other.0).0[3]; + let (r0, carry) = mac_by_value(r0, a0, b3); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = wrapping_mac_by_value(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, carry, a1, b3); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, carry, a2, b3); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, carry, a3, b3); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b3); + + (self.0).0 = [r0, r1, r2, r3]; + self.reduce(); + } + + pub fn optimistic_cios_mul_assign_with_different_semantics(&mut self, other: &Fs) { + let mut m; + + // round 0 + let b0 = (other.0).0[0]; + let (r0, carry) = full_width_mul((self.0).0[0], (other.0).0[0]); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = wrapping_mac_by_value(r0, m, MODULUS.0[1]); + + // loop over the rest + let (r1, carry) = mac_by_value(carry, (self.0).0[1], b0); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, MODULUS.0[1], red_carry); + + let (r2, carry) = mac_by_value(carry, (self.0).0[2], b0); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, MODULUS.0[2], red_carry); + + let (r3, carry) = mac_by_value(carry, (self.0).0[3], b0); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, MODULUS.0[3], red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b0); + + // round 1 + let b1 = (other.0).0[1]; + let (r0, carry) = mac_by_value(r0, (self.0).0[0], b1); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = wrapping_mac_by_value(r0, m, MODULUS.0[0]); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, carry, (self.0).0[1], b1); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, MODULUS.0[1], red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, carry, (self.0).0[2], b1); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, MODULUS.0[2], red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, carry, (self.0).0[3], b1); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, MODULUS.0[3], red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b1); + + // round 2 + let b2 = (other.0).0[2]; + let (r0, carry) = mac_by_value(r0, (self.0).0[0], b2); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = wrapping_mac_by_value(r0, m, MODULUS.0[0]); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, carry, (self.0).0[1], b2); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, MODULUS.0[1], red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, carry, (self.0).0[2], b2); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, MODULUS.0[2], red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, carry, (self.0).0[3], b2); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, MODULUS.0[3], red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b2); + + // round 3 + let b3 = (other.0).0[3]; + let (r0, carry) = mac_by_value(r0, (self.0).0[0], b3); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = wrapping_mac_by_value(r0, m, MODULUS.0[0]); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, carry, (self.0).0[1], b3); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, MODULUS.0[1], red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, carry, (self.0).0[2], b3); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, MODULUS.0[2], red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, carry, (self.0).0[3], b3); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, MODULUS.0[3], red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b3); + + (self.0).0 = [r0, r1, r2, r3]; + self.reduce(); + } + + pub fn optimistic_cios_by_value(self, other: Fs) -> Self { + let [q0, q1, q2, q3] = MODULUS.0; + let [b0, b1, b2, b3] = (other.0).0; + + // round 0 + let a0 = (self.0).0[0]; + let (r0, carry) = full_width_mul(a0, b0); + let m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = mac_by_value_return_carry_only(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_by_value(carry, a0, b1); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_by_value(carry, a0, b2); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_by_value(carry, a0, b3); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(a0); + drop(m); + + // round 1 + let a1 = (self.0).0[1]; + let (r0, carry) = mac_by_value(r0, a1, b0); + let m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = mac_by_value_return_carry_only(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, a1, b1, carry); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, a1, b2, carry); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, a1, b3, carry); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(a1); + drop(m); + + // round 2 + let a2 = (self.0).0[2]; + let (r0, carry) = mac_by_value(r0, a2, b0); + let m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = mac_by_value_return_carry_only(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, a2, b1, carry); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, a2, b2, carry); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, a2, b3, carry); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(a2); + drop(m); + + // round 3 + let a3 = (self.0).0[3]; + let (r0, carry) = mac_by_value(r0, a3, b0); + let m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = mac_by_value_return_carry_only(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, a3, b1, carry); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, a3, b2, carry); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, a3, b3, carry); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b3); + drop(m); + + let mut result = Fs(FsRepr([r0, r1, r2, r3])); + result.reduce(); + + result + } + + pub fn optimistic_cios_by_value_with_partial_red(self, other: Fs) -> Self { + let mut m; + + let [q0, q1, q2, q3] = MODULUS.0; + let [a0, a1, a2, a3] = (self.0).0; + + // round 0 + let b0 = (other.0).0[0]; + let (r0, carry) = full_width_mul(a0, b0); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = mac_by_value_return_carry_only(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_by_value(carry, a1, b0); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_by_value(carry, a2, b0); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_by_value(carry, a3, b0); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b0); + + // round 1 + let b1 = (other.0).0[1]; + let (r0, carry) = mac_by_value(r0, a0, b1); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = mac_by_value_return_carry_only(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, carry, a1, b1); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, carry, a2, b1); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, carry, a3, b1); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b1); + + // round 2 + let b2 = (other.0).0[2]; + let (r0, carry) = mac_by_value(r0, a0, b2); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = wrapping_mac_by_value(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, carry, a1, b2); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, carry, a2, b2); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, carry, a3, b2); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b2); + + // round 3 + let b3 = (other.0).0[3]; + let (r0, carry) = mac_by_value(r0, a0, b3); + m = r0.wrapping_mul(INV); + // everywhere semantic is arg0 + (arg1 * arg2) + let red_carry = wrapping_mac_by_value(r0, m, q0); + + // loop over the rest + let (r1, carry) = mac_with_carry_by_value(r1, carry, a1, b3); + let (r0, red_carry) = mac_with_carry_by_value(r1, m, q1, red_carry); + + let (r2, carry) = mac_with_carry_by_value(r2, carry, a2, b3); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, q2, red_carry); + + let (r3, carry) = mac_with_carry_by_value(r3, carry, a3, b3); + let (r2, red_carry) = mac_with_carry_by_value(r3, m, q3, red_carry); + + // this will check overflow in debug + let r3 = red_carry + carry; + drop(b3); + + Fs(FsRepr([r0, r1, r2, r3])) + } + + pub fn mulx_latency_mul_assign(&mut self, other: &Fs) { + let [b0, b1, b2, b3] = (other.0).0; + let [a0, a1, a2, a3] = (self.0).0; + + // round 0 + let (r1, r0) = full_width_mul(a0, b0); + let (r2_0, r1_0) = full_width_mul(a1, b0); + let (r2_1, r1_1) = full_width_mul(a0, b1); + let (r3, r2) = full_width_mul(a1, b1); + + // now propagate carries + let (r1, carry) = add(r1, r1_0); + // hope for carry chains with add and addx + let (r1, carry2) = add(r1, r1_1); + // + let (r2, carry) = add_three(r2_0, r2, carry); + let (r2_1, carry2) = add(r2_1, carry2); + // + let (r2, carry2) = add(r2, r2_1); + let r3 = r3 + carry + carry2; + + // round 1 + let (r3_0, r2_0) = full_width_mul(a2, b0); + let (r4_0, r3_0) = full_width_mul(a3, b0); + let (r4_1, r3_1) = full_width_mul(a2, b1); + let (r5, r4) = full_width_mul(a3, b1); + + let (r2, carry) = add(r2, r2_0); + + // now propagate carries + let (r3, carry) = add_three(r3, r3_0, carry); + // hope for carry chains with add and addx + let (r3, carry2) = add(r3, r3_1); + // + let (r4, carry) = add_three(r4_0, r4, carry); + let (r4_1, carry2) = add(r4_1, carry2); + // + let (r4, carry2) = add(r4, r4_1); + let (r5, r6) = add_three(r5, carry, carry2); + + // round 3 + let (r3_0, r2_0) = full_width_mul(a0, b2); + let (r4_0, r3_0) = full_width_mul(a0, b3); + let (r4_1, r3_1) = full_width_mul(a1, b2); + let (r5_0, r4_0) = full_width_mul(a1, b3); + + let (r2, carry) = add(r2, r2_0); + + // now propagate carries + let (r3, carry) = add_three(r3, r3_0, carry); + // hope for carry chains with add and addx + let (r3, carry2) = add(r3, r3_1); + // + let (r4, carry) = add_three(r4_0, r4, carry); + let (r4_1, carry2) = add(r4_1, carry2); + // + let (r4, carry2) = add(r4, r4_1); + let (r5, carry) = add_four(r5, r5_0, carry, carry2); + + let r6 = r6 + carry; + + // round 4 + let (r5_0, r4_0) = full_width_mul(a2, b2); + let (r6_0, r5_0) = full_width_mul(a2, b3); + let (r6_1, r5_1) = full_width_mul(a3, b2); + let (r7, r6_0) = full_width_mul(a3, b3); + + let (r4, carry) = add(r4, r2_0); + + // now propagate carries + let (r5, carry) = add_three(r5, r5_0, carry); + // hope for carry chains with add and addx + let (r5, carry2) = add(r5, r5_1); + // + let (r6, carry) = add_three(r6_0, r6, carry); + let (r6_1, carry2) = add(r6_1, carry2); + // + let (r6, carry2) = add(r6, r6_1); + let r7 = r7 + carry + carry2; + + self.mont_reduce(r0, r1, r2, r3, r4, r5, r6, r7); + } + + pub fn asm_mul_assign(&mut self, other: &Fs) { + *self = Fs(FsRepr(mont_mul_asm(&(self.0).0, &(other.0).0))); + } +} + +impl SqrtField for Fs { + + fn legendre(&self) -> LegendreSymbol { + // s = self^((s - 1) // 2) + let s = self.pow([0x684b872f6b7b965b, 0x53341049e6640841, 0x83339d80809a1d80, 0x73eda753299d7d4]); + if s == Self::zero() { Zero } + else { QuadraticNonResidue } + } + + fn sqrt(&self) -> Option { + // Shank's algorithm for s mod 4 = 3 + // https://eprint.iacr.org/2012/685.pdf (page 9, algorithm 2) + + // a1 = self^((s - 3) // 4) + let mut a1 = self.pow([0xb425c397b5bdcb2d, 0x299a0824f3320420, 0x4199cec0404d0ec0, 0x39f6d3a994cebea]); + let mut a0 = a1; + a0.square(); + a0.mul_assign(self); + + if a0 == NEGATIVE_ONE + { + None + } + else + { + a1.mul_assign(self); + Some(a1) + } + } +} + +#[inline(always)] +pub fn full_width_mul(a: u64, b: u64) -> (u64, u64) { + let tmp = (a as u128) * (b as u128); + + return (tmp as u64, (tmp >> 64) as u64); +} + +#[inline(always)] +pub fn mac_with_carry_by_value(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) { + let tmp = ((b as u128) * (c as u128)) + (a as u128) + (carry as u128); + + (tmp as u64, (tmp >> 64) as u64) +} + +#[inline(always)] +pub fn add(a: u64, b: u64) -> (u64, u64) { + let tmp = (a as u128) + (b as u128); + + (tmp as u64, (tmp >> 64) as u64) +} + +#[inline(always)] +pub fn add_three(a: u64, b: u64, c: u64) -> (u64, u64) { + let tmp = (a as u128) + (b as u128) + (c as u128); + + (tmp as u64, (tmp >> 64) as u64) +} + +#[inline(always)] +pub fn add_four(a: u64, b: u64, c: u64, d: u64) -> (u64, u64) { + let tmp = (a as u128) + (b as u128) + (c as u128) + (d as u128); + + (tmp as u64, (tmp >> 64) as u64) +} + +#[inline(always)] +pub fn wrapping_mac_by_value(a: u64, b: u64, c: u64) -> u64 { + b.wrapping_mul(c).wrapping_add(a) +} + +#[inline(always)] +pub fn mac_by_value_return_carry_only(a: u64, b: u64, c: u64) -> u64 { + let tmp = ((b as u128) * (c as u128)) + (a as u128); + + (tmp >> 64) as u64 +} + +#[inline(always)] +pub fn mac_by_value(a: u64, b: u64, c: u64) -> (u64, u64) { + let tmp = ((b as u128) * (c as u128)) + (a as u128); + + (tmp as u64, (tmp >> 64) as u64) +} + +#[inline(always)] +pub fn mul_with_high_carry_by_value(a: u64, b: u64, c: u64, mut carry: u64) -> (u64, u64, u64) { + let (hi, lo) = full_width_mul(b, c); + let (hi, of) = hi.overflowing_add(a); + if of { + carry += 1; + } + + (hi, lo, carry) +} + +#[inline(always)] +pub fn mul_and_add_existing_high(a: u64, b: u64, existing_hi: u64, carry: u64) -> (u64, u64, u64) { + let tmp = (a as u128) * (b as u128); + + let hi = tmp >> 64; + let lo = tmp as u64; + + let tmp = hi + (existing_hi as u128) + ((carry as u128) << 64); + + let carry = (tmp >> 64) as u64; + let hi = tmp as u64; + + (lo, hi, carry) +} + +#[inline(always)] +pub fn mul_and_add_existing(a: u64, b: u64, existing_lo: u64, existing_hi: u64, carry: u64) -> (u64, u64, u64) { + let tmp = ((a as u128) * (b as u128)) + (existing_lo as u128); + + let hi = tmp >> 64; + let lo = tmp as u64; + + let tmp = hi + (existing_hi as u128) + ((carry as u128) << 64); + + let carry = (tmp >> 64) as u64; + let hi = tmp as u64; + + (lo, hi, carry) +} + +#[inline(always)] +pub fn mul_and_full_block_propagate(a: u64, b: u64, z0: u64, z1: u64, z2: u64, z3: u64, t0: u64, t1: u64) -> (u64, u64, u64, u64, u64) { + let tmp = ((a as u128) * (b as u128)) + (z0 as u128); + + let hi = tmp >> 64; + let z0 = tmp as u64; + + // let tmp = hi + (z1 as u128); + + let tmp = hi + (z1 as u128) + ((t0 as u128) << 64); + + let hi = tmp >> 64; + let z1 = tmp as u64; + + let tmp = hi + (z2 as u128) + ((t1 as u128) << 64); + + let hi = tmp >> 64; + let z2 = tmp as u64; + + let tmp = hi + (z3 as u128); + + let c0 = (tmp >> 64) as u64; + let z3 = tmp as u64; + + (z0, z1, z2, z3, c0) +} + +#[inline(always)] +pub fn mul_and_full_block_propagate_into_existing_carry_catcher(a: u64, b: u64, z0: u64, z1: u64, z2: u64, z3: u64, t0: u64, t1: u64, c: u64) -> (u64, u64, u64, u64, u64) { + let tmp = ((a as u128) * (b as u128)) + (z0 as u128); + + let hi = tmp >> 64; + let z0 = tmp as u64; + + // let tmp = hi + (z1 as u128); + + let tmp = hi + (z1 as u128) + ((t0 as u128) << 64); + + let hi = tmp >> 64; + let z1 = tmp as u64; + + let tmp = hi + (z2 as u128) + ((t1 as u128) << 64); + + let hi = tmp >> 64; + let z2 = tmp as u64; + + let tmp = hi + (z3 as u128) + ((c as u128) << 64); + + let c0 = (tmp >> 64) as u64; + let z3 = tmp as u64; + + (z0, z1, z2, z3, c0) +} + +#[inline(always)] +pub fn mul_and_full_block_propagate_without_carry_catch(a: u64, b: u64, z0: u64, z1: u64, z2: u64, z3: u64, t0: u64, t1: u64) -> (u64, u64, u64, u64) { + let tmp = ((a as u128) * (b as u128)) + (z0 as u128); + + let hi = tmp >> 64; + let z0 = tmp as u64; + + // let tmp = hi + (z1 as u128); + + let tmp = hi + (z1 as u128) + ((t0 as u128) << 64); + + let hi = tmp >> 64; + let z1 = tmp as u64; + + let tmp = hi + (z2 as u128) + ((t1 as u128) << 64); + + let hi = (tmp >> 64) as u64; + let z2 = tmp as u64; + + let z3 = hi + z3; + + (z0, z1, z2, z3) +} + + +// // Computes one "row" of multiplication: [a0, a1, a2, a3] * b0 +// // Uses MULX +// #[allow(dead_code)] +// #[inline(always)] +// #[cfg(all(target_arch = "x86_64", target_feature = "adx"))] +// pub(crate) fn mul_1_asm(a: u64, b0: u64, b1: u64, b2: u64, b3: u64) -> (u64, u64, u64, u64, u64) { +// let r0: u64; +// let r1: u64; +// let r2: u64; +// let r3: u64; +// let r4: u64; +// let _lo: u64; +// // Binding `_lo` will not be used after assignment. +// #[allow(clippy::used_underscore_binding)] +// unsafe { +// asm!(r" +// mulx $7, $0, $1 // (r0, r1) = a * b0 +// mulx $8, $5, $2 // (lo, r2) = a * b1 +// add $5, $1 // r1 += lo (carry in CF) +// mulx $9, $5, $3 // (lo, r3) = a * b2 +// adc $5, $2 // r2 += lo + CF (carry in CF) +// mulx $10, $5, $4 // (lo, r4) = a * b3 +// adc $5, $3 // r3 += lo + CF (carry in CF) +// adc $11, $4 // r4 += 0 + CF (no carry, CF to 0) +// " +// : // Output constraints +// "=&r"(r0), // $0 r0..4 are in registers +// "=&r"(r1), // $1 +// "=&r"(r2), // $2 +// "=&r"(r3), // $3 +// "=&r"(r4) // $4 +// "=&r"(_lo) // $5 Temporary values can be in any register +// : // Input constraints +// "{rdx}"(a), // $6 a must be in RDX for MULX to work +// "rm"(b0), // $7 b0..b3 can be register or memory +// "rm"(b1), // $8 +// "rm"(b2), // $9 +// "rm"(b3), // $10 +// "i"(0) // $11 Immediate zero +// : // Clobbers +// "cc" // Flags +// ) +// } +// (r0, r1, r2, r3, r4) +// } + +// // Computes r[0..4] += a * b[0..4], returns carry +// // Uses MULX and ADCX/ADOX carry chain +// // Currently unused +// #[allow(dead_code)] +// #[allow(clippy::too_many_arguments)] +// #[inline(always)] +// #[cfg(all(target_arch = "x86_64", target_feature = "adx"))] +// pub(crate) fn mul_add_1_asm( +// r0: &mut u64, +// r1: &mut u64, +// r2: &mut u64, +// r3: &mut u64, +// a: u64, +// b0: u64, +// b1: u64, +// b2: u64, +// b3: u64, +// ) -> u64 { +// let _lo: u64; +// let _hi: u64; +// let r4: u64; +// // Bindings `_lo` and `_hi` will not be used after assignment. +// #[allow(clippy::used_underscore_binding)] +// unsafe { +// asm!(r" +// xor $4, $4 // r4 = CF = OF 0 +// mulx $8, $5, $6 // a * b0 +// adcx $5, $0 // r0 += lo + CF (carry in CF) +// adox $6, $1 // r1 += hi + OF (carry in OF) +// mulx $9, $5, $6 // a * b1 +// adcx $5, $1 // r1 += lo + CF (carry in CF) +// adox $6, $2 // r2 += hi + OF (carry in OF) +// mulx $10, $5, $6 // a * b2 +// adcx $5, $2 // r2 += lo + CF (carry in CF) +// adox $6, $3 // r3 += hi + OF (carry in OF) +// mulx $11, $5, $6 // a * b3 +// adcx $5, $3 // r3 += lo + CF (carry in CF) +// adcx $4, $4 // r4 += CF (no carry, CF = 0) +// adox $6, $4 // r4 += hi + OF (no carry, OF = 0) +// " +// : // Output constraints +// "+r"(*r0), // $0 r0..3 are in register and modified in place +// "+r"(*r1), // $1 +// "+r"(*r2), // $2 +// "+r"(*r3), // $3 +// "=&r"(r4) // $4 r4 is output to a register +// "=&r"(_lo), // $5 Temporary values can be in any register +// "=&r"(_hi) // $6 +// : // Input constraints +// "{rdx}"(a), // $7 a must be in RDX for MULX to work +// "rm"(b0), // $8 Second operand can be register or memory +// "rm"(b1), // $9 Second operand can be register or memory +// "rm"(b2), // $10 Second operand can be register or memory +// "rm"(b3) // $11 Second operand can be register or memory +// : // Clobbers +// "cc" // Flags +// ) +// } +// r4 +// } + +// Currently unused +#[allow(dead_code)] +#[allow(clippy::too_many_lines)] +#[inline(always)] +#[cfg(all(target_arch = "x86_64", target_feature = "adx"))] +pub fn mont_mul_asm(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + use core::mem::MaybeUninit; + + const ZERO: u64 = 0; // $3 + + let mut result = MaybeUninit::<[u64; 4]>::uninit(); + // MULX dst_high, dst_low, src_b (src_a = %rdx) + // src_b can be register or memory, not immediate + unsafe { + llvm_asm!(r" + // Assembly from Aztec's Barretenberg implementation, see + // + movq 0($1), %rdx + xorq %r8, %r8 + mulxq 8($2), %r8, %r9 + mulxq 24($2), %rdi, %r12 + mulxq 0($2), %r13, %r14 + mulxq 16($2), %r15, %r10 + movq %r13, %rdx + mulxq $8, %rdx, %r11 + adcxq %r8, %r14 + adoxq %rdi, %r10 + adcxq %r9, %r15 + adoxq $3, %r12 + adcxq $3, %r10 + mulxq $4, %r8, %r9 + mulxq $5, %rdi, %r11 + adoxq %r8, %r13 + adcxq %rdi, %r14 + adoxq %r9, %r14 + adcxq %r11, %r15 + mulxq $6, %r8, %r9 + mulxq $7, %rdi, %r11 + adoxq %r8, %r15 + adcxq %rdi, %r10 + adoxq %r9, %r10 + adcxq %r11, %r12 + adoxq $3, %r12 + movq 8($1), %rdx + mulxq 0($2), %r8, %r9 + mulxq 8($2), %rdi, %r11 + adcxq %r8, %r14 + adoxq %r9, %r15 + adcxq %rdi, %r15 + adoxq %r11, %r10 + mulxq 16($2), %r8, %r9 + mulxq 24($2), %rdi, %r13 + adcxq %r8, %r10 + adoxq %rdi, %r12 + adcxq %r9, %r12 + adoxq $3, %r13 + adcxq $3, %r13 + movq %r14, %rdx + mulxq $8, %rdx, %r8 + mulxq $4, %r8, %r9 + mulxq $5, %rdi, %r11 + adoxq %r8, %r14 + adcxq %rdi, %r15 + adoxq %r9, %r15 + adcxq %r11, %r10 + mulxq $6, %r8, %r9 + mulxq $7, %rdi, %r11 + adoxq %r8, %r10 + adcxq %r9, %r12 + adoxq %rdi, %r12 + adcxq %r11, %r13 + adoxq $3, %r13 + movq 16($1), %rdx + mulxq 0($2), %r8, %r9 + mulxq 8($2), %rdi, %r11 + adcxq %r8, %r15 + adoxq %r9, %r10 + adcxq %rdi, %r10 + adoxq %r11, %r12 + mulxq 16($2), %r8, %r9 + mulxq 24($2), %rdi, %r14 + adcxq %r8, %r12 + adoxq %r9, %r13 + adcxq %rdi, %r13 + adoxq $3, %r14 + adcxq $3, %r14 + movq %r15, %rdx + mulxq $8, %rdx, %r8 + mulxq $4, %r8, %r9 + mulxq $5, %rdi, %r11 + adoxq %r8, %r15 + adcxq %r9, %r10 + adoxq %rdi, %r10 + adcxq %r11, %r12 + mulxq $6, %r8, %r9 + mulxq $7, %rdi, %r11 + adoxq %r8, %r12 + adcxq %r9, %r13 + adoxq %rdi, %r13 + adcxq %r11, %r14 + adoxq $3, %r14 + movq 24($1), %rdx + mulxq 0($2), %r8, %r9 + mulxq 8($2), %rdi, %r11 + adcxq %r8, %r10 + adoxq %r9, %r12 + adcxq %rdi, %r12 + adoxq %r11, %r13 + mulxq 16($2), %r8, %r9 + mulxq 24($2), %rdi, %r15 + adcxq %r8, %r13 + adoxq %r9, %r14 + adcxq %rdi, %r14 + adoxq $3, %r15 + adcxq $3, %r15 + movq %r10, %rdx + mulxq $8, %rdx, %r8 + mulxq $4, %r8, %r9 + mulxq $5, %rdi, %r11 + adoxq %r8, %r10 + adcxq %r9, %r12 + adoxq %rdi, %r12 + adcxq %r11, %r13 + mulxq $6, %r8, %r9 + mulxq $7, %rdi, %rdx + adoxq %r8, %r13 + adcxq %r9, %r14 + adoxq %rdi, %r14 + adcxq %rdx, %r15 + adoxq $3, %r15 + movq %r12, 0($0) + movq %r13, 8($0) + movq %r14, 16($0) + movq %r15, 24($0) + " + : + : "r"(result.as_mut_ptr()), + "r"(a), "r"(b), + "m"(ZERO), + "m"(MODULUS.0[0]), + "m"(MODULUS.0[1]), + "m"(MODULUS.0[2]), + "m"(MODULUS.0[3]), + "m"(INV) + : "rdx", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "cc", "memory" + ); + } + let result = unsafe { result.assume_init() }; + + result +} + +#[cfg(test)] +mod test { + use super::Fs; + use ff::{Field, PrimeField}; + + use rand::{*}; + + #[test] + fn test_optimistic_cios() { + let rng = &mut XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + for _ in 0..10000 { + let a: Fs = rng.gen(); + let b: Fs = rng.gen(); + + let mut c = a; + c.mul_assign(&b); + + let d = a.optimistic_cios_by_value(b); + + assert_eq!(c, d); + } + } +} \ No newline at end of file diff --git a/crates/ff/tester/src/test_large_cios_field.rs b/crates/ff/tester/src/test_large_cios_field.rs new file mode 100644 index 0000000..37f8e36 --- /dev/null +++ b/crates/ff/tester/src/test_large_cios_field.rs @@ -0,0 +1,8 @@ +use ff::*; + +#[derive(PrimeField)] +#[PrimeFieldModulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583"] +#[PrimeFieldGenerator = "2"] +#[OptimisticCIOSMultiplication = "true"] +#[OptimisticCIOSSquaring = "true"] +pub(crate) struct Fr(FrRepr); \ No newline at end of file diff --git a/crates/ff/tester/src/test_large_field.rs b/crates/ff/tester/src/test_large_field.rs new file mode 100644 index 0000000..3032bb1 --- /dev/null +++ b/crates/ff/tester/src/test_large_field.rs @@ -0,0 +1,24 @@ +use ff::*; + +#[derive(PrimeField)] +#[PrimeFieldModulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583"] +#[PrimeFieldGenerator = "2"] +pub(crate) struct Fr(FrRepr); + +#[test] +fn test_to_hex() { + let a = Fr::from_repr(FrRepr::from(2)).unwrap(); + assert_eq!("0000000000000000000000000000000000000000000000000000000000000002", to_hex(&a)); + println!("`2` into hex = {}", to_hex(&a)); +} + +#[test] +fn test_hash_impl() { + let mut hashset = std::collections::HashSet::new(); + + hashset.insert(Fr::from_repr(FrRepr::from(2)).unwrap()); + hashset.insert(Fr::from_repr(FrRepr::from(4)).unwrap()); + hashset.insert(Fr::from_repr(FrRepr::from(2)).unwrap()); + + assert_eq!(hashset.len(), 2); +} \ No newline at end of file diff --git a/crates/ff/tester/src/test_short_field.rs b/crates/ff/tester/src/test_short_field.rs new file mode 100644 index 0000000..603b62d --- /dev/null +++ b/crates/ff/tester/src/test_short_field.rs @@ -0,0 +1,21 @@ +use ff::*; + +#[derive(PrimeField)] +#[PrimeFieldModulus = "17"] +#[PrimeFieldGenerator = "3"] +struct ShortFr(ShortFrRepr); + +#[test] +fn test_short_square() { + let mut a = ShortFr::from_repr(ShortFrRepr::from(5)).unwrap(); + a.square(); + assert_eq!("0000000000000008", to_hex(&a)); + println!("`5*2 mod 17` into hex = {}", to_hex(&a)); +} + +#[test] +fn test_short_to_hex() { + let a = ShortFr::from_repr(ShortFrRepr::from(2)).unwrap(); + assert_eq!("0000000000000002", to_hex(&a)); + println!("`2` into hex = {}", to_hex(&a)); +} \ No newline at end of file diff --git a/crates/ff/tester/tmp.rs b/crates/ff/tester/tmp.rs new file mode 100644 index 0000000..00a6a70 --- /dev/null +++ b/crates/ff/tester/tmp.rs @@ -0,0 +1,811 @@ +mod test_large_cios_field { + use ff::*; + #[PrimeFieldModulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583"] + #[PrimeFieldGenerator = "2"] + #[OptimisticCIOSMultiplication = "true"] + #[OptimisticCIOSSquaring = "true"] + pub(crate) struct Fr(FrRepr); + /// This is the modulus m of the prime field + const MODULUS: FrRepr = FrRepr([ + 4332616871279656263u64, + 10917124144477883021u64, + 13281191951274694749u64, + 3486998266802970665u64, + ]); + /// The number of bits needed to represent the modulus. + const MODULUS_BITS: u32 = 254u32; + /// The number of bits that must be shaved from the beginning of + /// the representation when randomly sampling. + const REPR_SHAVE_BITS: u32 = 2u32; + /// 2^{limbs*64} mod m + const R: FrRepr = FrRepr([ + 15230403791020821917u64, + 754611498739239741u64, + 7381016538464732716u64, + 1011752739694698287u64, + ]); + /// 2^{limbs*64*2} mod m + const R2: FrRepr = FrRepr([ + 17522657719365597833u64, + 13107472804851548667u64, + 5164255478447964150u64, + 493319470278259999u64, + ]); + /// -(m^{-1} mod m) mod m + const INV: u64 = 9786893198990664585u64; + /// Multiplicative generator of `MODULUS` - 1 order, also quadratic + /// nonresidue. + const GENERATOR: FrRepr = FrRepr([ + 12014063508332092218u64, + 1509222997478479483u64, + 14762033076929465432u64, + 2023505479389396574u64, + ]); + /// 2^s * t = MODULUS - 1 with t odd + const S: u32 = 1u32; + /// 2^s root of unity computed by GENERATOR^t + const ROOT_OF_UNITY: FrRepr = FrRepr([ + 15230403791020821917u64, + 754611498739239741u64, + 7381016538464732716u64, + 1011752739694698287u64, + ]); + pub struct FrRepr(pub [u64; 4usize]); + #[automatically_derived] + #[allow(unused_qualifications)] + impl ::core::marker::Copy for FrRepr {} + #[automatically_derived] + #[allow(unused_qualifications)] + impl ::core::clone::Clone for FrRepr { + #[inline] + fn clone(&self) -> FrRepr { + { + let _: ::core::clone::AssertParamIsClone<[u64; 4usize]>; + *self + } + } + } + impl ::core::marker::StructuralPartialEq for FrRepr {} + #[automatically_derived] + #[allow(unused_qualifications)] + impl ::core::cmp::PartialEq for FrRepr { + #[inline] + fn eq(&self, other: &FrRepr) -> bool { + match *other { + FrRepr(ref __self_1_0) => match *self { + FrRepr(ref __self_0_0) => (*__self_0_0) == (*__self_1_0), + }, + } + } + #[inline] + fn ne(&self, other: &FrRepr) -> bool { + match *other { + FrRepr(ref __self_1_0) => match *self { + FrRepr(ref __self_0_0) => (*__self_0_0) != (*__self_1_0), + }, + } + } + } + impl ::core::marker::StructuralEq for FrRepr {} + #[automatically_derived] + #[allow(unused_qualifications)] + impl ::core::cmp::Eq for FrRepr { + #[inline] + #[doc(hidden)] + fn assert_receiver_is_total_eq(&self) -> () { + { + let _: ::core::cmp::AssertParamIsEq<[u64; 4usize]>; + } + } + } + #[automatically_derived] + #[allow(unused_qualifications)] + impl ::core::default::Default for FrRepr { + #[inline] + fn default() -> FrRepr { + FrRepr(::core::default::Default::default()) + } + } + impl ::std::fmt::Debug for FrRepr { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + f.write_fmt(::core::fmt::Arguments::new_v1( + &["0x"], + &match () { + () => [], + }, + ))?; + for i in self.0.iter().rev() { + f.write_fmt(::core::fmt::Arguments::new_v1_formatted( + &[""], + &match (&*i,) { + (arg0,) => [::core::fmt::ArgumentV1::new( + arg0, + ::core::fmt::LowerHex::fmt, + )], + }, + &[::core::fmt::rt::v1::Argument { + position: 0usize, + format: ::core::fmt::rt::v1::FormatSpec { + fill: ' ', + align: ::core::fmt::rt::v1::Alignment::Unknown, + flags: 8u32, + precision: ::core::fmt::rt::v1::Count::Implied, + width: ::core::fmt::rt::v1::Count::Is(16usize), + }, + }], + ))?; + } + Ok(()) + } + } + impl ::rand::Rand for FrRepr { + #[inline(always)] + fn rand(rng: &mut R) -> Self { + FrRepr(rng.gen()) + } + } + impl ::std::fmt::Display for FrRepr { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + f.write_fmt(::core::fmt::Arguments::new_v1( + &["0x"], + &match () { + () => [], + }, + ))?; + for i in self.0.iter().rev() { + f.write_fmt(::core::fmt::Arguments::new_v1_formatted( + &[""], + &match (&*i,) { + (arg0,) => [::core::fmt::ArgumentV1::new( + arg0, + ::core::fmt::LowerHex::fmt, + )], + }, + &[::core::fmt::rt::v1::Argument { + position: 0usize, + format: ::core::fmt::rt::v1::FormatSpec { + fill: ' ', + align: ::core::fmt::rt::v1::Alignment::Unknown, + flags: 8u32, + precision: ::core::fmt::rt::v1::Count::Implied, + width: ::core::fmt::rt::v1::Count::Is(16usize), + }, + }], + ))?; + } + Ok(()) + } + } + impl std::hash::Hash for FrRepr { + fn hash(&self, state: &mut H) { + for limb in self.0.iter() { + limb.hash(state); + } + } + } + impl AsRef<[u64]> for FrRepr { + #[inline(always)] + fn as_ref(&self) -> &[u64] { + &self.0 + } + } + impl AsMut<[u64]> for FrRepr { + #[inline(always)] + fn as_mut(&mut self) -> &mut [u64] { + &mut self.0 + } + } + impl From for FrRepr { + #[inline(always)] + fn from(val: u64) -> FrRepr { + use std::default::Default; + let mut repr = Self::default(); + repr.0[0] = val; + repr + } + } + impl Ord for FrRepr { + #[inline(always)] + fn cmp(&self, other: &FrRepr) -> ::std::cmp::Ordering { + for (a, b) in self.0.iter().rev().zip(other.0.iter().rev()) { + if a < b { + return ::std::cmp::Ordering::Less; + } else if a > b { + return ::std::cmp::Ordering::Greater; + } + } + ::std::cmp::Ordering::Equal + } + } + impl PartialOrd for FrRepr { + #[inline(always)] + fn partial_cmp(&self, other: &FrRepr) -> Option<::std::cmp::Ordering> { + Some(self.cmp(other)) + } + } + impl crate::ff::PrimeFieldRepr for FrRepr { + #[inline(always)] + fn is_odd(&self) -> bool { + self.0[0] & 1 == 1 + } + #[inline(always)] + fn is_even(&self) -> bool { + !self.is_odd() + } + #[inline(always)] + fn is_zero(&self) -> bool { + self.0.iter().all(|&e| e == 0) + } + #[inline(always)] + fn shr(&mut self, mut n: u32) { + if n as usize >= 64 * 4usize { + *self = Self::from(0); + return; + } + while n >= 64 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + ::std::mem::swap(&mut t, i); + } + n -= 64; + } + if n > 0 { + let mut t = 0; + for i in self.0.iter_mut().rev() { + let t2 = *i << (64 - n); + *i >>= n; + *i |= t; + t = t2; + } + } + } + #[inline(always)] + fn div2(&mut self) { + let mut t = 0; + for i in self.0.iter_mut().rev() { + let t2 = *i << 63; + *i >>= 1; + *i |= t; + t = t2; + } + } + #[inline(always)] + fn mul2(&mut self) { + let mut last = 0; + for i in &mut self.0 { + let tmp = *i >> 63; + *i <<= 1; + *i |= last; + last = tmp; + } + } + #[inline(always)] + fn shl(&mut self, mut n: u32) { + if n as usize >= 64 * 4usize { + *self = Self::from(0); + return; + } + while n >= 64 { + let mut t = 0; + for i in &mut self.0 { + ::std::mem::swap(&mut t, i); + } + n -= 64; + } + if n > 0 { + let mut t = 0; + for i in &mut self.0 { + let t2 = *i >> (64 - n); + *i <<= n; + *i |= t; + t = t2; + } + } + } + #[inline(always)] + fn num_bits(&self) -> u32 { + let mut ret = (4usize as u32) * 64; + for i in self.0.iter().rev() { + let leading = i.leading_zeros(); + ret -= leading; + if leading != 64 { + break; + } + } + ret + } + #[inline(always)] + fn add_nocarry(&mut self, other: &FrRepr) { + let mut carry = 0; + for (a, b) in self.0.iter_mut().zip(other.0.iter()) { + *a = crate::ff::adc(*a, *b, &mut carry); + } + } + #[inline(always)] + fn sub_noborrow(&mut self, other: &FrRepr) { + let mut borrow = 0; + for (a, b) in self.0.iter_mut().zip(other.0.iter()) { + *a = crate::ff::sbb(*a, *b, &mut borrow); + } + } + } + impl ::std::marker::Copy for Fr {} + impl ::std::clone::Clone for Fr { + fn clone(&self) -> Fr { + *self + } + } + impl ::std::cmp::PartialEq for Fr { + fn eq(&self, other: &Fr) -> bool { + self.0 == other.0 + } + } + impl ::std::cmp::Eq for Fr {} + impl ::std::fmt::Debug for Fr { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + f.write_fmt(::core::fmt::Arguments::new_v1( + &["", "(", ")"], + &match (&"Fr", &self.into_repr()) { + (arg0, arg1) => [ + ::core::fmt::ArgumentV1::new(arg0, ::core::fmt::Display::fmt), + ::core::fmt::ArgumentV1::new(arg1, ::core::fmt::Debug::fmt), + ], + }, + )) + } + } + /// Elements are ordered lexicographically. + impl Ord for Fr { + #[inline(always)] + fn cmp(&self, other: &Fr) -> ::std::cmp::Ordering { + self.into_repr().cmp(&other.into_repr()) + } + } + impl PartialOrd for Fr { + #[inline(always)] + fn partial_cmp(&self, other: &Fr) -> Option<::std::cmp::Ordering> { + Some(self.cmp(other)) + } + } + impl ::std::fmt::Display for Fr { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + f.write_fmt(::core::fmt::Arguments::new_v1( + &["", "(", ")"], + &match (&"Fr", &self.into_repr()) { + (arg0, arg1) => [ + ::core::fmt::ArgumentV1::new(arg0, ::core::fmt::Display::fmt), + ::core::fmt::ArgumentV1::new(arg1, ::core::fmt::Display::fmt), + ], + }, + )) + } + } + impl ::rand::Rand for Fr { + /// Computes a uniformly random element using rejection sampling. + fn rand(rng: &mut R) -> Self { + loop { + let mut tmp = Fr(FrRepr::rand(rng)); + tmp.0.as_mut()[3usize] &= 0xffffffffffffffff >> REPR_SHAVE_BITS; + if tmp.is_valid() { + return tmp; + } + } + } + } + impl From for FrRepr { + fn from(e: Fr) -> FrRepr { + e.into_repr() + } + } + impl crate::ff::PrimeField for Fr { + type Repr = FrRepr; + fn from_repr(r: FrRepr) -> Result { + let mut r = Fr(r); + if r.is_valid() { + r.mul_assign(&Fr(R2)); + Ok(r) + } else { + Err(crate::ff::PrimeFieldDecodingError::NotInField({ + let res = ::alloc::fmt::format(::core::fmt::Arguments::new_v1( + &[""], + &match (&r.0,) { + (arg0,) => [::core::fmt::ArgumentV1::new( + arg0, + ::core::fmt::Display::fmt, + )], + }, + )); + res + })) + } + } + fn from_raw_repr(r: FrRepr) -> Result { + let mut r = Fr(r); + if r.is_valid() { + Ok(r) + } else { + Err(crate::ff::PrimeFieldDecodingError::NotInField({ + let res = ::alloc::fmt::format(::core::fmt::Arguments::new_v1( + &[""], + &match (&r.0,) { + (arg0,) => [::core::fmt::ArgumentV1::new( + arg0, + ::core::fmt::Display::fmt, + )], + }, + )); + res + })) + } + } + fn into_repr(&self) -> FrRepr { + let mut r = *self; + r.mont_reduce( + (self.0).0[0usize], + (self.0).0[1usize], + (self.0).0[2usize], + (self.0).0[3usize], + 0, + 0, + 0, + 0, + ); + r.0 + } + fn into_raw_repr(&self) -> FrRepr { + let r = *self; + r.0 + } + fn char() -> FrRepr { + MODULUS + } + const NUM_BITS: u32 = MODULUS_BITS; + const CAPACITY: u32 = Self::NUM_BITS - 1; + fn multiplicative_generator() -> Self { + Fr(GENERATOR) + } + const S: u32 = S; + fn root_of_unity() -> Self { + Fr(ROOT_OF_UNITY) + } + } + impl crate::ff::Field for Fr { + #[inline] + fn zero() -> Self { + Fr(FrRepr::from(0)) + } + #[inline] + fn one() -> Self { + Fr(R) + } + #[inline] + fn is_zero(&self) -> bool { + self.0.is_zero() + } + #[inline] + fn add_assign(&mut self, other: &Fr) { + self.0.add_nocarry(&other.0); + self.reduce(); + } + #[inline] + fn double(&mut self) { + self.0.mul2(); + self.reduce(); + } + #[inline] + fn sub_assign(&mut self, other: &Fr) { + if other.0 > self.0 { + self.0.add_nocarry(&MODULUS); + } + self.0.sub_noborrow(&other.0); + } + #[inline] + fn negate(&mut self) { + if !self.is_zero() { + let mut tmp = MODULUS; + tmp.sub_noborrow(&self.0); + self.0 = tmp; + } + } + fn inverse(&self) -> Option { + if self.is_zero() { + None + } else { + let one = FrRepr::from(1); + let mut u = self.0; + let mut v = MODULUS; + let mut b = Fr(R2); + let mut c = Self::zero(); + while u != one && v != one { + while u.is_even() { + u.div2(); + if b.0.is_even() { + b.0.div2(); + } else { + b.0.add_nocarry(&MODULUS); + b.0.div2(); + } + } + while v.is_even() { + v.div2(); + if c.0.is_even() { + c.0.div2(); + } else { + c.0.add_nocarry(&MODULUS); + c.0.div2(); + } + } + if v < u { + u.sub_noborrow(&v); + b.sub_assign(&c); + } else { + v.sub_noborrow(&u); + c.sub_assign(&b); + } + } + if u == one { + Some(b) + } else { + Some(c) + } + } + } + #[inline(always)] + fn frobenius_map(&mut self, _: usize) {} + #[inline] + fn mul_assign(&mut self, other: &Fr) { + let [b0, b1, b2, b3] = (other.0).0; + let a = (self.0).0[0usize]; + let (r0, carry) = crate::ff::full_width_mul(a, b0); + let m = r0.wrapping_mul(INV); + let red_carry = crate::ff::mac_by_value_return_carry_only(r0, m, MODULUS.0[0]); + let (r1, carry) = crate::ff::mac_by_value(carry, a, b1); + let (r0, red_carry) = + crate::ff::mac_with_carry_by_value(r1, m, MODULUS.0[1usize], red_carry); + let (r2, carry) = crate::ff::mac_by_value(carry, a, b2); + let (r1, red_carry) = + crate::ff::mac_with_carry_by_value(r2, m, MODULUS.0[2usize], red_carry); + let (r3, carry) = crate::ff::mac_by_value(carry, a, b3); + let (r2, red_carry) = + crate::ff::mac_with_carry_by_value(r3, m, MODULUS.0[3usize], red_carry); + let r3 = red_carry + carry; + let a = (self.0).0[1usize]; + let (r0, carry) = crate::ff::mac_by_value(r0, a, b0); + let m = r0.wrapping_mul(INV); + let red_carry = crate::ff::mac_by_value_return_carry_only(r0, m, MODULUS.0[0]); + let (r1, carry) = crate::ff::mac_with_carry_by_value(r1, a, b1, carry); + let (r0, red_carry) = + crate::ff::mac_with_carry_by_value(r1, m, MODULUS.0[1usize], red_carry); + let (r2, carry) = crate::ff::mac_with_carry_by_value(r2, a, b2, carry); + let (r1, red_carry) = + crate::ff::mac_with_carry_by_value(r2, m, MODULUS.0[2usize], red_carry); + let (r3, carry) = crate::ff::mac_with_carry_by_value(r3, a, b3, carry); + let (r2, red_carry) = + crate::ff::mac_with_carry_by_value(r3, m, MODULUS.0[3usize], red_carry); + let r3 = red_carry + carry; + let a = (self.0).0[2usize]; + let (r0, carry) = crate::ff::mac_by_value(r0, a, b0); + let m = r0.wrapping_mul(INV); + let red_carry = crate::ff::mac_by_value_return_carry_only(r0, m, MODULUS.0[0]); + let (r1, carry) = crate::ff::mac_with_carry_by_value(r1, a, b1, carry); + let (r0, red_carry) = + crate::ff::mac_with_carry_by_value(r1, m, MODULUS.0[1usize], red_carry); + let (r2, carry) = crate::ff::mac_with_carry_by_value(r2, a, b2, carry); + let (r1, red_carry) = + crate::ff::mac_with_carry_by_value(r2, m, MODULUS.0[2usize], red_carry); + let (r3, carry) = crate::ff::mac_with_carry_by_value(r3, a, b3, carry); + let (r2, red_carry) = + crate::ff::mac_with_carry_by_value(r3, m, MODULUS.0[3usize], red_carry); + let r3 = red_carry + carry; + let a = (self.0).0[3usize]; + let (r0, carry) = crate::ff::mac_by_value(r0, a, b0); + let m = r0.wrapping_mul(INV); + let red_carry = crate::ff::mac_by_value_return_carry_only(r0, m, MODULUS.0[0]); + let (r1, carry) = crate::ff::mac_with_carry_by_value(r1, a, b1, carry); + let (r0, red_carry) = + crate::ff::mac_with_carry_by_value(r1, m, MODULUS.0[1usize], red_carry); + let (r2, carry) = crate::ff::mac_with_carry_by_value(r2, a, b2, carry); + let (r1, red_carry) = + crate::ff::mac_with_carry_by_value(r2, m, MODULUS.0[2usize], red_carry); + let (r3, carry) = crate::ff::mac_with_carry_by_value(r3, a, b3, carry); + let (r2, red_carry) = + crate::ff::mac_with_carry_by_value(r3, m, MODULUS.0[3usize], red_carry); + let r3 = red_carry + carry; + *self = Fr(FrRepr([r0, r1, r2, r3])); + self.reduce(); + } + #[inline] + fn square(&mut self) { + let [a0, a1, a2, a3] = (self.0).0; + let (r0, carry) = crate::ff::full_width_mul(a0, a0); + let m = r0.wrapping_mul(INV); + let red_carry = crate::ff::mac_by_value_return_carry_only(r0, m, MODULUS.0[0]); + let (r1, carry, superhi) = crate::ff::mul_double_add_by_value(carry, a0, a1); + let (r0, red_carry) = + crate::ff::mac_with_carry_by_value(r1, m, MODULUS.0[1usize], red_carry); + let (r2, carry, superhi) = + crate::ff::mul_double_add_low_and_high_carry_by_value(a0, a2, carry, superhi); + let (r1, red_carry) = + crate::ff::mac_with_carry_by_value(r2, m, MODULUS.0[2usize], red_carry); + let (r3, carry) = crate::ff::mul_double_add_low_and_high_carry_by_value_ignore_superhi( + a0, a3, carry, superhi, + ); + let (r2, r3) = crate::ff::mac_with_low_and_high_carry_by_value( + red_carry, + m, + MODULUS.0[3usize], + r3, + carry, + ); + let m = r0.wrapping_mul(INV); + let red_carry = crate::ff::mac_by_value_return_carry_only(r0, m, MODULUS.0[0]); + let (r1, carry) = crate::ff::mac_by_value(r1, a1, a1); + let (r0, red_carry) = + crate::ff::mac_with_carry_by_value(r1, m, MODULUS.0[1usize], red_carry); + let (r2, carry, superhi) = + crate::ff::mul_double_add_add_carry_by_value(r2, a1, a2, carry); + let (r1, red_carry) = mac_with_carry_by_value(r2, m, MODULUS.0[2usize], red_carry); + let (r3, carry) = + crate::ff::mul_double_add_add_low_and_high_carry_by_value_ignore_superhi( + r3, a1, a3, carry, superhi, + ); + let (r2, r3) = crate::ff::mac_with_low_and_high_carry_by_value( + red_carry, + m, + MODULUS.0[3usize], + r3, + carry, + ); + let m = r0.wrapping_mul(INV); + let red_carry = crate::ff::mac_by_value_return_carry_only(r0, m, MODULUS.0[0]); + let (r0, red_carry) = + crate::ff::mac_with_carry_by_value(r1, m, MODULUS.0[1usize], red_carry); + let (r2, carry) = crate::ff::mac_by_value(r2, a2, a2); + let (r1, red_carry) = + crate::ff::mac_with_carry_by_value(r2, m, MODULUS.0[2usize], red_carry); + let (r3, carry) = + crate::ff::mul_double_add_add_carry_by_value_ignore_superhi(r3, a2, a3, carry); + let (r2, r3) = crate::ff::mac_with_low_and_high_carry_by_value( + red_carry, + m, + MODULUS.0[3usize], + r3, + carry, + ); + let m = r0.wrapping_mul(INV); + let red_carry = crate::ff::mac_by_value_return_carry_only(r0, m, MODULUS.0[0]); + let (r0, red_carry) = + crate::ff::mac_with_carry_by_value(r1, m, MODULUS.0[1usize], red_carry); + let (r1, red_carry) = + crate::ff::mac_with_carry_by_value(r2, m, MODULUS.0[2usize], red_carry); + let (r3, carry) = crate::ff::mac_by_value(r3, a3, a3); + let (r2, r3) = crate::ff::mac_with_low_and_high_carry_by_value( + red_carry, + m, + MODULUS.0[3usize], + r3, + carry, + ); + *self = Fr(FrRepr([r0, r1, r2, r3])); + self.reduce(); + } + } + impl std::default::Default for Fr { + fn default() -> Self { + Self::zero() + } + } + impl std::hash::Hash for Fr { + fn hash(&self, state: &mut H) { + for limb in self.0.as_ref().iter() { + limb.hash(state); + } + } + } + impl Fr { + /// Determines if the element is really in the field. This is only used + /// internally. + #[inline(always)] + fn is_valid(&self) -> bool { + self.0 < MODULUS + } + /// Subtracts the modulus from this element if this element is not in the + /// field. Only used interally. + #[inline(always)] + fn reduce(&mut self) { + if !self.is_valid() { + self.0.sub_noborrow(&MODULUS); + } + } + #[inline(always)] + fn mont_reduce( + &mut self, + r0: u64, + mut r1: u64, + mut r2: u64, + mut r3: u64, + mut r4: u64, + mut r5: u64, + mut r6: u64, + mut r7: u64, + ) { + let k = r0.wrapping_mul(INV); + let mut carry = 0; + crate::ff::mac_with_carry(r0, k, MODULUS.0[0], &mut carry); + r1 = crate::ff::mac_with_carry(r1, k, MODULUS.0[1usize], &mut carry); + r2 = crate::ff::mac_with_carry(r2, k, MODULUS.0[2usize], &mut carry); + r3 = crate::ff::mac_with_carry(r3, k, MODULUS.0[3usize], &mut carry); + r4 = crate::ff::adc(r4, 0, &mut carry); + let carry2 = carry; + let k = r1.wrapping_mul(INV); + let mut carry = 0; + crate::ff::mac_with_carry(r1, k, MODULUS.0[0], &mut carry); + r2 = crate::ff::mac_with_carry(r2, k, MODULUS.0[1usize], &mut carry); + r3 = crate::ff::mac_with_carry(r3, k, MODULUS.0[2usize], &mut carry); + r4 = crate::ff::mac_with_carry(r4, k, MODULUS.0[3usize], &mut carry); + r5 = crate::ff::adc(r5, carry2, &mut carry); + let carry2 = carry; + let k = r2.wrapping_mul(INV); + let mut carry = 0; + crate::ff::mac_with_carry(r2, k, MODULUS.0[0], &mut carry); + r3 = crate::ff::mac_with_carry(r3, k, MODULUS.0[1usize], &mut carry); + r4 = crate::ff::mac_with_carry(r4, k, MODULUS.0[2usize], &mut carry); + r5 = crate::ff::mac_with_carry(r5, k, MODULUS.0[3usize], &mut carry); + r6 = crate::ff::adc(r6, carry2, &mut carry); + let carry2 = carry; + let k = r3.wrapping_mul(INV); + let mut carry = 0; + crate::ff::mac_with_carry(r3, k, MODULUS.0[0], &mut carry); + r4 = crate::ff::mac_with_carry(r4, k, MODULUS.0[1usize], &mut carry); + r5 = crate::ff::mac_with_carry(r5, k, MODULUS.0[2usize], &mut carry); + r6 = crate::ff::mac_with_carry(r6, k, MODULUS.0[3usize], &mut carry); + r7 = crate::ff::adc(r7, carry2, &mut carry); + (self.0).0[0usize] = r4; + (self.0).0[1usize] = r5; + (self.0).0[2usize] = r6; + (self.0).0[3usize] = r7; + self.reduce(); + } + } + impl crate::ff::SqrtField for Fr { + fn legendre(&self) -> crate::ff::LegendreSymbol { + let s = self.pow([ + 11389680472494603939u64, + 14681934109093717318u64, + 15863968012492123182u64, + 1743499133401485332u64, + ]); + if s == Self::zero() { + crate::ff::LegendreSymbol::Zero + } else if s == Self::one() { + crate::ff::LegendreSymbol::QuadraticResidue + } else { + crate::ff::LegendreSymbol::QuadraticNonResidue + } + } + fn sqrt(&self) -> Option { + let mut a1 = self.pow([ + 5694840236247301969u64, + 7340967054546858659u64, + 7931984006246061591u64, + 871749566700742666u64, + ]); + let mut a0 = a1; + a0.square(); + a0.mul_assign(self); + if a0.0 + == FrRepr([ + 7548957153968385962u64, + 10162512645738643279u64, + 5900175412809962033u64, + 2475245527108272378u64, + ]) + { + None + } else { + a1.mul_assign(self); + Some(a1) + } + } + } +} diff --git a/crates/franklin-crypto/.github/CODEOWNERS b/crates/franklin-crypto/.github/CODEOWNERS deleted file mode 100644 index 0495698..0000000 --- a/crates/franklin-crypto/.github/CODEOWNERS +++ /dev/null @@ -1,5 +0,0 @@ -# about codeowners: -# https://help.github.com/articles/about-codeowners/ - -# notify about all changes -@shamatar diff --git a/crates/franklin-crypto/.github/ISSUE_TEMPLATE/bug_report.md b/crates/franklin-crypto/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 2d3e38a..0000000 --- a/crates/franklin-crypto/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,39 +0,0 @@ ---- -name: Bug report -about: Use this template for reporting issues -title: '' -labels: bug -assignees: '' ---- - -### 🐛 Bug Report - -#### 📝 Description - -Provide a clear and concise description of the bug. - -#### 🔄 Reproduction Steps - -Steps to reproduce the behaviour - -#### 🤔 Expected Behavior - -Describe what you expected to happen. - -#### 😯 Current Behavior - -Describe what actually happened. - -#### 🖥️ Environment - -Any relevant environment details. - -#### 📋 Additional Context - -Add any other context about the problem here. If applicable, add screenshots to help explain. - -#### 📎 Log Output - -``` -Paste any relevant log output here. -``` diff --git a/crates/franklin-crypto/.github/ISSUE_TEMPLATE/feature_request.md b/crates/franklin-crypto/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index d921e06..0000000 --- a/crates/franklin-crypto/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -name: Feature request -about: Use this template for requesting features -title: '' -labels: feat -assignees: '' ---- - -### 🌟 Feature Request - -#### 📝 Description - -Provide a clear and concise description of the feature you'd like to see. - -#### 🤔 Rationale - -Explain why this feature is important and how it benefits the project. - -#### 📋 Additional Context - -Add any other context or information about the feature request here. diff --git a/crates/franklin-crypto/.github/pull_request_template.md b/crates/franklin-crypto/.github/pull_request_template.md deleted file mode 100644 index 8ce206c..0000000 --- a/crates/franklin-crypto/.github/pull_request_template.md +++ /dev/null @@ -1,20 +0,0 @@ -# What ❔ - - - - - -## Why ❔ - - - - -## Checklist - - - - -- [ ] PR title corresponds to the body of PR (we generate changelog entries from PRs). -- [ ] Tests for the changes have been added / updated. -- [ ] Documentation comments have been added / updated. -- [ ] Code has been formatted via `zk fmt` and `zk lint`. diff --git a/crates/franklin-crypto/.github/workflows/cargo-license.yaml b/crates/franklin-crypto/.github/workflows/cargo-license.yaml deleted file mode 100644 index 189b471..0000000 --- a/crates/franklin-crypto/.github/workflows/cargo-license.yaml +++ /dev/null @@ -1,8 +0,0 @@ -name: Cargo license check -on: pull_request -jobs: - cargo-deny: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: EmbarkStudios/cargo-deny-action@v1 diff --git a/crates/franklin-crypto/.github/workflows/ci.yml b/crates/franklin-crypto/.github/workflows/ci.yml deleted file mode 100644 index bb0ffde..0000000 --- a/crates/franklin-crypto/.github/workflows/ci.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Build and compile - -on: - push: - branches: ["dev"] - pull_request: - branches: ["dev"] - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -env: - CARGO_TERM_COLOR: always - -jobs: - build: - runs-on: [ubuntu-latest] - - steps: - - uses: actions/checkout@v3 - - name: Setup rustup - run: | - wget -q -O - https://sh.rustup.rs | bash -s -- -y - echo "${HOME}/.cargo/bin" >> "${GITHUB_PATH}" - echo "CARGO_BUILD_JOBS=$(($(nproc) /2))" >> "${GITHUB_ENV}" - echo "export PATH=\"$HOME/.cargo/bin:\$PATH\"" >> "${HOME}/.bash_profile" - - name: Setup rust - run: | - rustup set profile minimal - rustup toolchain install nightly-2023-08-23 - rustup default nightly-2023-08-23 - - name: Compile - run: | - cargo build --verbose diff --git a/crates/franklin-crypto/.github/workflows/secrets_scanner.yaml b/crates/franklin-crypto/.github/workflows/secrets_scanner.yaml deleted file mode 100644 index 54054cf..0000000 --- a/crates/franklin-crypto/.github/workflows/secrets_scanner.yaml +++ /dev/null @@ -1,17 +0,0 @@ -name: Leaked Secrets Scan -on: [pull_request] -jobs: - TruffleHog: - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@ac593985615ec2ede58e132d2e21d2b1cbd6127c # v3 - with: - fetch-depth: 0 - - name: TruffleHog OSS - uses: trufflesecurity/trufflehog@0c66d30c1f4075cee1aada2e1ab46dabb1b0071a - with: - path: ./ - base: ${{ github.event.repository.default_branch }} - head: HEAD - extra_args: --debug --only-verified diff --git a/crates/rescue-poseidon/.github/ISSUE_TEMPLATE/bug_report.md b/crates/rescue-poseidon/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 2d3e38a..0000000 --- a/crates/rescue-poseidon/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,39 +0,0 @@ ---- -name: Bug report -about: Use this template for reporting issues -title: '' -labels: bug -assignees: '' ---- - -### 🐛 Bug Report - -#### 📝 Description - -Provide a clear and concise description of the bug. - -#### 🔄 Reproduction Steps - -Steps to reproduce the behaviour - -#### 🤔 Expected Behavior - -Describe what you expected to happen. - -#### 😯 Current Behavior - -Describe what actually happened. - -#### 🖥️ Environment - -Any relevant environment details. - -#### 📋 Additional Context - -Add any other context about the problem here. If applicable, add screenshots to help explain. - -#### 📎 Log Output - -``` -Paste any relevant log output here. -``` diff --git a/crates/rescue-poseidon/.github/ISSUE_TEMPLATE/feature_request.md b/crates/rescue-poseidon/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index d921e06..0000000 --- a/crates/rescue-poseidon/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -name: Feature request -about: Use this template for requesting features -title: '' -labels: feat -assignees: '' ---- - -### 🌟 Feature Request - -#### 📝 Description - -Provide a clear and concise description of the feature you'd like to see. - -#### 🤔 Rationale - -Explain why this feature is important and how it benefits the project. - -#### 📋 Additional Context - -Add any other context or information about the feature request here. diff --git a/crates/rescue-poseidon/.github/pull_request_template.md b/crates/rescue-poseidon/.github/pull_request_template.md deleted file mode 100644 index 8ce206c..0000000 --- a/crates/rescue-poseidon/.github/pull_request_template.md +++ /dev/null @@ -1,20 +0,0 @@ -# What ❔ - - - - - -## Why ❔ - - - - -## Checklist - - - - -- [ ] PR title corresponds to the body of PR (we generate changelog entries from PRs). -- [ ] Tests for the changes have been added / updated. -- [ ] Documentation comments have been added / updated. -- [ ] Code has been formatted via `zk fmt` and `zk lint`. diff --git a/crates/rescue-poseidon/.github/workflows/cargo-license.yaml b/crates/rescue-poseidon/.github/workflows/cargo-license.yaml deleted file mode 100644 index 189b471..0000000 --- a/crates/rescue-poseidon/.github/workflows/cargo-license.yaml +++ /dev/null @@ -1,8 +0,0 @@ -name: Cargo license check -on: pull_request -jobs: - cargo-deny: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: EmbarkStudios/cargo-deny-action@v1 diff --git a/crates/rescue-poseidon/.github/workflows/ci.yaml b/crates/rescue-poseidon/.github/workflows/ci.yaml deleted file mode 100644 index de913f0..0000000 --- a/crates/rescue-poseidon/.github/workflows/ci.yaml +++ /dev/null @@ -1,32 +0,0 @@ -name: "Rust CI" -on: - pull_request: - -jobs: - build: - name: cargo build and test - runs-on: [ubuntu-latest] - steps: - - uses: actions/checkout@v3 - - uses: actions-rust-lang/setup-rust-toolchain@v1 - with: - # Remove default `-D warnings`. - rustflags: "" - - name: Setup rust - run: | - rustup set profile minimal - rustup toolchain install nightly-2023-08-23 - rustup default nightly-2023-08-23 - - run: cargo build --verbose - - run: cargo test --verbose --all - - formatting: - name: cargo fmt - runs-on: [ubuntu-latest] - steps: - - uses: actions/checkout@v3 - - uses: actions-rust-lang/setup-rust-toolchain@v1 - with: - components: rustfmt - - name: Rustfmt Check - uses: actions-rust-lang/rustfmt@v1 diff --git a/crates/rescue-poseidon/.github/workflows/secrets_scanner.yaml b/crates/rescue-poseidon/.github/workflows/secrets_scanner.yaml deleted file mode 100644 index 54054cf..0000000 --- a/crates/rescue-poseidon/.github/workflows/secrets_scanner.yaml +++ /dev/null @@ -1,17 +0,0 @@ -name: Leaked Secrets Scan -on: [pull_request] -jobs: - TruffleHog: - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@ac593985615ec2ede58e132d2e21d2b1cbd6127c # v3 - with: - fetch-depth: 0 - - name: TruffleHog OSS - uses: trufflesecurity/trufflehog@0c66d30c1f4075cee1aada2e1ab46dabb1b0071a - with: - path: ./ - base: ${{ github.event.repository.default_branch }} - head: HEAD - extra_args: --debug --only-verified diff --git a/crates/snark-wrapper/.gitignore b/crates/snark-wrapper/.gitignore new file mode 100644 index 0000000..2c96eb1 --- /dev/null +++ b/crates/snark-wrapper/.gitignore @@ -0,0 +1,2 @@ +target/ +Cargo.lock diff --git a/crates/snark-wrapper/CONTRIBUTING.md b/crates/snark-wrapper/CONTRIBUTING.md new file mode 100644 index 0000000..dd3d458 --- /dev/null +++ b/crates/snark-wrapper/CONTRIBUTING.md @@ -0,0 +1,44 @@ +# Contribution Guidelines + +Hello! Thanks for your interest in joining the mission to accelerate the mass adoption of crypto for personal +sovereignty! We welcome contributions from anyone on the internet, and are grateful for even the smallest of fixes! + +## Ways to contribute + +There are many ways to contribute to the ZK Stack: + +1. Open issues: if you find a bug, have something you believe needs to be fixed, or have an idea for a feature, please + open an issue. +2. Add color to existing issues: provide screenshots, code snippets, and whatever you think would be helpful to resolve + issues. +3. Resolve issues: either by showing an issue isn't a problem and the current state is ok as is or by fixing the problem + and opening a PR. +4. Report security issues, see [our security policy](./github/SECURITY.md). +5. [Join the team!](https://matterlabs.notion.site/Shape-the-future-of-Ethereum-at-Matter-Labs-dfb3b5a037044bb3a8006af2eb0575e0) + +## Fixing issues + +To contribute code fixing issues, please fork the repo, fix an issue, commit, add documentation as per the PR template, +and the repo's maintainers will review the PR. +[here](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) +for guidance how to work with PRs created from a fork. + +## Licenses + +If you contribute to this project, your contributions will be made to the project under both Apache 2.0 and the MIT +license. + +## Resources + +We aim to make it as easy as possible to contribute to the mission. This is still WIP, and we're happy for contributions +and suggestions here too. Some resources to help: + +1. [In-repo docs aimed at developers](docs) +2. [zkSync Era docs!](https://era.zksync.io/docs/) +3. Company links can be found in the [repo's readme](README.md) + +## Code of Conduct + +Be polite and respectful. + +### Thank you diff --git a/crates/snark-wrapper/Cargo.toml b/crates/snark-wrapper/Cargo.toml new file mode 100644 index 0000000..7516416 --- /dev/null +++ b/crates/snark-wrapper/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "snark_wrapper" +version = "0.1.2" +edition = "2021" +authors = ["The Matter Labs Team "] +homepage = "https://zksync.io/" +repository = "https://github.com/matter-labs/zksync-era" +license = "MIT OR Apache-2.0" +keywords = ["blockchain", "zksync"] +categories = ["cryptography"] +description = "ZKsync snark wrapper" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +derivative = "2" +rand = "0.4" + +rescue_poseidon = "=0.5.2" +# rescue_poseidon = {path = "../rescue-poseidon"} diff --git a/crates/snark-wrapper/LICENSE-APACHE b/crates/snark-wrapper/LICENSE-APACHE new file mode 100644 index 0000000..16fe87b --- /dev/null +++ b/crates/snark-wrapper/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/crates/snark-wrapper/LICENSE-MIT b/crates/snark-wrapper/LICENSE-MIT new file mode 100644 index 0000000..31aa793 --- /dev/null +++ b/crates/snark-wrapper/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/crates/snark-wrapper/SECURITY.md b/crates/snark-wrapper/SECURITY.md new file mode 100644 index 0000000..2f2871c --- /dev/null +++ b/crates/snark-wrapper/SECURITY.md @@ -0,0 +1,74 @@ +# Security Policy + +We truly appreciate efforts to discover and disclose security issues responsibly! + +## Vulnerabilities + +If you'd like to report a security issue in the repositories of matter-labs organization, please proceed to our +[Bug Bounty Program on Immunefi](https://era.zksync.io/docs/reference/troubleshooting/audit-bug-bounty.html#bug-bounty-program). + +## Other Security Issues + +We take an impact-first approach instead of a rules-first approach. Therefore, if you believe you found the impactful +issue but can't report it via the Bug Bounty, please email us at +[security@matterlabs.dev](mailto:security@matterlabs.dev). + +### PGP Key + +The following PGP key may be used to communicate sensitive information to developers: + +Fingerprint: `5FED B2D0 EA2C 4906 DD66 71D7 A2C5 0B40 CE3C F297` + +``` +-----BEGIN PGP PUBLIC KEY BLOCK----- + +mQINBGEBmQkBEAD6tlkBEZFMvR8kOgxXX857nC2+oTik6TopJz4uCskuqDaeldMy +l+26BBzLkIeO1loS+bzVgnNFJRrGt9gv98MzNEHJVv6D7GsSLlUX/pz7Lxn0J4ry +o5XIk3MQTCUBdaXGs6GBLl5Xe8o+zNj4MKd4zjgDLinITNlE/YZCDsXyvYS3YFTQ +cwaUTNlawkKgw4BLaEqwB2JuyEhI9wx5X7ibjFL32sWMolYsNAlzFQzM09HCurTn +q0DYau9kPJARcEk9/DK2iq0z3gMCQ8iRTDaOWd8IbSP3HxcEoM5j5ZVAlULmjmUE +StDaMPLj0Kh01Tesh/j+vjchPXHT0n4zqi1+KOesAOk7SIwLadHfQMTpkU7G2fR1 +BrA5MtlzY+4Rm6o7qu3dpZ+Nc4iM3FUnaQRpvn4g5nTh8vjG94OCzX8DXWrCKyxx +amCs9PLDYOpx84fXYv4frkWpKh2digDSUGKhoHaOSnqyyvu3BNWXBCQZJ20rqEIu +sXOQMxWIoWCOOPRRvrHrKDA2hpoKjs3pGsProfpVRzb9702jhWpTfbDp9WjQlFtX +2ZIDxlwAxcugClgrp5JiUxvhg2A9lDNwCF7r1e68uNv5usBZQVKPJmnvS2nWgKy8 +x9oJsnwrEjxwiRHd34UvfMkwY9RENSJ+NoXqBdS7Lwz4m6vgbzq6K56WPQARAQAB +tCRaa1N5bmMgU2VjdXJpdHkgPHNlY3VyaXR5QHprc3luYy5pbz6JAk4EEwEKADgW +IQRf7bLQ6ixJBt1mcdeixQtAzjzylwUCYQGZCQIbAwULCQgHAgYVCgkICwIEFgID +AQIeAQIXgAAKCRCixQtAzjzyl5y8EAC/T3oq88Dak2b+5TlWdU2Gpm6924eAqlMt +y1KksDezzNQUlPiCUVllpin2PIjU/S+yzMWKXJA04LoVkEPfPOWjAaavLOjRumxu +MR6P2dVUg1InqzYVsJuRhKSpeexzNA5qO2BPM7/I2Iea1IoJPjogGbfXCo0r5kne +KU7a5GEa9eDHxpHTsbphQe2vpQ1239mUJrFpzAvILn6jV1tawMn5pNCXbsa8l6l2 +gtlyQPdOQECy77ZJxrgzaUBcs/RPzUGhwA/qNuvpF0whaCvZuUFMVuCTEu5LZka2 +I9Rixy+3jqBeONBgb+Fiz5phbiMX33M9JQwGONFaxdvpFTerLwPK2N1T8zcufa01 +ypzkWGheScFZemBxUwXwK4x579wjsnfrY11w0p1jtDgPTnLlXUA2mom4+7MyXPg0 +F75qh6vU1pdXaCVkruFgPVtIw+ccw2AxD50iZQ943ZERom9k165dR9+QxOVMXQ4P +VUxsFZWvK70/s8TLjsGljvSdSOa85iEUqSqh0AlCwIAxLMiDwh5s/ZgiHoIM6Xih +oCpuZyK9p0dn+DF/XkgAZ/S91PesMye3cGm6M5r0tS26aoc2Pk6X37Hha1pRALwo +MOHyaGjc/jjcXXxv6o55ALrOrzS0LQmLZ+EHuteCT15kmeY3kqYJ3og62KgiDvew +dKHENvg7d7kCDQRhAZleARAA6uD6WfdqGeKV5i170+kLsxR3QGav0qGNAbxpSJyn +iHQ8u7mQk3S+ziwN2AAopfBk1je+vCWtEGC3+DWRRfJSjLbtaBG8e6kLP3/cGA75 +qURz6glTG4nl5fcEAa6B1st0OxjVWiSLX3g/yjz8lznQb9awuRjdeHMnyx5DsJUN +d+Iu5KxGupQvKGOMKivSvC8VWk9taaQRpRF+++6stLCDk3ZtlxiopMs3X2jAp6xG +sOBbix1cv9BTsfaiL7XDL/gviqBPXYY5L42x6+jnPo5lROfnlLYkWrv6KZr7HD4k +tRXeaSwxLD2EkUyb16Jpp0be/ofvBtITGUDDLCGBiaXtx/v8d52MARjsyLJSYloj +1yiW01LfAiWHUC4z5jl2T7E7sicrlLH1M8Z6WbuqjdeaYwtfyPA2YCKr/3fn6pIo +D+pYaBSESmhA92P+XVaf5y2BZ6Qf8LveDpWwsVGdBGh9T0raA1ooe1GESLjmIjUa +z5AeQ/uXL5Md9I6bpMUUJYQiH19RPcFlJriI3phXyyf6Wlkk8oVEeCWyzcmw+x1V +deRTvE2x4WIwKGLXRNjin2j1AP7vU2HaNwlPrLijqdyi68+0irRQONoH7Qonr4ca +xWgL+pAaa3dWxf0xqK7uZFp4aTVWlr2uXtV/eaUtLmGMCU0jnjb109wg5L0F7WRT +PfEAEQEAAYkCNgQYAQoAIBYhBF/tstDqLEkG3WZx16LFC0DOPPKXBQJhAZleAhsM +AAoJEKLFC0DOPPKXAAEP/jK7ch9GkoaYlsuqY/aHtxEwVddUDOxjyn3FMDoln85L +/n8AmLQb2bcpKSqpaJwMbmfEyr5MDm8xnsBTfx3u6kgaLOWfKxjLQ6PM7kgIMdi4 +bfaRRuSEI1/R6c/hNpiGnzAeeexldH1we+eH1IVmh4crdat49S2xh7Qlv9ahvgsP +LfKl3rJ+aaX/Ok0AHzhvSfhFpPr1gAaGeaRt+rhlZsx2QyG4Ez8p2nDAcAzPiB3T +73ENoBIX6mTPfPm1UgrRyFKBqtUzAodz66j3r6ebBlWzIRg8iZenVMAxzjINAsxN +w1Bzfgsi5ZespfsSlmEaa7jJkqqDuEcLa2YuiFAue7Euqwz1aGeq1GfTicQioSCb +Ur/LGyz2Mj3ykbaP8p5mFVcUN51yQy6OcpvR/W1DfRT9SHFT/bCf9ixsjB2HlZGo +uxPJowwqmMgHd755ZzPDUM9YDgLI1yXdcYshObv3Wq537JAxnZJCGRK4Y8SwrMSh +8WRxlaM0AGWXiJFIDD4bQPIdnF3X8w0cGWE5Otkb8mMHOT+rFTVlDODwm1zF6oIG +PTwfVrpiZBwiUtfJol1exr/MzSPyGoJnYs3cRf2E3O+D1LbcR8w0LbjGuUy38Piz +ZO/vCeyJ3JZC5kE8nD+XBA4idwzh0BKEfH9t+WchQ3Up9rxyzLyQamoqt5Xby4pY +=xkM3 +-----END PGP PUBLIC KEY BLOCK----- +``` diff --git a/crates/snark-wrapper/deny.toml b/crates/snark-wrapper/deny.toml new file mode 100644 index 0000000..6977d43 --- /dev/null +++ b/crates/snark-wrapper/deny.toml @@ -0,0 +1,79 @@ +all-features = false +no-default-features = false + +[advisories] +vulnerability = "deny" +unmaintained = "warn" +yanked = "warn" +notice = "warn" +ignore = [ + #"RUSTSEC-0000-0000", +] + +[licenses] +unlicensed = "deny" +allow = [ + #"Apache-2.0 WITH LLVM-exception", + "MIT", + "Apache-2.0", + "ISC", + "Unlicense", + "MPL-2.0", + "Unicode-DFS-2016", + "CC0-1.0", + "BSD-2-Clause", + "BSD-3-Clause", + "Zlib", +] +deny = [ + #"Nokia", +] +copyleft = "warn" +allow-osi-fsf-free = "neither" +default = "deny" +confidence-threshold = 0.8 +exceptions = [ + # Each entry is the crate and version constraint, and its specific allow + # list + #{ allow = ["Zlib"], name = "adler32", version = "*" }, +] + +unused-allowed-license = "allow" + +[licenses.private] +ignore = false +registries = [ + #"https://sekretz.com/registry +] + +[bans] +multiple-versions = "warn" +wildcards = "allow" +highlight = "all" +workspace-default-features = "allow" +external-default-features = "allow" +allow = [ + #{ name = "ansi_term", version = "=0.11.0" }, +] +# List of crates to deny +deny = [ + # Each entry the name of a crate and a version range. If version is + # not specified, all versions will be matched. + #{ name = "ansi_term", version = "=0.11.0" }, +] + +skip = [ + #{ name = "ansi_term", version = "=0.11.0" }, +] +skip-tree = [ + #{ name = "ansi_term", version = "=0.11.0", depth = 20 }, +] + +[sources] +unknown-registry = "deny" +unknown-git = "allow" +allow-registry = ["https://github.com/rust-lang/crates.io-index"] +allow-git = [] + +[sources.allow-org] +#github = ["matter-labs"] diff --git a/crates/snark-wrapper/rust-toolchain b/crates/snark-wrapper/rust-toolchain new file mode 100644 index 0000000..03c040b --- /dev/null +++ b/crates/snark-wrapper/rust-toolchain @@ -0,0 +1 @@ +nightly-2024-08-01 diff --git a/crates/snark-wrapper/src/implementations/mod.rs b/crates/snark-wrapper/src/implementations/mod.rs new file mode 100644 index 0000000..5eaba5c --- /dev/null +++ b/crates/snark-wrapper/src/implementations/mod.rs @@ -0,0 +1,2 @@ +pub mod poseidon2; +pub mod verifier_builder; diff --git a/crates/snark-wrapper/src/implementations/poseidon2/mod.rs b/crates/snark-wrapper/src/implementations/poseidon2/mod.rs new file mode 100644 index 0000000..6fbc424 --- /dev/null +++ b/crates/snark-wrapper/src/implementations/poseidon2/mod.rs @@ -0,0 +1,217 @@ +use crate::franklin_crypto::bellman::pairing::Engine; +use crate::franklin_crypto::bellman::plonk::better_better_cs::cs::ConstraintSystem; +use crate::franklin_crypto::bellman::PrimeFieldRepr; +use crate::franklin_crypto::bellman::{PrimeField, SynthesisError}; +use crate::franklin_crypto::plonk::circuit::allocated_num::Num; +use crate::franklin_crypto::plonk::circuit::boolean::Boolean; +use crate::franklin_crypto::plonk::circuit::goldilocks::GoldilocksField; +use crate::franklin_crypto::plonk::circuit::linear_combination::LinearCombination; + +use rescue_poseidon::circuit::poseidon2::circuit_poseidon2_round_function; +use rescue_poseidon::poseidon2::Poseidon2Params; + +use crate::boojum::field::goldilocks::GoldilocksField as GL; +use crate::boojum::field::PrimeField as BoojumPrimeField; + +use derivative::*; + +pub mod transcript; +pub mod tree_hasher; + +#[derive(Derivative)] +#[derivative(Clone, Debug)] +pub struct CircuitPoseidon2Sponge { + pub(crate) state: [LinearCombination; WIDTH], + pub(crate) buffer: [LinearCombination; RATE], + pub(crate) gl_buffer: [GoldilocksField; CHUNK_BY], + pub(crate) filled: usize, + #[derivative(Debug = "ignore")] + pub(crate) params: Poseidon2Params, +} + +impl CircuitPoseidon2Sponge { + pub fn new() -> Self { + Self::new_from_params(Poseidon2Params::default()) + } + + pub fn new_from_params(params: Poseidon2Params) -> Self { + assert!(CHUNK_BY == (E::Fr::CAPACITY as usize) / (GL::CHAR_BITS as usize)); + assert!(ABSORB_BY_REPLACEMENT, "Only replacement mode is implemented"); + + Self { + state: [(); WIDTH].map(|_| LinearCombination::zero()), + buffer: [(); RATE].map(|_| LinearCombination::zero()), + gl_buffer: [GoldilocksField::zero(); CHUNK_BY], + filled: 0, + params, + } + } + + pub fn run_round_function>(&mut self, cs: &mut CS) -> Result<(), SynthesisError> { + circuit_poseidon2_round_function(cs, &self.params, &mut self.state) + } + + pub fn try_get_commitment>(&self, cs: &mut CS) -> Result; RATE]>, SynthesisError> { + if self.filled != 0 { + return Ok(None); + } + + let mut result = [Num::zero(); RATE]; + for (dst, src) in result.iter_mut().zip(self.state.iter()) { + *dst = src.clone().into_num(cs)?; + } + + Ok(Some(result)) + } + + pub fn absorb_buffer_to_state>(&mut self, cs: &mut CS) -> Result<(), SynthesisError> { + for (dst, src) in self.state.iter_mut().zip(self.buffer.iter_mut()) { + *dst = std::mem::replace(src, LinearCombination::zero()); + } + + self.run_round_function(cs)?; + self.filled = 0; + + Ok(()) + } + + pub fn absorb_single_gl>(&mut self, cs: &mut CS, value: &GoldilocksField) -> Result<(), SynthesisError> { + debug_assert!(self.filled < RATE * CHUNK_BY); + let pos = self.filled / CHUNK_BY; + let exp = self.filled % CHUNK_BY; + + let mut coeff = ::Repr::from(1); + coeff.shl((exp * GL::CHAR_BITS) as u32); + + self.buffer[pos].add_assign_number_with_coeff(&value.into_num(), E::Fr::from_repr(coeff).unwrap()); + self.filled += 1; + + if self.filled == RATE * CHUNK_BY { + self.absorb_buffer_to_state(cs)?; + } + + Ok(()) + } + + pub fn absorb_single>(&mut self, cs: &mut CS, value: Num) -> Result<(), SynthesisError> { + debug_assert!(self.filled < RATE * CHUNK_BY); + let pos = self.filled / CHUNK_BY; + let exp = self.filled % CHUNK_BY; + + match exp { + 0 => { + self.filled += CHUNK_BY; + self.buffer[pos] = value.into(); + } + _ => { + self.filled = (pos + 1) * CHUNK_BY; + + if self.filled == RATE * CHUNK_BY { + self.absorb_buffer_to_state(cs)?; + + self.buffer[0] = value.into(); + self.filled = CHUNK_BY; + } else { + self.filled += CHUNK_BY; + self.buffer[pos + 1] = value.into(); + } + } + } + + if self.filled == RATE * CHUNK_BY { + self.absorb_buffer_to_state(cs)?; + } + + Ok(()) + } + + pub fn absorb> + Clone, CS: ConstraintSystem>(&mut self, cs: &mut CS, values: &[T]) -> Result<(), SynthesisError> { + debug_assert!(self.filled < RATE * CHUNK_BY); + let mut pos = self.filled / CHUNK_BY; + let exp = self.filled % CHUNK_BY; + let len = values.len(); + + if exp != 0 { + pos += 1; + } + + if len + pos < RATE { + for (dst, src) in self.buffer[pos..pos + len].iter_mut().zip(values.iter()) { + *dst = src.clone().into(); + } + + self.filled += len * CHUNK_BY; + + return Ok(()); + } + + let chunks_start = RATE - pos; + let num_chunks = (len - chunks_start) / RATE; + let chunk_finish = chunks_start + num_chunks * RATE; + + for (i, value) in values[..chunks_start].iter().enumerate() { + self.buffer[pos + i] = value.clone().into(); + } + self.absorb_buffer_to_state(cs)?; + + for chunk in values[chunks_start..chunk_finish].chunks_exact(RATE) { + for (j, value) in chunk.iter().enumerate() { + self.state[j] = value.clone().into(); + } + self.run_round_function(cs)?; + } + + let new_pos = len - chunk_finish; + for (dst, src) in self.buffer[..new_pos].iter_mut().zip(values[chunk_finish..].iter()) { + *dst = src.clone().into(); + } + self.filled = new_pos * CHUNK_BY; + + Ok(()) + } + + pub fn finalize>(&mut self, cs: &mut CS) -> Result<[Num; RATE], SynthesisError> { + // padding + self.absorb_single_gl(cs, &GoldilocksField::one())?; + + if self.filled > 0 { + self.absorb_buffer_to_state(cs)?; + } + + let mut result = [Num::zero(); RATE]; + + for (dst, src) in result.iter_mut().zip(self.state.iter()) { + *dst = src.clone().into_num(cs)?; + } + + Ok(result) + } + + pub fn finalize_reset>(&mut self, cs: &mut CS) -> Result<[Num; RATE], SynthesisError> { + // padding + self.absorb_single_gl(cs, &GoldilocksField::one())?; + + // reset + let mut state = std::mem::replace(&mut self.state, [(); WIDTH].map(|_| LinearCombination::zero())); + + let filled = self.filled; + self.filled = 0; + + // run round function if necessary + if filled > 0 { + for (dst, src) in state.iter_mut().zip(self.buffer.iter_mut()) { + *dst = std::mem::replace(src, LinearCombination::zero()); + } + + circuit_poseidon2_round_function(cs, &self.params, &mut state)?; + } + + let mut result = [Num::zero(); RATE]; + + for (dst, src) in result.iter_mut().zip(state.into_iter()) { + *dst = src.into_num(cs)?; + } + + Ok(result) + } +} diff --git a/crates/snark-wrapper/src/implementations/poseidon2/transcript.rs b/crates/snark-wrapper/src/implementations/poseidon2/transcript.rs new file mode 100644 index 0000000..98809b2 --- /dev/null +++ b/crates/snark-wrapper/src/implementations/poseidon2/transcript.rs @@ -0,0 +1,185 @@ +use super::*; + +use crate::traits::transcript::CircuitGLTranscript; + +#[derive(Derivative)] +#[derivative(Clone, Debug)] +pub struct CircuitPoseidon2Transcript { + buffer: Vec>, + last_filled: usize, + available_challenges: Vec>, + #[derivative(Debug = "ignore")] + sponge: CircuitPoseidon2Sponge, +} + +impl CircuitPoseidon2Transcript { + pub fn new() -> Self { + Self { + buffer: Vec::new(), + last_filled: 0, + available_challenges: Vec::new(), + sponge: CircuitPoseidon2Sponge::::new(), + } + } +} + +impl CircuitGLTranscript + for CircuitPoseidon2Transcript +{ + type CircuitCompatibleCap = Num; + type TranscriptParameters = (); + + const IS_ALGEBRAIC: bool = true; + + fn new>(_cs: &mut CS, _params: Self::TranscriptParameters) -> Result { + Ok(Self::new()) + } + + fn witness_field_elements>(&mut self, _cs: &mut CS, field_els: &[GoldilocksField]) -> Result<(), SynthesisError> { + debug_assert!(self.last_filled < CHUNK_BY); + + let add_to_last = field_els.len().min((CHUNK_BY - self.last_filled) % CHUNK_BY); + + if add_to_last != 0 { + for (i, el) in field_els[..add_to_last].iter().enumerate() { + let mut coeff = ::Repr::from(1); + coeff.shl(((i + self.last_filled) * GL::CHAR_BITS) as u32); + + self.buffer.last_mut().unwrap().add_assign_number_with_coeff(&el.into_num(), E::Fr::from_repr(coeff).unwrap()); + } + } + + for chunk in field_els[add_to_last..].chunks(CHUNK_BY) { + let mut new = LinearCombination::zero(); + let mut coeff = ::Repr::from(1); + for el in chunk.iter() { + new.add_assign_number_with_coeff(&el.into_num(), E::Fr::from_repr(coeff).unwrap()); + coeff.shl(GL::CHAR_BITS as u32); + } + self.buffer.push(new); + } + + self.last_filled = (self.last_filled + field_els.len()) % CHUNK_BY; + + Ok(()) + } + + fn witness_merkle_tree_cap>(&mut self, _cs: &mut CS, cap: &Vec) -> Result<(), SynthesisError> { + self.last_filled = 0; + self.buffer.extend(cap.iter().map(|&el| el.into())); + + Ok(()) + } + + fn get_challenge>(&mut self, cs: &mut CS) -> Result, SynthesisError> { + assert_eq!(self.sponge.filled, 0); + + if self.buffer.is_empty() { + if self.available_challenges.len() > 0 { + let first_el = self.available_challenges.first().unwrap().clone(); + self.available_challenges.drain(..1); + return Ok(first_el); + } else { + self.sponge.run_round_function(cs)?; + + { + let commitment = self.sponge.try_get_commitment(cs)?.expect("must have no pending elements in the buffer"); + for &el in commitment.iter() { + self.available_challenges.extend(get_challenges_from_num(cs, el)?); + } + } + + return self.get_challenge(cs); + } + } + + let to_absorb = std::mem::replace(&mut self.buffer, vec![]); + self.sponge.absorb(cs, &to_absorb)?; + self.last_filled = 0; + + self.available_challenges = vec![]; + let commitment = self.sponge.finalize(cs)?; + for &el in commitment.iter() { + self.available_challenges.extend(get_challenges_from_num(cs, el)?); + } + + // to avoid duplication + self.get_challenge(cs) + } +} + +fn get_challenges_from_num>(cs: &mut CS, num: Num) -> Result>, SynthesisError> { + Ok(GoldilocksField::from_num_to_multiple_with_reduction::<_, 3>(cs, num)?.to_vec()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::boojum::cs::implementations::transcript::Transcript; + use crate::boojum::field::{SmallField, U64Representable}; + use rand::{Rand, Rng}; + + use crate::franklin_crypto::bellman::pairing::bn256::{Bn256, Fr}; + use crate::franklin_crypto::bellman::plonk::better_better_cs::cs::*; + use crate::franklin_crypto::plonk::circuit::bigint_new::BITWISE_LOGICAL_OPS_TABLE_NAME; + + use rescue_poseidon::poseidon2::transcript::Poseidon2Transcript; + + use crate::implementations::poseidon2::tree_hasher::AbsorptionModeReplacement; + + #[test] + fn test_poseidon2_transcript() { + let mut assembly = TrivialAssembly::::new(); + let _before = assembly.n(); + + let mut rng = rand::thread_rng(); + let buffer_u64 = [0; 100].map(|_| rng.gen_range(0, GL::CHAR)); + + let buffer_circuit = buffer_u64.map(|x| GoldilocksField::alloc_from_u64(&mut assembly, Some(x)).unwrap()); + + let buffer_gl = buffer_u64.map(|x| GL::from_u64_unchecked(x)); + + // add table for range check + let columns3 = vec![PolyIdentifier::VariablesPolynomial(0), PolyIdentifier::VariablesPolynomial(1), PolyIdentifier::VariablesPolynomial(2)]; + + let name = BITWISE_LOGICAL_OPS_TABLE_NAME; + let bitwise_logic_table = LookupTableApplication::new(name, TwoKeysOneValueBinopTable::::new(8, name), columns3.clone(), None, true); + assembly.add_table(bitwise_logic_table).unwrap(); + + let mut transcript = Poseidon2Transcript::, 2, 3>::new(); + let mut circuit_transcript = CircuitPoseidon2Transcript::::new(); + + transcript.witness_field_elements(&buffer_gl); + circuit_transcript.witness_field_elements(&mut assembly, &buffer_circuit).unwrap(); + + for _ in 0..5 { + let chal = transcript.get_challenge(); + let chal_circuit = circuit_transcript.get_challenge(&mut assembly).unwrap(); + + assert_eq!(chal, chal_circuit.into_num().get_value().unwrap().into_repr().as_ref()[0]); + } + + transcript.witness_field_elements(&buffer_gl); + circuit_transcript.witness_field_elements(&mut assembly, &buffer_circuit).unwrap(); + + for _ in 0..10 { + let chal = transcript.get_challenge(); + let chal_circuit = circuit_transcript.get_challenge(&mut assembly).unwrap(); + + assert_eq!(chal, chal_circuit.into_num().get_value().unwrap().into_repr().as_ref()[0]); + } + + let rand_fr: Vec<_> = (0..10).map(|_| Fr::rand(&mut rng)).collect(); + let num: Vec<_> = rand_fr.iter().map(|x| Num::alloc(&mut assembly, Some(*x)).unwrap()).collect(); + + transcript.witness_merkle_tree_cap(&rand_fr); + circuit_transcript.witness_merkle_tree_cap(&mut assembly, &num).unwrap(); + + for _ in 0..5 { + let chal = transcript.get_challenge(); + let chal_circuit = circuit_transcript.get_challenge(&mut assembly).unwrap(); + + assert_eq!(chal, chal_circuit.into_num().get_value().unwrap().into_repr().as_ref()[0]); + } + } +} diff --git a/crates/snark-wrapper/src/implementations/poseidon2/tree_hasher.rs b/crates/snark-wrapper/src/implementations/poseidon2/tree_hasher.rs new file mode 100644 index 0000000..84185b6 --- /dev/null +++ b/crates/snark-wrapper/src/implementations/poseidon2/tree_hasher.rs @@ -0,0 +1,164 @@ +use super::*; + +use crate::traits::tree_hasher::CircuitGLTreeHasher; +use rescue_poseidon::poseidon2::Poseidon2Sponge; + +impl CircuitGLTreeHasher + for CircuitPoseidon2Sponge +{ + type CircuitOutput = Num; + type NonCircuitSimulator = Poseidon2Sponge, RATE, WIDTH>; + + fn new>(_cs: &mut CS) -> Result { + Ok(Self::new()) + } + + fn placeholder_output>(_cs: &mut CS) -> Result { + Ok(Num::zero()) + } + + fn accumulate_into_leaf>(&mut self, cs: &mut CS, value: &GoldilocksField) -> Result<(), SynthesisError> { + self.absorb_single_gl(cs, value) + } + + fn finalize_into_leaf_hash_and_reset>(&mut self, cs: &mut CS) -> Result { + Ok(self.finalize_reset(cs)?[0]) + } + + fn hash_into_leaf<'a, S: IntoIterator>, CS: ConstraintSystem>(cs: &mut CS, source: S) -> Result + where + GoldilocksField: 'a, + { + let mut hasher = Self::new(); + + for el in source.into_iter() { + hasher.absorb_single_gl(cs, el)?; + } + Ok(hasher.finalize(cs)?[0]) + } + + fn hash_into_leaf_owned>, CS: ConstraintSystem>(cs: &mut CS, source: S) -> Result { + let mut hasher = Self::new(); + + for el in source.into_iter() { + hasher.absorb_single_gl(cs, &el)?; + } + Ok(hasher.finalize(cs)?[0]) + } + + fn swap_nodes>( + cs: &mut CS, + should_swap: Boolean, + left: &Self::CircuitOutput, + right: &Self::CircuitOutput, + _depth: usize, + ) -> Result<(Self::CircuitOutput, Self::CircuitOutput), SynthesisError> { + Num::conditionally_reverse(cs, left, right, &should_swap) + } + + fn hash_into_node>(cs: &mut CS, left: &Self::CircuitOutput, right: &Self::CircuitOutput, _depth: usize) -> Result { + let params = Poseidon2Params::::default(); + let mut state = [(); WIDTH].map(|_| LinearCombination::zero()); + state[0] = (*left).into(); + state[1] = (*right).into(); + + circuit_poseidon2_round_function(cs, ¶ms, &mut state)?; + + state[0].clone().into_num(cs) + } + + fn select_cap_node>(cs: &mut CS, cap_bits: &[Boolean], cap: &[Self::CircuitOutput]) -> Result { + assert_eq!(cap.len(), 1 << cap_bits.len()); + assert!(cap_bits.len() > 0); + + let mut input_space = Vec::with_capacity(cap.len() / 2); + let mut dst_space = Vec::with_capacity(cap.len() / 2); + + for (idx, bit) in cap_bits.iter().enumerate() { + let src = if idx == 0 { cap } else { &input_space }; + + debug_assert_eq!(cap.len() % 2, 0); + dst_space.clear(); + + for src in src.array_chunks::<2>() { + let [a, b] = src; + // NOTE order here + let selected = Num::conditionally_select(cs, bit, b, a)?; + dst_space.push(selected); + } + + std::mem::swap(&mut dst_space, &mut input_space); + } + + assert_eq!(input_space.len(), 1); + + Ok(input_space.pop().unwrap()) + } + + fn compare_output>(cs: &mut CS, a: &Self::CircuitOutput, b: &Self::CircuitOutput) -> Result { + Num::equals(cs, a, b) + } +} + +use crate::boojum::algebraic_props::round_function::AbsorptionModeTrait; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AbsorptionModeReplacement(std::marker::PhantomData); + +impl AbsorptionModeTrait for AbsorptionModeReplacement { + fn absorb(dst: &mut F, src: &F) { + *dst = *src; + } + + fn pad(_dst: &mut F) { + unimplemented!("pad is not used") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::boojum::cs::oracle::TreeHasher; + use crate::boojum::field::{SmallField, U64Representable}; + use rand::{Rand, Rng}; + + use crate::franklin_crypto::bellman::pairing::bn256::{Bn256, Fr}; + use crate::franklin_crypto::bellman::plonk::better_better_cs::cs::*; + use crate::franklin_crypto::plonk::circuit::bigint_new::BITWISE_LOGICAL_OPS_TABLE_NAME; + + type TH = Poseidon2Sponge, 2, 3>; + type CTH = CircuitPoseidon2Sponge; + + #[test] + fn test_poseidon2_tree_hasher() { + let mut assembly = TrivialAssembly::::new(); + let _before = assembly.n(); + + let mut rng = rand::thread_rng(); + let buffer_u64 = [0; 100].map(|_| rng.gen_range(0, GL::CHAR)); + + let buffer_circuit = buffer_u64.map(|x| GoldilocksField::alloc_from_u64(&mut assembly, Some(x)).unwrap()); + + let buffer_gl = buffer_u64.map(|x| GL::from_u64_unchecked(x)); + + // add table for range check + let columns3 = vec![PolyIdentifier::VariablesPolynomial(0), PolyIdentifier::VariablesPolynomial(1), PolyIdentifier::VariablesPolynomial(2)]; + + let name = BITWISE_LOGICAL_OPS_TABLE_NAME; + let bitwise_logic_table = LookupTableApplication::new(name, TwoKeysOneValueBinopTable::::new(8, name), columns3.clone(), None, true); + assembly.add_table(bitwise_logic_table).unwrap(); + + let leaf_hash = TH::hash_into_leaf(&buffer_gl); + let leaf_hash_circuit = CTH::hash_into_leaf(&mut assembly, &buffer_circuit).unwrap(); + + assert_eq!(leaf_hash, leaf_hash_circuit.get_value().unwrap()); + + let rand_fr = [0; 2].map(|_| Fr::rand(&mut rng)); + let num = rand_fr.clone().map(|x| Num::alloc(&mut assembly, Some(x)).unwrap()); + + let node_hash = TH::hash_into_node(&rand_fr[0], &rand_fr[1], 3); + let node_hash_circuit = CTH::hash_into_node(&mut assembly, &num[0], &num[1], 3).unwrap(); + + assert_eq!(node_hash, node_hash_circuit.get_value().unwrap()); + } +} diff --git a/crates/snark-wrapper/src/implementations/verifier_builder.rs b/crates/snark-wrapper/src/implementations/verifier_builder.rs new file mode 100644 index 0000000..3d7711e --- /dev/null +++ b/crates/snark-wrapper/src/implementations/verifier_builder.rs @@ -0,0 +1,351 @@ +use std::any::TypeId; +use std::collections::HashMap; + +use crate::franklin_crypto::bellman::pairing::Engine; +use crate::franklin_crypto::bellman::plonk::better_better_cs::cs::ConstraintSystem; + +use crate::boojum::cs::cs_builder::new_builder; +use crate::boojum::cs::cs_builder::CsBuilderImpl; +use crate::boojum::cs::traits::circuit::{CircuitBuilder, CircuitBuilderProxy}; +use crate::boojum::cs::traits::evaluator::GateBatchEvaluationComparisonFunction; +use crate::boojum::cs::traits::evaluator::GatePlacementType; +use crate::boojum::cs::traits::evaluator::PerChunkOffset; +use crate::boojum::cs::{CSGeometry, LookupParameters}; +use crate::boojum::field::goldilocks::GoldilocksField as GL; + +use crate::traits::circuit::ErasedBuilderForWrapperVerifier; +use crate::verifier_structs::gate_evaluator::TypeErasedGateEvaluationWrapperVerificationFunction; +use crate::verifier_structs::WrapperVerifier; + +impl + 'static, T: CircuitBuilder> ErasedBuilderForWrapperVerifier for CircuitBuilderProxy { + fn geometry(&self) -> CSGeometry { + T::geometry() + } + + fn lookup_parameters(&self) -> LookupParameters { + T::lookup_parameters() + } + + fn create_wrapper_verifier(&self, cs: &mut CS) -> WrapperVerifier { + let geometry = T::geometry(); + let builder_impl = CsWrapperVerifierBuilder::<'_, E, CS>::new_from_parameters(cs, geometry); + let builder = new_builder::<_, GL>(builder_impl); + + let builder = T::configure_builder(builder); + let verifier = builder.build(()); + + verifier + } +} + +pub struct CsWrapperVerifierBuilder<'a, E: Engine, CS: ConstraintSystem + 'static> { + pub(crate) cs: &'a mut CS, + + pub parameters: CSGeometry, + pub lookup_parameters: LookupParameters, + + pub(crate) gate_type_ids_for_specialized_columns: Vec, + pub(crate) evaluators_over_specialized_columns: Vec>, + pub(crate) offsets_for_specialized_evaluators: Vec<(PerChunkOffset, PerChunkOffset, usize)>, + + pub(crate) evaluators_over_general_purpose_columns: Vec>, + pub(crate) general_purpose_evaluators_comparison_functions: HashMap>, + + pub(crate) total_num_variables_for_specialized_columns: usize, + pub(crate) total_num_witnesses_for_specialized_columns: usize, + pub(crate) total_num_constants_for_specialized_columns: usize, +} + +impl<'a, E: Engine, CS: ConstraintSystem + 'static> CsWrapperVerifierBuilder<'a, E, CS> { + pub fn new_from_parameters(cs: &'a mut CS, parameters: CSGeometry) -> Self { + Self { + cs, + + parameters: parameters, + lookup_parameters: LookupParameters::NoLookup, + + gate_type_ids_for_specialized_columns: Vec::with_capacity(16), + evaluators_over_specialized_columns: Vec::with_capacity(16), + offsets_for_specialized_evaluators: Vec::with_capacity(16), + + evaluators_over_general_purpose_columns: Vec::with_capacity(16), + general_purpose_evaluators_comparison_functions: HashMap::with_capacity(16), + + total_num_variables_for_specialized_columns: 0, + total_num_witnesses_for_specialized_columns: 0, + total_num_constants_for_specialized_columns: 0, + } + } +} + +use crate::boojum::cs::cs_builder::CsBuilder; +use crate::boojum::cs::gates::lookup_marker::LookupFormalGate; +use crate::boojum::cs::gates::LookupTooling; +use crate::boojum::cs::traits::gate::GatePlacementStrategy; +use crate::boojum::cs::traits::{evaluator::GateConstraintEvaluator, gate::Gate}; +use crate::boojum::cs::GateConfigurationHolder; +use crate::boojum::cs::GateTypeEntry; +use crate::boojum::cs::StaticToolboxHolder; +use crate::boojum::cs::Tool; + +impl<'a, E: Engine, CS: ConstraintSystem + 'static> CsBuilderImpl> for CsWrapperVerifierBuilder<'a, E, CS> { + type Final, TB: StaticToolboxHolder> = WrapperVerifier; + + type BuildParams<'b> = (); + + fn parameters, TB: StaticToolboxHolder>(builder: &CsBuilder) -> CSGeometry { + builder.implementation.parameters + } + + fn allow_gate, TB: StaticToolboxHolder, G: Gate, TAux: 'static + Send + Sync + Clone>( + mut builder: CsBuilder, + placement_strategy: GatePlacementStrategy, + params: <>::Evaluator as GateConstraintEvaluator>::UniqueParameterizationParams, + aux_data: TAux, + ) -> CsBuilder, GC), TB> { + // log!("Adding gate {:?}", std::any::type_name::()); + + let this = &mut builder.implementation; + + let new_configuration = builder.gates_config.add_gate::(placement_strategy, params.clone(), aux_data); + let evaluator_type_id = TypeId::of::(); + let gate_type_id = TypeId::of::(); + let evaluator = G::Evaluator::new_from_parameters(params.clone()); + + // // depending on the configuration we should place it into corresponding set, + // // and create some extra staff + + match placement_strategy { + GatePlacementStrategy::UseGeneralPurposeColumns => { + // we should batch gates that have the same evaluator + if let Some(comparison_fns) = this.general_purpose_evaluators_comparison_functions.get_mut(&evaluator_type_id) { + let (dynamic_evaluator, comparator) = TypeErasedGateEvaluationWrapperVerificationFunction::from_evaluator(this.cs, evaluator, &this.parameters, placement_strategy); + + let mut existing_idx = None; + for (other_comparator, idx) in comparison_fns.iter() { + if other_comparator.equals_to(&comparator) { + existing_idx = Some(*idx); + break; + } + } + + if let Some(_existing_idx) = existing_idx { + // nothing, same evaluator + } else { + if comparison_fns.len() > 0 { + panic!("not yet supported"); + } + let idx = this.evaluators_over_general_purpose_columns.len(); + this.evaluators_over_general_purpose_columns.push(dynamic_evaluator); + // evaluator_type_id_into_evaluator_index_over_general_purpose_columns.insert(evaluator_type_id, idx); + comparison_fns.push((comparator, idx)); + } + + // gate_type_ids_for_general_purpose_columns.push(gate_type_id); + } else { + // new one + let idx = this.evaluators_over_general_purpose_columns.len(); + let (dynamic_evaluator, comparator) = TypeErasedGateEvaluationWrapperVerificationFunction::from_evaluator(this.cs, evaluator, &this.parameters, placement_strategy); + this.evaluators_over_general_purpose_columns.push(dynamic_evaluator); + // gate_type_ids_for_general_purpose_columns.push(gate_type_id); + // evaluator_type_id_into_evaluator_index_over_general_purpose_columns.insert(evaluator_type_id, idx); + this.general_purpose_evaluators_comparison_functions.insert(evaluator_type_id, vec![(comparator, idx)]); + } + } + GatePlacementStrategy::UseSpecializedColumns { num_repetitions, share_constants } => { + // we always add an evaluator + + let (dynamic_evaluator, _comparator) = TypeErasedGateEvaluationWrapperVerificationFunction::from_evaluator(this.cs, evaluator.clone(), &this.parameters, placement_strategy); + + // we need to extend copy-permutation data and witness placement data, + // as well as keep track on offsets into them + + let _idx = this.evaluators_over_specialized_columns.len(); + this.gate_type_ids_for_specialized_columns.push(gate_type_id); + this.evaluators_over_specialized_columns.push(dynamic_evaluator); + // gate_type_id_into_evaluator_index_over_specialized_columns.insert(gate_type_id, idx); + + let principal_width = evaluator.instance_width(); + let mut total_width = principal_width; + + for _ in 1..num_repetitions { + total_width.num_variables += principal_width.num_variables; + total_width.num_witnesses += principal_width.num_witnesses; + if share_constants == false { + total_width.num_constants += principal_width.num_constants; + } + } + + let total_constants_available = principal_width.num_constants; + + if share_constants { + match evaluator.placement_type() { + GatePlacementType::MultipleOnRow { per_chunk_offset: _ } => {} + GatePlacementType::UniqueOnRow => { + panic!("Can not share constants if placement type is unique"); + } + } + } + + let initial_offset = PerChunkOffset { + variables_offset: this.parameters.num_columns_under_copy_permutation + this.total_num_variables_for_specialized_columns, + witnesses_offset: this.parameters.num_witness_columns + this.total_num_witnesses_for_specialized_columns, + constants_offset: this.total_num_constants_for_specialized_columns, // we use separate vector for them + }; + + let offset_per_repetition = if share_constants == false { + PerChunkOffset { + variables_offset: principal_width.num_variables, + witnesses_offset: principal_width.num_witnesses, + constants_offset: principal_width.num_constants, + } + } else { + let offset_per_repetition = match evaluator.placement_type() { + GatePlacementType::MultipleOnRow { per_chunk_offset } => per_chunk_offset, + GatePlacementType::UniqueOnRow => { + panic!("Can not share constants if placement type is unique"); + } + }; + + assert_eq!(offset_per_repetition.variables_offset, principal_width.num_variables); + assert_eq!(offset_per_repetition.witnesses_offset, principal_width.num_witnesses); + + // offset_per_repetition.variables_offset = principal_width.num_variables; + // offset_per_repetition.witnesses_offset = principal_width.num_witnesses; + // and we only leave constants untouched + + offset_per_repetition + }; + + this.offsets_for_specialized_evaluators.push((initial_offset, offset_per_repetition, total_constants_available)); + + this.total_num_variables_for_specialized_columns += total_width.num_variables; + this.total_num_witnesses_for_specialized_columns += total_width.num_witnesses; + this.total_num_constants_for_specialized_columns += total_width.num_constants; + } + } + + CsBuilder { + gates_config: new_configuration, + ..builder + } + } + + fn add_tool, TB: StaticToolboxHolder, M: 'static + Send + Sync + Clone, T: 'static + Send + Sync>( + builder: CsBuilder, + tool: T, + // ) -> CsBuilder> { + ) -> CsBuilder, TB)> { + // TODO: toolbox is not used in the verifier, so perhaps it should be + // moved out to the builder impl so it would not get in the way and just + // hold the type in the phantom. + let new_toolbox = builder.toolbox.add_tool(tool); + + CsBuilder { toolbox: new_toolbox, ..builder } + } + + // type GcWithLookup> = GC; + type GcWithLookup> = (GateTypeEntry, GC); + // GC::DescendantHolder; + + fn allow_lookup, TB: StaticToolboxHolder>( + builder: CsBuilder, + lookup_parameters: LookupParameters, + ) -> CsBuilder, TB> { + let mut builder = builder; + builder.implementation.lookup_parameters = lookup_parameters; + + let (placement_strategy, num_variables, num_constants, share_table_id) = match lookup_parameters { + LookupParameters::NoLookup => { + // this is formal + + ( + GatePlacementStrategy::UseSpecializedColumns { + num_repetitions: 0, + share_constants: false, + }, + 0, + 0, + false, + ) + } + LookupParameters::TableIdAsVariable { width, share_table_id } => { + assert!(share_table_id == false, "other option is not yet implemented"); + // we need to resize multiplicities + assert!(builder.implementation.parameters.num_columns_under_copy_permutation >= (width + 1) as usize); + + (GatePlacementStrategy::UseGeneralPurposeColumns, (width + 1) as usize, 0, share_table_id) + } + LookupParameters::TableIdAsConstant { width, share_table_id } => { + assert!(share_table_id == true, "other option is not yet implemented"); + assert!(builder.implementation.parameters.num_columns_under_copy_permutation >= width as usize); + + (GatePlacementStrategy::UseGeneralPurposeColumns, width as usize, 1, share_table_id) + } + LookupParameters::UseSpecializedColumnsWithTableIdAsVariable { + width, + num_repetitions, + share_table_id, + } => { + assert!(share_table_id == false, "other option is not yet implemented"); + + ( + GatePlacementStrategy::UseSpecializedColumns { + num_repetitions, + share_constants: false, + }, + (width + 1) as usize, + 0, + share_table_id, + ) + } + LookupParameters::UseSpecializedColumnsWithTableIdAsConstant { + width, + num_repetitions, + share_table_id, + } => { + assert!(share_table_id == true, "other option is not yet implemented"); + + ( + GatePlacementStrategy::UseSpecializedColumns { + num_repetitions, + share_constants: share_table_id, + }, + width as usize, + 1, + share_table_id, + ) + } + }; + + Self::allow_gate(builder, placement_strategy, (num_variables, num_constants, share_table_id), (Vec::with_capacity(32), 0)) + } + + fn build<'b, GC: GateConfigurationHolder, TB: StaticToolboxHolder, ARG: Into>>(builder: CsBuilder, _params: ARG) -> Self::Final { + let this: CsWrapperVerifierBuilder = builder.implementation; + + // capture small pieces of information from the gate configuration + assert_eq!(this.evaluators_over_specialized_columns.len(), this.gate_type_ids_for_specialized_columns.len()); + + let capacity = this.evaluators_over_specialized_columns.len(); + let mut placement_strategies = HashMap::with_capacity(capacity); + + for gate_type_id in this.gate_type_ids_for_specialized_columns.iter() { + let placement_strategy = builder.gates_config.placement_strategy_for_type_id(*gate_type_id).expect("gate must be allowed"); + placement_strategies.insert(*gate_type_id, placement_strategy); + } + + WrapperVerifier { + parameters: this.parameters, + lookup_parameters: this.lookup_parameters, + gate_type_ids_for_specialized_columns: this.gate_type_ids_for_specialized_columns, + evaluators_over_specialized_columns: this.evaluators_over_specialized_columns, + offsets_for_specialized_evaluators: this.offsets_for_specialized_evaluators, + evaluators_over_general_purpose_columns: this.evaluators_over_general_purpose_columns, + total_num_variables_for_specialized_columns: this.total_num_variables_for_specialized_columns, + total_num_witnesses_for_specialized_columns: this.total_num_witnesses_for_specialized_columns, + total_num_constants_for_specialized_columns: this.total_num_constants_for_specialized_columns, + placement_strategies, + } + } +} diff --git a/crates/snark-wrapper/src/lib.rs b/crates/snark-wrapper/src/lib.rs new file mode 100644 index 0000000..ddb4952 --- /dev/null +++ b/crates/snark-wrapper/src/lib.rs @@ -0,0 +1,13 @@ +#![feature(array_chunks)] +#![feature(allocator_api)] +#![feature(type_changing_struct_update)] + +pub mod traits; +pub mod verifier; +pub mod verifier_structs; + +pub mod implementations; + +pub extern crate rescue_poseidon; +pub use franklin_crypto::boojum; +pub use rescue_poseidon::franklin_crypto; diff --git a/crates/snark-wrapper/src/traits/circuit.rs b/crates/snark-wrapper/src/traits/circuit.rs new file mode 100644 index 0000000..f6b00da --- /dev/null +++ b/crates/snark-wrapper/src/traits/circuit.rs @@ -0,0 +1,17 @@ +use super::*; + +use crate::boojum::cs::implementations::prover::ProofConfig; +use crate::boojum::cs::{CSGeometry, LookupParameters}; +use crate::verifier_structs::WrapperVerifier; + +pub trait ErasedBuilderForWrapperVerifier> { + fn geometry(&self) -> CSGeometry; + fn lookup_parameters(&self) -> LookupParameters; + fn create_wrapper_verifier(&self, cs: &mut CS) -> WrapperVerifier; +} + +pub trait ProofWrapperFunction { + fn builder_for_wrapper + 'static>(&self) -> Box>; + + fn proof_config_for_compression_step(&self) -> ProofConfig; +} diff --git a/crates/snark-wrapper/src/traits/mod.rs b/crates/snark-wrapper/src/traits/mod.rs new file mode 100644 index 0000000..3007cdc --- /dev/null +++ b/crates/snark-wrapper/src/traits/mod.rs @@ -0,0 +1,11 @@ +use crate::boojum::field::goldilocks::GoldilocksField as GL; + +use crate::franklin_crypto::bellman::pairing::Engine; +use crate::franklin_crypto::bellman::plonk::better_better_cs::cs::ConstraintSystem; +use crate::franklin_crypto::bellman::SynthesisError; +use crate::franklin_crypto::plonk::circuit::boolean::Boolean; +use crate::franklin_crypto::plonk::circuit::goldilocks::*; + +pub mod circuit; +pub mod transcript; +pub mod tree_hasher; diff --git a/crates/snark-wrapper/src/traits/transcript.rs b/crates/snark-wrapper/src/traits/transcript.rs new file mode 100644 index 0000000..9ba16ac --- /dev/null +++ b/crates/snark-wrapper/src/traits/transcript.rs @@ -0,0 +1,64 @@ +use super::*; + +pub trait CircuitGLTranscript: Clone + Send + Sync + std::fmt::Debug { + type CircuitCompatibleCap: Clone; + type TranscriptParameters: Clone + Send + Sync; + + const IS_ALGEBRAIC: bool = true; + + fn new>(cs: &mut CS, params: Self::TranscriptParameters) -> Result; + + fn witness_field_elements>(&mut self, cs: &mut CS, field_els: &[GoldilocksField]) -> Result<(), SynthesisError>; + + fn witness_merkle_tree_cap>(&mut self, cs: &mut CS, cap: &Vec) -> Result<(), SynthesisError>; + + fn get_challenge>(&mut self, cs: &mut CS) -> Result, SynthesisError>; + + fn get_multiple_challenges_fixed, const N: usize>(&mut self, cs: &mut CS) -> Result<[GoldilocksField; N], SynthesisError> { + let mut result = [GoldilocksField::zero(); N]; + for res in result.iter_mut() { + *res = self.get_challenge(cs)?; + } + + Ok(result) + } + + fn get_multiple_challenges>(&mut self, cs: &mut CS, num_challenges: usize) -> Result>, SynthesisError> { + let mut result = Vec::with_capacity(num_challenges); + for _ in 0..num_challenges { + let chal = self.get_challenge(cs)?; + result.push(chal); + } + + Ok(result) + } +} + +pub(crate) struct BoolsBuffer { + pub(crate) available: Vec, + pub(crate) max_needed: usize, +} + +impl BoolsBuffer { + pub fn get_bits, T: CircuitGLTranscript>(&mut self, cs: &mut CS, transcript: &mut T, num_bits: usize) -> Result, SynthesisError> { + if self.available.len() >= num_bits { + let give: Vec<_> = self.available.drain(..num_bits).collect(); + + Ok(give) + } else { + let bits_avaiable = GoldilocksField::::ORDER_BITS - self.max_needed; + + // get 1 field element form transcript + let field_el = transcript.get_challenge(cs)?; + let el_bits = field_el.spread_into_bits::(cs)?; + let mut lsb_iterator = el_bits.iter(); + + for _ in 0..bits_avaiable { + let bit = lsb_iterator.next().unwrap(); + self.available.push(*bit); + } + + self.get_bits(cs, transcript, num_bits) + } + } +} diff --git a/crates/snark-wrapper/src/traits/tree_hasher.rs b/crates/snark-wrapper/src/traits/tree_hasher.rs new file mode 100644 index 0000000..4bf9ead --- /dev/null +++ b/crates/snark-wrapper/src/traits/tree_hasher.rs @@ -0,0 +1,45 @@ +use super::*; + +use crate::boojum::cs::oracle::TreeHasher; + +pub trait CircuitGLTreeHasher: 'static + Clone + Send + Sync { + type NonCircuitSimulator: TreeHasher; + + type CircuitOutput: Sized + + 'static + + Clone + + Copy + + Sync + + Send + // + PartialEq + // + Eq + + std::fmt::Debug; + + fn new>(cs: &mut CS) -> Result; + + fn placeholder_output>(cs: &mut CS) -> Result; + + fn accumulate_into_leaf>(&mut self, cs: &mut CS, value: &GoldilocksField) -> Result<(), SynthesisError>; + + fn finalize_into_leaf_hash_and_reset>(&mut self, cs: &mut CS) -> Result; + + fn hash_into_leaf<'a, S: IntoIterator>, CS: ConstraintSystem>(cs: &mut CS, source: S) -> Result + where + GoldilocksField: 'a; + + fn hash_into_leaf_owned>, CS: ConstraintSystem>(cs: &mut CS, source: S) -> Result; + + fn swap_nodes>( + cs: &mut CS, + should_swap: Boolean, + left: &Self::CircuitOutput, + right: &Self::CircuitOutput, + depth: usize, + ) -> Result<(Self::CircuitOutput, Self::CircuitOutput), SynthesisError>; + + fn hash_into_node>(cs: &mut CS, left: &Self::CircuitOutput, right: &Self::CircuitOutput, depth: usize) -> Result; + + fn select_cap_node>(cs: &mut CS, cap_bits: &[Boolean], cap: &[Self::CircuitOutput]) -> Result; + + fn compare_output>(cs: &mut CS, a: &Self::CircuitOutput, b: &Self::CircuitOutput) -> Result; +} diff --git a/crates/snark-wrapper/src/verifier/first_step.rs b/crates/snark-wrapper/src/verifier/first_step.rs new file mode 100644 index 0000000..bf15ac7 --- /dev/null +++ b/crates/snark-wrapper/src/verifier/first_step.rs @@ -0,0 +1,91 @@ +use super::*; + +/// First of all we should absorb commitments to transcript +/// and get challenges in the following order: +/// - absorb setup commitment +/// - absorb public inputs +/// - absorb witness commitment +/// - get beta and gamma challenges +/// - get lookup_beta and lookup_gamma challenges +/// - absorb stage_2 commitment +/// - get alpha challenge +/// - absorb quotient commitment +/// - get z challenge +/// - absorb evaluations at z +pub(crate) fn verify_first_step + 'static, H: CircuitGLTreeHasher, TR: CircuitGLTranscript>( + cs: &mut CS, + proof: &AllocatedProof, + vk: &AllocatedVerificationKey, + challenges: &mut ChallengesHolder, + transcript: &mut TR, + // parameters + verifier: &WrapperVerifier, + fixed_parameters: &VerificationKeyCircuitGeometry, + constants: &ConstantsHolder, +) -> Result)>)>, SynthesisError> { + // allocate everything + let setup_tree_cap = &vk.setup_merkle_tree_cap; + assert_eq!(fixed_parameters.cap_size, setup_tree_cap.len()); + transcript.witness_merkle_tree_cap(cs, &setup_tree_cap)?; + + if proof.public_inputs.len() != fixed_parameters.public_inputs_locations.len() { + panic!("Invalid number of public inputs"); + } + + let num_public_inputs = proof.public_inputs.len(); + let mut public_inputs_with_values = Vec::with_capacity(num_public_inputs); + let mut public_input_allocated = Vec::with_capacity(num_public_inputs); + + // commit public inputs + for ((column, row), value) in fixed_parameters.public_inputs_locations.iter().copied().zip(proof.public_inputs.iter().copied()) { + transcript.witness_field_elements(cs, &[value])?; + public_input_allocated.push(value); + let value = value.into(); + public_inputs_with_values.push((column, row, value)); + } + + // commit witness + assert_eq!(fixed_parameters.cap_size, proof.witness_oracle_cap.len()); + transcript.witness_merkle_tree_cap(cs, &proof.witness_oracle_cap)?; + + // draw challenges for stage 2 + challenges.get_beta_gamma_challenges(cs, transcript, verifier)?; + + assert_eq!(fixed_parameters.cap_size, proof.stage_2_oracle_cap.len()); + transcript.witness_merkle_tree_cap(cs, &proof.stage_2_oracle_cap)?; + + challenges.get_alpha_powers(cs, transcript, constants)?; + + // commit quotient + assert_eq!(fixed_parameters.cap_size, proof.quotient_oracle_cap.len()); + transcript.witness_merkle_tree_cap(cs, &proof.quotient_oracle_cap)?; + + challenges.get_z_challenge(cs, transcript, fixed_parameters)?; + + // commit claimed values at z, and form our poly storage + for set in proof.values_at_z.iter().chain(proof.values_at_z_omega.iter()).chain(proof.values_at_0.iter()) { + transcript.witness_field_elements(cs, set)?; + } + + // and public inputs should also go into quotient + let mut public_input_opening_tuples: Vec<(GL, Vec<(usize, GoldilocksAsFieldWrapper)>)> = vec![]; + { + let omega = domain_generator_for_size::(fixed_parameters.domain_size as u64); + + for (column, row, value) in public_inputs_with_values.into_iter() { + let open_at = BoojumField::pow_u64(&omega, row as u64); + let pos = public_input_opening_tuples.iter().position(|el| el.0 == open_at); + if let Some(pos) = pos { + public_input_opening_tuples[pos].1.push((column, value)); + } else { + public_input_opening_tuples.push((open_at, vec![(column, value)])); + } + } + } + + assert_eq!(proof.values_at_z.len(), constants.num_poly_values_at_z); + assert_eq!(proof.values_at_z_omega.len(), constants.num_poly_values_at_z_omega); + assert_eq!(proof.values_at_0.len(), constants.num_poly_values_at_zero); + + Ok(public_input_opening_tuples) +} diff --git a/crates/snark-wrapper/src/verifier/fri.rs b/crates/snark-wrapper/src/verifier/fri.rs new file mode 100644 index 0000000..c0eab25 --- /dev/null +++ b/crates/snark-wrapper/src/verifier/fri.rs @@ -0,0 +1,585 @@ +use super::*; + +use crate::traits::transcript::BoolsBuffer; +use crate::verifier_structs::allocated_queries::AllocatedSingleRoundQueries; + +pub(crate) fn verify_fri_part + 'static, H: CircuitGLTreeHasher, TR: CircuitGLTranscript>( + cs: &mut CS, + proof: &AllocatedProof, + vk: &AllocatedVerificationKey, + challenges: &mut ChallengesHolder, + transcript: &mut TR, + public_input_opening_tuples: Vec<(GL, Vec<(usize, GoldilocksAsFieldWrapper)>)>, + // parameters + verifier: &WrapperVerifier, + fixed_parameters: &VerificationKeyCircuitGeometry, + constants: &ConstantsHolder, +) -> Result, SynthesisError> { + let mut validity_flags = vec![]; + + // get challenges + challenges.get_challenges_for_fri_quotiening(cs, transcript, constants.total_num_challenges_for_fri_quotiening)?; + challenges.get_fri_intermediate_challenges(cs, transcript, proof, fixed_parameters, constants)?; + + assert_eq!(constants.final_expected_degree as usize, proof.final_fri_monomials[0].len()); + assert_eq!(constants.final_expected_degree as usize, proof.final_fri_monomials[1].len()); + assert!(proof.final_fri_monomials[0].len() > 0); + + // witness monomial coeffs + transcript.witness_field_elements(cs, &proof.final_fri_monomials[0])?; + transcript.witness_field_elements(cs, &proof.final_fri_monomials[1])?; + + assert_eq!(constants.new_pow_bits, 0, "PoW not supported yet"); + // if new_pow_bits != 0 { + // log!("Doing PoW verification for {} bits", new_pow_bits); + // // log!("Prover gave challenge 0x{:016x}", proof.pow_challenge); + + // // pull enough challenges from the transcript + // let mut num_challenges = 256 / F::CHAR_BITS; + // if num_challenges % F::CHAR_BITS != 0 { + // num_challenges += 1; + // } + // let _challenges: Vec<_> = transcript.get_multiple_challenges(cs, num_challenges); + + // todo!() + // } + + let max_needed_bits = (fixed_parameters.domain_size * fixed_parameters.fri_lde_factor as u64).trailing_zeros() as usize; + + let mut bools_buffer = BoolsBuffer { + available: vec![], + max_needed: max_needed_bits, + }; + + let num_bits_for_in_coset_index = max_needed_bits - fixed_parameters.fri_lde_factor.trailing_zeros() as usize; + let base_tree_index_shift = fixed_parameters.domain_size.trailing_zeros(); + assert_eq!(num_bits_for_in_coset_index, base_tree_index_shift as usize); + + assert_eq!(constants.num_fri_repetitions, proof.queries_per_fri_repetition.len()); + + let multiplicative_generator = GoldilocksAsFieldWrapper::constant(GL::multiplicative_generator(), cs); + + // precompute once, will be handy later + let mut precomputed_powers = vec![]; + let mut precomputed_powers_inversed = vec![]; + for i in 0..=(fixed_parameters.domain_size * fixed_parameters.fri_lde_factor as u64).trailing_zeros() { + let omega = domain_generator_for_size::(1u64 << i); + precomputed_powers.push(omega); + precomputed_powers_inversed.push(BoojumPrimeField::inverse(&omega).unwrap()); + } + + // we also want to precompute "steps" for different interpolation degrees + // e.g. if we interpolate 8 elements, + // then those will be ordered as bitreverses of [0..=7], namely + // [0, 4, 2, 6, 1, 5, 3, 7] + + // so we want to have exactly half of it, because separation by 4 + // is exactly -1, so we need [1, sqrt4(1), sqrt8(1), sqrt4(1)*sqrt8(1)] + + let mut interpolation_steps = vec![GL::ONE; 4]; // max size + + for idx in [1, 3].into_iter() { + BoojumField::mul_assign(&mut interpolation_steps[idx], &precomputed_powers_inversed[2]); + } + for idx in [2, 3].into_iter() { + BoojumField::mul_assign(&mut interpolation_steps[idx], &precomputed_powers_inversed[3]); + } + + assert_eq!(interpolation_steps[0], GL::ONE); + assert_eq!(BoojumField::pow_u64(&interpolation_steps[1], 4), GL::ONE); + assert_eq!(BoojumField::pow_u64(&interpolation_steps[2], 8), GL::ONE); + + let precomputed_powers: Vec<_> = precomputed_powers.into_iter().map(|el| GoldilocksAsFieldWrapper::constant(el, cs)).collect(); + let precomputed_powers_inversed: Vec<_> = precomputed_powers_inversed.into_iter().map(|el| GoldilocksAsFieldWrapper::constant(el, cs)).collect(); + let interpolation_steps: Vec<_> = interpolation_steps.into_iter().map(|el| GoldilocksAsFieldWrapper::constant(el, cs)).collect(); + + let base_oracle_depth = fixed_parameters.base_oracles_depth(); + + for queries in proof.queries_per_fri_repetition.iter() { + let query_index_lsb_first_bits = bools_buffer.get_bits(cs, transcript, max_needed_bits)?; + + // we consider it to be some convenient for us encoding of coset + inner index. + + // Small note on indexing: when we commit to elements we use bitreversal enumeration everywhere. + // So index `i` in the tree corresponds to the element of `omega^bitreverse(i)`. + // This gives us natural separation of LDE cosets, such that subtrees form independent cosets, + // and if cosets are in the form of `{1, gamma, ...} x {1, omega, ...} where gamma^lde_factor == omega, + // then subtrees are enumerated by bitreverse powers of gamma + + // let inner_idx = &query_index_lsb_first_bits[0..num_bits_for_in_coset_index]; + // let coset_idx = &query_index_lsb_first_bits[num_bits_for_in_coset_index..]; + let base_tree_idx = query_index_lsb_first_bits.clone(); + + // first verify basic inclusion proofs + validity_flags.extend(verify_inclusion_proofs(cs, queries, proof, vk, &base_tree_idx, constants, base_oracle_depth)?); + + // now perform the quotiening operation + let zero_ext = GoldilocksExtAsFieldWrapper::::zero(cs); + let mut simulated_ext_element = zero_ext; + + assert_eq!(query_index_lsb_first_bits.len(), precomputed_powers.len() - 1); + + let domain_element = pow_from_precomputations(cs, &precomputed_powers[1..], &query_index_lsb_first_bits); + + // we will find it handy to have power of the generator with some bits masked to be zero + let mut power_chunks = vec![]; + let mut skip_highest_powers = 0; + // TODO: we may save here (in circuits case especially) if we compute recursively + for interpolation_degree_log2 in constants.fri_folding_schedule.iter() { + let domain_element = pow_from_precomputations( + cs, + &precomputed_powers_inversed[(1 + interpolation_degree_log2)..], + &query_index_lsb_first_bits[(skip_highest_powers + interpolation_degree_log2)..], + ); + + skip_highest_powers += *interpolation_degree_log2; + power_chunks.push(domain_element); + } + + // don't forget that we are shifted + let mut domain_element_for_quotiening = domain_element; + domain_element_for_quotiening.mul_assign(&multiplicative_generator, cs); + + let mut domain_element_for_interpolation = domain_element_for_quotiening; + + verify_quotening_operations( + cs, + &mut simulated_ext_element, + queries, + proof, + &public_input_opening_tuples, + domain_element_for_quotiening, + challenges, + verifier, + constants, + )?; + + let base_coset_inverse = BoojumPrimeField::inverse(&GL::multiplicative_generator()).unwrap(); + + let mut current_folded_value: GoldilocksExtAsFieldWrapper = simulated_ext_element; + let mut subidx = base_tree_idx; + let mut coset_inverse = base_coset_inverse; + + let mut expected_fri_query_len = base_oracle_depth; + + for (idx, (interpolation_degree_log2, fri_query)) in constants.fri_folding_schedule.iter().zip(queries.fri_queries.iter()).enumerate() { + expected_fri_query_len -= *interpolation_degree_log2; + let interpolation_degree = 1 << *interpolation_degree_log2; + let subidx_in_leaf = &subidx[..*interpolation_degree_log2]; + let tree_idx = &subidx[*interpolation_degree_log2..]; + + assert_eq!(fri_query.leaf_elements.len(), interpolation_degree * 2); + + let [c0, c1] = current_folded_value.into_coeffs_in_base(); + + let c0_from_leaf = binary_select(cs, &fri_query.leaf_elements[..interpolation_degree], subidx_in_leaf)?; + let c1_from_leaf = binary_select(cs, &fri_query.leaf_elements[interpolation_degree..], subidx_in_leaf)?; + + let c0_is_valid = GoldilocksField::equals(cs, &c0, &c0_from_leaf)?; + let c1_is_valid = GoldilocksField::equals(cs, &c1, &c1_from_leaf)?; + + validity_flags.push(c0_is_valid); + validity_flags.push(c1_is_valid); + + // verify query itself + let cap = if idx == 0 { &proof.fri_base_oracle_cap } else { &proof.fri_intermediate_oracles_caps[idx - 1] }; + assert_eq!(fri_query.proof.len(), expected_fri_query_len); + validity_flags.push(check_if_included::(cs, &fri_query.leaf_elements, &fri_query.proof, &cap, tree_idx)?); + + // interpolate + let mut elements_to_interpolate = Vec::with_capacity(interpolation_degree); + for (c0, c1) in fri_query.leaf_elements[..interpolation_degree].iter().zip(fri_query.leaf_elements[interpolation_degree..].iter()) { + let as_ext = GoldilocksExtAsFieldWrapper::::from_coeffs_in_base([*c0, *c1]); + elements_to_interpolate.push(as_ext); + } + + let mut next = Vec::with_capacity(interpolation_degree / 2); + let challenges = &challenges.fri_intermediate_challenges[idx]; + assert_eq!(challenges.len(), *interpolation_degree_log2); + + let mut base_pow = power_chunks[idx]; + + for challenge in challenges.iter() { + for (i, [a, b]) in elements_to_interpolate.array_chunks::<2>().enumerate() { + let mut result = *a; + result.add_assign(b, cs); + + let mut diff = *a; + diff.sub_assign(&b, cs); + diff.mul_assign(&challenge, cs); + // divide by corresponding power + let mut pow = base_pow; + pow.mul_assign(&interpolation_steps[i], cs); + let coset_inverse = GoldilocksAsFieldWrapper::constant(coset_inverse, cs); + pow.mul_assign(&coset_inverse, cs); + + GoldilocksExtAsFieldWrapper::::mul_by_base_and_accumulate_into(&mut result, &pow, &diff, cs)?; + + // diff.mul_assign_by_base(&pow, cs); + // result.add_assign(&diff, cs); + + next.push(result); + } + + std::mem::swap(&mut next, &mut elements_to_interpolate); + next.clear(); + base_pow.square(cs); + BoojumField::square(&mut coset_inverse); + } + + for _ in 0..*interpolation_degree_log2 { + domain_element_for_interpolation.square(cs); + } + + // recompute the index + subidx = tree_idx.to_vec(); + current_folded_value = elements_to_interpolate[0]; + } + + // and we should evaluate monomial form and compare + + let mut result_from_monomial = zero_ext; + // horner rule + for (c0, c1) in proof.final_fri_monomials[0].iter().zip(proof.final_fri_monomials[1].iter()).rev() { + let coeff = GoldilocksExtAsFieldWrapper::::from_coeffs_in_base([*c0, *c1]); + + // result_from_monomial = result_from_monomial * z + coeff + + let mut tmp = coeff; + GoldilocksExtAsFieldWrapper::::mul_by_base_and_accumulate_into(&mut tmp, &domain_element_for_interpolation, &result_from_monomial, cs)?; + + result_from_monomial = tmp; + + // result_from_monomial.mul_assign_by_base(&domain_element_for_interpolation, cs); + // result_from_monomial.add_assign(&coeff, cs); + } + + let result_from_monomial = result_from_monomial.into_coeffs_in_base(); + let current_folded_value = current_folded_value.into_coeffs_in_base(); + + let c0_is_valid = GoldilocksField::equals(cs, &result_from_monomial[0], ¤t_folded_value[0])?; + let c1_is_valid = GoldilocksField::equals(cs, &result_from_monomial[1], ¤t_folded_value[1])?; + + validity_flags.push(c0_is_valid); + validity_flags.push(c1_is_valid); + } + + Ok(validity_flags) +} + +fn verify_inclusion_proofs + 'static, H: CircuitGLTreeHasher>( + cs: &mut CS, + queries: &AllocatedSingleRoundQueries, + proof: &AllocatedProof, + vk: &AllocatedVerificationKey, + base_tree_idx: &Vec, + constants: &ConstantsHolder, + base_oracle_depth: usize, +) -> Result, SynthesisError> { + let mut validity_flags = Vec::new(); + + assert_eq!(constants.witness_leaf_size, queries.witness_query.leaf_elements.len()); + assert_eq!(base_oracle_depth, queries.witness_query.proof.len()); + validity_flags.push(check_if_included::( + cs, + &queries.witness_query.leaf_elements, + &queries.witness_query.proof, + &proof.witness_oracle_cap, + &base_tree_idx, + )?); + + assert_eq!(constants.stage_2_leaf_size, queries.stage_2_query.leaf_elements.len()); + assert_eq!(base_oracle_depth, queries.stage_2_query.proof.len()); + validity_flags.push(check_if_included::( + cs, + &queries.stage_2_query.leaf_elements, + &queries.stage_2_query.proof, + &proof.stage_2_oracle_cap, + &base_tree_idx, + )?); + + assert_eq!(constants.quotient_leaf_size, queries.quotient_query.leaf_elements.len()); + assert_eq!(base_oracle_depth, queries.quotient_query.proof.len()); + validity_flags.push(check_if_included::( + cs, + &queries.quotient_query.leaf_elements, + &queries.quotient_query.proof, + &proof.quotient_oracle_cap, + &base_tree_idx, + )?); + + assert_eq!(constants.setup_leaf_size, queries.setup_query.leaf_elements.len()); + assert_eq!(base_oracle_depth, queries.setup_query.proof.len()); + validity_flags.push(check_if_included::( + cs, + &queries.setup_query.leaf_elements, + &queries.setup_query.proof, + &vk.setup_merkle_tree_cap, + &base_tree_idx, + )?); + + Ok(validity_flags) +} + +fn verify_quotening_operations + 'static, H: CircuitGLTreeHasher>( + cs: &mut CS, + simulated_ext_element: &mut GoldilocksExtAsFieldWrapper, + queries: &AllocatedSingleRoundQueries, + proof: &AllocatedProof, + public_input_opening_tuples: &Vec<(GL, Vec<(usize, GoldilocksAsFieldWrapper)>)>, + domain_element_for_quotiening: GoldilocksAsFieldWrapper, + challenges: &ChallengesHolder, + verifier: &WrapperVerifier, + constants: &ConstantsHolder, +) -> Result<(), SynthesisError> { + let zero_num = GoldilocksField::zero(); + let zero_base = GoldilocksAsFieldWrapper::::zero(cs); + let zero_ext = GoldilocksExtAsFieldWrapper::::zero(cs); + + let mut challenge_offset = 0; + + let z_polys_offset = 0; + let intermediate_polys_offset = 2; + let lookup_witness_encoding_polys_offset = intermediate_polys_offset + constants.num_intermediate_partial_product_relations * 2; + let lookup_multiplicities_encoding_polys_offset = lookup_witness_encoding_polys_offset + constants.num_lookup_subarguments * 2; + let copy_permutation_polys_offset = 0; + let constants_offset = 0 + constants.num_copy_permutation_polys; + let lookup_tables_values_offset = 0 + constants.num_copy_permutation_polys + constants.num_constant_polys; + let variables_offset = 0; + let witness_columns_offset = constants.num_variable_polys; + let lookup_multiplicities_offset = witness_columns_offset + constants.num_witness_polys; + + let evaluations = EvaluationsHolder::from_proof(proof); + + { + let z = challenges.z; + let z_omega = challenges.z_omega; + + let cast_from_base = move |el: &[GoldilocksField]| { + el.iter() + .map(|el| GoldilocksExtAsFieldWrapper::::from_coeffs_in_base([*el, GoldilocksField::zero()])) + .collect::>() + }; + + let cast_from_extension = move |el: &[GoldilocksField]| { + assert_eq!(el.len() % 2, 0); + + el.array_chunks::<2>() + .map(|[c0, c1]| GoldilocksExtAsFieldWrapper::::from_coeffs_in_base([*c0, *c1])) + .collect::>() + }; + + let mut sources = vec![]; + // witness + sources.extend(cast_from_base( + &queries.witness_query.leaf_elements[variables_offset..(variables_offset + constants.num_variable_polys)], + )); + sources.extend(cast_from_base( + &queries.witness_query.leaf_elements[witness_columns_offset..(witness_columns_offset + constants.num_witness_polys)], + )); + // normal setup + sources.extend(cast_from_base(&queries.setup_query.leaf_elements[constants_offset..(constants_offset + constants.num_constant_polys)])); + sources.extend(cast_from_base( + &queries.setup_query.leaf_elements[copy_permutation_polys_offset..(copy_permutation_polys_offset + constants.num_copy_permutation_polys)], + )); + // copy-permutation + sources.extend(cast_from_extension(&queries.stage_2_query.leaf_elements[z_polys_offset..intermediate_polys_offset])); + sources.extend(cast_from_extension( + &queries.stage_2_query.leaf_elements[intermediate_polys_offset..lookup_witness_encoding_polys_offset], + )); + // lookup if exists + sources.extend(cast_from_base( + &queries.witness_query.leaf_elements[lookup_multiplicities_offset..(lookup_multiplicities_offset + constants.num_multiplicities_polys)], + )); + sources.extend(cast_from_extension( + &queries.stage_2_query.leaf_elements[lookup_witness_encoding_polys_offset..lookup_multiplicities_encoding_polys_offset], + )); + sources.extend(cast_from_extension(&queries.stage_2_query.leaf_elements[lookup_multiplicities_encoding_polys_offset..])); + // lookup setup + if verifier.lookup_parameters.lookup_is_allowed() { + let num_lookup_setups = verifier.lookup_parameters.lookup_width() + 1; + sources.extend(cast_from_base( + &queries.setup_query.leaf_elements[lookup_tables_values_offset..(lookup_tables_values_offset + num_lookup_setups)], + )); + } + // quotient + sources.extend(cast_from_extension(&queries.quotient_query.leaf_elements)); + + assert_eq!(sources.len(), evaluations.all_values_at_z.len()); + // log!("Making quotiening at Z"); + quotening_operation( + cs, + simulated_ext_element, + &sources, + &evaluations.all_values_at_z, + domain_element_for_quotiening, + z, + &challenges.challenges_for_fri_quotiening[challenge_offset..(challenge_offset + sources.len())], + ); + challenge_offset += sources.len(); + + // now z*omega + let mut sources = vec![]; + sources.extend(cast_from_extension(&queries.stage_2_query.leaf_elements[z_polys_offset..intermediate_polys_offset])); + + assert_eq!(sources.len(), evaluations.all_values_at_z_omega.len()); + // log!("Making quotiening at Z*omega"); + quotening_operation( + cs, + simulated_ext_element, + &sources, + &evaluations.all_values_at_z_omega, + domain_element_for_quotiening, + z_omega, + &challenges.challenges_for_fri_quotiening[challenge_offset..(challenge_offset + sources.len())], + ); + + challenge_offset += sources.len(); + // now at 0 if lookup is needed + if verifier.lookup_parameters.lookup_is_allowed() { + let mut sources = vec![]; + // witness encoding + sources.extend(cast_from_extension( + &queries.stage_2_query.leaf_elements[lookup_witness_encoding_polys_offset..lookup_multiplicities_encoding_polys_offset], + )); + // multiplicities encoding + sources.extend(cast_from_extension(&queries.stage_2_query.leaf_elements[lookup_multiplicities_encoding_polys_offset..])); + + assert_eq!(sources.len(), evaluations.all_values_at_0.len()); + // log!("Making quotiening at 0 for lookups sumchecks"); + quotening_operation( + cs, + simulated_ext_element, + &sources, + &evaluations.all_values_at_0, + domain_element_for_quotiening, + zero_ext, + &challenges.challenges_for_fri_quotiening[challenge_offset..(challenge_offset + sources.len())], + ); + + challenge_offset += sources.len(); + } + } + + // and public inputs + for (open_at, set) in public_input_opening_tuples.iter() { + let mut sources = Vec::with_capacity(set.len()); + let mut values = Vec::with_capacity(set.len()); + for (column, expected_value) in set.into_iter() { + let c0 = queries.witness_query.leaf_elements[*column]; + let el = GoldilocksExtAsFieldWrapper::::from_coeffs_in_base([c0, zero_num]); + sources.push(el); + + let value = GoldilocksExtAsFieldWrapper::::from_wrapper_coeffs_in_base([*expected_value, zero_base]); + values.push(value); + } + let num_challenges_required = sources.len(); + assert_eq!(values.len(), num_challenges_required); + + // log!("Making quotiening at {} for public inputs", open_at); + + let open_at = GoldilocksAsFieldWrapper::constant(*open_at, cs); + let open_at = GoldilocksExtAsFieldWrapper::::from_wrapper_coeffs_in_base([open_at, zero_base]); + + quotening_operation( + cs, + simulated_ext_element, + &sources, + &values, + domain_element_for_quotiening, + open_at, + &challenges.challenges_for_fri_quotiening[challenge_offset..(challenge_offset + sources.len())], + ); + + challenge_offset += num_challenges_required; + } + + assert_eq!(challenge_offset, challenges.challenges_for_fri_quotiening.len()); + + Ok(()) +} + +fn check_if_included, H: CircuitGLTreeHasher>( + cs: &mut CS, + leaf_elements: &Vec>, + proof: &Vec, + tree_cap: &Vec, + path: &[Boolean], +) -> Result { + let leaf_hash = >::hash_into_leaf(cs, leaf_elements)?; + verify_proof_over_cap::(cs, proof, tree_cap, &leaf_hash, path) +} + +pub fn verify_proof_over_cap, CS: ConstraintSystem>( + cs: &mut CS, + proof: &[H::CircuitOutput], + cap: &[H::CircuitOutput], + leaf_hash: &H::CircuitOutput, + path: &[Boolean], +) -> Result { + assert!(path.len() >= proof.len()); + + let mut current = leaf_hash.clone(); + let path_bits = &path[..proof.len()]; + let cap_bits = &path[proof.len()..]; + + for (proof_el, path_bit) in proof.iter().zip(path_bits.iter()) { + let (left, right) = H::swap_nodes(cs, *path_bit, ¤t, &proof_el, 0)?; + current = >::hash_into_node(cs, &left, &right, 0)?; + } + + let selected_cap_el = H::select_cap_node(cs, cap_bits, cap)?; + + H::compare_output(cs, ¤t, &selected_cap_el) +} + +fn pow_from_precomputations + 'static>(cs: &mut CS, bases: &[GoldilocksAsFieldWrapper], bits: &[Boolean]) -> GoldilocksAsFieldWrapper { + let mut result = GoldilocksAsFieldWrapper::::one(cs); + + for (base, bit) in bases.iter().zip(bits.iter()) { + let mut tmp = result; + tmp.mul_assign(base, cs); + result = GoldilocksAsFieldWrapper::conditionally_select(cs, *bit, &tmp, &result); + } + + result +} + +fn quotening_operation + 'static>( + cs: &mut CS, + dst: &mut GoldilocksExtAsFieldWrapper, + polynomial_values: &Vec>, + values_at: &Vec>, + domain_element: GoldilocksAsFieldWrapper, + at: GoldilocksExtAsFieldWrapper, + challenges: &[GoldilocksExtAsFieldWrapper], +) { + // we precompute challenges outside to avoid any manual extension ops here + assert_eq!(polynomial_values.len(), values_at.len()); + assert_eq!(polynomial_values.len(), challenges.len()); + + let zero_base = GoldilocksAsFieldWrapper::zero(cs); + + let mut denom = GoldilocksExtAsFieldWrapper::::from_wrapper_coeffs_in_base([domain_element, zero_base]); + denom.sub_assign(&at, cs); + denom = denom.inverse(cs); + + let mut acc = GoldilocksExtAsFieldWrapper::::zero(cs); + + for ((poly_value, value_at), challenge) in polynomial_values.iter().zip(values_at.iter()).zip(challenges.iter()) { + // (f(x) - f(z))/(x - z) + let mut tmp = *poly_value; + tmp.sub_assign(&value_at, cs); + + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(&mut acc, &tmp, &challenge, cs); + + // let mut as_ext = *challenge; + // as_ext.mul_assign(&tmp, cs); + // acc.add_assign(&as_ext, cs); + } + + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(dst, &acc, &denom, cs); + + // acc.mul_assign(&denom, cs); + // dst.add_assign(&acc, cs); +} diff --git a/crates/snark-wrapper/src/verifier/mod.rs b/crates/snark-wrapper/src/verifier/mod.rs new file mode 100644 index 0000000..0f5d3cb --- /dev/null +++ b/crates/snark-wrapper/src/verifier/mod.rs @@ -0,0 +1,186 @@ +use crate::boojum::cs::implementations::proof::Proof; +use crate::boojum::cs::implementations::prover::ProofConfig; +use crate::boojum::cs::implementations::utils::domain_generator_for_size; +use crate::boojum::cs::implementations::verifier::VerificationKey; +use crate::boojum::cs::implementations::verifier::VerificationKeyCircuitGeometry; +use crate::boojum::cs::oracle::TreeHasher; +use crate::boojum::cs::LookupParameters; +use crate::boojum::field::goldilocks::{GoldilocksExt2 as GLExt2, GoldilocksField as GL}; +use crate::boojum::field::traits::field_like::PrimeFieldLike; +use crate::boojum::field::Field as BoojumField; +use crate::boojum::field::PrimeField as BoojumPrimeField; + +use crate::franklin_crypto::bellman::pairing::Engine; +use crate::franklin_crypto::bellman::plonk::better_better_cs::cs::ConstraintSystem; +use crate::franklin_crypto::bellman::plonk::better_better_cs::cs::*; +use crate::franklin_crypto::bellman::plonk::better_better_cs::gates::selector_optimized_with_d_next::SelectorOptimizedWidth4MainGateWithDNext; +use crate::franklin_crypto::bellman::{Field, PrimeField, PrimeFieldRepr, SynthesisError}; +use crate::franklin_crypto::plonk::circuit::allocated_num::{AllocatedNum, Num}; +use crate::franklin_crypto::plonk::circuit::bigint_new::BITWISE_LOGICAL_OPS_TABLE_NAME; +use crate::franklin_crypto::plonk::circuit::boolean::Boolean; +use crate::franklin_crypto::plonk::circuit::custom_rescue_gate::Rescue5CustomGate; +use crate::franklin_crypto::plonk::circuit::goldilocks::prime_field_like::{GoldilocksAsFieldWrapper, GoldilocksExtAsFieldWrapper}; +use crate::franklin_crypto::plonk::circuit::goldilocks::GoldilocksField; +use crate::franklin_crypto::plonk::circuit::linear_combination::LinearCombination; +use crate::franklin_crypto::plonk::circuit::Assignment; + +use crate::traits::circuit::*; +use crate::traits::transcript::CircuitGLTranscript; +use crate::traits::tree_hasher::CircuitGLTreeHasher; +use crate::verifier_structs::allocated_vk::AllocatedVerificationKey; +use crate::verifier_structs::challenges::{ChallengesHolder, EvaluationsHolder}; +use crate::verifier_structs::constants::ConstantsHolder; +use crate::verifier_structs::{allocated_proof::*, *}; + +mod first_step; +mod fri; +mod quotient_contributions; +pub(crate) mod utils; + +use first_step::*; +use fri::*; +use quotient_contributions::*; +use utils::*; + +#[derive(Clone, Debug)] +pub struct WrapperCircuit< + E: Engine, + HS: TreeHasher, + H: CircuitGLTreeHasher, NonCircuitSimulator = HS>, + TR: CircuitGLTranscript, + PWF: ProofWrapperFunction, +> { + pub witness: Option>, + pub vk: VerificationKey, + pub fixed_parameters: VerificationKeyCircuitGeometry, + pub transcript_params: TR::TranscriptParameters, + pub wrapper_function: PWF, +} + +impl< + E: Engine, + HS: TreeHasher, + H: CircuitGLTreeHasher, NonCircuitSimulator = HS>, + TR: CircuitGLTranscript, + PWF: ProofWrapperFunction, + > Circuit for WrapperCircuit +{ + type MainGate = SelectorOptimizedWidth4MainGateWithDNext; + + fn declare_used_gates() -> Result>>, SynthesisError> { + Ok(vec![Self::MainGate::default().into_internal(), Rescue5CustomGate::default().into_internal()]) + } + + fn synthesize + 'static>(&self, cs: &mut CS) -> Result<(), SynthesisError> { + // Add table for range check + let columns3 = vec![PolyIdentifier::VariablesPolynomial(0), PolyIdentifier::VariablesPolynomial(1), PolyIdentifier::VariablesPolynomial(2)]; + + let name = BITWISE_LOGICAL_OPS_TABLE_NAME; + let bitwise_logic_table = LookupTableApplication::new(name, TwoKeysOneValueBinopTable::::new(8, name), columns3.clone(), None, true); + cs.add_table(bitwise_logic_table).unwrap(); + + // Prepare for proof verification + let verifier_builder = self.wrapper_function.builder_for_wrapper(); + let verifier = verifier_builder.create_wrapper_verifier(cs); + + let proof_config = self.wrapper_function.proof_config_for_compression_step(); + let fixed_parameters = self.fixed_parameters.clone(); + + let vk = AllocatedVerificationKey::::allocate_constant(&self.vk, &fixed_parameters); + + let proof: AllocatedProof = AllocatedProof::allocate_from_witness(cs, &self.witness, &verifier, &fixed_parameters, &proof_config)?; + + // Verify proof + let correct = crate::verifier::verify::(cs, self.transcript_params.clone(), &proof_config, &proof, &verifier, &fixed_parameters, &vk)?; + Boolean::enforce_equal(cs, &correct, &Boolean::constant(true))?; + + // Aggregate PI + let _pi = aggregate_public_inputs(cs, &proof.public_inputs)?; + + Ok(()) + } +} + +pub fn verify< + E: Engine, + CS: ConstraintSystem + 'static, + H: CircuitGLTreeHasher, + TR: CircuitGLTranscript, + // TODO POW +>( + cs: &mut CS, + transcript_params: TR::TranscriptParameters, + proof_config: &ProofConfig, + proof: &AllocatedProof, + verifier: &WrapperVerifier, + fixed_parameters: &VerificationKeyCircuitGeometry, + vk: &AllocatedVerificationKey, +) -> Result { + let mut validity_flags = Vec::with_capacity(256); + + let mut transcript = TR::new(cs, transcript_params)?; + let mut challenges = ChallengesHolder::new(cs); + + // prepare constants + let constants = ConstantsHolder::generate(proof_config, verifier, fixed_parameters); + assert_eq!(fixed_parameters.cap_size, vk.setup_merkle_tree_cap.len()); + + let public_input_opening_tuples = verify_first_step(cs, proof, vk, &mut challenges, &mut transcript, verifier, fixed_parameters, &constants)?; + + validity_flags.extend(check_quotient_contributions_in_z(cs, proof, &challenges, verifier, fixed_parameters, &constants)?); + + validity_flags.extend(verify_fri_part::( + cs, + proof, + vk, + &mut challenges, + &mut transcript, + public_input_opening_tuples, + verifier, + fixed_parameters, + &constants, + )?); + + let correct = smart_and(cs, &validity_flags)?; + + Ok(correct) +} + +/// aggregate public inputs to one scalar field element +fn aggregate_public_inputs>(cs: &mut CS, public_inputs: &[GoldilocksField]) -> Result, SynthesisError> { + let chunk_bit_size = (GL::CAPACITY_BITS / 8) * 8; + assert!( + public_inputs.len() * chunk_bit_size <= E::Fr::CAPACITY as usize, + "scalar field capacity is not enough to fit all public inputs" + ); + + // Firstly we check that public inputs have correct size + use crate::franklin_crypto::plonk::circuit::bigint_new::enforce_range_check_using_bitop_table; + for pi in public_inputs.iter() { + let table = cs.get_table(BITWISE_LOGICAL_OPS_TABLE_NAME).unwrap(); + enforce_range_check_using_bitop_table(cs, &pi.into_num().get_variable(), chunk_bit_size, table, false)?; + } + + // compute aggregated pi value + let mut tmp = E::Fr::one(); + let mut shift_repr = ::Repr::from(1); + shift_repr.shl(chunk_bit_size as u32); + let shift = E::Fr::from_repr(shift_repr).unwrap(); + + let mut lc = LinearCombination::::zero(); + for pi in public_inputs.iter().rev() { + lc.add_assign_number_with_coeff(&pi.into_num(), tmp); + tmp.mul_assign(&shift); + } + + // allocate as pi + let pi = Num::Variable(AllocatedNum::alloc_input(cs, || Ok(*lc.get_value().get()?))?); + + // check sum + let mut minus_one = E::Fr::one(); + minus_one.negate(); + lc.add_assign_number_with_coeff(&pi, minus_one); + lc.enforce_zero(cs)?; + + Ok(pi) +} diff --git a/crates/snark-wrapper/src/verifier/quotient_contributions.rs b/crates/snark-wrapper/src/verifier/quotient_contributions.rs new file mode 100644 index 0000000..17c78e8 --- /dev/null +++ b/crates/snark-wrapper/src/verifier/quotient_contributions.rs @@ -0,0 +1,651 @@ +use super::*; + +use std::collections::HashMap; + +use crate::boojum::cs::gates::lookup_marker::LookupFormalGate; +use crate::boojum::cs::gates::lookup_marker::LookupGateMarkerFormalEvaluator; +use crate::boojum::cs::implementations::copy_permutation::non_residues_for_copy_permutation; +use crate::boojum::cs::implementations::verifier::*; +use crate::boojum::cs::traits::gate::GatePlacementStrategy; +use std::alloc::Global; + +/// Run verifier at z. +/// We should check: +/// - lookup contribution +/// - specialized gates contribution +/// - general purpose gates contribution +/// - copy permutation contribution +pub(crate) fn check_quotient_contributions_in_z + 'static, H: CircuitGLTreeHasher>( + cs: &mut CS, + proof: &AllocatedProof, + challenges: &ChallengesHolder, + // parameters + verifier: &WrapperVerifier, + fixed_parameters: &VerificationKeyCircuitGeometry, + constants: &ConstantsHolder, +) -> Result, SynthesisError> { + let mut validity_flags = vec![]; + + let zero_ext = GoldilocksExtAsFieldWrapper::::zero(cs); + let one_ext = GoldilocksExtAsFieldWrapper::::one(cs); + + let non_residues_for_copy_permutation = non_residues_for_copy_permutation::(fixed_parameters.domain_size as usize, constants.num_variable_polys); + + let non_residues_for_copy_permutation: Vec<_> = non_residues_for_copy_permutation.into_iter().map(|el| GoldilocksAsFieldWrapper::constant(el, cs)).collect(); + + let evaluations = EvaluationsHolder::from_proof(proof); + + let mut source_it = evaluations.all_values_at_z.iter(); + // witness + let variables_polys_values: Vec<_> = (&mut source_it).take(constants.num_variable_polys).copied().collect(); + let witness_polys_values: Vec<_> = (&mut source_it).take(constants.num_witness_polys).copied().collect(); + // normal setup + let constant_poly_values: Vec<_> = (&mut source_it).take(constants.num_constant_polys).copied().collect(); + let sigmas_values: Vec<_> = (&mut source_it).take(constants.num_copy_permutation_polys).copied().collect(); + let copy_permutation_z_at_z = *source_it.next().unwrap(); + let grand_product_intermediate_polys: Vec<_> = (&mut source_it).take(constants.num_intermediate_partial_product_relations).copied().collect(); + // lookup if exists + let multiplicities_polys_values: Vec<_> = (&mut source_it).take(constants.num_multiplicities_polys).copied().collect(); + let lookup_witness_encoding_polys_values: Vec<_> = (&mut source_it).take(constants.num_lookup_subarguments).copied().collect(); + let multiplicities_encoding_polys_values: Vec<_> = (&mut source_it).take(constants.num_multiplicities_polys).copied().collect(); + // lookup setup + let lookup_tables_columns: Vec<_> = (&mut source_it).take(constants.num_lookup_table_setup_polys).copied().collect(); + // quotient + let quotient_chunks: Vec<_> = source_it.copied().collect(); + + assert_eq!(quotient_chunks.len(), constants.quotient_degree); + + let mut source_it = evaluations.all_values_at_z_omega.iter(); + let copy_permutation_z_at_z_omega = *source_it.next().unwrap(); + + let mut t_accumulator = GoldilocksExtAsFieldWrapper::::zero(cs); + // precompute selectors at z + + let mut selectors_buffer = HashMap::new(); + for (gate_idx, evaluator) in verifier.evaluators_over_general_purpose_columns.iter().enumerate() { + if let Some(path) = fixed_parameters.selectors_placement.output_placement(gate_idx) { + if selectors_buffer.contains_key(&path) { + panic!("same selector for different gates"); + } + + compute_selector_subpath_at_z(path, &mut selectors_buffer, &constant_poly_values, cs); + } else { + assert!(evaluator.num_quotient_terms == 0); + } + } + + validity_flags.extend(check_lookup_contribution( + cs, + &evaluations, + challenges, + &mut t_accumulator, + &variables_polys_values, + &lookup_witness_encoding_polys_values, + &lookup_tables_columns, + &constant_poly_values, + &mut selectors_buffer, + &multiplicities_encoding_polys_values, + &multiplicities_polys_values, + verifier, + fixed_parameters, + constants, + )?); + + let constants_for_gates_over_general_purpose_columns = fixed_parameters.extra_constant_polys_for_selectors + verifier.parameters.num_constant_columns; + + let src = VerifierPolyStorage::new(variables_polys_values.clone(), witness_polys_values, constant_poly_values); + + check_specialized_gates_contribution(cs, challenges, &mut t_accumulator, &src, verifier, constants, constants_for_gates_over_general_purpose_columns)?; + + // log!("Evaluating general purpose gates"); + + check_general_purpose_gates_contribution( + cs, + challenges, + &mut t_accumulator, + &src, + &mut selectors_buffer, + verifier, + fixed_parameters, + constants, + constants_for_gates_over_general_purpose_columns, + )?; + + // then copy_permutation algorithm + + let z_in_domain_size = challenges.z.pow_u64(fixed_parameters.domain_size as u64, cs); + + let mut vanishing_at_z = z_in_domain_size; + vanishing_at_z.sub_assign(&one_ext, cs); + + check_copy_permutation_contribution( + cs, + challenges, + &mut t_accumulator, + &variables_polys_values, + copy_permutation_z_at_z, + copy_permutation_z_at_z_omega, + &grand_product_intermediate_polys, + &sigmas_values, + &non_residues_for_copy_permutation, + vanishing_at_z, + constants, + )?; + + let mut t_from_chunks = zero_ext; + let mut pow = one_ext; + for el in quotient_chunks.into_iter() { + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(&mut t_from_chunks, &el, &pow, cs); + + // let mut tmp = el; + // tmp.mul_assign(&pow, cs); + // t_from_chunks.add_assign(&tmp, cs); + + pow.mul_assign(&z_in_domain_size, cs); + } + + t_from_chunks.mul_assign(&vanishing_at_z, cs); + + let t_accumulator = t_accumulator.into_coeffs_in_base(); + let t_from_chunks = t_from_chunks.into_coeffs_in_base(); + + let c0_is_valid = GoldilocksField::equals(cs, &t_accumulator[0], &t_from_chunks[0])?; + let c1_is_valid = GoldilocksField::equals(cs, &t_accumulator[1], &t_from_chunks[1])?; + + validity_flags.push(c0_is_valid); + validity_flags.push(c1_is_valid); + + Ok(validity_flags) +} + +pub(crate) fn check_lookup_contribution + 'static>( + cs: &mut CS, + evaluations: &EvaluationsHolder, + challenges: &ChallengesHolder, + t_accumulator: &mut GoldilocksExtAsFieldWrapper, + // polynomial values + variables_polys_values: &Vec>, + lookup_witness_encoding_polys_values: &Vec>, + lookup_tables_columns: &Vec>, + constant_poly_values: &Vec>, + selectors_buffer: &mut HashMap, GoldilocksExtAsFieldWrapper>, + multiplicities_encoding_polys_values: &Vec>, + multiplicities_polys_values: &Vec>, + // parameters + verifier: &WrapperVerifier, + fixed_parameters: &VerificationKeyCircuitGeometry, + constants: &ConstantsHolder, +) -> Result, SynthesisError> { + let mut validity_flags = vec![]; + + let one_ext = GoldilocksExtAsFieldWrapper::::one(cs); + let lookup_challenges = &challenges.pregenerated_challenges_for_lookup; + + // first we do the lookup + if verifier.lookup_parameters != LookupParameters::NoLookup { + // immediatelly do sumchecks + let lookup_witness_encoding_polys_polys_at_0 = &evaluations.all_values_at_0[..constants.num_lookup_subarguments]; + let multiplicities_encoding_polys_at_0 = &evaluations.all_values_at_0[constants.num_lookup_subarguments..]; + + let mut witness_subsum = GoldilocksExtAsFieldWrapper::::zero(cs); + for a in lookup_witness_encoding_polys_polys_at_0.iter() { + witness_subsum.add_assign(a, cs); + } + + let mut multiplicities_subsum = GoldilocksExtAsFieldWrapper::::zero(cs); + for b in multiplicities_encoding_polys_at_0.iter() { + multiplicities_subsum.add_assign(b, cs); + } + + let witness_subsum = witness_subsum.into_coeffs_in_base(); + let multiplicities_subsum = multiplicities_subsum.into_coeffs_in_base(); + + let c0_is_valid = GoldilocksField::equals(cs, &witness_subsum[0], &multiplicities_subsum[0])?; + let c1_is_valid = GoldilocksField::equals(cs, &witness_subsum[1], &multiplicities_subsum[1])?; + + validity_flags.push(c0_is_valid); + validity_flags.push(c1_is_valid); + + // lookup argument related parts + match verifier.lookup_parameters { + LookupParameters::TableIdAsVariable { width: _, share_table_id: _ } | LookupParameters::TableIdAsConstant { width: _, share_table_id: _ } => { + // exists by our setup + let lookup_evaluator_id = 0; + let selector_subpath = fixed_parameters.selectors_placement.output_placement(lookup_evaluator_id).expect("lookup gate must be placed"); + let selector = selectors_buffer.remove(&selector_subpath).expect("path must be unique and precomputed"); + + let column_elements_per_subargument = verifier.lookup_parameters.columns_per_subargument() as usize; + assert!(fixed_parameters.table_ids_column_idxes.len() == 0 || fixed_parameters.table_ids_column_idxes.len() == 1); + + // this is our lookup width, either counted by number of witness columns only, or if one includes setup + let num_lookup_columns = column_elements_per_subargument + ((fixed_parameters.table_ids_column_idxes.len() == 1) as usize); + assert_eq!(lookup_tables_columns.len(), num_lookup_columns); + + let capacity = column_elements_per_subargument + ((fixed_parameters.table_ids_column_idxes.len() == 1) as usize); + let mut powers_of_gamma = Vec::with_capacity(capacity); + let mut tmp = GoldilocksExtAsFieldWrapper::::one(cs); + powers_of_gamma.push(tmp); + for _idx in 1..capacity { + if _idx == 1 { + tmp = challenges.lookup_gamma; + } else { + tmp.mul_assign(&challenges.lookup_gamma, cs); + } + + powers_of_gamma.push(tmp); + } + + // precompute aggregation of lookup table polys + assert_eq!(powers_of_gamma.len(), capacity); + let mut lookup_table_columns_aggregated = challenges.lookup_beta; + for (gamma, column) in powers_of_gamma.iter().zip(lookup_tables_columns.iter()) { + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(&mut lookup_table_columns_aggregated, gamma, column, cs); + } + + let mut challenges_it = lookup_challenges.iter(); + + // first A polys + let variables_columns_for_lookup = &variables_polys_values[..(column_elements_per_subargument * constants.num_lookup_subarguments)]; + assert_eq!( + lookup_witness_encoding_polys_values.len(), + variables_columns_for_lookup.chunks_exact(column_elements_per_subargument as usize).len() + ); + + for (a_poly, witness_columns) in lookup_witness_encoding_polys_values + .iter() + .zip(variables_columns_for_lookup.chunks_exact(column_elements_per_subargument as usize)) + { + let alpha = *challenges_it.next().expect("challenge for lookup A poly contribution"); + let mut contribution = challenges.lookup_beta; + + let table_id = if let Some(table_id_poly) = fixed_parameters.table_ids_column_idxes.get(0).copied() { + vec![constant_poly_values[table_id_poly]] + } else { + vec![] + }; + + for (gamma, column) in powers_of_gamma.iter().zip(witness_columns.iter().chain(table_id.iter())) { + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(&mut contribution, gamma, column, cs); + } + + // mul by A(x) + contribution.mul_assign(a_poly, cs); + // sub selector + contribution.sub_assign(&selector, cs); + + // mul by power of challenge and accumulate + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(t_accumulator, &alpha, &contribution, cs); + + // contribution.mul_assign(&alpha, cs); + // t_accumulator.add_assign(&contribution, cs); + } + + // then B polys + assert_eq!(multiplicities_encoding_polys_values.len(), multiplicities_polys_values.len()); + for (b_poly, multiplicities_poly) in multiplicities_encoding_polys_values.iter().zip(multiplicities_polys_values.iter()) { + let alpha = *challenges_it.next().expect("challenge for lookup B poly contribution"); + let mut contribution = lookup_table_columns_aggregated; + // mul by B(x) + contribution.mul_assign(b_poly, cs); + // sub multiplicity + contribution.sub_assign(multiplicities_poly, cs); + + // mul by power of challenge and accumulate + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(t_accumulator, &alpha, &contribution, cs); + + // contribution.mul_assign(&alpha, cs); + // t_accumulator.add_assign(&contribution, cs); + } + } + LookupParameters::UseSpecializedColumnsWithTableIdAsConstant { + width: _, + num_repetitions: _, + share_table_id: _, + } + | LookupParameters::UseSpecializedColumnsWithTableIdAsVariable { + width: _, + num_repetitions: _, + share_table_id: _, + } => { + let column_elements_per_subargument = verifier.lookup_parameters.specialized_columns_per_subargument() as usize; + assert!(fixed_parameters.table_ids_column_idxes.len() == 0 || fixed_parameters.table_ids_column_idxes.len() == 1); + + // this is our lookup width, either counted by number of witness columns only, or if one includes setup + let num_lookup_columns = column_elements_per_subargument + ((fixed_parameters.table_ids_column_idxes.len() == 1) as usize); + assert_eq!(lookup_tables_columns.len(), num_lookup_columns); + + let capacity = column_elements_per_subargument + ((fixed_parameters.table_ids_column_idxes.len() == 1) as usize); + let mut powers_of_gamma = Vec::with_capacity(capacity); + let mut tmp = GoldilocksExtAsFieldWrapper::::one(cs); + powers_of_gamma.push(tmp); + for _idx in 1..capacity { + if _idx == 1 { + tmp = challenges.lookup_gamma; + } else { + tmp.mul_assign(&challenges.lookup_gamma, cs); + } + + powers_of_gamma.push(tmp); + } + + // precompute aggregation of lookup table polys + assert_eq!(powers_of_gamma.len(), capacity); + let mut lookup_table_columns_aggregated = challenges.lookup_beta; + for (gamma, column) in powers_of_gamma.iter().zip(lookup_tables_columns.iter()) { + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(&mut lookup_table_columns_aggregated, gamma, column, cs); + } + + let mut challenges_it = lookup_challenges.iter(); + + // first A polys + let variables_columns_for_lookup = &variables_polys_values[verifier.parameters.num_columns_under_copy_permutation + ..(verifier.parameters.num_columns_under_copy_permutation + column_elements_per_subargument * constants.num_lookup_subarguments)]; + assert_eq!( + lookup_witness_encoding_polys_values.len(), + variables_columns_for_lookup.chunks_exact(column_elements_per_subargument as usize).len() + ); + + for (a_poly, witness_columns) in lookup_witness_encoding_polys_values + .iter() + .zip(variables_columns_for_lookup.chunks_exact(column_elements_per_subargument as usize)) + { + let alpha = *challenges_it.next().expect("challenge for lookup A poly contribution"); + let mut contribution = challenges.lookup_beta; + + let table_id = if let Some(table_id_poly) = fixed_parameters.table_ids_column_idxes.get(0).copied() { + vec![constant_poly_values[table_id_poly]] + } else { + vec![] + }; + + for (gamma, column) in powers_of_gamma.iter().zip(witness_columns.iter().chain(table_id.iter())) { + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(&mut contribution, gamma, column, cs); + } + + // mul by A(x) + contribution.mul_assign(a_poly, cs); + // sub numerator + contribution.sub_assign(&one_ext, cs); + + // mul by power of challenge and accumulate + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(t_accumulator, &alpha, &contribution, cs); + + // contribution.mul_assign(&alpha, cs); + // t_accumulator.add_assign(&contribution, cs); + } + + // then B polys + assert_eq!(multiplicities_encoding_polys_values.len(), multiplicities_polys_values.len()); + for (b_poly, multiplicities_poly) in multiplicities_encoding_polys_values.iter().zip(multiplicities_polys_values.iter()) { + let alpha = *challenges_it.next().expect("challenge for lookup B poly contribution"); + let mut contribution = lookup_table_columns_aggregated; + // mul by B(x) + contribution.mul_assign(b_poly, cs); + // sub multiplicity + contribution.sub_assign(multiplicities_poly, cs); + + // mul by power of challenge and accumulate + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(t_accumulator, &alpha, &contribution, cs); + + // contribution.mul_assign(&alpha, cs); + // t_accumulator.add_assign(&contribution, cs); + } + } + _ => { + unreachable!() + } + } + } + + Ok(validity_flags) +} + +pub(crate) fn check_specialized_gates_contribution + 'static>( + cs: &mut CS, + challenges: &ChallengesHolder, + t_accumulator: &mut GoldilocksExtAsFieldWrapper, + src: &VerifierPolyStorage>, + // parameters + verifier: &WrapperVerifier, + constants: &ConstantsHolder, + constants_for_gates_over_general_purpose_columns: usize, +) -> Result<(), SynthesisError> { + let zero_ext = GoldilocksExtAsFieldWrapper::::zero(cs); + let one_ext = GoldilocksExtAsFieldWrapper::::one(cs); + + let specialized_evaluators_challenges = &challenges.pregenerated_challenges_for_gates_over_specialized_columns; + + let mut specialized_placement_data = vec![]; + let mut evaluation_functions = vec![]; + + for (idx, (gate_type_id, evaluator)) in verifier + .gate_type_ids_for_specialized_columns + .iter() + .zip(verifier.evaluators_over_specialized_columns.iter()) + .enumerate() + { + if gate_type_id == &std::any::TypeId::of::() { + continue; + } + + assert!( + evaluator.total_quotient_terms_over_all_repetitions != 0, + "evaluator {} has no contribution to quotient", + &evaluator.debug_name, + ); + // log!( + // "Will be evaluating {} over specialized columns", + // &evaluator.debug_name + // ); + + let num_terms = evaluator.num_quotient_terms; + let placement_strategy = verifier.placement_strategies.get(gate_type_id).copied().expect("gate must be allowed"); + let GatePlacementStrategy::UseSpecializedColumns { num_repetitions, share_constants } = placement_strategy else { + unreachable!(); + }; + + let total_terms = num_terms * num_repetitions; + + let (initial_offset, per_repetition_offset, total_constants_available) = verifier.offsets_for_specialized_evaluators[idx]; + + let placement_data = (num_repetitions, share_constants, initial_offset, per_repetition_offset, total_constants_available, total_terms); + + specialized_placement_data.push(placement_data); + let t = &**evaluator.columnwise_satisfiability_function.as_ref().expect("must be properly configured"); + evaluation_functions.push(t); + } + + let mut challenges_offset = 0; + + for (placement_data, evaluation_fn) in specialized_placement_data.iter().zip(evaluation_functions.iter()) { + let (num_repetitions, share_constants, initial_offset, per_repetition_offset, _total_constants_available, total_terms) = *placement_data; + + // we self-check again + if share_constants { + assert_eq!(per_repetition_offset.constants_offset, 0); + } + let mut final_offset = initial_offset; + for _ in 0..num_repetitions { + final_offset.add_offset(&per_repetition_offset); + } + + let mut dst = VerifierRelationDestination { + accumulator: zero_ext, + selector_value: one_ext, + challenges: specialized_evaluators_challenges.clone(), + current_challenge_offset: challenges_offset, + _marker: std::marker::PhantomData, + }; + + let mut src = src.subset( + initial_offset.variables_offset..final_offset.variables_offset, + initial_offset.witnesses_offset..final_offset.witnesses_offset, + (constants_for_gates_over_general_purpose_columns + initial_offset.constants_offset)..(constants_for_gates_over_general_purpose_columns + final_offset.constants_offset), + ); + + evaluation_fn.evaluate_over_columns(&mut src, &mut dst, cs); + + t_accumulator.add_assign(&dst.accumulator, cs); + + challenges_offset += total_terms; + } + + assert_eq!(challenges_offset, constants.total_num_gate_terms_for_specialized_columns); + + Ok(()) +} + +pub(crate) fn check_general_purpose_gates_contribution + 'static>( + cs: &mut CS, + challenges: &ChallengesHolder, + t_accumulator: &mut GoldilocksExtAsFieldWrapper, + src: &VerifierPolyStorage>, + selectors_buffer: &mut HashMap, GoldilocksExtAsFieldWrapper>, + // parameters + verifier: &WrapperVerifier, + fixed_parameters: &VerificationKeyCircuitGeometry, + constants: &ConstantsHolder, + constants_for_gates_over_general_purpose_columns: usize, +) -> Result<(), SynthesisError> { + let src = src.subset( + 0..verifier.parameters.num_columns_under_copy_permutation, + 0..verifier.parameters.num_witness_columns, + 0..constants_for_gates_over_general_purpose_columns, + ); + + let mut challenges_offset = 0; + let zero_ext = GoldilocksExtAsFieldWrapper::::zero(cs); + + let general_purpose_challenges = &challenges.pregenerated_challenges_for_gates_over_general_purpose_columns; + + for (gate_idx, evaluator) in verifier.evaluators_over_general_purpose_columns.iter().enumerate() { + if &evaluator.evaluator_type_id == &std::any::TypeId::of::() { + continue; + } + + if evaluator.total_quotient_terms_over_all_repetitions == 0 { + // we MAY formally have NOP gate in the set here, but we should not evaluate it. + // NOP gate will affect selectors placement, but not the rest + continue; + } + + if let Some(path) = fixed_parameters.selectors_placement.output_placement(gate_idx) { + let selector = selectors_buffer.remove(&path).expect("path must be unique and precomputed"); + let constant_placement_offset = path.len(); + + let mut dst = VerifierRelationDestination { + accumulator: zero_ext, + selector_value: selector, + challenges: general_purpose_challenges.clone(), + current_challenge_offset: challenges_offset, + _marker: std::marker::PhantomData, + }; + + let mut source = src.clone(); + + let evaluation_fn = &**evaluator.rowwise_satisfiability_function.as_ref().expect("gate must be allowed"); + + evaluation_fn.evaluate_over_general_purpose_columns(&mut source, &mut dst, constant_placement_offset, cs); + + t_accumulator.add_assign(&dst.accumulator, cs); + challenges_offset += evaluator.total_quotient_terms_over_all_repetitions; + } else { + assert!(evaluator.num_quotient_terms == 0); + } + } + + assert_eq!(challenges_offset, constants.total_num_gate_terms_for_general_purpose_columns); + + Ok(()) +} + +pub(crate) fn check_copy_permutation_contribution + 'static>( + cs: &mut CS, + challenges: &ChallengesHolder, + t_accumulator: &mut GoldilocksExtAsFieldWrapper, + // polynomial values + variables_polys_values: &Vec>, + copy_permutation_z_at_z: GoldilocksExtAsFieldWrapper, + copy_permutation_z_at_z_omega: GoldilocksExtAsFieldWrapper, + grand_product_intermediate_polys: &Vec>, + sigmas_values: &Vec>, + non_residues_for_copy_permutation: &Vec>, + vanishing_at_z: GoldilocksExtAsFieldWrapper, + // parameters + constants: &ConstantsHolder, +) -> Result<(), SynthesisError> { + let one_ext = GoldilocksExtAsFieldWrapper::::one(cs); + + let mut challenges_it = challenges.remaining_challenges.iter(); + + { + // (x^n - 1) / (x - 1), + let mut z_minus_one = challenges.z; + z_minus_one.sub_assign(&one_ext, cs); + + let mut unnormalized_l1_inverse_at_z = vanishing_at_z; + let z_minus_one_inversed = z_minus_one.inverse(cs); + unnormalized_l1_inverse_at_z.mul_assign(&z_minus_one_inversed, cs); + + let alpha = *challenges_it.next().expect("challenge for z(1) == 1"); + // (z(x) - 1) * l(1) + let mut contribution = copy_permutation_z_at_z; + contribution.sub_assign(&one_ext, cs); + contribution.mul_assign(&unnormalized_l1_inverse_at_z, cs); + + // mul by power of challenge and accumulate + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(t_accumulator, &alpha, &contribution, cs); + + // contribution.mul_assign(&alpha, cs); + // t_accumulator.add_assign(&contribution, cs); + } + + // partial products + + let lhs = grand_product_intermediate_polys.iter().chain(std::iter::once(©_permutation_z_at_z_omega)); + + let rhs = std::iter::once(©_permutation_z_at_z).chain(grand_product_intermediate_polys.iter()); + + for (((((lhs, rhs), alpha), non_residues), variables), sigmas) in lhs + .zip(rhs) + .zip(&mut challenges_it) + .zip(non_residues_for_copy_permutation.chunks(constants.quotient_degree)) + .zip(variables_polys_values.chunks(constants.quotient_degree)) + .zip(sigmas_values.chunks(constants.quotient_degree)) + { + let mut lhs = *lhs; + for (variable, sigma) in variables.iter().zip(sigmas.iter()) { + // denominator is w + beta * sigma(x) + gamma + let mut subres = *sigma; + subres.mul_assign(&challenges.beta, cs); + subres.add_assign(&variable, cs); + subres.add_assign(&challenges.gamma, cs); + lhs.mul_assign(&subres, cs); + } + + let mut rhs = *rhs; + let x_poly_value = challenges.z; + for (non_res, variable) in non_residues.iter().zip(variables.iter()) { + // numerator is w + beta * non_res * x + gamma + let mut subres = x_poly_value; + subres.mul_assign_by_base(cs, non_res)?; + subres.mul_assign(&challenges.beta, cs); + subres.add_assign(&variable, cs); + subres.add_assign(&challenges.gamma, cs); + rhs.mul_assign(&subres, cs); + } + + let mut contribution = lhs; + contribution.sub_assign(&rhs, cs); + + // mul by power of challenge and accumulate + GoldilocksExtAsFieldWrapper::::mul_and_accumulate_into(t_accumulator, &alpha, &contribution, cs); + + // contribution.mul_assign(&alpha, cs); + // t_accumulator.add_assign(&contribution, cs); + } + + assert_eq!(challenges_it.len(), 0, "must exhaust all the challenges"); + + Ok(()) +} diff --git a/crates/snark-wrapper/src/verifier/utils.rs b/crates/snark-wrapper/src/verifier/utils.rs new file mode 100644 index 0000000..3b93291 --- /dev/null +++ b/crates/snark-wrapper/src/verifier/utils.rs @@ -0,0 +1,95 @@ +use super::*; + +pub fn smart_and>(cs: &mut CS, bools: &[Boolean]) -> Result { + const LIMIT: usize = 4; + assert!(bools.len() > 0); + if bools.len() == 1 { + return Ok(bools[0]); + } + + if bools.len() == 2 { + // 1 gate + let result = Boolean::and(cs, &bools[0], &bools[1])?; + return Ok(result); + } + + // 1 gate for 2, + // 2 gates for 3, etc + if bools.len() < LIMIT { + // 1 gate + let mut result = Boolean::and(cs, &bools[0], &bools[1])?; + // len - 2 gates + for b in bools[2..].iter() { + result = Boolean::and(cs, &result, &b)?; + } + return Ok(result); + } + + // 1 gate for 3 + // 2 gates for 6 + // 3 gates for 9, etc + let mut lc = LinearCombination::zero(); + let num_elements_as_fr = E::Fr::from_str(&bools.len().to_string()).unwrap(); + lc.sub_assign_constant(num_elements_as_fr); + for b in bools.iter() { + lc.add_assign_boolean_with_coeff(b, E::Fr::one()); + } + let as_num = lc.into_num(cs)?; + + // 2 gates here + let all_true = as_num.is_zero(cs)?; + + // so 2 gates for 3 + // 4 gates for 6 + // 5 gates for 9 + // so we win at 4+ + + Ok(all_true) +} + +pub(crate) fn binary_select>(cs: &mut CS, elements: &[GoldilocksField], bits: &[Boolean]) -> Result, SynthesisError> { + assert_eq!(elements.len(), 1 << bits.len()); + assert!(bits.len() > 0); + + let mut input_space = Vec::with_capacity(elements.len() / 2); + let mut dst_space = Vec::with_capacity(elements.len() / 2); + + for (idx, bit) in bits.iter().enumerate() { + let src = if idx == 0 { elements } else { &input_space }; + + debug_assert_eq!(elements.len() % 2, 0); + dst_space.clear(); + + for src in src.array_chunks::<2>() { + let [a, b] = src; + // NOTE order here + let selected = GoldilocksField::conditionally_select(cs, *bit, b, a)?; + dst_space.push(selected); + } + + std::mem::swap(&mut dst_space, &mut input_space); + } + + assert_eq!(input_space.len(), 1); + + Ok(input_space.pop().unwrap()) +} + +pub(crate) fn materialize_powers_serial + 'static>(cs: &mut CS, base: GoldilocksExtAsFieldWrapper, size: usize) -> Vec> { + if size == 0 { + return Vec::new(); + } + let mut storage = Vec::with_capacity(size); + let mut current = GoldilocksExtAsFieldWrapper::one(cs); + storage.push(current); + for idx in 1..size { + if idx == 1 { + current = base; + } else { + current.mul_assign(&base, cs); + } + storage.push(current); + } + + storage +} diff --git a/crates/snark-wrapper/src/verifier_structs/allocated_proof.rs b/crates/snark-wrapper/src/verifier_structs/allocated_proof.rs new file mode 100644 index 0000000..6a27d98 --- /dev/null +++ b/crates/snark-wrapper/src/verifier_structs/allocated_proof.rs @@ -0,0 +1,140 @@ +use super::*; + +use super::allocated_queries::AllocatedSingleRoundQueries; +pub struct AllocatedProof> { + pub public_inputs: Vec>, + + pub witness_oracle_cap: Vec, + pub stage_2_oracle_cap: Vec, + pub quotient_oracle_cap: Vec, + pub final_fri_monomials: [Vec>; 2], + + pub values_at_z: Vec<[GoldilocksField; 2]>, + pub values_at_z_omega: Vec<[GoldilocksField; 2]>, + pub values_at_0: Vec<[GoldilocksField; 2]>, + + pub fri_base_oracle_cap: Vec, + pub fri_intermediate_oracles_caps: Vec>, + + pub queries_per_fri_repetition: Vec>, + + pub pow_challenge: [Boolean; 64], +} + +impl, H: CircuitGLTreeHasher, NonCircuitSimulator = HS>> AllocatedProof { + pub fn allocate_from_witness>( + cs: &mut CS, + witness: &Option>, + verifier: &WrapperVerifier, + fixed_parameters: &VerificationKeyCircuitGeometry, + proof_config: &ProofConfig, + ) -> Result { + if let Some(config) = witness.as_ref().map(|el| &el.proof_config) { + assert_eq!(config, proof_config); + } + + let constants = ConstantsHolder::generate(proof_config, verifier, fixed_parameters); + + let num_elements = fixed_parameters.num_public_inputs(); + let public_inputs = witness.as_ref().map(|el| el.public_inputs.iter().copied()); + let public_inputs = allocate_num_elements(cs, num_elements, public_inputs, GoldilocksField::alloc_from_field)?; + + let num_elements = fixed_parameters.cap_size; + let witness_oracle_cap = witness.as_ref().map(|el| el.witness_oracle_cap.iter().cloned()); + let witness_oracle_cap = allocate_num_elements(cs, num_elements, witness_oracle_cap, Num::alloc)?; + + let num_elements = fixed_parameters.cap_size; + let stage_2_oracle_cap = witness.as_ref().map(|el| el.stage_2_oracle_cap.iter().cloned()); + let stage_2_oracle_cap = allocate_num_elements(cs, num_elements, stage_2_oracle_cap, Num::alloc)?; + + let num_elements = fixed_parameters.cap_size; + let quotient_oracle_cap = witness.as_ref().map(|el| el.quotient_oracle_cap.iter().cloned()); + let quotient_oracle_cap = allocate_num_elements(cs, num_elements, quotient_oracle_cap, Num::alloc)?; + + let num_elements = constants.final_expected_degree; + let final_fri_monomials_c0 = witness.as_ref().map(|el| el.final_fri_monomials[0].iter().cloned()); + let final_fri_monomials_c0 = allocate_num_elements(cs, num_elements, final_fri_monomials_c0, GoldilocksField::alloc_from_field)?; + + let num_elements = constants.final_expected_degree; + let final_fri_monomials_c1 = witness.as_ref().map(|el| el.final_fri_monomials[1].iter().cloned()); + let final_fri_monomials_c1 = allocate_num_elements(cs, num_elements, final_fri_monomials_c1, GoldilocksField::alloc_from_field)?; + + let num_elements = constants.num_poly_values_at_z; + let values_at_z = witness.as_ref().map(|el| el.values_at_z.iter().map(|el| el.into_coeffs_in_base())); + let values_at_z = allocate_num_elements(cs, num_elements, values_at_z, allocate_gl_array)?; + + let num_elements = constants.num_poly_values_at_z_omega; + let values_at_z_omega = witness.as_ref().map(|el| el.values_at_z_omega.iter().map(|el| el.into_coeffs_in_base())); + let values_at_z_omega = allocate_num_elements(cs, num_elements, values_at_z_omega, allocate_gl_array)?; + + let num_elements = constants.num_poly_values_at_zero; + let values_at_0 = witness.as_ref().map(|el| el.values_at_0.iter().map(|el| el.into_coeffs_in_base())); + let values_at_0 = allocate_num_elements(cs, num_elements, values_at_0, allocate_gl_array)?; + + let num_elements = fixed_parameters.cap_size; + let fri_base_oracle_cap = witness.as_ref().map(|el| el.fri_base_oracle_cap.iter().cloned()); + let fri_base_oracle_cap = allocate_num_elements(cs, num_elements, fri_base_oracle_cap, Num::alloc)?; + + let fri_folding_schedule = constants.fri_folding_schedule; + assert!(fri_folding_schedule.len() > 0); + let mut fri_intermediate_oracles_caps = Vec::with_capacity(fri_folding_schedule.len() - 1); + for idx in 0..(fri_folding_schedule.len() - 1) { + let num_elements = fixed_parameters.cap_size; + let fri_intermediate_cap = witness.as_ref().map(|el| el.fri_intermediate_oracles_caps[idx].iter().cloned()); + let fri_intermediate_cap = allocate_num_elements(cs, num_elements, fri_intermediate_cap, Num::alloc)?; + fri_intermediate_oracles_caps.push(fri_intermediate_cap); + } + + let num_items = constants.num_fri_repetitions; + let mut queries_per_fri_repetition = Vec::with_capacity(num_items); + for idx in 0..num_items { + let wit = witness.as_ref().map(|el| el.queries_per_fri_repetition[idx].clone()); + let queries = AllocatedSingleRoundQueries::allocate_from_witness(cs, wit, verifier, fixed_parameters, proof_config)?; + queries_per_fri_repetition.push(queries); + } + + let mut pow_challenge_boolean = [Boolean::Constant(true); 64]; + let pow_challenge = witness.as_ref().map(|el| vec![el.pow_challenge]).unwrap_or(vec![]); + + let mut lsb_iter = crate::boojum::utils::LSBIterator::new(&pow_challenge); + + for i in 0..64 { + pow_challenge_boolean[i] = Boolean::alloc(cs, lsb_iter.next())?; + } + + let final_fri_monomials = [final_fri_monomials_c0, final_fri_monomials_c1]; + + Ok(Self { + public_inputs, + + witness_oracle_cap, + stage_2_oracle_cap, + quotient_oracle_cap, + final_fri_monomials, + + values_at_z, + values_at_z_omega, + values_at_0, + + fri_base_oracle_cap, + fri_intermediate_oracles_caps, + + queries_per_fri_repetition, + + pow_challenge: pow_challenge_boolean, + }) + } +} + +pub fn allocate_gl_array, const N: usize>(cs: &mut CS, source: Option<[GL; N]>) -> Result<[GoldilocksField; N], SynthesisError> { + let mut result = [GoldilocksField::zero(); N]; + + let mut source_it = source.map(|s| s.into_iter()); + + for i in 0..N { + let el = source_it.as_mut().map(|el| el.next().expect("Should be enough elements in the source")); + result[i] = GoldilocksField::alloc_from_field(cs, el)?; + } + + Ok(result) +} diff --git a/crates/snark-wrapper/src/verifier_structs/allocated_queries.rs b/crates/snark-wrapper/src/verifier_structs/allocated_queries.rs new file mode 100644 index 0000000..1ca8d65 --- /dev/null +++ b/crates/snark-wrapper/src/verifier_structs/allocated_queries.rs @@ -0,0 +1,75 @@ +use super::*; + +pub struct AllocatedSingleRoundQueries> { + // we need query for witness, setup, stage 2 and quotient + pub witness_query: AllocatedOracleQuery, + pub stage_2_query: AllocatedOracleQuery, + pub quotient_query: AllocatedOracleQuery, + pub setup_query: AllocatedOracleQuery, + + pub fri_queries: Vec>, +} + +impl, H: CircuitGLTreeHasher, NonCircuitSimulator = HS>> AllocatedSingleRoundQueries { + pub fn allocate_from_witness>( + cs: &mut CS, + witness: Option>, + verifier: &WrapperVerifier, + fixed_parameters: &VerificationKeyCircuitGeometry, + proof_config: &ProofConfig, + ) -> Result { + let base_oracle_depth = fixed_parameters.base_oracles_depth(); + let constants = ConstantsHolder::generate(proof_config, verifier, fixed_parameters); + + let witness_leaf_size = constants.witness_leaf_size; + let witness_query = AllocatedOracleQuery::allocate_from_witness(cs, witness.as_ref().map(|el| el.witness_query.clone()), witness_leaf_size, base_oracle_depth)?; + + let stage_2_leaf_size = constants.stage_2_leaf_size; + let stage_2_query = AllocatedOracleQuery::allocate_from_witness(cs, witness.as_ref().map(|el| el.stage_2_query.clone()), stage_2_leaf_size, base_oracle_depth)?; + + let quotient_leaf_size = constants.quotient_leaf_size; + let quotient_query = AllocatedOracleQuery::allocate_from_witness(cs, witness.as_ref().map(|el| el.quotient_query.clone()), quotient_leaf_size, base_oracle_depth)?; + + let setup_leaf_size = constants.setup_leaf_size; + let setup_query = AllocatedOracleQuery::allocate_from_witness(cs, witness.as_ref().map(|el| el.setup_query.clone()), setup_leaf_size, base_oracle_depth)?; + + // fri is a little bit more involved + let mut expected_fri_query_len = base_oracle_depth; + let interpolation_schedule = constants.fri_folding_schedule; + let mut fri_queries = Vec::with_capacity(interpolation_schedule.len()); + for (idx, interpolation_log_2) in interpolation_schedule.into_iter().enumerate() { + expected_fri_query_len -= interpolation_log_2; + let leaf_size = (1 << interpolation_log_2) * 2; // in extension + let wit = witness.as_ref().map(|el| el.fri_queries[idx].clone()); + let query = AllocatedOracleQuery::allocate_from_witness(cs, wit, leaf_size, expected_fri_query_len)?; + fri_queries.push(query); + } + + Ok(Self { + witness_query, + stage_2_query, + quotient_query, + setup_query, + fri_queries, + }) + } +} + +pub struct AllocatedOracleQuery> { + pub leaf_elements: Vec>, + pub proof: Vec, +} + +impl, H: CircuitGLTreeHasher, NonCircuitSimulator = HS>> AllocatedOracleQuery { + pub fn allocate_from_witness>(cs: &mut CS, witness: Option>, leaf_size: usize, proof_depth: usize) -> Result { + let num_elements = leaf_size; + let leaf_elements = witness.as_ref().map(|el| el.leaf_elements.iter().copied()); + let leaf_elements = allocate_num_elements(cs, num_elements, leaf_elements, GoldilocksField::alloc_from_field)?; + + let num_elements = proof_depth; + let proof = witness.as_ref().map(|el| el.proof.iter().cloned()); + let proof = allocate_num_elements(cs, num_elements, proof, Num::alloc)?; + + Ok(Self { leaf_elements, proof }) + } +} diff --git a/crates/snark-wrapper/src/verifier_structs/allocated_vk.rs b/crates/snark-wrapper/src/verifier_structs/allocated_vk.rs new file mode 100644 index 0000000..d5b1003 --- /dev/null +++ b/crates/snark-wrapper/src/verifier_structs/allocated_vk.rs @@ -0,0 +1,49 @@ +use super::*; + +use crate::boojum::cs::implementations::verifier::VerificationKey; + +#[derive(Clone, Debug)] +pub struct AllocatedVerificationKey> { + pub setup_merkle_tree_cap: Vec, +} + +impl, H: CircuitGLTreeHasher, NonCircuitSimulator = HS>> AllocatedVerificationKey { + pub fn allocate_from_witness>( + cs: &mut CS, + witness: Option>, + vk_fixed_parameters: &VerificationKeyCircuitGeometry, + ) -> Result { + if let Some(VerificationKey { + setup_merkle_tree_cap, + fixed_parameters, + }) = witness + { + assert_eq!(vk_fixed_parameters, &fixed_parameters); + + // allocate fixed length + assert!(fixed_parameters.cap_size > 0); + let cap = allocate_num_elements(cs, fixed_parameters.cap_size, Some(setup_merkle_tree_cap.into_iter()), Num::alloc)?; + + Ok(Self { setup_merkle_tree_cap: cap }) + } else { + let cap = allocate_num_elements(cs, vk_fixed_parameters.cap_size, None::>, Num::alloc)?; + + Ok(Self { setup_merkle_tree_cap: cap }) + } + } + + pub fn allocate_constant(witness: &VerificationKey, vk_fixed_parameters: &VerificationKeyCircuitGeometry) -> Self { + let VerificationKey { + setup_merkle_tree_cap, + fixed_parameters, + } = witness; + + assert_eq!(vk_fixed_parameters, fixed_parameters); + + // allocate fixed length + assert!(fixed_parameters.cap_size > 0); + let cap = setup_merkle_tree_cap.iter().map(|x| Num::Constant(*x)).collect(); + + Self { setup_merkle_tree_cap: cap } + } +} diff --git a/crates/snark-wrapper/src/verifier_structs/challenges.rs b/crates/snark-wrapper/src/verifier_structs/challenges.rs new file mode 100644 index 0000000..b48d1e0 --- /dev/null +++ b/crates/snark-wrapper/src/verifier_structs/challenges.rs @@ -0,0 +1,196 @@ +use crate::traits::transcript::CircuitGLTranscript; + +use super::allocated_proof::AllocatedProof; +use super::*; +use crate::boojum::field::traits::field_like::PrimeFieldLike; + +use crate::franklin_crypto::bellman::plonk::better_better_cs::cs::ConstraintSystem; +use crate::franklin_crypto::plonk::circuit::goldilocks::prime_field_like::GoldilocksExtAsFieldWrapper; +use crate::franklin_crypto::plonk::circuit::goldilocks::prime_field_like::*; + +use crate::verifier::utils::materialize_powers_serial; + +pub(crate) struct ChallengesHolder> { + pub(crate) beta: GoldilocksExtAsFieldWrapper, + pub(crate) gamma: GoldilocksExtAsFieldWrapper, + pub(crate) lookup_beta: GoldilocksExtAsFieldWrapper, + pub(crate) lookup_gamma: GoldilocksExtAsFieldWrapper, + + pub(crate) alpha: GoldilocksExtAsFieldWrapper, + pub(crate) pregenerated_challenges_for_lookup: Vec>, + pub(crate) pregenerated_challenges_for_gates_over_specialized_columns: Vec>, + pub(crate) pregenerated_challenges_for_gates_over_general_purpose_columns: Vec>, + pub(crate) remaining_challenges: Vec>, + + pub(crate) z: GoldilocksExtAsFieldWrapper, + pub(crate) z_omega: GoldilocksExtAsFieldWrapper, + + pub(crate) challenges_for_fri_quotiening: Vec>, + pub(crate) fri_intermediate_challenges: Vec>>, + // pub(crate) challenges: Vec>, +} + +impl + 'static> ChallengesHolder { + pub fn new(cs: &mut CS) -> Self { + Self { + beta: GoldilocksExtAsFieldWrapper::zero(cs), + gamma: GoldilocksExtAsFieldWrapper::zero(cs), + lookup_beta: GoldilocksExtAsFieldWrapper::zero(cs), + lookup_gamma: GoldilocksExtAsFieldWrapper::zero(cs), + + alpha: GoldilocksExtAsFieldWrapper::zero(cs), + pregenerated_challenges_for_lookup: vec![], + pregenerated_challenges_for_gates_over_specialized_columns: vec![], + pregenerated_challenges_for_gates_over_general_purpose_columns: vec![], + remaining_challenges: vec![], + + z: GoldilocksExtAsFieldWrapper::zero(cs), + z_omega: GoldilocksExtAsFieldWrapper::zero(cs), + + challenges_for_fri_quotiening: vec![], + fri_intermediate_challenges: vec![], + } + } + + pub fn get_beta_gamma_challenges>(&mut self, cs: &mut CS, transcript: &mut T, verifier: &WrapperVerifier) -> Result<(), SynthesisError> { + let beta = transcript.get_multiple_challenges_fixed::<_, 2>(cs)?; + self.beta = GoldilocksExtAsFieldWrapper::from_coeffs_in_base(beta); + + let gamma = transcript.get_multiple_challenges_fixed::<_, 2>(cs)?; + self.gamma = GoldilocksExtAsFieldWrapper::from_coeffs_in_base(gamma); + + (self.lookup_beta, self.lookup_gamma) = if verifier.lookup_parameters != LookupParameters::NoLookup { + // lookup argument related parts + let lookup_beta = transcript.get_multiple_challenges_fixed::<_, 2>(cs)?; + let lookup_beta = GoldilocksExtAsFieldWrapper::from_coeffs_in_base(lookup_beta); + let lookup_gamma = transcript.get_multiple_challenges_fixed::<_, 2>(cs)?; + let lookup_gamma = GoldilocksExtAsFieldWrapper::from_coeffs_in_base(lookup_gamma); + + (lookup_beta, lookup_gamma) + } else { + let zero_ext = GoldilocksExtAsFieldWrapper::zero(cs); + (zero_ext, zero_ext) + }; + + Ok(()) + } + + pub fn get_alpha_powers>(&mut self, cs: &mut CS, transcript: &mut T, constants: &ConstantsHolder) -> Result<(), SynthesisError> { + let alpha = transcript.get_multiple_challenges_fixed::<_, 2>(cs)?; + self.alpha = GoldilocksExtAsFieldWrapper::from_coeffs_in_base(alpha); + + let powers: Vec<_> = materialize_powers_serial(cs, self.alpha, constants.total_num_terms); + let rest = &powers[..]; + let (take, rest) = rest.split_at(constants.total_num_lookup_argument_terms); + self.pregenerated_challenges_for_lookup = take.to_vec(); + let (take, rest) = rest.split_at(constants.total_num_gate_terms_for_specialized_columns); + self.pregenerated_challenges_for_gates_over_specialized_columns = take.to_vec(); + let (take, rest) = rest.split_at(constants.total_num_gate_terms_for_general_purpose_columns); + self.pregenerated_challenges_for_gates_over_general_purpose_columns = take.to_vec(); + self.remaining_challenges = rest.to_vec(); + + Ok(()) + } + + pub fn get_z_challenge>(&mut self, cs: &mut CS, transcript: &mut T, fixed_parameters: &VerificationKeyCircuitGeometry) -> Result<(), SynthesisError> { + let z = transcript.get_multiple_challenges_fixed::<_, 2>(cs)?; + self.z = GoldilocksExtAsFieldWrapper::from_coeffs_in_base(z); + + use crate::boojum::cs::implementations::utils::domain_generator_for_size; + let omega = domain_generator_for_size::(fixed_parameters.domain_size as u64); + let omega_cs_constant = GoldilocksAsFieldWrapper::constant(omega, cs); + self.z_omega = self.z; + self.z_omega.mul_assign_by_base(cs, &omega_cs_constant)?; + + Ok(()) + } + + pub fn get_challenges_for_fri_quotiening>(&mut self, cs: &mut CS, transcript: &mut T, total_num_challenges: usize) -> Result<(), SynthesisError> { + // get challenges + let c0 = transcript.get_challenge(cs)?; + let c1 = transcript.get_challenge(cs)?; + + let challenge = GoldilocksExtAsFieldWrapper::from_coeffs_in_base([c0, c1]); + + self.challenges_for_fri_quotiening = crate::verifier::utils::materialize_powers_serial(cs, challenge, total_num_challenges); + + Ok(()) + } + + pub fn get_fri_intermediate_challenges, TR: CircuitGLTranscript>( + &mut self, + cs: &mut CS, + transcript: &mut TR, + proof: &AllocatedProof, + fixed_parameters: &VerificationKeyCircuitGeometry, + constants: &ConstantsHolder, + ) -> Result<(), SynthesisError> { + { + // now witness base FRI oracle + assert_eq!(fixed_parameters.cap_size, proof.fri_base_oracle_cap.len()); + transcript.witness_merkle_tree_cap(cs, &proof.fri_base_oracle_cap)?; + + let reduction_degree_log_2 = constants.fri_folding_schedule[0]; + + let c0 = transcript.get_challenge(cs)?; + let c1 = transcript.get_challenge(cs)?; + + let mut challenge_powers = Vec::with_capacity(reduction_degree_log_2); + let as_extension = GoldilocksExtAsFieldWrapper::from_coeffs_in_base([c0, c1]); + challenge_powers.push(as_extension); + + let mut current = as_extension; + + for _ in 1..reduction_degree_log_2 { + current.square(cs); + challenge_powers.push(current); + } + + self.fri_intermediate_challenges.push(challenge_powers); + } + + assert_eq!(constants.fri_folding_schedule[1..].len(), proof.fri_intermediate_oracles_caps.len()); + + for (interpolation_degree_log2, cap) in constants.fri_folding_schedule[1..].iter().zip(proof.fri_intermediate_oracles_caps.iter()) { + // commit new oracle + assert_eq!(fixed_parameters.cap_size, cap.len()); + transcript.witness_merkle_tree_cap(cs, &cap)?; + + // get challenge + let reduction_degree_log_2 = *interpolation_degree_log2; + let c0 = transcript.get_challenge(cs)?; + let c1 = transcript.get_challenge(cs)?; + + let mut challenge_powers = Vec::with_capacity(reduction_degree_log_2); + let as_extension = GoldilocksExtAsFieldWrapper::from_coeffs_in_base([c0, c1]); + challenge_powers.push(as_extension); + + let mut current = as_extension; + + for _ in 1..reduction_degree_log_2 { + current.square(cs); + challenge_powers.push(current); + } + + self.fri_intermediate_challenges.push(challenge_powers); + } + + Ok(()) + } +} + +pub(crate) struct EvaluationsHolder> { + pub(crate) all_values_at_z: Vec>, + pub(crate) all_values_at_z_omega: Vec>, + pub(crate) all_values_at_0: Vec>, +} + +impl> EvaluationsHolder { + pub(crate) fn from_proof>(proof: &AllocatedProof) -> Self { + Self { + all_values_at_z: proof.values_at_z.iter().map(|el| GoldilocksExtAsFieldWrapper::::from_coeffs_in_base(*el)).collect(), + all_values_at_z_omega: proof.values_at_z_omega.iter().map(|el| GoldilocksExtAsFieldWrapper::::from_coeffs_in_base(*el)).collect(), + all_values_at_0: proof.values_at_0.iter().map(|el| GoldilocksExtAsFieldWrapper::::from_coeffs_in_base(*el)).collect(), + } + } +} diff --git a/crates/snark-wrapper/src/verifier_structs/constants.rs b/crates/snark-wrapper/src/verifier_structs/constants.rs new file mode 100644 index 0000000..7f30763 --- /dev/null +++ b/crates/snark-wrapper/src/verifier_structs/constants.rs @@ -0,0 +1,221 @@ +use super::*; + +use crate::boojum::cs::implementations::verifier::{SizeCalculator, VerificationKeyCircuitGeometry}; +use crate::boojum::field::goldilocks::{GoldilocksExt2 as GLExt2, GoldilocksField as GL}; + +#[derive(Clone, Default, Debug)] +pub(crate) struct ConstantsHolder { + // quotient parameters + pub(crate) quotient_degree: usize, + pub(crate) num_lookup_subarguments: usize, + pub(crate) num_variable_polys: usize, + pub(crate) num_witness_polys: usize, + pub(crate) num_constant_polys: usize, + pub(crate) num_multiplicities_polys: usize, + pub(crate) num_copy_permutation_polys: usize, + pub(crate) num_lookup_table_setup_polys: usize, + pub(crate) num_intermediate_partial_product_relations: usize, + pub(crate) total_num_gate_terms_for_specialized_columns: usize, + pub(crate) total_num_gate_terms_for_general_purpose_columns: usize, + pub(crate) total_num_lookup_argument_terms: usize, + pub(crate) total_num_terms: usize, + + // commitments parameters + pub(crate) witness_leaf_size: usize, + pub(crate) stage_2_leaf_size: usize, + pub(crate) quotient_leaf_size: usize, + pub(crate) setup_leaf_size: usize, + + // opening parameters + pub(crate) num_poly_values_at_z: usize, + pub(crate) num_poly_values_at_z_omega: usize, + pub(crate) num_poly_values_at_zero: usize, + pub(crate) num_public_inputs: usize, + + // fri parameters + pub(crate) new_pow_bits: usize, + pub(crate) num_fri_repetitions: usize, + pub(crate) fri_folding_schedule: Vec, + pub(crate) final_expected_degree: usize, + pub(crate) total_num_challenges_for_fri_quotiening: usize, +} + +impl ConstantsHolder { + pub fn generate>(proof_config: &ProofConfig, verifier: &WrapperVerifier, fixed_parameters: &VerificationKeyCircuitGeometry) -> Self { + assert_eq!(verifier.parameters, fixed_parameters.parameters); + assert_eq!(verifier.lookup_parameters, fixed_parameters.lookup_parameters); + assert!(proof_config.fri_folding_schedule.is_none()); + assert_eq!(fixed_parameters.cap_size, proof_config.merkle_tree_cap_size); + assert_eq!(fixed_parameters.fri_lde_factor, proof_config.fri_lde_factor,); + + let mut result = Self::default(); + + result.quotient_degree = SizeCalculator::::quotient_degree(fixed_parameters); + result.num_lookup_subarguments = SizeCalculator::::num_sublookup_arguments(&verifier.parameters, &verifier.lookup_parameters); + result.num_variable_polys = SizeCalculator::::num_variable_polys(&verifier.parameters, verifier.total_num_variables_for_specialized_columns); + result.num_witness_polys = SizeCalculator::::num_witness_polys(&verifier.parameters, verifier.total_num_witnesses_for_specialized_columns); + result.num_constant_polys = SizeCalculator::::num_constant_polys(&verifier.parameters, fixed_parameters, verifier.total_num_constants_for_specialized_columns); + result.num_multiplicities_polys = + SizeCalculator::::num_multipicities_polys(&verifier.lookup_parameters, fixed_parameters.total_tables_len as usize, fixed_parameters.domain_size); + result.num_copy_permutation_polys = result.num_variable_polys; + + result.num_lookup_table_setup_polys = SizeCalculator::::num_lookup_table_setup_polys(&verifier.lookup_parameters); + + result.witness_leaf_size = SizeCalculator::::witness_leaf_size( + &verifier.parameters, + &verifier.lookup_parameters, + fixed_parameters, + verifier.total_num_variables_for_specialized_columns, + verifier.total_num_witnesses_for_specialized_columns, + ); + result.stage_2_leaf_size = SizeCalculator::::stage_2_leaf_size( + &verifier.parameters, + &verifier.lookup_parameters, + fixed_parameters, + verifier.total_num_variables_for_specialized_columns, + ); + result.quotient_leaf_size = SizeCalculator::::quotient_leaf_size(fixed_parameters); + result.setup_leaf_size = SizeCalculator::::setup_leaf_size( + &verifier.parameters, + &verifier.lookup_parameters, + fixed_parameters, + verifier.total_num_variables_for_specialized_columns, + verifier.total_num_constants_for_specialized_columns, + ); + + result.total_num_lookup_argument_terms = result.num_lookup_subarguments + result.num_multiplicities_polys; + + use crate::boojum::cs::implementations::copy_permutation::num_intermediate_partial_product_relations; + result.num_intermediate_partial_product_relations = num_intermediate_partial_product_relations(result.num_copy_permutation_polys, result.quotient_degree); + + result.compute_num_gate_terms(verifier); + result.compute_total_num_terms(); + + result.compute_num_poly_values_at_z(fixed_parameters); + result.compute_num_poly_values_at_z_omega(); + result.compute_num_poly_values_at_zero(); + result.num_public_inputs = fixed_parameters.public_inputs_locations.len(); + + result.compute_fri_parameters(fixed_parameters, proof_config); + + result.compute_total_num_challenges_for_fri_quotiening(verifier); + + result + } + + fn compute_num_gate_terms>(&mut self, verifier: &WrapperVerifier) { + assert_eq!(verifier.evaluators_over_specialized_columns.len(), verifier.gate_type_ids_for_specialized_columns.len()); + + self.total_num_gate_terms_for_specialized_columns = verifier + .evaluators_over_specialized_columns + .iter() + .zip(verifier.gate_type_ids_for_specialized_columns.iter()) + .map(|(evaluator, gate_type_id)| { + let placement_strategy = verifier.placement_strategies.get(gate_type_id).copied().expect("gate must be allowed"); + let num_repetitions = match placement_strategy { + GatePlacementStrategy::UseSpecializedColumns { num_repetitions, .. } => num_repetitions, + _ => unreachable!(), + }; + assert_eq!(evaluator.num_repetitions_on_row, num_repetitions); + let terms_per_repetition = evaluator.num_quotient_terms; + + terms_per_repetition * num_repetitions + }) + .sum(); + + self.total_num_gate_terms_for_general_purpose_columns = verifier + .evaluators_over_general_purpose_columns + .iter() + .map(|evaluator| evaluator.total_quotient_terms_over_all_repetitions) + .sum(); + } + + fn compute_total_num_terms(&mut self) { + self.total_num_terms = self.total_num_lookup_argument_terms // and lookup is first + + self.total_num_gate_terms_for_specialized_columns // then gates over specialized columns + + self.total_num_gate_terms_for_general_purpose_columns // all getes terms over general purpose columns + + 1 // z(1) == 1 copy permutation + + 1 // z(x * omega) = ... + + self.num_intermediate_partial_product_relations; // chunking copy permutation part; + } + + fn compute_num_poly_values_at_z(&mut self, fixed_parameters: &VerificationKeyCircuitGeometry) { + let expected_lookup_polys_total = if fixed_parameters.lookup_parameters.lookup_is_allowed() { + self.num_lookup_subarguments + // lookup witness encoding polys + self.num_multiplicities_polys * 2 + // multiplicity and multiplicity encoding + fixed_parameters.lookup_parameters.lookup_width() + // encode tables itself + 1 // encode table IDs + } else { + 0 + }; + + self.num_poly_values_at_z = self.num_variable_polys + self.num_witness_polys + + self.num_constant_polys + self.num_copy_permutation_polys + + 1 + // z_poly + self.num_intermediate_partial_product_relations + // partial products in copy-permutation + expected_lookup_polys_total + // everything from lookup + self.quotient_degree; // chunks of quotient poly + } + + fn compute_num_poly_values_at_z_omega(&mut self) { + self.num_poly_values_at_z_omega = 1; + } + + fn compute_num_poly_values_at_zero(&mut self) { + self.num_poly_values_at_zero = self.num_lookup_subarguments + self.num_multiplicities_polys; + } + + fn compute_total_num_challenges_for_fri_quotiening>(&mut self, verifier: &WrapperVerifier) { + let expected_lookup_polys_total = if verifier.lookup_parameters.lookup_is_allowed() { + self.num_lookup_subarguments + // lookup witness encoding polys + self.num_multiplicities_polys * 2 + // multiplicity and multiplicity encoding + verifier.lookup_parameters.lookup_width() + // encode tables itself + 1 // encode table IDs + } else { + 0 + }; + + let num_poly_values_at_z = self.num_variable_polys + self.num_witness_polys + + self.num_constant_polys + self.num_copy_permutation_polys + + 1 + // z_poly + self.num_intermediate_partial_product_relations + // partial products in copy-permutation + expected_lookup_polys_total + // everything from lookup + self.quotient_degree; // chunks of quotient poly + + let mut total_num_challenges = 0; + total_num_challenges += num_poly_values_at_z; + total_num_challenges += 1; + total_num_challenges += self.total_num_lookup_argument_terms; + total_num_challenges += self.num_public_inputs; + + self.total_num_challenges_for_fri_quotiening = total_num_challenges; + } + + fn compute_fri_parameters(&mut self, fixed_parameters: &VerificationKeyCircuitGeometry, proof_config: &ProofConfig) { + let ( + new_pow_bits, // updated POW bits if needed + num_queries, // num queries + interpolation_log2s_schedule, // folding schedule + final_expected_degree, + ) = crate::boojum::cs::implementations::prover::compute_fri_schedule( + proof_config.security_level as u32, + proof_config.merkle_tree_cap_size, + proof_config.pow_bits, + fixed_parameters.fri_lde_factor.trailing_zeros(), + fixed_parameters.domain_size.trailing_zeros(), + ); + + let mut expected_degree = fixed_parameters.domain_size; + + for interpolation_degree_log2 in interpolation_log2s_schedule.iter() { + expected_degree >>= interpolation_degree_log2; + } + + assert_eq!(final_expected_degree, expected_degree as usize); + + self.new_pow_bits = new_pow_bits as usize; + self.num_fri_repetitions = num_queries; + self.fri_folding_schedule = interpolation_log2s_schedule; + self.final_expected_degree = final_expected_degree; + } +} diff --git a/crates/snark-wrapper/src/verifier_structs/gate_evaluator.rs b/crates/snark-wrapper/src/verifier_structs/gate_evaluator.rs new file mode 100644 index 0000000..439638b --- /dev/null +++ b/crates/snark-wrapper/src/verifier_structs/gate_evaluator.rs @@ -0,0 +1,189 @@ +use super::*; +use derivative::*; + +use crate::boojum::cs::implementations::verifier::VerifierPolyStorage; +use crate::boojum::cs::implementations::verifier::VerifierRelationDestination; +use crate::boojum::cs::traits::evaluator::GatePlacementType; +use crate::boojum::cs::traits::evaluator::GatePurpose; +use crate::boojum::cs::traits::evaluator::GenericColumnwiseEvaluator; +use crate::boojum::cs::traits::evaluator::GenericDynamicEvaluatorOverGeneralPurposeColumns; +use crate::boojum::cs::traits::evaluator::GenericDynamicEvaluatorOverSpecializedColumns; +use crate::boojum::cs::traits::evaluator::GenericRowwiseEvaluator; + +use crate::franklin_crypto::plonk::circuit::goldilocks::prime_field_like::GoldilocksExtAsFieldWrapper; + +#[derive(Derivative)] +#[derivative(Debug)] +pub(crate) struct TypeErasedGateEvaluationWrapperVerificationFunction + 'static> { + pub(crate) debug_name: String, + pub(crate) evaluator_type_id: TypeId, + pub(crate) gate_purpose: GatePurpose, + pub(crate) max_constraint_degree: usize, + pub(crate) num_quotient_terms: usize, + pub(crate) num_required_constants: usize, + pub(crate) total_quotient_terms_over_all_repetitions: usize, + pub(crate) num_repetitions_on_row: usize, + pub(crate) placement_type: GatePlacementType, + #[derivative(Debug = "ignore")] + pub(crate) columnwise_satisfiability_function: Option< + Box< + dyn GenericDynamicEvaluatorOverSpecializedColumns< + GL, + GoldilocksExtAsFieldWrapper, + VerifierPolyStorage>, + VerifierRelationDestination>, + > + + 'static + + Send + + Sync, + >, + >, + #[derivative(Debug = "ignore")] + pub(crate) rowwise_satisfiability_function: Option< + Box< + dyn GenericDynamicEvaluatorOverGeneralPurposeColumns< + GL, + GoldilocksExtAsFieldWrapper, + VerifierPolyStorage>, + VerifierRelationDestination>, + > + + 'static + + Send + + Sync, + >, + >, +} + +use crate::boojum::cs::traits::evaluator::GateBatchEvaluationComparisonFunction; +use crate::boojum::cs::traits::evaluator::GateConstraintEvaluator; +impl + 'static> TypeErasedGateEvaluationWrapperVerificationFunction { + pub fn from_evaluator>( + cs: &mut CS, + evaluator: EV, + geometry: &CSGeometry, + placement_strategy: GatePlacementStrategy, + ) -> (Self, GateBatchEvaluationComparisonFunction) { + let debug_name = evaluator.instance_name(); + let evaluator_type_id = std::any::TypeId::of::(); + let gate_purpose = EV::gate_purpose(); + let max_constraint_degree = EV::max_constraint_degree(); + let num_quotient_terms = EV::num_quotient_terms(); + let num_required_constants = evaluator.num_required_constants_in_geometry(geometry); + let placement_type = evaluator.placement_type(); + let mut final_per_chunk_offset = PerChunkOffset::zero(); + let (num_repetitions_on_row, total_quotient_terms_over_all_repetitions) = match placement_strategy { + GatePlacementStrategy::UseGeneralPurposeColumns => { + let num_repetitions_on_row = evaluator.num_repetitions_in_geometry(geometry); + if let GatePlacementType::MultipleOnRow { per_chunk_offset } = &placement_type { + debug_assert!(num_repetitions_on_row > 0, "invalid for evaluator {}", std::any::type_name::()); + final_per_chunk_offset = *per_chunk_offset; + } else { + debug_assert_eq!(num_repetitions_on_row, 1); + } + + let total_quotient_terms_in_geometry = evaluator.total_quotient_terms_in_geometry(geometry); + + (num_repetitions_on_row, total_quotient_terms_in_geometry) + } + GatePlacementStrategy::UseSpecializedColumns { num_repetitions, share_constants } => { + let principal_width = evaluator.instance_width(); + final_per_chunk_offset = PerChunkOffset { + variables_offset: principal_width.num_variables, + witnesses_offset: principal_width.num_witnesses, + constants_offset: principal_width.num_constants, + }; + if share_constants { + final_per_chunk_offset.constants_offset = 0; + } + + (num_repetitions, num_repetitions * num_quotient_terms) + } + }; + + let (specialized_satisfiability_evaluator, general_purpose_satisfiability_evaluator) = match placement_strategy { + GatePlacementStrategy::UseSpecializedColumns { .. } => { + let specialized_evaluator = GenericColumnwiseEvaluator { + evaluator: evaluator.clone(), + global_constants: evaluator.create_global_constants(cs), + num_repetitions: num_repetitions_on_row, + per_chunk_offset: final_per_chunk_offset, + }; + + // dbg!(&specialized_evaluator); + + ( + Some(Box::new(specialized_evaluator) + as Box< + dyn GenericDynamicEvaluatorOverSpecializedColumns< + GL, + GoldilocksExtAsFieldWrapper, + VerifierPolyStorage>, + VerifierRelationDestination>, + > + + 'static + + Send + + Sync, + >), + None, + ) + } + GatePlacementStrategy::UseGeneralPurposeColumns => { + let general_purpose_evaluator = GenericRowwiseEvaluator { + evaluator: evaluator.clone(), + global_constants: evaluator.create_global_constants(cs), + num_repetitions: num_repetitions_on_row, + per_chunk_offset: final_per_chunk_offset, + }; + + // dbg!(&general_purpose_evaluator); + + ( + None, + Some(Box::new(general_purpose_evaluator) + as Box< + dyn GenericDynamicEvaluatorOverGeneralPurposeColumns< + GL, + GoldilocksExtAsFieldWrapper, + VerifierPolyStorage>, + VerifierRelationDestination>, + > + + 'static + + Send + + Sync, + >), + ) + } + }; + + let this_params = evaluator.unique_params(); + + let comparison_fn = move |other_evaluator: &dyn std::any::Any| -> bool { + assert_eq!(other_evaluator.type_id(), evaluator_type_id); + let other_evaluator: &EV = other_evaluator.downcast_ref().expect("must downcast"); + + this_params == other_evaluator.unique_params() + }; + + let comparator = GateBatchEvaluationComparisonFunction { + type_id: evaluator_type_id, + evaluator_dyn: Box::new(evaluator), + equality_fn: Box::new(comparison_fn), + }; + + let new = Self { + debug_name, + evaluator_type_id, + gate_purpose, + max_constraint_degree, + num_quotient_terms, + num_required_constants, + total_quotient_terms_over_all_repetitions, + num_repetitions_on_row, + placement_type, + columnwise_satisfiability_function: specialized_satisfiability_evaluator, + rowwise_satisfiability_function: general_purpose_satisfiability_evaluator, + }; + + (new, comparator) + } +} diff --git a/crates/snark-wrapper/src/verifier_structs/mod.rs b/crates/snark-wrapper/src/verifier_structs/mod.rs new file mode 100644 index 0000000..c7a7ef7 --- /dev/null +++ b/crates/snark-wrapper/src/verifier_structs/mod.rs @@ -0,0 +1,67 @@ +use crate::franklin_crypto::plonk::circuit::goldilocks::GoldilocksField; + +use crate::boojum::cs::implementations::proof::Proof; +use crate::boojum::cs::implementations::proof::{OracleQuery, SingleRoundQueries}; +use crate::boojum::cs::implementations::prover::ProofConfig; +use crate::boojum::cs::implementations::verifier::VerificationKeyCircuitGeometry; +use crate::boojum::cs::oracle::TreeHasher; +use crate::boojum::cs::traits::evaluator::PerChunkOffset; +use crate::boojum::cs::traits::gate::GatePlacementStrategy; +use crate::boojum::cs::CSGeometry; +use crate::boojum::cs::LookupParameters; +use crate::boojum::field::goldilocks::{GoldilocksExt2 as GLExt2, GoldilocksField as GL}; + +use crate::franklin_crypto::bellman::pairing::Engine; +use crate::franklin_crypto::bellman::plonk::better_better_cs::cs::ConstraintSystem; +use crate::franklin_crypto::bellman::SynthesisError; +use crate::franklin_crypto::plonk::circuit::allocated_num::Num; +use crate::franklin_crypto::plonk::circuit::boolean::Boolean; + +use crate::traits::tree_hasher::CircuitGLTreeHasher; +use crate::verifier_structs::constants::ConstantsHolder; +use crate::verifier_structs::gate_evaluator::TypeErasedGateEvaluationWrapperVerificationFunction; + +use std::any::TypeId; +use std::collections::HashMap; + +pub mod allocated_proof; +pub mod allocated_queries; +pub mod allocated_vk; +pub mod challenges; +pub mod constants; +pub mod gate_evaluator; + +pub struct WrapperVerifier + 'static> { + // when we init we get the following from VK + pub parameters: CSGeometry, + pub lookup_parameters: LookupParameters, + + pub(crate) gate_type_ids_for_specialized_columns: Vec, + pub(crate) evaluators_over_specialized_columns: Vec>, + pub(crate) offsets_for_specialized_evaluators: Vec<(PerChunkOffset, PerChunkOffset, usize)>, + + pub(crate) evaluators_over_general_purpose_columns: Vec>, + + pub(crate) total_num_variables_for_specialized_columns: usize, + pub(crate) total_num_witnesses_for_specialized_columns: usize, + pub(crate) total_num_constants_for_specialized_columns: usize, + + pub(crate) placement_strategies: HashMap, +} + +pub fn allocate_num_elements>( + cs: &mut CS, + num_elements: usize, + mut source: Option>, + allocating_function: impl Fn(&mut CS, Option) -> Result, +) -> Result, SynthesisError> { + let mut result = Vec::with_capacity(num_elements); + + for _ in 0..num_elements { + let el = source.as_mut().map(|el| el.next().expect("Should be enough elements in the source")); + result.push(allocating_function(cs, el)?); + } + debug_assert!(source.as_mut().map(|el| el.next().is_none()).unwrap_or(true)); + + Ok(result) +}