Skip to content

Commit

Permalink
feat(avm/public): user space PublicContext::get_args_hash (#8292)
Browse files Browse the repository at this point in the history
This PR implements `PublicContext::get_args_hash` in user space. We are
still passing the calldata length as a runtime variable until we can get
it at compile time. This requires @Thunkar 's work on `aztec(public)` as
a macro.

Once that is done, we'll pass the hasher as a closure when creating the
PublicContext, i.e.:
```
struct PublicContext {
    hash_getter: fn[(Field,)]() -> Field,
    // ...
}

impl PublicContext {
    pub fn new(..., hash_getter) -> Self {
        // ...
    }

    fn get_args_hash(self) -> Field {
        (self.hash_getter)()
    }
}

// In the aztec(public) macro
comptime let N = get_calldata_length();
let hash_getter = || {
    let mut hasher = ArgsHasher::new();
    let mut fields = std::meta::unquote!(quote { [0; $N] });
    fields = calldata_copy(2 /*or 1*/, N);
    hasher.add_many(fields);
    hasher.hash()
};
let context = PublicContext::new(..., hash_getter);
```
  • Loading branch information
fcarreiro authored Sep 10, 2024
1 parent 684d962 commit 56ce16a
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 13 deletions.
42 changes: 42 additions & 0 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ fn handle_foreign_call(
"avmOpcodeGetContractInstance" => {
handle_get_contract_instance(avm_instrs, destinations, inputs);
}
"avmOpcodeCalldataCopy" => handle_calldata_copy(avm_instrs, destinations, inputs),
"avmOpcodeStorageRead" => handle_storage_read(avm_instrs, destinations, inputs),
"avmOpcodeStorageWrite" => handle_storage_write(avm_instrs, destinations, inputs),
"debugLog" => handle_debug_log(avm_instrs, destinations, inputs),
Expand Down Expand Up @@ -973,6 +974,47 @@ fn handle_debug_log(
});
}

// #[oracle(avmOpcodeCalldataCopy)]
// unconstrained fn calldata_copy_opcode<let N: u32>(cdoffset: Field) -> [Field; N] {}
fn handle_calldata_copy(
avm_instrs: &mut Vec<AvmInstruction>,
destinations: &Vec<ValueOrArray>,
inputs: &Vec<ValueOrArray>,
) {
assert!(inputs.len() == 2);
assert!(destinations.len() == 1);

let cd_offset = match inputs[0] {
ValueOrArray::MemoryAddress(address) => address.0,
_ => panic!("CalldataCopy offset should be a memory address"),
};

let copy_size_offset = match inputs[1] {
ValueOrArray::MemoryAddress(address) => address.0,
_ => panic!("CalldataCopy size should be a memory address"),
};

let (dest_offset, ..) = match destinations[0] {
ValueOrArray::HeapArray(HeapArray { pointer, size }) => (pointer.0, size),
_ => panic!("CalldataCopy destination should be an array"),
};

avm_instrs.push(AvmInstruction {
opcode: AvmOpcode::CALLDATACOPY,
indirect: Some(SECOND_OPERAND_INDIRECT),
operands: vec![
AvmOperand::U32 {
value: cd_offset as u32, // cdOffset (calldata offset)
},
AvmOperand::U32 { value: copy_size_offset as u32 }, // copy size
AvmOperand::U32 {
value: dest_offset as u32, // dstOffset
},
],
..Default::default()
});
}

/// Emit a storage write opcode
/// The current implementation writes an array of values into storage ( contiguous slots in memory )
fn handle_storage_write(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use dep::protocol_types::traits::Empty;

// These inputs will likely go away once the AVM processes 1 public kernel per enqueued call.
struct PublicContextInputs {
args_hash: Field,
// TODO: Remove this structure and get calldata size at compile time.
calldata_length: Field,
is_static_call: bool
}

impl Empty for PublicContextInputs {
fn empty() -> Self {
PublicContextInputs {
args_hash: 0,
calldata_length: 0,
is_static_call: false
}
}
Expand Down
27 changes: 24 additions & 3 deletions noir-projects/aztec-nr/aztec/src/context/public_context.nr
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ use dep::protocol_types::traits::{Serialize, Deserialize, Empty};
use dep::protocol_types::abis::function_selector::FunctionSelector;
use crate::context::inputs::public_context_inputs::PublicContextInputs;
use crate::context::gas::GasOpts;
use crate::hash::ArgsHasher;

struct PublicContext {
inputs: PublicContextInputs,
args_hash: Option<Field>
}

impl PublicContext {
pub fn new(inputs: PublicContextInputs) -> Self {
PublicContext { inputs }
PublicContext { inputs, args_hash: Option::none() }
}

pub fn emit_unencrypted_log<T, let N: u32>(_self: &mut Self, log: T) where T: Serialize<N> {
Expand Down Expand Up @@ -130,8 +132,20 @@ impl PublicContext {
fn selector(_self: Self) -> FunctionSelector {
FunctionSelector::from_u32(function_selector())
}
fn get_args_hash(self) -> Field {
self.inputs.args_hash
fn get_args_hash(mut self) -> Field {
if !self.args_hash.is_some() {
let mut hasher = ArgsHasher::new();

// TODO: this should be replaced with the compile-time calldata size.
for i in 0..self.inputs.calldata_length as u32 {
let argn: [Field; 1] = calldata_copy((2 + i) as u32, 1);
hasher.add(argn[0]);
}

self.args_hash = Option::some(hasher.hash());
}

self.args_hash.unwrap()
}
fn transaction_fee(_self: Self) -> Field {
transaction_fee()
Expand Down Expand Up @@ -278,6 +292,10 @@ unconstrained fn call_static<let RET_SIZE: u32>(
call_static_opcode(gas, address, args, function_selector)
}

unconstrained fn calldata_copy<let N: u32>(cdoffset: u32, copy_size: u32) -> [Field; N] {
calldata_copy_opcode(cdoffset, copy_size)
}

unconstrained fn storage_read(storage_slot: Field) -> Field {
storage_read_opcode(storage_slot)
}
Expand Down Expand Up @@ -356,6 +374,9 @@ unconstrained fn l1_to_l2_msg_exists_opcode(msg_hash: Field, msg_leaf_index: Fie
#[oracle(avmOpcodeSendL2ToL1Msg)]
unconstrained fn send_l2_to_l1_msg_opcode(recipient: EthAddress, content: Field) {}

#[oracle(avmOpcodeCalldataCopy)]
unconstrained fn calldata_copy_opcode<let N: u32>(cdoffset: u32, copy_size: u32) -> [Field; N] {}

#[oracle(avmOpcodeCall)]
unconstrained fn call_opcode<let RET_SIZE: u32>(
gas: [Field; 2], // gas allocation: [l2_gas, da_gas]
Expand Down
7 changes: 7 additions & 0 deletions noir-projects/aztec-nr/aztec/src/test/helpers/cheatcodes.nr
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ unconstrained pub fn set_msg_sender(msg_sender: AztecAddress) {
oracle_set_msg_sender(msg_sender)
}

unconstrained pub fn set_calldata(calldata: [Field]) {
oracle_set_calldata(calldata)
}

unconstrained pub fn get_msg_sender() -> AztecAddress {
oracle_get_msg_sender()
}
Expand Down Expand Up @@ -187,3 +191,6 @@ unconstrained fn oracle_get_function_selector() -> FunctionSelector {}

#[oracle(setFunctionSelector)]
unconstrained fn oracle_set_function_selector(selector: FunctionSelector) {}

#[oracle(setCalldata)]
unconstrained fn oracle_set_calldata(calldata: [Field]) {}
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,22 @@ impl TestEnvironment {
let original_fn_selector = cheatcodes::get_function_selector();
let target_address = call_interface.get_contract_address();
let fn_selector = call_interface.get_selector();
let calldata = call_interface.get_args();

cheatcodes::set_fn_selector(fn_selector);
cheatcodes::set_contract_address(target_address);
cheatcodes::set_msg_sender(original_contract_address);
let mut inputs = cheatcodes::get_public_context_inputs();
inputs.args_hash = hash_args(call_interface.get_args());
inputs.calldata_length = call_interface.get_args().len() as Field;
inputs.is_static_call = call_interface.get_is_static();
cheatcodes::set_calldata(calldata);

let result = original_fn(inputs);

cheatcodes::set_fn_selector(original_fn_selector);
cheatcodes::set_contract_address(original_contract_address);
cheatcodes::set_msg_sender(original_msg_sender);
cheatcodes::set_calldata(calldata);
result
}

Expand Down
7 changes: 6 additions & 1 deletion noir-projects/aztec-nr/aztec/src/test/helpers/utils.nr
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,22 @@ impl<let N: u32, let M: u32> Deployer<N, M> {
let original_msg_sender = cheatcodes::get_msg_sender();
let original_contract_address = get_contract_address();
let original_fn_selector = cheatcodes::get_function_selector();
let calldata = call_interface.get_args();

cheatcodes::set_fn_selector(call_interface.get_selector());
cheatcodes::set_contract_address(instance.to_address());
cheatcodes::set_msg_sender(original_contract_address);
let mut inputs = cheatcodes::get_public_context_inputs();
inputs.args_hash = hash_args(call_interface.get_args());
inputs.calldata_length = call_interface.get_args().len() as Field;
inputs.is_static_call = call_interface.get_is_static();
cheatcodes::set_calldata(calldata);

let _result: T = original_fn(inputs);

cheatcodes::set_fn_selector(original_fn_selector);
cheatcodes::set_contract_address(original_contract_address);
cheatcodes::set_msg_sender(original_msg_sender);
cheatcodes::set_calldata(calldata);
instance
}

Expand Down
7 changes: 3 additions & 4 deletions yarn-project/simulator/src/avm/avm_execution_environment.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import { FunctionSelector, type GlobalVariables, type Header } from '@aztec/circuits.js';
import { computeVarArgsHash } from '@aztec/circuits.js/hash';
import { type AztecAddress } from '@aztec/foundation/aztec-address';
import { Fr } from '@aztec/foundation/fields';

export class AvmContextInputs {
static readonly SIZE = 2;

constructor(private argsHash: Fr, private isStaticCall: boolean) {}
constructor(private calldataSize: Fr, private isStaticCall: boolean) {}

public toFields(): Fr[] {
return [this.argsHash, new Fr(this.isStaticCall)];
return [this.calldataSize, new Fr(this.isStaticCall)];
}
}

Expand All @@ -33,7 +32,7 @@ export class AvmExecutionEnvironment {
) {
// We encode some extra inputs (AvmContextInputs) in calldata.
// This will have to go once we move away from one proof per call.
const inputs = new AvmContextInputs(computeVarArgsHash(calldata), isStaticCall).toFields();
const inputs = new AvmContextInputs(new Fr(calldata.length), isStaticCall).toFields();
this.calldata = [...inputs, ...calldata];
}

Expand Down
18 changes: 16 additions & 2 deletions yarn-project/txe/src/oracle/txe_oracle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ export class TXE implements TypedOracle {
private contractAddress: AztecAddress;
private msgSender: AztecAddress;
private functionSelector = FunctionSelector.fromField(new Fr(0));
// This will hold the _real_ calldata. That is, the one without the PublicContextInputs.
// TODO: Remove this comment once PublicContextInputs are removed.
private calldata: Fr[] = [];

private contractDataOracle: ContractDataOracle;

Expand Down Expand Up @@ -128,10 +131,20 @@ export class TXE implements TypedOracle {
return this.functionSelector;
}

getCalldata() {
// TODO: Remove this once PublicContextInputs are removed.
const inputs = this.getPublicContextInputs();
return [...inputs.toFields(), ...this.calldata];
}

setMsgSender(msgSender: Fr) {
this.msgSender = msgSender;
}

setCalldata(calldata: Fr[]) {
this.calldata = calldata;
}

setFunctionSelector(functionSelector: FunctionSelector) {
this.functionSelector = functionSelector;
}
Expand Down Expand Up @@ -204,10 +217,10 @@ export class TXE implements TypedOracle {

getPublicContextInputs() {
const inputs = {
argsHash: new Fr(0),
calldataLength: new Fr(this.calldata.length),
isStaticCall: false,
toFields: function () {
return [this.argsHash, new Fr(this.isStaticCall)];
return [this.calldataLength, new Fr(this.isStaticCall)];
},
};
return inputs;
Expand Down Expand Up @@ -738,6 +751,7 @@ export class TXE implements TypedOracle {
this.setMsgSender(this.contractAddress);
this.setContractAddress(targetContractAddress);
this.setFunctionSelector(functionSelector);
this.setCalldata(args);

const callContext = CallContext.empty();
callContext.msgSender = this.msgSender;
Expand Down
16 changes: 16 additions & 0 deletions yarn-project/txe/src/txe_service/txe_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ export class TXEService {
return toForeignCallResult([]);
}

setCalldata(_length: ForeignCallSingle, calldata: ForeignCallArray) {
(this.typedOracle as TXE).setCalldata(fromArray(calldata));
return toForeignCallResult([]);
}

getFunctionSelector() {
const functionSelector = (this.typedOracle as TXE).getFunctionSelector();
return toForeignCallResult([toSingle(functionSelector.toField())]);
Expand Down Expand Up @@ -582,6 +587,17 @@ export class TXEService {
return toForeignCallResult([]);
}

//unconstrained fn calldata_copy_opcode<let N: u32>(cdoffset: u32, copy_size: u32) -> [Field; N] {}
avmOpcodeCalldataCopy(cdOffsetInput: ForeignCallSingle, copySizeInput: ForeignCallSingle) {
const cdOffset = fromSingle(cdOffsetInput).toNumber();
const copySize = fromSingle(copySizeInput).toNumber();

const calldata = (this.typedOracle as TXE).getCalldata();
const calldataSlice = calldata.slice(cdOffset, cdOffset + copySize);

return toForeignCallResult([toArray(calldataSlice)]);
}

async getPublicKeysAndPartialAddress(address: ForeignCallSingle) {
const parsedAddress = AztecAddress.fromField(fromSingle(address));
const { publicKeys, partialAddress } = await this.typedOracle.getCompleteAddress(parsedAddress);
Expand Down

0 comments on commit 56ce16a

Please sign in to comment.