Skip to content

Commit

Permalink
feat(preimage): Async client handles
Browse files Browse the repository at this point in the history
## Overview

Makes the `HintWriterClient` + `PreimageOracleClient` traits
asynchronous to prevent blocking of the host program when executing a
client program natively.

Previously, since the preimage oracle bindings for the client were
entirely synchronous, the loops in `PipeHandle` could cause a deadlock.
Now that oracle IO is asynchronous, the runtime can interrupt a future
when it yields execution (i.e. `tokio::select` works.)

In the client program, synchronous execution is still guaranteed. It can
run async colored functions in a minimal runtime, such as the `block_on`
runtime in `kona_common`. `simple-revm` had to be changed as a part of
this PR, which has an example of this.
  • Loading branch information
clabby committed Jun 2, 2024
1 parent 044570b commit 6e475ec
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 67 deletions.
39 changes: 21 additions & 18 deletions bin/programs/simple-revm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,45 +33,48 @@ static CLIENT_HINT_PIPE: PipeHandle =

#[client_entry(0xFFFFFFF)]
fn main() {
let mut oracle = OracleReader::new(CLIENT_PREIMAGE_PIPE);
let hint_writer = HintWriter::new(CLIENT_HINT_PIPE);

io::print("Booting EVM and checking hash...\n");
let (digest, code) = boot(&mut oracle).expect("Failed to boot");

match run_evm(&mut oracle, &hint_writer, digest, code) {
Ok(_) => io::print("Success, hashes matched!\n"),
Err(e) => {
io::print_err(alloc::format!("Error: {}\n", e).as_ref());
io::exit(1);
kona_common::block_on(async {
let mut oracle = OracleReader::new(CLIENT_PREIMAGE_PIPE);
let hint_writer = HintWriter::new(CLIENT_HINT_PIPE);

io::print("Booting EVM and checking hash...\n");
let (digest, code) = boot(&mut oracle).await.expect("Failed to boot");

match run_evm(&mut oracle, &hint_writer, digest, code).await {
Ok(_) => io::print("Success, hashes matched!\n"),
Err(e) => {
io::print_err(alloc::format!("Error: {}\n", e).as_ref());
io::exit(1);
}
}
}
})
}

/// Boot the program and load bootstrap information.
#[inline]
fn boot(oracle: &mut OracleReader) -> Result<([u8; 32], Vec<u8>)> {
async fn boot(oracle: &mut OracleReader) -> Result<([u8; 32], Vec<u8>)> {
let digest = oracle
.get(PreimageKey::new_local(DIGEST_IDENT))?
.get(PreimageKey::new_local(DIGEST_IDENT))
.await?
.try_into()
.map_err(|_| anyhow!("Failed to convert digest to [u8; 32]"))?;
let code = oracle.get(PreimageKey::new_local(CODE_IDENT))?;
let code = oracle.get(PreimageKey::new_local(CODE_IDENT)).await?;

Ok((digest, code))
}

/// Call the SHA-256 precompile and assert that the input and output match the expected values
#[inline]
fn run_evm(
async fn run_evm(
oracle: &mut OracleReader,
hint_writer: &HintWriter,
digest: [u8; 32],
code: Vec<u8>,
) -> Result<()> {
// Send a hint for the preimage of the digest to the host so that it can prepare the preimage.
hint_writer.write(&alloc::format!("sha2-preimage {}", hex::encode(digest)))?;
hint_writer.write(&alloc::format!("sha2-preimage {}", hex::encode(digest))).await?;
// Get the preimage of `digest` from the host.
let input = oracle.get(PreimageKey::new_local(INPUT_IDENT))?;
let input = oracle.get(PreimageKey::new_local(INPUT_IDENT)).await?;

let mut cache_db = CacheDB::new(EmptyDB::default());

Expand Down
20 changes: 11 additions & 9 deletions crates/preimage/src/hint.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{traits::HintWriterClient, HintReaderServer, PipeHandle};
use alloc::{boxed::Box, string::String, vec};
use anyhow::Result;
use async_trait::async_trait;
use core::future::Future;
use tracing::{debug, error};

Expand All @@ -18,10 +19,11 @@ impl HintWriter {
}
}

#[async_trait]
impl HintWriterClient for HintWriter {
/// Write a hint to the host. This will overwrite any existing hint in the pipe, and block until
/// all data has been written.
fn write(&self, hint: &str) -> Result<()> {
async fn write(&self, hint: &str) -> Result<()> {
// Form the hint into a byte buffer. The format is a 4-byte big-endian length prefix
// followed by the hint string.
let mut hint_bytes = vec![0u8; hint.len() + 4];
Expand All @@ -31,13 +33,13 @@ impl HintWriterClient for HintWriter {
debug!(target: "hint_writer", "Writing hint \"{hint}\"");

// Write the hint to the host.
self.pipe_handle.write(&hint_bytes)?;
self.pipe_handle.write(&hint_bytes).await?;

debug!(target: "hint_writer", "Successfully wrote hint");

// Read the hint acknowledgement from the host.
let mut hint_ack = [0u8; 1];
self.pipe_handle.read_exact(&mut hint_ack)?;
self.pipe_handle.read_exact(&mut hint_ack).await?;

debug!(target: "hint_writer", "Received hint acknowledgement");

Expand All @@ -59,7 +61,7 @@ impl HintReader {
}
}

#[async_trait::async_trait]
#[async_trait]
impl HintReaderServer for HintReader {
async fn next_hint<F, Fut>(&self, mut route_hint: F) -> Result<()>
where
Expand All @@ -68,12 +70,12 @@ impl HintReaderServer for HintReader {
{
// Read the length of the raw hint payload.
let mut len_buf = [0u8; 4];
self.pipe_handle.read_exact(&mut len_buf)?;
self.pipe_handle.read_exact(&mut len_buf).await?;
let len = u32::from_be_bytes(len_buf);

// Read the raw hint payload.
let mut raw_payload = vec![0u8; len as usize];
self.pipe_handle.read_exact(raw_payload.as_mut_slice())?;
self.pipe_handle.read_exact(raw_payload.as_mut_slice()).await?;
let payload = String::from_utf8(raw_payload)
.map_err(|e| anyhow::anyhow!("Failed to decode hint payload: {e}"))?;

Expand All @@ -82,14 +84,14 @@ impl HintReaderServer for HintReader {
// Route the hint
if let Err(e) = route_hint(payload).await {
// Write back on error to prevent blocking the client.
self.pipe_handle.write(&[0x00])?;
self.pipe_handle.write(&[0x00]).await?;

error!("Failed to route hint: {e}");
anyhow::bail!("Failed to rout hint: {e}");
}

// Write back an acknowledgement to the client to unblock their process.
self.pipe_handle.write(&[0x00])?;
self.pipe_handle.write(&[0x00]).await?;

debug!(target: "hint_reader", "Successfully routed and acknowledged hint");

Expand Down Expand Up @@ -144,7 +146,7 @@ mod test {
let (hint_writer, hint_reader) = (sys.hint_writer, sys.hint_reader);
let incoming_hints = Arc::new(Mutex::new(Vec::new()));

let client = tokio::task::spawn(async move { hint_writer.write(MOCK_DATA) });
let client = tokio::task::spawn(async move { hint_writer.write(MOCK_DATA).await });
let host = tokio::task::spawn({
let incoming_hints_ref = Arc::clone(&incoming_hints);
async move {
Expand Down
27 changes: 14 additions & 13 deletions crates/preimage/src/oracle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,32 @@ impl OracleReader {
/// Set the preimage key for the global oracle reader. This will overwrite any existing key, and
/// block until the host has prepared the preimage and responded with the length of the
/// preimage.
fn write_key(&self, key: PreimageKey) -> Result<usize> {
async fn write_key(&self, key: PreimageKey) -> Result<usize> {
// Write the key to the host so that it can prepare the preimage.
let key_bytes: [u8; 32] = key.into();
self.pipe_handle.write(&key_bytes)?;
self.pipe_handle.write(&key_bytes).await?;

// Read the length prefix and reset the cursor.
let mut length_buffer = [0u8; 8];
self.pipe_handle.read_exact(&mut length_buffer)?;
self.pipe_handle.read_exact(&mut length_buffer).await?;
Ok(u64::from_be_bytes(length_buffer) as usize)
}
}

#[async_trait::async_trait]
impl PreimageOracleClient for OracleReader {
/// Get the data corresponding to the currently set key from the host. Return the data in a new
/// heap allocated `Vec<u8>`
fn get(&self, key: PreimageKey) -> Result<Vec<u8>> {
async fn get(&self, key: PreimageKey) -> Result<Vec<u8>> {
debug!(target: "oracle_client", "Requesting data from preimage oracle. Key {key}");

let length = self.write_key(key)?;
let length = self.write_key(key).await?;
let mut data_buffer = alloc::vec![0; length];

debug!(target: "oracle_client", "Reading data from preimage oracle. Key {key}");

// Grab a read lock on the preimage pipe to read the data.
self.pipe_handle.read_exact(&mut data_buffer)?;
self.pipe_handle.read_exact(&mut data_buffer).await?;

debug!(target: "oracle_client", "Successfully read data from preimage oracle. Key: {key}");

Expand All @@ -52,11 +53,11 @@ impl PreimageOracleClient for OracleReader {

/// Get the data corresponding to the currently set key from the host. Write the data into the
/// provided buffer
fn get_exact(&self, key: PreimageKey, buf: &mut [u8]) -> Result<()> {
async fn get_exact(&self, key: PreimageKey, buf: &mut [u8]) -> Result<()> {
debug!(target: "oracle_client", "Requesting data from preimage oracle. Key {key}");

// Write the key to the host and read the length of the preimage.
let length = self.write_key(key)?;
let length = self.write_key(key).await?;

debug!(target: "oracle_client", "Reading data from preimage oracle. Key {key}");

Expand All @@ -65,7 +66,7 @@ impl PreimageOracleClient for OracleReader {
bail!("Buffer size {} does not match preimage size {}", buf.len(), length);
}

self.pipe_handle.read_exact(buf)?;
self.pipe_handle.read_exact(buf).await?;

debug!(target: "oracle_client", "Successfully read data from preimage oracle. Key: {key}");

Expand Down Expand Up @@ -95,7 +96,7 @@ impl PreimageOracleServer for OracleServer {
{
// Read the preimage request from the client, and throw early if there isn't is any.
let mut buf = [0u8; 32];
self.pipe_handle.read_exact(&mut buf)?;
self.pipe_handle.read_exact(&mut buf).await?;
let preimage_key = PreimageKey::try_from(buf)?;

debug!(target: "oracle_server", "Fetching preimage for key {preimage_key}");
Expand All @@ -109,7 +110,7 @@ impl PreimageOracleServer for OracleServer {
.flatten()
.copied()
.collect::<Vec<_>>();
self.pipe_handle.write(data.as_slice())?;
self.pipe_handle.write(data.as_slice()).await?;

debug!(target: "oracle_server", "Successfully wrote preimage data for key {preimage_key}");

Expand Down Expand Up @@ -184,8 +185,8 @@ mod test {
let (oracle_reader, oracle_server) = (sys.oracle_reader, sys.oracle_server);

let client = tokio::task::spawn(async move {
let contents_a = oracle_reader.get(key_a).unwrap();
let contents_b = oracle_reader.get(key_b).unwrap();
let contents_a = oracle_reader.get(key_a).await.unwrap();
let contents_b = oracle_reader.get(key_b).await.unwrap();

// Drop the file descriptors to close the pipe, stopping the host's blocking loop on
// waiting for client requests.
Expand Down
101 changes: 79 additions & 22 deletions crates/preimage/src/pipe.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
//! This module contains a rudamentary pipe between two file descriptors, using [kona_common::io]
//! for reading and writing from the file descriptors.

use anyhow::{bail, Result};
use anyhow::{anyhow, Result};
use core::{
cell::RefCell,
cmp::Ordering,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use kona_common::{io, FileDescriptor};

/// [PipeHandle] is a handle for one end of a bidirectional pipe.
Expand All @@ -24,30 +31,14 @@ impl PipeHandle {
io::read(self.read_handle, buf)
}

/// Reads exactly `buf.len()` bytes into `buf`, blocking until all bytes are read.
pub fn read_exact(&self, buf: &mut [u8]) -> Result<usize> {
let mut read = 0;
while read < buf.len() {
let chunk_read = self.read(&mut buf[read..])?;
read += chunk_read;
}
Ok(read)
/// Reads exactly `buf.len()` bytes into `buf`.
pub fn read_exact<'a>(&self, buf: &'a mut [u8]) -> impl Future<Output = Result<usize>> + 'a {
ReadFuture { pipe_handle: *self, buf: RefCell::new(buf), read: 0 }
}

/// Write the given buffer to the pipe.
pub fn write(&self, buf: &[u8]) -> Result<usize> {
let mut written = 0;
loop {
match io::write(self.write_handle, &buf[written..]) {
Ok(0) => break,
Ok(n) => {
written += n;
continue;
}
Err(e) => bail!("Failed to write preimage key: {}", e),
}
}
Ok(written)
pub fn write<'a>(&self, buf: &'a [u8]) -> impl Future<Output = Result<usize>> + 'a {
WriteFuture { pipe_handle: *self, buf, written: 0 }
}

/// Returns the read handle for the pipe.
Expand All @@ -60,3 +51,69 @@ impl PipeHandle {
self.write_handle
}
}

/// A future that reads from a pipe, returning [Poll::Ready] when the buffer is full.
struct ReadFuture<'a> {
/// The pipe handle to read from
pipe_handle: PipeHandle,
/// The buffer to read into
buf: RefCell<&'a mut [u8]>,
/// The number of bytes read so far
read: usize,
}

impl Future for ReadFuture<'_> {
type Output = Result<usize>;

fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let mut buf = self.buf.borrow_mut();
let buf_len = buf.len();
let chunk_read = self.pipe_handle.read(&mut buf[self.read..])?;

// Drop the borrow on self.
drop(buf);

self.read += chunk_read;

match self.read.cmp(&buf_len) {
Ordering::Equal => Poll::Ready(Ok(self.read)),
Ordering::Greater => Poll::Ready(Err(anyhow!("Read more bytes than buffer size"))),
Ordering::Less => {
// Register the current task to be woken up when it can make progress
ctx.waker().wake_by_ref();
Poll::Pending
}
}
}
}

/// A future that writes to a pipe, returning [Poll::Ready] when the full buffer has been written.
struct WriteFuture<'a> {
/// The pipe handle to write to
pipe_handle: PipeHandle,
/// The buffer to write
buf: &'a [u8],
/// The number of bytes written so far
written: usize,
}

impl Future for WriteFuture<'_> {
type Output = Result<usize>;

fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match io::write(self.pipe_handle.write_handle(), &self.buf[self.written..]) {
Ok(0) => return Poll::Ready(Ok(self.written)), // Finished writing
Ok(n) => {
self.written += n;
continue;
}
Err(_) => {
// Register the current task to be woken up when it can make progress
ctx.waker().wake_by_ref();
return Poll::Pending;
}
}
}
}
}
Loading

0 comments on commit 6e475ec

Please sign in to comment.