Skip to content

Commit

Permalink
feat(avm)!: remove tag in NOT (#8606)
Browse files Browse the repository at this point in the history
Case study to see how it goes.
  • Loading branch information
fcarreiro authored Sep 18, 2024
1 parent 5f5ec20 commit d5695fc
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 74 deletions.
2 changes: 1 addition & 1 deletion avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ pub fn brillig_to_avm(
make_operand(bits_needed, &source.0),
make_operand(bits_needed, &destination.0),
],
tag: Some(tag_from_bit_size(BitSize::Integer(*bit_size))),
tag: None,
});
if let IntegerBitSize::U1 = bit_size {
// We need to cast the result back to u1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ class AvmBitwiseTests : public ::testing::Test {
std::vector<Row> gen_mutated_trace_not(FF const& a, FF const& c_mutated, avm_trace::AvmMemoryTag tag)
{
trace_builder.op_set(0, a, 0, tag);
trace_builder.op_not(0, 0, 1, tag);
trace_builder.op_not(0, 0, 1);
trace_builder.op_return(0, 0, 0);
auto trace = trace_builder.finalize();

Expand Down Expand Up @@ -467,7 +467,7 @@ TEST_P(AvmBitwiseTestsNot, ParamTest)
const auto [operands, mem_tag] = GetParam();
const auto [a, output] = operands;
trace_builder.op_set(0, a, 0, mem_tag);
trace_builder.op_not(0, 0, 1, mem_tag); // [1,254,0,0,....]
trace_builder.op_not(0, 0, 1); // [1,254,0,0,....]
trace_builder.op_return(0, 0, 0);
auto trace = trace_builder.finalize();
common_validate_op_not(trace, a, output, FF(0), FF(1), mem_tag);
Expand Down Expand Up @@ -745,7 +745,7 @@ TEST_F(AvmBitwiseNegativeTestsFF, UndefinedOverFF)
// Triggers a write row 1 of mem_trace and alu_trace
trace_builder.op_set(0, 10, 0, AvmMemoryTag::U8);
// Triggers a write in row 2 of alu_trace
trace_builder.op_not(0, 0, 1, AvmMemoryTag::U8);
trace_builder.op_not(0, 0, 1);
// Finally, we will have a write in row 3 of the mem_trace to copy the result
// from the op_not operation.
trace_builder.op_return(0, 0, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ const std::unordered_map<OpCode, std::vector<OperandType>> OPCODE_WIRE_FORMAT =
{ OpCode::OR_16, three_operand_format16 },
{ OpCode::XOR_8, three_operand_format8 },
{ OpCode::XOR_16, three_operand_format16 },
{ OpCode::NOT_8, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT8, OperandType::UINT8 } },
{ OpCode::NOT_16, { OperandType::INDIRECT, OperandType::TAG, OperandType::UINT16, OperandType::UINT16 } },
{ OpCode::NOT_8, { OperandType::INDIRECT, OperandType::UINT8, OperandType::UINT8 } },
{ OpCode::NOT_16, { OperandType::INDIRECT, OperandType::UINT16, OperandType::UINT16 } },
{ OpCode::SHL_8, three_operand_format8 },
{ OpCode::SHL_16, three_operand_format16 },
{ OpCode::SHR_8, three_operand_format8 },
Expand Down
10 changes: 4 additions & 6 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,15 +558,13 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
break;
case OpCode::NOT_8:
trace_builder.op_not(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint8_t>(inst.operands.at(2)),
std::get<uint8_t>(inst.operands.at(3)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
std::get<uint8_t>(inst.operands.at(1)),
std::get<uint8_t>(inst.operands.at(2)));
break;
case OpCode::NOT_16:
trace_builder.op_not(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint16_t>(inst.operands.at(2)),
std::get<uint16_t>(inst.operands.at(3)),
std::get<AvmMemoryTag>(inst.operands.at(1)));
std::get<uint16_t>(inst.operands.at(1)),
std::get<uint16_t>(inst.operands.at(2)));
break;
case OpCode::SHL_8:
trace_builder.op_shl(std::get<uint8_t>(inst.operands.at(0)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class AvmMemTraceBuilder {

// DO NOT USE FOR REAL OPERATIONS
FF unconstrained_read(uint8_t space_id, uint32_t addr) { return memory[space_id][addr].val; }
AvmMemoryTag unconstrained_get_memory_tag(uint8_t space_id, uint32_t addr) { return memory[space_id][addr].tag; }

private:
std::vector<MemoryTraceEntry> mem_trace; // Entries will be sorted by m_clk, m_sub_clk after finalize().
Expand Down
14 changes: 12 additions & 2 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,15 @@ AvmTraceBuilder::MemOp AvmTraceBuilder::constrained_write_to_memory(uint8_t spac
.val = value };
}

AvmMemoryTag AvmTraceBuilder::unconstrained_get_memory_tag(AddressWithMode addr)
{
auto offset = addr.offset;
if (addr.mode == AddressingMode::INDIRECT) {
offset = static_cast<decltype(offset)>(mem_trace_builder.unconstrained_read(call_ptr, offset));
}
return mem_trace_builder.unconstrained_get_memory_tag(call_ptr, offset);
}

FF AvmTraceBuilder::unconstrained_read_from_memory(AddressWithMode addr)
{
auto offset = addr.offset;
Expand Down Expand Up @@ -989,15 +998,16 @@ void AvmTraceBuilder::op_xor(
* @param indirect A byte encoding information about indirect/direct memory access.
* @param a_offset An index in memory pointing to the only operand of Not.
* @param dst_offset An index in memory pointing to the output of Not.
* @param in_tag The instruction memory tag of the operands.
*/
void AvmTraceBuilder::op_not(uint8_t indirect, uint32_t a_offset, uint32_t dst_offset, AvmMemoryTag in_tag)
void AvmTraceBuilder::op_not(uint8_t indirect, uint32_t a_offset, uint32_t dst_offset)
{
auto clk = static_cast<uint32_t>(main_trace.size()) + 1;

// Resolve any potential indirects in the order they are encoded in the indirect byte.
auto [resolved_a, resolved_c] = unpack_indirects<2>(indirect, { a_offset, dst_offset });

// We get our representative memory tag from the resolved_a memory address.
AvmMemoryTag in_tag = unconstrained_get_memory_tag(resolved_a);
// Reading from memory and loading into ia
auto read_a = constrained_read_from_memory(call_ptr, clk, resolved_a, in_tag, in_tag, IntermRegister::IA);

Expand Down
3 changes: 2 additions & 1 deletion barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class AvmTraceBuilder {
void op_and(uint8_t indirect, uint32_t a_offset, uint32_t b_offset, uint32_t dst_offset, AvmMemoryTag in_tag);
void op_or(uint8_t indirect, uint32_t a_offset, uint32_t b_offset, uint32_t dst_offset, AvmMemoryTag in_tag);
void op_xor(uint8_t indirect, uint32_t a_offset, uint32_t b_offset, uint32_t dst_offset, AvmMemoryTag in_tag);
void op_not(uint8_t indirect, uint32_t a_offset, uint32_t dst_offset, AvmMemoryTag in_tag);
void op_not(uint8_t indirect, uint32_t a_offset, uint32_t dst_offset);
void op_shl(uint8_t indirect, uint32_t a_offset, uint32_t b_offset, uint32_t dst_offset, AvmMemoryTag in_tag);
void op_shr(uint8_t indirect, uint32_t a_offset, uint32_t b_offset, uint32_t dst_offset, AvmMemoryTag in_tag);

Expand Down Expand Up @@ -270,6 +270,7 @@ class AvmTraceBuilder {
AvmMemTraceBuilder::MemOpOwner mem_op_owner = AvmMemTraceBuilder::MAIN);

// TODO: remove these once everything is constrained.
AvmMemoryTag unconstrained_get_memory_tag(AddressWithMode addr);
FF unconstrained_read_from_memory(AddressWithMode addr);
template <typename T> void read_slice_from_memory(AddressWithMode addr, size_t slice_len, std::vector<T>& slice);
void write_to_memory(AddressWithMode addr, FF val, AvmMemoryTag w_tag);
Expand Down
13 changes: 5 additions & 8 deletions yarn-project/simulator/src/avm/opcodes/bitwise.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -365,16 +365,13 @@ describe('Bitwise instructions', () => {
const buf = Buffer.from([
Opcode.NOT_16, // opcode
0x01, // indirect
TypeTag.UINT64, // inTag
...Buffer.from('1234', 'hex'), // aOffset
...Buffer.from('3456', 'hex'), // dstOffset
]);
const inst = new Not(
/*indirect=*/ 0x01,
/*inTag=*/ TypeTag.UINT64,
/*aOffset=*/ 0x1234,
/*dstOffset=*/ 0x3456,
).as(Opcode.NOT_16, Not.wireFormat16);
const inst = new Not(/*indirect=*/ 0x01, /*aOffset=*/ 0x1234, /*dstOffset=*/ 0x3456).as(
Opcode.NOT_16,
Not.wireFormat16,
);

expect(Not.as(Not.wireFormat16).deserialize(buf)).toEqual(inst);
expect(inst.serialize()).toEqual(buf);
Expand All @@ -385,7 +382,7 @@ describe('Bitwise instructions', () => {

context.machineState.memory.set(0, a);

await new Not(/*indirect=*/ 0, /*inTag=*/ TypeTag.UINT16, /*aOffset=*/ 0, /*dstOffset=*/ 1).execute(context);
await new Not(/*indirect=*/ 0, /*aOffset=*/ 0, /*dstOffset=*/ 1).execute(context);

const expected = new Uint16(0b1001101100011011n); // high bits!
const actual = context.machineState.memory.get(1);
Expand Down
24 changes: 14 additions & 10 deletions yarn-project/simulator/src/avm/opcodes/bitwise.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import type { AvmContext } from '../avm_context.js';
import { type IntegralValue, type TaggedMemoryInterface, TypeTag } from '../avm_memory_types.js';
import { Opcode } from '../serialization/instruction_serialization.js';
import { type IntegralValue, TaggedMemory, type TaggedMemoryInterface, TypeTag } from '../avm_memory_types.js';
import { Opcode, OperandType } from '../serialization/instruction_serialization.js';
import { Addressing } from './addressing_mode.js';
import { ThreeOperandInstruction, TwoOperandInstruction } from './instruction_impl.js';
import { Instruction } from './instruction.js';
import { ThreeOperandInstruction } from './instruction_impl.js';

abstract class ThreeOperandBitwiseInstruction extends ThreeOperandInstruction {
public async execute(context: AvmContext): Promise<void> {
Expand Down Expand Up @@ -85,24 +86,27 @@ export class Shr extends ThreeOperandBitwiseInstruction {
}
}

export class Not extends TwoOperandInstruction {
export class Not extends Instruction {
static readonly type: string = 'NOT';
static readonly opcode = Opcode.NOT_8;

constructor(indirect: number, inTag: number, aOffset: number, dstOffset: number) {
super(indirect, inTag, aOffset, dstOffset);
static readonly wireFormat8 = [OperandType.UINT8, OperandType.UINT8, OperandType.UINT8, OperandType.UINT8];
static readonly wireFormat16 = [OperandType.UINT8, OperandType.UINT8, OperandType.UINT16, OperandType.UINT16];

constructor(private indirect: number, private srcOffset: number, private dstOffset: number) {
super();
}

public async execute(context: AvmContext): Promise<void> {
const memoryOperations = { reads: 1, writes: 1, indirect: this.indirect };
const memory = context.machineState.memory.track(this.type);
context.machineState.consumeGas(this.gasCost(memoryOperations));

const [aOffset, dstOffset] = Addressing.fromWire(this.indirect).resolve([this.aOffset, this.dstOffset], memory);
memory.checkTags(this.inTag, aOffset);
const a = memory.getAs<IntegralValue>(aOffset);
const [srcOffset, dstOffset] = Addressing.fromWire(this.indirect).resolve([this.srcOffset, this.dstOffset], memory);
TaggedMemory.checkIsIntegralTag(memory.getTag(srcOffset));
const value = memory.getAs<IntegralValue>(srcOffset);

const res = a.not();
const res = value.not();
memory.set(dstOffset, res);

memory.assert(memoryOperations);
Expand Down
35 changes: 0 additions & 35 deletions yarn-project/simulator/src/avm/opcodes/instruction_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,6 @@ import { OperandType } from '../serialization/instruction_serialization.js';
import { Addressing } from './addressing_mode.js';
import { Instruction } from './instruction.js';

/** Wire format that informs deserialization for instructions with two operands. */
export const TwoOperandWireFormat8 = [
OperandType.UINT8,
OperandType.UINT8,
OperandType.UINT8,
OperandType.UINT8,
OperandType.UINT8,
];
export const TwoOperandWireFormat16 = [
OperandType.UINT8,
OperandType.UINT8,
OperandType.UINT8,
OperandType.UINT16,
OperandType.UINT16,
];

/** Wire format that informs deserialization for instructions with three operands. */
export const ThreeOperandWireFormat8 = [
OperandType.UINT8,
Expand All @@ -38,25 +22,6 @@ export const ThreeOperandWireFormat16 = [
OperandType.UINT16,
];

/**
* Covers (de)serialization for an instruction with:
* indirect, inTag, and two operands.
*/
export abstract class TwoOperandInstruction extends Instruction {
// Informs (de)serialization. See Instruction.deserialize.
static readonly wireFormat8: OperandType[] = TwoOperandWireFormat8;
static readonly wireFormat16: OperandType[] = TwoOperandWireFormat16;

constructor(
protected indirect: number,
protected inTag: number,
protected aOffset: number,
protected dstOffset: number,
) {
super();
}
}

/**
* Covers (de)serialization for an instruction with:
* indirect, inTag, and three operands.
Expand Down
26 changes: 20 additions & 6 deletions yarn-project/simulator/src/avm/opcodes/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { Field, TaggedMemory } from '../avm_memory_types.js';
import { Opcode, OperandType } from '../serialization/instruction_serialization.js';
import { Addressing } from './addressing_mode.js';
import { Instruction } from './instruction.js';
import { TwoOperandInstruction } from './instruction_impl.js';

export class Set extends Instruction {
static readonly type: string = 'SET';
Expand Down Expand Up @@ -121,23 +120,38 @@ export class CMov extends Instruction {
}
}

export class Cast extends TwoOperandInstruction {
export class Cast extends Instruction {
static readonly type: string = 'CAST';
static readonly opcode = Opcode.CAST_8;

constructor(indirect: number, dstTag: number, srcOffset: number, dstOffset: number) {
super(indirect, dstTag, srcOffset, dstOffset);
static readonly wireFormat8 = [
OperandType.UINT8,
OperandType.UINT8,
OperandType.UINT8,
OperandType.UINT8,
OperandType.UINT8,
];
static readonly wireFormat16 = [
OperandType.UINT8,
OperandType.UINT8,
OperandType.UINT8,
OperandType.UINT16,
OperandType.UINT16,
];

constructor(private indirect: number, private dstTag: number, private srcOffset: number, private dstOffset: number) {
super();
}

public async execute(context: AvmContext): Promise<void> {
const memoryOperations = { reads: 1, writes: 1, indirect: this.indirect };
const memory = context.machineState.memory.track(this.type);
context.machineState.consumeGas(this.gasCost(memoryOperations));

const [srcOffset, dstOffset] = Addressing.fromWire(this.indirect).resolve([this.aOffset, this.dstOffset], memory);
const [srcOffset, dstOffset] = Addressing.fromWire(this.indirect).resolve([this.srcOffset, this.dstOffset], memory);

const a = memory.get(srcOffset);
const casted = TaggedMemory.buildFromTagTruncating(a.toBigInt(), this.inTag);
const casted = TaggedMemory.buildFromTagTruncating(a.toBigInt(), this.dstTag);

memory.set(dstOffset, casted);

Expand Down

0 comments on commit d5695fc

Please sign in to comment.