Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Oracle mocker for nargo test #2928

Merged
merged 5 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions compiler/noirc_frontend/src/hir/def_map/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ impl CrateDefMap {
.value_definitions()
.filter_map(|id| {
id.as_function().map(|function_id| {
let is_entry_point = !interner
.function_attributes(&function_id)
.has_contract_library_method();
let attributes = interner.function_attributes(&function_id);
let is_entry_point = !attributes.has_contract_library_method()
&& !attributes.is_test_function();
ContractFunctionMeta { function_id, is_entry_point }
})
})
Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_frontend/src/lexer/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@
match string.trim() {
"should_fail" => Some(TestScope::ShouldFailWith { reason: None }),
s if s.starts_with("should_fail_with") => {
let parts: Vec<&str> = s.splitn(2, '=').collect();

Check warning on line 335 in compiler/noirc_frontend/src/lexer/token.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (splitn)
if parts.len() == 2 {
let reason = parts[1].trim();
let reason = reason.trim_matches('"');
Expand Down Expand Up @@ -384,6 +384,10 @@
.any(|attribute| attribute == &SecondaryAttribute::ContractLibraryMethod)
}

pub fn is_test_function(&self) -> bool {
matches!(self.function, Some(FunctionAttribute::Test(_)))
}

/// Returns note if a deprecated secondary attribute is found
pub fn get_deprecated_note(&self) -> Option<Option<String>> {
self.secondary.iter().find_map(|attr| match attr {
Expand Down Expand Up @@ -646,7 +650,7 @@
Keyword::Field => write!(f, "Field"),
Keyword::Fn => write!(f, "fn"),
Keyword::For => write!(f, "for"),
Keyword::FormatString => write!(f, "fmtstr"),

Check warning on line 653 in compiler/noirc_frontend/src/lexer/token.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (fmtstr)
Keyword::Global => write!(f, "global"),
Keyword::If => write!(f, "if"),
Keyword::Impl => write!(f, "impl"),
Expand Down Expand Up @@ -689,7 +693,7 @@
"Field" => Keyword::Field,
"fn" => Keyword::Fn,
"for" => Keyword::For,
"fmtstr" => Keyword::FormatString,

Check warning on line 696 in compiler/noirc_frontend/src/lexer/token.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (fmtstr)
"global" => Keyword::Global,
"if" => Keyword::If,
"impl" => Keyword::Impl,
Expand Down
1 change: 1 addition & 0 deletions noir_stdlib/src/lib.nr
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod collections;
mod compat;
mod option;
mod string;
mod test;

// Oracle calls are required to be wrapped in an unconstrained function
// Thus, the only argument to the `println` oracle is expected to always be an ident
Expand Down
45 changes: 45 additions & 0 deletions noir_stdlib/src/test.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#[oracle(create_mock)]
unconstrained fn create_mock_oracle<N>(_name: str<N>) -> Field {}

#[oracle(set_mock_params)]
unconstrained fn set_mock_params_oracle<P>(_id: Field, _params: P) {}

#[oracle(set_mock_returns)]
unconstrained fn set_mock_returns_oracle<R>(_id: Field, _returns: R) {}

#[oracle(set_mock_times)]
unconstrained fn set_mock_times_oracle(_id: Field, _times: u64) {}

#[oracle(clear_mock)]
unconstrained fn clear_mock_oracle(_id: Field) {}

struct OracleMock {
id: Field,
}

impl OracleMock {
unconstrained pub fn mock<N>(name: str<N>) -> Self {
Self {
id: create_mock_oracle(name),
}
}

unconstrained pub fn with_params<P>(self, params: P) -> Self {
set_mock_params_oracle(self.id, params);
self
}

unconstrained pub fn returns<R>(self, returns: R) -> Self {
set_mock_returns_oracle(self.id, returns);
self
}

unconstrained pub fn times(self, times: u64) -> Self {
set_mock_times_oracle(self.id, times);
self
}

unconstrained pub fn clear(self) {
clear_mock_oracle(self.id);
}
}
7 changes: 5 additions & 2 deletions tooling/nargo/src/ops/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use acvm::{acir::circuit::Circuit, acir::native_types::WitnessMap};
use crate::errors::ExecutionError;
use crate::NargoError;

use super::foreign_calls::ForeignCall;
use super::foreign_calls::ForeignCallExecutor;

pub fn execute_circuit<B: BlackBoxFunctionSolver>(
blackbox_solver: &B,
Expand All @@ -24,6 +24,8 @@ pub fn execute_circuit<B: BlackBoxFunctionSolver>(
.map(|(_, message)| message.clone())
};

let mut foreign_call_executor = ForeignCallExecutor::default();

loop {
let solver_status = acvm.solve();

Expand Down Expand Up @@ -57,7 +59,8 @@ pub fn execute_circuit<B: BlackBoxFunctionSolver>(
}));
}
ACVMStatus::RequiresForeignCall(foreign_call) => {
let foreign_call_result = ForeignCall::execute(&foreign_call, show_output)?;
let foreign_call_result =
foreign_call_executor.execute(&foreign_call, show_output)?;
acvm.resolve_pending_foreign_call(foreign_call_result);
}
}
Expand Down
150 changes: 145 additions & 5 deletions tooling/nargo/src/ops/foreign_calls.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use acvm::{
acir::brillig::{ForeignCallResult, Value},
brillig_vm::brillig::ForeignCallParam,
acir::brillig::{ForeignCallParam, ForeignCallResult, Value},
pwg::ForeignCallWaitInfo,
};
use iter_extended::vecmap;
use noirc_printable_type::PrintableValueDisplay;
use noirc_printable_type::{decode_string_value, ForeignCallError, PrintableValueDisplay};

use crate::NargoError;

Expand All @@ -14,6 +13,11 @@ pub(crate) enum ForeignCall {
Println,
Sequence,
ReverseSequence,
CreateMock,
SetMockParams,
SetMockReturns,
SetMockTimes,
ClearMock,
}

impl std::fmt::Display for ForeignCall {
Expand All @@ -28,6 +32,11 @@ impl ForeignCall {
ForeignCall::Println => "println",
ForeignCall::Sequence => "get_number_sequence",
ForeignCall::ReverseSequence => "get_reverse_number_sequence",
ForeignCall::CreateMock => "create_mock",
ForeignCall::SetMockParams => "set_mock_params",
ForeignCall::SetMockReturns => "set_mock_returns",
ForeignCall::SetMockTimes => "set_mock_times",
ForeignCall::ClearMock => "clear_mock",
}
}

Expand All @@ -36,16 +45,65 @@ impl ForeignCall {
"println" => Some(ForeignCall::Println),
"get_number_sequence" => Some(ForeignCall::Sequence),
"get_reverse_number_sequence" => Some(ForeignCall::ReverseSequence),
"create_mock" => Some(ForeignCall::CreateMock),
"set_mock_params" => Some(ForeignCall::SetMockParams),
"set_mock_returns" => Some(ForeignCall::SetMockReturns),
"set_mock_times" => Some(ForeignCall::SetMockTimes),
"clear_mock" => Some(ForeignCall::ClearMock),
_ => None,
}
}
}

/// This struct represents an oracle mock. It can be used for testing programs that use oracles.
#[derive(Debug, PartialEq, Eq, Clone)]
struct MockedCall {
/// The id of the mock, used to update or remove it
id: usize,
/// The oracle it's mocking
name: String,
/// Optionally match the parameters
params: Option<Vec<ForeignCallParam>>,
/// The result to return when this mock is called
result: ForeignCallResult,
/// How many times should this mock be called before it is removed
times_left: Option<u64>,
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
}

impl MockedCall {
fn new(id: usize, name: String) -> Self {
Self {
id,
name,
params: None,
result: ForeignCallResult { values: vec![] },
times_left: None,
}
}
}

impl MockedCall {
fn matches(&self, name: &str, params: &Vec<ForeignCallParam>) -> bool {
self.name == name && (self.params.is_none() || self.params.as_ref() == Some(params))
}
}

#[derive(Debug, Default)]
pub(crate) struct ForeignCallExecutor {
/// Mocks have unique ids used to identify them in Noir, allowing to update or remove them.
last_mock_id: usize,
/// The registered mocks
mocked_responses: Vec<MockedCall>,
}

impl ForeignCallExecutor {
pub(crate) fn execute(
&mut self,
foreign_call: &ForeignCallWaitInfo,
show_output: bool,
) -> Result<ForeignCallResult, NargoError> {
let foreign_call_name = foreign_call.function.as_str();
match Self::lookup(foreign_call_name) {
match ForeignCall::lookup(foreign_call_name) {
Some(ForeignCall::Println) => {
if show_output {
Self::execute_println(&foreign_call.inputs)?;
Expand Down Expand Up @@ -76,10 +134,92 @@ impl ForeignCall {
],
})
}
None => panic!("unexpected foreign call {foreign_call_name:?}"),
Some(ForeignCall::CreateMock) => {
let mock_oracle_name = Self::parse_string(&foreign_call.inputs[0]);
assert!(ForeignCall::lookup(&mock_oracle_name).is_none());
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
let id = self.last_mock_id;
self.mocked_responses.push(MockedCall::new(id, mock_oracle_name));
self.last_mock_id += 1;

Ok(ForeignCallResult { values: vec![Value::from(id).into()] })
}
Some(ForeignCall::SetMockParams) => {
let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?;
self.find_mock_by_id(id)
.unwrap_or_else(|| panic!("Unknown mock id {}", id))
.params = Some(params.to_vec());

Ok(ForeignCallResult { values: vec![] })
}
Some(ForeignCall::SetMockReturns) => {
let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?;
self.find_mock_by_id(id)
.unwrap_or_else(|| panic!("Unknown mock id {}", id))
.result = ForeignCallResult { values: params.to_vec() };

Ok(ForeignCallResult { values: vec![] })
}
Some(ForeignCall::SetMockTimes) => {
let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?;
let times = params[0]
.unwrap_value()
.to_field()
.try_to_u64()
.expect("Invalid bit size of times");

self.find_mock_by_id(id)
.unwrap_or_else(|| panic!("Unknown mock id {}", id))
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
.times_left = Some(times);

Ok(ForeignCallResult { values: vec![] })
}
Some(ForeignCall::ClearMock) => {
let (id, _) = Self::extract_mock_id(&foreign_call.inputs)?;
self.mocked_responses.retain(|response| response.id != id);
Ok(ForeignCallResult { values: vec![] })
}
None => {
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
let response_position = self
.mocked_responses
.iter()
.position(|response| response.matches(foreign_call_name, &foreign_call.inputs))
.unwrap_or_else(|| panic!("Unknown foreign call {}", foreign_call_name));

let mock = self
.mocked_responses
.get_mut(response_position)
.expect("Invalid position of mocked response");
let result = mock.result.values.clone();

if let Some(times_left) = &mut mock.times_left {
*times_left -= 1;
if *times_left == 0 {
self.mocked_responses.remove(response_position);
}
}

Ok(ForeignCallResult { values: result })
}
}
}

fn extract_mock_id(
foreign_call_inputs: &[ForeignCallParam],
) -> Result<(usize, &[ForeignCallParam]), ForeignCallError> {
let (id, params) =
foreign_call_inputs.split_first().ok_or(ForeignCallError::MissingForeignCallInputs)?;
Ok((id.unwrap_value().to_usize(), params))
}

fn find_mock_by_id(&mut self, id: usize) -> Option<&mut MockedCall> {
self.mocked_responses.iter_mut().find(|response| response.id == id)
}

fn parse_string(param: &ForeignCallParam) -> String {
let fields: Vec<_> = param.values().into_iter().map(|value| value.to_field()).collect();
decode_string_value(&fields)
}

fn execute_println(foreign_call_inputs: &[ForeignCallParam]) -> Result<(), NargoError> {
let display_values: PrintableValueDisplay = foreign_call_inputs.try_into()?;
println!("{display_values}");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "mock_oracle"
type = "bin"
authors = [""]
compiler_version = "0.1"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = "10"

30 changes: 30 additions & 0 deletions tooling/nargo_cli/tests/execution_success/mock_oracle/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use dep::std::test::OracleMock;

struct Point {
x: Field,
y: Field,
}

#[oracle(foo)]
unconstrained fn foo_oracle(_point: Point, _array: [Field; 4]) -> Field {}

unconstrained fn main() {
let array = [1,2,3,4];
let another_array = [4,3,2,1];
let point = Point {
x: 14,
y: 27,
};

OracleMock::mock("foo").returns(42).times(1);
let mock = OracleMock::mock("foo").returns(0);
assert_eq(42, foo_oracle(point, array));
assert_eq(0, foo_oracle(point, array));
mock.clear();

OracleMock::mock("foo").with_params((point, array)).returns(10);
OracleMock::mock("foo").with_params((point, another_array)).returns(20);
assert_eq(10, foo_oracle(point, array));
assert_eq(20, foo_oracle(point, another_array));
}

Loading