Skip to content

Commit

Permalink
feat(avm)!: variants for CAST/NOT opcode (#8497)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcarreiro authored Sep 11, 2024
1 parent c1aa6f7 commit bc609fa
Show file tree
Hide file tree
Showing 18 changed files with 140 additions and 91 deletions.
12 changes: 8 additions & 4 deletions avm-transpiler/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ pub enum AvmOpcode {
OR_16,
XOR_8,
XOR_16,
NOT,
NOT_8,
NOT_16,
SHL_8,
SHL_16,
SHR_8,
SHR_16,
CAST,
CAST_8,
CAST_16,
// Execution environment
ADDRESS,
STORAGEADDRESS,
Expand Down Expand Up @@ -127,13 +129,15 @@ impl AvmOpcode {
AvmOpcode::OR_16 => "OR_16",
AvmOpcode::XOR_8 => "XOR_8",
AvmOpcode::XOR_16 => "XOR_16",
AvmOpcode::NOT => "NOT",
AvmOpcode::NOT_8 => "NOT_8",
AvmOpcode::NOT_16 => "NOT_16",
AvmOpcode::SHL_8 => "SHL_8",
AvmOpcode::SHL_16 => "SHL_16",
AvmOpcode::SHR_8 => "SHR_8",
AvmOpcode::SHR_16 => "SHR_16",
// Compute - Type Conversions
AvmOpcode::CAST => "CAST",
AvmOpcode::CAST_8 => "CAST_8",
AvmOpcode::CAST_16 => "CAST_16",

// Execution Environment
AvmOpcode::ADDRESS => "ADDRESS",
Expand Down
10 changes: 8 additions & 2 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,12 @@ fn generate_cast_instruction(
destination_indirect: bool,
dst_tag: AvmTypeTag,
) -> AvmInstruction {
let bits_needed = bits_needed_for(&source).max(bits_needed_for(&destination));
let avm_opcode = match bits_needed {
8 => AvmOpcode::CAST_8,
16 => AvmOpcode::CAST_16,
_ => panic!("CAST only supports 8 and 16 bit encodings, needed {}", bits_needed),
};
let mut indirect_flags = ALL_DIRECT;
if source_indirect {
indirect_flags |= ZEROTH_OPERAND_INDIRECT;
Expand All @@ -831,10 +837,10 @@ fn generate_cast_instruction(
indirect_flags |= FIRST_OPERAND_INDIRECT;
}
AvmInstruction {
opcode: AvmOpcode::CAST,
opcode: avm_opcode,
indirect: Some(indirect_flags),
tag: Some(dst_tag),
operands: vec![AvmOperand::U32 { value: source }, AvmOperand::U32 { value: destination }],
operands: vec![make_operand(bits_needed, &source), make_operand(bits_needed, &destination)],
}
}

Expand Down
60 changes: 30 additions & 30 deletions barretenberg/cpp/src/barretenberg/vm/avm/tests/execution.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,11 +730,11 @@ TEST_F(AvmExecutionTests, setAndCastOpcodes)
"02" // U16
"B813" // val 47123
"0011" // dst_offset 17
+ to_hex(OpCode::CAST) + // opcode CAST
+ to_hex(OpCode::CAST_8) + // opcode CAST
"00" // Indirect flag
"01" // U8
"00000011" // addr a
"00000012" // addr casted a
"11" // addr a
"12" // addr casted a
+ to_hex(OpCode::RETURN) + // opcode RETURN
"00" // Indirect flag
"00000000" // ret offset 0
Expand All @@ -747,12 +747,12 @@ TEST_F(AvmExecutionTests, setAndCastOpcodes)

// SUB
EXPECT_THAT(instructions.at(1),
AllOf(Field(&Instruction::op_code, OpCode::CAST),
AllOf(Field(&Instruction::op_code, OpCode::CAST_8),
Field(&Instruction::operands,
ElementsAre(VariantWith<uint8_t>(0),
VariantWith<AvmMemoryTag>(AvmMemoryTag::U8),
VariantWith<uint32_t>(17),
VariantWith<uint32_t>(18)))));
VariantWith<uint8_t>(17),
VariantWith<uint8_t>(18)))));

auto trace = gen_trace_from_instr(instructions);

Expand Down Expand Up @@ -1238,16 +1238,16 @@ TEST_F(AvmExecutionTests, embeddedCurveAddOpCode)
"00000000" // cd_offset
"00000001" // copy_size
"00000000" // dst_offset
+ to_hex(OpCode::CAST) + // opcode CAST inf to U8
+ to_hex(OpCode::CAST_8) + // opcode CAST inf to U8
"00" // Indirect flag
"01" // U8 tag field
"00000002" // a_is_inf
"00000002" // a_is_inf
+ to_hex(OpCode::CAST) + // opcode CAST inf to U8
"02" // a_is_inf
"02" // a_is_inf
+ to_hex(OpCode::CAST_8) + // opcode CAST inf to U8
"00" // Indirect flag
"01" // U8 tag field
"00000005" // b_is_inf
"00000005" // b_is_inf
"05" // b_is_inf
"05" // b_is_inf
+ to_hex(OpCode::SET_8) + // opcode SET for direct src_length
"00" // Indirect flag
"03" // U32
Expand Down Expand Up @@ -1314,16 +1314,16 @@ TEST_F(AvmExecutionTests, msmOpCode)
"00000000" // cd_offset 0
"00000001" // copy_size (10 elements)
"00000000" // dst_offset 0
+ to_hex(OpCode::CAST) + // opcode CAST inf to U8
+ to_hex(OpCode::CAST_8) + // opcode CAST inf to U8
"00" // Indirect flag
"01" // U8 tag field
"00000002" // a_is_inf
"00000002" //
+ to_hex(OpCode::CAST) + // opcode CAST inf to U8
"02" // a_is_inf
"02" //
+ to_hex(OpCode::CAST_8) + // opcode CAST inf to U8
"00" // Indirect flag
"01" // U8 tag field
"00000005" // b_is_inf
"00000005" //
"05" // b_is_inf
"05" //
+ to_hex(OpCode::SET_8) + // opcode SET for length
"00" // Indirect flag
"03" // U32
Expand Down Expand Up @@ -1758,11 +1758,11 @@ TEST_F(AvmExecutionTests, kernelOutputEmitOpcodes)
"01" // value 1
"01" // dst_offset 1
// Cast set to field
+ to_hex(OpCode::CAST) + // opcode CAST
+ to_hex(OpCode::CAST_8) + // opcode CAST
"00" // Indirect flag
"06" // tag field
"00000001" // dst 1
"00000001" // dst 1
"01" // dst 1
"01" // dst 1
+ to_hex(OpCode::EMITNOTEHASH) + // opcode EMITNOTEHASH
"00" // Indirect flag
"00000001" // src offset 1
Expand Down Expand Up @@ -1859,11 +1859,11 @@ TEST_F(AvmExecutionTests, kernelOutputStorageLoadOpcodeSimple)
"03" // U32
"09" // value 9
"01" // dst_offset 1
+ to_hex(OpCode::CAST) + // opcode CAST (Cast set to field)
+ to_hex(OpCode::CAST_8) + // opcode CAST (Cast set to field)
"00" // Indirect flag
"06" // tag field
"00000001" // dst 1
"00000001" // dst 1
"01" // dst 1
"01" // dst 1
+ to_hex(OpCode::SLOAD) + // opcode SLOAD
"00" // Indirect flag
"00000001" // slot offset 1
Expand Down Expand Up @@ -1972,11 +1972,11 @@ TEST_F(AvmExecutionTests, kernelOutputStorageOpcodes)
"09" // value 9
"01" // dst_offset 1
// Cast set to field
+ to_hex(OpCode::CAST) + // opcode CAST
+ to_hex(OpCode::CAST_8) + // opcode CAST
"00" // Indirect flag
"06" // tag field
"00000001" // dst 1
"00000001" // dst 1
"01" // dst 1
"01" // dst 1
+ to_hex(OpCode::SLOAD) + // opcode SLOAD
"00" // Indirect flag
"00000001" // slot offset 1
Expand Down Expand Up @@ -2047,11 +2047,11 @@ TEST_F(AvmExecutionTests, kernelOutputHashExistsOpcodes)
"01" // value 1
"01" // dst_offset 1
// Cast set to field
+ to_hex(OpCode::CAST) + // opcode CAST
+ to_hex(OpCode::CAST_8) + // opcode CAST
"00" // Indirect flag
"06" // tag field
"00000001" // dst 1
"00000001" // dst 1
"01" // dst 1
"01" // dst 1
+ to_hex(OpCode::NOTEHASHEXISTS) + // opcode NOTEHASHEXISTS
"00" // Indirect flag
"00000001" // slot offset 1
Expand Down
15 changes: 8 additions & 7 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/alu_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ FF AvmAluTraceBuilder::op_not(FF const& a, AvmMemoryTag in_tag, uint32_t const c

alu_trace.push_back(AvmAluTraceBuilder::AluTraceEntry{
.alu_clk = clk,
.opcode = OpCode::NOT,
.opcode = OpCode::NOT_8, // FIXME: take into account all opcodes.
.tag = in_tag,
.alu_ia = a,
.alu_ic = c,
Expand Down Expand Up @@ -585,7 +585,7 @@ FF AvmAluTraceBuilder::op_cast(FF const& a, AvmMemoryTag in_tag, uint32_t clk)
}
alu_trace.push_back(AvmAluTraceBuilder::AluTraceEntry{
.alu_clk = clk,
.opcode = OpCode::CAST,
.opcode = OpCode::CAST_8, // FIXME: take into account all opcodes.
.tag = in_tag,
.alu_ia = a,
.alu_ic = c,
Expand Down Expand Up @@ -618,9 +618,10 @@ bool AvmAluTraceBuilder::is_range_check_required() const
bool AvmAluTraceBuilder::is_alu_row_enabled(const AvmAluTraceBuilder::AluTraceEntry& r)
{
return (r.opcode == OpCode::ADD_8 || r.opcode == OpCode::SUB_8 || r.opcode == OpCode::MUL_8 ||
r.opcode == OpCode::EQ_8 || r.opcode == OpCode::NOT || r.opcode == OpCode::LT_8 ||
r.opcode == OpCode::LTE_8 || r.opcode == OpCode::SHR_8 || r.opcode == OpCode::SHL_8 ||
r.opcode == OpCode::CAST || r.opcode == OpCode::DIV_8);
r.opcode == OpCode::EQ_8 || r.opcode == OpCode::NOT_8 || r.opcode == OpCode::NOT_16 ||
r.opcode == OpCode::LT_8 || r.opcode == OpCode::LTE_8 || r.opcode == OpCode::SHR_8 ||
r.opcode == OpCode::SHL_8 || r.opcode == OpCode::CAST_8 || r.opcode == OpCode::CAST_8 ||
r.opcode == OpCode::CAST_16 || r.opcode == OpCode::DIV_8);
}

/**
Expand All @@ -640,11 +641,11 @@ void AvmAluTraceBuilder::finalize(std::vector<AvmFullRow<FF>>& main_trace)
dest.alu_op_add = FF(src.opcode == OpCode::ADD_8 || src.opcode == OpCode::ADD_16 ? 1 : 0);
dest.alu_op_sub = FF(src.opcode == OpCode::SUB_8 || src.opcode == OpCode::SUB_16 ? 1 : 0);
dest.alu_op_mul = FF(src.opcode == OpCode::MUL_8 || src.opcode == OpCode::MUL_16 ? 1 : 0);
dest.alu_op_not = FF(src.opcode == OpCode::NOT ? 1 : 0);
dest.alu_op_not = FF(src.opcode == OpCode::NOT_8 || src.opcode == OpCode::NOT_16 ? 1 : 0);
dest.alu_op_eq = FF(src.opcode == OpCode::EQ_8 || src.opcode == OpCode::EQ_16 ? 1 : 0);
dest.alu_op_lt = FF(src.opcode == OpCode::LT_8 || src.opcode == OpCode::LT_16 ? 1 : 0);
dest.alu_op_lte = FF(src.opcode == OpCode::LTE_8 || src.opcode == OpCode::LTE_16 ? 1 : 0);
dest.alu_op_cast = FF(src.opcode == OpCode::CAST ? 1 : 0);
dest.alu_op_cast = FF(src.opcode == OpCode::CAST_8 || src.opcode == OpCode::CAST_16 ? 1 : 0);
dest.alu_op_shr = FF(src.opcode == OpCode::SHR_8 || src.opcode == OpCode::SHR_16 ? 1 : 0);
dest.alu_op_shl = FF(src.opcode == OpCode::SHL_8 || src.opcode == OpCode::SHL_16 ? 1 : 0);
dest.alu_op_div = FF(src.opcode == OpCode::DIV_8 || src.opcode == OpCode::DIV_16 ? 1 : 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ const std::unordered_map<OpCode, std::vector<OperandType>> OPCODE_WIRE_FORMAT =
{ OpCode::OR_16, three_operand_format16 },
{ OpCode::XOR_8, three_operand_format8 },
{ OpCode::XOR_16, three_operand_format16 },
{ OpCode::NOT, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT8, OperandType::UINT8 } },
{ OpCode::NOT_8, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT8, OperandType::UINT8 } },
{ OpCode::NOT_16, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT16, OperandType::UINT16 } },
{ OpCode::SHL_8, three_operand_format8 },
{ OpCode::SHL_16, three_operand_format16 },
{ OpCode::SHR_8, three_operand_format8 },
{ OpCode::SHR_16, three_operand_format16 },
// Compute - Type Conversions
{ OpCode::CAST, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT32, OperandType::UINT32 } },
{ OpCode::CAST_8, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT8, OperandType::UINT8 } },
{ OpCode::CAST_16, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT16, OperandType::UINT16 } },

// Execution Environment - Globals
{ OpCode::ADDRESS, getter_format },
Expand Down
24 changes: 18 additions & 6 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,16 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
std::get<uint16_t>(inst.operands.at(4)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
break;
case OpCode::NOT:
case OpCode::NOT_8:
trace_builder.op_not(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint32_t>(inst.operands.at(2)),
std::get<uint32_t>(inst.operands.at(3)),
std::get<uint8_t>(inst.operands.at(2)),
std::get<uint8_t>(inst.operands.at(3)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
break;
case OpCode::NOT_16:
trace_builder.op_not(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint16_t>(inst.operands.at(2)),
std::get<uint16_t>(inst.operands.at(3)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
break;
case OpCode::SHL_8:
Expand Down Expand Up @@ -592,10 +598,16 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
break;

// Compute - Type Conversions
case OpCode::CAST:
case OpCode::CAST_8:
trace_builder.op_cast(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint32_t>(inst.operands.at(2)),
std::get<uint32_t>(inst.operands.at(3)),
std::get<uint8_t>(inst.operands.at(2)),
std::get<uint8_t>(inst.operands.at(3)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
break;
case OpCode::CAST_16:
trace_builder.op_cast(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint16_t>(inst.operands.at(2)),
std::get<uint16_t>(inst.operands.at(3)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
break;

Expand Down
6 changes: 4 additions & 2 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_gas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ const std::unordered_map<OpCode, FixedGasTable::GasRow> GAS_COST_TABLE = {
{ OpCode::OR_16, make_cost(AVM_OR_BASE_L2_GAS, 0, AVM_OR_DYN_L2_GAS, 0) },
{ OpCode::XOR_8, make_cost(AVM_XOR_BASE_L2_GAS, 0, AVM_XOR_DYN_L2_GAS, 0) },
{ OpCode::XOR_16, make_cost(AVM_XOR_BASE_L2_GAS, 0, AVM_XOR_DYN_L2_GAS, 0) },
{ OpCode::NOT, make_cost(AVM_NOT_BASE_L2_GAS, 0, AVM_NOT_DYN_L2_GAS, 0) },
{ OpCode::NOT_8, make_cost(AVM_NOT_BASE_L2_GAS, 0, AVM_NOT_DYN_L2_GAS, 0) },
{ OpCode::NOT_16, make_cost(AVM_NOT_BASE_L2_GAS, 0, AVM_NOT_DYN_L2_GAS, 0) },
{ OpCode::SHL_8, make_cost(AVM_SHL_BASE_L2_GAS, 0, AVM_SHL_DYN_L2_GAS, 0) },
{ OpCode::SHL_16, make_cost(AVM_SHL_BASE_L2_GAS, 0, AVM_SHL_DYN_L2_GAS, 0) },
{ OpCode::SHR_8, make_cost(AVM_SHR_BASE_L2_GAS, 0, AVM_SHR_DYN_L2_GAS, 0) },
{ OpCode::SHR_16, make_cost(AVM_SHR_BASE_L2_GAS, 0, AVM_SHR_DYN_L2_GAS, 0) },
{ OpCode::CAST, make_cost(AVM_CAST_BASE_L2_GAS, 0, AVM_CAST_DYN_L2_GAS, 0) },
{ OpCode::CAST_8, make_cost(AVM_CAST_BASE_L2_GAS, 0, AVM_CAST_DYN_L2_GAS, 0) },
{ OpCode::CAST_16, make_cost(AVM_CAST_BASE_L2_GAS, 0, AVM_CAST_DYN_L2_GAS, 0) },
{ OpCode::ADDRESS, make_cost(AVM_ADDRESS_BASE_L2_GAS, 0, AVM_ADDRESS_DYN_L2_GAS, 0) },
{ OpCode::STORAGEADDRESS, make_cost(AVM_STORAGEADDRESS_BASE_L2_GAS, 0, AVM_STORAGEADDRESS_DYN_L2_GAS, 0) },
{ OpCode::SENDER, make_cost(AVM_SENDER_BASE_L2_GAS, 0, AVM_SENDER_DYN_L2_GAS, 0) },
Expand Down
Loading

0 comments on commit bc609fa

Please sign in to comment.