Skip to content

Commit

Permalink
feat: Add CMOV instruction to brillig and brillig gen (#5308)
Browse files Browse the repository at this point in the history
  • Loading branch information
sirasistant authored Mar 20, 2024
1 parent 5f0f4fc commit 208abbb
Show file tree
Hide file tree
Showing 14 changed files with 297 additions and 46 deletions.
18 changes: 18 additions & 0 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,24 @@ pub fn brillig_to_avm(brillig: &Brillig) -> Vec<u8> {
} => {
avm_instrs.push(generate_mov_instruction(Some(ALL_DIRECT), source.to_usize() as u32, destination.to_usize() as u32));
}
BrilligOpcode::ConditionalMov {
source_a,
source_b,
condition,
destination,
} => {
avm_instrs.push(AvmInstruction {
opcode: AvmOpcode::CMOV,
indirect: Some(ALL_DIRECT),
operands: vec![
AvmOperand::U32 { value: source_a.to_usize() as u32 },
AvmOperand::U32 { value: source_b.to_usize() as u32 },
AvmOperand::U32 { value: condition.to_usize() as u32 },
AvmOperand::U32 { value: destination.to_usize() as u32 },
],
..Default::default()
});
}
BrilligOpcode::Load {
destination,
source_pointer,
Expand Down
74 changes: 74 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,17 @@ struct BrilligOpcode {
static Mov bincodeDeserialize(std::vector<uint8_t>);
};

struct ConditionalMov {
Program::MemoryAddress destination;
Program::MemoryAddress source_a;
Program::MemoryAddress source_b;
Program::MemoryAddress condition;

friend bool operator==(const ConditionalMov&, const ConditionalMov&);
std::vector<uint8_t> bincodeSerialize() const;
static ConditionalMov bincodeDeserialize(std::vector<uint8_t>);
};

struct Load {
Program::MemoryAddress destination;
Program::MemoryAddress source_pointer;
Expand Down Expand Up @@ -644,6 +655,7 @@ struct BrilligOpcode {
Return,
ForeignCall,
Mov,
ConditionalMov,
Load,
Store,
BlackBox,
Expand Down Expand Up @@ -5832,6 +5844,68 @@ Program::BrilligOpcode::Mov serde::Deserializable<Program::BrilligOpcode::Mov>::

namespace Program {

inline bool operator==(const BrilligOpcode::ConditionalMov& lhs, const BrilligOpcode::ConditionalMov& rhs)
{
if (!(lhs.destination == rhs.destination)) {
return false;
}
if (!(lhs.source_a == rhs.source_a)) {
return false;
}
if (!(lhs.source_b == rhs.source_b)) {
return false;
}
if (!(lhs.condition == rhs.condition)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BrilligOpcode::ConditionalMov::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligOpcode::ConditionalMov>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligOpcode::ConditionalMov BrilligOpcode::ConditionalMov::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligOpcode::ConditionalMov>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BrilligOpcode::ConditionalMov>::serialize(
const Program::BrilligOpcode::ConditionalMov& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.destination)>::serialize(obj.destination, serializer);
serde::Serializable<decltype(obj.source_a)>::serialize(obj.source_a, serializer);
serde::Serializable<decltype(obj.source_b)>::serialize(obj.source_b, serializer);
serde::Serializable<decltype(obj.condition)>::serialize(obj.condition, serializer);
}

template <>
template <typename Deserializer>
Program::BrilligOpcode::ConditionalMov serde::Deserializable<Program::BrilligOpcode::ConditionalMov>::deserialize(
Deserializer& deserializer)
{
Program::BrilligOpcode::ConditionalMov obj;
obj.destination = serde::Deserializable<decltype(obj.destination)>::deserialize(deserializer);
obj.source_a = serde::Deserializable<decltype(obj.source_a)>::deserialize(deserializer);
obj.source_b = serde::Deserializable<decltype(obj.source_b)>::deserialize(deserializer);
obj.condition = serde::Deserializable<decltype(obj.condition)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BrilligOpcode::Load& lhs, const BrilligOpcode::Load& rhs)
{
if (!(lhs.destination == rhs.destination)) {
Expand Down
60 changes: 59 additions & 1 deletion noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,17 @@ namespace Program {
static Mov bincodeDeserialize(std::vector<uint8_t>);
};

struct ConditionalMov {
Program::MemoryAddress destination;
Program::MemoryAddress source_a;
Program::MemoryAddress source_b;
Program::MemoryAddress condition;

friend bool operator==(const ConditionalMov&, const ConditionalMov&);
std::vector<uint8_t> bincodeSerialize() const;
static ConditionalMov bincodeDeserialize(std::vector<uint8_t>);
};

struct Load {
Program::MemoryAddress destination;
Program::MemoryAddress source_pointer;
Expand Down Expand Up @@ -612,7 +623,7 @@ namespace Program {
static Stop bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<BinaryFieldOp, BinaryIntOp, Cast, JumpIfNot, JumpIf, Jump, CalldataCopy, Call, Const, Return, ForeignCall, Mov, Load, Store, BlackBox, Trap, Stop> value;
std::variant<BinaryFieldOp, BinaryIntOp, Cast, JumpIfNot, JumpIf, Jump, CalldataCopy, Call, Const, Return, ForeignCall, Mov, ConditionalMov, Load, Store, BlackBox, Trap, Stop> value;

friend bool operator==(const BrilligOpcode&, const BrilligOpcode&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4826,6 +4837,53 @@ Program::BrilligOpcode::Mov serde::Deserializable<Program::BrilligOpcode::Mov>::
return obj;
}

namespace Program {

inline bool operator==(const BrilligOpcode::ConditionalMov &lhs, const BrilligOpcode::ConditionalMov &rhs) {
if (!(lhs.destination == rhs.destination)) { return false; }
if (!(lhs.source_a == rhs.source_a)) { return false; }
if (!(lhs.source_b == rhs.source_b)) { return false; }
if (!(lhs.condition == rhs.condition)) { return false; }
return true;
}

inline std::vector<uint8_t> BrilligOpcode::ConditionalMov::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BrilligOpcode::ConditionalMov>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BrilligOpcode::ConditionalMov BrilligOpcode::ConditionalMov::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BrilligOpcode::ConditionalMov>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BrilligOpcode::ConditionalMov>::serialize(const Program::BrilligOpcode::ConditionalMov &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.destination)>::serialize(obj.destination, serializer);
serde::Serializable<decltype(obj.source_a)>::serialize(obj.source_a, serializer);
serde::Serializable<decltype(obj.source_b)>::serialize(obj.source_b, serializer);
serde::Serializable<decltype(obj.condition)>::serialize(obj.condition, serializer);
}

template <>
template <typename Deserializer>
Program::BrilligOpcode::ConditionalMov serde::Deserializable<Program::BrilligOpcode::ConditionalMov>::deserialize(Deserializer &deserializer) {
Program::BrilligOpcode::ConditionalMov obj;
obj.destination = serde::Deserializable<decltype(obj.destination)>::deserialize(deserializer);
obj.source_a = serde::Deserializable<decltype(obj.source_a)>::deserialize(deserializer);
obj.source_b = serde::Deserializable<decltype(obj.source_b)>::deserialize(deserializer);
obj.condition = serde::Deserializable<decltype(obj.condition)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BrilligOpcode::Load &lhs, const BrilligOpcode::Load &rhs) {
Expand Down
8 changes: 4 additions & 4 deletions noir/noir-repo/acvm-repo/acir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ mod reflection {
generator.output(&mut source, &registry).unwrap();

// Comment this out to write updated C++ code to file.
// if let Some(old_hash) = old_hash {
// let new_hash = fxhash::hash64(&source);
// assert_eq!(new_hash, old_hash, "Serialization format has changed");
// }
if let Some(old_hash) = old_hash {
let new_hash = fxhash::hash64(&source);
assert_eq!(new_hash, old_hash, "Serialization format has changed");
}

write_to_file(&source, &path);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,11 @@ fn simple_brillig_foreign_call() {
let bytes = Program::serialize_program(&program);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 77, 10, 165, 244, 212,
167, 216, 31, 244, 51, 61, 120, 241, 32, 226, 251, 85, 140, 176, 136, 122, 209, 129, 144,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 77, 10, 165, 244, 214,
159, 216, 31, 244, 51, 61, 120, 241, 32, 226, 251, 85, 140, 176, 136, 122, 209, 129, 144,
176, 9, 97, 151, 84, 225, 74, 69, 50, 31, 48, 35, 85, 251, 164, 235, 53, 94, 218, 247, 75,
163, 95, 150, 12, 153, 179, 227, 191, 114, 195, 222, 216, 240, 59, 63, 75, 221, 251, 208,
106, 207, 232, 150, 65, 100, 53, 33, 2, 9, 69, 91, 82, 144, 1, 0, 0,
106, 207, 232, 150, 65, 100, 53, 33, 2, 22, 232, 178, 27, 144, 1, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down Expand Up @@ -311,15 +311,15 @@ fn complex_brillig_foreign_call() {
let bytes = Program::serialize_program(&program);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 65, 14, 132, 32, 12, 108, 101, 117, 205, 158,
246, 9, 38, 187, 15, 96, 247, 5, 254, 197, 120, 211, 232, 209, 231, 139, 113, 136, 181, 65,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 65, 14, 132, 32, 12, 108, 101, 117, 205, 222,
246, 7, 38, 187, 15, 96, 247, 5, 254, 197, 120, 211, 232, 209, 231, 139, 113, 136, 181, 65,
47, 98, 162, 147, 52, 20, 24, 202, 164, 45, 48, 205, 200, 157, 49, 124, 227, 44, 129, 207,
152, 75, 120, 94, 137, 209, 30, 195, 143, 227, 197, 178, 103, 105, 76, 110, 160, 209, 156,
160, 209, 247, 195, 69, 235, 29, 179, 46, 81, 243, 103, 2, 239, 231, 225, 44, 117, 150, 97,
254, 196, 152, 99, 157, 176, 87, 168, 188, 147, 224, 121, 20, 209, 180, 254, 109, 70, 75,
47, 178, 186, 251, 37, 116, 86, 93, 219, 55, 245, 96, 20, 85, 75, 253, 8, 255, 171, 246,
121, 231, 220, 4, 249, 237, 132, 56, 28, 224, 109, 113, 223, 180, 164, 50, 165, 0, 137, 17,
72, 139, 88, 97, 4, 198, 90, 226, 196, 33, 5, 0, 0,
72, 139, 88, 97, 4, 173, 98, 132, 157, 33, 5, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ import { WitnessMap } from '@noir-lang/acvm_js';

// See `complex_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 65, 14, 132, 32, 12, 108, 101, 117, 205, 158,
246, 9, 38, 187, 15, 96, 247, 5, 254, 197, 120, 211, 232, 209, 231, 139, 113, 136, 181, 65,
47, 98, 162, 147, 52, 20, 24, 202, 164, 45, 48, 205, 200, 157, 49, 124, 227, 44, 129, 207,
152, 75, 120, 94, 137, 209, 30, 195, 143, 227, 197, 178, 103, 105, 76, 110, 160, 209, 156,
160, 209, 247, 195, 69, 235, 29, 179, 46, 81, 243, 103, 2, 239, 231, 225, 44, 117, 150, 97,
254, 196, 152, 99, 157, 176, 87, 168, 188, 147, 224, 121, 20, 209, 180, 254, 109, 70, 75,
47, 178, 186, 251, 37, 116, 86, 93, 219, 55, 245, 96, 20, 85, 75, 253, 8, 255, 171, 246,
121, 231, 220, 4, 249, 237, 132, 56, 28, 224, 109, 113, 223, 180, 164, 50, 165, 0, 137, 17,
72, 139, 88, 97, 4, 198, 90, 226, 196, 33, 5, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 213, 84, 65, 14, 132, 32, 12, 108, 101, 117, 205, 222, 246, 7, 38, 187, 15, 96,
247, 5, 254, 197, 120, 211, 232, 209, 231, 139, 113, 136, 181, 65, 47, 98, 162, 147, 52, 20, 24, 202, 164, 45, 48,
205, 200, 157, 49, 124, 227, 44, 129, 207, 152, 75, 120, 94, 137, 209, 30, 195, 143, 227, 197, 178, 103, 105, 76, 110,
160, 209, 156, 160, 209, 247, 195, 69, 235, 29, 179, 46, 81, 243, 103, 2, 239, 231, 225, 44, 117, 150, 97, 254, 196,
152, 99, 157, 176, 87, 168, 188, 147, 224, 121, 20, 209, 180, 254, 109, 70, 75, 47, 178, 186, 251, 37, 116, 86, 93,
219, 55, 245, 96, 20, 85, 75, 253, 8, 255, 171, 246, 121, 231, 220, 4, 249, 237, 132, 56, 28, 224, 109, 113, 223, 180,
164, 50, 165, 0, 137, 17, 72, 139, 88, 97, 4, 173, 98, 132, 157, 33, 5, 0, 0,
]);
export const initialWitnessMap: WitnessMap = new Map([
[1, '0x0000000000000000000000000000000000000000000000000000000000000001'],
Expand Down
9 changes: 4 additions & 5 deletions noir/noir-repo/acvm-repo/acvm_js/test/shared/foreign_call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ import { WitnessMap } from '@noir-lang/acvm_js';

// See `simple_brillig_foreign_call` integration test in `acir/tests/test_program_serialization.rs`.
export const bytecode = Uint8Array.from([
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 77, 10, 165, 244, 212,
167, 216, 31, 244, 51, 61, 120, 241, 32, 226, 251, 85, 140, 176, 136, 122, 209, 129, 144,
176, 9, 97, 151, 84, 225, 74, 69, 50, 31, 48, 35, 85, 251, 164, 235, 53, 94, 218, 247, 75,
163, 95, 150, 12, 153, 179, 227, 191, 114, 195, 222, 216, 240, 59, 63, 75, 221, 251, 208,
106, 207, 232, 150, 65, 100, 53, 33, 2, 9, 69, 91, 82, 144, 1, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 173, 143, 65, 10, 192, 32, 12, 4, 77, 10, 165, 244, 214, 159, 216, 31, 244, 51, 61,
120, 241, 32, 226, 251, 85, 140, 176, 136, 122, 209, 129, 144, 176, 9, 97, 151, 84, 225, 74, 69, 50, 31, 48, 35, 85,
251, 164, 235, 53, 94, 218, 247, 75, 163, 95, 150, 12, 153, 179, 227, 191, 114, 195, 222, 216, 240, 59, 63, 75, 221,
251, 208, 106, 207, 232, 150, 65, 100, 53, 33, 2, 22, 232, 178, 27, 144, 1, 0, 0,
]);
export const initialWitnessMap: WitnessMap = new Map([
[1, '0x0000000000000000000000000000000000000000000000000000000000000005'],
Expand Down
7 changes: 7 additions & 0 deletions noir/noir-repo/acvm-repo/brillig/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ pub enum BrilligOpcode {
destination: MemoryAddress,
source: MemoryAddress,
},
/// destination = condition > 0 ? source_a : source_b
ConditionalMov {
destination: MemoryAddress,
source_a: MemoryAddress,
source_b: MemoryAddress,
condition: MemoryAddress,
},
Load {
destination: MemoryAddress,
source_pointer: MemoryAddress,
Expand Down
55 changes: 55 additions & 0 deletions noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,15 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
self.memory.write(*destination_address, source_value);
self.increment_program_counter()
}
Opcode::ConditionalMov { destination, source_a, source_b, condition } => {
let condition_value = self.memory.read(*condition);
if condition_value.is_zero() {
self.memory.write(*destination, self.memory.read(*source_b));
} else {
self.memory.write(*destination, self.memory.read(*source_a));
}
self.increment_program_counter()
}
Opcode::Trap => self.fail("explicit trap hit in brillig".to_string()),
Opcode::Stop { return_data_offset, return_data_size } => {
self.finish(*return_data_offset, *return_data_size)
Expand Down Expand Up @@ -793,6 +802,52 @@ mod tests {
assert_eq!(source_value, Value::from(1u128));
}

#[test]
fn cmov_opcode() {
let calldata =
vec![Value::from(0u128), Value::from(1u128), Value::from(2u128), Value::from(3u128)];

let calldata_copy = Opcode::CalldataCopy {
destination_address: MemoryAddress::from(0),
size: 4,
offset: 0,
};

let opcodes = &[
calldata_copy,
Opcode::ConditionalMov {
destination: MemoryAddress(4), // Sets 3_u128 to memory address 4
source_a: MemoryAddress(2),
source_b: MemoryAddress(3),
condition: MemoryAddress(0),
},
Opcode::ConditionalMov {
destination: MemoryAddress(5), // Sets 2_u128 to memory address 5
source_a: MemoryAddress(2),
source_b: MemoryAddress(3),
condition: MemoryAddress(1),
},
];
let mut vm = VM::new(calldata, opcodes, vec![], &DummyBlackBoxSolver);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::Finished { return_data_offset: 0, return_data_size: 0 });

let VM { memory, .. } = vm;

let destination_value = memory.read(MemoryAddress::from(4));
assert_eq!(destination_value, Value::from(3_u128));

let source_value = memory.read(MemoryAddress::from(5));
assert_eq!(source_value, Value::from(2_u128));
}

#[test]
fn cmp_binary_ops() {
let bit_size = 32;
Expand Down
Loading

0 comments on commit 208abbb

Please sign in to comment.