Skip to content

Commit

Permalink
Merge branch 'early_check' into hpc
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyBF committed Dec 21, 2024
2 parents 2133f89 + 014eec5 commit 669a9d7
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 9 deletions.
59 changes: 52 additions & 7 deletions ext/src/save.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,37 @@ impl<T: io::Read> std::ops::Drop for ChecksumReader<T> {
}

/// Open the file pointed to by `path` as a `Box<dyn Read>`. If the file does not exist, look for
/// compressed versions.
fn open_file(path: PathBuf) -> Option<Box<dyn io::Read>> {
/// compressed versions. If `early_check` is true, we check the checksum before returning the file.
fn open_file(path: PathBuf, early_check: bool) -> Option<Box<dyn io::Read>> {
use io::BufRead;

fn do_early_check(path: PathBuf, mut reader: impl io::Read) -> Option<Box<dyn io::Read>> {
let mut file_contents = Vec::new();
let num_bytes = std::io::copy(&mut reader, &mut file_contents)
.unwrap_or_else(|e| panic!("Error when reading from {path:?}: {e}"));
if num_bytes < 4 {
tracing::warn!("File {path:?} is too short to contain a checksum. Deleting file.");
std::fs::remove_file(&path)
.unwrap_or_else(|e| panic!("Error when deleting {path:?}: {e}"));
return None;
}

let checksum_pos = num_bytes as usize - 4;
let (content_bytes, mut checksum_bytes) = file_contents.split_at(checksum_pos);
let mut adler = adler::Adler32::new();
adler.write_slice(content_bytes);
let checksum = checksum_bytes.read_u32::<LittleEndian>().unwrap();

if adler.checksum() == checksum {
Some(Box::new(io::Cursor::new(file_contents)))
} else {
tracing::warn!("Checksum mismatch for {path:?}. Deleting file.");
std::fs::remove_file(&path)
.unwrap_or_else(|e| panic!("Error when deleting {path:?}: {e}"));
None
}
}

// We should try in decreasing order of access speed.
match File::open(&path) {
Ok(f) => {
Expand All @@ -312,7 +339,11 @@ fn open_file(path: PathBuf) -> Option<Box<dyn io::Read>> {
.unwrap_or_else(|e| panic!("Error when deleting empty file {path:?}: {e}"));
return None;
}
return Some(Box::new(ChecksumReader::new(reader)));
return if early_check {
do_early_check(path, reader)
} else {
Some(Box::new(ChecksumReader::new(reader)))
};
}
Err(e) => {
if e.kind() != io::ErrorKind::NotFound {
Expand All @@ -327,9 +358,12 @@ fn open_file(path: PathBuf) -> Option<Box<dyn io::Read>> {
path.set_extension("zst");
match File::open(&path) {
Ok(f) => {
return Some(Box::new(ChecksumReader::new(
zstd::stream::Decoder::new(f).unwrap(),
)))
let reader = zstd::stream::Decoder::new(f).unwrap();
return if early_check {
do_early_check(path, reader)
} else {
Some(Box::new(ChecksumReader::new(reader)))
};
}
Err(e) => {
if e.kind() != io::ErrorKind::NotFound {
Expand Down Expand Up @@ -395,6 +429,17 @@ impl<A: Algebra> SaveFile<A> {
Ok(())
}

/// Whether we should load the file in memory and check the checksum before returning it. This
/// only returns false for quasi-inverses because they are our largest files by far. This is a
/// function of `SaveFile` and not just `SaveKind` because we may want to change the behavior
/// depending on the stem or some other heuristic.
fn should_check_early(&self) -> bool {
!matches!(
self.kind,
SaveKind::AugmentationQi | SaveKind::NassauQi | SaveKind::ResQi
)
}

/// This panics if there is no save dir
fn get_save_path(&self, mut dir: PathBuf) -> PathBuf {
if let Some(idx) = self.idx {
Expand All @@ -418,7 +463,7 @@ impl<A: Algebra> SaveFile<A> {
pub fn open_file(&self, dir: PathBuf) -> Option<Box<dyn io::Read>> {
let file_path = self.get_save_path(dir);
let path_string = file_path.to_string_lossy().into_owned();
if let Some(mut f) = open_file(file_path) {
if let Some(mut f) = open_file(file_path, self.should_check_early()) {
self.validate_header(&mut f).unwrap();
tracing::info!("success open_read: {}", path_string);
Some(f)
Expand Down
70 changes: 68 additions & 2 deletions ext/tests/save_load_resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,7 @@ fn test_load_secondary() {
}

#[test]
#[should_panic(expected = "Invalid file checksum")]
fn test_checksum() {
fn test_checksum_early() {
use std::{
fs::OpenOptions,
io::{Seek, SeekFrom, Write},
Expand All @@ -300,6 +299,73 @@ fn test_checksum() {
file.seek(SeekFrom::Start(41)).unwrap();
file.write_all(&[1]).unwrap();

// Differentials are checked early for integrity, and silently replaced if they are malformed
construct_standard::<false, _, _>("S_2", Some(tempdir.path().into()))
.unwrap()
.compute_through_bidegree(Bidegree::s_t(2, 2));
}

#[test]
#[should_panic(expected = "Error when deleting")]
fn test_checksum_early_locked() {
use std::{
fs::OpenOptions,
io::{Seek, SeekFrom, Write},
};

let tempdir = tempfile::TempDir::new().unwrap();

construct_standard::<false, _, _>("S_2", Some(tempdir.path().into()))
.unwrap()
.compute_through_bidegree(Bidegree::s_t(2, 2));

let mut path = tempdir.path().to_owned();
path.push("differentials/2_2_differential");

let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(path)
.unwrap();

file.seek(SeekFrom::Start(41)).unwrap();
file.write_all(&[1]).unwrap();

lock_tempdir(tempdir.path());

// This should try to delete the file and panic
construct_standard::<false, _, _>("S_2", Some(tempdir.path().into()))
.unwrap()
.compute_through_bidegree(Bidegree::s_t(2, 2));
}

#[test]
#[should_panic(expected = "Invalid file checksum")]
fn test_checksum_late() {
use std::{
fs::OpenOptions,
io::{Seek, SeekFrom, Write},
};

let tempdir = tempfile::TempDir::new().unwrap();

construct_standard::<false, _, _>("S_2", Some(tempdir.path().into()))
.unwrap()
.compute_through_bidegree(Bidegree::s_t(2, 2));

let mut path = tempdir.path().to_owned();
path.push("res_qis/1_2_res_qi");

let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(path)
.unwrap();

file.seek(SeekFrom::Start(41)).unwrap();
file.write_all(&[1]).unwrap();

// Quasi-inverses are checked after using them, and we panic if the check fails
construct_standard::<false, _, _>("S_2", Some(tempdir.path().into()))
.unwrap()
.compute_through_bidegree(Bidegree::s_t(2, 2));
Expand Down

0 comments on commit 669a9d7

Please sign in to comment.