From 507fafdb728db1ac16386e0a21d4e6998984cab7 Mon Sep 17 00:00:00 2001 From: AztecBot Date: Tue, 21 May 2024 10:41:19 +0000 Subject: [PATCH 1/3] [1 changes] feat: add native rust implementations of pedersen functions (https://github.com/noir-lang/noir/pull/4871) chore: add benchmarks for pedersen and schnorr verification (https://github.com/noir-lang/noir/pull/5056) --- .noir-sync-commit | 2 +- .../.github/workflows/formatting.yml | 2 +- .../.github/workflows/gates_report.yml | 144 +- noir/noir-repo/.tokeignore | 12 + noir/noir-repo/Cargo.lock | 2 +- .../acvm-repo/acir/benches/serialization.rs | 2 +- .../optimizers/constant_backpropagation.rs | 2 +- .../compiler/optimizers/redundant_range.rs | 2 +- noir/noir-repo/acvm-repo/acvm_js/build.sh | 2 +- .../bn254_blackbox_solver/Cargo.toml | 1 + .../benches/criterion.rs | 53 +- .../src/generator/generators.rs | 184 ++ .../src/generator/hash_to_curve.rs | 135 + .../src/generator/mod.rs | 8 + .../bn254_blackbox_solver/src/lib.rs | 24 +- .../src/pedersen/commitment.rs | 76 + .../src/pedersen/hash.rs | 68 + .../bn254_blackbox_solver/src/pedersen/mod.rs | 2 + .../src/wasm/barretenberg_structures.rs | 25 - .../bn254_blackbox_solver/src/wasm/mod.rs | 20 - .../src/wasm/pedersen.rs | 73 - .../compiler/noirc_driver/src/lib.rs | 29 +- .../noirc_driver/tests/stdlib_warnings.rs | 3 +- .../src/brillig/brillig_gen/brillig_block.rs | 87 +- .../noirc_evaluator/src/ssa/acir_gen/mod.rs | 91 +- .../noirc_evaluator/src/ssa/opt/inlining.rs | 8 +- .../src/ssa/opt/remove_bit_shifts.rs | 11 +- .../src/ssa/opt/remove_enable_side_effects.rs | 12 +- .../src/ssa/ssa_gen/context.rs | 49 +- .../noirc_frontend/src/ast/function.rs | 9 + .../compiler/noirc_frontend/src/ast/mod.rs | 3 + .../noirc_frontend/src/ast/statement.rs | 5 +- .../src/elaborator/expressions.rs | 604 ++++ .../noirc_frontend/src/elaborator/mod.rs | 782 ++++++ .../noirc_frontend/src/elaborator/patterns.rs | 465 +++ .../noirc_frontend/src/elaborator/scope.rs | 200 ++ .../src/elaborator/statements.rs | 409 +++ .../noirc_frontend/src/elaborator/types.rs | 1438 ++++++++++ .../src/hir/comptime/interpreter.rs | 64 + .../noirc_frontend/src/hir/comptime/tests.rs | 13 + .../noirc_frontend/src/hir/comptime/value.rs | 13 + .../src/hir/def_collector/dc_crate.rs | 104 +- .../src/hir/def_collector/dc_mod.rs | 16 +- .../noirc_frontend/src/hir/def_map/mod.rs | 10 +- .../src/hir/resolution/import.rs | 9 + .../src/hir/resolution/resolver.rs | 12 +- .../noirc_frontend/src/hir/type_check/expr.rs | 4 +- .../noirc_frontend/src/hir/type_check/mod.rs | 2 +- .../noirc_frontend/src/hir_def/expr.rs | 11 +- .../noirc_frontend/src/hir_def/function.rs | 5 +- .../noirc_frontend/src/hir_def/types.rs | 8 +- .../compiler/noirc_frontend/src/lib.rs | 1 + .../noirc_frontend/src/node_interner.rs | 6 +- .../noirc_frontend/src/parser/parser.rs | 7 +- .../compiler/noirc_frontend/src/tests.rs | 2482 ++++++++--------- .../src/tests/name_shadowing.rs | 419 +++ noir/noir-repo/compiler/wasm/src/compile.rs | 21 +- .../compiler/wasm/src/compile_new.rs | 30 +- noir/noir-repo/cspell.json | 1 + .../docs/noir/concepts/data_types/integers.md | 4 +- .../docs/docs/noir/standard_library/traits.md | 33 +- noir/noir-repo/noir_stdlib/src/aes128.nr | 1 - .../noir_stdlib/src/embedded_curve_ops.nr | 7 +- noir/noir-repo/noir_stdlib/src/ops.nr | 173 +- noir/noir-repo/noir_stdlib/src/ops/arith.nr | 103 + noir/noir-repo/noir_stdlib/src/ops/bit.nr | 109 + noir/noir-repo/noir_stdlib/src/uint128.nr | 270 +- noir/noir-repo/scripts/count_loc.sh | 33 + .../security/insectarium/noir_stdlib.md | 61 + .../brillig_embedded_curve/src/main.nr | 6 +- .../no_predicates_brillig/Nargo.toml | 7 + .../no_predicates_brillig/Prover.toml | 2 + .../no_predicates_brillig/src/main.nr | 16 + .../execution_success/u16_support/Nargo.toml | 7 + .../execution_success/u16_support/Prover.toml | 1 + .../execution_success/u16_support/src/main.nr | 24 + .../tooling/backend_interface/Cargo.toml | 1 - .../tooling/backend_interface/src/cli/info.rs | 62 - .../tooling/backend_interface/src/cli/mod.rs | 2 - .../backend_interface/src/proof_system.rs | 25 +- .../mock_backend/src/info_cmd.rs | 40 - .../test-binaries/mock_backend/src/main.rs | 3 - .../tooling/bb_abstraction_leaks/build.rs | 2 +- noir/noir-repo/tooling/lsp/src/lib.rs | 2 +- .../tooling/lsp/src/notifications/mod.rs | 4 +- .../lsp/src/requests/code_lens_request.rs | 2 +- .../lsp/src/requests/goto_declaration.rs | 2 +- .../lsp/src/requests/goto_definition.rs | 2 +- .../tooling/lsp/src/requests/test_run.rs | 2 +- .../tooling/lsp/src/requests/tests.rs | 2 +- .../tooling/nargo_cli/src/cli/check_cmd.rs | 11 +- .../nargo_cli/src/cli/codegen_verifier_cmd.rs | 3 +- .../tooling/nargo_cli/src/cli/compile_cmd.rs | 17 +- .../tooling/nargo_cli/src/cli/dap_cmd.rs | 18 +- .../tooling/nargo_cli/src/cli/debug_cmd.rs | 14 +- .../tooling/nargo_cli/src/cli/execute_cmd.rs | 14 +- .../tooling/nargo_cli/src/cli/export_cmd.rs | 8 +- .../tooling/nargo_cli/src/cli/info_cmd.rs | 23 +- .../tooling/nargo_cli/src/cli/lsp_cmd.rs | 8 +- .../tooling/nargo_cli/src/cli/mod.rs | 18 +- .../tooling/nargo_cli/src/cli/new_cmd.rs | 8 +- .../tooling/nargo_cli/src/cli/prove_cmd.rs | 7 +- .../tooling/nargo_cli/src/cli/test_cmd.rs | 10 +- .../tooling/nargo_cli/src/cli/verify_cmd.rs | 7 +- .../tooling/nargo_cli/tests/stdlib-tests.rs | 17 +- .../tooling/noir_js/test/node/execute.test.ts | 12 + .../noir_js_backend_barretenberg/package.json | 2 +- noir/noir-repo/tooling/noirc_abi/src/lib.rs | 10 +- noir/noir-repo/yarn.lock | 13 +- 109 files changed, 7370 insertions(+), 2140 deletions(-) create mode 100644 noir/noir-repo/.tokeignore create mode 100644 noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/generators.rs create mode 100644 noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/hash_to_curve.rs create mode 100644 noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/mod.rs create mode 100644 noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/commitment.rs create mode 100644 noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/hash.rs create mode 100644 noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/mod.rs delete mode 100644 noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/barretenberg_structures.rs delete mode 100644 noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/pedersen.rs create mode 100644 noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs create mode 100644 noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs create mode 100644 noir/noir-repo/compiler/noirc_frontend/src/elaborator/patterns.rs create mode 100644 noir/noir-repo/compiler/noirc_frontend/src/elaborator/scope.rs create mode 100644 noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs create mode 100644 noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs create mode 100644 noir/noir-repo/compiler/noirc_frontend/src/tests/name_shadowing.rs create mode 100644 noir/noir-repo/noir_stdlib/src/ops/arith.nr create mode 100644 noir/noir-repo/noir_stdlib/src/ops/bit.nr create mode 100755 noir/noir-repo/scripts/count_loc.sh create mode 100644 noir/noir-repo/security/insectarium/noir_stdlib.md create mode 100644 noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Nargo.toml create mode 100644 noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Prover.toml create mode 100644 noir/noir-repo/test_programs/execution_success/no_predicates_brillig/src/main.nr create mode 100644 noir/noir-repo/test_programs/execution_success/u16_support/Nargo.toml create mode 100644 noir/noir-repo/test_programs/execution_success/u16_support/Prover.toml create mode 100644 noir/noir-repo/test_programs/execution_success/u16_support/src/main.nr delete mode 100644 noir/noir-repo/tooling/backend_interface/src/cli/info.rs delete mode 100644 noir/noir-repo/tooling/backend_interface/test-binaries/mock_backend/src/info_cmd.rs diff --git a/.noir-sync-commit b/.noir-sync-commit index 61a3851ea0c..4195c98aff3 100644 --- a/.noir-sync-commit +++ b/.noir-sync-commit @@ -1 +1 @@ -c49d3a9ded819b828cffdfc031e86614da21e329 +fb039f74df23aea39bc0593a5d538d82b4efadf0 diff --git a/noir/noir-repo/.github/workflows/formatting.yml b/noir/noir-repo/.github/workflows/formatting.yml index 8166fb0f7c2..08c02af519f 100644 --- a/noir/noir-repo/.github/workflows/formatting.yml +++ b/noir/noir-repo/.github/workflows/formatting.yml @@ -44,7 +44,7 @@ jobs: save-if: ${{ github.event_name != 'merge_group' }} - name: Run `cargo clippy` - run: cargo clippy --workspace --locked --release + run: cargo clippy --all-targets --workspace --locked --release - name: Run `cargo fmt` run: cargo fmt --all --check diff --git a/noir/noir-repo/.github/workflows/gates_report.yml b/noir/noir-repo/.github/workflows/gates_report.yml index ba4cb600c59..3d4bef1940e 100644 --- a/noir/noir-repo/.github/workflows/gates_report.yml +++ b/noir/noir-repo/.github/workflows/gates_report.yml @@ -1,88 +1,88 @@ -name: Report gates diff +# name: Report gates diff -on: - push: - branches: - - master - pull_request: +# on: +# push: +# branches: +# - master +# pull_request: -jobs: - build-nargo: - runs-on: ubuntu-latest - strategy: - matrix: - target: [x86_64-unknown-linux-gnu] +# jobs: +# build-nargo: +# runs-on: ubuntu-latest +# strategy: +# matrix: +# target: [x86_64-unknown-linux-gnu] - steps: - - name: Checkout Noir repo - uses: actions/checkout@v4 +# steps: +# - name: Checkout Noir repo +# uses: actions/checkout@v4 - - name: Setup toolchain - uses: dtolnay/rust-toolchain@1.74.1 +# - name: Setup toolchain +# uses: dtolnay/rust-toolchain@1.74.1 - - uses: Swatinem/rust-cache@v2 - with: - key: ${{ matrix.target }} - cache-on-failure: true - save-if: ${{ github.event_name != 'merge_group' }} +# - uses: Swatinem/rust-cache@v2 +# with: +# key: ${{ matrix.target }} +# cache-on-failure: true +# save-if: ${{ github.event_name != 'merge_group' }} - - name: Build Nargo - run: cargo build --package nargo_cli --release +# - name: Build Nargo +# run: cargo build --package nargo_cli --release - - name: Package artifacts - run: | - mkdir dist - cp ./target/release/nargo ./dist/nargo - 7z a -ttar -so -an ./dist/* | 7z a -si ./nargo-x86_64-unknown-linux-gnu.tar.gz +# - name: Package artifacts +# run: | +# mkdir dist +# cp ./target/release/nargo ./dist/nargo +# 7z a -ttar -so -an ./dist/* | 7z a -si ./nargo-x86_64-unknown-linux-gnu.tar.gz - - name: Upload artifact - uses: actions/upload-artifact@v4 - with: - name: nargo - path: ./dist/* - retention-days: 3 +# - name: Upload artifact +# uses: actions/upload-artifact@v4 +# with: +# name: nargo +# path: ./dist/* +# retention-days: 3 - compare_gas_reports: - needs: [build-nargo] - runs-on: ubuntu-latest - permissions: - pull-requests: write +# compare_gas_reports: +# needs: [build-nargo] +# runs-on: ubuntu-latest +# permissions: +# pull-requests: write - steps: - - uses: actions/checkout@v4 +# steps: +# - uses: actions/checkout@v4 - - name: Download nargo binary - uses: actions/download-artifact@v4 - with: - name: nargo - path: ./nargo +# - name: Download nargo binary +# uses: actions/download-artifact@v4 +# with: +# name: nargo +# path: ./nargo - - name: Set nargo on PATH - run: | - nargo_binary="${{ github.workspace }}/nargo/nargo" - chmod +x $nargo_binary - echo "$(dirname $nargo_binary)" >> $GITHUB_PATH - export PATH="$PATH:$(dirname $nargo_binary)" - nargo -V +# - name: Set nargo on PATH +# run: | +# nargo_binary="${{ github.workspace }}/nargo/nargo" +# chmod +x $nargo_binary +# echo "$(dirname $nargo_binary)" >> $GITHUB_PATH +# export PATH="$PATH:$(dirname $nargo_binary)" +# nargo -V - - name: Generate gates report - working-directory: ./test_programs - run: | - ./gates_report.sh - mv gates_report.json ../gates_report.json +# - name: Generate gates report +# working-directory: ./test_programs +# run: | +# ./gates_report.sh +# mv gates_report.json ../gates_report.json - - name: Compare gates reports - id: gates_diff - uses: vezenovm/noir-gates-diff@acf12797860f237117e15c0d6e08d64253af52b6 - with: - report: gates_report.json - summaryQuantile: 0.9 # only display the 10% most significant circuit size diffs in the summary (defaults to 20%) +# - name: Compare gates reports +# id: gates_diff +# uses: vezenovm/noir-gates-diff@acf12797860f237117e15c0d6e08d64253af52b6 +# with: +# report: gates_report.json +# summaryQuantile: 0.9 # only display the 10% most significant circuit size diffs in the summary (defaults to 20%) - - name: Add gates diff to sticky comment - if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' - uses: marocchino/sticky-pull-request-comment@v2 - with: - # delete the comment in case changes no longer impact circuit sizes - delete: ${{ !steps.gates_diff.outputs.markdown }} - message: ${{ steps.gates_diff.outputs.markdown }} +# - name: Add gates diff to sticky comment +# if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' +# uses: marocchino/sticky-pull-request-comment@v2 +# with: +# # delete the comment in case changes no longer impact circuit sizes +# delete: ${{ !steps.gates_diff.outputs.markdown }} +# message: ${{ steps.gates_diff.outputs.markdown }} diff --git a/noir/noir-repo/.tokeignore b/noir/noir-repo/.tokeignore new file mode 100644 index 00000000000..55f24e41dbd --- /dev/null +++ b/noir/noir-repo/.tokeignore @@ -0,0 +1,12 @@ +docs +scripts + +# aztec_macros is explicitly considered OOS for Noir audit +aztec_macros + +# config files +*.toml +*.md +*.json +*.txt +*.config.mjs diff --git a/noir/noir-repo/Cargo.lock b/noir/noir-repo/Cargo.lock index 859579c077f..2e8eeb10a58 100644 --- a/noir/noir-repo/Cargo.lock +++ b/noir/noir-repo/Cargo.lock @@ -462,7 +462,6 @@ dependencies = [ "dirs", "flate2", "reqwest", - "serde", "serde_json", "tar", "tempfile", @@ -615,6 +614,7 @@ dependencies = [ "acvm_blackbox_solver", "ark-ec", "ark-ff", + "ark-std", "cfg-if 1.0.0", "criterion", "getrandom 0.2.10", diff --git a/noir/noir-repo/acvm-repo/acir/benches/serialization.rs b/noir/noir-repo/acvm-repo/acir/benches/serialization.rs index e51726e3901..a7f32b4a4c7 100644 --- a/noir/noir-repo/acvm-repo/acir/benches/serialization.rs +++ b/noir/noir-repo/acvm-repo/acir/benches/serialization.rs @@ -33,7 +33,7 @@ fn sample_program(num_opcodes: usize) -> Program { functions: vec![Circuit { current_witness_index: 4000, opcodes: assert_zero_opcodes.to_vec(), - expression_width: ExpressionWidth::Bounded { width: 3 }, + expression_width: ExpressionWidth::Bounded { width: 4 }, private_parameters: BTreeSet::from([Witness(1), Witness(2), Witness(3), Witness(4)]), public_parameters: PublicInputs(BTreeSet::from([Witness(5)])), return_values: PublicInputs(BTreeSet::from([Witness(6)])), diff --git a/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/constant_backpropagation.rs b/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/constant_backpropagation.rs index 0e7d28104da..5b778f63f07 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/constant_backpropagation.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/constant_backpropagation.rs @@ -282,7 +282,7 @@ mod tests { fn test_circuit(opcodes: Vec) -> Circuit { Circuit { current_witness_index: 1, - expression_width: ExpressionWidth::Bounded { width: 3 }, + expression_width: ExpressionWidth::Bounded { width: 4 }, opcodes, private_parameters: BTreeSet::new(), public_parameters: PublicInputs::default(), diff --git a/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs b/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs index c6ca18d30ae..0e1629717b3 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs @@ -164,7 +164,7 @@ mod tests { Circuit { current_witness_index: 1, - expression_width: ExpressionWidth::Bounded { width: 3 }, + expression_width: ExpressionWidth::Bounded { width: 4 }, opcodes, private_parameters: BTreeSet::new(), public_parameters: PublicInputs::default(), diff --git a/noir/noir-repo/acvm-repo/acvm_js/build.sh b/noir/noir-repo/acvm-repo/acvm_js/build.sh index c07d2d8a4c1..16fb26e55db 100755 --- a/noir/noir-repo/acvm-repo/acvm_js/build.sh +++ b/noir/noir-repo/acvm-repo/acvm_js/build.sh @@ -25,7 +25,7 @@ function run_if_available { require_command jq require_command cargo require_command wasm-bindgen -#require_command wasm-opt +require_command wasm-opt self_path=$(dirname "$(readlink -f "$0")") pname=$(cargo read-manifest | jq -r '.name') diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/Cargo.toml b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/Cargo.toml index 3a6c9b1d55b..b261be65735 100644 --- a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/Cargo.toml +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/Cargo.toml @@ -40,6 +40,7 @@ getrandom.workspace = true wasmer = "4.2.6" [dev-dependencies] +ark-std = { version = "^0.4.0", default-features = false } criterion = "0.5.0" pprof = { version = "0.12", features = [ "flamegraph", diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/benches/criterion.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/benches/criterion.rs index eb529ed2c11..a8fa7d8aae4 100644 --- a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/benches/criterion.rs +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/benches/criterion.rs @@ -2,7 +2,8 @@ use criterion::{criterion_group, criterion_main, Criterion}; use std::{hint::black_box, time::Duration}; use acir::FieldElement; -use bn254_blackbox_solver::poseidon2_permutation; +use acvm_blackbox_solver::BlackBoxFunctionSolver; +use bn254_blackbox_solver::{poseidon2_permutation, Bn254BlackBoxSolver}; use pprof::criterion::{Output, PProfProfiler}; @@ -12,10 +13,58 @@ fn bench_poseidon2(c: &mut Criterion) { c.bench_function("poseidon2", |b| b.iter(|| poseidon2_permutation(black_box(&inputs), 4))); } +fn bench_pedersen_commitment(c: &mut Criterion) { + let inputs = [FieldElement::one(); 2]; + let solver = Bn254BlackBoxSolver::new(); + + c.bench_function("pedersen_commitment", |b| { + b.iter(|| solver.pedersen_commitment(black_box(&inputs), 0)) + }); +} + +fn bench_pedersen_hash(c: &mut Criterion) { + let inputs = [FieldElement::one(); 2]; + let solver = Bn254BlackBoxSolver::new(); + + c.bench_function("pedersen_hash", |b| b.iter(|| solver.pedersen_hash(black_box(&inputs), 0))); +} + +fn bench_schnorr_verify(c: &mut Criterion) { + let solver = Bn254BlackBoxSolver::new(); + + let pub_key_x = FieldElement::from_hex( + "0x04b260954662e97f00cab9adb773a259097f7a274b83b113532bce27fa3fb96a", + ) + .unwrap(); + let pub_key_y = FieldElement::from_hex( + "0x2fd51571db6c08666b0edfbfbc57d432068bccd0110a39b166ab243da0037197", + ) + .unwrap(); + let sig_bytes: [u8; 64] = [ + 1, 13, 119, 112, 212, 39, 233, 41, 84, 235, 255, 93, 245, 172, 186, 83, 157, 253, 76, 77, + 33, 128, 178, 15, 214, 67, 105, 107, 177, 234, 77, 48, 27, 237, 155, 84, 39, 84, 247, 27, + 22, 8, 176, 230, 24, 115, 145, 220, 254, 122, 135, 179, 171, 4, 214, 202, 64, 199, 19, 84, + 239, 138, 124, 12, + ]; + + let message: &[u8] = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + + c.bench_function("schnorr_verify", |b| { + b.iter(|| { + solver.schnorr_verify( + black_box(&pub_key_x), + black_box(&pub_key_y), + black_box(&sig_bytes), + black_box(message), + ) + }) + }); +} + criterion_group!( name = benches; config = Criterion::default().sample_size(40).measurement_time(Duration::from_secs(20)).with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); - targets = bench_poseidon2 + targets = bench_poseidon2, bench_pedersen_commitment, bench_pedersen_hash, bench_schnorr_verify ); criterion_main!(benches); diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/generators.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/generators.rs new file mode 100644 index 00000000000..f89d582d167 --- /dev/null +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/generators.rs @@ -0,0 +1,184 @@ +// Adapted from https://github.com/laudiacay/barustenberg/blob/df6bc6f095fe7f288bf6a12e7317fd8eb33d68ae/barustenberg/src/ecc/groups/affine_element.rshttps://github.com/laudiacay/barustenberg/blob/df6bc6f095fe7f288bf6a12e7317fd8eb33d68ae/barustenberg/src/ecc/groups/group.rs +//! +//! Code is used under the MIT license + +use std::sync::OnceLock; + +use ark_ec::short_weierstrass::Affine; + +use acvm_blackbox_solver::blake3; +use grumpkin::GrumpkinParameters; + +use super::hash_to_curve::hash_to_curve; + +pub(crate) const DEFAULT_DOMAIN_SEPARATOR: &[u8] = "DEFAULT_DOMAIN_SEPARATOR".as_bytes(); +const NUM_DEFAULT_GENERATORS: usize = 8; + +fn default_generators() -> &'static [Affine; NUM_DEFAULT_GENERATORS] { + static INSTANCE: OnceLock<[Affine; NUM_DEFAULT_GENERATORS]> = + OnceLock::new(); + INSTANCE.get_or_init(|| { + _derive_generators(DEFAULT_DOMAIN_SEPARATOR, NUM_DEFAULT_GENERATORS as u32, 0) + .try_into() + .expect("Should generate `NUM_DEFAULT_GENERATORS`") + }) +} + +/// Derives generator points via [hash-to-curve][hash_to_curve]. +/// +/// # ALGORITHM DESCRIPTION +/// +/// 1. Each generator has an associated "generator index" described by its location in the vector +/// 2. a 64-byte preimage buffer is generated with the following structure: +/// - bytes 0-31: BLAKE3 hash of domain_separator +/// - bytes 32-63: generator index in big-endian form +/// 3. The [hash-to-curve algorithm][hash_to_curve] is used to hash the above into a curve point. +/// +/// NOTE: The domain separator is included to ensure that it is possible to derive independent sets of +/// index-addressable generators. +/// +/// [hash_to_curve]: super::hash_to_curve::hash_to_curve +pub(crate) fn derive_generators( + domain_separator_bytes: &[u8], + num_generators: u32, + starting_index: u32, +) -> Vec> { + // We cache a small number of the default generators so we can reuse them without needing to repeatedly recalculate them. + if domain_separator_bytes == DEFAULT_DOMAIN_SEPARATOR + && starting_index + num_generators <= NUM_DEFAULT_GENERATORS as u32 + { + let start_index = starting_index as usize; + let end_index = (starting_index + num_generators) as usize; + default_generators()[start_index..end_index].to_vec() + } else { + _derive_generators(domain_separator_bytes, num_generators, starting_index) + } +} + +fn _derive_generators( + domain_separator_bytes: &[u8], + num_generators: u32, + starting_index: u32, +) -> Vec> { + let mut generator_preimage = [0u8; 64]; + let domain_hash = blake3(domain_separator_bytes).expect("hash should succeed"); + //1st 32 bytes are blake3 domain_hash + generator_preimage[..32].copy_from_slice(&domain_hash); + + // Convert generator index in big-endian form + let mut res = Vec::with_capacity(num_generators as usize); + for i in starting_index..(starting_index + num_generators) { + let generator_be_bytes: [u8; 4] = i.to_be_bytes(); + generator_preimage[32] = generator_be_bytes[0]; + generator_preimage[33] = generator_be_bytes[1]; + generator_preimage[34] = generator_be_bytes[2]; + generator_preimage[35] = generator_be_bytes[3]; + let generator = hash_to_curve(&generator_preimage, 0); + res.push(generator); + } + res +} + +#[cfg(test)] +mod test { + + use ark_ec::AffineRepr; + use ark_ff::{BigInteger, PrimeField}; + + use super::*; + + #[test] + fn test_derive_generators() { + let res = derive_generators("test domain".as_bytes(), 128, 0); + + let is_unique = |y: Affine, j: usize| -> bool { + for (i, res) in res.iter().enumerate() { + if i != j && *res == y { + return false; + } + } + true + }; + + for (i, res) in res.iter().enumerate() { + assert!(is_unique(*res, i)); + assert!(res.is_on_curve()); + } + } + + #[test] + fn derive_length_generator() { + let domain_separator = "pedersen_hash_length"; + let length_generator = derive_generators(domain_separator.as_bytes(), 1, 0)[0]; + + let expected_generator = ( + "2df8b940e5890e4e1377e05373fae69a1d754f6935e6a780b666947431f2cdcd", + "2ecd88d15967bc53b885912e0d16866154acb6aac2d3f85e27ca7eefb2c19083", + ); + assert_eq!( + hex::encode(length_generator.x().unwrap().into_bigint().to_bytes_be()), + expected_generator.0, + "Failed on x component" + ); + assert_eq!( + hex::encode(length_generator.y().unwrap().into_bigint().to_bytes_be()), + expected_generator.1, + "Failed on y component" + ); + } + + #[test] + fn derives_default_generators() { + const DEFAULT_GENERATORS: &[[&str; 2]] = &[ + [ + "083e7911d835097629f0067531fc15cafd79a89beecb39903f69572c636f4a5a", + "1a7f5efaad7f315c25a918f30cc8d7333fccab7ad7c90f14de81bcc528f9935d", + ], + [ + "054aa86a73cb8a34525e5bbed6e43ba1198e860f5f3950268f71df4591bde402", + "209dcfbf2cfb57f9f6046f44d71ac6faf87254afc7407c04eb621a6287cac126", + ], + [ + "1c44f2a5207c81c28a8321a5815ce8b1311024bbed131819bbdaf5a2ada84748", + "03aaee36e6422a1d0191632ac6599ae9eba5ac2c17a8c920aa3caf8b89c5f8a8", + ], + [ + "26d8b1160c6821a30c65f6cb47124afe01c29f4338f44d4a12c9fccf22fb6fb2", + "05c70c3b9c0d25a4c100e3a27bf3cc375f8af8cdd9498ec4089a823d7464caff", + ], + [ + "20ed9c6a1d27271c4498bfce0578d59db1adbeaa8734f7facc097b9b994fcf6e", + "29cd7d370938b358c62c4a00f73a0d10aba7e5aaa04704a0713f891ebeb92371", + ], + [ + "0224a8abc6c8b8d50373d64cd2a1ab1567bf372b3b1f7b861d7f01257052d383", + "2358629b90eafb299d6650a311e79914b0215eb0a790810b26da5a826726d711", + ], + [ + "0f106f6d46bc904a5290542490b2f238775ff3c445b2f8f704c466655f460a2a", + "29ab84d472f1d33f42fe09c47b8f7710f01920d6155250126731e486877bcf27", + ], + [ + "0298f2e42249f0519c8a8abd91567ebe016e480f219b8c19461d6a595cc33696", + "035bec4b8520a4ece27bd5aafabee3dfe1390d7439c419a8c55aceb207aac83b", + ], + ]; + + let generated_generators = + derive_generators(DEFAULT_DOMAIN_SEPARATOR, DEFAULT_GENERATORS.len() as u32, 0); + for (i, (generator, expected_generator)) in + generated_generators.iter().zip(DEFAULT_GENERATORS).enumerate() + { + assert_eq!( + hex::encode(generator.x().unwrap().into_bigint().to_bytes_be()), + expected_generator[0], + "Failed on x component of generator {i}" + ); + assert_eq!( + hex::encode(generator.y().unwrap().into_bigint().to_bytes_be()), + expected_generator[1], + "Failed on y component of generator {i}" + ); + } + } +} diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/hash_to_curve.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/hash_to_curve.rs new file mode 100644 index 00000000000..c0197883442 --- /dev/null +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/hash_to_curve.rs @@ -0,0 +1,135 @@ +// Adapted from https://github.com/laudiacay/barustenberg/blob/df6bc6f095fe7f288bf6a12e7317fd8eb33d68ae/barustenberg/src/ecc/groups/affine_element.rs +//! +//! Code is used under the MIT license + +use acvm_blackbox_solver::blake3; + +use ark_ec::{short_weierstrass::Affine, AffineRepr, CurveConfig}; +use ark_ff::Field; +use ark_ff::{BigInteger, PrimeField}; +use grumpkin::GrumpkinParameters; + +/// Hash a seed buffer into a point +/// +/// # ALGORITHM DESCRIPTION +/// +/// 1. Initialize unsigned integer `attempt_count = 0` +/// 2. Copy seed into a buffer whose size is 2 bytes greater than `seed` (initialized to `0`) +/// 3. Interpret `attempt_count` as a byte and write into buffer at `[buffer.size() - 2]` +/// 4. Compute Blake3 hash of buffer +/// 5. Set the end byte of the buffer to `1` +/// 6. Compute Blake3 hash of buffer +/// 7. Interpret the two hash outputs as the high / low 256 bits of a 512-bit integer (big-endian) +/// 8. Derive x-coordinate of point by reducing the 512-bit integer modulo the curve's field modulus (Fq) +/// 9. Compute `y^2` from the curve formula `y^2 = x^3 + ax + b` (`a`, `b` are curve params. for BN254, `a = 0`, `b = 3`) +/// 10. IF `y^2` IS NOT A QUADRATIC RESIDUE: +/// +/// a. increment `attempt_count` by 1 and go to step 2 +/// +/// 11. IF `y^2` IS A QUADRATIC RESIDUE: +/// +/// a. derive y coordinate via `y = sqrt(y)` +/// +/// b. Interpret most significant bit of 512-bit integer as a 'parity' bit +/// +/// c. If parity bit is set AND `y`'s most significant bit is not set, invert `y` +/// +/// d. If parity bit is not set AND `y`'s most significant bit is set, invert `y` +/// +/// e. return (x, y) +/// +/// N.B. steps c. and e. are because the `sqrt()` algorithm can return 2 values, +/// we need to a way to canonically distinguish between these 2 values and select a "preferred" one +pub(crate) fn hash_to_curve(seed: &[u8], attempt_count: u8) -> Affine { + let seed_size = seed.len(); + // expand by 2 bytes to cover incremental hash attempts + let mut target_seed = seed.to_vec(); + target_seed.extend_from_slice(&[0u8; 2]); + + target_seed[seed_size] = attempt_count; + target_seed[seed_size + 1] = 0; + let hash_hi = blake3(&target_seed).expect("hash should succeed"); + target_seed[seed_size + 1] = 1; + let hash_lo = blake3(&target_seed).expect("hash should succeed"); + + let mut hash = hash_hi.to_vec(); + hash.extend_from_slice(&hash_lo); + + // Here we reduce the 512 bit number modulo the base field modulus to calculate `x` + let x = <::BaseField as Field>::BasePrimeField::from_be_bytes_mod_order(&hash); + let x = ::BaseField::from_base_prime_field(x); + + if let Some(point) = Affine::::get_point_from_x_unchecked(x, false) { + let parity_bit = hash_hi[0] > 127; + let y_bit_set = point.y().unwrap().into_bigint().get_bit(0); + if (parity_bit && !y_bit_set) || (!parity_bit && y_bit_set) { + -point + } else { + point + } + } else { + hash_to_curve(seed, attempt_count + 1) + } +} + +#[cfg(test)] +mod test { + + use ark_ec::AffineRepr; + use ark_ff::{BigInteger, PrimeField}; + + use super::hash_to_curve; + + #[test] + fn smoke_test() { + let test_cases: [(&[u8], u8, (&str, &str)); 4] = [ + ( + &[], + 0, + ( + "24c4cb9c1206ab5470592f237f1698abe684dadf0ab4d7a132c32b2134e2c12e", + "0668b8d61a317fb34ccad55c930b3554f1828a0e5530479ecab4defe6bbc0b2e", + ), + ), + ( + &[], + 1, + ( + "24c4cb9c1206ab5470592f237f1698abe684dadf0ab4d7a132c32b2134e2c12e", + "0668b8d61a317fb34ccad55c930b3554f1828a0e5530479ecab4defe6bbc0b2e", + ), + ), + ( + &[1], + 0, + ( + "107f1b633c6113f3222f39f6256f0546b41a4880918c86864b06471afb410454", + "050cd3823d0c01590b6a50adcc85d2ee4098668fd28805578aa05a423ea938c6", + ), + ), + ( + &[0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64], + 0, + ( + "037c5c229ae495f6e8d1b4bf7723fafb2b198b51e27602feb8a4d1053d685093", + "10cf9596c5b2515692d930efa2cf3817607e4796856a79f6af40c949b066969f", + ), + ), + ]; + + for (seed, attempt_count, expected_point) in test_cases { + let point = hash_to_curve(seed, attempt_count); + assert!(point.is_on_curve()); + assert_eq!( + hex::encode(point.x().unwrap().into_bigint().to_bytes_be()), + expected_point.0, + "Failed on x component with seed {seed:?}, attempt_count {attempt_count}" + ); + assert_eq!( + hex::encode(point.y().unwrap().into_bigint().to_bytes_be()), + expected_point.1, + "Failed on y component with seed {seed:?}, attempt_count {attempt_count}" + ); + } + } +} diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/mod.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/mod.rs new file mode 100644 index 00000000000..0f62642516a --- /dev/null +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/generator/mod.rs @@ -0,0 +1,8 @@ +//! This module is adapted from the [Barustenberg][barustenberg] Rust implementation of the Barretenberg library. +//! +//! Code is used under the MIT license +//! +//! [barustenberg]: https://github.com/laudiacay/barustenberg/blob/df6bc6f095fe7f288bf6a12e7317fd8eb33d68ae/ + +pub(crate) mod generators; +mod hash_to_curve; diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs index 4cb51b59755..ae6fb7999a0 100644 --- a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs @@ -6,14 +6,17 @@ use acir::{BlackBoxFunc, FieldElement}; use acvm_blackbox_solver::{BlackBoxFunctionSolver, BlackBoxResolutionError}; mod embedded_curve_ops; +mod generator; +mod pedersen; mod poseidon2; mod wasm; +use ark_ec::AffineRepr; pub use embedded_curve_ops::{embedded_curve_add, multi_scalar_mul}; pub use poseidon2::poseidon2_permutation; use wasm::Barretenberg; -use self::wasm::{Pedersen, SchnorrSig}; +use self::wasm::SchnorrSig; pub struct Bn254BlackBoxSolver { blackbox_vendor: Barretenberg, @@ -72,10 +75,13 @@ impl BlackBoxFunctionSolver for Bn254BlackBoxSolver { inputs: &[FieldElement], domain_separator: u32, ) -> Result<(FieldElement, FieldElement), BlackBoxResolutionError> { - #[allow(deprecated)] - self.blackbox_vendor.encrypt(inputs.to_vec(), domain_separator).map_err(|err| { - BlackBoxResolutionError::Failed(BlackBoxFunc::PedersenCommitment, err.to_string()) - }) + let inputs: Vec = inputs.iter().map(|input| input.into_repr()).collect(); + let result = pedersen::commitment::commit_native_with_index(&inputs, domain_separator); + let res_x = + FieldElement::from_repr(*result.x().expect("should not commit to point at infinity")); + let res_y = + FieldElement::from_repr(*result.y().expect("should not commit to point at infinity")); + Ok((res_x, res_y)) } fn pedersen_hash( @@ -83,10 +89,10 @@ impl BlackBoxFunctionSolver for Bn254BlackBoxSolver { inputs: &[FieldElement], domain_separator: u32, ) -> Result { - #[allow(deprecated)] - self.blackbox_vendor.hash(inputs.to_vec(), domain_separator).map_err(|err| { - BlackBoxResolutionError::Failed(BlackBoxFunc::PedersenCommitment, err.to_string()) - }) + let inputs: Vec = inputs.iter().map(|input| input.into_repr()).collect(); + let result = pedersen::hash::hash_with_index(&inputs, domain_separator); + let result = FieldElement::from_repr(result); + Ok(result) } fn multi_scalar_mul( diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/commitment.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/commitment.rs new file mode 100644 index 00000000000..6769150508a --- /dev/null +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/commitment.rs @@ -0,0 +1,76 @@ +// Taken from: https://github.com/laudiacay/barustenberg/blob/df6bc6f095fe7f288bf6a12e7317fd8eb33d68ae/barustenberg/src/crypto/pedersen/pederson.rs + +use ark_ec::{short_weierstrass::Affine, AffineRepr, CurveGroup}; +use ark_ff::{MontConfig, PrimeField}; +use grumpkin::{Fq, FqConfig, Fr, FrConfig, GrumpkinParameters}; + +use crate::generator::generators::{derive_generators, DEFAULT_DOMAIN_SEPARATOR}; + +/// Given a vector of fields, generate a pedersen commitment using the indexed generators. +pub(crate) fn commit_native_with_index( + inputs: &[Fq], + starting_index: u32, +) -> Affine { + let generators = + derive_generators(DEFAULT_DOMAIN_SEPARATOR, inputs.len() as u32, starting_index); + + // As |F_r| > |F_q|, we can safely convert any `F_q` into an `F_r` uniquely. + assert!(FrConfig::MODULUS > FqConfig::MODULUS); + + inputs.iter().enumerate().fold(Affine::zero(), |mut acc, (i, input)| { + acc = (acc + (generators[i] * Fr::from_bigint(input.into_bigint()).unwrap()).into_affine()) + .into_affine(); + acc + }) +} + +#[cfg(test)] +mod test { + + use acir::FieldElement; + use ark_ec::short_weierstrass::Affine; + use ark_std::{One, Zero}; + use grumpkin::Fq; + + use crate::pedersen::commitment::commit_native_with_index; + + #[test] + fn commitment() { + // https://github.com/AztecProtocol/aztec-packages/blob/72931bdb8202c34042cdfb8cee2ef44b75939879/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen.test.cpp#L10-L18 + let res = commit_native_with_index(&[Fq::one(), Fq::one()], 0); + let expected = Affine::new( + FieldElement::from_hex( + "0x2f7a8f9a6c96926682205fb73ee43215bf13523c19d7afe36f12760266cdfe15", + ) + .unwrap() + .into_repr(), + FieldElement::from_hex( + "0x01916b316adbbf0e10e39b18c1d24b33ec84b46daddf72f43878bcc92b6057e6", + ) + .unwrap() + .into_repr(), + ); + + assert_eq!(res, expected); + } + + #[test] + fn commitment_with_zero() { + // https://github.com/AztecProtocol/aztec-packages/blob/72931bdb8202c34042cdfb8cee2ef44b75939879/barretenberg/cpp/src/barretenberg/crypto/pedersen_commitment/pedersen.test.cpp#L20-L29 + let res = commit_native_with_index(&[Fq::zero(), Fq::one()], 0); + let expected = Affine::new( + FieldElement::from_hex( + "0x054aa86a73cb8a34525e5bbed6e43ba1198e860f5f3950268f71df4591bde402", + ) + .unwrap() + .into_repr(), + FieldElement::from_hex( + "0x209dcfbf2cfb57f9f6046f44d71ac6faf87254afc7407c04eb621a6287cac126", + ) + .unwrap() + .into_repr(), + ); + + assert_eq!(res, expected); + } +} diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/hash.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/hash.rs new file mode 100644 index 00000000000..28bf354edc9 --- /dev/null +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/hash.rs @@ -0,0 +1,68 @@ +// Taken from: https://github.com/laudiacay/barustenberg/blob/df6bc6f095fe7f288bf6a12e7317fd8eb33d68ae/barustenberg/src/crypto/pedersen/pederson_hash.rs + +use std::sync::OnceLock; + +use ark_ec::{short_weierstrass::Affine, CurveConfig, CurveGroup}; +use grumpkin::GrumpkinParameters; + +use crate::generator::generators::derive_generators; + +use super::commitment::commit_native_with_index; + +/// Given a vector of fields, generate a pedersen hash using the indexed generators. +pub(crate) fn hash_with_index( + inputs: &[grumpkin::Fq], + starting_index: u32, +) -> ::BaseField { + let length_as_scalar: ::ScalarField = + (inputs.len() as u64).into(); + let length_prefix = *length_generator() * length_as_scalar; + let result = length_prefix + commit_native_with_index(inputs, starting_index); + result.into_affine().x +} + +fn length_generator() -> &'static Affine { + static INSTANCE: OnceLock> = OnceLock::new(); + INSTANCE.get_or_init(|| derive_generators("pedersen_hash_length".as_bytes(), 1, 0)[0]) +} + +#[cfg(test)] +pub(crate) mod test { + + use super::*; + + use acir::FieldElement; + use ark_std::One; + use grumpkin::Fq; + + //reference: https://github.com/AztecProtocol/barretenberg/blob/master/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.test.cpp + #[test] + fn hash_one() { + // https://github.com/AztecProtocol/aztec-packages/blob/72931bdb8202c34042cdfb8cee2ef44b75939879/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.test.cpp#L21-L26 + let res = hash_with_index(&[Fq::one(), Fq::one()], 0); + + assert_eq!( + res, + FieldElement::from_hex( + "0x07ebfbf4df29888c6cd6dca13d4bb9d1a923013ddbbcbdc3378ab8845463297b", + ) + .unwrap() + .into_repr(), + ); + } + + #[test] + fn test_hash_with_index() { + // https://github.com/AztecProtocol/aztec-packages/blob/72931bdb8202c34042cdfb8cee2ef44b75939879/barretenberg/cpp/src/barretenberg/crypto/pedersen_hash/pedersen.test.cpp#L28-L33 + let res = hash_with_index(&[Fq::one(), Fq::one()], 5); + + assert_eq!( + res, + FieldElement::from_hex( + "0x1c446df60816b897cda124524e6b03f36df0cec333fad87617aab70d7861daa6", + ) + .unwrap() + .into_repr(), + ); + } +} diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/mod.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/mod.rs new file mode 100644 index 00000000000..c3c4ed56450 --- /dev/null +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/pedersen/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod commitment; +pub(crate) mod hash; diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/barretenberg_structures.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/barretenberg_structures.rs deleted file mode 100644 index 302ffa8af9b..00000000000 --- a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/barretenberg_structures.rs +++ /dev/null @@ -1,25 +0,0 @@ -use acir::FieldElement; - -#[derive(Debug, Default)] -pub(crate) struct Assignments(Vec); - -impl Assignments { - pub(crate) fn to_bytes(&self) -> Vec { - let mut buffer = Vec::new(); - - let witness_len = self.0.len() as u32; - buffer.extend_from_slice(&witness_len.to_be_bytes()); - - for assignment in self.0.iter() { - buffer.extend_from_slice(&assignment.to_be_bytes()); - } - - buffer - } -} - -impl From> for Assignments { - fn from(w: Vec) -> Assignments { - Assignments(w) - } -} diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/mod.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/mod.rs index f4f6f56aa99..e0a5c4c9069 100644 --- a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/mod.rs +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/mod.rs @@ -4,13 +4,8 @@ //! //! As [`acvm`] includes rust implementations for these opcodes, this module can be removed. -mod barretenberg_structures; -mod pedersen; mod schnorr; -use barretenberg_structures::Assignments; - -pub(crate) use pedersen::Pedersen; pub(crate) use schnorr::SchnorrSig; /// The number of bytes necessary to store a `FieldElement`. @@ -208,10 +203,6 @@ impl Barretenberg { buf } - pub(crate) fn call(&self, name: &str, param: &WASMValue) -> Result { - self.call_multiple(name, vec![param]) - } - pub(crate) fn call_multiple( &self, name: &str, @@ -236,17 +227,6 @@ impl Barretenberg { Ok(WASMValue(option_value)) } - - /// Creates a pointer and allocates the bytes that the pointer references to, to the heap - pub(crate) fn allocate(&self, bytes: &[u8]) -> Result { - let ptr: i32 = self.call("bbmalloc", &bytes.len().into())?.try_into()?; - - let i32_bytes = ptr.to_be_bytes(); - let u32_bytes = u32::from_be_bytes(i32_bytes); - - self.transfer_to_heap(bytes, u32_bytes as usize); - Ok(ptr.into()) - } } fn init_memory_and_state() -> (Memory, Store, Imports) { diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/pedersen.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/pedersen.rs deleted file mode 100644 index c816e5b4d1b..00000000000 --- a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/wasm/pedersen.rs +++ /dev/null @@ -1,73 +0,0 @@ -use acir::FieldElement; - -use super::{Assignments, Barretenberg, Error, FIELD_BYTES}; - -pub(crate) trait Pedersen { - fn encrypt( - &self, - inputs: Vec, - hash_index: u32, - ) -> Result<(FieldElement, FieldElement), Error>; - - fn hash(&self, inputs: Vec, hash_index: u32) -> Result; -} - -impl Pedersen for Barretenberg { - fn encrypt( - &self, - inputs: Vec, - hash_index: u32, - ) -> Result<(FieldElement, FieldElement), Error> { - let input_buf = Assignments::from(inputs).to_bytes(); - let input_ptr = self.allocate(&input_buf)?; - let result_ptr: usize = 0; - - self.call_multiple( - "pedersen_plookup_commit_with_hash_index", - vec![&input_ptr, &result_ptr.into(), &hash_index.into()], - )?; - - let result_bytes: [u8; 2 * FIELD_BYTES] = self.read_memory(result_ptr); - let (point_x_bytes, point_y_bytes) = result_bytes.split_at(FIELD_BYTES); - - let point_x = FieldElement::from_be_bytes_reduce(point_x_bytes); - let point_y = FieldElement::from_be_bytes_reduce(point_y_bytes); - - Ok((point_x, point_y)) - } - - fn hash(&self, inputs: Vec, hash_index: u32) -> Result { - let input_buf = Assignments::from(inputs).to_bytes(); - let input_ptr = self.allocate(&input_buf)?; - let result_ptr: usize = 0; - - self.call_multiple( - "pedersen_plookup_compress_with_hash_index", - vec![&input_ptr, &result_ptr.into(), &hash_index.into()], - )?; - - let result_bytes: [u8; FIELD_BYTES] = self.read_memory(result_ptr); - - let hash = FieldElement::from_be_bytes_reduce(&result_bytes); - - Ok(hash) - } -} - -#[test] -fn pedersen_hash_to_point() -> Result<(), Error> { - let barretenberg = Barretenberg::new(); - let (x, y) = barretenberg.encrypt(vec![FieldElement::one(), FieldElement::one()], 1)?; - let expected_x = FieldElement::from_hex( - "0x12afb43195f5c621d1d2cabb5f629707095c5307fd4185a663d4e80bb083e878", - ) - .unwrap(); - let expected_y = FieldElement::from_hex( - "0x25793f5b5e62beb92fd18a66050293a9fd554a2ff13bceba0339cae1a038d7c1", - ) - .unwrap(); - - assert_eq!(expected_x.to_hex(), x.to_hex()); - assert_eq!(expected_y.to_hex(), y.to_hex()); - Ok(()) -} diff --git a/noir/noir-repo/compiler/noirc_driver/src/lib.rs b/noir/noir-repo/compiler/noirc_driver/src/lib.rs index ef874d45f88..801c0b685a9 100644 --- a/noir/noir-repo/compiler/noirc_driver/src/lib.rs +++ b/noir/noir-repo/compiler/noirc_driver/src/lib.rs @@ -54,8 +54,8 @@ pub const NOIR_ARTIFACT_VERSION_STRING: &str = #[derive(Args, Clone, Debug, Default)] pub struct CompileOptions { /// Override the expression width requested by the backend. - #[arg(long, value_parser = parse_expression_width)] - pub expression_width: Option, + #[arg(long, value_parser = parse_expression_width, default_value = "4")] + pub expression_width: ExpressionWidth, /// Force a full recompilation. #[arg(long = "force")] @@ -103,6 +103,10 @@ pub struct CompileOptions { /// Force Brillig output (for step debugging) #[arg(long, hide = true)] pub force_brillig: bool, + + /// Enable the experimental elaborator pass + #[arg(long, hide = true)] + pub use_elaborator: bool, } fn parse_expression_width(input: &str) -> Result { @@ -245,12 +249,13 @@ pub fn check_crate( crate_id: CrateId, deny_warnings: bool, disable_macros: bool, + use_elaborator: bool, ) -> CompilationResult<()> { let macros: &[&dyn MacroProcessor] = if disable_macros { &[] } else { &[&aztec_macros::AztecMacro as &dyn MacroProcessor] }; let mut errors = vec![]; - let diagnostics = CrateDefMap::collect_defs(crate_id, context, macros); + let diagnostics = CrateDefMap::collect_defs(crate_id, context, use_elaborator, macros); errors.extend(diagnostics.into_iter().map(|(error, file_id)| { let diagnostic = CustomDiagnostic::from(&error); diagnostic.in_file(file_id) @@ -282,8 +287,13 @@ pub fn compile_main( options: &CompileOptions, cached_program: Option, ) -> CompilationResult { - let (_, mut warnings) = - check_crate(context, crate_id, options.deny_warnings, options.disable_macros)?; + let (_, mut warnings) = check_crate( + context, + crate_id, + options.deny_warnings, + options.disable_macros, + options.use_elaborator, + )?; let main = context.get_main_function(&crate_id).ok_or_else(|| { // TODO(#2155): This error might be a better to exist in Nargo @@ -318,8 +328,13 @@ pub fn compile_contract( crate_id: CrateId, options: &CompileOptions, ) -> CompilationResult { - let (_, warnings) = - check_crate(context, crate_id, options.deny_warnings, options.disable_macros)?; + let (_, warnings) = check_crate( + context, + crate_id, + options.deny_warnings, + options.disable_macros, + options.use_elaborator, + )?; // TODO: We probably want to error if contracts is empty let contracts = context.get_all_contracts(&crate_id); diff --git a/noir/noir-repo/compiler/noirc_driver/tests/stdlib_warnings.rs b/noir/noir-repo/compiler/noirc_driver/tests/stdlib_warnings.rs index 6f437621123..327c8daad06 100644 --- a/noir/noir-repo/compiler/noirc_driver/tests/stdlib_warnings.rs +++ b/noir/noir-repo/compiler/noirc_driver/tests/stdlib_warnings.rs @@ -24,7 +24,8 @@ fn stdlib_does_not_produce_constant_warnings() -> Result<(), ErrorsAndWarnings> let mut context = Context::new(file_manager, parsed_files); let root_crate_id = prepare_crate(&mut context, file_name); - let ((), warnings) = noirc_driver::check_crate(&mut context, root_crate_id, false, false)?; + let ((), warnings) = + noirc_driver::check_crate(&mut context, root_crate_id, false, false, false)?; assert_eq!(warnings, Vec::new(), "stdlib is producing warnings"); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index 873ebe51e6f..f660c8e0b7a 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -1328,7 +1328,15 @@ impl<'block> BrilligBlock<'block> { self.brillig_context.binary_instruction(left, right, result_variable, brillig_binary_op); - self.add_overflow_check(brillig_binary_op, left, right, result_variable, is_signed); + self.add_overflow_check( + brillig_binary_op, + left, + right, + result_variable, + binary, + dfg, + is_signed, + ); } /// Splits a two's complement signed integer in the sign bit and the absolute value. @@ -1481,15 +1489,20 @@ impl<'block> BrilligBlock<'block> { self.brillig_context.deallocate_single_addr(bias); } + #[allow(clippy::too_many_arguments)] fn add_overflow_check( &mut self, binary_operation: BrilligBinaryOp, left: SingleAddrVariable, right: SingleAddrVariable, result: SingleAddrVariable, + binary: &Binary, + dfg: &DataFlowGraph, is_signed: bool, ) { let bit_size = left.bit_size; + let max_lhs_bits = dfg.get_value_max_num_bits(binary.lhs); + let max_rhs_bits = dfg.get_value_max_num_bits(binary.rhs); if bit_size == FieldElement::max_num_bits() { return; @@ -1497,6 +1510,11 @@ impl<'block> BrilligBlock<'block> { match (binary_operation, is_signed) { (BrilligBinaryOp::Add, false) => { + if std::cmp::max(max_lhs_bits, max_rhs_bits) < bit_size { + // `left` and `right` have both been casted up from smaller types and so cannot overflow. + return; + } + let condition = SingleAddrVariable::new(self.brillig_context.allocate_register(), 1); // Check that lhs <= result @@ -1511,6 +1529,12 @@ impl<'block> BrilligBlock<'block> { self.brillig_context.deallocate_single_addr(condition); } (BrilligBinaryOp::Sub, false) => { + if dfg.is_constant(binary.lhs) && max_lhs_bits > max_rhs_bits { + // `left` is a fixed constant and `right` is restricted such that `left - right > 0` + // Note strict inequality as `right > left` while `max_lhs_bits == max_rhs_bits` is possible. + return; + } + let condition = SingleAddrVariable::new(self.brillig_context.allocate_register(), 1); // Check that rhs <= lhs @@ -1527,39 +1551,36 @@ impl<'block> BrilligBlock<'block> { self.brillig_context.deallocate_single_addr(condition); } (BrilligBinaryOp::Mul, false) => { - // Multiplication overflow is only possible for bit sizes > 1 - if bit_size > 1 { - let is_right_zero = - SingleAddrVariable::new(self.brillig_context.allocate_register(), 1); - let zero = - self.brillig_context.make_constant_instruction(0_usize.into(), bit_size); - self.brillig_context.binary_instruction( - zero, - right, - is_right_zero, - BrilligBinaryOp::Equals, - ); - self.brillig_context.codegen_if_not(is_right_zero.address, |ctx| { - let condition = SingleAddrVariable::new(ctx.allocate_register(), 1); - let division = SingleAddrVariable::new(ctx.allocate_register(), bit_size); - // Check that result / rhs == lhs - ctx.binary_instruction( - result, - right, - division, - BrilligBinaryOp::UnsignedDiv, - ); - ctx.binary_instruction(division, left, condition, BrilligBinaryOp::Equals); - ctx.codegen_constrain( - condition, - Some("attempt to multiply with overflow".to_string()), - ); - ctx.deallocate_single_addr(condition); - ctx.deallocate_single_addr(division); - }); - self.brillig_context.deallocate_single_addr(is_right_zero); - self.brillig_context.deallocate_single_addr(zero); + if bit_size == 1 || max_lhs_bits + max_rhs_bits <= bit_size { + // Either performing boolean multiplication (which cannot overflow), + // or `left` and `right` have both been casted up from smaller types and so cannot overflow. + return; } + + let is_right_zero = + SingleAddrVariable::new(self.brillig_context.allocate_register(), 1); + let zero = self.brillig_context.make_constant_instruction(0_usize.into(), bit_size); + self.brillig_context.binary_instruction( + zero, + right, + is_right_zero, + BrilligBinaryOp::Equals, + ); + self.brillig_context.codegen_if_not(is_right_zero.address, |ctx| { + let condition = SingleAddrVariable::new(ctx.allocate_register(), 1); + let division = SingleAddrVariable::new(ctx.allocate_register(), bit_size); + // Check that result / rhs == lhs + ctx.binary_instruction(result, right, division, BrilligBinaryOp::UnsignedDiv); + ctx.binary_instruction(division, left, condition, BrilligBinaryOp::Equals); + ctx.codegen_constrain( + condition, + Some("attempt to multiply with overflow".to_string()), + ); + ctx.deallocate_single_addr(condition); + ctx.deallocate_single_addr(division); + }); + self.brillig_context.deallocate_single_addr(is_right_zero); + self.brillig_context.deallocate_single_addr(zero); } _ => {} } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs index 2e2f03a0012..05d2e3e3e6a 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -1729,13 +1729,16 @@ impl<'a> Context<'a> { // will expand the array if there is one. let return_acir_vars = self.flatten_value_list(return_values, dfg)?; let mut warnings = Vec::new(); - for acir_var in return_acir_vars { + for (acir_var, is_databus) in return_acir_vars { if self.acir_context.is_constant(&acir_var) { warnings.push(SsaReport::Warning(InternalWarning::ReturnConstant { call_stack: call_stack.clone(), })); } - self.acir_context.return_var(acir_var)?; + if !is_databus { + // We do not return value for the data bus. + self.acir_context.return_var(acir_var)?; + } } Ok(warnings) } @@ -1837,15 +1840,15 @@ impl<'a> Context<'a> { let binary_type = AcirType::from(binary_type); let bit_count = binary_type.bit_size(); - - match binary.operator { + let num_type = binary_type.to_numeric_type(); + let result = match binary.operator { BinaryOp::Add => self.acir_context.add_var(lhs, rhs), BinaryOp::Sub => self.acir_context.sub_var(lhs, rhs), BinaryOp::Mul => self.acir_context.mul_var(lhs, rhs), BinaryOp::Div => self.acir_context.div_var( lhs, rhs, - binary_type, + binary_type.clone(), self.current_side_effects_enabled_var, ), // Note: that this produces unnecessary constraints when @@ -1869,7 +1872,71 @@ impl<'a> Context<'a> { BinaryOp::Shl | BinaryOp::Shr => unreachable!( "ICE - bit shift operators do not exist in ACIR and should have been replaced" ), + }?; + + if let NumericType::Unsigned { bit_size } = &num_type { + // Check for integer overflow + self.check_unsigned_overflow( + result, + *bit_size, + binary.lhs, + binary.rhs, + dfg, + binary.operator, + )?; } + + Ok(result) + } + + /// Adds a range check against the bit size of the result of addition, subtraction or multiplication + fn check_unsigned_overflow( + &mut self, + result: AcirVar, + bit_size: u32, + lhs: ValueId, + rhs: ValueId, + dfg: &DataFlowGraph, + op: BinaryOp, + ) -> Result<(), RuntimeError> { + // We try to optimize away operations that are guaranteed not to overflow + let max_lhs_bits = dfg.get_value_max_num_bits(lhs); + let max_rhs_bits = dfg.get_value_max_num_bits(rhs); + + let msg = match op { + BinaryOp::Add => { + if std::cmp::max(max_lhs_bits, max_rhs_bits) < bit_size { + // `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow. + return Ok(()); + } + "attempt to add with overflow".to_string() + } + BinaryOp::Sub => { + if dfg.is_constant(lhs) && max_lhs_bits > max_rhs_bits { + // `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0` + // Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible. + return Ok(()); + } + "attempt to subtract with overflow".to_string() + } + BinaryOp::Mul => { + if bit_size == 1 || max_lhs_bits + max_rhs_bits <= bit_size { + // Either performing boolean multiplication (which cannot overflow), + // or `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow. + return Ok(()); + } + "attempt to multiply with overflow".to_string() + } + _ => return Ok(()), + }; + + let with_pred = self.acir_context.mul_var(result, self.current_side_effects_enabled_var)?; + self.acir_context.range_constrain_var( + with_pred, + &NumericType::Unsigned { bit_size }, + Some(msg), + )?; + Ok(()) } /// Operands in a binary operation are checked to have the same type. @@ -2595,12 +2662,22 @@ impl<'a> Context<'a> { &mut self, arguments: &[ValueId], dfg: &DataFlowGraph, - ) -> Result, InternalError> { + ) -> Result, InternalError> { let mut acir_vars = Vec::with_capacity(arguments.len()); for value_id in arguments { + let is_databus = if let Some(return_databus) = self.data_bus.return_data { + dfg[*value_id] == dfg[return_databus] + } else { + false + }; let value = self.convert_value(*value_id, dfg); acir_vars.append( - &mut self.acir_context.flatten(value)?.iter().map(|(var, _)| *var).collect(), + &mut self + .acir_context + .flatten(value)? + .iter() + .map(|(var, _)| (*var, is_databus)) + .collect(), ); } Ok(acir_vars) diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index bddfb25f26c..73dc3888184 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -11,7 +11,7 @@ use crate::ssa::{ ir::{ basic_block::BasicBlockId, dfg::{CallStack, InsertInstructionResult}, - function::{Function, FunctionId}, + function::{Function, FunctionId, RuntimeType}, instruction::{Instruction, InstructionId, TerminatorInstruction}, value::{Value, ValueId}, }, @@ -392,10 +392,12 @@ impl<'function> PerFunctionContext<'function> { Some(func_id) => { let function = &ssa.functions[&func_id]; // If we have not already finished the flattening pass, functions marked - // to not have predicates should be marked as entry points. + // to not have predicates should be marked as entry points unless we are inlining into brillig. + let entry_point = &ssa.functions[&self.context.entry_point]; let no_predicates_is_entry_point = self.context.no_predicates_is_entry_point - && function.is_no_predicates(); + && function.is_no_predicates() + && !matches!(entry_point.runtime(), RuntimeType::Brillig); if function.runtime().is_entry_point() || no_predicates_is_entry_point { self.push_instruction(*id); } else { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs index 42727054503..65a77552c79 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs @@ -109,7 +109,7 @@ impl Context<'_> { return InsertInstructionResult::SimplifiedTo(zero).first(); } } - let pow = self.numeric_constant(FieldElement::from(rhs_bit_size_pow_2), typ); + let pow = self.numeric_constant(FieldElement::from(rhs_bit_size_pow_2), typ.clone()); let max_lhs_bits = self.function.dfg.get_value_max_num_bits(lhs); @@ -123,15 +123,18 @@ impl Context<'_> { // we can safely cast to unsigned because overflow_checks prevent bit-shift with a negative value let rhs_unsigned = self.insert_cast(rhs, Type::unsigned(bit_size)); let pow = self.pow(base, rhs_unsigned); - let pow = self.insert_cast(pow, typ); + let pow = self.insert_cast(pow, typ.clone()); (FieldElement::max_num_bits(), self.insert_binary(predicate, BinaryOp::Mul, pow)) }; if max_bit <= bit_size { self.insert_binary(lhs, BinaryOp::Mul, pow) } else { - let result = self.insert_binary(lhs, BinaryOp::Mul, pow); - self.insert_truncate(result, bit_size, max_bit) + let lhs_field = self.insert_cast(lhs, Type::field()); + let pow_field = self.insert_cast(pow, Type::field()); + let result = self.insert_binary(lhs_field, BinaryOp::Mul, pow_field); + let result = self.insert_truncate(result, bit_size, max_bit); + self.insert_cast(result, typ) } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs index 02b9202b209..ea37d857e58 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs @@ -108,17 +108,19 @@ impl Context { fn responds_to_side_effects_var(dfg: &DataFlowGraph, instruction: &Instruction) -> bool { use Instruction::*; match instruction { - Binary(binary) => { - if matches!(binary.operator, BinaryOp::Div | BinaryOp::Mod) { + Binary(binary) => match binary.operator { + BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul => { + dfg.type_of_value(binary.lhs).is_unsigned() + } + BinaryOp::Div | BinaryOp::Mod => { if let Some(rhs) = dfg.get_numeric_constant(binary.rhs) { rhs == FieldElement::zero() } else { true } - } else { - false } - } + _ => false, + }, Cast(_, _) | Not(_) diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index f7ecdc8870d..ebcbfbabe73 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -304,7 +304,7 @@ impl<'a> FunctionContext<'a> { /// Insert constraints ensuring that the operation does not overflow the bit size of the result /// - /// If the result is unsigned, we simply range check against the bit size + /// If the result is unsigned, overflow will be checked during acir-gen (cf. issue #4456), except for bit-shifts, because we will convert them to field multiplication /// /// If the result is signed, we just prepare it for check_signed_overflow() by casting it to /// an unsigned value representing the signed integer. @@ -351,51 +351,12 @@ impl<'a> FunctionContext<'a> { } Type::Numeric(NumericType::Unsigned { bit_size }) => { let dfg = &self.builder.current_function.dfg; - - let max_lhs_bits = self.builder.current_function.dfg.get_value_max_num_bits(lhs); - let max_rhs_bits = self.builder.current_function.dfg.get_value_max_num_bits(rhs); + let max_lhs_bits = dfg.get_value_max_num_bits(lhs); match operator { - BinaryOpKind::Add => { - if std::cmp::max(max_lhs_bits, max_rhs_bits) < bit_size { - // `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow. - return result; - } - - let message = "attempt to add with overflow".to_string(); - self.builder.set_location(location).insert_range_check( - result, - bit_size, - Some(message), - ); - } - BinaryOpKind::Subtract => { - if dfg.is_constant(lhs) && max_lhs_bits > max_rhs_bits { - // `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0` - // Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible. - return result; - } - - let message = "attempt to subtract with overflow".to_string(); - self.builder.set_location(location).insert_range_check( - result, - bit_size, - Some(message), - ); - } - BinaryOpKind::Multiply => { - if bit_size == 1 || max_lhs_bits + max_rhs_bits <= bit_size { - // Either performing boolean multiplication (which cannot overflow), - // or `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow. - return result; - } - - let message = "attempt to multiply with overflow".to_string(); - self.builder.set_location(location).insert_range_check( - result, - bit_size, - Some(message), - ); + BinaryOpKind::Add | BinaryOpKind::Subtract | BinaryOpKind::Multiply => { + // Overflow check is deferred to acir-gen + return result; } BinaryOpKind::ShiftLeft => { if let Some(rhs_const) = dfg.get_numeric_constant(rhs) { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/function.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/function.rs index dc426a4642a..8acc068d86a 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/function.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/function.rs @@ -32,6 +32,15 @@ pub enum FunctionKind { Recursive, } +impl FunctionKind { + pub fn can_ignore_return_type(self) -> bool { + match self { + FunctionKind::LowLevel | FunctionKind::Builtin | FunctionKind::Oracle => true, + FunctionKind::Normal | FunctionKind::Recursive => false, + } + } +} + impl NoirFunction { pub fn normal(def: FunctionDefinition) -> NoirFunction { NoirFunction { kind: FunctionKind::Normal, def } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/mod.rs index 254ec4a7590..1c5a5c610aa 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/mod.rs @@ -32,6 +32,7 @@ use iter_extended::vecmap; pub enum IntegerBitSize { One, Eight, + Sixteen, ThirtyTwo, SixtyFour, } @@ -48,6 +49,7 @@ impl From for u32 { match size { One => 1, Eight => 8, + Sixteen => 16, ThirtyTwo => 32, SixtyFour => 64, } @@ -64,6 +66,7 @@ impl TryFrom for IntegerBitSize { match value { 1 => Ok(One), 8 => Ok(Eight), + 16 => Ok(Sixteen), 32 => Ok(ThirtyTwo), 64 => Ok(SixtyFour), _ => Err(InvalidIntegerBitSizeError(value)), diff --git a/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs b/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs index 0da39edfd85..94b5841e52c 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/ast/statement.rs @@ -565,7 +565,7 @@ impl ForRange { identifier: Ident, block: Expression, for_loop_span: Span, - ) -> StatementKind { + ) -> Statement { /// Counter used to generate unique names when desugaring /// code in the parser requires the creation of fresh variables. /// The parser is stateless so this is a static global instead. @@ -662,7 +662,8 @@ impl ForRange { let block = ExpressionKind::Block(BlockExpression { statements: vec![let_array, for_loop], }); - StatementKind::Expression(Expression::new(block, for_loop_span)) + let kind = StatementKind::Expression(Expression::new(block, for_loop_span)); + Statement { kind, span: for_loop_span } } } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs new file mode 100644 index 00000000000..ed8ed5305d1 --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -0,0 +1,604 @@ +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; +use regex::Regex; +use rustc_hash::FxHashSet as HashSet; + +use crate::{ + ast::{ + ArrayLiteral, ConstructorExpression, IfExpression, InfixExpression, Lambda, + UnresolvedTypeExpression, + }, + hir::{ + resolution::{errors::ResolverError, resolver::LambdaContext}, + type_check::TypeCheckError, + }, + hir_def::{ + expr::{ + HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression, + HirConstructorExpression, HirIdent, HirIfExpression, HirIndexExpression, + HirInfixExpression, HirLambda, HirMemberAccess, HirMethodCallExpression, + HirMethodReference, HirPrefixExpression, + }, + traits::TraitConstraint, + }, + macros_api::{ + BlockExpression, CallExpression, CastExpression, Expression, ExpressionKind, HirExpression, + HirLiteral, HirStatement, Ident, IndexExpression, Literal, MemberAccessExpression, + MethodCallExpression, PrefixExpression, + }, + node_interner::{DefinitionKind, ExprId, FuncId}, + Shared, StructType, Type, +}; + +use super::Elaborator; + +impl<'context> Elaborator<'context> { + pub(super) fn elaborate_expression(&mut self, expr: Expression) -> (ExprId, Type) { + let (hir_expr, typ) = match expr.kind { + ExpressionKind::Literal(literal) => self.elaborate_literal(literal, expr.span), + ExpressionKind::Block(block) => self.elaborate_block(block), + ExpressionKind::Prefix(prefix) => self.elaborate_prefix(*prefix), + ExpressionKind::Index(index) => self.elaborate_index(*index), + ExpressionKind::Call(call) => self.elaborate_call(*call, expr.span), + ExpressionKind::MethodCall(call) => self.elaborate_method_call(*call, expr.span), + ExpressionKind::Constructor(constructor) => self.elaborate_constructor(*constructor), + ExpressionKind::MemberAccess(access) => { + return self.elaborate_member_access(*access, expr.span) + } + ExpressionKind::Cast(cast) => self.elaborate_cast(*cast, expr.span), + ExpressionKind::Infix(infix) => return self.elaborate_infix(*infix, expr.span), + ExpressionKind::If(if_) => self.elaborate_if(*if_), + ExpressionKind::Variable(variable) => return self.elaborate_variable(variable), + ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple), + ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda), + ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr), + ExpressionKind::Quote(quote) => self.elaborate_quote(quote), + ExpressionKind::Comptime(comptime) => self.elaborate_comptime_block(comptime), + ExpressionKind::Error => (HirExpression::Error, Type::Error), + }; + let id = self.interner.push_expr(hir_expr); + self.interner.push_expr_location(id, expr.span, self.file); + self.interner.push_expr_type(id, typ.clone()); + (id, typ) + } + + pub(super) fn elaborate_block(&mut self, block: BlockExpression) -> (HirExpression, Type) { + self.push_scope(); + let mut block_type = Type::Unit; + let mut statements = Vec::with_capacity(block.statements.len()); + + for (i, statement) in block.statements.into_iter().enumerate() { + let (id, stmt_type) = self.elaborate_statement(statement); + statements.push(id); + + if let HirStatement::Semi(expr) = self.interner.statement(&id) { + let inner_expr_type = self.interner.id_type(expr); + let span = self.interner.expr_span(&expr); + + self.unify(&inner_expr_type, &Type::Unit, || TypeCheckError::UnusedResultError { + expr_type: inner_expr_type.clone(), + expr_span: span, + }); + + if i + 1 == statements.len() { + block_type = stmt_type; + } + } + } + + self.pop_scope(); + (HirExpression::Block(HirBlockExpression { statements }), block_type) + } + + fn elaborate_literal(&mut self, literal: Literal, span: Span) -> (HirExpression, Type) { + use HirExpression::Literal as Lit; + match literal { + Literal::Unit => (Lit(HirLiteral::Unit), Type::Unit), + Literal::Bool(b) => (Lit(HirLiteral::Bool(b)), Type::Bool), + Literal::Integer(integer, sign) => { + let int = HirLiteral::Integer(integer, sign); + (Lit(int), self.polymorphic_integer_or_field()) + } + Literal::Str(str) | Literal::RawStr(str, _) => { + let len = Type::Constant(str.len() as u64); + (Lit(HirLiteral::Str(str)), Type::String(Box::new(len))) + } + Literal::FmtStr(str) => self.elaborate_fmt_string(str, span), + Literal::Array(array_literal) => { + self.elaborate_array_literal(array_literal, span, true) + } + Literal::Slice(array_literal) => { + self.elaborate_array_literal(array_literal, span, false) + } + } + } + + fn elaborate_array_literal( + &mut self, + array_literal: ArrayLiteral, + span: Span, + is_array: bool, + ) -> (HirExpression, Type) { + let (expr, elem_type, length) = match array_literal { + ArrayLiteral::Standard(elements) => { + let first_elem_type = self.interner.next_type_variable(); + let first_span = elements.first().map(|elem| elem.span).unwrap_or(span); + + let elements = vecmap(elements.into_iter().enumerate(), |(i, elem)| { + let span = elem.span; + let (elem_id, elem_type) = self.elaborate_expression(elem); + + self.unify(&elem_type, &first_elem_type, || { + TypeCheckError::NonHomogeneousArray { + first_span, + first_type: first_elem_type.to_string(), + first_index: 0, + second_span: span, + second_type: elem_type.to_string(), + second_index: i, + } + .add_context("elements in an array must have the same type") + }); + elem_id + }); + + let length = Type::Constant(elements.len() as u64); + (HirArrayLiteral::Standard(elements), first_elem_type, length) + } + ArrayLiteral::Repeated { repeated_element, length } => { + let span = length.span; + let length = + UnresolvedTypeExpression::from_expr(*length, span).unwrap_or_else(|error| { + self.push_err(ResolverError::ParserError(Box::new(error))); + UnresolvedTypeExpression::Constant(0, span) + }); + + let length = self.convert_expression_type(length); + let (repeated_element, elem_type) = self.elaborate_expression(*repeated_element); + + let length_clone = length.clone(); + (HirArrayLiteral::Repeated { repeated_element, length }, elem_type, length_clone) + } + }; + let constructor = if is_array { HirLiteral::Array } else { HirLiteral::Slice }; + let elem_type = Box::new(elem_type); + let typ = if is_array { + Type::Array(Box::new(length), elem_type) + } else { + Type::Slice(elem_type) + }; + (HirExpression::Literal(constructor(expr)), typ) + } + + fn elaborate_fmt_string(&mut self, str: String, call_expr_span: Span) -> (HirExpression, Type) { + let re = Regex::new(r"\{([a-zA-Z0-9_]+)\}") + .expect("ICE: an invalid regex pattern was used for checking format strings"); + + let mut fmt_str_idents = Vec::new(); + let mut capture_types = Vec::new(); + + for field in re.find_iter(&str) { + let matched_str = field.as_str(); + let ident_name = &matched_str[1..(matched_str.len() - 1)]; + + let scope_tree = self.scopes.current_scope_tree(); + let variable = scope_tree.find(ident_name); + if let Some((old_value, _)) = variable { + old_value.num_times_used += 1; + let ident = HirExpression::Ident(old_value.ident.clone()); + let expr_id = self.interner.push_expr(ident); + self.interner.push_expr_location(expr_id, call_expr_span, self.file); + let ident = old_value.ident.clone(); + let typ = self.type_check_variable(ident, expr_id); + self.interner.push_expr_type(expr_id, typ.clone()); + capture_types.push(typ); + fmt_str_idents.push(expr_id); + } else if ident_name.parse::().is_ok() { + self.push_err(ResolverError::NumericConstantInFormatString { + name: ident_name.to_owned(), + span: call_expr_span, + }); + } else { + self.push_err(ResolverError::VariableNotDeclared { + name: ident_name.to_owned(), + span: call_expr_span, + }); + } + } + + let len = Type::Constant(str.len() as u64); + let typ = Type::FmtString(Box::new(len), Box::new(Type::Tuple(capture_types))); + (HirExpression::Literal(HirLiteral::FmtStr(str, fmt_str_idents)), typ) + } + + fn elaborate_prefix(&mut self, prefix: PrefixExpression) -> (HirExpression, Type) { + let span = prefix.rhs.span; + let (rhs, rhs_type) = self.elaborate_expression(prefix.rhs); + let ret_type = self.type_check_prefix_operand(&prefix.operator, &rhs_type, span); + (HirExpression::Prefix(HirPrefixExpression { operator: prefix.operator, rhs }), ret_type) + } + + fn elaborate_index(&mut self, index_expr: IndexExpression) -> (HirExpression, Type) { + let span = index_expr.index.span; + let (index, index_type) = self.elaborate_expression(index_expr.index); + + let expected = self.polymorphic_integer_or_field(); + self.unify(&index_type, &expected, || TypeCheckError::TypeMismatch { + expected_typ: "an integer".to_owned(), + expr_typ: index_type.to_string(), + expr_span: span, + }); + + // When writing `a[i]`, if `a : &mut ...` then automatically dereference `a` as many + // times as needed to get the underlying array. + let lhs_span = index_expr.collection.span; + let (lhs, lhs_type) = self.elaborate_expression(index_expr.collection); + let (collection, lhs_type) = self.insert_auto_dereferences(lhs, lhs_type); + + let typ = match lhs_type.follow_bindings() { + // XXX: We can check the array bounds here also, but it may be better to constant fold first + // and have ConstId instead of ExprId for constants + Type::Array(_, base_type) => *base_type, + Type::Slice(base_type) => *base_type, + Type::Error => Type::Error, + typ => { + self.push_err(TypeCheckError::TypeMismatch { + expected_typ: "Array".to_owned(), + expr_typ: typ.to_string(), + expr_span: lhs_span, + }); + Type::Error + } + }; + + let expr = HirExpression::Index(HirIndexExpression { collection, index }); + (expr, typ) + } + + fn elaborate_call(&mut self, call: CallExpression, span: Span) -> (HirExpression, Type) { + let (func, func_type) = self.elaborate_expression(*call.func); + + let mut arguments = Vec::with_capacity(call.arguments.len()); + let args = vecmap(call.arguments, |arg| { + let span = arg.span; + let (arg, typ) = self.elaborate_expression(arg); + arguments.push(arg); + (typ, arg, span) + }); + + let location = Location::new(span, self.file); + let call = HirCallExpression { func, arguments, location }; + let typ = self.type_check_call(&call, func_type, args, span); + (HirExpression::Call(call), typ) + } + + fn elaborate_method_call( + &mut self, + method_call: MethodCallExpression, + span: Span, + ) -> (HirExpression, Type) { + let object_span = method_call.object.span; + let (mut object, mut object_type) = self.elaborate_expression(method_call.object); + object_type = object_type.follow_bindings(); + + let method_name = method_call.method_name.0.contents.as_str(); + match self.lookup_method(&object_type, method_name, span) { + Some(method_ref) => { + // Automatically add `&mut` if the method expects a mutable reference and + // the object is not already one. + if let HirMethodReference::FuncId(func_id) = &method_ref { + if *func_id != FuncId::dummy_id() { + let function_type = self.interner.function_meta(func_id).typ.clone(); + + self.try_add_mutable_reference_to_object( + &function_type, + &mut object_type, + &mut object, + ); + } + } + + // These arguments will be given to the desugared function call. + // Compared to the method arguments, they also contain the object. + let mut function_args = Vec::with_capacity(method_call.arguments.len() + 1); + let mut arguments = Vec::with_capacity(method_call.arguments.len()); + + function_args.push((object_type.clone(), object, object_span)); + + for arg in method_call.arguments { + let span = arg.span; + let (arg, typ) = self.elaborate_expression(arg); + arguments.push(arg); + function_args.push((typ, arg, span)); + } + + let location = Location::new(span, self.file); + let method = method_call.method_name; + let method_call = HirMethodCallExpression { method, object, arguments, location }; + + // Desugar the method call into a normal, resolved function call + // so that the backend doesn't need to worry about methods + // TODO: update object_type here? + let ((function_id, function_name), function_call) = method_call.into_function_call( + &method_ref, + object_type, + location, + self.interner, + ); + + let func_type = self.type_check_variable(function_name, function_id); + + // Type check the new call now that it has been changed from a method call + // to a function call. This way we avoid duplicating code. + let typ = self.type_check_call(&function_call, func_type, function_args, span); + (HirExpression::Call(function_call), typ) + } + None => (HirExpression::Error, Type::Error), + } + } + + fn elaborate_constructor( + &mut self, + constructor: ConstructorExpression, + ) -> (HirExpression, Type) { + let span = constructor.type_name.span(); + + match self.lookup_type_or_error(constructor.type_name) { + Some(Type::Struct(r#type, struct_generics)) => { + let struct_type = r#type.clone(); + let generics = struct_generics.clone(); + + let fields = constructor.fields; + let field_types = r#type.borrow().get_fields(&struct_generics); + let fields = self.resolve_constructor_expr_fields( + struct_type.clone(), + field_types, + fields, + span, + ); + let expr = HirExpression::Constructor(HirConstructorExpression { + fields, + r#type, + struct_generics, + }); + (expr, Type::Struct(struct_type, generics)) + } + Some(typ) => { + self.push_err(ResolverError::NonStructUsedInConstructor { typ, span }); + (HirExpression::Error, Type::Error) + } + None => (HirExpression::Error, Type::Error), + } + } + + /// Resolve all the fields of a struct constructor expression. + /// Ensures all fields are present, none are repeated, and all + /// are part of the struct. + fn resolve_constructor_expr_fields( + &mut self, + struct_type: Shared, + field_types: Vec<(String, Type)>, + fields: Vec<(Ident, Expression)>, + span: Span, + ) -> Vec<(Ident, ExprId)> { + let mut ret = Vec::with_capacity(fields.len()); + let mut seen_fields = HashSet::default(); + let mut unseen_fields = struct_type.borrow().field_names(); + + for (field_name, field) in fields { + let expected_type = field_types.iter().find(|(name, _)| name == &field_name.0.contents); + let expected_type = expected_type.map(|(_, typ)| typ).unwrap_or(&Type::Error); + + let field_span = field.span; + let (resolved, field_type) = self.elaborate_expression(field); + + if unseen_fields.contains(&field_name) { + unseen_fields.remove(&field_name); + seen_fields.insert(field_name.clone()); + + self.unify_with_coercions(&field_type, expected_type, resolved, || { + TypeCheckError::TypeMismatch { + expected_typ: expected_type.to_string(), + expr_typ: field_type.to_string(), + expr_span: field_span, + } + }); + } else if seen_fields.contains(&field_name) { + // duplicate field + self.push_err(ResolverError::DuplicateField { field: field_name.clone() }); + } else { + // field not required by struct + self.push_err(ResolverError::NoSuchField { + field: field_name.clone(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret.push((field_name, resolved)); + } + + if !unseen_fields.is_empty() { + self.push_err(ResolverError::MissingFields { + span, + missing_fields: unseen_fields.into_iter().map(|field| field.to_string()).collect(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret + } + + fn elaborate_member_access( + &mut self, + access: MemberAccessExpression, + span: Span, + ) -> (ExprId, Type) { + let (lhs, lhs_type) = self.elaborate_expression(access.lhs); + let rhs = access.rhs; + // `is_offset` is only used when lhs is a reference and we want to return a reference to rhs + let access = HirMemberAccess { lhs, rhs, is_offset: false }; + let expr_id = self.intern_expr(HirExpression::MemberAccess(access.clone()), span); + let typ = self.type_check_member_access(access, expr_id, lhs_type, span); + self.interner.push_expr_type(expr_id, typ.clone()); + (expr_id, typ) + } + + pub fn intern_expr(&mut self, expr: HirExpression, span: Span) -> ExprId { + let id = self.interner.push_expr(expr); + self.interner.push_expr_location(id, span, self.file); + id + } + + fn elaborate_cast(&mut self, cast: CastExpression, span: Span) -> (HirExpression, Type) { + let (lhs, lhs_type) = self.elaborate_expression(cast.lhs); + let r#type = self.resolve_type(cast.r#type); + let result = self.check_cast(lhs_type, &r#type, span); + let expr = HirExpression::Cast(HirCastExpression { lhs, r#type }); + (expr, result) + } + + fn elaborate_infix(&mut self, infix: InfixExpression, span: Span) -> (ExprId, Type) { + let (lhs, lhs_type) = self.elaborate_expression(infix.lhs); + let (rhs, rhs_type) = self.elaborate_expression(infix.rhs); + let trait_id = self.interner.get_operator_trait_method(infix.operator.contents); + + let operator = HirBinaryOp::new(infix.operator, self.file); + let expr = HirExpression::Infix(HirInfixExpression { + lhs, + operator, + trait_method_id: trait_id, + rhs, + }); + + let expr_id = self.interner.push_expr(expr); + self.interner.push_expr_location(expr_id, span, self.file); + + let typ = match self.infix_operand_type_rules(&lhs_type, &operator, &rhs_type, span) { + Ok((typ, use_impl)) => { + if use_impl { + // Delay checking the trait constraint until the end of the function. + // Checking it now could bind an unbound type variable to any type + // that implements the trait. + let constraint = TraitConstraint { + typ: lhs_type.clone(), + trait_id: trait_id.trait_id, + trait_generics: Vec::new(), + }; + self.trait_constraints.push((constraint, expr_id)); + self.type_check_operator_method(expr_id, trait_id, &lhs_type, span); + } + typ + } + Err(error) => { + self.push_err(error); + Type::Error + } + }; + + self.interner.push_expr_type(expr_id, typ.clone()); + (expr_id, typ) + } + + fn elaborate_if(&mut self, if_expr: IfExpression) -> (HirExpression, Type) { + let expr_span = if_expr.condition.span; + let (condition, cond_type) = self.elaborate_expression(if_expr.condition); + let (consequence, mut ret_type) = self.elaborate_expression(if_expr.consequence); + + self.unify(&cond_type, &Type::Bool, || TypeCheckError::TypeMismatch { + expected_typ: Type::Bool.to_string(), + expr_typ: cond_type.to_string(), + expr_span, + }); + + let alternative = if_expr.alternative.map(|alternative| { + let expr_span = alternative.span; + let (else_, else_type) = self.elaborate_expression(alternative); + + self.unify(&ret_type, &else_type, || { + let err = TypeCheckError::TypeMismatch { + expected_typ: ret_type.to_string(), + expr_typ: else_type.to_string(), + expr_span, + }; + + let context = if ret_type == Type::Unit { + "Are you missing a semicolon at the end of your 'else' branch?" + } else if else_type == Type::Unit { + "Are you missing a semicolon at the end of the first block of this 'if'?" + } else { + "Expected the types of both if branches to be equal" + }; + + err.add_context(context) + }); + else_ + }); + + if alternative.is_none() { + ret_type = Type::Unit; + } + + let if_expr = HirIfExpression { condition, consequence, alternative }; + (HirExpression::If(if_expr), ret_type) + } + + fn elaborate_tuple(&mut self, tuple: Vec) -> (HirExpression, Type) { + let mut element_ids = Vec::with_capacity(tuple.len()); + let mut element_types = Vec::with_capacity(tuple.len()); + + for element in tuple { + let (id, typ) = self.elaborate_expression(element); + element_ids.push(id); + element_types.push(typ); + } + + (HirExpression::Tuple(element_ids), Type::Tuple(element_types)) + } + + fn elaborate_lambda(&mut self, lambda: Lambda) -> (HirExpression, Type) { + self.push_scope(); + let scope_index = self.scopes.current_scope_index(); + + self.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index }); + + let mut arg_types = Vec::with_capacity(lambda.parameters.len()); + let parameters = vecmap(lambda.parameters, |(pattern, typ)| { + let parameter = DefinitionKind::Local(None); + let typ = self.resolve_inferred_type(typ); + arg_types.push(typ.clone()); + (self.elaborate_pattern(pattern, typ.clone(), parameter), typ) + }); + + let return_type = self.resolve_inferred_type(lambda.return_type); + let body_span = lambda.body.span; + let (body, body_type) = self.elaborate_expression(lambda.body); + + let lambda_context = self.lambda_stack.pop().unwrap(); + self.pop_scope(); + + self.unify(&body_type, &return_type, || TypeCheckError::TypeMismatch { + expected_typ: return_type.to_string(), + expr_typ: body_type.to_string(), + expr_span: body_span, + }); + + let captured_vars = vecmap(&lambda_context.captures, |capture| { + self.interner.definition_type(capture.ident.id) + }); + + let env_type = + if captured_vars.is_empty() { Type::Unit } else { Type::Tuple(captured_vars) }; + + let captures = lambda_context.captures; + let expr = HirExpression::Lambda(HirLambda { parameters, return_type, body, captures }); + (expr, Type::Function(arg_types, Box::new(body_type), Box::new(env_type))) + } + + fn elaborate_quote(&mut self, block: BlockExpression) -> (HirExpression, Type) { + (HirExpression::Quote(block), Type::Code) + } + + fn elaborate_comptime_block(&mut self, _comptime: BlockExpression) -> (HirExpression, Type) { + todo!("Elaborate comptime block") + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs new file mode 100644 index 00000000000..446e5b62ead --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/mod.rs @@ -0,0 +1,782 @@ +#![allow(unused)] +use std::{ + collections::{BTreeMap, BTreeSet}, + rc::Rc, +}; + +use crate::hir::def_map::CrateDefMap; +use crate::{ + ast::{ + ArrayLiteral, ConstructorExpression, FunctionKind, IfExpression, InfixExpression, Lambda, + UnresolvedTraitConstraint, UnresolvedTypeExpression, + }, + hir::{ + def_collector::dc_crate::CompilationError, + resolution::{errors::ResolverError, path_resolver::PathResolver, resolver::LambdaContext}, + scope::ScopeForest as GenericScopeForest, + type_check::TypeCheckError, + }, + hir_def::{ + expr::{ + HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression, + HirConstructorExpression, HirIdent, HirIfExpression, HirIndexExpression, + HirInfixExpression, HirLambda, HirMemberAccess, HirMethodCallExpression, + HirMethodReference, HirPrefixExpression, + }, + traits::TraitConstraint, + }, + macros_api::{ + BlockExpression, CallExpression, CastExpression, Expression, ExpressionKind, HirExpression, + HirLiteral, HirStatement, Ident, IndexExpression, Literal, MemberAccessExpression, + MethodCallExpression, NodeInterner, NoirFunction, PrefixExpression, Statement, + StatementKind, StructId, + }, + node_interner::{DefinitionKind, DependencyId, ExprId, FuncId, StmtId, TraitId}, + Shared, StructType, Type, TypeVariable, +}; +use crate::{ + ast::{TraitBound, UnresolvedGenerics}, + graph::CrateId, + hir::{ + def_collector::{ + dc_crate::{CollectedItems, DefCollector}, + errors::DefCollectorErrorKind, + }, + def_map::{LocalModuleId, ModuleDefId, ModuleId, MAIN_FUNCTION}, + resolution::{ + errors::PubPosition, + import::{PathResolution, PathResolutionError}, + path_resolver::StandardPathResolver, + }, + Context, + }, + hir_def::function::{FuncMeta, HirFunction}, + macros_api::{Param, Path, UnresolvedType, UnresolvedTypeData, Visibility}, + node_interner::TraitImplId, + token::FunctionAttribute, + Generics, +}; + +mod expressions; +mod patterns; +mod scope; +mod statements; +mod types; + +use fm::FileId; +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; +use regex::Regex; +use rustc_hash::FxHashSet as HashSet; + +/// ResolverMetas are tagged onto each definition to track how many times they are used +#[derive(Debug, PartialEq, Eq)] +pub struct ResolverMeta { + num_times_used: usize, + ident: HirIdent, + warn_if_unused: bool, +} + +type ScopeForest = GenericScopeForest; + +pub struct Elaborator<'context> { + scopes: ScopeForest, + + errors: Vec<(CompilationError, FileId)>, + + interner: &'context mut NodeInterner, + + def_maps: &'context BTreeMap, + + file: FileId, + + in_unconstrained_fn: bool, + nested_loops: usize, + + /// True if the current module is a contract. + /// This is usually determined by self.path_resolver.module_id(), but it can + /// be overridden for impls. Impls are an odd case since the methods within resolve + /// as if they're in the parent module, but should be placed in a child module. + /// Since they should be within a child module, in_contract is manually set to false + /// for these so we can still resolve them in the parent module without them being in a contract. + in_contract: bool, + + /// Contains a mapping of the current struct or functions's generics to + /// unique type variables if we're resolving a struct. Empty otherwise. + /// This is a Vec rather than a map to preserve the order a functions generics + /// were declared in. + generics: Vec<(Rc, TypeVariable, Span)>, + + /// When resolving lambda expressions, we need to keep track of the variables + /// that are captured. We do this in order to create the hidden environment + /// parameter for the lambda function. + lambda_stack: Vec, + + /// Set to the current type if we're resolving an impl + self_type: Option, + + /// The current dependency item we're resolving. + /// Used to link items to their dependencies in the dependency graph + current_item: Option, + + /// If we're currently resolving methods within a trait impl, this will be set + /// to the corresponding trait impl ID. + current_trait_impl: Option, + + trait_id: Option, + + /// In-resolution names + /// + /// This needs to be a set because we can have multiple in-resolution + /// names when resolving structs that are declared in reverse order of their + /// dependencies, such as in the following case: + /// + /// ``` + /// struct Wrapper { + /// value: Wrapped + /// } + /// struct Wrapped { + /// } + /// ``` + resolving_ids: BTreeSet, + + trait_bounds: Vec, + + current_function: Option, + + /// All type variables created in the current function. + /// This map is used to default any integer type variables at the end of + /// a function (before checking trait constraints) if a type wasn't already chosen. + type_variables: Vec, + + /// Trait constraints are collected during type checking until they are + /// verified at the end of a function. This is because constraints arise + /// on each variable, but it is only until function calls when the types + /// needed for the trait constraint may become known. + trait_constraints: Vec<(TraitConstraint, ExprId)>, + + /// The current module this elaborator is in. + /// Initially empty, it is set whenever a new top-level item is resolved. + local_module: LocalModuleId, + + crate_id: CrateId, +} + +impl<'context> Elaborator<'context> { + pub fn new(context: &'context mut Context, crate_id: CrateId) -> Self { + Self { + scopes: ScopeForest::default(), + errors: Vec::new(), + interner: &mut context.def_interner, + def_maps: &context.def_maps, + file: FileId::dummy(), + in_unconstrained_fn: false, + nested_loops: 0, + in_contract: false, + generics: Vec::new(), + lambda_stack: Vec::new(), + self_type: None, + current_item: None, + trait_id: None, + local_module: LocalModuleId::dummy_id(), + crate_id, + resolving_ids: BTreeSet::new(), + trait_bounds: Vec::new(), + current_function: None, + type_variables: Vec::new(), + trait_constraints: Vec::new(), + current_trait_impl: None, + } + } + + pub fn elaborate( + context: &'context mut Context, + crate_id: CrateId, + items: CollectedItems, + ) -> Vec<(CompilationError, FileId)> { + let mut this = Self::new(context, crate_id); + + // the resolver filters literal globals first + for global in items.globals {} + + for alias in items.type_aliases {} + + for trait_ in items.traits {} + + for struct_ in items.types {} + + for trait_impl in &items.trait_impls { + // only collect now + } + + for impl_ in &items.impls { + // only collect now + } + + // resolver resolves non-literal globals here + + for functions in items.functions { + this.file = functions.file_id; + this.trait_id = functions.trait_id; // TODO: Resolve? + for (local_module, id, func) in functions.functions { + this.local_module = local_module; + this.elaborate_function(func, id); + } + } + + for impl_ in items.impls {} + + for trait_impl in items.trait_impls {} + + let cycle_errors = this.interner.check_for_dependency_cycles(); + this.errors.extend(cycle_errors); + + this.errors + } + + fn elaborate_function(&mut self, mut function: NoirFunction, id: FuncId) { + self.current_function = Some(id); + self.resolve_where_clause(&mut function.def.where_clause); + + // Without this, impl methods can accidentally be placed in contracts. See #3254 + if self.self_type.is_some() { + self.in_contract = false; + } + + self.scopes.start_function(); + self.current_item = Some(DependencyId::Function(id)); + + // Check whether the function has globals in the local module and add them to the scope + self.resolve_local_globals(); + self.add_generics(&function.def.generics); + + self.desugar_impl_trait_args(&mut function, id); + self.trait_bounds = function.def.where_clause.clone(); + + let is_low_level_or_oracle = function + .attributes() + .function + .as_ref() + .map_or(false, |func| func.is_low_level() || func.is_oracle()); + + if function.def.is_unconstrained { + self.in_unconstrained_fn = true; + } + + let func_meta = self.extract_meta(&function, id); + + self.add_trait_constraints_to_scope(&func_meta); + + let (hir_func, body_type) = match function.kind { + FunctionKind::Builtin | FunctionKind::LowLevel | FunctionKind::Oracle => { + (HirFunction::empty(), Type::Error) + } + FunctionKind::Normal | FunctionKind::Recursive => { + let block_span = function.def.span; + let (block, body_type) = self.elaborate_block(function.def.body); + let expr_id = self.intern_expr(block, block_span); + self.interner.push_expr_type(expr_id, body_type.clone()); + (HirFunction::unchecked_from_expr(expr_id), body_type) + } + }; + + if !func_meta.can_ignore_return_type() { + self.type_check_function_body(body_type, &func_meta, hir_func.as_expr()); + } + + // Default any type variables that still need defaulting. + // This is done before trait impl search since leaving them bindable can lead to errors + // when multiple impls are available. Instead we default first to choose the Field or u64 impl. + for typ in &self.type_variables { + if let Type::TypeVariable(variable, kind) = typ.follow_bindings() { + let msg = "TypeChecker should only track defaultable type vars"; + variable.bind(kind.default_type().expect(msg)); + } + } + + // Verify any remaining trait constraints arising from the function body + for (constraint, expr_id) in std::mem::take(&mut self.trait_constraints) { + let span = self.interner.expr_span(&expr_id); + self.verify_trait_constraint( + &constraint.typ, + constraint.trait_id, + &constraint.trait_generics, + expr_id, + span, + ); + } + + // Now remove all the `where` clause constraints we added + for constraint in &func_meta.trait_constraints { + self.interner.remove_assumed_trait_implementations_for_trait(constraint.trait_id); + } + + let func_scope_tree = self.scopes.end_function(); + + // The arguments to low-level and oracle functions are always unused so we do not produce warnings for them. + if !is_low_level_or_oracle { + self.check_for_unused_variables_in_scope_tree(func_scope_tree); + } + + self.trait_bounds.clear(); + + self.interner.push_fn_meta(func_meta, id); + self.interner.update_fn(id, hir_func); + self.current_function = None; + } + + /// This turns function parameters of the form: + /// fn foo(x: impl Bar) + /// + /// into + /// fn foo(x: T0_impl_Bar) where T0_impl_Bar: Bar + fn desugar_impl_trait_args(&mut self, func: &mut NoirFunction, func_id: FuncId) { + let mut impl_trait_generics = HashSet::default(); + let mut counter: usize = 0; + for parameter in func.def.parameters.iter_mut() { + if let UnresolvedTypeData::TraitAsType(path, args) = ¶meter.typ.typ { + let mut new_generic_ident: Ident = + format!("T{}_impl_{}", func_id, path.as_string()).into(); + let mut new_generic_path = Path::from_ident(new_generic_ident.clone()); + while impl_trait_generics.contains(&new_generic_ident) + || self.lookup_generic_or_global_type(&new_generic_path).is_some() + { + new_generic_ident = + format!("T{}_impl_{}_{}", func_id, path.as_string(), counter).into(); + new_generic_path = Path::from_ident(new_generic_ident.clone()); + counter += 1; + } + impl_trait_generics.insert(new_generic_ident.clone()); + + let is_synthesized = true; + let new_generic_type_data = + UnresolvedTypeData::Named(new_generic_path, vec![], is_synthesized); + let new_generic_type = + UnresolvedType { typ: new_generic_type_data.clone(), span: None }; + let new_trait_bound = TraitBound { + trait_path: path.clone(), + trait_id: None, + trait_generics: args.to_vec(), + }; + let new_trait_constraint = UnresolvedTraitConstraint { + typ: new_generic_type, + trait_bound: new_trait_bound, + }; + + parameter.typ.typ = new_generic_type_data; + func.def.generics.push(new_generic_ident); + func.def.where_clause.push(new_trait_constraint); + } + } + self.add_generics(&impl_trait_generics.into_iter().collect()); + } + + /// Add the given generics to scope. + /// Each generic will have a fresh Shared associated with it. + pub fn add_generics(&mut self, generics: &UnresolvedGenerics) -> Generics { + vecmap(generics, |generic| { + // Map the generic to a fresh type variable + let id = self.interner.next_type_variable_id(); + let typevar = TypeVariable::unbound(id); + let span = generic.0.span(); + + // Check for name collisions of this generic + let name = Rc::new(generic.0.contents.clone()); + + if let Some((_, _, first_span)) = self.find_generic(&name) { + self.push_err(ResolverError::DuplicateDefinition { + name: generic.0.contents.clone(), + first_span: *first_span, + second_span: span, + }); + } else { + self.generics.push((name, typevar.clone(), span)); + } + + typevar + }) + } + + fn push_err(&mut self, error: impl Into) { + self.errors.push((error.into(), self.file)); + } + + fn resolve_where_clause(&mut self, clause: &mut [UnresolvedTraitConstraint]) { + for bound in clause { + if let Some(trait_id) = self.resolve_trait_by_path(bound.trait_bound.trait_path.clone()) + { + bound.trait_bound.trait_id = Some(trait_id); + } + } + } + + fn resolve_trait_by_path(&mut self, path: Path) -> Option { + let path_resolver = StandardPathResolver::new(self.module_id()); + + let error = match path_resolver.resolve(self.def_maps, path.clone()) { + Ok(PathResolution { module_def_id: ModuleDefId::TraitId(trait_id), error }) => { + if let Some(error) = error { + self.push_err(error); + } + return Some(trait_id); + } + Ok(_) => DefCollectorErrorKind::NotATrait { not_a_trait_name: path }, + Err(_) => DefCollectorErrorKind::TraitNotFound { trait_path: path }, + }; + self.push_err(error); + None + } + + fn resolve_local_globals(&mut self) { + let globals = vecmap(self.interner.get_all_globals(), |global| { + (global.id, global.local_id, global.ident.clone()) + }); + for (id, local_module_id, name) in globals { + if local_module_id == self.local_module { + let definition = DefinitionKind::Global(id); + self.add_global_variable_decl(name, definition); + } + } + } + + /// TODO: This is currently only respected for generic free functions + /// there's a bunch of other places where trait constraints can pop up + fn resolve_trait_constraints( + &mut self, + where_clause: &[UnresolvedTraitConstraint], + ) -> Vec { + where_clause + .iter() + .cloned() + .filter_map(|constraint| self.resolve_trait_constraint(constraint)) + .collect() + } + + pub fn resolve_trait_constraint( + &mut self, + constraint: UnresolvedTraitConstraint, + ) -> Option { + let typ = self.resolve_type(constraint.typ); + let trait_generics = + vecmap(constraint.trait_bound.trait_generics, |typ| self.resolve_type(typ)); + + let span = constraint.trait_bound.trait_path.span(); + let the_trait = self.lookup_trait_or_error(constraint.trait_bound.trait_path)?; + let trait_id = the_trait.id; + + let expected_generics = the_trait.generics.len(); + let actual_generics = trait_generics.len(); + + if actual_generics != expected_generics { + let item_name = the_trait.name.to_string(); + self.push_err(ResolverError::IncorrectGenericCount { + span, + item_name, + actual: actual_generics, + expected: expected_generics, + }); + } + + Some(TraitConstraint { typ, trait_id, trait_generics }) + } + + /// Extract metadata from a NoirFunction + /// to be used in analysis and intern the function parameters + /// Prerequisite: self.add_generics() has already been called with the given + /// function's generics, including any generics from the impl, if any. + fn extract_meta(&mut self, func: &NoirFunction, func_id: FuncId) -> FuncMeta { + let location = Location::new(func.name_ident().span(), self.file); + let id = self.interner.function_definition_id(func_id); + let name_ident = HirIdent::non_trait_method(id, location); + + let attributes = func.attributes().clone(); + let has_no_predicates_attribute = attributes.is_no_predicates(); + let should_fold = attributes.is_foldable(); + if !self.inline_attribute_allowed(func) { + if has_no_predicates_attribute { + self.push_err(ResolverError::NoPredicatesAttributeOnUnconstrained { + ident: func.name_ident().clone(), + }); + } else if should_fold { + self.push_err(ResolverError::FoldAttributeOnUnconstrained { + ident: func.name_ident().clone(), + }); + } + } + // Both the #[fold] and #[no_predicates] alter a function's inline type and code generation in similar ways. + // In certain cases such as type checking (for which the following flag will be used) both attributes + // indicate we should code generate in the same way. Thus, we unify the attributes into one flag here. + let has_inline_attribute = has_no_predicates_attribute || should_fold; + let is_entry_point = self.is_entry_point_function(func); + + let mut generics = vecmap(&self.generics, |(_, typevar, _)| typevar.clone()); + let mut parameters = vec![]; + let mut parameter_types = vec![]; + + for Param { visibility, pattern, typ, span: _ } in func.parameters().iter().cloned() { + if visibility == Visibility::Public && !self.pub_allowed(func) { + self.push_err(ResolverError::UnnecessaryPub { + ident: func.name_ident().clone(), + position: PubPosition::Parameter, + }); + } + + let type_span = typ.span.unwrap_or_else(|| pattern.span()); + let typ = self.resolve_type_inner(typ, &mut generics); + self.check_if_type_is_valid_for_program_input( + &typ, + is_entry_point, + has_inline_attribute, + type_span, + ); + let pattern = self.elaborate_pattern(pattern, typ.clone(), DefinitionKind::Local(None)); + + parameters.push((pattern, typ.clone(), visibility)); + parameter_types.push(typ); + } + + let return_type = Box::new(self.resolve_type(func.return_type())); + + self.declare_numeric_generics(¶meter_types, &return_type); + + if !self.pub_allowed(func) && func.def.return_visibility == Visibility::Public { + self.push_err(ResolverError::UnnecessaryPub { + ident: func.name_ident().clone(), + position: PubPosition::ReturnType, + }); + } + + let is_low_level_function = + attributes.function.as_ref().map_or(false, |func| func.is_low_level()); + + if !self.crate_id.is_stdlib() && is_low_level_function { + let error = + ResolverError::LowLevelFunctionOutsideOfStdlib { ident: func.name_ident().clone() }; + self.push_err(error); + } + + // 'pub' is required on return types for entry point functions + if is_entry_point + && return_type.as_ref() != &Type::Unit + && func.def.return_visibility == Visibility::Private + { + self.push_err(ResolverError::NecessaryPub { ident: func.name_ident().clone() }); + } + // '#[recursive]' attribute is only allowed for entry point functions + if !is_entry_point && func.kind == FunctionKind::Recursive { + self.push_err(ResolverError::MisplacedRecursiveAttribute { + ident: func.name_ident().clone(), + }); + } + + if matches!(attributes.function, Some(FunctionAttribute::Test { .. })) + && !parameters.is_empty() + { + self.push_err(ResolverError::TestFunctionHasParameters { + span: func.name_ident().span(), + }); + } + + let mut typ = Type::Function(parameter_types, return_type, Box::new(Type::Unit)); + + if !generics.is_empty() { + typ = Type::Forall(generics, Box::new(typ)); + } + + self.interner.push_definition_type(name_ident.id, typ.clone()); + + let direct_generics = func.def.generics.iter(); + let direct_generics = direct_generics + .filter_map(|generic| self.find_generic(&generic.0.contents)) + .map(|(name, typevar, _span)| (name.clone(), typevar.clone())) + .collect(); + + FuncMeta { + name: name_ident, + kind: func.kind, + location, + typ, + direct_generics, + trait_impl: self.current_trait_impl, + parameters: parameters.into(), + return_type: func.def.return_type.clone(), + return_visibility: func.def.return_visibility, + has_body: !func.def.body.is_empty(), + trait_constraints: self.resolve_trait_constraints(&func.def.where_clause), + is_entry_point, + has_inline_attribute, + } + } + + /// Only sized types are valid to be used as main's parameters or the parameters to a contract + /// function. If the given type is not sized (e.g. contains a slice or NamedGeneric type), an + /// error is issued. + fn check_if_type_is_valid_for_program_input( + &mut self, + typ: &Type, + is_entry_point: bool, + has_inline_attribute: bool, + span: Span, + ) { + if (is_entry_point && !typ.is_valid_for_program_input()) + || (has_inline_attribute && !typ.is_valid_non_inlined_function_input()) + { + self.push_err(TypeCheckError::InvalidTypeForEntryPoint { span }); + } + } + + fn inline_attribute_allowed(&self, func: &NoirFunction) -> bool { + // Inline attributes are only relevant for constrained functions + // as all unconstrained functions are not inlined + !func.def.is_unconstrained + } + + /// True if the 'pub' keyword is allowed on parameters in this function + /// 'pub' on function parameters is only allowed for entry point functions + fn pub_allowed(&self, func: &NoirFunction) -> bool { + self.is_entry_point_function(func) || func.attributes().is_foldable() + } + + fn is_entry_point_function(&self, func: &NoirFunction) -> bool { + if self.in_contract { + func.attributes().is_contract_entry_point() + } else { + func.name() == MAIN_FUNCTION + } + } + + fn declare_numeric_generics(&mut self, params: &[Type], return_type: &Type) { + if self.generics.is_empty() { + return; + } + + for (name_to_find, type_variable) in Self::find_numeric_generics(params, return_type) { + // Declare any generics to let users use numeric generics in scope. + // Don't issue a warning if these are unused + // + // We can fail to find the generic in self.generics if it is an implicit one created + // by the compiler. This can happen when, e.g. eliding array lengths using the slice + // syntax [T]. + if let Some((name, _, span)) = + self.generics.iter().find(|(name, _, _)| name.as_ref() == &name_to_find) + { + let ident = Ident::new(name.to_string(), *span); + let definition = DefinitionKind::GenericType(type_variable); + self.add_variable_decl_inner(ident, false, false, false, definition); + } + } + } + + fn find_numeric_generics( + parameters: &[Type], + return_type: &Type, + ) -> Vec<(String, TypeVariable)> { + let mut found = BTreeMap::new(); + for parameter in parameters { + Self::find_numeric_generics_in_type(parameter, &mut found); + } + Self::find_numeric_generics_in_type(return_type, &mut found); + found.into_iter().collect() + } + + fn find_numeric_generics_in_type(typ: &Type, found: &mut BTreeMap) { + match typ { + Type::FieldElement + | Type::Integer(_, _) + | Type::Bool + | Type::Unit + | Type::Error + | Type::TypeVariable(_, _) + | Type::Constant(_) + | Type::NamedGeneric(_, _) + | Type::Code + | Type::Forall(_, _) => (), + + Type::TraitAsType(_, _, args) => { + for arg in args { + Self::find_numeric_generics_in_type(arg, found); + } + } + + Type::Array(length, element_type) => { + if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + found.insert(name.to_string(), type_variable.clone()); + } + Self::find_numeric_generics_in_type(element_type, found); + } + + Type::Slice(element_type) => { + Self::find_numeric_generics_in_type(element_type, found); + } + + Type::Tuple(fields) => { + for field in fields { + Self::find_numeric_generics_in_type(field, found); + } + } + + Type::Function(parameters, return_type, _env) => { + for parameter in parameters { + Self::find_numeric_generics_in_type(parameter, found); + } + Self::find_numeric_generics_in_type(return_type, found); + } + + Type::Struct(struct_type, generics) => { + for (i, generic) in generics.iter().enumerate() { + if let Type::NamedGeneric(type_variable, name) = generic { + if struct_type.borrow().generic_is_numeric(i) { + found.insert(name.to_string(), type_variable.clone()); + } + } else { + Self::find_numeric_generics_in_type(generic, found); + } + } + } + Type::Alias(alias, generics) => { + for (i, generic) in generics.iter().enumerate() { + if let Type::NamedGeneric(type_variable, name) = generic { + if alias.borrow().generic_is_numeric(i) { + found.insert(name.to_string(), type_variable.clone()); + } + } else { + Self::find_numeric_generics_in_type(generic, found); + } + } + } + Type::MutableReference(element) => Self::find_numeric_generics_in_type(element, found), + Type::String(length) => { + if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + found.insert(name.to_string(), type_variable.clone()); + } + } + Type::FmtString(length, fields) => { + if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + found.insert(name.to_string(), type_variable.clone()); + } + Self::find_numeric_generics_in_type(fields, found); + } + } + } + + fn add_trait_constraints_to_scope(&mut self, func_meta: &FuncMeta) { + for constraint in &func_meta.trait_constraints { + let object = constraint.typ.clone(); + let trait_id = constraint.trait_id; + let generics = constraint.trait_generics.clone(); + + if !self.interner.add_assumed_trait_implementation(object, trait_id, generics) { + if let Some(the_trait) = self.interner.try_get_trait(trait_id) { + let trait_name = the_trait.name.to_string(); + let typ = constraint.typ.clone(); + let span = func_meta.location.span; + self.push_err(TypeCheckError::UnneededTraitConstraint { + trait_name, + typ, + span, + }); + } + } + } + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/patterns.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/patterns.rs new file mode 100644 index 00000000000..195d37878f1 --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -0,0 +1,465 @@ +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; +use rustc_hash::FxHashSet as HashSet; + +use crate::{ + ast::ERROR_IDENT, + hir::{ + resolution::errors::ResolverError, + type_check::{Source, TypeCheckError}, + }, + hir_def::{ + expr::{HirIdent, ImplKind}, + stmt::HirPattern, + }, + macros_api::{HirExpression, Ident, Path, Pattern}, + node_interner::{DefinitionId, DefinitionKind, ExprId, TraitImplKind}, + Shared, StructType, Type, TypeBindings, +}; + +use super::{Elaborator, ResolverMeta}; + +impl<'context> Elaborator<'context> { + pub(super) fn elaborate_pattern( + &mut self, + pattern: Pattern, + expected_type: Type, + definition_kind: DefinitionKind, + ) -> HirPattern { + self.elaborate_pattern_mut(pattern, expected_type, definition_kind, None) + } + + fn elaborate_pattern_mut( + &mut self, + pattern: Pattern, + expected_type: Type, + definition: DefinitionKind, + mutable: Option, + ) -> HirPattern { + match pattern { + Pattern::Identifier(name) => { + // If this definition is mutable, do not store the rhs because it will + // not always refer to the correct value of the variable + let definition = match (mutable, definition) { + (Some(_), DefinitionKind::Local(_)) => DefinitionKind::Local(None), + (_, other) => other, + }; + let ident = self.add_variable_decl(name, mutable.is_some(), true, definition); + self.interner.push_definition_type(ident.id, expected_type); + HirPattern::Identifier(ident) + } + Pattern::Mutable(pattern, span, _) => { + if let Some(first_mut) = mutable { + self.push_err(ResolverError::UnnecessaryMut { first_mut, second_mut: span }); + } + + let pattern = + self.elaborate_pattern_mut(*pattern, expected_type, definition, Some(span)); + let location = Location::new(span, self.file); + HirPattern::Mutable(Box::new(pattern), location) + } + Pattern::Tuple(fields, span) => { + let field_types = match expected_type { + Type::Tuple(fields) => fields, + Type::Error => Vec::new(), + expected_type => { + let tuple = + Type::Tuple(vecmap(&fields, |_| self.interner.next_type_variable())); + + self.push_err(TypeCheckError::TypeMismatchWithSource { + expected: expected_type, + actual: tuple, + span, + source: Source::Assignment, + }); + Vec::new() + } + }; + + let fields = vecmap(fields.into_iter().enumerate(), |(i, field)| { + let field_type = field_types.get(i).cloned().unwrap_or(Type::Error); + self.elaborate_pattern_mut(field, field_type, definition.clone(), mutable) + }); + let location = Location::new(span, self.file); + HirPattern::Tuple(fields, location) + } + Pattern::Struct(name, fields, span) => self.elaborate_struct_pattern( + name, + fields, + span, + expected_type, + definition, + mutable, + ), + } + } + + fn elaborate_struct_pattern( + &mut self, + name: Path, + fields: Vec<(Ident, Pattern)>, + span: Span, + expected_type: Type, + definition: DefinitionKind, + mutable: Option, + ) -> HirPattern { + let error_identifier = |this: &mut Self| { + // Must create a name here to return a HirPattern::Identifier. Allowing + // shadowing here lets us avoid further errors if we define ERROR_IDENT + // multiple times. + let name = ERROR_IDENT.into(); + let identifier = this.add_variable_decl(name, false, true, definition.clone()); + HirPattern::Identifier(identifier) + }; + + let (struct_type, generics) = match self.lookup_type_or_error(name) { + Some(Type::Struct(struct_type, generics)) => (struct_type, generics), + None => return error_identifier(self), + Some(typ) => { + self.push_err(ResolverError::NonStructUsedInConstructor { typ, span }); + return error_identifier(self); + } + }; + + let actual_type = Type::Struct(struct_type.clone(), generics); + let location = Location::new(span, self.file); + + self.unify(&actual_type, &expected_type, || TypeCheckError::TypeMismatchWithSource { + expected: expected_type.clone(), + actual: actual_type.clone(), + span: location.span, + source: Source::Assignment, + }); + + let typ = struct_type.clone(); + let fields = self.resolve_constructor_pattern_fields( + typ, + fields, + span, + expected_type.clone(), + definition, + mutable, + ); + + HirPattern::Struct(expected_type, fields, location) + } + + /// Resolve all the fields of a struct constructor expression. + /// Ensures all fields are present, none are repeated, and all + /// are part of the struct. + fn resolve_constructor_pattern_fields( + &mut self, + struct_type: Shared, + fields: Vec<(Ident, Pattern)>, + span: Span, + expected_type: Type, + definition: DefinitionKind, + mutable: Option, + ) -> Vec<(Ident, HirPattern)> { + let mut ret = Vec::with_capacity(fields.len()); + let mut seen_fields = HashSet::default(); + let mut unseen_fields = struct_type.borrow().field_names(); + + for (field, pattern) in fields { + let field_type = expected_type.get_field_type(&field.0.contents).unwrap_or(Type::Error); + let resolved = + self.elaborate_pattern_mut(pattern, field_type, definition.clone(), mutable); + + if unseen_fields.contains(&field) { + unseen_fields.remove(&field); + seen_fields.insert(field.clone()); + } else if seen_fields.contains(&field) { + // duplicate field + self.push_err(ResolverError::DuplicateField { field: field.clone() }); + } else { + // field not required by struct + self.push_err(ResolverError::NoSuchField { + field: field.clone(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret.push((field, resolved)); + } + + if !unseen_fields.is_empty() { + self.push_err(ResolverError::MissingFields { + span, + missing_fields: unseen_fields.into_iter().map(|field| field.to_string()).collect(), + struct_definition: struct_type.borrow().name.clone(), + }); + } + + ret + } + + pub(super) fn add_variable_decl( + &mut self, + name: Ident, + mutable: bool, + allow_shadowing: bool, + definition: DefinitionKind, + ) -> HirIdent { + self.add_variable_decl_inner(name, mutable, allow_shadowing, true, definition) + } + + pub fn add_variable_decl_inner( + &mut self, + name: Ident, + mutable: bool, + allow_shadowing: bool, + warn_if_unused: bool, + definition: DefinitionKind, + ) -> HirIdent { + if definition.is_global() { + return self.add_global_variable_decl(name, definition); + } + + let location = Location::new(name.span(), self.file); + let id = + self.interner.push_definition(name.0.contents.clone(), mutable, definition, location); + let ident = HirIdent::non_trait_method(id, location); + let resolver_meta = + ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused }; + + let scope = self.scopes.get_mut_scope(); + let old_value = scope.add_key_value(name.0.contents.clone(), resolver_meta); + + if !allow_shadowing { + if let Some(old_value) = old_value { + self.push_err(ResolverError::DuplicateDefinition { + name: name.0.contents, + first_span: old_value.ident.location.span, + second_span: location.span, + }); + } + } + + ident + } + + pub fn add_global_variable_decl( + &mut self, + name: Ident, + definition: DefinitionKind, + ) -> HirIdent { + let scope = self.scopes.get_mut_scope(); + + // This check is necessary to maintain the same definition ids in the interner. Currently, each function uses a new resolver that has its own ScopeForest and thus global scope. + // We must first check whether an existing definition ID has been inserted as otherwise there will be multiple definitions for the same global statement. + // This leads to an error in evaluation where the wrong definition ID is selected when evaluating a statement using the global. The check below prevents this error. + let mut global_id = None; + let global = self.interner.get_all_globals(); + for global_info in global { + if global_info.ident == name && global_info.local_id == self.local_module { + global_id = Some(global_info.id); + } + } + + let (ident, resolver_meta) = if let Some(id) = global_id { + let global = self.interner.get_global(id); + let hir_ident = HirIdent::non_trait_method(global.definition_id, global.location); + let ident = hir_ident.clone(); + let resolver_meta = ResolverMeta { num_times_used: 0, ident, warn_if_unused: true }; + (hir_ident, resolver_meta) + } else { + let location = Location::new(name.span(), self.file); + let id = + self.interner.push_definition(name.0.contents.clone(), false, definition, location); + let ident = HirIdent::non_trait_method(id, location); + let resolver_meta = + ResolverMeta { num_times_used: 0, ident: ident.clone(), warn_if_unused: true }; + (ident, resolver_meta) + }; + + let old_global_value = scope.add_key_value(name.0.contents.clone(), resolver_meta); + if let Some(old_global_value) = old_global_value { + self.push_err(ResolverError::DuplicateDefinition { + name: name.0.contents.clone(), + first_span: old_global_value.ident.location.span, + second_span: name.span(), + }); + } + ident + } + + // Checks for a variable having been declared before. + // (Variable declaration and definition cannot be separate in Noir.) + // Once the variable has been found, intern and link `name` to this definition, + // returning (the ident, the IdentId of `name`) + // + // If a variable is not found, then an error is logged and a dummy id + // is returned, for better error reporting UX + pub(super) fn find_variable_or_default(&mut self, name: &Ident) -> (HirIdent, usize) { + self.use_variable(name).unwrap_or_else(|error| { + self.push_err(error); + let id = DefinitionId::dummy_id(); + let location = Location::new(name.span(), self.file); + (HirIdent::non_trait_method(id, location), 0) + }) + } + + /// Lookup and use the specified variable. + /// This will increment its use counter by one and return the variable if found. + /// If the variable is not found, an error is returned. + pub(super) fn use_variable( + &mut self, + name: &Ident, + ) -> Result<(HirIdent, usize), ResolverError> { + // Find the definition for this Ident + let scope_tree = self.scopes.current_scope_tree(); + let variable = scope_tree.find(&name.0.contents); + + let location = Location::new(name.span(), self.file); + if let Some((variable_found, scope)) = variable { + variable_found.num_times_used += 1; + let id = variable_found.ident.id; + Ok((HirIdent::non_trait_method(id, location), scope)) + } else { + Err(ResolverError::VariableNotDeclared { + name: name.0.contents.clone(), + span: name.0.span(), + }) + } + } + + pub(super) fn elaborate_variable(&mut self, variable: Path) -> (ExprId, Type) { + let span = variable.span; + let expr = self.resolve_variable(variable); + let id = self.interner.push_expr(HirExpression::Ident(expr.clone())); + self.interner.push_expr_location(id, span, self.file); + let typ = self.type_check_variable(expr, id); + self.interner.push_expr_type(id, typ.clone()); + (id, typ) + } + + fn resolve_variable(&mut self, path: Path) -> HirIdent { + if let Some((method, constraint, assumed)) = self.resolve_trait_generic_path(&path) { + HirIdent { + location: Location::new(path.span, self.file), + id: self.interner.trait_method_id(method), + impl_kind: ImplKind::TraitMethod(method, constraint, assumed), + } + } else { + // If the Path is being used as an Expression, then it is referring to a global from a separate module + // Otherwise, then it is referring to an Identifier + // This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10; + // If the expression is a singular indent, we search the resolver's current scope as normal. + let (hir_ident, var_scope_index) = self.get_ident_from_path(path); + + if hir_ident.id != DefinitionId::dummy_id() { + match self.interner.definition(hir_ident.id).kind { + DefinitionKind::Function(id) => { + if let Some(current_item) = self.current_item { + self.interner.add_function_dependency(current_item, id); + } + } + DefinitionKind::Global(global_id) => { + if let Some(current_item) = self.current_item { + self.interner.add_global_dependency(current_item, global_id); + } + } + DefinitionKind::GenericType(_) => { + // Initialize numeric generics to a polymorphic integer type in case + // they're used in expressions. We must do this here since type_check_variable + // does not check definition kinds and otherwise expects parameters to + // already be typed. + if self.interner.definition_type(hir_ident.id) == Type::Error { + let typ = Type::polymorphic_integer_or_field(self.interner); + self.interner.push_definition_type(hir_ident.id, typ); + } + } + DefinitionKind::Local(_) => { + // only local variables can be captured by closures. + self.resolve_local_variable(hir_ident.clone(), var_scope_index); + } + } + } + + hir_ident + } + } + + pub(super) fn type_check_variable(&mut self, ident: HirIdent, expr_id: ExprId) -> Type { + let mut bindings = TypeBindings::new(); + + // Add type bindings from any constraints that were used. + // We need to do this first since otherwise instantiating the type below + // will replace each trait generic with a fresh type variable, rather than + // the type used in the trait constraint (if it exists). See #4088. + if let ImplKind::TraitMethod(_, constraint, _) = &ident.impl_kind { + let the_trait = self.interner.get_trait(constraint.trait_id); + assert_eq!(the_trait.generics.len(), constraint.trait_generics.len()); + + for (param, arg) in the_trait.generics.iter().zip(&constraint.trait_generics) { + // Avoid binding t = t + if !arg.occurs(param.id()) { + bindings.insert(param.id(), (param.clone(), arg.clone())); + } + } + } + + // An identifiers type may be forall-quantified in the case of generic functions. + // E.g. `fn foo(t: T, field: Field) -> T` has type `forall T. fn(T, Field) -> T`. + // We must instantiate identifiers at every call site to replace this T with a new type + // variable to handle generic functions. + let t = self.interner.id_type_substitute_trait_as_type(ident.id); + + // This instantiates a trait's generics as well which need to be set + // when the constraint below is later solved for when the function is + // finished. How to link the two? + let (typ, bindings) = t.instantiate_with_bindings(bindings, self.interner); + + // Push any trait constraints required by this definition to the context + // to be checked later when the type of this variable is further constrained. + if let Some(definition) = self.interner.try_definition(ident.id) { + if let DefinitionKind::Function(function) = definition.kind { + let function = self.interner.function_meta(&function); + + for mut constraint in function.trait_constraints.clone() { + constraint.apply_bindings(&bindings); + self.trait_constraints.push((constraint, expr_id)); + } + } + } + + if let ImplKind::TraitMethod(_, mut constraint, assumed) = ident.impl_kind { + constraint.apply_bindings(&bindings); + if assumed { + let trait_impl = TraitImplKind::Assumed { + object_type: constraint.typ, + trait_generics: constraint.trait_generics, + }; + self.interner.select_impl_for_expression(expr_id, trait_impl); + } else { + // Currently only one impl can be selected per expr_id, so this + // constraint needs to be pushed after any other constraints so + // that monomorphization can resolve this trait method to the correct impl. + self.trait_constraints.push((constraint, expr_id)); + } + } + + self.interner.store_instantiation_bindings(expr_id, bindings); + typ + } + + fn get_ident_from_path(&mut self, path: Path) -> (HirIdent, usize) { + let location = Location::new(path.span(), self.file); + + let error = match path.as_ident().map(|ident| self.use_variable(ident)) { + Some(Ok(found)) => return found, + // Try to look it up as a global, but still issue the first error if we fail + Some(Err(error)) => match self.lookup_global(path) { + Ok(id) => return (HirIdent::non_trait_method(id, location), 0), + Err(_) => error, + }, + None => match self.lookup_global(path) { + Ok(id) => return (HirIdent::non_trait_method(id, location), 0), + Err(error) => error, + }, + }; + self.push_err(error); + let id = DefinitionId::dummy_id(); + (HirIdent::non_trait_method(id, location), 0) + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/scope.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/scope.rs new file mode 100644 index 00000000000..cf10dbbc2b2 --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/scope.rs @@ -0,0 +1,200 @@ +use noirc_errors::Spanned; +use rustc_hash::FxHashMap as HashMap; + +use crate::ast::ERROR_IDENT; +use crate::hir::comptime::Value; +use crate::hir::def_map::{LocalModuleId, ModuleId}; +use crate::hir::resolution::path_resolver::{PathResolver, StandardPathResolver}; +use crate::hir::resolution::resolver::SELF_TYPE_NAME; +use crate::hir::scope::{Scope as GenericScope, ScopeTree as GenericScopeTree}; +use crate::macros_api::Ident; +use crate::{ + hir::{ + def_map::{ModuleDefId, TryFromModuleDefId}, + resolution::errors::ResolverError, + }, + hir_def::{ + expr::{HirCapturedVar, HirIdent}, + traits::Trait, + }, + macros_api::{Path, StructId}, + node_interner::{DefinitionId, TraitId, TypeAliasId}, + Shared, StructType, +}; +use crate::{Type, TypeAlias}; + +use super::{Elaborator, ResolverMeta}; + +type Scope = GenericScope; +type ScopeTree = GenericScopeTree; + +impl<'context> Elaborator<'context> { + pub(super) fn lookup(&mut self, path: Path) -> Result { + let span = path.span(); + let id = self.resolve_path(path)?; + T::try_from(id).ok_or_else(|| ResolverError::Expected { + expected: T::description(), + got: id.as_str().to_owned(), + span, + }) + } + + pub(super) fn module_id(&self) -> ModuleId { + assert_ne!(self.local_module, LocalModuleId::dummy_id(), "local_module is unset"); + ModuleId { krate: self.crate_id, local_id: self.local_module } + } + + pub(super) fn resolve_path(&mut self, path: Path) -> Result { + let resolver = StandardPathResolver::new(self.module_id()); + let path_resolution = resolver.resolve(self.def_maps, path)?; + + if let Some(error) = path_resolution.error { + self.push_err(error); + } + + Ok(path_resolution.module_def_id) + } + + pub(super) fn get_struct(&self, type_id: StructId) -> Shared { + self.interner.get_struct(type_id) + } + + pub(super) fn get_trait_mut(&mut self, trait_id: TraitId) -> &mut Trait { + self.interner.get_trait_mut(trait_id) + } + + pub(super) fn resolve_local_variable(&mut self, hir_ident: HirIdent, var_scope_index: usize) { + let mut transitive_capture_index: Option = None; + + for lambda_index in 0..self.lambda_stack.len() { + if self.lambda_stack[lambda_index].scope_index > var_scope_index { + // Beware: the same variable may be captured multiple times, so we check + // for its presence before adding the capture below. + let position = self.lambda_stack[lambda_index] + .captures + .iter() + .position(|capture| capture.ident.id == hir_ident.id); + + if position.is_none() { + self.lambda_stack[lambda_index].captures.push(HirCapturedVar { + ident: hir_ident.clone(), + transitive_capture_index, + }); + } + + if lambda_index + 1 < self.lambda_stack.len() { + // There is more than one closure between the current scope and + // the scope of the variable, so this is a propagated capture. + // We need to track the transitive capture index as we go up in + // the closure stack. + transitive_capture_index = Some(position.unwrap_or( + // If this was a fresh capture, we added it to the end of + // the captures vector: + self.lambda_stack[lambda_index].captures.len() - 1, + )); + } + } + } + } + + pub(super) fn lookup_global(&mut self, path: Path) -> Result { + let span = path.span(); + let id = self.resolve_path(path)?; + + if let Some(function) = TryFromModuleDefId::try_from(id) { + return Ok(self.interner.function_definition_id(function)); + } + + if let Some(global) = TryFromModuleDefId::try_from(id) { + let global = self.interner.get_global(global); + return Ok(global.definition_id); + } + + let expected = "global variable".into(); + let got = "local variable".into(); + Err(ResolverError::Expected { span, expected, got }) + } + + pub fn push_scope(&mut self) { + self.scopes.start_scope(); + } + + pub fn pop_scope(&mut self) { + let scope = self.scopes.end_scope(); + self.check_for_unused_variables_in_scope_tree(scope.into()); + } + + pub fn check_for_unused_variables_in_scope_tree(&mut self, scope_decls: ScopeTree) { + let mut unused_vars = Vec::new(); + for scope in scope_decls.0.into_iter() { + Self::check_for_unused_variables_in_local_scope(scope, &mut unused_vars); + } + + for unused_var in unused_vars.iter() { + if let Some(definition_info) = self.interner.try_definition(unused_var.id) { + let name = &definition_info.name; + if name != ERROR_IDENT && !definition_info.is_global() { + let ident = Ident(Spanned::from(unused_var.location.span, name.to_owned())); + self.push_err(ResolverError::UnusedVariable { ident }); + } + } + } + } + + fn check_for_unused_variables_in_local_scope(decl_map: Scope, unused_vars: &mut Vec) { + let unused_variables = decl_map.filter(|(variable_name, metadata)| { + let has_underscore_prefix = variable_name.starts_with('_'); // XXX: This is used for development mode, and will be removed + metadata.warn_if_unused && metadata.num_times_used == 0 && !has_underscore_prefix + }); + unused_vars.extend(unused_variables.map(|(_, meta)| meta.ident.clone())); + } + + /// Lookup a given trait by name/path. + pub fn lookup_trait_or_error(&mut self, path: Path) -> Option<&mut Trait> { + match self.lookup(path) { + Ok(trait_id) => Some(self.get_trait_mut(trait_id)), + Err(error) => { + self.push_err(error); + None + } + } + } + + /// Lookup a given struct type by name. + pub fn lookup_struct_or_error(&mut self, path: Path) -> Option> { + match self.lookup(path) { + Ok(struct_id) => Some(self.get_struct(struct_id)), + Err(error) => { + self.push_err(error); + None + } + } + } + + /// Looks up a given type by name. + /// This will also instantiate any struct types found. + pub(super) fn lookup_type_or_error(&mut self, path: Path) -> Option { + let ident = path.as_ident(); + if ident.map_or(false, |i| i == SELF_TYPE_NAME) { + if let Some(typ) = &self.self_type { + return Some(typ.clone()); + } + } + + match self.lookup(path) { + Ok(struct_id) => { + let struct_type = self.get_struct(struct_id); + let generics = struct_type.borrow().instantiate(self.interner); + Some(Type::Struct(struct_type, generics)) + } + Err(error) => { + self.push_err(error); + None + } + } + } + + pub fn lookup_type_alias(&mut self, path: Path) -> Option> { + self.lookup(path).ok().map(|id| self.interner.get_type_alias(id)) + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs new file mode 100644 index 00000000000..a7a2df4041e --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/statements.rs @@ -0,0 +1,409 @@ +use noirc_errors::{Location, Span}; + +use crate::{ + ast::{AssignStatement, ConstrainStatement, LValue}, + hir::{ + resolution::errors::ResolverError, + type_check::{Source, TypeCheckError}, + }, + hir_def::{ + expr::HirIdent, + stmt::{ + HirAssignStatement, HirConstrainStatement, HirForStatement, HirLValue, HirLetStatement, + }, + }, + macros_api::{ + ForLoopStatement, ForRange, HirStatement, LetStatement, Statement, StatementKind, + }, + node_interner::{DefinitionId, DefinitionKind, StmtId}, + Type, +}; + +use super::Elaborator; + +impl<'context> Elaborator<'context> { + fn elaborate_statement_value(&mut self, statement: Statement) -> (HirStatement, Type) { + match statement.kind { + StatementKind::Let(let_stmt) => self.elaborate_let(let_stmt), + StatementKind::Constrain(constrain) => self.elaborate_constrain(constrain), + StatementKind::Assign(assign) => self.elaborate_assign(assign), + StatementKind::For(for_stmt) => self.elaborate_for(for_stmt), + StatementKind::Break => self.elaborate_jump(true, statement.span), + StatementKind::Continue => self.elaborate_jump(false, statement.span), + StatementKind::Comptime(statement) => self.elaborate_comptime(*statement), + StatementKind::Expression(expr) => { + let (expr, typ) = self.elaborate_expression(expr); + (HirStatement::Expression(expr), typ) + } + StatementKind::Semi(expr) => { + let (expr, _typ) = self.elaborate_expression(expr); + (HirStatement::Semi(expr), Type::Unit) + } + StatementKind::Error => (HirStatement::Error, Type::Error), + } + } + + pub(super) fn elaborate_statement(&mut self, statement: Statement) -> (StmtId, Type) { + let span = statement.span; + let (hir_statement, typ) = self.elaborate_statement_value(statement); + let id = self.interner.push_stmt(hir_statement); + self.interner.push_stmt_location(id, span, self.file); + (id, typ) + } + + pub(super) fn elaborate_let(&mut self, let_stmt: LetStatement) -> (HirStatement, Type) { + let expr_span = let_stmt.expression.span; + let (expression, expr_type) = self.elaborate_expression(let_stmt.expression); + let definition = DefinitionKind::Local(Some(expression)); + let annotated_type = self.resolve_type(let_stmt.r#type); + + // First check if the LHS is unspecified + // If so, then we give it the same type as the expression + let r#type = if annotated_type != Type::Error { + // Now check if LHS is the same type as the RHS + // Importantly, we do not coerce any types implicitly + self.unify_with_coercions(&expr_type, &annotated_type, expression, || { + TypeCheckError::TypeMismatch { + expected_typ: annotated_type.to_string(), + expr_typ: expr_type.to_string(), + expr_span, + } + }); + if annotated_type.is_unsigned() { + self.lint_overflowing_uint(&expression, &annotated_type); + } + annotated_type + } else { + expr_type + }; + + let let_ = HirLetStatement { + pattern: self.elaborate_pattern(let_stmt.pattern, r#type.clone(), definition), + r#type, + expression, + attributes: let_stmt.attributes, + comptime: let_stmt.comptime, + }; + (HirStatement::Let(let_), Type::Unit) + } + + pub(super) fn elaborate_constrain(&mut self, stmt: ConstrainStatement) -> (HirStatement, Type) { + let expr_span = stmt.0.span; + let (expr_id, expr_type) = self.elaborate_expression(stmt.0); + + // Must type check the assertion message expression so that we instantiate bindings + let msg = stmt.1.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0); + + self.unify(&expr_type, &Type::Bool, || TypeCheckError::TypeMismatch { + expr_typ: expr_type.to_string(), + expected_typ: Type::Bool.to_string(), + expr_span, + }); + + (HirStatement::Constrain(HirConstrainStatement(expr_id, self.file, msg)), Type::Unit) + } + + pub(super) fn elaborate_assign(&mut self, assign: AssignStatement) -> (HirStatement, Type) { + let span = assign.expression.span; + let (expression, expr_type) = self.elaborate_expression(assign.expression); + let (lvalue, lvalue_type, mutable) = self.elaborate_lvalue(assign.lvalue, span); + + if !mutable { + let (name, span) = self.get_lvalue_name_and_span(&lvalue); + self.push_err(TypeCheckError::VariableMustBeMutable { name, span }); + } + + self.unify_with_coercions(&expr_type, &lvalue_type, expression, || { + TypeCheckError::TypeMismatchWithSource { + actual: expr_type.clone(), + expected: lvalue_type.clone(), + span, + source: Source::Assignment, + } + }); + + let stmt = HirAssignStatement { lvalue, expression }; + (HirStatement::Assign(stmt), Type::Unit) + } + + pub(super) fn elaborate_for(&mut self, for_loop: ForLoopStatement) -> (HirStatement, Type) { + let (start, end) = match for_loop.range { + ForRange::Range(start, end) => (start, end), + ForRange::Array(_) => { + let for_stmt = + for_loop.range.into_for(for_loop.identifier, for_loop.block, for_loop.span); + + return self.elaborate_statement_value(for_stmt); + } + }; + + let start_span = start.span; + let end_span = end.span; + + let (start_range, start_range_type) = self.elaborate_expression(start); + let (end_range, end_range_type) = self.elaborate_expression(end); + let (identifier, block) = (for_loop.identifier, for_loop.block); + + self.nested_loops += 1; + self.push_scope(); + + // TODO: For loop variables are currently mutable by default since we haven't + // yet implemented syntax for them to be optionally mutable. + let kind = DefinitionKind::Local(None); + let identifier = self.add_variable_decl(identifier, false, true, kind); + + // Check that start range and end range have the same types + let range_span = start_span.merge(end_span); + self.unify(&start_range_type, &end_range_type, || TypeCheckError::TypeMismatch { + expected_typ: start_range_type.to_string(), + expr_typ: end_range_type.to_string(), + expr_span: range_span, + }); + + let expected_type = self.polymorphic_integer(); + + self.unify(&start_range_type, &expected_type, || TypeCheckError::TypeCannotBeUsed { + typ: start_range_type.clone(), + place: "for loop", + span: range_span, + }); + + self.interner.push_definition_type(identifier.id, start_range_type); + + let (block, _block_type) = self.elaborate_expression(block); + + self.pop_scope(); + self.nested_loops -= 1; + + let statement = + HirStatement::For(HirForStatement { start_range, end_range, block, identifier }); + + (statement, Type::Unit) + } + + fn elaborate_jump(&mut self, is_break: bool, span: noirc_errors::Span) -> (HirStatement, Type) { + if !self.in_unconstrained_fn { + self.push_err(ResolverError::JumpInConstrainedFn { is_break, span }); + } + if self.nested_loops == 0 { + self.push_err(ResolverError::JumpOutsideLoop { is_break, span }); + } + + let expr = if is_break { HirStatement::Break } else { HirStatement::Continue }; + (expr, self.interner.next_type_variable()) + } + + fn get_lvalue_name_and_span(&self, lvalue: &HirLValue) -> (String, Span) { + match lvalue { + HirLValue::Ident(name, _) => { + let span = name.location.span; + + if let Some(definition) = self.interner.try_definition(name.id) { + (definition.name.clone(), span) + } else { + ("(undeclared variable)".into(), span) + } + } + HirLValue::MemberAccess { object, .. } => self.get_lvalue_name_and_span(object), + HirLValue::Index { array, .. } => self.get_lvalue_name_and_span(array), + HirLValue::Dereference { lvalue, .. } => self.get_lvalue_name_and_span(lvalue), + } + } + + fn elaborate_lvalue(&mut self, lvalue: LValue, assign_span: Span) -> (HirLValue, Type, bool) { + match lvalue { + LValue::Ident(ident) => { + let mut mutable = true; + let (ident, scope_index) = self.find_variable_or_default(&ident); + self.resolve_local_variable(ident.clone(), scope_index); + + let typ = if ident.id == DefinitionId::dummy_id() { + Type::Error + } else { + if let Some(definition) = self.interner.try_definition(ident.id) { + mutable = definition.mutable; + } + + let typ = self.interner.definition_type(ident.id).instantiate(self.interner).0; + typ.follow_bindings() + }; + + (HirLValue::Ident(ident.clone(), typ.clone()), typ, mutable) + } + LValue::MemberAccess { object, field_name, span } => { + let (object, lhs_type, mut mutable) = self.elaborate_lvalue(*object, assign_span); + let mut object = Box::new(object); + let field_name = field_name.clone(); + + let object_ref = &mut object; + let mutable_ref = &mut mutable; + let location = Location::new(span, self.file); + + let dereference_lhs = move |_: &mut Self, _, element_type| { + // We must create a temporary value first to move out of object_ref before + // we eventually reassign to it. + let id = DefinitionId::dummy_id(); + let ident = HirIdent::non_trait_method(id, location); + let tmp_value = HirLValue::Ident(ident, Type::Error); + + let lvalue = std::mem::replace(object_ref, Box::new(tmp_value)); + *object_ref = + Box::new(HirLValue::Dereference { lvalue, element_type, location }); + *mutable_ref = true; + }; + + let name = &field_name.0.contents; + let (object_type, field_index) = self + .check_field_access(&lhs_type, name, field_name.span(), Some(dereference_lhs)) + .unwrap_or((Type::Error, 0)); + + let field_index = Some(field_index); + let typ = object_type.clone(); + let lvalue = + HirLValue::MemberAccess { object, field_name, field_index, typ, location }; + (lvalue, object_type, mutable) + } + LValue::Index { array, index, span } => { + let expr_span = index.span; + let (index, index_type) = self.elaborate_expression(index); + let location = Location::new(span, self.file); + + let expected = self.polymorphic_integer_or_field(); + self.unify(&index_type, &expected, || TypeCheckError::TypeMismatch { + expected_typ: "an integer".to_owned(), + expr_typ: index_type.to_string(), + expr_span, + }); + + let (mut lvalue, mut lvalue_type, mut mutable) = + self.elaborate_lvalue(*array, assign_span); + + // Before we check that the lvalue is an array, try to dereference it as many times + // as needed to unwrap any &mut wrappers. + while let Type::MutableReference(element) = lvalue_type.follow_bindings() { + let element_type = element.as_ref().clone(); + lvalue = + HirLValue::Dereference { lvalue: Box::new(lvalue), element_type, location }; + lvalue_type = *element; + // We know this value to be mutable now since we found an `&mut` + mutable = true; + } + + let typ = match lvalue_type.follow_bindings() { + Type::Array(_, elem_type) => *elem_type, + Type::Slice(elem_type) => *elem_type, + Type::Error => Type::Error, + Type::String(_) => { + let (_lvalue_name, lvalue_span) = self.get_lvalue_name_and_span(&lvalue); + self.push_err(TypeCheckError::StringIndexAssign { span: lvalue_span }); + Type::Error + } + other => { + // TODO: Need a better span here + self.push_err(TypeCheckError::TypeMismatch { + expected_typ: "array".to_string(), + expr_typ: other.to_string(), + expr_span: assign_span, + }); + Type::Error + } + }; + + let array = Box::new(lvalue); + let array_type = typ.clone(); + (HirLValue::Index { array, index, typ, location }, array_type, mutable) + } + LValue::Dereference(lvalue, span) => { + let (lvalue, reference_type, _) = self.elaborate_lvalue(*lvalue, assign_span); + let lvalue = Box::new(lvalue); + let location = Location::new(span, self.file); + + let element_type = Type::type_variable(self.interner.next_type_variable_id()); + let expected_type = Type::MutableReference(Box::new(element_type.clone())); + + self.unify(&reference_type, &expected_type, || TypeCheckError::TypeMismatch { + expected_typ: expected_type.to_string(), + expr_typ: reference_type.to_string(), + expr_span: assign_span, + }); + + // Dereferences are always mutable since we already type checked against a &mut T + let typ = element_type.clone(); + let lvalue = HirLValue::Dereference { lvalue, element_type, location }; + (lvalue, typ, true) + } + } + } + + /// Type checks a field access, adding dereference operators as necessary + pub(super) fn check_field_access( + &mut self, + lhs_type: &Type, + field_name: &str, + span: Span, + dereference_lhs: Option, + ) -> Option<(Type, usize)> { + let lhs_type = lhs_type.follow_bindings(); + + match &lhs_type { + Type::Struct(s, args) => { + let s = s.borrow(); + if let Some((field, index)) = s.get_field(field_name, args) { + return Some((field, index)); + } + } + Type::Tuple(elements) => { + if let Ok(index) = field_name.parse::() { + let length = elements.len(); + if index < length { + return Some((elements[index].clone(), index)); + } else { + self.push_err(TypeCheckError::TupleIndexOutOfBounds { + index, + lhs_type, + length, + span, + }); + return None; + } + } + } + // If the lhs is a mutable reference we automatically transform + // lhs.field into (*lhs).field + Type::MutableReference(element) => { + if let Some(mut dereference_lhs) = dereference_lhs { + dereference_lhs(self, lhs_type.clone(), element.as_ref().clone()); + return self.check_field_access( + element, + field_name, + span, + Some(dereference_lhs), + ); + } else { + let (element, index) = + self.check_field_access(element, field_name, span, dereference_lhs)?; + return Some((Type::MutableReference(Box::new(element)), index)); + } + } + _ => (), + } + + // If we get here the type has no field named 'access.rhs'. + // Now we specialize the error message based on whether we know the object type in question yet. + if let Type::TypeVariable(..) = &lhs_type { + self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + } else if lhs_type != Type::Error { + self.push_err(TypeCheckError::AccessUnknownMember { + lhs_type, + field_name: field_name.to_string(), + span, + }); + } + + None + } + + pub(super) fn elaborate_comptime(&self, _statement: Statement) -> (HirStatement, Type) { + todo!("Comptime scanning") + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs new file mode 100644 index 00000000000..4c8364b6dda --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/elaborator/types.rs @@ -0,0 +1,1438 @@ +use std::rc::Rc; + +use iter_extended::vecmap; +use noirc_errors::{Location, Span}; + +use crate::{ + ast::{BinaryOpKind, IntegerBitSize, UnresolvedTraitConstraint, UnresolvedTypeExpression}, + hir::{ + def_map::ModuleDefId, + resolution::{ + errors::ResolverError, + import::PathResolution, + resolver::{verify_mutable_reference, SELF_TYPE_NAME}, + }, + type_check::{Source, TypeCheckError}, + }, + hir_def::{ + expr::{ + HirBinaryOp, HirCallExpression, HirIdent, HirMemberAccess, HirMethodReference, + HirPrefixExpression, + }, + function::FuncMeta, + traits::{Trait, TraitConstraint}, + }, + macros_api::{ + HirExpression, HirLiteral, HirStatement, Path, PathKind, SecondaryAttribute, Signedness, + UnaryOp, UnresolvedType, UnresolvedTypeData, + }, + node_interner::{DefinitionKind, ExprId, GlobalId, TraitId, TraitImplKind, TraitMethodId}, + Generics, Shared, StructType, Type, TypeAlias, TypeBinding, TypeVariable, TypeVariableKind, +}; + +use super::Elaborator; + +impl<'context> Elaborator<'context> { + /// Translates an UnresolvedType to a Type + pub(super) fn resolve_type(&mut self, typ: UnresolvedType) -> Type { + let span = typ.span; + let resolved_type = self.resolve_type_inner(typ, &mut vec![]); + if resolved_type.is_nested_slice() { + self.push_err(ResolverError::NestedSlices { span: span.unwrap() }); + } + + resolved_type + } + + /// Translates an UnresolvedType into a Type and appends any + /// freshly created TypeVariables created to new_variables. + pub fn resolve_type_inner( + &mut self, + typ: UnresolvedType, + new_variables: &mut Generics, + ) -> Type { + use crate::ast::UnresolvedTypeData::*; + + let resolved_type = match typ.typ { + FieldElement => Type::FieldElement, + Array(size, elem) => { + let elem = Box::new(self.resolve_type_inner(*elem, new_variables)); + let size = self.resolve_array_size(Some(size), new_variables); + Type::Array(Box::new(size), elem) + } + Slice(elem) => { + let elem = Box::new(self.resolve_type_inner(*elem, new_variables)); + Type::Slice(elem) + } + Expression(expr) => self.convert_expression_type(expr), + Integer(sign, bits) => Type::Integer(sign, bits), + Bool => Type::Bool, + String(size) => { + let resolved_size = self.resolve_array_size(size, new_variables); + Type::String(Box::new(resolved_size)) + } + FormatString(size, fields) => { + let resolved_size = self.convert_expression_type(size); + let fields = self.resolve_type_inner(*fields, new_variables); + Type::FmtString(Box::new(resolved_size), Box::new(fields)) + } + Code => Type::Code, + Unit => Type::Unit, + Unspecified => Type::Error, + Error => Type::Error, + Named(path, args, _) => self.resolve_named_type(path, args, new_variables), + TraitAsType(path, args) => self.resolve_trait_as_type(path, args, new_variables), + + Tuple(fields) => { + Type::Tuple(vecmap(fields, |field| self.resolve_type_inner(field, new_variables))) + } + Function(args, ret, env) => { + let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); + let ret = Box::new(self.resolve_type_inner(*ret, new_variables)); + + // expect() here is valid, because the only places we don't have a span are omitted types + // e.g. a function without return type implicitly has a spanless UnresolvedType::Unit return type + // To get an invalid env type, the user must explicitly specify the type, which will have a span + let env_span = + env.span.expect("Unexpected missing span for closure environment type"); + + let env = Box::new(self.resolve_type_inner(*env, new_variables)); + + match *env { + Type::Unit | Type::Tuple(_) | Type::NamedGeneric(_, _) => { + Type::Function(args, ret, env) + } + _ => { + self.push_err(ResolverError::InvalidClosureEnvironment { + typ: *env, + span: env_span, + }); + Type::Error + } + } + } + MutableReference(element) => { + Type::MutableReference(Box::new(self.resolve_type_inner(*element, new_variables))) + } + Parenthesized(typ) => self.resolve_type_inner(*typ, new_variables), + }; + + if let Type::Struct(_, _) = resolved_type { + if let Some(unresolved_span) = typ.span { + // Record the location of the type reference + self.interner.push_type_ref_location( + resolved_type.clone(), + Location::new(unresolved_span, self.file), + ); + } + } + resolved_type + } + + pub fn find_generic(&self, target_name: &str) -> Option<&(Rc, TypeVariable, Span)> { + self.generics.iter().find(|(name, _, _)| name.as_ref() == target_name) + } + + fn resolve_named_type( + &mut self, + path: Path, + args: Vec, + new_variables: &mut Generics, + ) -> Type { + if args.is_empty() { + if let Some(typ) = self.lookup_generic_or_global_type(&path) { + return typ; + } + } + + // Check if the path is a type variable first. We currently disallow generics on type + // variables since we do not support higher-kinded types. + if path.segments.len() == 1 { + let name = &path.last_segment().0.contents; + + if name == SELF_TYPE_NAME { + if let Some(self_type) = self.self_type.clone() { + if !args.is_empty() { + self.push_err(ResolverError::GenericsOnSelfType { span: path.span() }); + } + return self_type; + } + } + } + + let span = path.span(); + let mut args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); + + if let Some(type_alias) = self.lookup_type_alias(path.clone()) { + let type_alias = type_alias.borrow(); + let expected_generic_count = type_alias.generics.len(); + let type_alias_string = type_alias.to_string(); + let id = type_alias.id; + + self.verify_generics_count(expected_generic_count, &mut args, span, || { + type_alias_string + }); + + if let Some(item) = self.current_item { + self.interner.add_type_alias_dependency(item, id); + } + + // Collecting Type Alias references [Location]s to be used by LSP in order + // to resolve the definition of the type alias + self.interner.add_type_alias_ref(id, Location::new(span, self.file)); + + // Because there is no ordering to when type aliases (and other globals) are resolved, + // it is possible for one to refer to an Error type and issue no error if it is set + // equal to another type alias. Fixing this fully requires an analysis to create a DFG + // of definition ordering, but for now we have an explicit check here so that we at + // least issue an error that the type was not found instead of silently passing. + let alias = self.interner.get_type_alias(id); + return Type::Alias(alias, args); + } + + match self.lookup_struct_or_error(path) { + Some(struct_type) => { + if self.resolving_ids.contains(&struct_type.borrow().id) { + self.push_err(ResolverError::SelfReferentialStruct { + span: struct_type.borrow().name.span(), + }); + + return Type::Error; + } + + let expected_generic_count = struct_type.borrow().generics.len(); + if !self.in_contract + && self + .interner + .struct_attributes(&struct_type.borrow().id) + .iter() + .any(|attr| matches!(attr, SecondaryAttribute::Abi(_))) + { + self.push_err(ResolverError::AbiAttributeOutsideContract { + span: struct_type.borrow().name.span(), + }); + } + self.verify_generics_count(expected_generic_count, &mut args, span, || { + struct_type.borrow().to_string() + }); + + if let Some(current_item) = self.current_item { + let dependency_id = struct_type.borrow().id; + self.interner.add_type_dependency(current_item, dependency_id); + } + + Type::Struct(struct_type, args) + } + None => Type::Error, + } + } + + fn resolve_trait_as_type( + &mut self, + path: Path, + args: Vec, + new_variables: &mut Generics, + ) -> Type { + let args = vecmap(args, |arg| self.resolve_type_inner(arg, new_variables)); + + if let Some(t) = self.lookup_trait_or_error(path) { + Type::TraitAsType(t.id, Rc::new(t.name.to_string()), args) + } else { + Type::Error + } + } + + fn verify_generics_count( + &mut self, + expected_count: usize, + args: &mut Vec, + span: Span, + type_name: impl FnOnce() -> String, + ) { + if args.len() != expected_count { + self.push_err(ResolverError::IncorrectGenericCount { + span, + item_name: type_name(), + actual: args.len(), + expected: expected_count, + }); + + // Fix the generic count so we can continue typechecking + args.resize_with(expected_count, || Type::Error); + } + } + + pub fn lookup_generic_or_global_type(&mut self, path: &Path) -> Option { + if path.segments.len() == 1 { + let name = &path.last_segment().0.contents; + if let Some((name, var, _)) = self.find_generic(name) { + return Some(Type::NamedGeneric(var.clone(), name.clone())); + } + } + + // If we cannot find a local generic of the same name, try to look up a global + match self.resolve_path(path.clone()) { + Ok(ModuleDefId::GlobalId(id)) => { + if let Some(current_item) = self.current_item { + self.interner.add_global_dependency(current_item, id); + } + + Some(Type::Constant(self.eval_global_as_array_length(id, path))) + } + _ => None, + } + } + + fn resolve_array_size( + &mut self, + length: Option, + new_variables: &mut Generics, + ) -> Type { + match length { + None => { + let id = self.interner.next_type_variable_id(); + let typevar = TypeVariable::unbound(id); + new_variables.push(typevar.clone()); + + // 'Named'Generic is a bit of a misnomer here, we want a type variable that + // wont be bound over but this one has no name since we do not currently + // require users to explicitly be generic over array lengths. + Type::NamedGeneric(typevar, Rc::new("".into())) + } + Some(length) => self.convert_expression_type(length), + } + } + + pub(super) fn convert_expression_type(&mut self, length: UnresolvedTypeExpression) -> Type { + match length { + UnresolvedTypeExpression::Variable(path) => { + self.lookup_generic_or_global_type(&path).unwrap_or_else(|| { + self.push_err(ResolverError::NoSuchNumericTypeVariable { path }); + Type::Constant(0) + }) + } + UnresolvedTypeExpression::Constant(int, _) => Type::Constant(int), + UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => { + let (lhs_span, rhs_span) = (lhs.span(), rhs.span()); + let lhs = self.convert_expression_type(*lhs); + let rhs = self.convert_expression_type(*rhs); + + match (lhs, rhs) { + (Type::Constant(lhs), Type::Constant(rhs)) => { + Type::Constant(op.function()(lhs, rhs)) + } + (lhs, _) => { + let span = + if !matches!(lhs, Type::Constant(_)) { lhs_span } else { rhs_span }; + self.push_err(ResolverError::InvalidArrayLengthExpr { span }); + Type::Constant(0) + } + } + } + } + } + + // this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type) + // + // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + fn resolve_trait_static_method_by_self( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + let trait_id = self.trait_id?; + + if path.kind == PathKind::Plain && path.segments.len() == 2 { + let name = &path.segments[0].0.contents; + let method = &path.segments[1]; + + if name == SELF_TYPE_NAME { + let the_trait = self.interner.get_trait(trait_id); + let method = the_trait.find_method(method.0.contents.as_str())?; + + let constraint = TraitConstraint { + typ: self.self_type.clone()?, + trait_generics: Type::from_generics(&the_trait.generics), + trait_id, + }; + return Some((method, constraint, false)); + } + } + None + } + + // this resolves TraitName::some_static_method + // + // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + fn resolve_trait_static_method( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + if path.kind == PathKind::Plain && path.segments.len() == 2 { + let method = &path.segments[1]; + + let mut trait_path = path.clone(); + trait_path.pop(); + let trait_id = self.lookup(trait_path).ok()?; + let the_trait = self.interner.get_trait(trait_id); + + let method = the_trait.find_method(method.0.contents.as_str())?; + let constraint = TraitConstraint { + typ: Type::TypeVariable( + the_trait.self_type_typevar.clone(), + TypeVariableKind::Normal, + ), + trait_generics: Type::from_generics(&the_trait.generics), + trait_id, + }; + return Some((method, constraint, false)); + } + None + } + + // This resolves a static trait method T::trait_method by iterating over the where clause + // + // Returns the trait method, trait constraint, and whether the impl is assumed from a where + // clause. This is always true since this helper searches where clauses for a generic constraint. + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + fn resolve_trait_method_by_named_generic( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + if path.segments.len() != 2 { + return None; + } + + for UnresolvedTraitConstraint { typ, trait_bound } in self.trait_bounds.clone() { + if let UnresolvedTypeData::Named(constraint_path, _, _) = &typ.typ { + // if `path` is `T::method_name`, we're looking for constraint of the form `T: SomeTrait` + if constraint_path.segments.len() == 1 + && path.segments[0] != constraint_path.last_segment() + { + continue; + } + + if let Ok(ModuleDefId::TraitId(trait_id)) = + self.resolve_path(trait_bound.trait_path.clone()) + { + let the_trait = self.interner.get_trait(trait_id); + if let Some(method) = + the_trait.find_method(path.segments.last().unwrap().0.contents.as_str()) + { + let constraint = TraitConstraint { + trait_id, + typ: self.resolve_type(typ.clone()), + trait_generics: vecmap(trait_bound.trait_generics, |typ| { + self.resolve_type(typ) + }), + }; + return Some((method, constraint, true)); + } + } + } + } + None + } + + // Try to resolve the given trait method path. + // + // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not + // E.g. `t.method()` with `where T: Foo` in scope will return `(Foo::method, T, vec![Bar])` + pub(super) fn resolve_trait_generic_path( + &mut self, + path: &Path, + ) -> Option<(TraitMethodId, TraitConstraint, bool)> { + self.resolve_trait_static_method_by_self(path) + .or_else(|| self.resolve_trait_static_method(path)) + .or_else(|| self.resolve_trait_method_by_named_generic(path)) + } + + fn eval_global_as_array_length(&mut self, global: GlobalId, path: &Path) -> u64 { + let Some(stmt) = self.interner.get_global_let_statement(global) else { + let path = path.clone(); + self.push_err(ResolverError::NoSuchNumericTypeVariable { path }); + return 0; + }; + + let length = stmt.expression; + let span = self.interner.expr_span(&length); + let result = self.try_eval_array_length_id(length, span); + + match result.map(|length| length.try_into()) { + Ok(Ok(length_value)) => return length_value, + Ok(Err(_cast_err)) => self.push_err(ResolverError::IntegerTooLarge { span }), + Err(Some(error)) => self.push_err(error), + Err(None) => (), + } + 0 + } + + fn try_eval_array_length_id( + &self, + rhs: ExprId, + span: Span, + ) -> Result> { + // Arbitrary amount of recursive calls to try before giving up + let fuel = 100; + self.try_eval_array_length_id_with_fuel(rhs, span, fuel) + } + + fn try_eval_array_length_id_with_fuel( + &self, + rhs: ExprId, + span: Span, + fuel: u32, + ) -> Result> { + if fuel == 0 { + // If we reach here, it is likely from evaluating cyclic globals. We expect an error to + // be issued for them after name resolution so issue no error now. + return Err(None); + } + + match self.interner.expression(&rhs) { + HirExpression::Literal(HirLiteral::Integer(int, false)) => { + int.try_into_u128().ok_or(Some(ResolverError::IntegerTooLarge { span })) + } + HirExpression::Ident(ident) => { + let definition = self.interner.definition(ident.id); + match definition.kind { + DefinitionKind::Global(global_id) => { + let let_statement = self.interner.get_global_let_statement(global_id); + if let Some(let_statement) = let_statement { + let expression = let_statement.expression; + self.try_eval_array_length_id_with_fuel(expression, span, fuel - 1) + } else { + Err(Some(ResolverError::InvalidArrayLengthExpr { span })) + } + } + _ => Err(Some(ResolverError::InvalidArrayLengthExpr { span })), + } + } + HirExpression::Infix(infix) => { + let lhs = self.try_eval_array_length_id_with_fuel(infix.lhs, span, fuel - 1)?; + let rhs = self.try_eval_array_length_id_with_fuel(infix.rhs, span, fuel - 1)?; + + match infix.operator.kind { + BinaryOpKind::Add => Ok(lhs + rhs), + BinaryOpKind::Subtract => Ok(lhs - rhs), + BinaryOpKind::Multiply => Ok(lhs * rhs), + BinaryOpKind::Divide => Ok(lhs / rhs), + BinaryOpKind::Equal => Ok((lhs == rhs) as u128), + BinaryOpKind::NotEqual => Ok((lhs != rhs) as u128), + BinaryOpKind::Less => Ok((lhs < rhs) as u128), + BinaryOpKind::LessEqual => Ok((lhs <= rhs) as u128), + BinaryOpKind::Greater => Ok((lhs > rhs) as u128), + BinaryOpKind::GreaterEqual => Ok((lhs >= rhs) as u128), + BinaryOpKind::And => Ok(lhs & rhs), + BinaryOpKind::Or => Ok(lhs | rhs), + BinaryOpKind::Xor => Ok(lhs ^ rhs), + BinaryOpKind::ShiftRight => Ok(lhs >> rhs), + BinaryOpKind::ShiftLeft => Ok(lhs << rhs), + BinaryOpKind::Modulo => Ok(lhs % rhs), + } + } + _other => Err(Some(ResolverError::InvalidArrayLengthExpr { span })), + } + } + + /// Check if an assignment is overflowing with respect to `annotated_type` + /// in a declaration statement where `annotated_type` is an unsigned integer + pub(super) fn lint_overflowing_uint(&mut self, rhs_expr: &ExprId, annotated_type: &Type) { + let expr = self.interner.expression(rhs_expr); + let span = self.interner.expr_span(rhs_expr); + match expr { + HirExpression::Literal(HirLiteral::Integer(value, false)) => { + let v = value.to_u128(); + if let Type::Integer(_, bit_count) = annotated_type { + let bit_count: u32 = (*bit_count).into(); + let max = 1 << bit_count; + if v >= max { + self.push_err(TypeCheckError::OverflowingAssignment { + expr: value, + ty: annotated_type.clone(), + range: format!("0..={}", max - 1), + span, + }); + }; + }; + } + HirExpression::Prefix(expr) => { + self.lint_overflowing_uint(&expr.rhs, annotated_type); + if matches!(expr.operator, UnaryOp::Minus) { + self.push_err(TypeCheckError::InvalidUnaryOp { + kind: "annotated_type".to_string(), + span, + }); + } + } + HirExpression::Infix(expr) => { + self.lint_overflowing_uint(&expr.lhs, annotated_type); + self.lint_overflowing_uint(&expr.rhs, annotated_type); + } + _ => {} + } + } + + pub(super) fn unify( + &mut self, + actual: &Type, + expected: &Type, + make_error: impl FnOnce() -> TypeCheckError, + ) { + let mut errors = Vec::new(); + actual.unify(expected, &mut errors, make_error); + self.errors.extend(errors.into_iter().map(|error| (error.into(), self.file))); + } + + /// Wrapper of Type::unify_with_coercions using self.errors + pub(super) fn unify_with_coercions( + &mut self, + actual: &Type, + expected: &Type, + expression: ExprId, + make_error: impl FnOnce() -> TypeCheckError, + ) { + let mut errors = Vec::new(); + actual.unify_with_coercions(expected, expression, self.interner, &mut errors, make_error); + self.errors.extend(errors.into_iter().map(|error| (error.into(), self.file))); + } + + /// Return a fresh integer or field type variable and log it + /// in self.type_variables to default it later. + pub(super) fn polymorphic_integer_or_field(&mut self) -> Type { + let typ = Type::polymorphic_integer_or_field(self.interner); + self.type_variables.push(typ.clone()); + typ + } + + /// Return a fresh integer type variable and log it + /// in self.type_variables to default it later. + pub(super) fn polymorphic_integer(&mut self) -> Type { + let typ = Type::polymorphic_integer(self.interner); + self.type_variables.push(typ.clone()); + typ + } + + /// Translates a (possibly Unspecified) UnresolvedType to a Type. + /// Any UnresolvedType::Unspecified encountered are replaced with fresh type variables. + pub(super) fn resolve_inferred_type(&mut self, typ: UnresolvedType) -> Type { + match &typ.typ { + UnresolvedTypeData::Unspecified => self.interner.next_type_variable(), + _ => self.resolve_type_inner(typ, &mut vec![]), + } + } + + pub(super) fn type_check_prefix_operand( + &mut self, + op: &crate::ast::UnaryOp, + rhs_type: &Type, + span: Span, + ) -> Type { + let mut unify = |this: &mut Self, expected| { + this.unify(rhs_type, &expected, || TypeCheckError::TypeMismatch { + expr_typ: rhs_type.to_string(), + expected_typ: expected.to_string(), + expr_span: span, + }); + expected + }; + + match op { + crate::ast::UnaryOp::Minus => { + if rhs_type.is_unsigned() { + self.push_err(TypeCheckError::InvalidUnaryOp { + kind: rhs_type.to_string(), + span, + }); + } + let expected = self.polymorphic_integer_or_field(); + self.unify(rhs_type, &expected, || TypeCheckError::InvalidUnaryOp { + kind: rhs_type.to_string(), + span, + }); + expected + } + crate::ast::UnaryOp::Not => { + let rhs_type = rhs_type.follow_bindings(); + + // `!` can work on booleans or integers + if matches!(rhs_type, Type::Integer(..)) { + return rhs_type; + } + + unify(self, Type::Bool) + } + crate::ast::UnaryOp::MutableReference => { + Type::MutableReference(Box::new(rhs_type.follow_bindings())) + } + crate::ast::UnaryOp::Dereference { implicitly_added: _ } => { + let element_type = self.interner.next_type_variable(); + unify(self, Type::MutableReference(Box::new(element_type.clone()))); + element_type + } + } + } + + /// Insert as many dereference operations as necessary to automatically dereference a method + /// call object to its base value type T. + pub(super) fn insert_auto_dereferences(&mut self, object: ExprId, typ: Type) -> (ExprId, Type) { + if let Type::MutableReference(element) = typ { + let location = self.interner.id_location(object); + + let object = self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { + operator: UnaryOp::Dereference { implicitly_added: true }, + rhs: object, + })); + self.interner.push_expr_type(object, element.as_ref().clone()); + self.interner.push_expr_location(object, location.span, location.file); + + // Recursively dereference to allow for converting &mut &mut T to T + self.insert_auto_dereferences(object, *element) + } else { + (object, typ) + } + } + + /// Given a method object: `(*foo).bar` of a method call `(*foo).bar.baz()`, remove the + /// implicitly added dereference operator if one is found. + /// + /// Returns Some(new_expr_id) if a dereference was removed and None otherwise. + fn try_remove_implicit_dereference(&mut self, object: ExprId) -> Option { + match self.interner.expression(&object) { + HirExpression::MemberAccess(mut access) => { + let new_lhs = self.try_remove_implicit_dereference(access.lhs)?; + access.lhs = new_lhs; + access.is_offset = true; + + // `object` will have a different type now, which will be filled in + // later when type checking the method call as a function call. + self.interner.replace_expr(&object, HirExpression::MemberAccess(access)); + Some(object) + } + HirExpression::Prefix(prefix) => match prefix.operator { + // Found a dereference we can remove. Now just replace it with its rhs to remove it. + UnaryOp::Dereference { implicitly_added: true } => Some(prefix.rhs), + _ => None, + }, + _ => None, + } + } + + fn bind_function_type_impl( + &mut self, + fn_params: &[Type], + fn_ret: &Type, + callsite_args: &[(Type, ExprId, Span)], + span: Span, + ) -> Type { + if fn_params.len() != callsite_args.len() { + self.push_err(TypeCheckError::ParameterCountMismatch { + expected: fn_params.len(), + found: callsite_args.len(), + span, + }); + return Type::Error; + } + + for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) { + self.unify(arg, param, || TypeCheckError::TypeMismatch { + expected_typ: param.to_string(), + expr_typ: arg.to_string(), + expr_span: *arg_span, + }); + } + + fn_ret.clone() + } + + pub(super) fn bind_function_type( + &mut self, + function: Type, + args: Vec<(Type, ExprId, Span)>, + span: Span, + ) -> Type { + // Could do a single unification for the entire function type, but matching beforehand + // lets us issue a more precise error on the individual argument that fails to type check. + match function { + Type::TypeVariable(binding, TypeVariableKind::Normal) => { + if let TypeBinding::Bound(typ) = &*binding.borrow() { + return self.bind_function_type(typ.clone(), args, span); + } + + let ret = self.interner.next_type_variable(); + let args = vecmap(args, |(arg, _, _)| arg); + let env_type = self.interner.next_type_variable(); + let expected = Type::Function(args, Box::new(ret.clone()), Box::new(env_type)); + + if let Err(error) = binding.try_bind(expected, span) { + self.push_err(error); + } + ret + } + // The closure env is ignored on purpose: call arguments never place + // constraints on closure environments. + Type::Function(parameters, ret, _env) => { + self.bind_function_type_impl(¶meters, &ret, &args, span) + } + Type::Error => Type::Error, + found => { + self.push_err(TypeCheckError::ExpectedFunction { found, span }); + Type::Error + } + } + } + + pub(super) fn check_cast(&mut self, from: Type, to: &Type, span: Span) -> Type { + match from.follow_bindings() { + Type::Integer(..) + | Type::FieldElement + | Type::TypeVariable(_, TypeVariableKind::IntegerOrField) + | Type::TypeVariable(_, TypeVariableKind::Integer) + | Type::Bool => (), + + Type::TypeVariable(_, _) => { + self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + return Type::Error; + } + Type::Error => return Type::Error, + from => { + self.push_err(TypeCheckError::InvalidCast { from, span }); + return Type::Error; + } + } + + match to { + Type::Integer(sign, bits) => Type::Integer(*sign, *bits), + Type::FieldElement => Type::FieldElement, + Type::Bool => Type::Bool, + Type::Error => Type::Error, + _ => { + self.push_err(TypeCheckError::UnsupportedCast { span }); + Type::Error + } + } + } + + // Given a binary comparison operator and another type. This method will produce the output type + // and a boolean indicating whether to use the trait impl corresponding to the operator + // or not. A value of false indicates the caller to use a primitive operation for this + // operator, while a true value indicates a user-provided trait impl is required. + fn comparator_operand_type_rules( + &mut self, + lhs_type: &Type, + rhs_type: &Type, + op: &HirBinaryOp, + span: Span, + ) -> Result<(Type, bool), TypeCheckError> { + use Type::*; + + match (lhs_type, rhs_type) { + // Avoid reporting errors multiple times + (Error, _) | (_, Error) => Ok((Bool, false)), + (Alias(alias, args), other) | (other, Alias(alias, args)) => { + let alias = alias.borrow().get_type(args); + self.comparator_operand_type_rules(&alias, other, op, span) + } + + // Matches on TypeVariable must be first to follow any type + // bindings. + (TypeVariable(var, _), other) | (other, TypeVariable(var, _)) => { + if let TypeBinding::Bound(binding) = &*var.borrow() { + return self.comparator_operand_type_rules(other, binding, op, span); + } + + let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span); + Ok((Bool, use_impl)) + } + (Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => { + if sign_x != sign_y { + return Err(TypeCheckError::IntegerSignedness { + sign_x: *sign_x, + sign_y: *sign_y, + span, + }); + } + if bit_width_x != bit_width_y { + return Err(TypeCheckError::IntegerBitWidth { + bit_width_x: *bit_width_x, + bit_width_y: *bit_width_y, + span, + }); + } + Ok((Bool, false)) + } + (FieldElement, FieldElement) => { + if op.kind.is_valid_for_field_type() { + Ok((Bool, false)) + } else { + Err(TypeCheckError::FieldComparison { span }) + } + } + + // <= and friends are technically valid for booleans, just not very useful + (Bool, Bool) => Ok((Bool, false)), + + (lhs, rhs) => { + self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource { + expected: lhs.clone(), + actual: rhs.clone(), + span: op.location.span, + source: Source::Binary, + }); + Ok((Bool, true)) + } + } + } + + /// Handles the TypeVariable case for checking binary operators. + /// Returns true if we should use the impl for the operator instead of the primitive + /// version of it. + fn bind_type_variables_for_infix( + &mut self, + lhs_type: &Type, + op: &HirBinaryOp, + rhs_type: &Type, + span: Span, + ) -> bool { + self.unify(lhs_type, rhs_type, || TypeCheckError::TypeMismatchWithSource { + expected: lhs_type.clone(), + actual: rhs_type.clone(), + source: Source::Binary, + span, + }); + + let use_impl = !lhs_type.is_numeric(); + + // If this operator isn't valid for fields we have to possibly narrow + // TypeVariableKind::IntegerOrField to TypeVariableKind::Integer. + // Doing so also ensures a type error if Field is used. + // The is_numeric check is to allow impls for custom types to bypass this. + if !op.kind.is_valid_for_field_type() && lhs_type.is_numeric() { + let target = Type::polymorphic_integer(self.interner); + + use crate::ast::BinaryOpKind::*; + use TypeCheckError::*; + self.unify(lhs_type, &target, || match op.kind { + Less | LessEqual | Greater | GreaterEqual => FieldComparison { span }, + And | Or | Xor | ShiftRight | ShiftLeft => FieldBitwiseOp { span }, + Modulo => FieldModulo { span }, + other => unreachable!("Operator {other:?} should be valid for Field"), + }); + } + + use_impl + } + + // Given a binary operator and another type. This method will produce the output type + // and a boolean indicating whether to use the trait impl corresponding to the operator + // or not. A value of false indicates the caller to use a primitive operation for this + // operator, while a true value indicates a user-provided trait impl is required. + pub(super) fn infix_operand_type_rules( + &mut self, + lhs_type: &Type, + op: &HirBinaryOp, + rhs_type: &Type, + span: Span, + ) -> Result<(Type, bool), TypeCheckError> { + if op.kind.is_comparator() { + return self.comparator_operand_type_rules(lhs_type, rhs_type, op, span); + } + + use Type::*; + match (lhs_type, rhs_type) { + // An error type on either side will always return an error + (Error, _) | (_, Error) => Ok((Error, false)), + (Alias(alias, args), other) | (other, Alias(alias, args)) => { + let alias = alias.borrow().get_type(args); + self.infix_operand_type_rules(&alias, op, other, span) + } + + // Matches on TypeVariable must be first so that we follow any type + // bindings. + (TypeVariable(int, _), other) | (other, TypeVariable(int, _)) => { + if let TypeBinding::Bound(binding) = &*int.borrow() { + return self.infix_operand_type_rules(binding, op, other, span); + } + if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight { + self.unify( + rhs_type, + &Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), + || TypeCheckError::InvalidShiftSize { span }, + ); + let use_impl = if lhs_type.is_numeric() { + let integer_type = Type::polymorphic_integer(self.interner); + self.bind_type_variables_for_infix(lhs_type, op, &integer_type, span) + } else { + true + }; + return Ok((lhs_type.clone(), use_impl)); + } + let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span); + Ok((other.clone(), use_impl)) + } + (Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => { + if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight { + if *sign_y != Signedness::Unsigned || *bit_width_y != IntegerBitSize::Eight { + return Err(TypeCheckError::InvalidShiftSize { span }); + } + return Ok((Integer(*sign_x, *bit_width_x), false)); + } + if sign_x != sign_y { + return Err(TypeCheckError::IntegerSignedness { + sign_x: *sign_x, + sign_y: *sign_y, + span, + }); + } + if bit_width_x != bit_width_y { + return Err(TypeCheckError::IntegerBitWidth { + bit_width_x: *bit_width_x, + bit_width_y: *bit_width_y, + span, + }); + } + Ok((Integer(*sign_x, *bit_width_x), false)) + } + // The result of two Fields is always a witness + (FieldElement, FieldElement) => { + if !op.kind.is_valid_for_field_type() { + if op.kind == BinaryOpKind::Modulo { + return Err(TypeCheckError::FieldModulo { span }); + } else { + return Err(TypeCheckError::FieldBitwiseOp { span }); + } + } + Ok((FieldElement, false)) + } + + (Bool, Bool) => Ok((Bool, false)), + + (lhs, rhs) => { + if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight { + if rhs == &Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight) { + return Ok((lhs.clone(), true)); + } + return Err(TypeCheckError::InvalidShiftSize { span }); + } + self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource { + expected: lhs.clone(), + actual: rhs.clone(), + span: op.location.span, + source: Source::Binary, + }); + Ok((lhs.clone(), true)) + } + } + } + + /// Prerequisite: verify_trait_constraint of the operator's trait constraint. + /// + /// Although by this point the operator is expected to already have a trait impl, + /// we still need to match the operator's type against the method's instantiated type + /// to ensure the instantiation bindings are correct and the monomorphizer can + /// re-apply the needed bindings. + pub(super) fn type_check_operator_method( + &mut self, + expr_id: ExprId, + trait_method_id: TraitMethodId, + object_type: &Type, + span: Span, + ) { + let the_trait = self.interner.get_trait(trait_method_id.trait_id); + + let method = &the_trait.methods[trait_method_id.method_index]; + let (method_type, mut bindings) = method.typ.clone().instantiate(self.interner); + + match method_type { + Type::Function(args, _, _) => { + // We can cheat a bit and match against only the object type here since no operator + // overload uses other generic parameters or return types aside from the object type. + let expected_object_type = &args[0]; + self.unify(object_type, expected_object_type, || TypeCheckError::TypeMismatch { + expected_typ: expected_object_type.to_string(), + expr_typ: object_type.to_string(), + expr_span: span, + }); + } + other => { + unreachable!("Expected operator method to have a function type, but found {other}") + } + } + + // We must also remember to apply these substitutions to the object_type + // referenced by the selected trait impl, if one has yet to be selected. + let impl_kind = self.interner.get_selected_impl_for_expression(expr_id); + if let Some(TraitImplKind::Assumed { object_type, trait_generics }) = impl_kind { + let the_trait = self.interner.get_trait(trait_method_id.trait_id); + let object_type = object_type.substitute(&bindings); + bindings.insert( + the_trait.self_type_typevar_id, + (the_trait.self_type_typevar.clone(), object_type.clone()), + ); + self.interner.select_impl_for_expression( + expr_id, + TraitImplKind::Assumed { object_type, trait_generics }, + ); + } + + self.interner.store_instantiation_bindings(expr_id, bindings); + } + + pub(super) fn type_check_member_access( + &mut self, + mut access: HirMemberAccess, + expr_id: ExprId, + lhs_type: Type, + span: Span, + ) -> Type { + let access_lhs = &mut access.lhs; + + let dereference_lhs = |this: &mut Self, lhs_type, element| { + let old_lhs = *access_lhs; + *access_lhs = this.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { + operator: crate::ast::UnaryOp::Dereference { implicitly_added: true }, + rhs: old_lhs, + })); + this.interner.push_expr_type(old_lhs, lhs_type); + this.interner.push_expr_type(*access_lhs, element); + + let old_location = this.interner.id_location(old_lhs); + this.interner.push_expr_location(*access_lhs, span, old_location.file); + }; + + // If this access is just a field offset, we want to avoid dereferencing + let dereference_lhs = (!access.is_offset).then_some(dereference_lhs); + + match self.check_field_access(&lhs_type, &access.rhs.0.contents, span, dereference_lhs) { + Some((element_type, index)) => { + self.interner.set_field_index(expr_id, index); + // We must update `access` in case we added any dereferences to it + self.interner.replace_expr(&expr_id, HirExpression::MemberAccess(access)); + element_type + } + None => Type::Error, + } + } + + pub(super) fn lookup_method( + &mut self, + object_type: &Type, + method_name: &str, + span: Span, + ) -> Option { + match object_type.follow_bindings() { + Type::Struct(typ, _args) => { + let id = typ.borrow().id; + match self.interner.lookup_method(object_type, id, method_name, false) { + Some(method_id) => Some(HirMethodReference::FuncId(method_id)), + None => { + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + } + } + // TODO: We should allow method calls on `impl Trait`s eventually. + // For now it is fine since they are only allowed on return types. + Type::TraitAsType(..) => { + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + Type::NamedGeneric(_, _) => { + let func_meta = self.interner.function_meta( + &self.current_function.expect("unexpected method outside a function"), + ); + + for constraint in &func_meta.trait_constraints { + if *object_type == constraint.typ { + if let Some(the_trait) = self.interner.try_get_trait(constraint.trait_id) { + for (method_index, method) in the_trait.methods.iter().enumerate() { + if method.name.0.contents == method_name { + let trait_method = TraitMethodId { + trait_id: constraint.trait_id, + method_index, + }; + return Some(HirMethodReference::TraitMethodId( + trait_method, + constraint.trait_generics.clone(), + )); + } + } + } + } + } + + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + // Mutable references to another type should resolve to methods of their element type. + // This may be a struct or a primitive type. + Type::MutableReference(element) => self + .interner + .lookup_primitive_trait_method_mut(element.as_ref(), method_name) + .map(HirMethodReference::FuncId) + .or_else(|| self.lookup_method(&element, method_name, span)), + + // If we fail to resolve the object to a struct type, we have no way of type + // checking its arguments as we can't even resolve the name of the function + Type::Error => None, + + // The type variable must be unbound at this point since follow_bindings was called + Type::TypeVariable(_, TypeVariableKind::Normal) => { + self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + None + } + + other => match self.interner.lookup_primitive_method(&other, method_name) { + Some(method_id) => Some(HirMethodReference::FuncId(method_id)), + None => { + self.push_err(TypeCheckError::UnresolvedMethodCall { + method_name: method_name.to_string(), + object_type: object_type.clone(), + span, + }); + None + } + }, + } + } + + pub(super) fn type_check_call( + &mut self, + call: &HirCallExpression, + func_type: Type, + args: Vec<(Type, ExprId, Span)>, + span: Span, + ) -> Type { + // Need to setup these flags here as `self` is borrowed mutably to type check the rest of the call expression + // These flags are later used to type check calls to unconstrained functions from constrained functions + let func_mod = self.current_function.map(|func| self.interner.function_modifiers(&func)); + let is_current_func_constrained = + func_mod.map_or(true, |func_mod| !func_mod.is_unconstrained); + + let is_unconstrained_call = self.is_unconstrained_call(call.func); + self.check_if_deprecated(call.func); + + // Check that we are not passing a mutable reference from a constrained runtime to an unconstrained runtime + if is_current_func_constrained && is_unconstrained_call { + for (typ, _, _) in args.iter() { + if matches!(&typ.follow_bindings(), Type::MutableReference(_)) { + self.push_err(TypeCheckError::ConstrainedReferenceToUnconstrained { span }); + } + } + } + + let return_type = self.bind_function_type(func_type, args, span); + + // Check that we are not passing a slice from an unconstrained runtime to a constrained runtime + if is_current_func_constrained && is_unconstrained_call { + if return_type.contains_slice() { + self.push_err(TypeCheckError::UnconstrainedSliceReturnToConstrained { span }); + } else if matches!(&return_type.follow_bindings(), Type::MutableReference(_)) { + self.push_err(TypeCheckError::UnconstrainedReferenceToConstrained { span }); + } + }; + + return_type + } + + fn check_if_deprecated(&mut self, expr: ExprId) { + if let HirExpression::Ident(HirIdent { location, id, impl_kind: _ }) = + self.interner.expression(&expr) + { + if let Some(DefinitionKind::Function(func_id)) = + self.interner.try_definition(id).map(|def| &def.kind) + { + let attributes = self.interner.function_attributes(func_id); + if let Some(note) = attributes.get_deprecated_note() { + self.push_err(TypeCheckError::CallDeprecated { + name: self.interner.definition_name(id).to_string(), + note, + span: location.span, + }); + } + } + } + } + + fn is_unconstrained_call(&self, expr: ExprId) -> bool { + if let HirExpression::Ident(HirIdent { id, .. }) = self.interner.expression(&expr) { + if let Some(DefinitionKind::Function(func_id)) = + self.interner.try_definition(id).map(|def| &def.kind) + { + let modifiers = self.interner.function_modifiers(func_id); + return modifiers.is_unconstrained; + } + } + false + } + + /// Check if the given method type requires a mutable reference to the object type, and check + /// if the given object type is already a mutable reference. If not, add one. + /// This is used to automatically transform a method call: `foo.bar()` into a function + /// call: `bar(&mut foo)`. + /// + /// A notable corner case of this function is where it interacts with auto-deref of `.`. + /// If a field is being mutated e.g. `foo.bar.mutate_bar()` where `foo: &mut Foo`, the compiler + /// will insert a dereference before bar `(*foo).bar.mutate_bar()` which would cause us to + /// mutate a copy of bar rather than a reference to it. We must check for this corner case here + /// and remove the implicitly added dereference operator if we find one. + pub(super) fn try_add_mutable_reference_to_object( + &mut self, + function_type: &Type, + object_type: &mut Type, + object: &mut ExprId, + ) { + let expected_object_type = match function_type { + Type::Function(args, _, _) => args.first(), + Type::Forall(_, typ) => match typ.as_ref() { + Type::Function(args, _, _) => args.first(), + typ => unreachable!("Unexpected type for function: {typ}"), + }, + typ => unreachable!("Unexpected type for function: {typ}"), + }; + + if let Some(expected_object_type) = expected_object_type { + let actual_type = object_type.follow_bindings(); + + if matches!(expected_object_type.follow_bindings(), Type::MutableReference(_)) { + if !matches!(actual_type, Type::MutableReference(_)) { + if let Err(error) = verify_mutable_reference(self.interner, *object) { + self.push_err(TypeCheckError::ResolverError(error)); + } + + let new_type = Type::MutableReference(Box::new(actual_type)); + *object_type = new_type.clone(); + + // First try to remove a dereference operator that may have been implicitly + // inserted by a field access expression `foo.bar` on a mutable reference `foo`. + let new_object = self.try_remove_implicit_dereference(*object); + + // If that didn't work, then wrap the whole expression in an `&mut` + *object = new_object.unwrap_or_else(|| { + let location = self.interner.id_location(*object); + + let new_object = + self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { + operator: UnaryOp::MutableReference, + rhs: *object, + })); + self.interner.push_expr_type(new_object, new_type); + self.interner.push_expr_location(new_object, location.span, location.file); + new_object + }); + } + // Otherwise if the object type is a mutable reference and the method is not, insert as + // many dereferences as needed. + } else if matches!(actual_type, Type::MutableReference(_)) { + let (new_object, new_type) = self.insert_auto_dereferences(*object, actual_type); + *object_type = new_type; + *object = new_object; + } + } + } + + pub fn type_check_function_body(&mut self, body_type: Type, meta: &FuncMeta, body_id: ExprId) { + let (expr_span, empty_function) = self.function_info(body_id); + let declared_return_type = meta.return_type(); + + let func_span = self.interner.expr_span(&body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet + if let Type::TraitAsType(trait_id, _, generics) = declared_return_type { + if self.interner.lookup_trait_implementation(&body_type, *trait_id, generics).is_err() { + self.push_err(TypeCheckError::TypeMismatchWithSource { + expected: declared_return_type.clone(), + actual: body_type, + span: func_span, + source: Source::Return(meta.return_type.clone(), expr_span), + }); + } + } else { + self.unify_with_coercions(&body_type, declared_return_type, body_id, || { + let mut error = TypeCheckError::TypeMismatchWithSource { + expected: declared_return_type.clone(), + actual: body_type.clone(), + span: func_span, + source: Source::Return(meta.return_type.clone(), expr_span), + }; + + if empty_function { + error = error.add_context( + "implicitly returns `()` as its body has no tail or `return` expression", + ); + } + error + }); + } + } + + fn function_info(&self, function_body_id: ExprId) -> (noirc_errors::Span, bool) { + let (expr_span, empty_function) = + if let HirExpression::Block(block) = self.interner.expression(&function_body_id) { + let last_stmt = block.statements().last(); + let mut span = self.interner.expr_span(&function_body_id); + + if let Some(last_stmt) = last_stmt { + if let HirStatement::Expression(expr) = self.interner.statement(last_stmt) { + span = self.interner.expr_span(&expr); + } + } + + (span, last_stmt.is_none()) + } else { + (self.interner.expr_span(&function_body_id), false) + }; + (expr_span, empty_function) + } + + pub fn verify_trait_constraint( + &mut self, + object_type: &Type, + trait_id: TraitId, + trait_generics: &[Type], + function_ident_id: ExprId, + span: Span, + ) { + match self.interner.lookup_trait_implementation(object_type, trait_id, trait_generics) { + Ok(impl_kind) => { + self.interner.select_impl_for_expression(function_ident_id, impl_kind); + } + Err(erroring_constraints) => { + if erroring_constraints.is_empty() { + self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + } else { + // Don't show any errors where try_get_trait returns None. + // This can happen if a trait is used that was never declared. + let constraints = erroring_constraints + .into_iter() + .map(|constraint| { + let r#trait = self.interner.try_get_trait(constraint.trait_id)?; + let mut name = r#trait.name.to_string(); + if !constraint.trait_generics.is_empty() { + let generics = + vecmap(&constraint.trait_generics, ToString::to_string); + name += &format!("<{}>", generics.join(", ")); + } + Some((constraint.typ, name)) + }) + .collect::>>(); + + if let Some(constraints) = constraints { + self.push_err(TypeCheckError::NoMatchingImplFound { constraints, span }); + } + } + } + } + } +} diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 26b7c212a30..84df3a0a244 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -401,6 +401,14 @@ impl<'a> Interpreter<'a> { let value = if is_negative { 0u8.wrapping_sub(value) } else { value }; Ok(Value::U8(value)) } + (Signedness::Unsigned, IntegerBitSize::Sixteen) => { + let value: u16 = + value.try_to_u64().and_then(|value| value.try_into().ok()).ok_or( + InterpreterError::IntegerOutOfRangeForType { value, typ, location }, + )?; + let value = if is_negative { 0u16.wrapping_sub(value) } else { value }; + Ok(Value::U16(value)) + } (Signedness::Unsigned, IntegerBitSize::ThirtyTwo) => { let value: u32 = value.try_to_u64().and_then(|value| value.try_into().ok()).ok_or( @@ -430,6 +438,14 @@ impl<'a> Interpreter<'a> { let value = if is_negative { -value } else { value }; Ok(Value::I8(value)) } + (Signedness::Signed, IntegerBitSize::Sixteen) => { + let value: i16 = + value.try_to_u64().and_then(|value| value.try_into().ok()).ok_or( + InterpreterError::IntegerOutOfRangeForType { value, typ, location }, + )?; + let value = if is_negative { -value } else { value }; + Ok(Value::I16(value)) + } (Signedness::Signed, IntegerBitSize::ThirtyTwo) => { let value: i32 = value.try_to_u64().and_then(|value| value.try_into().ok()).ok_or( @@ -509,9 +525,11 @@ impl<'a> Interpreter<'a> { crate::ast::UnaryOp::Minus => match rhs { Value::Field(value) => Ok(Value::Field(FieldElement::zero() - value)), Value::I8(value) => Ok(Value::I8(-value)), + Value::I16(value) => Ok(Value::I16(-value)), Value::I32(value) => Ok(Value::I32(-value)), Value::I64(value) => Ok(Value::I64(-value)), Value::U8(value) => Ok(Value::U8(0 - value)), + Value::U16(value) => Ok(Value::U16(0 - value)), Value::U32(value) => Ok(Value::U32(0 - value)), Value::U64(value) => Ok(Value::U64(0 - value)), value => { @@ -523,9 +541,11 @@ impl<'a> Interpreter<'a> { crate::ast::UnaryOp::Not => match rhs { Value::Bool(value) => Ok(Value::Bool(!value)), Value::I8(value) => Ok(Value::I8(!value)), + Value::I16(value) => Ok(Value::I16(!value)), Value::I32(value) => Ok(Value::I32(!value)), Value::I64(value) => Ok(Value::I64(!value)), Value::U8(value) => Ok(Value::U8(!value)), + Value::U16(value) => Ok(Value::U16(!value)), Value::U32(value) => Ok(Value::U32(!value)), Value::U64(value) => Ok(Value::U64(!value)), value => { @@ -559,9 +579,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::Add => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs + rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs + rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs + rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs + rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs + rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs + rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs + rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs + rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs + rhs)), (lhs, rhs) => { @@ -572,9 +594,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::Subtract => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs - rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs - rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs - rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs - rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs - rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs - rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs - rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs - rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs - rhs)), (lhs, rhs) => { @@ -585,9 +609,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::Multiply => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs * rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs * rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs * rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs * rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs * rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs * rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs * rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs * rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs * rhs)), (lhs, rhs) => { @@ -598,9 +624,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::Divide => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs / rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs / rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs / rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs / rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs / rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs / rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs / rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs / rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs / rhs)), (lhs, rhs) => { @@ -611,9 +639,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::Equal => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs == rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs == rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs == rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs == rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs == rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs == rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs == rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs == rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs == rhs)), (lhs, rhs) => { @@ -624,9 +654,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::NotEqual => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs != rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs != rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs != rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs != rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs != rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs != rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs != rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs != rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs != rhs)), (lhs, rhs) => { @@ -637,9 +669,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::Less => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs < rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs < rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs < rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs < rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs < rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs < rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs < rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs < rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs < rhs)), (lhs, rhs) => { @@ -650,9 +684,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::LessEqual => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs <= rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs <= rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs <= rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs <= rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs <= rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs <= rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs <= rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs <= rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs <= rhs)), (lhs, rhs) => { @@ -663,9 +699,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::Greater => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs > rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs > rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs > rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs > rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs > rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs > rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs > rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs > rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs > rhs)), (lhs, rhs) => { @@ -676,9 +714,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::GreaterEqual => match (lhs, rhs) { (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs >= rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs >= rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::Bool(lhs >= rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs >= rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs >= rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs >= rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::Bool(lhs >= rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs >= rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs >= rhs)), (lhs, rhs) => { @@ -689,9 +729,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::And => match (lhs, rhs) { (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs & rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs & rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs & rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs & rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs & rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs & rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs & rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs & rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs & rhs)), (lhs, rhs) => { @@ -702,9 +744,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::Or => match (lhs, rhs) { (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs | rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs | rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs | rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs | rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs | rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs | rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs | rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs | rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs | rhs)), (lhs, rhs) => { @@ -715,9 +759,11 @@ impl<'a> Interpreter<'a> { BinaryOpKind::Xor => match (lhs, rhs) { (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs ^ rhs)), (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs ^ rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs ^ rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs ^ rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs ^ rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs ^ rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs ^ rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs ^ rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs ^ rhs)), (lhs, rhs) => { @@ -727,9 +773,11 @@ impl<'a> Interpreter<'a> { }, BinaryOpKind::ShiftRight => match (lhs, rhs) { (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs >> rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs >> rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs >> rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs >> rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs >> rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs >> rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs >> rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs >> rhs)), (lhs, rhs) => { @@ -739,9 +787,11 @@ impl<'a> Interpreter<'a> { }, BinaryOpKind::ShiftLeft => match (lhs, rhs) { (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs << rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs << rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs << rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs << rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs << rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs << rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs << rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs << rhs)), (lhs, rhs) => { @@ -751,9 +801,11 @@ impl<'a> Interpreter<'a> { }, BinaryOpKind::Modulo => match (lhs, rhs) { (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs % rhs)), + (Value::I16(lhs), Value::I16(rhs)) => Ok(Value::I16(lhs % rhs)), (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs % rhs)), (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs % rhs)), (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs % rhs)), + (Value::U16(lhs), Value::U16(rhs)) => Ok(Value::U16(lhs % rhs)), (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs % rhs)), (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs % rhs)), (lhs, rhs) => { @@ -795,9 +847,11 @@ impl<'a> Interpreter<'a> { value.try_to_u64().expect("index could not fit into u64") as usize } Value::I8(value) => value as usize, + Value::I16(value) => value as usize, Value::I32(value) => value as usize, Value::I64(value) => value as usize, Value::U8(value) => value as usize, + Value::U16(value) => value as usize, Value::U32(value) => value as usize, Value::U64(value) => value as usize, value => { @@ -908,9 +962,11 @@ impl<'a> Interpreter<'a> { let (mut lhs, lhs_is_negative) = match self.evaluate(cast.lhs)? { Value::Field(value) => (value, false), Value::U8(value) => ((value as u128).into(), false), + Value::U16(value) => ((value as u128).into(), false), Value::U32(value) => ((value as u128).into(), false), Value::U64(value) => ((value as u128).into(), false), Value::I8(value) => signed_int_to_field!(value), + Value::I16(value) => signed_int_to_field!(value), Value::I32(value) => signed_int_to_field!(value), Value::I64(value) => signed_int_to_field!(value), Value::Bool(value) => { @@ -946,6 +1002,9 @@ impl<'a> Interpreter<'a> { Err(InterpreterError::TypeUnsupported { typ: cast.r#type, location }) } (Signedness::Unsigned, IntegerBitSize::Eight) => cast_to_int!(lhs, to_u128, u8, U8), + (Signedness::Unsigned, IntegerBitSize::Sixteen) => { + cast_to_int!(lhs, to_u128, u16, U16) + } (Signedness::Unsigned, IntegerBitSize::ThirtyTwo) => { cast_to_int!(lhs, to_u128, u32, U32) } @@ -957,6 +1016,9 @@ impl<'a> Interpreter<'a> { Err(InterpreterError::TypeUnsupported { typ: cast.r#type, location }) } (Signedness::Signed, IntegerBitSize::Eight) => cast_to_int!(lhs, to_i128, i8, I8), + (Signedness::Signed, IntegerBitSize::Sixteen) => { + cast_to_int!(lhs, to_i128, i16, I16) + } (Signedness::Signed, IntegerBitSize::ThirtyTwo) => { cast_to_int!(lhs, to_i128, i32, I32) } @@ -1149,9 +1211,11 @@ impl<'a> Interpreter<'a> { let get_index = |this: &mut Self, expr| -> IResult<(_, fn(_) -> _)> { match this.evaluate(expr)? { Value::I8(value) => Ok((value as i128, |i| Value::I8(i as i8))), + Value::I16(value) => Ok((value as i128, |i| Value::I16(i as i16))), Value::I32(value) => Ok((value as i128, |i| Value::I32(i as i32))), Value::I64(value) => Ok((value as i128, |i| Value::I64(i as i64))), Value::U8(value) => Ok((value as i128, |i| Value::U8(i as u8))), + Value::U16(value) => Ok((value as i128, |i| Value::U16(i as u16))), Value::U32(value) => Ok((value as i128, |i| Value::U32(i as u32))), Value::U64(value) => Ok((value as i128, |i| Value::U64(i as u64))), value => { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/tests.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/tests.rs index 5a12eb7292c..41475d3ccf4 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/tests.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/tests.rs @@ -103,6 +103,19 @@ fn for_loop() { assert_eq!(result, Value::U8(15)); } +#[test] +fn for_loop_u16() { + let program = "fn main() -> pub u16 { + let mut x = 0; + for i in 0 .. 6 { + x += i; + } + x + }"; + let result = interpret(program, vec!["main".into()]); + assert_eq!(result, Value::U16(15)); +} + #[test] fn for_loop_with_break() { let program = "unconstrained fn main() -> pub u32 { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs index 6845c6ac5a9..4e4a260871a 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -22,9 +22,11 @@ pub enum Value { Bool(bool), Field(FieldElement), I8(i8), + I16(i16), I32(i32), I64(i64), U8(u8), + U16(u16), U32(u32), U64(u64), String(Rc), @@ -45,9 +47,11 @@ impl Value { Value::Bool(_) => Type::Bool, Value::Field(_) => Type::FieldElement, Value::I8(_) => Type::Integer(Signedness::Signed, IntegerBitSize::Eight), + Value::I16(_) => Type::Integer(Signedness::Signed, IntegerBitSize::Sixteen), Value::I32(_) => Type::Integer(Signedness::Signed, IntegerBitSize::ThirtyTwo), Value::I64(_) => Type::Integer(Signedness::Signed, IntegerBitSize::SixtyFour), Value::U8(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), + Value::U16(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::Sixteen), Value::U32(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo), Value::U64(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::SixtyFour), Value::String(value) => { @@ -87,6 +91,12 @@ impl Value { let value = (value as u128).into(); HirExpression::Literal(HirLiteral::Integer(value, negative)) } + Value::I16(value) => { + let negative = value < 0; + let value = value.abs(); + let value = (value as u128).into(); + HirExpression::Literal(HirLiteral::Integer(value, negative)) + } Value::I32(value) => { let negative = value < 0; let value = value.abs(); @@ -102,6 +112,9 @@ impl Value { Value::U8(value) => { HirExpression::Literal(HirLiteral::Integer((value as u128).into(), false)) } + Value::U16(value) => { + HirExpression::Literal(HirLiteral::Integer((value as u128).into(), false)) + } Value::U32(value) => { HirExpression::Literal(HirLiteral::Integer((value as u128).into(), false)) } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 2f6b101e62f..4aac0fec9c3 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -1,5 +1,6 @@ use super::dc_mod::collect_defs; use super::errors::{DefCollectorErrorKind, DuplicateType}; +use crate::elaborator::Elaborator; use crate::graph::CrateId; use crate::hir::comptime::{Interpreter, InterpreterError}; use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleId}; @@ -129,14 +130,18 @@ pub struct UnresolvedGlobal { /// Given a Crate root, collect all definitions in that crate pub struct DefCollector { pub(crate) def_map: CrateDefMap, - pub(crate) collected_imports: Vec, - pub(crate) collected_functions: Vec, - pub(crate) collected_types: BTreeMap, - pub(crate) collected_type_aliases: BTreeMap, - pub(crate) collected_traits: BTreeMap, - pub(crate) collected_globals: Vec, - pub(crate) collected_impls: ImplMap, - pub(crate) collected_traits_impls: Vec, + pub(crate) imports: Vec, + pub(crate) items: CollectedItems, +} + +pub struct CollectedItems { + pub(crate) functions: Vec, + pub(crate) types: BTreeMap, + pub(crate) type_aliases: BTreeMap, + pub(crate) traits: BTreeMap, + pub(crate) globals: Vec, + pub(crate) impls: ImplMap, + pub(crate) trait_impls: Vec, } /// Maps the type and the module id in which the impl is defined to the functions contained in that @@ -210,14 +215,16 @@ impl DefCollector { fn new(def_map: CrateDefMap) -> DefCollector { DefCollector { def_map, - collected_imports: vec![], - collected_functions: vec![], - collected_types: BTreeMap::new(), - collected_type_aliases: BTreeMap::new(), - collected_traits: BTreeMap::new(), - collected_impls: HashMap::new(), - collected_globals: vec![], - collected_traits_impls: vec![], + imports: vec![], + items: CollectedItems { + functions: vec![], + types: BTreeMap::new(), + type_aliases: BTreeMap::new(), + traits: BTreeMap::new(), + impls: HashMap::new(), + globals: vec![], + trait_impls: vec![], + }, } } @@ -229,6 +236,7 @@ impl DefCollector { context: &mut Context, ast: SortedModule, root_file_id: FileId, + use_elaborator: bool, macro_processors: &[&dyn MacroProcessor], ) -> Vec<(CompilationError, FileId)> { let mut errors: Vec<(CompilationError, FileId)> = vec![]; @@ -242,7 +250,12 @@ impl DefCollector { let crate_graph = &context.crate_graph[crate_id]; for dep in crate_graph.dependencies.clone() { - errors.extend(CrateDefMap::collect_defs(dep.crate_id, context, macro_processors)); + errors.extend(CrateDefMap::collect_defs( + dep.crate_id, + context, + use_elaborator, + macro_processors, + )); let dep_def_root = context.def_map(&dep.crate_id).expect("ice: def map was just created").root; @@ -275,18 +288,13 @@ impl DefCollector { // Add the current crate to the collection of DefMaps context.def_maps.insert(crate_id, def_collector.def_map); - inject_prelude(crate_id, context, crate_root, &mut def_collector.collected_imports); + inject_prelude(crate_id, context, crate_root, &mut def_collector.imports); for submodule in submodules { - inject_prelude( - crate_id, - context, - LocalModuleId(submodule), - &mut def_collector.collected_imports, - ); + inject_prelude(crate_id, context, LocalModuleId(submodule), &mut def_collector.imports); } // Resolve unresolved imports collected from the crate, one by one. - for collected_import in def_collector.collected_imports { + for collected_import in std::mem::take(&mut def_collector.imports) { match resolve_import(crate_id, &collected_import, &context.def_maps) { Ok(resolved_import) => { if let Some(error) = resolved_import.error { @@ -323,6 +331,12 @@ impl DefCollector { } } + if use_elaborator { + let mut more_errors = Elaborator::elaborate(context, crate_id, def_collector.items); + more_errors.append(&mut errors); + return errors; + } + let mut resolved_module = ResolvedModule { errors, ..Default::default() }; // We must first resolve and intern the globals before we can resolve any stmts inside each function. @@ -330,26 +344,25 @@ impl DefCollector { // // Additionally, we must resolve integer globals before structs since structs may refer to // the values of integer globals as numeric generics. - let (literal_globals, other_globals) = - filter_literal_globals(def_collector.collected_globals); + let (literal_globals, other_globals) = filter_literal_globals(def_collector.items.globals); resolved_module.resolve_globals(context, literal_globals, crate_id); resolved_module.errors.extend(resolve_type_aliases( context, - def_collector.collected_type_aliases, + def_collector.items.type_aliases, crate_id, )); resolved_module.errors.extend(resolve_traits( context, - def_collector.collected_traits, + def_collector.items.traits, crate_id, )); // Must resolve structs before we resolve globals. resolved_module.errors.extend(resolve_structs( context, - def_collector.collected_types, + def_collector.items.types, crate_id, )); @@ -358,7 +371,7 @@ impl DefCollector { resolved_module.errors.extend(collect_trait_impls( context, crate_id, - &mut def_collector.collected_traits_impls, + &mut def_collector.items.trait_impls, )); // Before we resolve any function symbols we must go through our impls and @@ -368,11 +381,7 @@ impl DefCollector { // // These are resolved after trait impls so that struct methods are chosen // over trait methods if there are name conflicts. - resolved_module.errors.extend(collect_impls( - context, - crate_id, - &def_collector.collected_impls, - )); + resolved_module.errors.extend(collect_impls(context, crate_id, &def_collector.items.impls)); // We must wait to resolve non-integer globals until after we resolve structs since struct // globals will need to reference the struct type they're initialized to to ensure they are valid. @@ -383,7 +392,7 @@ impl DefCollector { &mut context.def_interner, crate_id, &context.def_maps, - def_collector.collected_functions, + def_collector.items.functions, None, &mut resolved_module.errors, ); @@ -392,13 +401,13 @@ impl DefCollector { &mut context.def_interner, crate_id, &context.def_maps, - def_collector.collected_impls, + def_collector.items.impls, &mut resolved_module.errors, )); resolved_module.trait_impl_functions = resolve_trait_impls( context, - def_collector.collected_traits_impls, + def_collector.items.trait_impls, crate_id, &mut resolved_module.errors, ); @@ -431,15 +440,18 @@ fn inject_prelude( crate_root: LocalModuleId, collected_imports: &mut Vec, ) { - let segments: Vec<_> = "std::prelude" - .split("::") - .map(|segment| crate::ast::Ident::new(segment.into(), Span::default())) - .collect(); + if !crate_id.is_stdlib() { + let segments: Vec<_> = "std::prelude" + .split("::") + .map(|segment| crate::ast::Ident::new(segment.into(), Span::default())) + .collect(); - let path = - Path { segments: segments.clone(), kind: crate::ast::PathKind::Dep, span: Span::default() }; + let path = Path { + segments: segments.clone(), + kind: crate::ast::PathKind::Dep, + span: Span::default(), + }; - if !crate_id.is_stdlib() { if let Ok(PathResolution { module_def_id, error }) = path_resolver::resolve_path( &context.def_maps, ModuleId { krate: crate_id, local_id: crate_root }, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index b2ec7dbc813..e688f192d3d 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -70,7 +70,7 @@ pub fn collect_defs( // Then add the imports to defCollector to resolve once all modules in the hierarchy have been resolved for import in ast.imports { - collector.def_collector.collected_imports.push(ImportDirective { + collector.def_collector.imports.push(ImportDirective { module_id: collector.module_id, path: import.path, alias: import.alias, @@ -126,7 +126,7 @@ impl<'a> ModCollector<'a> { errors.push((err.into(), self.file_id)); } - self.def_collector.collected_globals.push(UnresolvedGlobal { + self.def_collector.items.globals.push(UnresolvedGlobal { file_id: self.file_id, module_id: self.module_id, global_id, @@ -154,7 +154,7 @@ impl<'a> ModCollector<'a> { } let key = (r#impl.object_type, self.module_id); - let methods = self.def_collector.collected_impls.entry(key).or_default(); + let methods = self.def_collector.items.impls.entry(key).or_default(); methods.push((r#impl.generics, r#impl.type_span, unresolved_functions)); } } @@ -191,7 +191,7 @@ impl<'a> ModCollector<'a> { trait_generics: trait_impl.trait_generics, }; - self.def_collector.collected_traits_impls.push(unresolved_trait_impl); + self.def_collector.items.trait_impls.push(unresolved_trait_impl); } } @@ -269,7 +269,7 @@ impl<'a> ModCollector<'a> { } } - self.def_collector.collected_functions.push(unresolved_functions); + self.def_collector.items.functions.push(unresolved_functions); errors } @@ -316,7 +316,7 @@ impl<'a> ModCollector<'a> { } // And store the TypeId -> StructType mapping somewhere it is reachable - self.def_collector.collected_types.insert(id, unresolved); + self.def_collector.items.types.insert(id, unresolved); } definition_errors } @@ -354,7 +354,7 @@ impl<'a> ModCollector<'a> { errors.push((err.into(), self.file_id)); } - self.def_collector.collected_type_aliases.insert(type_alias_id, unresolved); + self.def_collector.items.type_aliases.insert(type_alias_id, unresolved); } errors } @@ -506,7 +506,7 @@ impl<'a> ModCollector<'a> { method_ids, fns_with_default_impl: unresolved_functions, }; - self.def_collector.collected_traits.insert(trait_id, unresolved); + self.def_collector.items.traits.insert(trait_id, unresolved); } errors } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs index 590c2e3d6b6..19e06387d43 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/def_map/mod.rs @@ -73,6 +73,7 @@ impl CrateDefMap { pub fn collect_defs( crate_id: CrateId, context: &mut Context, + use_elaborator: bool, macro_processors: &[&dyn MacroProcessor], ) -> Vec<(CompilationError, FileId)> { // Check if this Crate has already been compiled @@ -116,7 +117,14 @@ impl CrateDefMap { }; // Now we want to populate the CrateDefMap using the DefCollector - errors.extend(DefCollector::collect(def_map, context, ast, root_file_id, macro_processors)); + errors.extend(DefCollector::collect( + def_map, + context, + ast, + root_file_id, + use_elaborator, + macro_processors, + )); errors.extend( parsing_errors.iter().map(|e| (e.clone().into(), root_file_id)).collect::>(), diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/import.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/import.rs index 8850331f683..343113836ed 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/import.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/import.rs @@ -2,11 +2,14 @@ use noirc_errors::{CustomDiagnostic, Span}; use thiserror::Error; use crate::graph::CrateId; +use crate::hir::def_collector::dc_crate::CompilationError; use std::collections::BTreeMap; use crate::ast::{Ident, ItemVisibility, Path, PathKind}; use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleDefId, ModuleId, PerNs}; +use super::errors::ResolverError; + #[derive(Debug, Clone)] pub struct ImportDirective { pub module_id: LocalModuleId, @@ -53,6 +56,12 @@ pub struct ResolvedImport { pub error: Option, } +impl From for CompilationError { + fn from(error: PathResolutionError) -> Self { + Self::ResolverError(ResolverError::PathResolutionError(error)) + } +} + impl<'a> From<&'a PathResolutionError> for CustomDiagnostic { fn from(error: &'a PathResolutionError) -> Self { match &error { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 60baaecab59..7dc307fe716 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -56,17 +56,17 @@ use crate::hir_def::{ use super::errors::{PubPosition, ResolverError}; use super::import::PathResolution; -const SELF_TYPE_NAME: &str = "Self"; +pub const SELF_TYPE_NAME: &str = "Self"; type Scope = GenericScope; type ScopeTree = GenericScopeTree; type ScopeForest = GenericScopeForest; pub struct LambdaContext { - captures: Vec, + pub captures: Vec, /// the index in the scope tree /// (sometimes being filled by ScopeTree's find method) - scope_index: usize, + pub scope_index: usize, } /// The primary jobs of the Resolver are to validate that every variable found refers to exactly 1 @@ -1345,7 +1345,7 @@ impl<'a> Resolver<'a> { range @ ForRange::Array(_) => { let for_stmt = range.into_for(for_loop.identifier, for_loop.block, for_loop.span); - self.resolve_stmt(for_stmt, for_loop.span) + self.resolve_stmt(for_stmt.kind, for_loop.span) } } } @@ -1361,7 +1361,7 @@ impl<'a> Resolver<'a> { StatementKind::Comptime(statement) => { let hir_statement = self.resolve_stmt(statement.kind, statement.span); let statement_id = self.interner.push_stmt(hir_statement); - self.interner.push_statement_location(statement_id, statement.span, self.file); + self.interner.push_stmt_location(statement_id, statement.span, self.file); HirStatement::Comptime(statement_id) } } @@ -1370,7 +1370,7 @@ impl<'a> Resolver<'a> { pub fn intern_stmt(&mut self, stmt: Statement) -> StmtId { let hir_stmt = self.resolve_stmt(stmt.kind, stmt.span); let id = self.interner.push_stmt(hir_stmt); - self.interner.push_statement_location(id, stmt.span, self.file); + self.interner.push_stmt_location(id, stmt.span, self.file); id } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/expr.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/expr.rs index 9b40c959981..48598109829 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -250,14 +250,14 @@ impl<'interner> TypeChecker<'interner> { } // TODO: update object_type here? - let function_call = method_call.into_function_call( + let (_, function_call) = method_call.into_function_call( &method_ref, object_type, location, self.interner, ); - self.interner.replace_expr(expr_id, function_call); + self.interner.replace_expr(expr_id, HirExpression::Call(function_call)); // Type check the new call now that it has been changed from a method call // to a function call. This way we avoid duplicating code. diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/mod.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/mod.rs index 0f8131d6ebb..2e448858d9e 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -25,7 +25,7 @@ use crate::{ Type, TypeBindings, }; -use self::errors::Source; +pub use self::errors::Source; pub struct TypeChecker<'interner> { interner: &'interner mut NodeInterner, diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs index bf7d9b7b4ba..8df6785e0eb 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/expr.rs @@ -200,13 +200,15 @@ pub enum HirMethodReference { impl HirMethodCallExpression { /// Converts a method call into a function call + /// + /// Returns ((func_var_id, func_var), call_expr) pub fn into_function_call( mut self, method: &HirMethodReference, object_type: Type, location: Location, interner: &mut NodeInterner, - ) -> HirExpression { + ) -> ((ExprId, HirIdent), HirCallExpression) { let mut arguments = vec![self.object]; arguments.append(&mut self.arguments); @@ -224,10 +226,11 @@ impl HirMethodCallExpression { (id, ImplKind::TraitMethod(*method_id, constraint, false)) } }; - let func = HirExpression::Ident(HirIdent { location, id, impl_kind }); - let func = interner.push_expr(func); + let func_var = HirIdent { location, id, impl_kind }; + let func = interner.push_expr(HirExpression::Ident(func_var.clone())); interner.push_expr_location(func, location.span, location.file); - HirExpression::Call(HirCallExpression { func, arguments, location }) + let expr = HirCallExpression { func, arguments, location }; + ((func, func_var), expr) } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs index c38dd41fd3d..ceec9ad8580 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/function.rs @@ -135,10 +135,7 @@ impl FuncMeta { /// So this method tells the type checker to ignore the return /// of the empty function, which is unit pub fn can_ignore_return_type(&self) -> bool { - match self.kind { - FunctionKind::LowLevel | FunctionKind::Builtin | FunctionKind::Oracle => true, - FunctionKind::Normal | FunctionKind::Recursive => false, - } + self.kind.can_ignore_return_type() } pub fn function_signature(&self) -> FunctionSignature { diff --git a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs index f3b2a24c1f0..f31aeea0552 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/hir_def/types.rs @@ -1423,14 +1423,14 @@ impl Type { /// Retrieves the type of the given field name /// Panics if the type is not a struct or tuple. - pub fn get_field_type(&self, field_name: &str) -> Type { + pub fn get_field_type(&self, field_name: &str) -> Option { match self { - Type::Struct(def, args) => def.borrow().get_field(field_name, args).unwrap().0, + Type::Struct(def, args) => def.borrow().get_field(field_name, args).map(|(typ, _)| typ), Type::Tuple(fields) => { let mut fields = fields.iter().enumerate(); - fields.find(|(i, _)| i.to_string() == *field_name).unwrap().1.clone() + fields.find(|(i, _)| i.to_string() == *field_name).map(|(_, typ)| typ).cloned() } - other => panic!("Tried to iterate over the fields of '{other}', which has none"), + _ => None, } } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/lib.rs b/noir/noir-repo/compiler/noirc_frontend/src/lib.rs index 958a18ac2fb..b05c635f436 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/lib.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/lib.rs @@ -12,6 +12,7 @@ pub mod ast; pub mod debug; +pub mod elaborator; pub mod graph; pub mod lexer; pub mod monomorphization; diff --git a/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs b/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs index 88adc7a9414..faf89016f96 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/node_interner.rs @@ -532,7 +532,7 @@ impl NodeInterner { self.id_to_type.insert(expr_id.into(), typ); } - /// Store the type for an interned expression + /// Store the type for a definition pub fn push_definition_type(&mut self, definition_id: DefinitionId, typ: Type) { self.definition_to_type.insert(definition_id, typ); } @@ -696,7 +696,7 @@ impl NodeInterner { let statement = self.push_stmt(HirStatement::Error); let span = name.span(); let id = self.push_global(name, local_id, statement, file, attributes, mutable); - self.push_statement_location(statement, span, file); + self.push_stmt_location(statement, span, file); id } @@ -942,7 +942,7 @@ impl NodeInterner { self.id_location(stmt_id) } - pub fn push_statement_location(&mut self, id: StmtId, span: Span, file: FileId) { + pub fn push_stmt_location(&mut self, id: StmtId, span: Span, file: FileId) { self.id_to_location.insert(id.into(), Location::new(span, file)); } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser.rs b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser.rs index b627714d2a6..b527284d1a9 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/parser/parser.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/parser/parser.rs @@ -1374,7 +1374,7 @@ mod test { fresh_statement(), true, ), - vec!["x as u8", "0 as Field", "(x + 3) as [Field; 8]"], + vec!["x as u8", "x as u16", "0 as Field", "(x + 3) as [Field; 8]"], ); parse_all_failing( atom_or_right_unary( @@ -1546,7 +1546,10 @@ mod test { // Let statements are not type checked here, so the parser will accept as // long as it is a type. Other statements such as Public are type checked // Because for now, they can only have one type - parse_all(declaration(expression()), vec!["let _ = 42", "let x = y", "let x : u8 = y"]); + parse_all( + declaration(expression()), + vec!["let _ = 42", "let x = y", "let x : u8 = y", "let x: u16 = y"], + ); } #[test] diff --git a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs index 5f99e9e347a..fb80a7d8018 100644 --- a/noir/noir-repo/compiler/noirc_frontend/src/tests.rs +++ b/noir/noir-repo/compiler/noirc_frontend/src/tests.rs @@ -1,1236 +1,1215 @@ +#![cfg(test)] + +#[cfg(test)] +mod name_shadowing; + // XXX: These tests repeat a lot of code // what we should do is have test cases which are passed to a test harness // A test harness will allow for more expressive and readable tests -#[cfg(test)] -mod test { - - use core::panic; - use std::collections::BTreeMap; - - use fm::FileId; - - use iter_extended::vecmap; - use noirc_errors::Location; - - use crate::hir::def_collector::dc_crate::CompilationError; - use crate::hir::def_collector::errors::{DefCollectorErrorKind, DuplicateType}; - use crate::hir::def_map::ModuleData; - use crate::hir::resolution::errors::ResolverError; - use crate::hir::resolution::import::PathResolutionError; - use crate::hir::type_check::TypeCheckError; - use crate::hir::Context; - use crate::node_interner::{NodeInterner, StmtId}; - - use crate::hir::def_collector::dc_crate::DefCollector; - use crate::hir_def::expr::HirExpression; - use crate::hir_def::stmt::HirStatement; - use crate::monomorphization::monomorphize; - use crate::parser::ParserErrorReason; - use crate::ParsedModule; - use crate::{ - hir::def_map::{CrateDefMap, LocalModuleId}, - parse_program, - }; - use fm::FileManager; - use noirc_arena::Arena; +use core::panic; +use std::collections::BTreeMap; + +use fm::FileId; + +use iter_extended::vecmap; +use noirc_errors::Location; + +use crate::hir::def_collector::dc_crate::CompilationError; +use crate::hir::def_collector::errors::{DefCollectorErrorKind, DuplicateType}; +use crate::hir::def_map::ModuleData; +use crate::hir::resolution::errors::ResolverError; +use crate::hir::resolution::import::PathResolutionError; +use crate::hir::type_check::TypeCheckError; +use crate::hir::Context; +use crate::node_interner::{NodeInterner, StmtId}; + +use crate::hir::def_collector::dc_crate::DefCollector; +use crate::hir_def::expr::HirExpression; +use crate::hir_def::stmt::HirStatement; +use crate::monomorphization::monomorphize; +use crate::parser::ParserErrorReason; +use crate::ParsedModule; +use crate::{ + hir::def_map::{CrateDefMap, LocalModuleId}, + parse_program, +}; +use fm::FileManager; +use noirc_arena::Arena; + +pub(crate) fn has_parser_error(errors: &[(CompilationError, FileId)]) -> bool { + errors.iter().any(|(e, _f)| matches!(e, CompilationError::ParseError(_))) +} - pub(crate) fn has_parser_error(errors: &[(CompilationError, FileId)]) -> bool { - errors.iter().any(|(e, _f)| matches!(e, CompilationError::ParseError(_))) - } +pub(crate) fn remove_experimental_warnings(errors: &mut Vec<(CompilationError, FileId)>) { + errors.retain(|(error, _)| match error { + CompilationError::ParseError(error) => { + !matches!(error.reason(), Some(ParserErrorReason::ExperimentalFeature(..))) + } + _ => true, + }); +} - pub(crate) fn remove_experimental_warnings(errors: &mut Vec<(CompilationError, FileId)>) { - errors.retain(|(error, _)| match error { - CompilationError::ParseError(error) => { - !matches!(error.reason(), Some(ParserErrorReason::ExperimentalFeature(..))) - } - _ => true, - }); - } - - pub(crate) fn get_program( - src: &str, - ) -> (ParsedModule, Context, Vec<(CompilationError, FileId)>) { - let root = std::path::Path::new("/"); - let fm = FileManager::new(root); - - let mut context = Context::new(fm, Default::default()); - context.def_interner.populate_dummy_operator_traits(); - let root_file_id = FileId::dummy(); - let root_crate_id = context.crate_graph.add_crate_root(root_file_id); - - let (program, parser_errors) = parse_program(src); - let mut errors = vecmap(parser_errors, |e| (e.into(), root_file_id)); - remove_experimental_warnings(&mut errors); - - if !has_parser_error(&errors) { - // Allocate a default Module for the root, giving it a ModuleId - let mut modules: Arena = Arena::default(); - let location = Location::new(Default::default(), root_file_id); - let root = modules.insert(ModuleData::new(None, location, false)); - - let def_map = CrateDefMap { - root: LocalModuleId(root), - modules, - krate: root_crate_id, - extern_prelude: BTreeMap::new(), - }; +pub(crate) fn get_program(src: &str) -> (ParsedModule, Context, Vec<(CompilationError, FileId)>) { + let root = std::path::Path::new("/"); + let fm = FileManager::new(root); + + let mut context = Context::new(fm, Default::default()); + context.def_interner.populate_dummy_operator_traits(); + let root_file_id = FileId::dummy(); + let root_crate_id = context.crate_graph.add_crate_root(root_file_id); + + let (program, parser_errors) = parse_program(src); + let mut errors = vecmap(parser_errors, |e| (e.into(), root_file_id)); + remove_experimental_warnings(&mut errors); + + if !has_parser_error(&errors) { + // Allocate a default Module for the root, giving it a ModuleId + let mut modules: Arena = Arena::default(); + let location = Location::new(Default::default(), root_file_id); + let root = modules.insert(ModuleData::new(None, location, false)); + + let def_map = CrateDefMap { + root: LocalModuleId(root), + modules, + krate: root_crate_id, + extern_prelude: BTreeMap::new(), + }; - // Now we want to populate the CrateDefMap using the DefCollector - errors.extend(DefCollector::collect( - def_map, - &mut context, - program.clone().into_sorted(), - root_file_id, - &[], // No macro processors - )); - } - (program, context, errors) + // Now we want to populate the CrateDefMap using the DefCollector + errors.extend(DefCollector::collect( + def_map, + &mut context, + program.clone().into_sorted(), + root_file_id, + false, + &[], // No macro processors + )); } + (program, context, errors) +} - pub(crate) fn get_program_errors(src: &str) -> Vec<(CompilationError, FileId)> { - get_program(src).2 - } +pub(crate) fn get_program_errors(src: &str) -> Vec<(CompilationError, FileId)> { + get_program(src).2 +} - #[test] - fn check_trait_implemented_for_all_t() { - let src = " - trait Default { - fn default() -> Self; - } - - trait Eq { - fn eq(self, other: Self) -> bool; +#[test] +fn check_trait_implemented_for_all_t() { + let src = " + trait Default { + fn default() -> Self; + } + + trait Eq { + fn eq(self, other: Self) -> bool; + } + + trait IsDefault { + fn is_default(self) -> bool; + } + + impl IsDefault for T where T: Default + Eq { + fn is_default(self) -> bool { + self.eq(T::default()) } - - trait IsDefault { - fn is_default(self) -> bool; + } + + struct Foo { + a: u64, + } + + impl Eq for Foo { + fn eq(self, other: Foo) -> bool { self.a == other.a } + } + + impl Default for u64 { + fn default() -> Self { + 0 } - - impl IsDefault for T where T: Default + Eq { - fn is_default(self) -> bool { - self.eq(T::default()) - } + } + + impl Default for Foo { + fn default() -> Self { + Foo { a: Default::default() } } - - struct Foo { - a: u64, + } + + fn main(a: Foo) -> pub bool { + a.is_default() + }"; + + let errors = get_program_errors(src); + errors.iter().for_each(|err| println!("{:?}", err)); + assert!(errors.is_empty()); +} + +#[test] +fn check_trait_implementation_duplicate_method() { + let src = " + trait Default { + fn default(x: Field, y: Field) -> Field; + } + + struct Foo { + bar: Field, + array: [Field; 2], + } + + impl Default for Foo { + // Duplicate trait methods should not compile + fn default(x: Field, y: Field) -> Field { + y + 2 * x } - - impl Eq for Foo { - fn eq(self, other: Foo) -> bool { self.a == other.a } + // Duplicate trait methods should not compile + fn default(x: Field, y: Field) -> Field { + x + 2 * y } - - impl Default for u64 { - fn default() -> Self { - 0 + } + + fn main() {}"; + + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + + for (err, _file_id) in errors { + match &err { + CompilationError::DefinitionError(DefCollectorErrorKind::Duplicate { + typ, + first_def, + second_def, + }) => { + assert_eq!(typ, &DuplicateType::TraitAssociatedFunction); + assert_eq!(first_def, "default"); + assert_eq!(second_def, "default"); } - } - - impl Default for Foo { - fn default() -> Self { - Foo { a: Default::default() } + _ => { + panic!("No other errors are expected! Found = {:?}", err); } - } - - fn main(a: Foo) -> pub bool { - a.is_default() - }"; - - let errors = get_program_errors(src); - errors.iter().for_each(|err| println!("{:?}", err)); - assert!(errors.is_empty()); + }; } +} - #[test] - fn check_trait_implementation_duplicate_method() { - let src = " - trait Default { - fn default(x: Field, y: Field) -> Field; - } - - struct Foo { - bar: Field, - array: [Field; 2], +#[test] +fn check_trait_wrong_method_return_type() { + let src = " + trait Default { + fn default() -> Self; + } + + struct Foo { + } + + impl Default for Foo { + fn default() -> Field { + 0 } - - impl Default for Foo { - // Duplicate trait methods should not compile - fn default(x: Field, y: Field) -> Field { - y + 2 * x + } + + fn main() { + } + "; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + + for (err, _file_id) in errors { + match &err { + CompilationError::TypeError(TypeCheckError::TypeMismatch { + expected_typ, + expr_typ, + expr_span: _, + }) => { + assert_eq!(expected_typ, "Foo"); + assert_eq!(expr_typ, "Field"); } - // Duplicate trait methods should not compile - fn default(x: Field, y: Field) -> Field { - x + 2 * y + _ => { + panic!("No other errors are expected! Found = {:?}", err); } - } - - fn main() {}"; - - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - - for (err, _file_id) in errors { - match &err { - CompilationError::DefinitionError(DefCollectorErrorKind::Duplicate { - typ, - first_def, - second_def, - }) => { - assert_eq!(typ, &DuplicateType::TraitAssociatedFunction); - assert_eq!(first_def, "default"); - assert_eq!(second_def, "default"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } + }; } +} - #[test] - fn check_trait_wrong_method_return_type() { - let src = " - trait Default { - fn default() -> Self; - } - - struct Foo { +#[test] +fn check_trait_wrong_method_return_type2() { + let src = " + trait Default { + fn default(x: Field, y: Field) -> Self; + } + + struct Foo { + bar: Field, + array: [Field; 2], + } + + impl Default for Foo { + fn default(x: Field, _y: Field) -> Field { + x } - - impl Default for Foo { - fn default() -> Field { - 0 + } + + fn main() { + }"; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + + for (err, _file_id) in errors { + match &err { + CompilationError::TypeError(TypeCheckError::TypeMismatch { + expected_typ, + expr_typ, + expr_span: _, + }) => { + assert_eq!(expected_typ, "Foo"); + assert_eq!(expr_typ, "Field"); } - } - - fn main() { - } - "; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - - for (err, _file_id) in errors { - match &err { - CompilationError::TypeError(TypeCheckError::TypeMismatch { - expected_typ, - expr_typ, - expr_span: _, - }) => { - assert_eq!(expected_typ, "Foo"); - assert_eq!(expr_typ, "Field"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } + _ => { + panic!("No other errors are expected! Found = {:?}", err); + } + }; } +} - #[test] - fn check_trait_wrong_method_return_type2() { - let src = " - trait Default { - fn default(x: Field, y: Field) -> Self; - } - - struct Foo { - bar: Field, - array: [Field; 2], +#[test] +fn check_trait_missing_implementation() { + let src = " + trait Default { + fn default(x: Field, y: Field) -> Self; + + fn method2(x: Field) -> Field; + + } + + struct Foo { + bar: Field, + array: [Field; 2], + } + + impl Default for Foo { + fn default(x: Field, y: Field) -> Self { + Self { bar: x, array: [x,y] } } - - impl Default for Foo { - fn default(x: Field, _y: Field) -> Field { - x + } + + fn main() { + } + "; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + + for (err, _file_id) in errors { + match &err { + CompilationError::DefinitionError(DefCollectorErrorKind::TraitMissingMethod { + trait_name, + method_name, + trait_impl_span: _, + }) => { + assert_eq!(trait_name, "Default"); + assert_eq!(method_name, "method2"); } - } - - fn main() { - }"; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - - for (err, _file_id) in errors { - match &err { - CompilationError::TypeError(TypeCheckError::TypeMismatch { - expected_typ, - expr_typ, - expr_span: _, - }) => { - assert_eq!(expected_typ, "Foo"); - assert_eq!(expr_typ, "Field"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } + _ => { + panic!("No other errors are expected! Found = {:?}", err); + } + }; } +} - #[test] - fn check_trait_missing_implementation() { - let src = " - trait Default { - fn default(x: Field, y: Field) -> Self; - - fn method2(x: Field) -> Field; - - } - - struct Foo { - bar: Field, - array: [Field; 2], +#[test] +fn check_trait_not_in_scope() { + let src = " + struct Foo { + bar: Field, + array: [Field; 2], + } + + // Default trait does not exist + impl Default for Foo { + fn default(x: Field, y: Field) -> Self { + Self { bar: x, array: [x,y] } } - - impl Default for Foo { - fn default(x: Field, y: Field) -> Self { - Self { bar: x, array: [x,y] } + } + + fn main() { + } + + "; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + for (err, _file_id) in errors { + match &err { + CompilationError::DefinitionError(DefCollectorErrorKind::TraitNotFound { + trait_path, + }) => { + assert_eq!(trait_path.as_string(), "Default"); } - } - - fn main() { - } - "; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - - for (err, _file_id) in errors { - match &err { - CompilationError::DefinitionError(DefCollectorErrorKind::TraitMissingMethod { - trait_name, - method_name, - trait_impl_span: _, - }) => { - assert_eq!(trait_name, "Default"); - assert_eq!(method_name, "method2"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } + _ => { + panic!("No other errors are expected! Found = {:?}", err); + } + }; } +} - #[test] - fn check_trait_not_in_scope() { - let src = " - struct Foo { - bar: Field, - array: [Field; 2], +#[test] +fn check_trait_wrong_method_name() { + let src = " + trait Default { + } + + struct Foo { + bar: Field, + array: [Field; 2], + } + + // wrong trait name method should not compile + impl Default for Foo { + fn does_not_exist(x: Field, y: Field) -> Self { + Self { bar: x, array: [x,y] } } - - // Default trait does not exist - impl Default for Foo { - fn default(x: Field, y: Field) -> Self { - Self { bar: x, array: [x,y] } + } + + fn main() { + }"; + let compilation_errors = get_program_errors(src); + assert!(!has_parser_error(&compilation_errors)); + assert!( + compilation_errors.len() == 1, + "Expected 1 compilation error, got: {:?}", + compilation_errors + ); + + for (err, _file_id) in compilation_errors { + match &err { + CompilationError::DefinitionError(DefCollectorErrorKind::MethodNotInTrait { + trait_name, + impl_method, + }) => { + assert_eq!(trait_name, "Default"); + assert_eq!(impl_method, "does_not_exist"); } - } - - fn main() { - } - - "; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - for (err, _file_id) in errors { - match &err { - CompilationError::DefinitionError(DefCollectorErrorKind::TraitNotFound { - trait_path, - }) => { - assert_eq!(trait_path.as_string(), "Default"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } + _ => { + panic!("No other errors are expected! Found = {:?}", err); + } + }; } +} - #[test] - fn check_trait_wrong_method_name() { - let src = " - trait Default { - } - - struct Foo { - bar: Field, - array: [Field; 2], +#[test] +fn check_trait_wrong_parameter() { + let src = " + trait Default { + fn default(x: Field) -> Self; + } + + struct Foo { + bar: u32, + } + + impl Default for Foo { + fn default(x: u32) -> Self { + Foo {bar: x} } - - // wrong trait name method should not compile - impl Default for Foo { - fn does_not_exist(x: Field, y: Field) -> Self { - Self { bar: x, array: [x,y] } + } + + fn main() { + } + "; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + + for (err, _file_id) in errors { + match &err { + CompilationError::TypeError(TypeCheckError::TraitMethodParameterTypeMismatch { + method_name, + expected_typ, + actual_typ, + .. + }) => { + assert_eq!(method_name, "default"); + assert_eq!(expected_typ, "Field"); + assert_eq!(actual_typ, "u32"); } - } - - fn main() { - }"; - let compilation_errors = get_program_errors(src); - assert!(!has_parser_error(&compilation_errors)); - assert!( - compilation_errors.len() == 1, - "Expected 1 compilation error, got: {:?}", - compilation_errors - ); - - for (err, _file_id) in compilation_errors { - match &err { - CompilationError::DefinitionError(DefCollectorErrorKind::MethodNotInTrait { - trait_name, - impl_method, - }) => { - assert_eq!(trait_name, "Default"); - assert_eq!(impl_method, "does_not_exist"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } + _ => { + panic!("No other errors are expected! Found = {:?}", err); + } + }; } +} - #[test] - fn check_trait_wrong_parameter() { - let src = " - trait Default { - fn default(x: Field) -> Self; - } - - struct Foo { - bar: u32, +#[test] +fn check_trait_wrong_parameter2() { + let src = " + trait Default { + fn default(x: Field, y: Field) -> Self; + } + + struct Foo { + bar: Field, + array: [Field; 2], + } + + impl Default for Foo { + fn default(x: Field, y: Foo) -> Self { + Self { bar: x, array: [x, y.bar] } } - - impl Default for Foo { - fn default(x: u32) -> Self { - Foo {bar: x} + } + + fn main() { + }"; + + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + + for (err, _file_id) in errors { + match &err { + CompilationError::TypeError(TypeCheckError::TraitMethodParameterTypeMismatch { + method_name, + expected_typ, + actual_typ, + .. + }) => { + assert_eq!(method_name, "default"); + assert_eq!(expected_typ, "Field"); + assert_eq!(actual_typ, "Foo"); } - } - - fn main() { - } - "; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - - for (err, _file_id) in errors { - match &err { - CompilationError::TypeError(TypeCheckError::TraitMethodParameterTypeMismatch { - method_name, - expected_typ, - actual_typ, - .. - }) => { - assert_eq!(method_name, "default"); - assert_eq!(expected_typ, "Field"); - assert_eq!(actual_typ, "u32"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } + _ => { + panic!("No other errors are expected! Found = {:?}", err); + } + }; } +} - #[test] - fn check_trait_wrong_parameter2() { - let src = " - trait Default { - fn default(x: Field, y: Field) -> Self; - } - - struct Foo { - bar: Field, - array: [Field; 2], - } - - impl Default for Foo { - fn default(x: Field, y: Foo) -> Self { - Self { bar: x, array: [x, y.bar] } +#[test] +fn check_trait_wrong_parameter_type() { + let src = " + trait Default { + fn default(x: Field, y: NotAType) -> Field; + } + + fn main(x: Field, y: Field) { + assert(y == x); + }"; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 2, "Expected 2 errors, got: {:?}", errors); + + for (err, _file_id) in errors { + match &err { + CompilationError::ResolverError(ResolverError::PathResolutionError( + PathResolutionError::Unresolved(ident), + )) => { + assert_eq!(ident, "NotAType"); } - } - - fn main() { - }"; - - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - - for (err, _file_id) in errors { - match &err { - CompilationError::TypeError(TypeCheckError::TraitMethodParameterTypeMismatch { - method_name, - expected_typ, - actual_typ, - .. - }) => { - assert_eq!(method_name, "default"); - assert_eq!(expected_typ, "Field"); - assert_eq!(actual_typ, "Foo"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } + _ => { + panic!("No other errors are expected! Found = {:?}", err); + } + }; } +} - #[test] - fn check_trait_wrong_parameter_type() { - let src = " - trait Default { - fn default(x: Field, y: NotAType) -> Field; - } - - fn main(x: Field, y: Field) { - assert(y == x); - }"; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 2, "Expected 2 errors, got: {:?}", errors); - - for (err, _file_id) in errors { - match &err { - CompilationError::ResolverError(ResolverError::PathResolutionError( - PathResolutionError::Unresolved(ident), - )) => { - assert_eq!(ident, "NotAType"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } +#[test] +fn check_trait_wrong_parameters_count() { + let src = " + trait Default { + fn default(x: Field, y: Field) -> Self; } - - #[test] - fn check_trait_wrong_parameters_count() { - let src = " - trait Default { - fn default(x: Field, y: Field) -> Self; - } - - struct Foo { - bar: Field, - array: [Field; 2], + + struct Foo { + bar: Field, + array: [Field; 2], + } + + impl Default for Foo { + fn default(x: Field) -> Self { + Self { bar: x, array: [x, x] } } - - impl Default for Foo { - fn default(x: Field) -> Self { - Self { bar: x, array: [x, x] } + } + + fn main() { + } + "; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + for (err, _file_id) in errors { + match &err { + CompilationError::TypeError(TypeCheckError::MismatchTraitImplNumParameters { + actual_num_parameters, + expected_num_parameters, + trait_name, + method_name, + .. + }) => { + assert_eq!(actual_num_parameters, &1_usize); + assert_eq!(expected_num_parameters, &2_usize); + assert_eq!(method_name, "default"); + assert_eq!(trait_name, "Default"); } - } - - fn main() { - } - "; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - for (err, _file_id) in errors { - match &err { - CompilationError::TypeError(TypeCheckError::MismatchTraitImplNumParameters { - actual_num_parameters, - expected_num_parameters, - trait_name, - method_name, - .. - }) => { - assert_eq!(actual_num_parameters, &1_usize); - assert_eq!(expected_num_parameters, &2_usize); - assert_eq!(method_name, "default"); - assert_eq!(trait_name, "Default"); - } - _ => { - panic!("No other errors are expected in this test case! Found = {:?}", err); - } - }; - } + _ => { + panic!("No other errors are expected in this test case! Found = {:?}", err); + } + }; } +} - #[test] - fn check_trait_impl_for_non_type() { - let src = " - trait Default { - fn default(x: Field, y: Field) -> Field; - } - - impl Default for main { - fn default(x: Field, y: Field) -> Field { - x + y - } - } +#[test] +fn check_trait_impl_for_non_type() { + let src = " + trait Default { + fn default(x: Field, y: Field) -> Field; + } - fn main() {} - "; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - for (err, _file_id) in errors { - match &err { - CompilationError::ResolverError(ResolverError::Expected { - expected, got, .. - }) => { - assert_eq!(expected, "type"); - assert_eq!(got, "function"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; + impl Default for main { + fn default(x: Field, y: Field) -> Field { + x + y } } - #[test] - fn check_impl_struct_not_trait() { - let src = " - struct Foo { - bar: Field, - array: [Field; 2], - } - - struct Default { - x: Field, - z: Field, - } - - // Default is struct not a trait - impl Default for Foo { - fn default(x: Field, y: Field) -> Self { - Self { bar: x, array: [x,y] } + fn main() {} + "; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + for (err, _file_id) in errors { + match &err { + CompilationError::ResolverError(ResolverError::Expected { expected, got, .. }) => { + assert_eq!(expected, "type"); + assert_eq!(got, "function"); } - } - - fn main() { - } - - "; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - for (err, _file_id) in errors { - match &err { - CompilationError::DefinitionError(DefCollectorErrorKind::NotATrait { - not_a_trait_name, - }) => { - assert_eq!(not_a_trait_name.to_string(), "plain::Default"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } + _ => { + panic!("No other errors are expected! Found = {:?}", err); + } + }; } +} - #[test] - fn check_trait_duplicate_declaration() { - let src = " - trait Default { - fn default(x: Field, y: Field) -> Self; - } - - struct Foo { - bar: Field, - array: [Field; 2], +#[test] +fn check_impl_struct_not_trait() { + let src = " + struct Foo { + bar: Field, + array: [Field; 2], + } + + struct Default { + x: Field, + z: Field, + } + + // Default is struct not a trait + impl Default for Foo { + fn default(x: Field, y: Field) -> Self { + Self { bar: x, array: [x,y] } } - - impl Default for Foo { - fn default(x: Field,y: Field) -> Self { - Self { bar: x, array: [x,y] } + } + + fn main() { + } + + "; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + for (err, _file_id) in errors { + match &err { + CompilationError::DefinitionError(DefCollectorErrorKind::NotATrait { + not_a_trait_name, + }) => { + assert_eq!(not_a_trait_name.to_string(), "plain::Default"); } - } - - - trait Default { - fn default(x: Field) -> Self; - } - - fn main() { - }"; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - for (err, _file_id) in errors { - match &err { - CompilationError::DefinitionError(DefCollectorErrorKind::Duplicate { - typ, - first_def, - second_def, - }) => { - assert_eq!(typ, &DuplicateType::Trait); - assert_eq!(first_def, "Default"); - assert_eq!(second_def, "Default"); - } - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } + _ => { + panic!("No other errors are expected! Found = {:?}", err); + } + }; } +} - #[test] - fn check_trait_duplicate_implementation() { - let src = " - trait Default { - } - struct Foo { - bar: Field, - } - - impl Default for Foo { - } - impl Default for Foo { - } - fn main() { - } - "; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 2, "Expected 2 errors, got: {:?}", errors); - for (err, _file_id) in errors { - match &err { - CompilationError::DefinitionError(DefCollectorErrorKind::OverlappingImpl { - .. - }) => (), - CompilationError::DefinitionError(DefCollectorErrorKind::OverlappingImplNote { - .. - }) => (), - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; - } +#[test] +fn check_trait_duplicate_declaration() { + let src = " + trait Default { + fn default(x: Field, y: Field) -> Self; } - - #[test] - fn check_trait_duplicate_implementation_with_alias() { - let src = " - trait Default { - } - - struct MyStruct { - } - - type MyType = MyStruct; - - impl Default for MyStruct { - } - - impl Default for MyType { - } - - fn main() { - } - "; - let errors = get_program_errors(src); - assert!(!has_parser_error(&errors)); - assert!(errors.len() == 2, "Expected 2 errors, got: {:?}", errors); - for (err, _file_id) in errors { - match &err { - CompilationError::DefinitionError(DefCollectorErrorKind::OverlappingImpl { - .. - }) => (), - CompilationError::DefinitionError(DefCollectorErrorKind::OverlappingImplNote { - .. - }) => (), - _ => { - panic!("No other errors are expected! Found = {:?}", err); - } - }; + + struct Foo { + bar: Field, + array: [Field; 2], + } + + impl Default for Foo { + fn default(x: Field,y: Field) -> Self { + Self { bar: x, array: [x,y] } } } + + + trait Default { + fn default(x: Field) -> Self; + } + + fn main() { + }"; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + for (err, _file_id) in errors { + match &err { + CompilationError::DefinitionError(DefCollectorErrorKind::Duplicate { + typ, + first_def, + second_def, + }) => { + assert_eq!(typ, &DuplicateType::Trait); + assert_eq!(first_def, "Default"); + assert_eq!(second_def, "Default"); + } + _ => { + panic!("No other errors are expected! Found = {:?}", err); + } + }; + } +} - #[test] - fn test_impl_self_within_default_def() { - let src = " - trait Bar { - fn ok(self) -> Self; - - fn ref_ok(self) -> Self { - self.ok() +#[test] +fn check_trait_duplicate_implementation() { + let src = " + trait Default { + } + struct Foo { + bar: Field, + } + + impl Default for Foo { + } + impl Default for Foo { + } + fn main() { + } + "; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 2, "Expected 2 errors, got: {:?}", errors); + for (err, _file_id) in errors { + match &err { + CompilationError::DefinitionError(DefCollectorErrorKind::OverlappingImpl { + .. + }) => (), + CompilationError::DefinitionError(DefCollectorErrorKind::OverlappingImplNote { + .. + }) => (), + _ => { + panic!("No other errors are expected! Found = {:?}", err); } - } + }; + } +} - impl Bar for (T, T) where T: Bar { - fn ok(self) -> Self { - self +#[test] +fn check_trait_duplicate_implementation_with_alias() { + let src = " + trait Default { + } + + struct MyStruct { + } + + type MyType = MyStruct; + + impl Default for MyStruct { + } + + impl Default for MyType { + } + + fn main() { + } + "; + let errors = get_program_errors(src); + assert!(!has_parser_error(&errors)); + assert!(errors.len() == 2, "Expected 2 errors, got: {:?}", errors); + for (err, _file_id) in errors { + match &err { + CompilationError::DefinitionError(DefCollectorErrorKind::OverlappingImpl { + .. + }) => (), + CompilationError::DefinitionError(DefCollectorErrorKind::OverlappingImplNote { + .. + }) => (), + _ => { + panic!("No other errors are expected! Found = {:?}", err); } - }"; - let errors = get_program_errors(src); - errors.iter().for_each(|err| println!("{:?}", err)); - assert!(errors.is_empty()); + }; } +} - #[test] - fn check_trait_as_type_as_fn_parameter() { - let src = " - trait Eq { - fn eq(self, other: Self) -> bool; - } +#[test] +fn test_impl_self_within_default_def() { + let src = " + trait Bar { + fn ok(self) -> Self; - struct Foo { - a: u64, + fn ref_ok(self) -> Self { + self.ok() } + } - impl Eq for Foo { - fn eq(self, other: Foo) -> bool { self.a == other.a } + impl Bar for (T, T) where T: Bar { + fn ok(self) -> Self { + self } + }"; + let errors = get_program_errors(src); + errors.iter().for_each(|err| println!("{:?}", err)); + assert!(errors.is_empty()); +} - fn test_eq(x: impl Eq) -> bool { - x.eq(x) - } +#[test] +fn check_trait_as_type_as_fn_parameter() { + let src = " + trait Eq { + fn eq(self, other: Self) -> bool; + } - fn main(a: Foo) -> pub bool { - test_eq(a) - }"; + struct Foo { + a: u64, + } - let errors = get_program_errors(src); - errors.iter().for_each(|err| println!("{:?}", err)); - assert!(errors.is_empty()); + impl Eq for Foo { + fn eq(self, other: Foo) -> bool { self.a == other.a } } - #[test] - fn check_trait_as_type_as_two_fn_parameters() { - let src = " - trait Eq { - fn eq(self, other: Self) -> bool; - } + fn test_eq(x: impl Eq) -> bool { + x.eq(x) + } - trait Test { - fn test(self) -> bool; - } + fn main(a: Foo) -> pub bool { + test_eq(a) + }"; - struct Foo { - a: u64, - } + let errors = get_program_errors(src); + errors.iter().for_each(|err| println!("{:?}", err)); + assert!(errors.is_empty()); +} - impl Eq for Foo { - fn eq(self, other: Foo) -> bool { self.a == other.a } - } +#[test] +fn check_trait_as_type_as_two_fn_parameters() { + let src = " + trait Eq { + fn eq(self, other: Self) -> bool; + } - impl Test for u64 { - fn test(self) -> bool { self == self } - } + trait Test { + fn test(self) -> bool; + } - fn test_eq(x: impl Eq, y: impl Test) -> bool { - x.eq(x) == y.test() - } + struct Foo { + a: u64, + } - fn main(a: Foo, b: u64) -> pub bool { - test_eq(a, b) - }"; - - let errors = get_program_errors(src); - errors.iter().for_each(|err| println!("{:?}", err)); - assert!(errors.is_empty()); - } - - fn get_program_captures(src: &str) -> Vec> { - let (program, context, _errors) = get_program(src); - let interner = context.def_interner; - let mut all_captures: Vec> = Vec::new(); - for func in program.into_sorted().functions { - let func_id = interner.find_function(func.name()).unwrap(); - let hir_func = interner.function(&func_id); - // Iterate over function statements and apply filtering function - find_lambda_captures( - hir_func.block(&interner).statements(), - &interner, - &mut all_captures, - ); - } - all_captures - } - - fn find_lambda_captures( - stmts: &[StmtId], - interner: &NodeInterner, - result: &mut Vec>, - ) { - for stmt_id in stmts.iter() { - let hir_stmt = interner.statement(stmt_id); - let expr_id = match hir_stmt { - HirStatement::Expression(expr_id) => expr_id, - HirStatement::Let(let_stmt) => let_stmt.expression, - HirStatement::Assign(assign_stmt) => assign_stmt.expression, - HirStatement::Constrain(constr_stmt) => constr_stmt.0, - HirStatement::Semi(semi_expr) => semi_expr, - HirStatement::For(for_loop) => for_loop.block, - HirStatement::Error => panic!("Invalid HirStatement!"), - HirStatement::Break => panic!("Unexpected break"), - HirStatement::Continue => panic!("Unexpected continue"), - HirStatement::Comptime(_) => panic!("Unexpected comptime"), - }; - let expr = interner.expression(&expr_id); + impl Eq for Foo { + fn eq(self, other: Foo) -> bool { self.a == other.a } + } - get_lambda_captures(expr, interner, result); // TODO: dyn filter function as parameter - } + impl Test for u64 { + fn test(self) -> bool { self == self } } - fn get_lambda_captures( - expr: HirExpression, - interner: &NodeInterner, - result: &mut Vec>, - ) { - if let HirExpression::Lambda(lambda_expr) = expr { - let mut cur_capture = Vec::new(); + fn test_eq(x: impl Eq, y: impl Test) -> bool { + x.eq(x) == y.test() + } - for capture in lambda_expr.captures.iter() { - cur_capture.push(interner.definition(capture.ident.id).name.clone()); - } - result.push(cur_capture); + fn main(a: Foo, b: u64) -> pub bool { + test_eq(a, b) + }"; - // Check for other captures recursively within the lambda body - let hir_body_expr = interner.expression(&lambda_expr.body); - if let HirExpression::Block(block_expr) = hir_body_expr { - find_lambda_captures(block_expr.statements(), interner, result); - } - } + let errors = get_program_errors(src); + errors.iter().for_each(|err| println!("{:?}", err)); + assert!(errors.is_empty()); +} + +fn get_program_captures(src: &str) -> Vec> { + let (program, context, _errors) = get_program(src); + let interner = context.def_interner; + let mut all_captures: Vec> = Vec::new(); + for func in program.into_sorted().functions { + let func_id = interner.find_function(func.name()).unwrap(); + let hir_func = interner.function(&func_id); + // Iterate over function statements and apply filtering function + find_lambda_captures(hir_func.block(&interner).statements(), &interner, &mut all_captures); } + all_captures +} - #[test] - fn resolve_empty_function() { - let src = " - fn main() { +fn find_lambda_captures(stmts: &[StmtId], interner: &NodeInterner, result: &mut Vec>) { + for stmt_id in stmts.iter() { + let hir_stmt = interner.statement(stmt_id); + let expr_id = match hir_stmt { + HirStatement::Expression(expr_id) => expr_id, + HirStatement::Let(let_stmt) => let_stmt.expression, + HirStatement::Assign(assign_stmt) => assign_stmt.expression, + HirStatement::Constrain(constr_stmt) => constr_stmt.0, + HirStatement::Semi(semi_expr) => semi_expr, + HirStatement::For(for_loop) => for_loop.block, + HirStatement::Error => panic!("Invalid HirStatement!"), + HirStatement::Break => panic!("Unexpected break"), + HirStatement::Continue => panic!("Unexpected continue"), + HirStatement::Comptime(_) => panic!("Unexpected comptime"), + }; + let expr = interner.expression(&expr_id); - } - "; - assert!(get_program_errors(src).is_empty()); - } - #[test] - fn resolve_basic_function() { - let src = r#" - fn main(x : Field) { - let y = x + x; - assert(y == x); - } - "#; - assert!(get_program_errors(src).is_empty()); - } - #[test] - fn resolve_unused_var() { - let src = r#" - fn main(x : Field) { - let y = x + x; - assert(x == x); - } - "#; - - let errors = get_program_errors(src); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - // It should be regarding the unused variable - match &errors[0].0 { - CompilationError::ResolverError(ResolverError::UnusedVariable { ident }) => { - assert_eq!(&ident.0.contents, "y"); - } - _ => unreachable!("we should only have an unused var error"), - } + get_lambda_captures(expr, interner, result); // TODO: dyn filter function as parameter } +} - #[test] - fn resolve_unresolved_var() { - let src = r#" - fn main(x : Field) { - let y = x + x; - assert(y == z); - } - "#; - let errors = get_program_errors(src); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - // It should be regarding the unresolved var `z` (Maybe change to undeclared and special case) - match &errors[0].0 { - CompilationError::ResolverError(ResolverError::VariableNotDeclared { - name, - span: _, - }) => assert_eq!(name, "z"), - _ => unimplemented!("we should only have an unresolved variable"), +fn get_lambda_captures( + expr: HirExpression, + interner: &NodeInterner, + result: &mut Vec>, +) { + if let HirExpression::Lambda(lambda_expr) = expr { + let mut cur_capture = Vec::new(); + + for capture in lambda_expr.captures.iter() { + cur_capture.push(interner.definition(capture.ident.id).name.clone()); + } + result.push(cur_capture); + + // Check for other captures recursively within the lambda body + let hir_body_expr = interner.expression(&lambda_expr.body); + if let HirExpression::Block(block_expr) = hir_body_expr { + find_lambda_captures(block_expr.statements(), interner, result); } } +} + +#[test] +fn resolve_empty_function() { + let src = " + fn main() { - #[test] - fn unresolved_path() { - let src = " - fn main(x : Field) { - let _z = some::path::to::a::func(x); - } - "; - let errors = get_program_errors(src); - assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); - for (compilation_error, _file_id) in errors { - match compilation_error { - CompilationError::ResolverError(err) => { - match err { - ResolverError::PathResolutionError(PathResolutionError::Unresolved( - name, - )) => { - assert_eq!(name.to_string(), "some"); - } - _ => unimplemented!("we should only have an unresolved function"), - }; - } - _ => unimplemented!(), - } } + "; + assert!(get_program_errors(src).is_empty()); +} +#[test] +fn resolve_basic_function() { + let src = r#" + fn main(x : Field) { + let y = x + x; + assert(y == x); + } + "#; + assert!(get_program_errors(src).is_empty()); +} +#[test] +fn resolve_unused_var() { + let src = r#" + fn main(x : Field) { + let y = x + x; + assert(x == x); + } + "#; + + let errors = get_program_errors(src); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + // It should be regarding the unused variable + match &errors[0].0 { + CompilationError::ResolverError(ResolverError::UnusedVariable { ident }) => { + assert_eq!(&ident.0.contents, "y"); + } + _ => unreachable!("we should only have an unused var error"), } +} - #[test] - fn resolve_literal_expr() { - let src = r#" - fn main(x : Field) { - let y = 5; - assert(y == x); - } - "#; - assert!(get_program_errors(src).is_empty()); +#[test] +fn resolve_unresolved_var() { + let src = r#" + fn main(x : Field) { + let y = x + x; + assert(y == z); + } + "#; + let errors = get_program_errors(src); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + // It should be regarding the unresolved var `z` (Maybe change to undeclared and special case) + match &errors[0].0 { + CompilationError::ResolverError(ResolverError::VariableNotDeclared { name, span: _ }) => { + assert_eq!(name, "z"); + } + _ => unimplemented!("we should only have an unresolved variable"), } +} - #[test] - fn multiple_resolution_errors() { - let src = r#" - fn main(x : Field) { - let y = foo::bar(x); - let z = y + a; - } - "#; - - let errors = get_program_errors(src); - assert!(errors.len() == 3, "Expected 3 errors, got: {:?}", errors); - - // Errors are: - // `a` is undeclared - // `z` is unused - // `foo::bar` does not exist - for (compilation_error, _file_id) in errors { - match compilation_error { - CompilationError::ResolverError(err) => { - match err { - ResolverError::UnusedVariable { ident } => { - assert_eq!(&ident.0.contents, "z"); - } - ResolverError::VariableNotDeclared { name, .. } => { - assert_eq!(name, "a"); - } - ResolverError::PathResolutionError(PathResolutionError::Unresolved( - name, - )) => { - assert_eq!(name.to_string(), "foo"); - } - _ => unimplemented!(), - }; - } - _ => unimplemented!(), +#[test] +fn unresolved_path() { + let src = " + fn main(x : Field) { + let _z = some::path::to::a::func(x); + } + "; + let errors = get_program_errors(src); + assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors); + for (compilation_error, _file_id) in errors { + match compilation_error { + CompilationError::ResolverError(err) => { + match err { + ResolverError::PathResolutionError(PathResolutionError::Unresolved(name)) => { + assert_eq!(name.to_string(), "some"); + } + _ => unimplemented!("we should only have an unresolved function"), + }; } + _ => unimplemented!(), } } +} - #[test] - fn resolve_prefix_expr() { - let src = r#" - fn main(x : Field) { - let _y = -x; - } - "#; - assert!(get_program_errors(src).is_empty()); - } +#[test] +fn resolve_literal_expr() { + let src = r#" + fn main(x : Field) { + let y = 5; + assert(y == x); + } + "#; + assert!(get_program_errors(src).is_empty()); +} - #[test] - fn resolve_for_expr() { - let src = r#" - fn main(x : u64) { - for i in 1..20 { - let _z = x + i; +#[test] +fn multiple_resolution_errors() { + let src = r#" + fn main(x : Field) { + let y = foo::bar(x); + let z = y + a; + } + "#; + + let errors = get_program_errors(src); + assert!(errors.len() == 3, "Expected 3 errors, got: {:?}", errors); + + // Errors are: + // `a` is undeclared + // `z` is unused + // `foo::bar` does not exist + for (compilation_error, _file_id) in errors { + match compilation_error { + CompilationError::ResolverError(err) => { + match err { + ResolverError::UnusedVariable { ident } => { + assert_eq!(&ident.0.contents, "z"); + } + ResolverError::VariableNotDeclared { name, .. } => { + assert_eq!(name, "a"); + } + ResolverError::PathResolutionError(PathResolutionError::Unresolved(name)) => { + assert_eq!(name.to_string(), "foo"); + } + _ => unimplemented!(), }; } - "#; - assert!(get_program_errors(src).is_empty()); + _ => unimplemented!(), + } } +} - #[test] - fn resolve_call_expr() { - let src = r#" - fn main(x : Field) { - let _z = foo(x); - } +#[test] +fn resolve_prefix_expr() { + let src = r#" + fn main(x : Field) { + let _y = -x; + } + "#; + assert!(get_program_errors(src).is_empty()); +} - fn foo(x : Field) -> Field { - x - } - "#; - assert!(get_program_errors(src).is_empty()); - } - - #[test] - fn resolve_shadowing() { - let src = r#" - fn main(x : Field) { - let x = foo(x); - let x = x; - let (x, x) = (x, x); - let _ = x; - } +#[test] +fn resolve_for_expr() { + let src = r#" + fn main(x : u64) { + for i in 1..20 { + let _z = x + i; + }; + } + "#; + assert!(get_program_errors(src).is_empty()); +} - fn foo(x : Field) -> Field { - x - } - "#; - assert!(get_program_errors(src).is_empty()); - } +#[test] +fn resolve_call_expr() { + let src = r#" + fn main(x : Field) { + let _z = foo(x); + } - #[test] - fn resolve_basic_closure() { - let src = r#" - fn main(x : Field) -> pub Field { - let closure = |y| y + x; - closure(x) - } - "#; - assert!(get_program_errors(src).is_empty()); - } + fn foo(x : Field) -> Field { + x + } + "#; + assert!(get_program_errors(src).is_empty()); +} - #[test] - fn resolve_simplified_closure() { - // based on bug https://github.com/noir-lang/noir/issues/1088 +#[test] +fn resolve_shadowing() { + let src = r#" + fn main(x : Field) { + let x = foo(x); + let x = x; + let (x, x) = (x, x); + let _ = x; + } - let src = r#"fn do_closure(x: Field) -> Field { - let y = x; - let ret_capture = || { - y - }; - ret_capture() - } - - fn main(x: Field) { - assert(do_closure(x) == 100); - } - - "#; - let parsed_captures = get_program_captures(src); - let expected_captures = vec![vec!["y".to_string()]]; - assert_eq!(expected_captures, parsed_captures); - } - - #[test] - fn resolve_complex_closures() { - let src = r#" - fn main(x: Field) -> pub Field { - let closure_without_captures = |x: Field| -> Field { x + x }; - let a = closure_without_captures(1); - - let closure_capturing_a_param = |y: Field| -> Field { y + x }; - let b = closure_capturing_a_param(2); - - let closure_capturing_a_local_var = |y: Field| -> Field { y + b }; - let c = closure_capturing_a_local_var(3); - - let closure_with_transitive_captures = |y: Field| -> Field { - let d = 5; - let nested_closure = |z: Field| -> Field { - let doubly_nested_closure = |w: Field| -> Field { w + x + b }; - a + z + y + d + x + doubly_nested_closure(4) + x + y - }; - let res = nested_closure(5); - res + fn foo(x : Field) -> Field { + x + } + "#; + assert!(get_program_errors(src).is_empty()); +} + +#[test] +fn resolve_basic_closure() { + let src = r#" + fn main(x : Field) -> pub Field { + let closure = |y| y + x; + closure(x) + } + "#; + assert!(get_program_errors(src).is_empty()); +} + +#[test] +fn resolve_simplified_closure() { + // based on bug https://github.com/noir-lang/noir/issues/1088 + + let src = r#"fn do_closure(x: Field) -> Field { + let y = x; + let ret_capture = || { + y + }; + ret_capture() + } + + fn main(x: Field) { + assert(do_closure(x) == 100); + } + + "#; + let parsed_captures = get_program_captures(src); + let expected_captures = vec![vec!["y".to_string()]]; + assert_eq!(expected_captures, parsed_captures); +} + +#[test] +fn resolve_complex_closures() { + let src = r#" + fn main(x: Field) -> pub Field { + let closure_without_captures = |x: Field| -> Field { x + x }; + let a = closure_without_captures(1); + + let closure_capturing_a_param = |y: Field| -> Field { y + x }; + let b = closure_capturing_a_param(2); + + let closure_capturing_a_local_var = |y: Field| -> Field { y + b }; + let c = closure_capturing_a_local_var(3); + + let closure_with_transitive_captures = |y: Field| -> Field { + let d = 5; + let nested_closure = |z: Field| -> Field { + let doubly_nested_closure = |w: Field| -> Field { w + x + b }; + a + z + y + d + x + doubly_nested_closure(4) + x + y }; + let res = nested_closure(5); + res + }; + + a + b + c + closure_with_transitive_captures(6) + } + "#; + assert!(get_program_errors(src).is_empty(), "there should be no errors"); + + let expected_captures = vec![ + vec![], + vec!["x".to_string()], + vec!["b".to_string()], + vec!["x".to_string(), "b".to_string(), "a".to_string()], + vec!["x".to_string(), "b".to_string(), "a".to_string(), "y".to_string(), "d".to_string()], + vec!["x".to_string(), "b".to_string()], + ]; + + let parsed_captures = get_program_captures(src); + + assert_eq!(expected_captures, parsed_captures); +} + +#[test] +fn resolve_fmt_strings() { + let src = r#" + fn main() { + let string = f"this is i: {i}"; + println(string); + + println(f"I want to print {0}"); + + let new_val = 10; + println(f"random_string{new_val}{new_val}"); + } + fn println(x : T) -> T { + x + } + "#; + + let errors = get_program_errors(src); + assert!(errors.len() == 5, "Expected 5 errors, got: {:?}", errors); - a + b + c + closure_with_transitive_captures(6) + for (err, _file_id) in errors { + match &err { + CompilationError::ResolverError(ResolverError::VariableNotDeclared { + name, .. + }) => { + assert_eq!(name, "i"); } - "#; - assert!(get_program_errors(src).is_empty(), "there should be no errors"); - - let expected_captures = vec![ - vec![], - vec!["x".to_string()], - vec!["b".to_string()], - vec!["x".to_string(), "b".to_string(), "a".to_string()], - vec![ - "x".to_string(), - "b".to_string(), - "a".to_string(), - "y".to_string(), - "d".to_string(), - ], - vec!["x".to_string(), "b".to_string()], - ]; - - let parsed_captures = get_program_captures(src); - - assert_eq!(expected_captures, parsed_captures); - } - - #[test] - fn resolve_fmt_strings() { - let src = r#" - fn main() { - let string = f"this is i: {i}"; - println(string); - - println(f"I want to print {0}"); - - let new_val = 10; - println(f"random_string{new_val}{new_val}"); + CompilationError::ResolverError(ResolverError::NumericConstantInFormatString { + name, + .. + }) => { + assert_eq!(name, "0"); } - fn println(x : T) -> T { - x + CompilationError::TypeError(TypeCheckError::UnusedResultError { + expr_type: _, + expr_span, + }) => { + let a = src.get(expr_span.start() as usize..expr_span.end() as usize).unwrap(); + assert!( + a == "println(string)" + || a == "println(f\"I want to print {0}\")" + || a == "println(f\"random_string{new_val}{new_val}\")" + ); } - "#; - - let errors = get_program_errors(src); - assert!(errors.len() == 5, "Expected 5 errors, got: {:?}", errors); - - for (err, _file_id) in errors { - match &err { - CompilationError::ResolverError(ResolverError::VariableNotDeclared { - name, - .. - }) => { - assert_eq!(name, "i"); - } - CompilationError::ResolverError(ResolverError::NumericConstantInFormatString { - name, - .. - }) => { - assert_eq!(name, "0"); - } - CompilationError::TypeError(TypeCheckError::UnusedResultError { - expr_type: _, - expr_span, - }) => { - let a = src.get(expr_span.start() as usize..expr_span.end() as usize).unwrap(); - assert!( - a == "println(string)" - || a == "println(f\"I want to print {0}\")" - || a == "println(f\"random_string{new_val}{new_val}\")" - ); - } - _ => unimplemented!(), - }; - } + _ => unimplemented!(), + }; } +} - fn check_rewrite(src: &str, expected: &str) { - let (_program, mut context, _errors) = get_program(src); - let main_func_id = context.def_interner.find_function("main").unwrap(); - let program = monomorphize(main_func_id, &mut context.def_interner).unwrap(); - assert!(format!("{}", program) == expected); - } +fn check_rewrite(src: &str, expected: &str) { + let (_program, mut context, _errors) = get_program(src); + let main_func_id = context.def_interner.find_function("main").unwrap(); + let program = monomorphize(main_func_id, &mut context.def_interner).unwrap(); + assert!(format!("{}", program) == expected); +} - #[test] - fn simple_closure_with_no_captured_variables() { - let src = r#" - fn main() -> pub Field { - let x = 1; - let closure = || x; - closure() - } - "#; +#[test] +fn simple_closure_with_no_captured_variables() { + let src = r#" + fn main() -> pub Field { + let x = 1; + let closure = || x; + closure() + } + "#; - let expected_rewrite = r#"fn main$f0() -> Field { + let expected_rewrite = r#"fn main$f0() -> Field { let x$0 = 1; let closure$3 = { let closure_variable$2 = { @@ -1248,167 +1227,154 @@ fn lambda$f1(mut env$l1: (Field)) -> Field { env$l1.0 } "#; - check_rewrite(src, expected_rewrite); - } - - #[test] - fn deny_mutually_recursive_structs() { - let src = r#" - struct Foo { bar: Bar } - struct Bar { foo: Foo } - fn main() {} - "#; - assert_eq!(get_program_errors(src).len(), 1); - } - - #[test] - fn deny_cyclic_globals() { - let src = r#" - global A = B; - global B = A; - fn main() {} - "#; - assert_eq!(get_program_errors(src).len(), 1); - } - - #[test] - fn deny_cyclic_type_aliases() { - let src = r#" - type A = B; - type B = A; - fn main() {} - "#; - assert_eq!(get_program_errors(src).len(), 1); - } - - #[test] - fn ensure_nested_type_aliases_type_check() { - let src = r#" - type A = B; - type B = u8; - fn main() { - let _a: A = 0 as u16; - } - "#; - assert_eq!(get_program_errors(src).len(), 1); - } - - #[test] - fn type_aliases_in_entry_point() { - let src = r#" - type Foo = u8; - fn main(_x: Foo) {} - "#; - assert_eq!(get_program_errors(src).len(), 0); - } - - #[test] - fn operators_in_global_used_in_type() { - let src = r#" - global ONE = 1; - global COUNT = ONE + 2; - fn main() { - let _array: [Field; COUNT] = [1, 2, 3]; - } - "#; - assert_eq!(get_program_errors(src).len(), 0); - } + check_rewrite(src, expected_rewrite); +} - #[test] - fn break_and_continue_in_constrained_fn() { - let src = r#" - fn main() { - for i in 0 .. 10 { - if i == 2 { - continue; - } - if i == 5 { - break; - } +#[test] +fn deny_cyclic_globals() { + let src = r#" + global A = B; + global B = A; + fn main() {} + "#; + assert_eq!(get_program_errors(src).len(), 1); +} + +#[test] +fn deny_cyclic_type_aliases() { + let src = r#" + type A = B; + type B = A; + fn main() {} + "#; + assert_eq!(get_program_errors(src).len(), 1); +} + +#[test] +fn ensure_nested_type_aliases_type_check() { + let src = r#" + type A = B; + type B = u8; + fn main() { + let _a: A = 0 as u16; + } + "#; + assert_eq!(get_program_errors(src).len(), 1); +} + +#[test] +fn type_aliases_in_entry_point() { + let src = r#" + type Foo = u8; + fn main(_x: Foo) {} + "#; + assert_eq!(get_program_errors(src).len(), 0); +} + +#[test] +fn operators_in_global_used_in_type() { + let src = r#" + global ONE = 1; + global COUNT = ONE + 2; + fn main() { + let _array: [Field; COUNT] = [1, 2, 3]; + } + "#; + assert_eq!(get_program_errors(src).len(), 0); +} + +#[test] +fn break_and_continue_in_constrained_fn() { + let src = r#" + fn main() { + for i in 0 .. 10 { + if i == 2 { + continue; + } + if i == 5 { + break; } } - "#; - assert_eq!(get_program_errors(src).len(), 2); - } + } + "#; + assert_eq!(get_program_errors(src).len(), 2); +} - #[test] - fn break_and_continue_outside_loop() { - let src = r#" - unconstrained fn main() { - continue; - break; - } - "#; - assert_eq!(get_program_errors(src).len(), 2); - } +#[test] +fn break_and_continue_outside_loop() { + let src = r#" + unconstrained fn main() { + continue; + break; + } + "#; + assert_eq!(get_program_errors(src).len(), 2); +} - // Regression for #2540 - #[test] - fn for_loop_over_array() { - let src = r#" - fn hello(_array: [u1; N]) { - for _ in 0..N {} - } +// Regression for #2540 +#[test] +fn for_loop_over_array() { + let src = r#" + fn hello(_array: [u1; N]) { + for _ in 0..N {} + } - fn main() { - let array: [u1; 2] = [0, 1]; - hello(array); - } - "#; - assert_eq!(get_program_errors(src).len(), 0); - } - - // Regression for #4545 - #[test] - fn type_aliases_in_main() { - let src = r#" - type Outer = [u8; N]; - fn main(_arg: Outer<1>) {} - "#; - assert_eq!(get_program_errors(src).len(), 0); - } - - #[test] - fn ban_mutable_globals() { - // Mutable globals are only allowed in a comptime context - let src = r#" - mut global FOO: Field = 0; - fn main() {} - "#; - assert_eq!(get_program_errors(src).len(), 1); - } - - #[test] - fn deny_inline_attribute_on_unconstrained() { - let src = r#" - #[no_predicates] - unconstrained fn foo(x: Field, y: Field) { - assert(x != y); - } - "#; - let errors = get_program_errors(src); - assert_eq!(errors.len(), 1); - assert!(matches!( - errors[0].0, - CompilationError::ResolverError( - ResolverError::NoPredicatesAttributeOnUnconstrained { .. } - ) - )); - } + fn main() { + let array: [u1; 2] = [0, 1]; + hello(array); + } + "#; + assert_eq!(get_program_errors(src).len(), 0); +} - #[test] - fn deny_fold_attribute_on_unconstrained() { - let src = r#" - #[fold] - unconstrained fn foo(x: Field, y: Field) { - assert(x != y); - } - "#; - let errors = get_program_errors(src); - assert_eq!(errors.len(), 1); - assert!(matches!( - errors[0].0, - CompilationError::ResolverError(ResolverError::FoldAttributeOnUnconstrained { .. }) - )); - } +// Regression for #4545 +#[test] +fn type_aliases_in_main() { + let src = r#" + type Outer = [u8; N]; + fn main(_arg: Outer<1>) {} + "#; + assert_eq!(get_program_errors(src).len(), 0); +} + +#[test] +fn ban_mutable_globals() { + // Mutable globals are only allowed in a comptime context + let src = r#" + mut global FOO: Field = 0; + fn main() {} + "#; + assert_eq!(get_program_errors(src).len(), 1); +} + +#[test] +fn deny_inline_attribute_on_unconstrained() { + let src = r#" + #[no_predicates] + unconstrained fn foo(x: Field, y: Field) { + assert(x != y); + } + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + assert!(matches!( + errors[0].0, + CompilationError::ResolverError(ResolverError::NoPredicatesAttributeOnUnconstrained { .. }) + )); +} + +#[test] +fn deny_fold_attribute_on_unconstrained() { + let src = r#" + #[fold] + unconstrained fn foo(x: Field, y: Field) { + assert(x != y); + } + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + assert!(matches!( + errors[0].0, + CompilationError::ResolverError(ResolverError::FoldAttributeOnUnconstrained { .. }) + )); } diff --git a/noir/noir-repo/compiler/noirc_frontend/src/tests/name_shadowing.rs b/noir/noir-repo/compiler/noirc_frontend/src/tests/name_shadowing.rs new file mode 100644 index 00000000000..b0d83510039 --- /dev/null +++ b/noir/noir-repo/compiler/noirc_frontend/src/tests/name_shadowing.rs @@ -0,0 +1,419 @@ +#![cfg(test)] +use super::get_program_errors; +use std::collections::HashSet; + +#[test] +fn test_name_shadowing() { + let src = " + trait Default { + fn default() -> Self; + } + + impl Default for bool { + fn default() -> bool { + false + } + } + + impl Default for Field { + fn default() -> Field { + 0 + } + } + + impl Default for [T; N] where T: Default { + fn default() -> [T; N] { + [Default::default(); N] + } + } + + impl Default for (T, U) where T: Default, U: Default { + fn default() -> (T, U) { + (Default::default(), Default::default()) + } + } + + fn drop_var(_x: T, y: U) -> U { y } + + mod local_module { + use crate::{Default, drop_var}; + + global LOCAL_GLOBAL_N: Field = 0; + + global LOCAL_GLOBAL_M: Field = 1; + + struct LocalStruct { + field1: A, + field2: B, + field3: [A; N], + field4: ([A; N], [B; M]), + field5: &mut A, + } + + impl Default for LocalStruct where A: Default, B: Default { + fn default() -> Self { + let mut mut_field = &mut Default::default(); + Self { + field1: Default::default(), + field2: Default::default(), + field3: Default::default(), + field4: Default::default(), + field5: mut_field, + } + } + } + + trait DefinedInLocalModule1 { + fn trait_fn1(self, x: A); + fn trait_fn2(self, y: B); + fn trait_fn3(&mut self, x: A, y: B); + fn trait_fn4(self, x: [A; 0], y: [B]); + fn trait_fn5(self, x: [A; N], y: [B; M]) -> [A; 0]; + fn trait_fn6(self, x: [A; N], y: [B; M]) -> [A; 0]; + fn trait_fn7(self, _x: fn([A; 0]) -> B) -> Field { + drop_var(self, N + M) + } + } + + impl DefinedInLocalModule1 for LocalStruct { + fn trait_fn1(self, _x: A) { drop_var(self, ()) } + fn trait_fn2(self, _y: B) { drop_var(self, ()) } + fn trait_fn3(&mut self, _x: A, _y: B) { drop_var(self, ()) } + fn trait_fn4(self, _x: [A; 0], _y: [B]) { drop_var(self, ()) } + fn trait_fn5(self, _x: [A; N], _y: [B; M]) -> [A; 0] { drop_var(self, []) } + fn trait_fn6(self, _x: [A; N], _y: [B; M]) -> [A; 0] { drop_var(self, []) } + } + + pub fn local_fn4(_x: (A, B), _y: [Field; N], _z: [Field; M]) -> [A; 0] { + assert(LOCAL_GLOBAL_N != LOCAL_GLOBAL_M); + let x: Field = 0; + assert(x == 0); + let x: Field = 1; + assert(x == 1); + [] + } + } + + mod library { + use crate::{Default, drop_var}; + + mod library2 { + use crate::{Default, drop_var}; + + global IMPORT_GLOBAL_N_2: Field = 4; + + global IMPORT_GLOBAL_M_2: Field = 5; + + // When we re-export this type from another library and then use it in + // main, we get a panic + struct ReExportMeFromAnotherLib1 { + x : Field, + } + + struct PubLibLocalStruct3 { + pub_field1: A, + pub_field2: B, + pub_field3: [A; N], + pub_field4: ([A; N], [B; M]), + pub_field5: &mut A, + } + + impl Default for PubLibLocalStruct3 where A: Default, B: Default { + fn default() -> Self { + let mut mut_field = &mut Default::default(); + Self { + pub_field1: Default::default(), + pub_field2: Default::default(), + pub_field3: Default::default(), + pub_field4: Default::default(), + pub_field5: mut_field, + } + } + } + + trait PubLibDefinedInLocalModule3 { + fn pub_trait_fn1(self, x: A); + fn pub_trait_fn2(self, y: B); + fn pub_trait_fn3(&mut self, x: A, y: B); + fn pub_trait_fn4(self, x: [A; 0], y: [B]); + fn pub_trait_fn5(self, x: [A; N], y: [B; M]) -> [A; 0]; + fn pub_trait_fn6(self, x: [A; N], y: [B; M]) -> [A; 0]; + fn pub_trait_fn7(self, _x: fn([A; 0]) -> B) -> Field { + drop_var(self, N + M) + } + } + + impl PubLibDefinedInLocalModule3 for PubLibLocalStruct3 { + fn pub_trait_fn1(self, _x: A) { drop_var(self, ()) } + fn pub_trait_fn2(self, _y: B) { drop_var(self, ()) } + fn pub_trait_fn3(&mut self, _x: A, _y: B) { drop_var(self, ()) } + fn pub_trait_fn4(self, _x: [A; 0], _y: [B]) { drop_var(self, ()) } + fn pub_trait_fn5(self, _x: [A; N], _y: [B; M]) -> [A; 0] { drop_var(self, []) } + fn pub_trait_fn6(self, _x: [A; N], _y: [B; M]) -> [A; 0] { drop_var(self, []) } + } + + pub fn PubLiblocal_fn3(_x: (A, B), _y: [Field; N], _z: [Field; M]) -> [A; 0] { + assert(IMPORT_GLOBAL_N_2 != IMPORT_GLOBAL_M_2); + [] + } + } + + // Re-export + use library2::ReExportMeFromAnotherLib1; + + global IMPORT_GLOBAL_N_1: Field = 2; + + global IMPORT_GLOBAL_M_1: Field = 3; + + struct LibLocalStruct1 { + lib_field1: A, + lib_field2: B, + lib_field3: [A; N], + lib_field4: ([A; N], [B; M]), + lib_field5: &mut A, + } + + impl Default for LibLocalStruct1 where A: Default, B: Default { + fn default() -> Self { + let mut mut_field = &mut Default::default(); + Self { + lib_field1: Default::default(), + lib_field2: Default::default(), + lib_field3: Default::default(), + lib_field4: Default::default(), + lib_field5: mut_field, + } + } + } + + trait LibDefinedInLocalModule1 { + fn lib_trait_fn1(self, x: A); + fn lib_trait_fn2(self, y: B); + fn lib_trait_fn3(&mut self, x: A, y: B); + fn lib_trait_fn4(self, x: [A; 0], y: [B]); + fn lib_trait_fn5(self, x: [A; N], y: [B; M]) -> [A; 0]; + fn lib_trait_fn6(self, x: [A; N], y: [B; M]) -> [A; 0]; + fn lib_trait_fn7(self, _x: fn([A; 0]) -> B) -> Field { + drop_var(self, N + M) + } + } + + impl LibDefinedInLocalModule1 for LibLocalStruct1 { + fn lib_trait_fn1(self, _x: A) { drop_var(self, ()) } + fn lib_trait_fn2(self, _y: B) { drop_var(self, ()) } + fn lib_trait_fn3(&mut self, _x: A, _y: B) { drop_var(self, ()) } + fn lib_trait_fn4(self, _x: [A; 0], _y: [B]) { drop_var(self, ()) } + fn lib_trait_fn5(self, _x: [A; N], _y: [B; M]) -> [A; 0] { drop_var(self, []) } + fn lib_trait_fn6(self, _x: [A; N], _y: [B; M]) -> [A; 0] { drop_var(self, []) } + } + + pub fn Liblocal_fn1(_x: (A, B), _y: [Field; N], _z: [Field; M]) -> [A; 0] { + assert(IMPORT_GLOBAL_N_1 != IMPORT_GLOBAL_M_1); + [] + } + } + + mod library3 { + use crate::{Default, drop_var}; + + global IMPORT_GLOBAL_N_3: Field = 6; + + global IMPORT_GLOBAL_M_3: Field = 7; + + struct ReExportMeFromAnotherLib2 { + x : Field, + } + + struct PubCrateLibLocalStruct2 { + crate_field1: A, + crate_field2: B, + crate_field3: [A; N], + crate_field4: ([A; N], [B; M]), + crate_field5: &mut A, + } + + impl Default for PubCrateLibLocalStruct2 where A: Default, B: Default { + fn default() -> Self { + let mut mut_field = &mut Default::default(); + Self { + crate_field1: Default::default(), + crate_field2: Default::default(), + crate_field3: Default::default(), + crate_field4: Default::default(), + crate_field5: mut_field, + } + } + } + + trait PubCrateLibDefinedInLocalModule2 { + fn crate_trait_fn1(self, x: A); + fn crate_trait_fn2(self, y: B); + fn crate_trait_fn3(&mut self, x: A, y: B); + fn crate_trait_fn4(self, x: [A; 0], y: [B]); + fn crate_trait_fn5(self, x: [A; N], y: [B; M]) -> [A; 0]; + fn crate_trait_fn6(self, x: [A; N], y: [B; M]) -> [A; 0]; + fn crate_trait_fn7(self, _x: fn([A; 0]) -> B) -> Field { + drop_var(self, N + M) + } + } + + impl PubCrateLibDefinedInLocalModule2 for PubCrateLibLocalStruct2 { + fn crate_trait_fn1(self, _x: A) { drop_var(self, ()) } + fn crate_trait_fn2(self, _y: B) { drop_var(self, ()) } + fn crate_trait_fn3(&mut self, _x: A, _y: B) { drop_var(self, ()) } + fn crate_trait_fn4(self, _x: [A; 0], _y: [B]) { drop_var(self, ()) } + fn crate_trait_fn5(self, _x: [A; N], _y: [B; M]) -> [A; 0] { drop_var(self, ()); [] } + fn crate_trait_fn6(self, _x: [A; N], _y: [B; M]) -> [A; 0] { drop_var(self, ()); [] } + } + + pub(crate) fn PubCrateLiblocal_fn2(_x: (A, B), _y: [Field; N], _z: [Field; M]) -> [A; 0] { + assert(IMPORT_GLOBAL_N_3 != IMPORT_GLOBAL_M_3); + [] + } + } + + + use crate::local_module::{local_fn4, LocalStruct, DefinedInLocalModule1, LOCAL_GLOBAL_N, LOCAL_GLOBAL_M}; + + use library::{ReExportMeFromAnotherLib1, LibLocalStruct1, LibDefinedInLocalModule1, Liblocal_fn1, IMPORT_GLOBAL_N_1, IMPORT_GLOBAL_M_1}; + + // overlapping + // use library::library2::ReExportMeFromAnotherLib1; + use crate::library::library2::{PubLibLocalStruct3, PubLibDefinedInLocalModule3, PubLiblocal_fn3, IMPORT_GLOBAL_N_2, IMPORT_GLOBAL_M_2}; + + use library3::{ReExportMeFromAnotherLib2, PubCrateLibLocalStruct2, PubCrateLibDefinedInLocalModule2, PubCrateLiblocal_fn2, IMPORT_GLOBAL_N_3, IMPORT_GLOBAL_M_3}; + + + fn main(_x: ReExportMeFromAnotherLib1, _y: ReExportMeFromAnotherLib2) { + assert(LOCAL_GLOBAL_N != LOCAL_GLOBAL_M); + assert(IMPORT_GLOBAL_N_1 != IMPORT_GLOBAL_M_1); + assert(IMPORT_GLOBAL_N_2 != IMPORT_GLOBAL_M_2); + assert(IMPORT_GLOBAL_N_3 != IMPORT_GLOBAL_M_3); + + let x: LocalStruct = Default::default(); + assert(drop_var(x.trait_fn5([0; LOCAL_GLOBAL_N], [false; LOCAL_GLOBAL_M]), true)); + assert(drop_var(x.trait_fn6([0; LOCAL_GLOBAL_N], [false; LOCAL_GLOBAL_M]), true)); + + let x: LibLocalStruct1 = Default::default(); + assert(drop_var(x.lib_trait_fn5([0; IMPORT_GLOBAL_N_1], [false; IMPORT_GLOBAL_M_1]), true)); + assert(drop_var(x.lib_trait_fn6([0; IMPORT_GLOBAL_N_1], [false; IMPORT_GLOBAL_M_1]), true)); + + let x: PubLibLocalStruct3 = Default::default(); + assert(drop_var(x.pub_trait_fn5([0; IMPORT_GLOBAL_N_2], [false; IMPORT_GLOBAL_M_2]), true)); + assert(drop_var(x.pub_trait_fn6([0; IMPORT_GLOBAL_N_2], [false; IMPORT_GLOBAL_M_2]), true)); + + let x: PubCrateLibLocalStruct2 = Default::default(); + assert(drop_var(x.crate_trait_fn5([0; IMPORT_GLOBAL_N_3], [false; IMPORT_GLOBAL_M_3]), true)); + assert(drop_var(x.crate_trait_fn6([0; IMPORT_GLOBAL_N_3], [false; IMPORT_GLOBAL_M_3]), true)); + + assert(drop_var(local_fn2((0, 1), [], []), true)); + assert(drop_var(Liblocal_fn1((0, 1), [], []), true)); + assert(drop_var(PubLiblocal_fn4((0, 1), [], []), true)); + assert(drop_var(PubCrateLiblocal_fn3((0, 1), [], []), true)); + }"; + + // NOTE: these names must be "replacement-unique", i.e. + // replacing one in a discinct name should do nothing + let names_to_collapse = [ + "DefinedInLocalModule1", + "IMPORT_GLOBAL_M_1", + "IMPORT_GLOBAL_M_2", + "IMPORT_GLOBAL_M_3", + "IMPORT_GLOBAL_N_1", + "IMPORT_GLOBAL_N_2", + "IMPORT_GLOBAL_N_3", + "LOCAL_GLOBAL_M", + "LOCAL_GLOBAL_N", + "LibDefinedInLocalModule1", + "LibLocalStruct1", + "Liblocal_fn1", + "LocalStruct", + "PubCrateLibDefinedInLocalModule2", + "PubCrateLibLocalStruct2", + "PubCrateLiblocal_fn2", + "PubLibDefinedInLocalModule3", + "PubLibLocalStruct3", + "PubLiblocal_fn3", + "ReExportMeFromAnotherLib1", + "ReExportMeFromAnotherLib2", + "local_fn4", + "crate_field1", + "crate_field2", + "crate_field3", + "crate_field4", + "crate_field5", + "crate_trait_fn1", + "crate_trait_fn2", + "crate_trait_fn3", + "crate_trait_fn4", + "crate_trait_fn5", + "crate_trait_fn6", + "crate_trait_fn7", + "field1", + "field2", + "field3", + "field4", + "field5", + "lib_field1", + "lib_field2", + "lib_field3", + "lib_field4", + "lib_field5", + "lib_trait_fn1", + "lib_trait_fn2", + "lib_trait_fn3", + "lib_trait_fn4", + "lib_trait_fn5", + "lib_trait_fn6", + "lib_trait_fn7", + "pub_field1", + "pub_field2", + "pub_field3", + "pub_field4", + "pub_field5", + "pub_trait_fn1", + "pub_trait_fn2", + "pub_trait_fn3", + "pub_trait_fn4", + "pub_trait_fn5", + "pub_trait_fn6", + "pub_trait_fn7", + "trait_fn1", + "trait_fn2", + "trait_fn3", + "trait_fn4", + "trait_fn5", + "trait_fn6", + "trait_fn7", + ]; + + // TODO(https://github.com/noir-lang/noir/issues/4973): + // Name resolution panic from name shadowing test + let cases_to_skip = [ + (1, 21), + (2, 11), + (2, 21), + (3, 11), + (3, 18), + (3, 21), + (4, 21), + (5, 11), + (5, 21), + (6, 11), + (6, 18), + (6, 21), + ]; + let cases_to_skip: HashSet<(usize, usize)> = cases_to_skip.into_iter().collect(); + + for (i, x) in names_to_collapse.iter().enumerate() { + for (j, y) in names_to_collapse.iter().enumerate().filter(|(j, _)| i < *j) { + if !cases_to_skip.contains(&(i, j)) { + dbg!((i, j)); + + let modified_src = src.replace(x, y); + let errors = get_program_errors(&modified_src); + assert!(!errors.is_empty(), "Expected errors, got: {:?}", errors); + } + } + } +} diff --git a/noir/noir-repo/compiler/wasm/src/compile.rs b/noir/noir-repo/compiler/wasm/src/compile.rs index de157a1fe20..57b17a6f79e 100644 --- a/noir/noir-repo/compiler/wasm/src/compile.rs +++ b/noir/noir-repo/compiler/wasm/src/compile.rs @@ -1,3 +1,4 @@ +use acvm::acir::circuit::ExpressionWidth; use fm::FileManager; use gloo_utils::format::JsValueSerdeExt; use js_sys::{JsString, Object}; @@ -169,9 +170,10 @@ pub fn compile_program( console_error_panic_hook::set_once(); let (crate_id, mut context) = prepare_context(entry_point, dependency_graph, file_source_map)?; - let compile_options = CompileOptions::default(); - // For now we default to a bounded width of 3, though we can add it as a parameter - let expression_width = acvm::acir::circuit::ExpressionWidth::Bounded { width: 3 }; + let compile_options = CompileOptions { + expression_width: ExpressionWidth::Bounded { width: 4 }, + ..CompileOptions::default() + }; let compiled_program = noirc_driver::compile_main(&mut context, crate_id, &compile_options, None) @@ -184,7 +186,8 @@ pub fn compile_program( })? .0; - let optimized_program = nargo::ops::transform_program(compiled_program, expression_width); + let optimized_program = + nargo::ops::transform_program(compiled_program, compile_options.expression_width); let warnings = optimized_program.warnings.clone(); Ok(JsCompileProgramResult::new(optimized_program.into(), warnings)) @@ -199,9 +202,10 @@ pub fn compile_contract( console_error_panic_hook::set_once(); let (crate_id, mut context) = prepare_context(entry_point, dependency_graph, file_source_map)?; - let compile_options = CompileOptions::default(); - // For now we default to a bounded width of 3, though we can add it as a parameter - let expression_width = acvm::acir::circuit::ExpressionWidth::Bounded { width: 3 }; + let compile_options = CompileOptions { + expression_width: ExpressionWidth::Bounded { width: 4 }, + ..CompileOptions::default() + }; let compiled_contract = noirc_driver::compile_contract(&mut context, crate_id, &compile_options) @@ -214,7 +218,8 @@ pub fn compile_contract( })? .0; - let optimized_contract = nargo::ops::transform_contract(compiled_contract, expression_width); + let optimized_contract = + nargo::ops::transform_contract(compiled_contract, compile_options.expression_width); let functions = optimized_contract.functions.into_iter().map(ContractFunctionArtifact::from).collect(); diff --git a/noir/noir-repo/compiler/wasm/src/compile_new.rs b/noir/noir-repo/compiler/wasm/src/compile_new.rs index c187fe7f3de..4f11cafb975 100644 --- a/noir/noir-repo/compiler/wasm/src/compile_new.rs +++ b/noir/noir-repo/compiler/wasm/src/compile_new.rs @@ -3,6 +3,7 @@ use crate::compile::{ PathToFileSourceMap, }; use crate::errors::{CompileError, JsCompileError}; +use acvm::acir::circuit::ExpressionWidth; use nargo::artifacts::contract::{ContractArtifact, ContractFunctionArtifact}; use nargo::parse_all; use noirc_driver::{ @@ -96,11 +97,14 @@ impl CompilerContext { mut self, program_width: usize, ) -> Result { - let compile_options = CompileOptions::default(); - let np_language = acvm::acir::circuit::ExpressionWidth::Bounded { width: program_width }; + let expression_width = if program_width == 0 { + ExpressionWidth::Unbounded + } else { + ExpressionWidth::Bounded { width: 4 } + }; + let compile_options = CompileOptions { expression_width, ..CompileOptions::default() }; let root_crate_id = *self.context.root_crate_id(); - let compiled_program = compile_main(&mut self.context, root_crate_id, &compile_options, None) .map_err(|errs| { @@ -112,7 +116,8 @@ impl CompilerContext { })? .0; - let optimized_program = nargo::ops::transform_program(compiled_program, np_language); + let optimized_program = + nargo::ops::transform_program(compiled_program, compile_options.expression_width); let warnings = optimized_program.warnings.clone(); Ok(JsCompileProgramResult::new(optimized_program.into(), warnings)) @@ -122,10 +127,14 @@ impl CompilerContext { mut self, program_width: usize, ) -> Result { - let compile_options = CompileOptions::default(); - let np_language = acvm::acir::circuit::ExpressionWidth::Bounded { width: program_width }; - let root_crate_id = *self.context.root_crate_id(); + let expression_width = if program_width == 0 { + ExpressionWidth::Unbounded + } else { + ExpressionWidth::Bounded { width: 4 } + }; + let compile_options = CompileOptions { expression_width, ..CompileOptions::default() }; + let root_crate_id = *self.context.root_crate_id(); let compiled_contract = compile_contract(&mut self.context, root_crate_id, &compile_options) .map_err(|errs| { @@ -137,7 +146,8 @@ impl CompilerContext { })? .0; - let optimized_contract = nargo::ops::transform_contract(compiled_contract, np_language); + let optimized_contract = + nargo::ops::transform_contract(compiled_contract, compile_options.expression_width); let functions = optimized_contract.functions.into_iter().map(ContractFunctionArtifact::from).collect(); @@ -166,7 +176,7 @@ pub fn compile_program_( let compiler_context = prepare_compiler_context(entry_point, dependency_graph, file_source_map)?; - let program_width = 3; + let program_width = 4; compiler_context.compile_program(program_width) } @@ -183,7 +193,7 @@ pub fn compile_contract_( let compiler_context = prepare_compiler_context(entry_point, dependency_graph, file_source_map)?; - let program_width = 3; + let program_width = 4; compiler_context.compile_contract(program_width) } diff --git a/noir/noir-repo/cspell.json b/noir/noir-repo/cspell.json index bf3040265c2..1fbbe5c428d 100644 --- a/noir/noir-repo/cspell.json +++ b/noir/noir-repo/cspell.json @@ -18,6 +18,7 @@ "Backpropagation", "barebones", "barretenberg", + "barustenberg", "bincode", "bindgen", "bitand", diff --git a/noir/noir-repo/docs/docs/noir/concepts/data_types/integers.md b/noir/noir-repo/docs/docs/noir/concepts/data_types/integers.md index 1c6b375db49..6b2d3773912 100644 --- a/noir/noir-repo/docs/docs/noir/concepts/data_types/integers.md +++ b/noir/noir-repo/docs/docs/noir/concepts/data_types/integers.md @@ -5,7 +5,9 @@ keywords: [noir, integer types, methods, examples, arithmetic] sidebar_position: 1 --- -An integer type is a range constrained field type. The Noir frontend supports both unsigned and signed integer types. The allowed sizes are 1, 8, 32 and 64 bits. +An integer type is a range constrained field type. +The Noir frontend supports both unsigned and signed integer types. +The allowed sizes are 1, 8, 16, 32 and 64 bits. :::info diff --git a/noir/noir-repo/docs/docs/noir/standard_library/traits.md b/noir/noir-repo/docs/docs/noir/standard_library/traits.md index b32a2969563..96a7b8e2f22 100644 --- a/noir/noir-repo/docs/docs/noir/standard_library/traits.md +++ b/noir/noir-repo/docs/docs/noir/standard_library/traits.md @@ -186,10 +186,10 @@ These traits abstract over addition, subtraction, multiplication, and division r Implementing these traits for a given type will also allow that type to be used with the corresponding operator for that trait (`+` for Add, etc) in addition to the normal method names. -#include_code add-trait noir_stdlib/src/ops.nr rust -#include_code sub-trait noir_stdlib/src/ops.nr rust -#include_code mul-trait noir_stdlib/src/ops.nr rust -#include_code div-trait noir_stdlib/src/ops.nr rust +#include_code add-trait noir_stdlib/src/ops/arith.nr rust +#include_code sub-trait noir_stdlib/src/ops/arith.nr rust +#include_code mul-trait noir_stdlib/src/ops/arith.nr rust +#include_code div-trait noir_stdlib/src/ops/arith.nr rust The implementations block below is given for the `Add` trait, but the same types that implement `Add` also implement `Sub`, `Mul`, and `Div`. @@ -211,7 +211,7 @@ impl Add for u64 { .. } ### `std::ops::Rem` -#include_code rem-trait noir_stdlib/src/ops.nr rust +#include_code rem-trait noir_stdlib/src/ops/arith.nr rust `Rem::rem(a, b)` is the remainder function returning the result of what is left after dividing `a` and `b`. Implementing `Rem` allows the `%` operator @@ -234,18 +234,27 @@ impl Rem for i64 { fn rem(self, other: i64) -> i64 { self % other } } ### `std::ops::Neg` -#include_code neg-trait noir_stdlib/src/ops.nr rust +#include_code neg-trait noir_stdlib/src/ops/arith.nr rust `Neg::neg` is equivalent to the unary negation operator `-`. Implementations: -#include_code neg-trait-impls noir_stdlib/src/ops.nr rust +#include_code neg-trait-impls noir_stdlib/src/ops/arith.nr rust + +### `std::ops::Not` + +#include_code not-trait noir_stdlib/src/ops/bit.nr rust + +`Not::not` is equivalent to the unary bitwise NOT operator `!`. + +Implementations: +#include_code not-trait-impls noir_stdlib/src/ops/bit.nr rust ### `std::ops::{ BitOr, BitAnd, BitXor }` -#include_code bitor-trait noir_stdlib/src/ops.nr rust -#include_code bitand-trait noir_stdlib/src/ops.nr rust -#include_code bitxor-trait noir_stdlib/src/ops.nr rust +#include_code bitor-trait noir_stdlib/src/ops/bit.nr rust +#include_code bitand-trait noir_stdlib/src/ops/bit.nr rust +#include_code bitxor-trait noir_stdlib/src/ops/bit.nr rust Traits for the bitwise operations `|`, `&`, and `^`. @@ -272,8 +281,8 @@ impl BitOr for i64 { fn bitor(self, other: i64) -> i64 { self | other } } ### `std::ops::{ Shl, Shr }` -#include_code shl-trait noir_stdlib/src/ops.nr rust -#include_code shr-trait noir_stdlib/src/ops.nr rust +#include_code shl-trait noir_stdlib/src/ops/bit.nr rust +#include_code shr-trait noir_stdlib/src/ops/bit.nr rust Traits for a bit shift left and bit shift right. diff --git a/noir/noir-repo/noir_stdlib/src/aes128.nr b/noir/noir-repo/noir_stdlib/src/aes128.nr index ac5c2b48ad8..e6e2a5e4997 100644 --- a/noir/noir-repo/noir_stdlib/src/aes128.nr +++ b/noir/noir-repo/noir_stdlib/src/aes128.nr @@ -1,4 +1,3 @@ - #[foreign(aes128_encrypt)] // docs:start:aes128 pub fn aes128_encrypt(input: [u8; N], iv: [u8; 16], key: [u8; 16]) -> [u8] {} diff --git a/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr b/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr index 6a1f17dae98..21d658db615 100644 --- a/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr +++ b/noir/noir-repo/noir_stdlib/src/embedded_curve_ops.nr @@ -1,4 +1,4 @@ -use crate::ops::{Add, Sub, Neg}; +use crate::ops::arith::{Add, Sub, Neg}; // TODO(https://github.com/noir-lang/noir/issues/4931) struct EmbeddedCurvePoint { @@ -76,7 +76,4 @@ fn embedded_curve_add( } #[foreign(embedded_curve_add)] -fn embedded_curve_add_array_return( - _point1: EmbeddedCurvePoint, - _point2: EmbeddedCurvePoint -) -> [Field; 2] {} +fn embedded_curve_add_array_return(_point1: EmbeddedCurvePoint, _point2: EmbeddedCurvePoint) -> [Field; 2] {} diff --git a/noir/noir-repo/noir_stdlib/src/ops.nr b/noir/noir-repo/noir_stdlib/src/ops.nr index e0814267aea..8b1903cff0b 100644 --- a/noir/noir-repo/noir_stdlib/src/ops.nr +++ b/noir/noir-repo/noir_stdlib/src/ops.nr @@ -1,170 +1,5 @@ -// docs:start:add-trait -trait Add { - fn add(self, other: Self) -> Self; -} -// docs:end:add-trait - -impl Add for Field { fn add(self, other: Field) -> Field { self + other } } - -impl Add for u64 { fn add(self, other: u64) -> u64 { self + other } } -impl Add for u32 { fn add(self, other: u32) -> u32 { self + other } } -impl Add for u8 { fn add(self, other: u8) -> u8 { self + other } } - -impl Add for i8 { fn add(self, other: i8) -> i8 { self + other } } -impl Add for i32 { fn add(self, other: i32) -> i32 { self + other } } -impl Add for i64 { fn add(self, other: i64) -> i64 { self + other } } - -// docs:start:sub-trait -trait Sub { - fn sub(self, other: Self) -> Self; -} -// docs:end:sub-trait - -impl Sub for Field { fn sub(self, other: Field) -> Field { self - other } } - -impl Sub for u64 { fn sub(self, other: u64) -> u64 { self - other } } -impl Sub for u32 { fn sub(self, other: u32) -> u32 { self - other } } -impl Sub for u8 { fn sub(self, other: u8) -> u8 { self - other } } - -impl Sub for i8 { fn sub(self, other: i8) -> i8 { self - other } } -impl Sub for i32 { fn sub(self, other: i32) -> i32 { self - other } } -impl Sub for i64 { fn sub(self, other: i64) -> i64 { self - other } } - -// docs:start:mul-trait -trait Mul { - fn mul(self, other: Self) -> Self; -} -// docs:end:mul-trait - -impl Mul for Field { fn mul(self, other: Field) -> Field { self * other } } - -impl Mul for u64 { fn mul(self, other: u64) -> u64 { self * other } } -impl Mul for u32 { fn mul(self, other: u32) -> u32 { self * other } } -impl Mul for u8 { fn mul(self, other: u8) -> u8 { self * other } } - -impl Mul for i8 { fn mul(self, other: i8) -> i8 { self * other } } -impl Mul for i32 { fn mul(self, other: i32) -> i32 { self * other } } -impl Mul for i64 { fn mul(self, other: i64) -> i64 { self * other } } - -// docs:start:div-trait -trait Div { - fn div(self, other: Self) -> Self; -} -// docs:end:div-trait - -impl Div for Field { fn div(self, other: Field) -> Field { self / other } } - -impl Div for u64 { fn div(self, other: u64) -> u64 { self / other } } -impl Div for u32 { fn div(self, other: u32) -> u32 { self / other } } -impl Div for u8 { fn div(self, other: u8) -> u8 { self / other } } - -impl Div for i8 { fn div(self, other: i8) -> i8 { self / other } } -impl Div for i32 { fn div(self, other: i32) -> i32 { self / other } } -impl Div for i64 { fn div(self, other: i64) -> i64 { self / other } } - -// docs:start:rem-trait -trait Rem{ - fn rem(self, other: Self) -> Self; -} -// docs:end:rem-trait - -impl Rem for u64 { fn rem(self, other: u64) -> u64 { self % other } } -impl Rem for u32 { fn rem(self, other: u32) -> u32 { self % other } } -impl Rem for u8 { fn rem(self, other: u8) -> u8 { self % other } } - -impl Rem for i8 { fn rem(self, other: i8) -> i8 { self % other } } -impl Rem for i32 { fn rem(self, other: i32) -> i32 { self % other } } -impl Rem for i64 { fn rem(self, other: i64) -> i64 { self % other } } - -// docs:start:neg-trait -trait Neg { - fn neg(self) -> Self; -} -// docs:end:neg-trait - -// docs:start:neg-trait-impls -impl Neg for Field { fn neg(self) -> Field { -self } } - -impl Neg for i8 { fn neg(self) -> i8 { -self } } -impl Neg for i32 { fn neg(self) -> i32 { -self } } -impl Neg for i64 { fn neg(self) -> i64 { -self } } -// docs:end:neg-trait-impls - -// docs:start:bitor-trait -trait BitOr { - fn bitor(self, other: Self) -> Self; -} -// docs:end:bitor-trait - -impl BitOr for bool { fn bitor(self, other: bool) -> bool { self | other } } - -impl BitOr for u64 { fn bitor(self, other: u64) -> u64 { self | other } } -impl BitOr for u32 { fn bitor(self, other: u32) -> u32 { self | other } } -impl BitOr for u8 { fn bitor(self, other: u8) -> u8 { self | other } } - -impl BitOr for i8 { fn bitor(self, other: i8) -> i8 { self | other } } -impl BitOr for i32 { fn bitor(self, other: i32) -> i32 { self | other } } -impl BitOr for i64 { fn bitor(self, other: i64) -> i64 { self | other } } - -// docs:start:bitand-trait -trait BitAnd { - fn bitand(self, other: Self) -> Self; -} -// docs:end:bitand-trait - -impl BitAnd for bool { fn bitand(self, other: bool) -> bool { self & other } } - -impl BitAnd for u64 { fn bitand(self, other: u64) -> u64 { self & other } } -impl BitAnd for u32 { fn bitand(self, other: u32) -> u32 { self & other } } -impl BitAnd for u8 { fn bitand(self, other: u8) -> u8 { self & other } } - -impl BitAnd for i8 { fn bitand(self, other: i8) -> i8 { self & other } } -impl BitAnd for i32 { fn bitand(self, other: i32) -> i32 { self & other } } -impl BitAnd for i64 { fn bitand(self, other: i64) -> i64 { self & other } } - -// docs:start:bitxor-trait -trait BitXor { - fn bitxor(self, other: Self) -> Self; -} -// docs:end:bitxor-trait - -impl BitXor for bool { fn bitxor(self, other: bool) -> bool { self ^ other } } - -impl BitXor for u64 { fn bitxor(self, other: u64) -> u64 { self ^ other } } -impl BitXor for u32 { fn bitxor(self, other: u32) -> u32 { self ^ other } } -impl BitXor for u8 { fn bitxor(self, other: u8) -> u8 { self ^ other } } - -impl BitXor for i8 { fn bitxor(self, other: i8) -> i8 { self ^ other } } -impl BitXor for i32 { fn bitxor(self, other: i32) -> i32 { self ^ other } } -impl BitXor for i64 { fn bitxor(self, other: i64) -> i64 { self ^ other } } - -// docs:start:shl-trait -trait Shl { - fn shl(self, other: u8) -> Self; -} -// docs:end:shl-trait - -impl Shl for u32 { fn shl(self, other: u8) -> u32 { self << other } } -impl Shl for u64 { fn shl(self, other: u8) -> u64 { self << other } } -impl Shl for u8 { fn shl(self, other: u8) -> u8 { self << other } } -impl Shl for u1 { fn shl(self, other: u8) -> u1 { self << other } } - -impl Shl for i8 { fn shl(self, other: u8) -> i8 { self << other } } -impl Shl for i32 { fn shl(self, other: u8) -> i32 { self << other } } -impl Shl for i64 { fn shl(self, other: u8) -> i64 { self << other } } - -// docs:start:shr-trait -trait Shr { - fn shr(self, other: u8) -> Self; -} -// docs:end:shr-trait - -impl Shr for u64 { fn shr(self, other: u8) -> u64 { self >> other } } -impl Shr for u32 { fn shr(self, other: u8) -> u32 { self >> other } } -impl Shr for u8 { fn shr(self, other: u8) -> u8 { self >> other } } -impl Shr for u1 { fn shr(self, other: u8) -> u1 { self >> other } } - -impl Shr for i8 { fn shr(self, other: u8) -> i8 { self >> other } } -impl Shr for i32 { fn shr(self, other: u8) -> i32 { self >> other } } -impl Shr for i64 { fn shr(self, other: u8) -> i64 { self >> other } } +mod arith; +mod bit; +use arith::{Add, Sub, Mul, Div, Rem, Neg}; +use bit::{Not, BitOr, BitAnd, BitXor, Shl, Shr}; diff --git a/noir/noir-repo/noir_stdlib/src/ops/arith.nr b/noir/noir-repo/noir_stdlib/src/ops/arith.nr new file mode 100644 index 00000000000..df0ff978a7c --- /dev/null +++ b/noir/noir-repo/noir_stdlib/src/ops/arith.nr @@ -0,0 +1,103 @@ +// docs:start:add-trait +trait Add { + fn add(self, other: Self) -> Self; +} +// docs:end:add-trait + +impl Add for Field { fn add(self, other: Field) -> Field { self + other } } + +impl Add for u64 { fn add(self, other: u64) -> u64 { self + other } } +impl Add for u32 { fn add(self, other: u32) -> u32 { self + other } } +impl Add for u16 { fn add(self, other: u16) -> u16 { self + other } } +impl Add for u8 { fn add(self, other: u8) -> u8 { self + other } } + +impl Add for i8 { fn add(self, other: i8) -> i8 { self + other } } +impl Add for i16 { fn add(self, other: i16) -> i16 { self + other } } +impl Add for i32 { fn add(self, other: i32) -> i32 { self + other } } +impl Add for i64 { fn add(self, other: i64) -> i64 { self + other } } + +// docs:start:sub-trait +trait Sub { + fn sub(self, other: Self) -> Self; +} +// docs:end:sub-trait + +impl Sub for Field { fn sub(self, other: Field) -> Field { self - other } } + +impl Sub for u64 { fn sub(self, other: u64) -> u64 { self - other } } +impl Sub for u32 { fn sub(self, other: u32) -> u32 { self - other } } +impl Sub for u16 { fn sub(self, other: u16) -> u16 { self - other } } +impl Sub for u8 { fn sub(self, other: u8) -> u8 { self - other } } + +impl Sub for i8 { fn sub(self, other: i8) -> i8 { self - other } } +impl Sub for i16 { fn sub(self, other: i16) -> i16 { self - other } } +impl Sub for i32 { fn sub(self, other: i32) -> i32 { self - other } } +impl Sub for i64 { fn sub(self, other: i64) -> i64 { self - other } } + +// docs:start:mul-trait +trait Mul { + fn mul(self, other: Self) -> Self; +} +// docs:end:mul-trait + +impl Mul for Field { fn mul(self, other: Field) -> Field { self * other } } + +impl Mul for u64 { fn mul(self, other: u64) -> u64 { self * other } } +impl Mul for u32 { fn mul(self, other: u32) -> u32 { self * other } } +impl Mul for u16 { fn mul(self, other: u16) -> u16 { self * other } } +impl Mul for u8 { fn mul(self, other: u8) -> u8 { self * other } } + +impl Mul for i8 { fn mul(self, other: i8) -> i8 { self * other } } +impl Mul for i16 { fn mul(self, other: i16) -> i16 { self * other } } +impl Mul for i32 { fn mul(self, other: i32) -> i32 { self * other } } +impl Mul for i64 { fn mul(self, other: i64) -> i64 { self * other } } + +// docs:start:div-trait +trait Div { + fn div(self, other: Self) -> Self; +} +// docs:end:div-trait + +impl Div for Field { fn div(self, other: Field) -> Field { self / other } } + +impl Div for u64 { fn div(self, other: u64) -> u64 { self / other } } +impl Div for u32 { fn div(self, other: u32) -> u32 { self / other } } +impl Div for u16 { fn div(self, other: u16) -> u16 { self / other } } +impl Div for u8 { fn div(self, other: u8) -> u8 { self / other } } + +impl Div for i8 { fn div(self, other: i8) -> i8 { self / other } } +impl Div for i16 { fn div(self, other: i16) -> i16 { self / other } } +impl Div for i32 { fn div(self, other: i32) -> i32 { self / other } } +impl Div for i64 { fn div(self, other: i64) -> i64 { self / other } } + +// docs:start:rem-trait +trait Rem{ + fn rem(self, other: Self) -> Self; +} +// docs:end:rem-trait + +impl Rem for u64 { fn rem(self, other: u64) -> u64 { self % other } } +impl Rem for u32 { fn rem(self, other: u32) -> u32 { self % other } } +impl Rem for u16 { fn rem(self, other: u16) -> u16 { self % other } } +impl Rem for u8 { fn rem(self, other: u8) -> u8 { self % other } } + +impl Rem for i8 { fn rem(self, other: i8) -> i8 { self % other } } +impl Rem for i16 { fn rem(self, other: i16) -> i16 { self % other } } +impl Rem for i32 { fn rem(self, other: i32) -> i32 { self % other } } +impl Rem for i64 { fn rem(self, other: i64) -> i64 { self % other } } + +// docs:start:neg-trait +trait Neg { + fn neg(self) -> Self; +} +// docs:end:neg-trait + +// docs:start:neg-trait-impls +impl Neg for Field { fn neg(self) -> Field { -self } } + +impl Neg for i8 { fn neg(self) -> i8 { -self } } +impl Neg for i16 { fn neg(self) -> i16 { -self } } +impl Neg for i32 { fn neg(self) -> i32 { -self } } +impl Neg for i64 { fn neg(self) -> i64 { -self } } +// docs:end:neg-trait-impls + diff --git a/noir/noir-repo/noir_stdlib/src/ops/bit.nr b/noir/noir-repo/noir_stdlib/src/ops/bit.nr new file mode 100644 index 00000000000..a31cfee878c --- /dev/null +++ b/noir/noir-repo/noir_stdlib/src/ops/bit.nr @@ -0,0 +1,109 @@ +// docs:start:not-trait +trait Not { + fn not(self: Self) -> Self; +} +// docs:end:not-trait + +// docs:start:not-trait-impls +impl Not for bool { fn not(self) -> bool { !self } } + +impl Not for u64 { fn not(self) -> u64 { !self } } +impl Not for u32 { fn not(self) -> u32 { !self } } +impl Not for u16 { fn not(self) -> u16 { !self } } +impl Not for u8 { fn not(self) -> u8 { !self } } +impl Not for u1 { fn not(self) -> u1 { !self } } + +impl Not for i8 { fn not(self) -> i8 { !self } } +impl Not for i16 { fn not(self) -> i16 { !self } } +impl Not for i32 { fn not(self) -> i32 { !self } } +impl Not for i64 { fn not(self) -> i64 { !self } } +// docs:end:not-trait-impls + +// docs:start:bitor-trait +trait BitOr { + fn bitor(self, other: Self) -> Self; +} +// docs:end:bitor-trait + +impl BitOr for bool { fn bitor(self, other: bool) -> bool { self | other } } + +impl BitOr for u64 { fn bitor(self, other: u64) -> u64 { self | other } } +impl BitOr for u32 { fn bitor(self, other: u32) -> u32 { self | other } } +impl BitOr for u16 { fn bitor(self, other: u16) -> u16 { self | other } } +impl BitOr for u8 { fn bitor(self, other: u8) -> u8 { self | other } } + +impl BitOr for i8 { fn bitor(self, other: i8) -> i8 { self | other } } +impl BitOr for i16 { fn bitor(self, other: i16) -> i16 { self | other } } +impl BitOr for i32 { fn bitor(self, other: i32) -> i32 { self | other } } +impl BitOr for i64 { fn bitor(self, other: i64) -> i64 { self | other } } + +// docs:start:bitand-trait +trait BitAnd { + fn bitand(self, other: Self) -> Self; +} +// docs:end:bitand-trait + +impl BitAnd for bool { fn bitand(self, other: bool) -> bool { self & other } } + +impl BitAnd for u64 { fn bitand(self, other: u64) -> u64 { self & other } } +impl BitAnd for u32 { fn bitand(self, other: u32) -> u32 { self & other } } +impl BitAnd for u16 { fn bitand(self, other: u16) -> u16 { self & other } } +impl BitAnd for u8 { fn bitand(self, other: u8) -> u8 { self & other } } + +impl BitAnd for i8 { fn bitand(self, other: i8) -> i8 { self & other } } +impl BitAnd for i16 { fn bitand(self, other: i16) -> i16 { self & other } } +impl BitAnd for i32 { fn bitand(self, other: i32) -> i32 { self & other } } +impl BitAnd for i64 { fn bitand(self, other: i64) -> i64 { self & other } } + +// docs:start:bitxor-trait +trait BitXor { + fn bitxor(self, other: Self) -> Self; +} +// docs:end:bitxor-trait + +impl BitXor for bool { fn bitxor(self, other: bool) -> bool { self ^ other } } + +impl BitXor for u64 { fn bitxor(self, other: u64) -> u64 { self ^ other } } +impl BitXor for u32 { fn bitxor(self, other: u32) -> u32 { self ^ other } } +impl BitXor for u16 { fn bitxor(self, other: u16) -> u16 { self ^ other } } +impl BitXor for u8 { fn bitxor(self, other: u8) -> u8 { self ^ other } } + +impl BitXor for i8 { fn bitxor(self, other: i8) -> i8 { self ^ other } } +impl BitXor for i16 { fn bitxor(self, other: i16) -> i16 { self ^ other } } +impl BitXor for i32 { fn bitxor(self, other: i32) -> i32 { self ^ other } } +impl BitXor for i64 { fn bitxor(self, other: i64) -> i64 { self ^ other } } + +// docs:start:shl-trait +trait Shl { + fn shl(self, other: u8) -> Self; +} +// docs:end:shl-trait + +impl Shl for u32 { fn shl(self, other: u8) -> u32 { self << other } } +impl Shl for u64 { fn shl(self, other: u8) -> u64 { self << other } } +impl Shl for u16 { fn shl(self, other: u8) -> u16 { self << other } } +impl Shl for u8 { fn shl(self, other: u8) -> u8 { self << other } } +impl Shl for u1 { fn shl(self, other: u8) -> u1 { self << other } } + +impl Shl for i8 { fn shl(self, other: u8) -> i8 { self << other } } +impl Shl for i16 { fn shl(self, other: u8) -> i16 { self << other } } +impl Shl for i32 { fn shl(self, other: u8) -> i32 { self << other } } +impl Shl for i64 { fn shl(self, other: u8) -> i64 { self << other } } + +// docs:start:shr-trait +trait Shr { + fn shr(self, other: u8) -> Self; +} +// docs:end:shr-trait + +impl Shr for u64 { fn shr(self, other: u8) -> u64 { self >> other } } +impl Shr for u32 { fn shr(self, other: u8) -> u32 { self >> other } } +impl Shr for u16 { fn shr(self, other: u8) -> u16 { self >> other } } +impl Shr for u8 { fn shr(self, other: u8) -> u8 { self >> other } } +impl Shr for u1 { fn shr(self, other: u8) -> u1 { self >> other } } + +impl Shr for i8 { fn shr(self, other: u8) -> i8 { self >> other } } +impl Shr for i16 { fn shr(self, other: u8) -> i16 { self >> other } } +impl Shr for i32 { fn shr(self, other: u8) -> i32 { self >> other } } +impl Shr for i64 { fn shr(self, other: u8) -> i64 { self >> other } } + diff --git a/noir/noir-repo/noir_stdlib/src/uint128.nr b/noir/noir-repo/noir_stdlib/src/uint128.nr index d0f38079e6f..173fa54863a 100644 --- a/noir/noir-repo/noir_stdlib/src/uint128.nr +++ b/noir/noir-repo/noir_stdlib/src/uint128.nr @@ -1,8 +1,9 @@ -use crate::ops::{Add, Sub, Mul, Div, Rem, BitOr, BitAnd, BitXor, Shl, Shr}; +use crate::ops::{Add, Sub, Mul, Div, Rem, Not, BitOr, BitAnd, BitXor, Shl, Shr}; use crate::cmp::{Eq, Ord, Ordering}; +use crate::println; global pow64 : Field = 18446744073709551616; //2^64; - +global pow63 : Field = 9223372036854775808; // 2^63; struct U128 { lo: Field, hi: Field, @@ -20,6 +21,13 @@ impl U128 { U128::from_u64s_le(lo, hi) } + pub fn zero() -> U128 { + U128 { lo: 0, hi: 0 } + } + + pub fn one() -> U128 { + U128 { lo: 1, hi: 0 } + } pub fn from_le_bytes(bytes: [u8; 16]) -> U128 { let mut lo = 0; let mut base = 1; @@ -87,27 +95,44 @@ impl U128 { U128 { lo: lo as Field, hi: hi as Field } } + unconstrained fn uconstrained_check_is_upper_ascii(ascii: u8) -> bool { + ((ascii >= 65) & (ascii <= 90)) // Between 'A' and 'Z' + } + fn decode_ascii(ascii: u8) -> Field { if ascii < 58 { ascii - 48 - } else if ascii < 71 { - ascii - 55 } else { + let ascii = ascii + 32 * (U128::uconstrained_check_is_upper_ascii(ascii) as u8); + assert(ascii >= 97); // enforce >= 'a' + assert(ascii <= 102); // enforce <= 'f' ascii - 87 } as Field } + // TODO: Replace with a faster version. + // A circuit that uses this function can be slow to compute + // (we're doing up to 127 calls to compute the quotient) unconstrained fn unconstrained_div(self: Self, b: U128) -> (U128, U128) { - if self < b { - (U128::from_u64s_le(0, 0), self) + if b == U128::zero() { + // Return 0,0 to avoid eternal loop + (U128::zero(), U128::zero()) + } else if self < b { + (U128::zero(), self) + } else if self == b { + (U128::one(), U128::zero()) } else { - //TODO check if this can overflow? - let (q,r) = self.unconstrained_div(b * U128::from_u64s_le(2, 0)); + let (q,r) = if b.hi as u64 >= pow63 as u64 { + // The result of multiplication by 2 would overflow + (U128::zero(), self) + } else { + self.unconstrained_div(b * U128::from_u64s_le(2, 0)) + }; let q_mul_2 = q * U128::from_u64s_le(2, 0); if r < b { (q_mul_2, r) } else { - (q_mul_2 + U128::from_u64s_le(1, 0), r - b) + (q_mul_2 + U128::one(), r - b) } } } @@ -129,11 +154,7 @@ impl U128 { let low = self.lo * b.lo; let lo = low as u64 as Field; let carry = (low - lo) / pow64; - let high = if crate::field::modulus_num_bits() as u32 > 196 { - (self.lo + self.hi) * (b.lo + b.hi) - low + carry - } else { - self.lo * b.hi + self.hi * b.lo + carry - }; + let high = self.lo * b.hi + self.hi * b.lo + carry; let hi = high as u64 as Field; U128 { lo, hi } } @@ -228,11 +249,20 @@ impl Ord for U128 { } } +impl Not for U128 { + fn not(self) -> U128 { + U128 { + lo: (!(self.lo as u64)) as Field, + hi: (!(self.hi as u64)) as Field + } + } +} + impl BitOr for U128 { fn bitor(self, other: U128) -> U128 { U128 { lo: ((self.lo as u64) | (other.lo as u64)) as Field, - hi: ((self.hi as u64) | (other.hi as u64))as Field + hi: ((self.hi as u64) | (other.hi as u64)) as Field } } } @@ -284,3 +314,213 @@ impl Shr for U128 { self / U128::from_integer(y) } } + +mod tests { + use crate::uint128::{U128, pow64, pow63}; + + #[test] + fn test_not() { + let num = U128::from_u64s_le(0, 0); + let not_num = num.not(); + + let max_u64: Field = pow64 - 1; + assert_eq(not_num.hi, max_u64); + assert_eq(not_num.lo, max_u64); + + let not_not_num = not_num.not(); + assert_eq(num, not_not_num); + } + #[test] + fn test_construction() { + // Check little-endian u64 is inversed with big-endian u64 construction + let a = U128::from_u64s_le(2, 1); + let b = U128::from_u64s_be(1, 2); + assert_eq(a, b); + // Check byte construction is equivalent + let c = U128::from_le_bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]); + let d = U128::from_u64s_le(0x0706050403020100, 0x0f0e0d0c0b0a0908); + assert_eq(c, d); + } + #[test] + fn test_byte_decomposition() { + let a = U128::from_u64s_le(0x0706050403020100, 0x0f0e0d0c0b0a0908); + // Get big-endian and little-endian byte decompostions + let le_bytes_a= a.to_le_bytes(); + let be_bytes_a= a.to_be_bytes(); + + // Check equivalence + for i in 0..16 { + assert_eq(le_bytes_a[i], be_bytes_a[15 - i]); + } + // Reconstruct U128 from byte decomposition + let b= U128::from_le_bytes(le_bytes_a); + // Check that it's the same element + assert_eq(a, b); + } + #[test] + fn test_hex_constuction() { + let a = U128::from_u64s_le(0x1, 0x2); + let b = U128::from_hex("0x20000000000000001"); + assert_eq(a, b); + + let c= U128::from_hex("0xffffffffffffffffffffffffffffffff"); + let d= U128::from_u64s_le(0xffffffffffffffff, 0xffffffffffffffff); + assert_eq(c, d); + + let e= U128::from_hex("0x00000000000000000000000000000000"); + let f= U128::from_u64s_le(0, 0); + assert_eq(e, f); + } + + // Ascii decode tests + + #[test] + fn test_ascii_decode_correct_range() { + // '0'..'9' range + for i in 0..10 { + let decoded= U128::decode_ascii(48 + i); + assert_eq(decoded, i as Field); + } + // 'A'..'F' range + for i in 0..6 { + let decoded = U128::decode_ascii(65 + i); + assert_eq(decoded, (i + 10) as Field); + } + // 'a'..'f' range + for i in 0..6 { + let decoded = U128::decode_ascii(97 + i); + assert_eq(decoded, (i + 10) as Field); + } + } + + #[test(should_fail)] + fn test_ascii_decode_range_less_than_48_fails_0() { + crate::println(U128::decode_ascii(0)); + } + #[test(should_fail)] + fn test_ascii_decode_range_less_than_48_fails_1() { + crate::println(U128::decode_ascii(47)); + } + + #[test(should_fail)] + fn test_ascii_decode_range_58_64_fails_0() { + let _ = U128::decode_ascii(58); + } + #[test(should_fail)] + fn test_ascii_decode_range_58_64_fails_1() { + let _ = U128::decode_ascii(64); + } + #[test(should_fail)] + fn test_ascii_decode_range_71_96_fails_0() { + let _ = U128::decode_ascii(71); + } + #[test(should_fail)] + fn test_ascii_decode_range_71_96_fails_1() { + let _ = U128::decode_ascii(96); + } + #[test(should_fail)] + fn test_ascii_decode_range_greater_than_102_fails() { + let _ = U128::decode_ascii(103); + } + + #[test(should_fail)] + fn test_ascii_decode_regression() { + // This code will actually fail because of ascii_decode, + // but in the past it was possible to create a value > (1<<128) + let a = U128::from_hex("0x~fffffffffffffffffffffffffffffff"); + let b:Field= a.to_integer(); + let c= b.to_le_bytes(17); + assert(c[16] != 0); + } + + #[test] + fn test_unconstrained_div() { + // Test the potential overflow case + let a= U128::from_u64s_le(0x0, 0xffffffffffffffff); + let b= U128::from_u64s_le(0x0, 0xfffffffffffffffe); + let c= U128::one(); + let d= U128::from_u64s_le(0x0, 0x1); + let (q,r) = a.unconstrained_div(b); + assert_eq(q, c); + assert_eq(r, d); + + let a = U128::from_u64s_le(2, 0); + let b = U128::one(); + // Check the case where a is a multiple of b + let (c,d ) = a.unconstrained_div(b); + assert_eq((c, d), (a, U128::zero())); + + // Check where b is a multiple of a + let (c,d) = b.unconstrained_div(a); + assert_eq((c, d), (U128::zero(), b)); + + // Dividing by zero returns 0,0 + let a = U128::from_u64s_le(0x1, 0x0); + let b = U128::zero(); + let (c,d)= a.unconstrained_div(b); + assert_eq((c, d), (U128::zero(), U128::zero())); + + // Dividing 1<<127 by 1<<127 (special case) + let a = U128::from_u64s_le(0x0, pow63 as u64); + let b = U128::from_u64s_le(0x0, pow63 as u64); + let (c,d )= a.unconstrained_div(b); + assert_eq((c, d), (U128::one(), U128::zero())); + } + + #[test] + fn integer_conversions() { + // Maximum + let start:Field = 0xffffffffffffffffffffffffffffffff; + let a = U128::from_integer(start); + let end = a.to_integer(); + assert_eq(start, end); + + // Minimum + let start:Field = 0x0; + let a = U128::from_integer(start); + let end = a.to_integer(); + assert_eq(start, end); + + // Low limb + let start:Field = 0xffffffffffffffff; + let a = U128::from_integer(start); + let end = a.to_integer(); + assert_eq(start, end); + + // High limb + let start:Field = 0xffffffffffffffff0000000000000000; + let a = U128::from_integer(start); + let end = a.to_integer(); + assert_eq(start, end); + } + #[test] + fn test_wrapping_mul() { + // 1*0==0 + assert_eq(U128::zero(), U128::zero().wrapping_mul(U128::one())); + + // 0*1==0 + assert_eq(U128::zero(), U128::one().wrapping_mul(U128::zero())); + + // 1*1==1 + assert_eq(U128::one(), U128::one().wrapping_mul(U128::one())); + + // 0 * ( 1 << 64 ) == 0 + assert_eq(U128::zero(), U128::zero().wrapping_mul(U128::from_u64s_le(0, 1))); + + // ( 1 << 64 ) * 0 == 0 + assert_eq(U128::zero(), U128::from_u64s_le(0, 1).wrapping_mul(U128::zero())); + + // 1 * ( 1 << 64 ) == 1 << 64 + assert_eq(U128::from_u64s_le(0, 1), U128::from_u64s_le(0, 1).wrapping_mul(U128::one())); + + // ( 1 << 64 ) * 1 == 1 << 64 + assert_eq(U128::from_u64s_le(0, 1), U128::one().wrapping_mul(U128::from_u64s_le(0, 1))); + + // ( 1 << 64 ) * ( 1 << 64 ) == 1 << 64 + assert_eq(U128::zero(), U128::from_u64s_le(0, 1).wrapping_mul(U128::from_u64s_le(0, 1))); + // -1 * -1 == 1 + assert_eq( + U128::one(), U128::from_u64s_le(0xffffffffffffffff, 0xffffffffffffffff).wrapping_mul(U128::from_u64s_le(0xffffffffffffffff, 0xffffffffffffffff)) + ); + } +} diff --git a/noir/noir-repo/scripts/count_loc.sh b/noir/noir-repo/scripts/count_loc.sh new file mode 100755 index 00000000000..91565aa6c4a --- /dev/null +++ b/noir/noir-repo/scripts/count_loc.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +set -eu + +# Run relative to repo root +cd $(dirname "$0")/../ + +if ! command -v "tokei" >/dev/null 2>&1; then + echo "Error: tokei is required but not installed." >&2 + echo "Error: Run \`cargo install --git https://github.com/TomAFrench/tokei --branch tf/add-noir-support tokei\`" >&2 + + exit 1 +fi + +echo "" +echo "Total:" + +tokei ./ --sort code + +echo "" +echo "ACIR/ACVM:" +tokei ./acvm-repo --sort code + +echo "" +echo "Compiler:" +tokei ./compiler --sort code + +echo "" +echo "Tooling:" +tokei ./tooling --sort code + +echo "" +echo "Standard Library:" +tokei ./noir_stdlib --sort code diff --git a/noir/noir-repo/security/insectarium/noir_stdlib.md b/noir/noir-repo/security/insectarium/noir_stdlib.md new file mode 100644 index 00000000000..5ec4eb5f6cd --- /dev/null +++ b/noir/noir-repo/security/insectarium/noir_stdlib.md @@ -0,0 +1,61 @@ +# Bugs found in Noir stdlib + +## U128 + +### decode_ascii +Old **decode_ascii** function didn't check that the values of individual bytes in the string were just in the range of [0-9a-f-A-F]. +```rust +fn decode_ascii(ascii: u8) -> Field { + if ascii < 58 { + ascii - 48 + } else if ascii < 71 { + ascii - 55 + } else { + ascii - 87 + } as Field +} +``` +Since the function used the assumption that decode_ascii returns values in range [0,15] to construct **lo** and **hi** it was possible to overflow these 64-bit limbs. + +### unconstrained_div +```rust + unconstrained fn unconstrained_div(self: Self, b: U128) -> (U128, U128) { + if self < b { + (U128::from_u64s_le(0, 0), self) + } else { + //TODO check if this can overflow? + let (q,r) = self.unconstrained_div(b * U128::from_u64s_le(2, 0)); + let q_mul_2 = q * U128::from_u64s_le(2, 0); + if r < b { + (q_mul_2, r) + } else { + (q_mul_2 + U128::from_u64s_le(1, 0), r - b) + } + } + } +``` +There were 2 issues in unconstrained_div: +1) Attempting to divide by zero resulted in an infinite loop, because there was no check. +2) $a >= 2^{127}$ cause the function to multiply b to such power of 2 that the result would be more than $2^{128}$ and lead to assertion failure even though it was a legitimate input + +N.B. initial fix by Rumata888 also had an edgecase missing for when a==b and b >= (1<<127). + +### wrapping_mul +```rust +fn wrapping_mul(self: Self, b: U128) -> U128 { + let low = self.lo * b.lo; + let lo = low as u64 as Field; + let carry = (low - lo) / pow64; + let high = if crate::field::modulus_num_bits() as u32 > 196 { + (self.lo + self.hi) * (b.lo + b.hi) - low + carry // Bug + } else { + self.lo * b.hi + self.hi * b.lo + carry + }; + let hi = high as u64 as Field; + U128 { lo, hi } + } +``` +Wrapping mul had the code copied from regular mul barring the assertion that the product of high limbs is zero. Because that check was removed, the optimized path for moduli > 196 bits was incorrect, since it included their product (as at least one of them was supposed to be zero originally, but not for wrapping multiplication) + + + diff --git a/noir/noir-repo/test_programs/execution_success/brillig_embedded_curve/src/main.nr b/noir/noir-repo/test_programs/execution_success/brillig_embedded_curve/src/main.nr index 1a183bb13d9..8a1a7f08975 100644 --- a/noir/noir-repo/test_programs/execution_success/brillig_embedded_curve/src/main.nr +++ b/noir/noir-repo/test_programs/execution_success/brillig_embedded_curve/src/main.nr @@ -1,10 +1,6 @@ use dep::std; -unconstrained fn main( - priv_key: Field, - pub_x: pub Field, - pub_y: pub Field, -) { +unconstrained fn main(priv_key: Field, pub_x: pub Field, pub_y: pub Field) { let g1_y = 17631683881184975370165255887551781615748388533673675138860; let g1 = std::embedded_curve_ops::EmbeddedCurvePoint { x: 1, y: g1_y }; diff --git a/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Nargo.toml b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Nargo.toml new file mode 100644 index 00000000000..328d78c8f99 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "no_predicates_brillig" +type = "bin" +authors = [""] +compiler_version = ">=0.27.0" + +[dependencies] diff --git a/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Prover.toml b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Prover.toml new file mode 100644 index 00000000000..93a825f609f --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/Prover.toml @@ -0,0 +1,2 @@ +x = "10" +y = "20" diff --git a/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/src/main.nr b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/src/main.nr new file mode 100644 index 00000000000..65e2e5d61fe --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/no_predicates_brillig/src/main.nr @@ -0,0 +1,16 @@ +unconstrained fn main(x: u32, y: pub u32) { + intermediate_function(x, y); +} + +fn intermediate_function(x: u32, y: u32) { + basic_checks(x, y); +} + +#[no_predicates] +fn basic_checks(x: u32, y: u32) { + if x > y { + assert(x == 10); + } else { + assert(y == 20); + } +} diff --git a/noir/noir-repo/test_programs/execution_success/u16_support/Nargo.toml b/noir/noir-repo/test_programs/execution_success/u16_support/Nargo.toml new file mode 100644 index 00000000000..1c6b58e01e8 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/u16_support/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "u16_support" +type = "bin" +authors = [""] +compiler_version = ">=0.29.0" + +[dependencies] \ No newline at end of file diff --git a/noir/noir-repo/test_programs/execution_success/u16_support/Prover.toml b/noir/noir-repo/test_programs/execution_success/u16_support/Prover.toml new file mode 100644 index 00000000000..a56a84e61a4 --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/u16_support/Prover.toml @@ -0,0 +1 @@ +x = "2" diff --git a/noir/noir-repo/test_programs/execution_success/u16_support/src/main.nr b/noir/noir-repo/test_programs/execution_success/u16_support/src/main.nr new file mode 100644 index 00000000000..e8b418f16da --- /dev/null +++ b/noir/noir-repo/test_programs/execution_success/u16_support/src/main.nr @@ -0,0 +1,24 @@ +fn main(x: u16) { + test_u16(x); + test_u16_unconstrained(x); +} + +unconstrained fn test_u16_unconstrained(x: u16) { + test_u16(x) +} + +fn test_u16(x: u16) { + let t1: u16 = 1234; + let t2: u16 = 4321; + let t = t1 + t2; + + let t4 = t - t2; + assert(t4 == t1); + + let mut small_int = x as u16; + let shift = small_int << (x as u8); + assert(shift == 8); + assert(shift >> (x as u8) == small_int); + assert(shift >> 15 == 0); + assert(shift << 15 == 0); +} diff --git a/noir/noir-repo/tooling/backend_interface/Cargo.toml b/noir/noir-repo/tooling/backend_interface/Cargo.toml index b731c138c7d..f6b5d5d0132 100644 --- a/noir/noir-repo/tooling/backend_interface/Cargo.toml +++ b/noir/noir-repo/tooling/backend_interface/Cargo.toml @@ -13,7 +13,6 @@ license.workspace = true acvm.workspace = true dirs.workspace = true thiserror.workspace = true -serde.workspace = true serde_json.workspace = true bb_abstraction_leaks.workspace = true tracing.workspace = true diff --git a/noir/noir-repo/tooling/backend_interface/src/cli/info.rs b/noir/noir-repo/tooling/backend_interface/src/cli/info.rs deleted file mode 100644 index 6e6603ce53e..00000000000 --- a/noir/noir-repo/tooling/backend_interface/src/cli/info.rs +++ /dev/null @@ -1,62 +0,0 @@ -use acvm::acir::circuit::ExpressionWidth; - -use serde::Deserialize; -use std::path::{Path, PathBuf}; - -use crate::BackendError; - -use super::string_from_stderr; - -pub(crate) struct InfoCommand { - pub(crate) crs_path: PathBuf, -} - -#[derive(Deserialize)] -struct InfoResponse { - language: LanguageResponse, -} - -#[derive(Deserialize)] -struct LanguageResponse { - name: String, - width: Option, -} - -impl InfoCommand { - pub(crate) fn run(self, binary_path: &Path) -> Result { - let mut command = std::process::Command::new(binary_path); - - command.arg("info").arg("-c").arg(self.crs_path).arg("-o").arg("-"); - - let output = command.output()?; - - if !output.status.success() { - return Err(BackendError::CommandFailed(string_from_stderr(&output.stderr))); - } - - let backend_info: InfoResponse = - serde_json::from_slice(&output.stdout).expect("Backend should return valid json"); - let expression_width: ExpressionWidth = match backend_info.language.name.as_str() { - "PLONK-CSAT" => { - let width = backend_info.language.width.unwrap(); - ExpressionWidth::Bounded { width } - } - "R1CS" => ExpressionWidth::Unbounded, - _ => panic!("Unknown Expression width configuration"), - }; - - Ok(expression_width) - } -} - -#[test] -fn info_command() -> Result<(), BackendError> { - let backend = crate::get_mock_backend()?; - let crs_path = backend.backend_directory(); - - let expression_width = InfoCommand { crs_path }.run(backend.binary_path())?; - - assert!(matches!(expression_width, ExpressionWidth::Bounded { width: 4 })); - - Ok(()) -} diff --git a/noir/noir-repo/tooling/backend_interface/src/cli/mod.rs b/noir/noir-repo/tooling/backend_interface/src/cli/mod.rs index b4dec859839..df43bd5cc2f 100644 --- a/noir/noir-repo/tooling/backend_interface/src/cli/mod.rs +++ b/noir/noir-repo/tooling/backend_interface/src/cli/mod.rs @@ -2,7 +2,6 @@ mod contract; mod gates; -mod info; mod proof_as_fields; mod prove; mod verify; @@ -12,7 +11,6 @@ mod write_vk; pub(crate) use contract::ContractCommand; pub(crate) use gates::GatesCommand; -pub(crate) use info::InfoCommand; pub(crate) use proof_as_fields::ProofAsFieldsCommand; pub(crate) use prove::ProveCommand; pub(crate) use verify::VerifyCommand; diff --git a/noir/noir-repo/tooling/backend_interface/src/proof_system.rs b/noir/noir-repo/tooling/backend_interface/src/proof_system.rs index fa1f82a5722..20a6dcf70f1 100644 --- a/noir/noir-repo/tooling/backend_interface/src/proof_system.rs +++ b/noir/noir-repo/tooling/backend_interface/src/proof_system.rs @@ -3,7 +3,7 @@ use std::io::Write; use std::path::Path; use acvm::acir::{ - circuit::{ExpressionWidth, Program}, + circuit::Program, native_types::{WitnessMap, WitnessStack}, }; use acvm::FieldElement; @@ -11,8 +11,8 @@ use tempfile::tempdir; use tracing::warn; use crate::cli::{ - GatesCommand, InfoCommand, ProofAsFieldsCommand, ProveCommand, VerifyCommand, - VkAsFieldsCommand, WriteVkCommand, + GatesCommand, ProofAsFieldsCommand, ProveCommand, VerifyCommand, VkAsFieldsCommand, + WriteVkCommand, }; use crate::{Backend, BackendError}; @@ -33,25 +33,6 @@ impl Backend { .run(binary_path) } - pub fn get_backend_info(&self) -> Result { - let binary_path = self.assert_binary_exists()?; - self.assert_correct_version()?; - InfoCommand { crs_path: self.crs_directory() }.run(binary_path) - } - - /// If we cannot get a valid backend, returns `ExpressionWidth::Bound { width: 4 }`` - /// The function also prints a message saying we could not find a backend - pub fn get_backend_info_or_default(&self) -> ExpressionWidth { - if let Ok(expression_width) = self.get_backend_info() { - expression_width - } else { - warn!( - "No valid backend found, ExpressionWidth defaulting to Bounded with a width of 4" - ); - ExpressionWidth::Bounded { width: 4 } - } - } - #[tracing::instrument(level = "trace", skip_all)] pub fn prove( &self, diff --git a/noir/noir-repo/tooling/backend_interface/test-binaries/mock_backend/src/info_cmd.rs b/noir/noir-repo/tooling/backend_interface/test-binaries/mock_backend/src/info_cmd.rs deleted file mode 100644 index cdaebb95fc9..00000000000 --- a/noir/noir-repo/tooling/backend_interface/test-binaries/mock_backend/src/info_cmd.rs +++ /dev/null @@ -1,40 +0,0 @@ -use clap::Args; -use std::io::Write; -use std::path::PathBuf; - -const INFO_RESPONSE: &str = r#"{ - "language": { - "name": "PLONK-CSAT", - "width": 4 - }, - "opcodes_supported": ["arithmetic", "directive", "brillig", "memory_init", "memory_op"], - "black_box_functions_supported": [ - "and", - "xor", - "range", - "sha256", - "blake2s", - "blake3", - "keccak256", - "schnorr_verify", - "pedersen", - "pedersen_hash", - "ecdsa_secp256k1", - "ecdsa_secp256r1", - "multi_scalar_mul", - "recursive_aggregation" - ] -}"#; - -#[derive(Debug, Clone, Args)] -pub(crate) struct InfoCommand { - #[clap(short = 'c')] - pub(crate) crs_path: Option, - - #[clap(short = 'o')] - pub(crate) info_path: Option, -} - -pub(crate) fn run(_args: InfoCommand) { - std::io::stdout().write_all(INFO_RESPONSE.as_bytes()).unwrap(); -} diff --git a/noir/noir-repo/tooling/backend_interface/test-binaries/mock_backend/src/main.rs b/noir/noir-repo/tooling/backend_interface/test-binaries/mock_backend/src/main.rs index ef8819af94b..74ea82d28f8 100644 --- a/noir/noir-repo/tooling/backend_interface/test-binaries/mock_backend/src/main.rs +++ b/noir/noir-repo/tooling/backend_interface/test-binaries/mock_backend/src/main.rs @@ -7,7 +7,6 @@ use clap::{Parser, Subcommand}; mod contract_cmd; mod gates_cmd; -mod info_cmd; mod prove_cmd; mod verify_cmd; mod write_vk_cmd; @@ -21,7 +20,6 @@ struct BackendCli { #[derive(Subcommand, Clone, Debug)] enum BackendCommand { - Info(info_cmd::InfoCommand), Contract(contract_cmd::ContractCommand), Gates(gates_cmd::GatesCommand), Prove(prove_cmd::ProveCommand), @@ -34,7 +32,6 @@ fn main() { let BackendCli { command } = BackendCli::parse(); match command { - BackendCommand::Info(args) => info_cmd::run(args), BackendCommand::Contract(args) => contract_cmd::run(args), BackendCommand::Gates(args) => gates_cmd::run(args), BackendCommand::Prove(args) => prove_cmd::run(args), diff --git a/noir/noir-repo/tooling/bb_abstraction_leaks/build.rs b/noir/noir-repo/tooling/bb_abstraction_leaks/build.rs index b3dfff9e94c..45da7f9d00c 100644 --- a/noir/noir-repo/tooling/bb_abstraction_leaks/build.rs +++ b/noir/noir-repo/tooling/bb_abstraction_leaks/build.rs @@ -10,7 +10,7 @@ use const_format::formatcp; const USERNAME: &str = "AztecProtocol"; const REPO: &str = "aztec-packages"; -const VERSION: &str = "0.35.1"; +const VERSION: &str = "0.38.0"; const TAG: &str = formatcp!("aztec-packages-v{}", VERSION); const API_URL: &str = diff --git a/noir/noir-repo/tooling/lsp/src/lib.rs b/noir/noir-repo/tooling/lsp/src/lib.rs index be9b83e02f6..05345b96c80 100644 --- a/noir/noir-repo/tooling/lsp/src/lib.rs +++ b/noir/noir-repo/tooling/lsp/src/lib.rs @@ -345,7 +345,7 @@ fn prepare_package_from_source_string() { let mut state = LspState::new(&client, acvm::blackbox_solver::StubbedBlackBoxSolver); let (mut context, crate_id) = crate::prepare_source(source.to_string(), &mut state); - let _check_result = noirc_driver::check_crate(&mut context, crate_id, false, false); + let _check_result = noirc_driver::check_crate(&mut context, crate_id, false, false, false); let main_func_id = context.get_main_function(&crate_id); assert!(main_func_id.is_some()); } diff --git a/noir/noir-repo/tooling/lsp/src/notifications/mod.rs b/noir/noir-repo/tooling/lsp/src/notifications/mod.rs index 355bb7832c4..3856bdc79e9 100644 --- a/noir/noir-repo/tooling/lsp/src/notifications/mod.rs +++ b/noir/noir-repo/tooling/lsp/src/notifications/mod.rs @@ -56,7 +56,7 @@ pub(super) fn on_did_change_text_document( state.input_files.insert(params.text_document.uri.to_string(), text.clone()); let (mut context, crate_id) = prepare_source(text, state); - let _ = check_crate(&mut context, crate_id, false, false); + let _ = check_crate(&mut context, crate_id, false, false, false); let workspace = match resolve_workspace_for_source_path( params.text_document.uri.to_file_path().unwrap().as_path(), @@ -139,7 +139,7 @@ fn process_noir_document( let (mut context, crate_id) = prepare_package(&workspace_file_manager, &parsed_files, package); - let file_diagnostics = match check_crate(&mut context, crate_id, false, false) { + let file_diagnostics = match check_crate(&mut context, crate_id, false, false, false) { Ok(((), warnings)) => warnings, Err(errors_and_warnings) => errors_and_warnings, }; diff --git a/noir/noir-repo/tooling/lsp/src/requests/code_lens_request.rs b/noir/noir-repo/tooling/lsp/src/requests/code_lens_request.rs index 893ba33d845..744bddedd9d 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/code_lens_request.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/code_lens_request.rs @@ -67,7 +67,7 @@ fn on_code_lens_request_inner( let (mut context, crate_id) = prepare_source(source_string, state); // We ignore the warnings and errors produced by compilation for producing code lenses // because we can still get the test functions even if compilation fails - let _ = check_crate(&mut context, crate_id, false, false); + let _ = check_crate(&mut context, crate_id, false, false, false); let collected_lenses = collect_lenses_for_package(&context, crate_id, &workspace, package, None); diff --git a/noir/noir-repo/tooling/lsp/src/requests/goto_declaration.rs b/noir/noir-repo/tooling/lsp/src/requests/goto_declaration.rs index 8e6d519b895..5cff16b2348 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/goto_declaration.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/goto_declaration.rs @@ -46,7 +46,7 @@ fn on_goto_definition_inner( interner = def_interner; } else { // We ignore the warnings and errors produced by compilation while resolving the definition - let _ = noirc_driver::check_crate(&mut context, crate_id, false, false); + let _ = noirc_driver::check_crate(&mut context, crate_id, false, false, false); interner = &context.def_interner; } diff --git a/noir/noir-repo/tooling/lsp/src/requests/goto_definition.rs b/noir/noir-repo/tooling/lsp/src/requests/goto_definition.rs index 88bb667f2e8..32e13ce00f6 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/goto_definition.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/goto_definition.rs @@ -54,7 +54,7 @@ fn on_goto_definition_inner( interner = def_interner; } else { // We ignore the warnings and errors produced by compilation while resolving the definition - let _ = noirc_driver::check_crate(&mut context, crate_id, false, false); + let _ = noirc_driver::check_crate(&mut context, crate_id, false, false, false); interner = &context.def_interner; } diff --git a/noir/noir-repo/tooling/lsp/src/requests/test_run.rs b/noir/noir-repo/tooling/lsp/src/requests/test_run.rs index 1844a3d9bf0..83b05ba06a2 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/test_run.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/test_run.rs @@ -60,7 +60,7 @@ fn on_test_run_request_inner( Some(package) => { let (mut context, crate_id) = prepare_package(&workspace_file_manager, &parsed_files, package); - if check_crate(&mut context, crate_id, false, false).is_err() { + if check_crate(&mut context, crate_id, false, false, false).is_err() { let result = NargoTestRunResult { id: params.id.clone(), result: "error".to_string(), diff --git a/noir/noir-repo/tooling/lsp/src/requests/tests.rs b/noir/noir-repo/tooling/lsp/src/requests/tests.rs index 5b78fcc65c3..cdf4ad338c4 100644 --- a/noir/noir-repo/tooling/lsp/src/requests/tests.rs +++ b/noir/noir-repo/tooling/lsp/src/requests/tests.rs @@ -61,7 +61,7 @@ fn on_tests_request_inner( prepare_package(&workspace_file_manager, &parsed_files, package); // We ignore the warnings and errors produced by compilation for producing tests // because we can still get the test functions even if compilation fails - let _ = check_crate(&mut context, crate_id, false, false); + let _ = check_crate(&mut context, crate_id, false, false, false); // We don't add test headings for a package if it contains no `#[test]` functions get_package_tests_in_crate(&context, &crate_id, &package.name) diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/check_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/check_cmd.rs index 2b729e44b8a..d5313d96076 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/check_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/check_cmd.rs @@ -1,4 +1,3 @@ -use crate::backends::Backend; use crate::errors::CliError; use clap::Args; @@ -42,11 +41,7 @@ pub(crate) struct CheckCommand { compile_options: CompileOptions, } -pub(crate) fn run( - _backend: &Backend, - args: CheckCommand, - config: NargoConfig, -) -> Result<(), CliError> { +pub(crate) fn run(args: CheckCommand, config: NargoConfig) -> Result<(), CliError> { let toml_path = get_package_manifest(&config.program_dir)?; let default_selection = if args.workspace { PackageSelection::All } else { PackageSelection::DefaultOrAll }; @@ -92,6 +87,7 @@ fn check_package( compile_options.deny_warnings, compile_options.disable_macros, compile_options.silence_warnings, + compile_options.use_elaborator, )?; if package.is_library() || package.is_contract() { @@ -178,8 +174,9 @@ pub(crate) fn check_crate_and_report_errors( deny_warnings: bool, disable_macros: bool, silence_warnings: bool, + use_elaborator: bool, ) -> Result<(), CompileError> { - let result = check_crate(context, crate_id, deny_warnings, disable_macros); + let result = check_crate(context, crate_id, deny_warnings, disable_macros, use_elaborator); report_errors(result, &context.file_manager, deny_warnings, silence_warnings) } diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/codegen_verifier_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/codegen_verifier_cmd.rs index 259e209b65a..8c64d9cd935 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/codegen_verifier_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/codegen_verifier_cmd.rs @@ -44,7 +44,6 @@ pub(crate) fn run( insert_all_files_for_workspace_into_file_manager(&workspace, &mut workspace_file_manager); let parsed_files = parse_all(&workspace_file_manager); - let expression_width = backend.get_backend_info()?; let binary_packages = workspace.into_iter().filter(|package| package.is_binary()); for package in binary_packages { let compilation_result = compile_program( @@ -62,7 +61,7 @@ pub(crate) fn run( args.compile_options.silence_warnings, )?; - let program = nargo::ops::transform_program(program, expression_width); + let program = nargo::ops::transform_program(program, args.compile_options.expression_width); // TODO(https://github.com/noir-lang/noir/issues/4428): // We do not expect to have a smart contract verifier for a foldable program with multiple circuits. diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/compile_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/compile_cmd.rs index 54e8535f094..2f878406939 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/compile_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/compile_cmd.rs @@ -20,7 +20,6 @@ use noirc_frontend::hir::ParsedFiles; use notify::{EventKind, RecursiveMode, Watcher}; use notify_debouncer_full::new_debouncer; -use crate::backends::Backend; use crate::errors::CliError; use super::fs::program::only_acir; @@ -47,11 +46,7 @@ pub(crate) struct CompileCommand { watch: bool, } -pub(crate) fn run( - backend: &Backend, - mut args: CompileCommand, - config: NargoConfig, -) -> Result<(), CliError> { +pub(crate) fn run(args: CompileCommand, config: NargoConfig) -> Result<(), CliError> { let toml_path = get_package_manifest(&config.program_dir)?; let default_selection = if args.workspace { PackageSelection::All } else { PackageSelection::DefaultOrAll }; @@ -63,10 +58,6 @@ pub(crate) fn run( Some(NOIR_ARTIFACT_VERSION_STRING.to_owned()), )?; - if args.compile_options.expression_width.is_none() { - args.compile_options.expression_width = Some(backend.get_backend_info_or_default()); - }; - if args.watch { watch_workspace(&workspace, &args.compile_options) .map_err(|err| CliError::Generic(err.to_string()))?; @@ -128,8 +119,6 @@ fn compile_workspace_full( insert_all_files_for_workspace_into_file_manager(workspace, &mut workspace_file_manager); let parsed_files = parse_all(&workspace_file_manager); - let expression_width = - compile_options.expression_width.expect("expression width should have been set"); let compiled_workspace = compile_workspace(&workspace_file_manager, &parsed_files, workspace, compile_options); @@ -149,12 +138,12 @@ fn compile_workspace_full( // Save build artifacts to disk. let only_acir = compile_options.only_acir; for (package, program) in binary_packages.into_iter().zip(compiled_programs) { - let program = nargo::ops::transform_program(program, expression_width); + let program = nargo::ops::transform_program(program, compile_options.expression_width); save_program(program.clone(), &package, &workspace.target_directory_path(), only_acir); } let circuit_dir = workspace.target_directory_path(); for (package, contract) in contract_packages.into_iter().zip(compiled_contracts) { - let contract = nargo::ops::transform_contract(contract, expression_width); + let contract = nargo::ops::transform_contract(contract, compile_options.expression_width); save_contract(contract, &package, &circuit_dir); } diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/dap_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/dap_cmd.rs index ba4f91609ef..124e30069ae 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/dap_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/dap_cmd.rs @@ -1,6 +1,5 @@ use acvm::acir::circuit::ExpressionWidth; use acvm::acir::native_types::WitnessMap; -use backend_interface::Backend; use clap::Args; use nargo::constants::PROVER_INPUT_FILE; use nargo::workspace::Workspace; @@ -29,8 +28,8 @@ use noir_debugger::errors::{DapError, LoadError}; #[derive(Debug, Clone, Args)] pub(crate) struct DapCommand { /// Override the expression width requested by the backend. - #[arg(long, value_parser = parse_expression_width)] - expression_width: Option, + #[arg(long, value_parser = parse_expression_width, default_value = "4")] + expression_width: ExpressionWidth, #[clap(long)] preflight_check: bool, @@ -249,14 +248,7 @@ fn run_preflight_check( Ok(()) } -pub(crate) fn run( - backend: &Backend, - args: DapCommand, - _config: NargoConfig, -) -> Result<(), CliError> { - let expression_width = - args.expression_width.unwrap_or_else(|| backend.get_backend_info_or_default()); - +pub(crate) fn run(args: DapCommand, _config: NargoConfig) -> Result<(), CliError> { // When the --preflight-check flag is present, we run Noir's DAP server in "pre-flight mode", which test runs // the DAP initialization code without actually starting the DAP server. // @@ -270,12 +262,12 @@ pub(crate) fn run( // the DAP loop is established, which otherwise are considered "out of band" by the maintainers of the DAP spec. // More details here: https://github.com/microsoft/vscode/issues/108138 if args.preflight_check { - return run_preflight_check(expression_width, args).map_err(CliError::DapError); + return run_preflight_check(args.expression_width, args).map_err(CliError::DapError); } let output = BufWriter::new(std::io::stdout()); let input = BufReader::new(std::io::stdin()); let server = Server::new(input, output); - loop_uninitialized_dap(server, expression_width).map_err(CliError::DapError) + loop_uninitialized_dap(server, args.expression_width).map_err(CliError::DapError) } diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/debug_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/debug_cmd.rs index 7cb5cd7846b..f950cd0405c 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/debug_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/debug_cmd.rs @@ -24,7 +24,6 @@ use noirc_frontend::hir::ParsedFiles; use super::fs::{inputs::read_inputs_from_file, witness::save_witness_to_dir}; use super::NargoConfig; -use crate::backends::Backend; use crate::errors::CliError; /// Executes a circuit in debug mode @@ -53,11 +52,7 @@ pub(crate) struct DebugCommand { skip_instrumentation: Option, } -pub(crate) fn run( - backend: &Backend, - args: DebugCommand, - config: NargoConfig, -) -> Result<(), CliError> { +pub(crate) fn run(args: DebugCommand, config: NargoConfig) -> Result<(), CliError> { let acir_mode = args.acir_mode; let skip_instrumentation = args.skip_instrumentation.unwrap_or(acir_mode); @@ -69,10 +64,6 @@ pub(crate) fn run( Some(NOIR_ARTIFACT_VERSION_STRING.to_string()), )?; let target_dir = &workspace.target_directory_path(); - let expression_width = args - .compile_options - .expression_width - .unwrap_or_else(|| backend.get_backend_info_or_default()); let Some(package) = workspace.into_iter().find(|p| p.is_binary()) else { println!( @@ -89,7 +80,8 @@ pub(crate) fn run( args.compile_options.clone(), )?; - let compiled_program = nargo::ops::transform_program(compiled_program, expression_width); + let compiled_program = + nargo::ops::transform_program(compiled_program, args.compile_options.expression_width); run_async(package, compiled_program, &args.prover_name, &args.witness_name, target_dir) } diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/execute_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/execute_cmd.rs index 854ad559012..68f902dfe33 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/execute_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/execute_cmd.rs @@ -18,7 +18,6 @@ use noirc_frontend::graph::CrateName; use super::fs::{inputs::read_inputs_from_file, witness::save_witness_to_dir}; use super::NargoConfig; -use crate::backends::Backend; use crate::errors::CliError; /// Executes a circuit to calculate its return value @@ -48,11 +47,7 @@ pub(crate) struct ExecuteCommand { oracle_resolver: Option, } -pub(crate) fn run( - backend: &Backend, - args: ExecuteCommand, - config: NargoConfig, -) -> Result<(), CliError> { +pub(crate) fn run(args: ExecuteCommand, config: NargoConfig) -> Result<(), CliError> { let toml_path = get_package_manifest(&config.program_dir)?; let default_selection = if args.workspace { PackageSelection::All } else { PackageSelection::DefaultOrAll }; @@ -68,10 +63,6 @@ pub(crate) fn run( insert_all_files_for_workspace_into_file_manager(&workspace, &mut workspace_file_manager); let parsed_files = parse_all(&workspace_file_manager); - let expression_width = args - .compile_options - .expression_width - .unwrap_or_else(|| backend.get_backend_info_or_default()); let binary_packages = workspace.into_iter().filter(|package| package.is_binary()); for package in binary_packages { let compilation_result = compile_program( @@ -89,7 +80,8 @@ pub(crate) fn run( args.compile_options.silence_warnings, )?; - let compiled_program = nargo::ops::transform_program(compiled_program, expression_width); + let compiled_program = + nargo::ops::transform_program(compiled_program, args.compile_options.expression_width); let (return_value, witness_stack) = execute_program_and_decode( compiled_program, diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/export_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/export_cmd.rs index 044c2cb4ebb..324eed340ad 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/export_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/export_cmd.rs @@ -20,7 +20,6 @@ use noirc_frontend::graph::CrateName; use clap::Args; -use crate::backends::Backend; use crate::errors::CliError; use super::check_cmd::check_crate_and_report_errors; @@ -43,11 +42,7 @@ pub(crate) struct ExportCommand { compile_options: CompileOptions, } -pub(crate) fn run( - _backend: &Backend, - args: ExportCommand, - config: NargoConfig, -) -> Result<(), CliError> { +pub(crate) fn run(args: ExportCommand, config: NargoConfig) -> Result<(), CliError> { let toml_path = get_package_manifest(&config.program_dir)?; let default_selection = if args.workspace { PackageSelection::All } else { PackageSelection::DefaultOrAll }; @@ -94,6 +89,7 @@ fn compile_exported_functions( compile_options.deny_warnings, compile_options.disable_macros, compile_options.silence_warnings, + compile_options.use_elaborator, )?; let exported_functions = context.get_all_exported_functions_in_crate(&crate_id); diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/info_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/info_cmd.rs index 3695fb57d31..cac3c36f904 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/info_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/info_cmd.rs @@ -70,10 +70,6 @@ pub(crate) fn run( insert_all_files_for_workspace_into_file_manager(&workspace, &mut workspace_file_manager); let parsed_files = parse_all(&workspace_file_manager); - let expression_width = args - .compile_options - .expression_width - .unwrap_or_else(|| backend.get_backend_info_or_default()); let compiled_workspace = compile_workspace( &workspace_file_manager, &parsed_files, @@ -89,10 +85,10 @@ pub(crate) fn run( )?; let compiled_programs = vecmap(compiled_programs, |program| { - nargo::ops::transform_program(program, expression_width) + nargo::ops::transform_program(program, args.compile_options.expression_width) }); let compiled_contracts = vecmap(compiled_contracts, |contract| { - nargo::ops::transform_contract(contract, expression_width) + nargo::ops::transform_contract(contract, args.compile_options.expression_width) }); if args.profile_info { @@ -122,13 +118,24 @@ pub(crate) fn run( let program_info = binary_packages .par_bridge() .map(|(package, program)| { - count_opcodes_and_gates_in_program(backend, program, package, expression_width) + count_opcodes_and_gates_in_program( + backend, + program, + package, + args.compile_options.expression_width, + ) }) .collect::>()?; let contract_info = compiled_contracts .into_par_iter() - .map(|contract| count_opcodes_and_gates_in_contract(backend, contract, expression_width)) + .map(|contract| { + count_opcodes_and_gates_in_contract( + backend, + contract, + args.compile_options.expression_width, + ) + }) .collect::>()?; let info_report = InfoReport { programs: program_info, contracts: contract_info }; diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/lsp_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/lsp_cmd.rs index 1428b8070c8..45ac02ea552 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/lsp_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/lsp_cmd.rs @@ -8,7 +8,6 @@ use noir_lsp::NargoLspService; use tower::ServiceBuilder; use super::NargoConfig; -use crate::backends::Backend; use crate::errors::CliError; /// Starts the Noir LSP server @@ -19,12 +18,7 @@ use crate::errors::CliError; #[derive(Debug, Clone, Args)] pub(crate) struct LspCommand; -pub(crate) fn run( - // Backend is currently unused, but we might want to use it to inform the lsp in the future - _backend: &Backend, - _args: LspCommand, - _config: NargoConfig, -) -> Result<(), CliError> { +pub(crate) fn run(_args: LspCommand, _config: NargoConfig) -> Result<(), CliError> { use tokio::runtime::Builder; let runtime = Builder::new_current_thread().enable_all().build().unwrap(); diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/mod.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/mod.rs index e8e17893815..ad778549ac0 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/mod.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/mod.rs @@ -107,21 +107,21 @@ pub(crate) fn start_cli() -> eyre::Result<()> { let backend = crate::backends::Backend::new(active_backend); match command { - NargoCommand::New(args) => new_cmd::run(&backend, args, config), + NargoCommand::New(args) => new_cmd::run(args, config), NargoCommand::Init(args) => init_cmd::run(args, config), - NargoCommand::Check(args) => check_cmd::run(&backend, args, config), - NargoCommand::Compile(args) => compile_cmd::run(&backend, args, config), - NargoCommand::Debug(args) => debug_cmd::run(&backend, args, config), - NargoCommand::Execute(args) => execute_cmd::run(&backend, args, config), - NargoCommand::Export(args) => export_cmd::run(&backend, args, config), + NargoCommand::Check(args) => check_cmd::run(args, config), + NargoCommand::Compile(args) => compile_cmd::run(args, config), + NargoCommand::Debug(args) => debug_cmd::run(args, config), + NargoCommand::Execute(args) => execute_cmd::run(args, config), + NargoCommand::Export(args) => export_cmd::run(args, config), NargoCommand::Prove(args) => prove_cmd::run(&backend, args, config), NargoCommand::Verify(args) => verify_cmd::run(&backend, args, config), - NargoCommand::Test(args) => test_cmd::run(&backend, args, config), + NargoCommand::Test(args) => test_cmd::run(args, config), NargoCommand::Info(args) => info_cmd::run(&backend, args, config), NargoCommand::CodegenVerifier(args) => codegen_verifier_cmd::run(&backend, args, config), NargoCommand::Backend(args) => backend_cmd::run(args), - NargoCommand::Lsp(args) => lsp_cmd::run(&backend, args, config), - NargoCommand::Dap(args) => dap_cmd::run(&backend, args, config), + NargoCommand::Lsp(args) => lsp_cmd::run(args, config), + NargoCommand::Dap(args) => dap_cmd::run(args, config), NargoCommand::Fmt(args) => fmt_cmd::run(args, config), }?; diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/new_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/new_cmd.rs index b4c823d0c1e..21951f27260 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/new_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/new_cmd.rs @@ -1,4 +1,3 @@ -use crate::backends::Backend; use crate::errors::CliError; use super::{init_cmd::initialize_project, NargoConfig}; @@ -30,12 +29,7 @@ pub(crate) struct NewCommand { pub(crate) contract: bool, } -pub(crate) fn run( - // Backend is currently unused, but we might want to use it to inform the "new" template in the future - _backend: &Backend, - args: NewCommand, - config: NargoConfig, -) -> Result<(), CliError> { +pub(crate) fn run(args: NewCommand, config: NargoConfig) -> Result<(), CliError> { let package_dir = config.program_dir.join(&args.path); if package_dir.exists() { diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/prove_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/prove_cmd.rs index b9e4bca9e69..47c71527fd8 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/prove_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/prove_cmd.rs @@ -69,10 +69,6 @@ pub(crate) fn run( insert_all_files_for_workspace_into_file_manager(&workspace, &mut workspace_file_manager); let parsed_files = parse_all(&workspace_file_manager); - let expression_width = args - .compile_options - .expression_width - .unwrap_or_else(|| backend.get_backend_info_or_default()); let binary_packages = workspace.into_iter().filter(|package| package.is_binary()); for package in binary_packages { let compilation_result = compile_program( @@ -90,7 +86,8 @@ pub(crate) fn run( args.compile_options.silence_warnings, )?; - let compiled_program = nargo::ops::transform_program(compiled_program, expression_width); + let compiled_program = + nargo::ops::transform_program(compiled_program, args.compile_options.expression_width); prove_package( backend, diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/test_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/test_cmd.rs index 88a804d5cf4..51e21248afd 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/test_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/test_cmd.rs @@ -19,7 +19,7 @@ use noirc_frontend::{ use rayon::prelude::{IntoParallelIterator, ParallelBridge, ParallelIterator}; use termcolor::{Color, ColorChoice, ColorSpec, StandardStream, WriteColor}; -use crate::{backends::Backend, cli::check_cmd::check_crate_and_report_errors, errors::CliError}; +use crate::{cli::check_cmd::check_crate_and_report_errors, errors::CliError}; use super::NargoConfig; @@ -54,11 +54,7 @@ pub(crate) struct TestCommand { oracle_resolver: Option, } -pub(crate) fn run( - _backend: &Backend, - args: TestCommand, - config: NargoConfig, -) -> Result<(), CliError> { +pub(crate) fn run(args: TestCommand, config: NargoConfig) -> Result<(), CliError> { let toml_path = get_package_manifest(&config.program_dir)?; let default_selection = if args.workspace { PackageSelection::All } else { PackageSelection::DefaultOrAll }; @@ -179,6 +175,7 @@ fn run_test( crate_id, compile_options.deny_warnings, compile_options.disable_macros, + compile_options.use_elaborator, ) .expect("Any errors should have occurred when collecting test functions"); @@ -212,6 +209,7 @@ fn get_tests_in_package( compile_options.deny_warnings, compile_options.disable_macros, compile_options.silence_warnings, + compile_options.use_elaborator, )?; Ok(context diff --git a/noir/noir-repo/tooling/nargo_cli/src/cli/verify_cmd.rs b/noir/noir-repo/tooling/nargo_cli/src/cli/verify_cmd.rs index 7202a179aae..a6078f6c1d3 100644 --- a/noir/noir-repo/tooling/nargo_cli/src/cli/verify_cmd.rs +++ b/noir/noir-repo/tooling/nargo_cli/src/cli/verify_cmd.rs @@ -54,10 +54,6 @@ pub(crate) fn run( insert_all_files_for_workspace_into_file_manager(&workspace, &mut workspace_file_manager); let parsed_files = parse_all(&workspace_file_manager); - let expression_width = args - .compile_options - .expression_width - .unwrap_or_else(|| backend.get_backend_info_or_default()); let binary_packages = workspace.into_iter().filter(|package| package.is_binary()); for package in binary_packages { let compilation_result = compile_program( @@ -75,7 +71,8 @@ pub(crate) fn run( args.compile_options.silence_warnings, )?; - let compiled_program = nargo::ops::transform_program(compiled_program, expression_width); + let compiled_program = + nargo::ops::transform_program(compiled_program, args.compile_options.expression_width); verify_package(backend, &workspace, package, compiled_program, &args.verifier_name)?; } diff --git a/noir/noir-repo/tooling/nargo_cli/tests/stdlib-tests.rs b/noir/noir-repo/tooling/nargo_cli/tests/stdlib-tests.rs index 9d377cfaee9..70a9354f50a 100644 --- a/noir/noir-repo/tooling/nargo_cli/tests/stdlib-tests.rs +++ b/noir/noir-repo/tooling/nargo_cli/tests/stdlib-tests.rs @@ -10,8 +10,7 @@ use nargo::{ parse_all, prepare_package, }; -#[test] -fn stdlib_noir_tests() { +fn run_stdlib_tests(use_elaborator: bool) { let mut file_manager = file_manager_with_stdlib(&PathBuf::from(".")); file_manager.add_file_with_source_canonical_path(&PathBuf::from("main.nr"), "".to_owned()); let parsed_files = parse_all(&file_manager); @@ -30,7 +29,7 @@ fn stdlib_noir_tests() { let (mut context, dummy_crate_id) = prepare_package(&file_manager, &parsed_files, &dummy_package); - let result = check_crate(&mut context, dummy_crate_id, true, false); + let result = check_crate(&mut context, dummy_crate_id, true, false, use_elaborator); report_errors(result, &context.file_manager, true, false) .expect("Error encountered while compiling standard library"); @@ -60,3 +59,15 @@ fn stdlib_noir_tests() { assert!(!test_report.is_empty(), "Could not find any tests within the stdlib"); assert!(test_report.iter().all(|(_, status)| !status.failed())); } + +#[test] +fn stdlib_noir_tests() { + run_stdlib_tests(false) +} + +// Once this no longer panics we can use the elaborator by default and remove the old passes +#[test] +#[should_panic] +fn stdlib_elaborator_tests() { + run_stdlib_tests(true) +} diff --git a/noir/noir-repo/tooling/noir_js/test/node/execute.test.ts b/noir/noir-repo/tooling/noir_js/test/node/execute.test.ts index b2e76e54efc..d047e35035f 100644 --- a/noir/noir-repo/tooling/noir_js/test/node/execute.test.ts +++ b/noir/noir-repo/tooling/noir_js/test/node/execute.test.ts @@ -117,3 +117,15 @@ it('successfully executes a program with multiple acir circuits', async () => { expect(knownError.message).to.equal('Circuit execution failed: Error: Cannot satisfy constraint'); } }); + +it('successfully executes a program with multiple acir circuits', async () => { + const inputs = { + x: '10', + }; + try { + await new Noir(fold_fibonacci_program).execute(inputs); + } catch (error) { + const knownError = error as Error; + expect(knownError.message).to.equal('Circuit execution failed: Error: Cannot satisfy constraint'); + } +}); diff --git a/noir/noir-repo/tooling/noir_js_backend_barretenberg/package.json b/noir/noir-repo/tooling/noir_js_backend_barretenberg/package.json index c6985f4b037..3368dcd8a09 100644 --- a/noir/noir-repo/tooling/noir_js_backend_barretenberg/package.json +++ b/noir/noir-repo/tooling/noir_js_backend_barretenberg/package.json @@ -42,7 +42,7 @@ "lint": "NODE_NO_WARNINGS=1 eslint . --ext .ts --ignore-path ./.eslintignore --max-warnings 0" }, "dependencies": { - "@aztec/bb.js": "portal:../../../../barretenberg/ts", + "@aztec/bb.js": "0.38.0", "@noir-lang/types": "workspace:*", "fflate": "^0.8.0" }, diff --git a/noir/noir-repo/tooling/noirc_abi/src/lib.rs b/noir/noir-repo/tooling/noirc_abi/src/lib.rs index 7e89a102a98..7a1d1787ca5 100644 --- a/noir/noir-repo/tooling/noirc_abi/src/lib.rs +++ b/noir/noir-repo/tooling/noirc_abi/src/lib.rs @@ -471,7 +471,15 @@ impl Abi { .copied() }) { - Some(decode_value(&mut return_witness_values.into_iter(), &return_type.abi_type)?) + // We do not return value for the data bus. + if return_type.visibility == AbiVisibility::DataBus { + None + } else { + Some(decode_value( + &mut return_witness_values.into_iter(), + &return_type.abi_type, + )?) + } } else { // Unlike for the circuit inputs, we tolerate not being able to find the witness values for the return value. // This is because the user may be decoding a partial witness map for which is hasn't been calculated yet. diff --git a/noir/noir-repo/yarn.lock b/noir/noir-repo/yarn.lock index b45678f5d8b..85966ce3392 100644 --- a/noir/noir-repo/yarn.lock +++ b/noir/noir-repo/yarn.lock @@ -221,18 +221,19 @@ __metadata: languageName: node linkType: hard -"@aztec/bb.js@portal:../../../../barretenberg/ts::locator=%40noir-lang%2Fbackend_barretenberg%40workspace%3Atooling%2Fnoir_js_backend_barretenberg": - version: 0.0.0-use.local - resolution: "@aztec/bb.js@portal:../../../../barretenberg/ts::locator=%40noir-lang%2Fbackend_barretenberg%40workspace%3Atooling%2Fnoir_js_backend_barretenberg" +"@aztec/bb.js@npm:0.38.0": + version: 0.38.0 + resolution: "@aztec/bb.js@npm:0.38.0" dependencies: comlink: ^4.4.1 commander: ^10.0.1 debug: ^4.3.4 tslib: ^2.4.0 bin: - bb.js: ./dest/node/main.js + bb.js: dest/node/main.js + checksum: 5ebc2850f37993db1d0fe4306ec612e9df14c5d227e1451f1b2f96e63e61c64225c46b32d1e1d2a1a0c37795e50b2875362520e9eb49324312516ec9fd6de2c7 languageName: node - linkType: soft + linkType: hard "@babel/code-frame@npm:^7.0.0, @babel/code-frame@npm:^7.10.4, @babel/code-frame@npm:^7.12.11, @babel/code-frame@npm:^7.16.0, @babel/code-frame@npm:^7.22.13, @babel/code-frame@npm:^7.23.5, @babel/code-frame@npm:^7.8.3": version: 7.23.5 @@ -4395,7 +4396,7 @@ __metadata: version: 0.0.0-use.local resolution: "@noir-lang/backend_barretenberg@workspace:tooling/noir_js_backend_barretenberg" dependencies: - "@aztec/bb.js": "portal:../../../../barretenberg/ts" + "@aztec/bb.js": 0.38.0 "@noir-lang/types": "workspace:*" "@types/node": ^20.6.2 "@types/prettier": ^3 From f8d1e1399e04ec77e4483cf24bc55daf273db10e Mon Sep 17 00:00:00 2001 From: Tom French Date: Tue, 21 May 2024 11:48:51 +0100 Subject: [PATCH 2/3] chore: update bb.js install --- .../noir_js_backend_barretenberg/package.json | 2 +- noir/noir-repo/yarn.lock | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/noir/noir-repo/tooling/noir_js_backend_barretenberg/package.json b/noir/noir-repo/tooling/noir_js_backend_barretenberg/package.json index 3368dcd8a09..c6985f4b037 100644 --- a/noir/noir-repo/tooling/noir_js_backend_barretenberg/package.json +++ b/noir/noir-repo/tooling/noir_js_backend_barretenberg/package.json @@ -42,7 +42,7 @@ "lint": "NODE_NO_WARNINGS=1 eslint . --ext .ts --ignore-path ./.eslintignore --max-warnings 0" }, "dependencies": { - "@aztec/bb.js": "0.38.0", + "@aztec/bb.js": "portal:../../../../barretenberg/ts", "@noir-lang/types": "workspace:*", "fflate": "^0.8.0" }, diff --git a/noir/noir-repo/yarn.lock b/noir/noir-repo/yarn.lock index 85966ce3392..b45678f5d8b 100644 --- a/noir/noir-repo/yarn.lock +++ b/noir/noir-repo/yarn.lock @@ -221,19 +221,18 @@ __metadata: languageName: node linkType: hard -"@aztec/bb.js@npm:0.38.0": - version: 0.38.0 - resolution: "@aztec/bb.js@npm:0.38.0" +"@aztec/bb.js@portal:../../../../barretenberg/ts::locator=%40noir-lang%2Fbackend_barretenberg%40workspace%3Atooling%2Fnoir_js_backend_barretenberg": + version: 0.0.0-use.local + resolution: "@aztec/bb.js@portal:../../../../barretenberg/ts::locator=%40noir-lang%2Fbackend_barretenberg%40workspace%3Atooling%2Fnoir_js_backend_barretenberg" dependencies: comlink: ^4.4.1 commander: ^10.0.1 debug: ^4.3.4 tslib: ^2.4.0 bin: - bb.js: dest/node/main.js - checksum: 5ebc2850f37993db1d0fe4306ec612e9df14c5d227e1451f1b2f96e63e61c64225c46b32d1e1d2a1a0c37795e50b2875362520e9eb49324312516ec9fd6de2c7 + bb.js: ./dest/node/main.js languageName: node - linkType: hard + linkType: soft "@babel/code-frame@npm:^7.0.0, @babel/code-frame@npm:^7.10.4, @babel/code-frame@npm:^7.12.11, @babel/code-frame@npm:^7.16.0, @babel/code-frame@npm:^7.22.13, @babel/code-frame@npm:^7.23.5, @babel/code-frame@npm:^7.8.3": version: 7.23.5 @@ -4396,7 +4395,7 @@ __metadata: version: 0.0.0-use.local resolution: "@noir-lang/backend_barretenberg@workspace:tooling/noir_js_backend_barretenberg" dependencies: - "@aztec/bb.js": 0.38.0 + "@aztec/bb.js": "portal:../../../../barretenberg/ts" "@noir-lang/types": "workspace:*" "@types/node": ^20.6.2 "@types/prettier": ^3 From 830965c791b4a6fbb61d1eab55761905f9652da1 Mon Sep 17 00:00:00 2001 From: Tom French Date: Tue, 21 May 2024 11:49:59 +0100 Subject: [PATCH 3/3] chore: revert changes to u128 --- noir/noir-repo/noir_stdlib/src/uint128.nr | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/noir/noir-repo/noir_stdlib/src/uint128.nr b/noir/noir-repo/noir_stdlib/src/uint128.nr index 173fa54863a..0332c8ac865 100644 --- a/noir/noir-repo/noir_stdlib/src/uint128.nr +++ b/noir/noir-repo/noir_stdlib/src/uint128.nr @@ -330,6 +330,7 @@ mod tests { let not_not_num = not_num.not(); assert_eq(num, not_not_num); } + #[test] fn test_construction() { // Check little-endian u64 is inversed with big-endian u64 construction @@ -341,6 +342,7 @@ mod tests { let d = U128::from_u64s_le(0x0706050403020100, 0x0f0e0d0c0b0a0908); assert_eq(c, d); } + #[test] fn test_byte_decomposition() { let a = U128::from_u64s_le(0x0706050403020100, 0x0f0e0d0c0b0a0908); @@ -357,6 +359,7 @@ mod tests { // Check that it's the same element assert_eq(a, b); } + #[test] fn test_hex_constuction() { let a = U128::from_u64s_le(0x1, 0x2); @@ -397,6 +400,7 @@ mod tests { fn test_ascii_decode_range_less_than_48_fails_0() { crate::println(U128::decode_ascii(0)); } + #[test(should_fail)] fn test_ascii_decode_range_less_than_48_fails_1() { crate::println(U128::decode_ascii(47)); @@ -406,18 +410,22 @@ mod tests { fn test_ascii_decode_range_58_64_fails_0() { let _ = U128::decode_ascii(58); } + #[test(should_fail)] fn test_ascii_decode_range_58_64_fails_1() { let _ = U128::decode_ascii(64); } + #[test(should_fail)] fn test_ascii_decode_range_71_96_fails_0() { let _ = U128::decode_ascii(71); } + #[test(should_fail)] fn test_ascii_decode_range_71_96_fails_1() { let _ = U128::decode_ascii(96); } + #[test(should_fail)] fn test_ascii_decode_range_greater_than_102_fails() { let _ = U128::decode_ascii(103);