Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow reading chunks from the input stream #267

Merged
merged 4 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 11 additions & 41 deletions core/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@ use crate::utils::Buffer;
use serde::{de::DeserializeOwned, Deserialize, Serialize};

/// Standard input.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone)]
pub struct AthenaStdin {
/// Input stored as a vec of vec of bytes. It's stored this way because the read syscall reads
/// a vec of bytes at a time.
pub buffer: Vec<Vec<u8>>,
pub ptr: usize,
buffer: Vec<u8>,
}

/// Public values for the runner.
Expand All @@ -17,49 +14,26 @@ pub struct AthenaPublicValues {
}

impl AthenaStdin {
/// Create a new `AthenaStdin`.
pub const fn new() -> Self {
Self {
buffer: Vec::new(),
ptr: 0,
}
}

/// Create a `AthenaStdin` from a slice of bytes.
pub fn from(data: &[u8]) -> Self {
Self {
buffer: vec![data.to_vec()],
ptr: 0,
}
}

/// Read a value from the buffer.
pub fn read<T: DeserializeOwned>(&mut self) -> T {
let result: T = bincode::deserialize(&self.buffer[self.ptr]).expect("failed to deserialize");
self.ptr += 1;
result
}

/// Read a slice of bytes from the buffer.
pub fn read_slice(&mut self, slice: &mut [u8]) {
slice.copy_from_slice(&self.buffer[self.ptr]);
self.ptr += 1;
Self { buffer: Vec::new() }
}

/// Write a value to the buffer.
pub fn write<T: Serialize>(&mut self, data: &T) {
let mut tmp = Vec::new();
bincode::serialize_into(&mut tmp, data).expect("serialization failed");
self.buffer.push(tmp);
bincode::serialize_into(&mut self.buffer, data).expect("serialization failed");
}

/// Write a slice of bytes to the buffer.
pub fn write_slice(&mut self, slice: &[u8]) {
self.buffer.push(slice.to_vec());
self.buffer.extend_from_slice(slice);
}

pub fn write_vec(&mut self, vec: Vec<u8>) {
self.buffer.push(vec);
pub fn write_vec(&mut self, mut vec: Vec<u8>) {
self.buffer.append(&mut vec);
}

pub fn to_vec(self) -> Vec<u8> {
self.buffer
}
}

Expand All @@ -71,10 +45,6 @@ impl AthenaPublicValues {
}
}

pub fn raw(&self) -> String {
format!("0x{}", hex::encode(self.buffer.data.clone()))
}

/// Create a `AthenaPublicValues` from a slice of bytes.
pub fn from(data: &[u8]) -> Self {
Self {
Expand Down
12 changes: 4 additions & 8 deletions core/src/runtime/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,15 @@ impl Read for Runtime<'_> {

impl Runtime<'_> {
pub fn write_stdin<U: Serialize>(&mut self, input: &U) {
let mut buf = Vec::new();
bincode::serialize_into(&mut buf, input).expect("serialization failed");
self.state.input_stream.push(buf);
bincode::serialize_into(&mut self.state.input_stream, input).expect("serialization failed");
}

pub fn write_stdin_slice(&mut self, input: &[u8]) {
self.state.input_stream.push(input.to_vec());
self.write_from(input.iter().copied());
}

pub fn write_vecs(&mut self, inputs: &[Vec<u8>]) {
for input in inputs {
self.state.input_stream.push(input.clone());
}
pub fn write_from<T: IntoIterator<Item = u8>>(&mut self, input: T) {
self.state.input_stream.extend(input);
}

pub fn read_public_values<U: DeserializeOwned>(&mut self) -> U {
Expand Down
2 changes: 1 addition & 1 deletion core/src/runtime/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct ExecutionState {
pub uninitialized_memory: HashMap<u32, u32, BuildNoHashHasher<u32>>,

/// A stream of input values (global to the entire program).
pub input_stream: Vec<Vec<u8>>,
pub input_stream: Vec<u8>,

/// A ptr to the current position in the input stream incremented by HINT_READ opcode.
pub input_stream_ptr: usize,
Expand Down
15 changes: 11 additions & 4 deletions core/src/runtime/syscall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,19 @@ impl<'a, 'h> SyscallContext<'a, 'h> {
#[tracing::instrument(skip(self))]
pub fn bytes(&self, mut addr: u32, len: usize) -> Vec<u8> {
let mut bytes = Vec::new();
let mut bytes_to_read = len;

// handle case when addr is not aligned to 4B
let addr_offset = (addr % 4) as usize;
let addr_offset = addr % 4;
if addr_offset != 0 {
let word = self.word(addr - addr_offset as u32).to_le_bytes();
bytes.extend_from_slice(&word[addr_offset..]);
tracing::debug!(addr, len, addr_offset, "addr not aligned");
let word = self.word(addr - addr_offset).to_le_bytes();
bytes.extend_from_slice(&word[addr_offset as usize..]);
addr += bytes.len() as u32;
bytes_to_read = bytes_to_read.saturating_sub(bytes.len());
}

for addr in (addr..addr + (len - bytes.len()) as u32).step_by(4) {
for addr in (addr..addr + bytes_to_read as u32).step_by(4) {
bytes.extend_from_slice(&self.word(addr).to_le_bytes());
}
bytes.truncate(len); // handle case when len is not a multiple of 4
Expand Down Expand Up @@ -290,5 +294,8 @@ mod tests {
// address not aligned and length not a multiple of 4
let read = ctx.bytes(0x103, 59);
assert_eq!(read, memory[3..3 + 59]);

let read = ctx.bytes(0x1, 2);
assert_eq!(read, memory[1..1 + 2]);
}
}
174 changes: 122 additions & 52 deletions core/src/syscall/hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ pub(crate) struct SyscallHintLen;

impl Syscall for SyscallHintLen {
fn execute(&self, ctx: &mut SyscallContext, _: u32, _: u32) -> SyscallResult {
if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() {
tracing::debug!(
"no more data to read in stdin: input_stream_ptr={}, input_stream_len={}",
ctx.rt.state.input_stream_ptr,
ctx.rt.state.input_stream.len(),
);
return Ok(Outcome::Result(Some(0)));
}
Ok(Outcome::Result(Some(
ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr].len() as u32,
)))
let len = if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() {
0
} else {
ctx.rt.state.input_stream.len() - ctx.rt.state.input_stream_ptr
};
tracing::debug!(
ptr = ctx.rt.state.input_stream_ptr,
total = ctx.rt.state.input_stream.len(),
len,
"hinted remaning data in the input stream"
);
Ok(Outcome::Result(Some(len as u32)))
}
}

Expand All @@ -26,59 +27,82 @@ pub(crate) struct SyscallHintRead;

impl Syscall for SyscallHintRead {
fn execute(&self, ctx: &mut SyscallContext, ptr: u32, len: u32) -> SyscallResult {
if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() {
if ctx.rt.unconstrained {
tracing::error!("hint read should not be used in a unconstrained block");
return Err(StatusCode::StaticModeViolation);
}
let data = ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr..].to_vec();
if len as usize > data.len() {
tracing::debug!(
"failed reading stdin due to insufficient input data: input_stream_ptr={}, input_stream_len={}",
ctx.rt.state.input_stream_ptr,
ctx.rt.state.input_stream.len()
ptr = ctx.rt.state.input_stream_ptr,
total = ctx.rt.state.input_stream.len(),
available = data.len(),
len,
"failed reading stdin due to insufficient input data",
);
return Err(StatusCode::InsufficientInput);
}
let vec = &ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr];
ctx.rt.state.input_stream_ptr += 1;
assert!(
!ctx.rt.unconstrained,
"hint read should not be used in a unconstrained block"
);
if vec.len() != len as usize {
let mut data = &data[..len as usize];
let mut address = ptr;

// Handle unaligned start
if address % 4 != 0 {
let aligned_addr = address & !3; // Round down to aligned address
let offset = (address % 4) as usize;
let bytes_to_write = std::cmp::min(4 - offset, data.len());
tracing::debug!(
"hint input stream read length mismatch: expected={}, actual={}",
len,
vec.len()
address,
aligned_addr,
offset,
bytes_to_write,
"hint read address not aligned to 4 bytes"
);
return Err(StatusCode::InvalidSyscallArgument);

let mut word_bytes = ctx.rt.mr(aligned_addr).to_le_bytes();
tracing::debug!(word = hex::encode(word_bytes), "read existing word");

word_bytes[offset..offset + bytes_to_write].copy_from_slice(&data[..bytes_to_write]);

ctx.rt.mw(aligned_addr, u32::from_le_bytes(word_bytes));
tracing::debug!(word = hex::encode(word_bytes), "written updated word");

address = aligned_addr + 4;
data = &data[bytes_to_write..];
}
if ptr % 4 != 0 {
tracing::debug!("hint read address not aligned to 4 bytes");
return Err(StatusCode::InvalidSyscallArgument);

// Iterate through the remaining data in 4-byte chunks
let mut chunks = data.chunks_exact(4);
for chunk in &mut chunks {
// unwrap() won't panic, which is guaranteed by chunks()
let word = u32::from_le_bytes(chunk.try_into().unwrap());
ctx.rt.mw(address, word);
address += 4;
}
// Iterate through the vec in 4-byte chunks
for i in (0..len).step_by(4) {
// Get each byte in the chunk
let b1 = vec[i as usize];
// In case the vec is not a multiple of 4, right-pad with 0s. This is fine because we
// are assuming the word is uninitialized, so filling it with 0s makes sense.
let b2 = vec.get(i as usize + 1).copied().unwrap_or(0);
let b3 = vec.get(i as usize + 2).copied().unwrap_or(0);
let b4 = vec.get(i as usize + 3).copied().unwrap_or(0);
let word = u32::from_le_bytes([b1, b2, b3, b4]);

// Save the data into runtime state so the runtime will use the desired data instead of
// 0 when first reading/writing from this address.
ctx
.rt
.state
.uninitialized_memory
.entry(ptr + i)
.and_modify(|_| panic!("hint read address is initialized already"))
.or_insert(word);
// In case the vec is not a multiple of 4, right-pad with 0s. This is fine because we
// are assuming the word is uninitialized, so filling it with 0s makes sense.
let remainder = chunks.remainder();
if !remainder.is_empty() {
let mut word_array = [0u8; 4];
let len = remainder.len();
word_array[..len].copy_from_slice(remainder);
ctx.rt.mw(address, u32::from_le_bytes(word_array));
}
Ok(Outcome::Result(None))
tracing::debug!(
from = ptr,
to = address as usize + remainder.len(),
read = len,
"HintRead syscall finished"
);
tracing::trace!(data = hex::encode(data));
ctx.rt.state.input_stream_ptr += len as usize;
Ok(Outcome::Result(Some(len)))
}
}

#[cfg(test)]
mod tests {
use athena_interface::StatusCode;

use crate::{
runtime::{Outcome, Program, Runtime, Syscall, SyscallContext},
utils::AthenaCoreOpts,
Expand All @@ -96,9 +120,55 @@ mod tests {

// with inputs
let data = [vec![1, 2, 3, 4, 5], vec![6, 7]];
ctx.rt.write_vecs(&data);
ctx.rt.write_stdin_slice(&data[0]);
ctx.rt.write_stdin_slice(&data[1]);

let result = syscall.execute(&mut ctx, 0, 0).unwrap();
assert_eq!(Outcome::Result(Some(data[0].len() as u32)), result);
assert_eq!(
Outcome::Result(Some((data[0].len() + data[1].len()) as u32)),
result
);
}

#[test]
fn hint_read_cant_run_in_unconstrained() {
let mut rt = Runtime::new(Program::default(), None, AthenaCoreOpts::default(), None);
rt.unconstrained = true;
let mut ctx = SyscallContext::new(&mut rt);
let syscall = super::SyscallHintRead {};

let result = syscall.execute(&mut ctx, 0, 0);
assert_eq!(Err(StatusCode::StaticModeViolation), result);
}

#[test]
fn hint_read() {
let mut rt = Runtime::new(Program::default(), None, AthenaCoreOpts::default(), None);
let mut ctx = SyscallContext::new(&mut rt);
let syscall = super::SyscallHintRead {};

// no inputs
let result = syscall.execute(&mut ctx, 0, 10);
assert_eq!(Err(StatusCode::InsufficientInput), result);

let data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
ctx.rt.write_stdin(&data);

// can't read more than available
let result = syscall.execute(&mut ctx, 0, data.len() as u32 + 1);
assert_eq!(Err(StatusCode::InsufficientInput), result);

// read only up to `len`
let len = 3;
let result = syscall.execute(&mut ctx, 0, len as u32);
assert_eq!(Ok(Outcome::Result(Some(len as u32))), result);
assert_eq!(&data[..len], ctx.bytes(0, len).as_slice());

// read the rest
let address = len;
let len = data.len() - len;
let result = syscall.execute(&mut ctx, address as u32, len as u32);
assert_eq!(Ok(Outcome::Result(Some(len as u32))), result);
assert_eq!(data, ctx.bytes(0, data.len()).as_slice());
}
}
9 changes: 4 additions & 5 deletions core/src/syscall/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ impl Syscall for SyscallWrite {
}
}
3 => {
rt.state.public_values_stream.extend_from_slice(&bytes);
rt.state.public_values_stream.extend(bytes);
}
4 => {
rt.state.input_stream.push(bytes);
rt.state.input_stream.extend(bytes);
}
fd => {
tracing::debug!(fd, "executing hook");
match rt.execute_hook(fd, &bytes) {
Ok(result) => {
rt.state.input_stream.push(result);
rt.state.input_stream.extend(result);
}
Err(err) => {
tracing::debug!(fd, ?err, "hook failed");
Expand Down Expand Up @@ -115,7 +115,6 @@ mod tests {

let result = SyscallWrite {}.execute(&mut SyscallContext { rt: &mut runtime }, 7, 0);
result.unwrap();
let result = runtime.state.input_stream.pop().unwrap();
assert_eq!(vec![1, 2, 3, 4, 5], result);
assert_eq!(vec![1, 2, 3, 4, 5], runtime.state.input_stream);
}
}
Loading
Loading