Skip to content

Commit

Permalink
contracts: Add more MIPS2 tests (#12003)
Browse files Browse the repository at this point in the history
* contracts: Add more MIPS2 tests

* remove unused var
  • Loading branch information
Inphi authored Sep 19, 2024
1 parent ff338bc commit 4806d83
Showing 1 changed file with 203 additions and 2 deletions.
205 changes: 203 additions & 2 deletions packages/contracts-bedrock/test/cannon/MIPS2.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2463,8 +2463,209 @@ contract MIPS2_Test is CommonTest {
assertEq(postState, outputState(expect), "unexpected post state");
}

// TODO(client-pod#959): Port over the remaining single-threaded tests from MIPS.t.sol
// TODO(client-pod#959): Assert unimplemented syscalls
function test_sll_succeeds() external {
uint8 shiftamt = 4;
uint32 insn = encodespec(0x0, 0x9, 0x8, uint16(shiftamt) << 6); // sll t0, t1, 3
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
thread.registers[9] = 0x20; // t1
updateThreadStacks(state, thread);

uint32 result = thread.registers[9] << shiftamt;
MIPS2.State memory expect = arithmeticPostState(state, thread, 8, /* t0 */ result);

bytes32 postState = mips.step(
encodeState(state), bytes.concat(abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT), memProof), 0
);
assertEq(postState, outputState(expect), "unexpected post state");
}

function test_srl_succeeds() external {
uint8 shiftamt = 4;
uint32 insn = encodespec(0x0, 0x9, 0x8, uint16(shiftamt) << 6 | 2); // srl t0, t1, 3
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
thread.registers[9] = 0x20; // t1
updateThreadStacks(state, thread);

uint32 result = thread.registers[9] >> shiftamt;
MIPS2.State memory expect = arithmeticPostState(state, thread, 8, /* t0 */ result);

bytes32 postState = mips.step(
encodeState(state), bytes.concat(abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT), memProof), 0
);
assertEq(postState, outputState(expect), "unexpected post state");
}

function test_sra_succeeds() external {
uint8 shiftamt = 4;
uint32 insn = encodespec(0x0, 0x9, 0x8, uint16(shiftamt) << 6 | 3); // sra t0, t1, 3
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
thread.registers[9] = 0x80_00_00_20; // t1
updateThreadStacks(state, thread);

uint32 result = 0xF8_00_00_02; // 4 shifts while preserving sign bit
MIPS2.State memory expect = arithmeticPostState(state, thread, 8, /* t0 */ result);

bytes32 postState = mips.step(
encodeState(state), bytes.concat(abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT), memProof), 0
);
assertEq(postState, outputState(expect), "unexpected post state");
}

function test_sllv_succeeds() external {
uint32 insn = encodespec(0xa, 0x9, 0x8, 4); // sllv t0, t1, t2
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
thread.registers[9] = 0x20; // t1
thread.registers[10] = 4; // t2
updateThreadStacks(state, thread);

uint32 result = thread.registers[9] << thread.registers[10];
MIPS2.State memory expect = arithmeticPostState(state, thread, 8, /* t0 */ result);

bytes32 postState = mips.step(
encodeState(state), bytes.concat(abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT), memProof), 0
);
assertEq(postState, outputState(expect), "unexpected post state");
}

function test_srlv_succeeds() external {
uint32 insn = encodespec(0xa, 0x9, 0x8, 6); // srlv t0, t1, t2
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
thread.registers[9] = 0x20_00; // t1
thread.registers[10] = 4; // t2
updateThreadStacks(state, thread);

uint32 result = thread.registers[9] >> thread.registers[10];
MIPS2.State memory expect = arithmeticPostState(state, thread, 8, /* t0 */ result);

bytes32 postState = mips.step(
encodeState(state), bytes.concat(abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT), memProof), 0
);
assertEq(postState, outputState(expect), "unexpected post state");
}

function test_lui_succeeds() external {
uint32 insn = encodeitype(0xf, 0x0, 0x8, 0x4); // lui $t0, 0x04
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
updateThreadStacks(state, thread);

MIPS2.State memory expect = arithmeticPostState(state, thread, 8, /* t0 */ 0x00_04_00_00);
bytes32 postState = mips.step(
encodeState(state), bytes.concat(abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT), memProof), 0
);
assertEq(postState, outputState(expect), "unexpected post state");
}

function test_clo_succeeds() external {
uint32 insn = encodespec2(0x9, 0x0, 0x8, 0x21); // clo t0, t1
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
thread.registers[9] = 0xFF_00_00_00; // t1
updateThreadStacks(state, thread);

MIPS2.State memory expect = arithmeticPostState(state, thread, 8, /* t0 */ 8);
bytes32 postState = mips.step(
encodeState(state), bytes.concat(abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT), memProof), 0
);
assertEq(postState, outputState(expect), "unexpected post state");
}

function test_clz_succeeds() external {
uint32 insn = encodespec2(0x9, 0x0, 0x8, 0x20); // clz t0, t1
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
thread.registers[9] = 0x00_00_F0_00; // t1
updateThreadStacks(state, thread);

MIPS2.State memory expect = arithmeticPostState(state, thread, 8, /* t0 */ 16);
bytes32 postState = mips.step(
encodeState(state), bytes.concat(abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT), memProof), 0
);
assertEq(postState, outputState(expect), "unexpected post state");
}

function test_preimage_read_succeeds() external {
uint32 pc = 0x0;
uint32 insn = 0x0000000c; // syscall
uint32 a1 = 0x4;
uint32 a1_val = 0x0000abba;
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, a1, a1_val);
state.preimageKey = bytes32(uint256(1) << 248 | 0x01);
state.preimageOffset = 8; // start reading past the pre-image length prefix
thread.registers[2] = 4003; // read syscall
thread.registers[4] = 5; // fd
thread.registers[5] = a1; // addr
thread.registers[6] = 4; // count
threading.createThread();
threading.replaceCurrent(thread);
bytes memory threadWitness = threading.witness();
finalizeThreadingState(threading, state);

MIPS2.ThreadState memory expectThread = copyThread(thread);
expectThread.pc = thread.nextPC;
expectThread.nextPC = thread.nextPC + 4;
expectThread.registers[2] = 4; // return
expectThread.registers[7] = 0; // errno
threading.replaceCurrent(expectThread);

// prime the pre-image oracle
bytes32 word = bytes32(uint256(0xdeadbeef) << 224);
uint8 size = 4;
uint8 partOffset = 8;
oracle.loadLocalData(uint256(state.preimageKey), 0, word, size, partOffset);

MIPS2.State memory expect = copyState(state);
expect.preimageOffset += 4;
expect.step = state.step + 1;
expect.stepsSinceLastContextSwitch = state.stepsSinceLastContextSwitch + 1;
// recompute merkle root of written pre-image
(expect.memRoot,) = ffi.getCannonMemoryProof(pc, insn, a1, 0xdeadbeef);
finalizeThreadingState(threading, expect);

bytes32 postState = mips.step(encodeState(state), bytes.concat(threadWitness, memProof), 0);
assertEq(postState, outputState(expect), "unexpected post state");
}

function test_preimage_write_succeeds() external {
uint32 insn = 0x0000000c; // syscall
uint32 a1 = 0x4;
uint32 a1_val = 0x0000abba;
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, a1, a1_val);
state.preimageKey = bytes32(0);
state.preimageOffset = 1;
thread.registers[2] = 4004; // write syscall
thread.registers[4] = 6; // fd
thread.registers[5] = a1; // addr
thread.registers[6] = 4; // count
threading.createThread();
threading.replaceCurrent(thread);
bytes memory threadWitness = threading.witness();
finalizeThreadingState(threading, state);

MIPS2.ThreadState memory expectThread = copyThread(thread);
expectThread.pc = thread.nextPC;
expectThread.nextPC = thread.nextPC + 4;
expectThread.registers[2] = 4; // return
expectThread.registers[7] = 0; // errno
threading.replaceCurrent(expectThread);

MIPS2.State memory expect = copyState(state);
expect.preimageKey = bytes32(uint256(0xabba));
expect.preimageOffset = 0;
expect.step = state.step + 1;
expect.stepsSinceLastContextSwitch = state.stepsSinceLastContextSwitch + 1;
finalizeThreadingState(threading, expect);

bytes32 postState = mips.step(encodeState(state), bytes.concat(threadWitness, memProof), 0);
assertEq(postState, outputState(expect), "unexpected post state");
}

/// @dev Modifies the MIPS2 State based on threading state
function finalizeThreadingState(Threading _threading, MIPS2.State memory _state) internal view {
Expand Down

0 comments on commit 4806d83

Please sign in to comment.