Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add parameter to call_data attribute #5599

Merged
merged 6 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/noirc_driver/src/abi_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ fn to_abi_visibility(value: Visibility) -> AbiVisibility {
match value {
Visibility::Public => AbiVisibility::Public,
Visibility::Private => AbiVisibility::Private,
Visibility::DataBus => AbiVisibility::DataBus,
Visibility::CallData(_) | Visibility::ReturnData => AbiVisibility::DataBus,
}
}

Expand Down
58 changes: 35 additions & 23 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,12 @@ impl<'a> Context<'a> {
let (return_vars, return_warnings) =
self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?;

let call_data_arrays: Vec<ValueId> =
self.data_bus.call_data.iter().map(|cd| cd.array_id).collect();
for call_data_array in call_data_arrays {
self.ensure_array_is_initialized(call_data_array, dfg)?;
}

// TODO: This is a naive method of assigning the return values to their witnesses as
// we're likely to get a number of constraints which are asserting one witness to be equal to another.
//
Expand Down Expand Up @@ -1263,20 +1269,23 @@ impl<'a> Context<'a> {
let res_typ = dfg.type_of_value(results[0]);

// Get operations to call-data parameters are replaced by a get to the call-data-bus array
if let Some(call_data) = self.data_bus.call_data {
if self.data_bus.call_data_map.contains_key(&array) {
// TODO: the block_id of call-data must be notified to the backend
// TODO: should we do the same for return-data?
let type_size = res_typ.flattened_size();
let type_size =
self.acir_context.add_constant(FieldElement::from(type_size as i128));
let offset = self.acir_context.mul_var(var_index, type_size)?;
let bus_index = self
.acir_context
.add_constant(FieldElement::from(self.data_bus.call_data_map[&array] as i128));
let new_index = self.acir_context.add_var(offset, bus_index)?;
return self.array_get(instruction, call_data, new_index, dfg, index_side_effect);
}
if let Some(call_data) =
self.data_bus.call_data.iter().find(|cd| cd.index_map.contains_key(&array))
{
let type_size = res_typ.flattened_size();
let type_size = self.acir_context.add_constant(FieldElement::from(type_size as i128));
let offset = self.acir_context.mul_var(var_index, type_size)?;
let bus_index = self
.acir_context
.add_constant(FieldElement::from(call_data.index_map[&array] as i128));
let new_index = self.acir_context.add_var(offset, bus_index)?;
return self.array_get(
instruction,
call_data.array_id,
new_index,
dfg,
index_side_effect,
);
}

// Compiler sanity check
Expand Down Expand Up @@ -1707,17 +1716,20 @@ impl<'a> Context<'a> {
len: usize,
value: Option<AcirValue>,
) -> Result<(), InternalError> {
let databus = if self.data_bus.call_data.is_some()
&& self.block_id(&self.data_bus.call_data.unwrap()) == array
{
BlockType::CallData
} else if self.data_bus.return_data.is_some()
let mut databus = BlockType::Memory;
if self.data_bus.return_data.is_some()
&& self.block_id(&self.data_bus.return_data.unwrap()) == array
{
BlockType::ReturnData
} else {
BlockType::Memory
};
databus = BlockType::ReturnData;
}
for array_id in self.data_bus.call_data_array() {
if self.block_id(&array_id) == array {
assert!(databus == BlockType::Memory);
databus = BlockType::CallData;
break;
}
}

self.acir_context.initialize_array(array, len, value, databus)?;
self.initialized_arrays.insert(array);
Ok(())
Expand Down
100 changes: 71 additions & 29 deletions compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::rc::Rc;

use crate::ssa::ir::{types::Type, value::ValueId};
Expand All @@ -8,6 +9,12 @@ use noirc_frontend::hir_def::function::FunctionSignature;

use super::FunctionBuilder;

#[derive(Clone)]
pub(crate) enum DatabusVisibility {
None,
CallData(u32),
ReturnData,
}
/// Used to create a data bus, which is an array of private inputs
/// replacing public inputs
pub(crate) struct DataBusBuilder {
Expand All @@ -27,15 +34,16 @@ impl DataBusBuilder {
}
}

/// Generates a boolean vector telling which (ssa) parameter from the given function signature
/// Generates a vector telling which (ssa) parameters from the given function signature
/// are tagged with databus visibility
pub(crate) fn is_databus(main_signature: &FunctionSignature) -> Vec<bool> {
pub(crate) fn is_databus(main_signature: &FunctionSignature) -> Vec<DatabusVisibility> {
let mut params_is_databus = Vec::new();

for param in &main_signature.0 {
let is_databus = match param.2 {
ast::Visibility::Public | ast::Visibility::Private => false,
ast::Visibility::DataBus => true,
ast::Visibility::Public | ast::Visibility::Private => DatabusVisibility::None,
ast::Visibility::CallData(id) => DatabusVisibility::CallData(id),
ast::Visibility::ReturnData => DatabusVisibility::ReturnData,
};
let len = param.1.field_count() as usize;
params_is_databus.extend(vec![is_databus; len]);
Expand All @@ -44,34 +52,51 @@ impl DataBusBuilder {
}
}

#[derive(Clone, Debug)]
pub(crate) struct CallData {
pub(crate) array_id: ValueId,
pub(crate) index_map: HashMap<ValueId, usize>,
}

#[derive(Clone, Default, Debug)]
pub(crate) struct DataBus {
pub(crate) call_data: Option<ValueId>,
pub(crate) call_data_map: HashMap<ValueId, usize>,
pub(crate) call_data: Vec<CallData>,
pub(crate) return_data: Option<ValueId>,
}

impl DataBus {
/// Updates the databus values with the provided function
pub(crate) fn map_values(&self, mut f: impl FnMut(ValueId) -> ValueId) -> DataBus {
let mut call_data_map = HashMap::default();
for (k, v) in self.call_data_map.iter() {
call_data_map.insert(f(*k), *v);
}
DataBus {
call_data: self.call_data.map(&mut f),
call_data_map,
return_data: self.return_data.map(&mut f),
}
let call_data = self
.call_data
.iter()
.map(|cd| {
let mut call_data_map = HashMap::default();
for (k, v) in cd.index_map.iter() {
call_data_map.insert(f(*k), *v);
}
CallData { array_id: f(cd.array_id), index_map: call_data_map }
})
.collect();
DataBus { call_data, return_data: self.return_data.map(&mut f) }
}

pub(crate) fn call_data_array(&self) -> Vec<ValueId> {
self.call_data.iter().map(|cd| cd.array_id).collect()
}
/// Construct a databus from call_data and return_data data bus builders
pub(crate) fn get_data_bus(call_data: DataBusBuilder, return_data: DataBusBuilder) -> DataBus {
DataBus {
call_data: call_data.databus,
call_data_map: call_data.map,
return_data: return_data.databus,
pub(crate) fn get_data_bus(
call_data: Vec<DataBusBuilder>,
return_data: DataBusBuilder,
) -> DataBus {
let mut call_data_args = Vec::new();
for call_data_item in call_data {
if let Some(array_id) = call_data_item.databus {
call_data_args.push(CallData { array_id, index_map: call_data_item.map });
}
}

DataBus { call_data: call_data_args, return_data: return_data.databus }
}
}

Expand Down Expand Up @@ -129,19 +154,36 @@ impl FunctionBuilder {
}

/// Generate the data bus for call-data, based on the parameters of the entry block
/// and a boolean vector telling which ones are call-data
pub(crate) fn call_data_bus(&mut self, is_params_databus: Vec<bool>) -> DataBusBuilder {
/// and a vector telling which ones are call-data
pub(crate) fn call_data_bus(
&mut self,
is_params_databus: Vec<DatabusVisibility>,
) -> Vec<DataBusBuilder> {
//filter parameters of the first block that have call-data visibility
let first_block = self.current_function.entry_block();
let params = self.current_function.dfg[first_block].parameters();
let mut databus_param = Vec::new();
for (param, is_databus) in params.iter().zip(is_params_databus) {
if is_databus {
databus_param.push(param.to_owned());
let mut databus_param: BTreeMap<u32, Vec<ValueId>> = BTreeMap::new();
for (param, databus_attribute) in params.iter().zip(is_params_databus) {
match databus_attribute {
DatabusVisibility::None | DatabusVisibility::ReturnData => continue,
DatabusVisibility::CallData(call_data_id) => {
if let std::collections::btree_map::Entry::Vacant(e) =
databus_param.entry(call_data_id)
{
e.insert(vec![param.to_owned()]);
} else {
databus_param.get_mut(&call_data_id).unwrap().push(param.to_owned());
}
}
}
}
// create the call-data-bus from the filtered list
let call_data = DataBusBuilder::new();
self.initialize_data_bus(&databus_param, call_data)
// create the call-data-bus from the filtered lists
let mut result = Vec::new();
for id in databus_param.keys() {
let builder = DataBusBuilder::new();
let call_databus = self.initialize_data_bus(&databus_param[id], builder);
result.push(call_databus);
}
result
}
}
4 changes: 2 additions & 2 deletions compiler/noirc_evaluator/src/ssa/opt/die.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ impl Ssa {
/// of its instructions are needed elsewhere.
fn dead_instruction_elimination(function: &mut Function) {
let mut context = Context::default();
if let Some(call_data) = function.dfg.data_bus.call_data {
context.mark_used_instruction_results(&function.dfg, call_data);
for call_data in &function.dfg.data_bus.call_data {
context.mark_used_instruction_results(&function.dfg, call_data.array_id);
}

let blocks = PostOrder::with_function(function);
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
// see which parameter has call_data/return_data attribute
let is_databus = DataBusBuilder::is_databus(&program.main_function_signature);

let is_return_data = matches!(program.return_visibility, Visibility::DataBus);
let is_return_data = matches!(program.return_visibility, Visibility::ReturnData);

let return_location = program.return_location;
let context = SharedContext::new(program);
Expand Down Expand Up @@ -472,7 +472,7 @@
/// br loop_entry(v0)
/// loop_entry(i: Field):
/// v2 = lt i v1
/// brif v2, then: loop_body, else: loop_end

Check warning on line 475 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// loop_body():
/// v3 = ... codegen body ...
/// v4 = add 1, i
Expand Down Expand Up @@ -531,7 +531,7 @@
/// For example, the expression `if cond { a } else { b }` is codegen'd as:
///
/// v0 = ... codegen cond ...
/// brif v0, then: then_block, else: else_block

Check warning on line 534 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// then_block():
/// v1 = ... codegen a ...
/// br end_if(v1)
Expand All @@ -544,7 +544,7 @@
/// As another example, the expression `if cond { a }` is codegen'd as:
///
/// v0 = ... codegen cond ...
/// brif v0, then: then_block, else: end_block

Check warning on line 547 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// then_block:
/// v1 = ... codegen a ...
/// br end_if()
Expand Down
7 changes: 5 additions & 2 deletions compiler/noirc_frontend/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
/*env:*/ Box<UnresolvedType>,
),

// The type of quoted code for metaprogramming

Check warning on line 120 in compiler/noirc_frontend/src/ast/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (metaprogramming)
Quoted(crate::QuotedType),

/// An already resolved type. These can only be parsed if they were present in the token stream
Expand Down Expand Up @@ -390,15 +390,18 @@
Private,
/// DataBus is public input handled as private input. We use the fact that return values are properly computed by the program to avoid having them as public inputs
/// it is useful for recursion and is handled by the proving system.
DataBus,
/// The u32 value is used to group inputs having the same value.
CallData(u32),
ReturnData,
}

impl std::fmt::Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Public => write!(f, "pub"),
Self::Private => write!(f, "priv"),
Self::DataBus => write!(f, "databus"),
Self::CallData(id) => write!(f, "calldata{id}"),
Self::ReturnData => write!(f, "returndata"),

Check warning on line 404 in compiler/noirc_frontend/src/ast/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (returndata)
}
}
}
2 changes: 2 additions & 0 deletions compiler/noirc_frontend/src/parser/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ pub enum ParserErrorReason {
Lexer(LexerErrorKind),
#[error("The only supported numeric generic types are `u1`, `u8`, `u16`, and `u32`")]
ForbiddenNumericGenericType,
#[error("Invalid call data identifier, must be a number. E.g `call_data(0)`")]
InvalidCallDataIdentifier,
}

/// Represents a parsing error, or a parsing error in the making.
Expand Down
30 changes: 20 additions & 10 deletions compiler/noirc_frontend/src/parser/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use crate::ast::{
use crate::lexer::{lexer::from_spanned_token_result, Lexer};
use crate::parser::{force, ignore_then_commit, statement_recovery};
use crate::token::{Keyword, Token, TokenKind};
use acvm::AcirField;

use chumsky::prelude::*;
use iter_extended::vecmap;
Expand Down Expand Up @@ -645,19 +646,28 @@ where
})
}

fn call_data() -> impl NoirParser<Visibility> {
keyword(Keyword::CallData).then(parenthesized(literal())).validate(|token, span, emit| {
match token {
(_, ExpressionKind::Literal(Literal::Integer(x, _))) => {
let id = x.to_u128() as u32;
Visibility::CallData(id)
}
_ => {
emit(ParserError::with_reason(ParserErrorReason::InvalidCallDataIdentifier, span));
Visibility::CallData(0)
}
}
})
}

fn optional_visibility() -> impl NoirParser<Visibility> {
keyword(Keyword::Pub)
.or(keyword(Keyword::CallData))
.or(keyword(Keyword::ReturnData))
.map(|_| Visibility::Public)
.or(call_data())
.or(keyword(Keyword::ReturnData).map(|_| Visibility::ReturnData))
.or_not()
.map(|opt| match opt {
Some(Token::Keyword(Keyword::Pub)) => Visibility::Public,
Some(Token::Keyword(Keyword::CallData)) | Some(Token::Keyword(Keyword::ReturnData)) => {
Visibility::DataBus
}
None => Visibility::Private,
_ => unreachable!("unexpected token found"),
})
.map(|opt| opt.unwrap_or(Visibility::Private))
}

pub fn expression() -> impl ExprParser {
Expand Down
2 changes: 1 addition & 1 deletion test_programs/execution_success/databus/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
fn main(mut x: u32, y: call_data u32, z: call_data [u32; 4]) -> return_data u32 {
fn main(mut x: u32, y: call_data(0) u32, z: call_data(0) [u32; 4]) -> return_data u32 {
let a = z[x];
michaeljklein marked this conversation as resolved.
Show resolved Hide resolved
a + foo(y)
}
Expand Down
7 changes: 4 additions & 3 deletions tooling/nargo_fmt/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,10 @@ impl HasItem for Param {
fn format(self, visitor: &FmtVisitor, shape: Shape) -> String {
let pattern = visitor.slice(self.pattern.span());
let visibility = match self.visibility {
Visibility::Public => "pub",
Visibility::Private => "",
Visibility::DataBus => "call_data",
Visibility::Public => "pub".to_string(),
Visibility::Private => "".to_string(),
Visibility::CallData(x) => format!("call_data({x})"),
Visibility::ReturnData => "return_data".to_string(),
};

if self.pattern.is_synthesized() || self.typ.is_synthesized() {
Expand Down
5 changes: 4 additions & 1 deletion tooling/nargo_fmt/src/visitor/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ impl super::FmtVisitor<'_> {

let visibility = match func.def.return_visibility {
Visibility::Public => "pub",
Visibility::DataBus => "return_data",
Visibility::ReturnData => "return_data",
Visibility::Private => "",
Visibility::CallData(_) => {
unreachable!("call_data cannot be used for return value")
}
};
result.push_str(&append_space_if_nonempty(visibility.into()));

Expand Down
Loading