Skip to content

Commit

Permalink
fix: Snapshot2 should untrack cached pages in case of repeated memory…
Browse files Browse the repository at this point in the history
… writes (#418) (#419)

Test case is provided by @mohanson

Co-authored-by: Xuejie Xiao <xxuejie@gmail.com>
  • Loading branch information
mohanson and xxuejie authored Mar 19, 2024
1 parent bd848a6 commit 49ef291
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
19 changes: 17 additions & 2 deletions src/snapshot2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
bits::roundup,
elf::{LoadingAction, ProgramMetadata},
machine::SupportMachine,
memory::{Memory, FLAG_DIRTY},
memory::{get_page_indices, Memory, FLAG_DIRTY},
Error, Register, RISCV_GENERAL_REGISTER_NUMBER, RISCV_PAGESIZE,
};
use bytes::Bytes;
Expand Down Expand Up @@ -121,6 +121,7 @@ impl<I: Clone + PartialEq, D: DataSource<I>> Snapshot2Context<I, D> {
length: u64,
) -> Result<(u64, u64), Error> {
let (data, full_length) = self.data_source.load_data(id, offset, length)?;
self.untrack_pages(machine, addr, data.len() as u64)?;
machine.memory_mut().store_bytes(addr, &data)?;
self.track_pages(machine, addr, data.len() as u64, id, offset)?;
Ok((data.len() as u64, full_length))
Expand Down Expand Up @@ -227,7 +228,7 @@ impl<I: Clone + PartialEq, D: DataSource<I>> Snapshot2Context<I, D> {
self.track_pages(machine, start, length, id, offset + action.source.start)
}

/// This is only made public for advanced usages, but make sure to exercise more
/// The followings are only made public for advanced usages, but make sure to exercise more
/// cautions when calling it!
pub fn track_pages<M: SupportMachine>(
&mut self,
Expand Down Expand Up @@ -255,6 +256,20 @@ impl<I: Clone + PartialEq, D: DataSource<I>> Snapshot2Context<I, D> {
}
Ok(())
}

pub fn untrack_pages<M: SupportMachine>(
&mut self,
machine: &mut M,
start: u64,
length: u64,
) -> Result<(), Error> {
let page_indices = get_page_indices(start, length)?;
for page in page_indices.0..=page_indices.1 {
machine.memory_mut().set_flag(page, FLAG_DIRTY)?;
self.pages.remove(&page);
}
Ok(())
}
}

#[derive(Clone, Debug, Deserialize, Serialize)]
Expand Down
39 changes: 38 additions & 1 deletion tests/test_resume2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,11 @@ fn load_program(name: &str) -> TestSource {
file.read_to_end(&mut buffer).unwrap();
let program = buffer.into();

let data = vec![7; 16 * 4096];
let mut data = vec![0; 16 * 4096];
for i in 0..data.len() {
data[i] = i as u8;
}

let mut m = HashMap::default();
m.insert(DATA_ID, data.into());
m.insert(PROGRAM_ID, program);
Expand Down Expand Up @@ -622,3 +626,36 @@ pub fn test_sc_after_snapshot2() {
assert!(result2.is_ok());
assert_eq!(result2.unwrap(), 0);
}

#[cfg(not(feature = "enable-chaos-mode-by-default"))]
#[test]
pub fn test_store_bytes_twice() {
let data_source = load_program("tests/programs/sc_after_snapshot");

let mut machine = MachineTy::Asm.build(data_source.clone(), VERSION2);
machine.set_max_cycles(u64::MAX);
machine.load_program(&vec!["main".into()]).unwrap();

match machine {
Machine::Asm(ref mut inner, ref ctx) => {
ctx.lock()
.unwrap()
.store_bytes(&mut inner.machine, 0, &DATA_ID, 2, 29186)
.unwrap();
ctx.lock()
.unwrap()
.store_bytes(&mut inner.machine, 0, &DATA_ID, 0, 11008)
.unwrap();
}
_ => unimplemented!(),
}
let a = machine.full_memory().unwrap()[4096 * 2];

let snapshot = machine.snapshot().unwrap();
let mut machine2 = MachineTy::Asm.build(data_source.clone(), VERSION2);
machine2.resume(snapshot).unwrap();
machine2.set_max_cycles(u64::MAX);
let b = machine2.full_memory().unwrap()[4096 * 2];

assert_eq!(a, b);
}

0 comments on commit 49ef291

Please sign in to comment.