diff --git a/tensorboard/data/server/bench.rs b/tensorboard/data/server/bench.rs index 809edfefb6..0ea320e84f 100644 --- a/tensorboard/data/server/bench.rs +++ b/tensorboard/data/server/bench.rs @@ -29,6 +29,12 @@ struct Opts { logdir: PathBuf, #[clap(long, default_value = "info")] log_level: String, + // Pair of `--no-checksum` and `--checksum` flags, defaulting to "no checksum". + #[clap(long, multiple_occurrences = true, overrides_with = "checksum")] + #[allow(unused)] + no_checksum: bool, + #[clap(long, multiple_occurrences = true, overrides_with = "no_checksum")] + checksum: bool, } fn main() { @@ -37,6 +43,7 @@ fn main() { let commit = Commit::new(); let mut loader = LogdirLoader::new(&commit, opts.logdir); + loader.checksum(opts.checksum); // if neither `--[no-]checksum` given, defaults to false info!("Starting load cycle"); let start = Instant::now(); diff --git a/tensorboard/data/server/cli.rs b/tensorboard/data/server/cli.rs index 73fa10dc4e..5fc70650cc 100644 --- a/tensorboard/data/server/cli.rs +++ b/tensorboard/data/server/cli.rs @@ -86,6 +86,27 @@ struct Opts { /// port file is specified but cannot be written, the server will die. #[clap(long)] port_file: Option, + + /// Checksum all records (negate with `--no-checksum`) + /// + /// With `--checksum`, every record will be checksummed before being parsed. With + /// `--no-checksum` (the default), records are only checksummed if parsing fails. Skipping + /// checksums for records that successfully parse can be significantly faster, but also means + /// that some bit flips may not be detected. + #[clap(long, multiple_occurrences = true, overrides_with = "no_checksum")] + checksum: bool, + + /// Only checksum records that fail to parse + /// + /// Negates `--checksum`. This is the default. + #[clap( + long, + multiple_occurrences = true, + overrides_with = "checksum", + hidden = true + )] + #[allow(unused)] + no_checksum: bool, } /// A duration in seconds. @@ -148,16 +169,16 @@ pub async fn main() -> Result<(), Box> { .spawn({ let logdir = opts.logdir; let reload_interval = opts.reload_interval; - move || { - let mut loader = LogdirLoader::new(commit, logdir); - loop { - info!("Starting load cycle"); - let start = Instant::now(); - loader.reload(); - let end = Instant::now(); - info!("Finished load cycle ({:?})", end - start); - thread::sleep(reload_interval.duration()); - } + let mut loader = LogdirLoader::new(commit, logdir); + // Checksum only if `--checksum` given (i.e., off by default). + loader.checksum(opts.checksum); + move || loop { + info!("Starting load cycle"); + let start = Instant::now(); + loader.reload(); + let end = Instant::now(); + info!("Finished load cycle ({:?})", end - start); + thread::sleep(reload_interval.duration()); } }) .expect("failed to spawn reloader thread"); diff --git a/tensorboard/data/server/event_file.rs b/tensorboard/data/server/event_file.rs index 1d863dff27..822636f94c 100644 --- a/tensorboard/data/server/event_file.rs +++ b/tensorboard/data/server/event_file.rs @@ -32,6 +32,8 @@ pub struct EventFileReader { last_wall_time: Option, /// Underlying record reader owned by this event file. reader: TfRecordReader, + /// Whether to compute CRCs for records before parsing as protos. + checksum: bool, } /// Error returned by [`EventFileReader::read_event`]. @@ -68,14 +70,30 @@ impl EventFileReader { Self { last_wall_time: None, reader: TfRecordReader::new(reader), + checksum: true, } } + /// Sets whether to compute checksums for records before parsing them as protos. + pub fn checksum(&mut self, yes: bool) { + self.checksum = yes; + } + /// Reads the next event from the file. pub fn read_event(&mut self) -> Result { let record = self.reader.read_record()?; - record.checksum()?; - let event = Event::decode(&record.data[..])?; + let event = if self.checksum { + record.checksum()?; + Event::decode(&record.data[..])? + } else { + match Event::decode(&record.data[..]) { + Ok(proto) => proto, + Err(e) => { + record.checksum()?; + return Err(e.into()); + } + } + }; let wall_time = event.wall_time; if wall_time.is_nan() { return Err(ReadEventError::NanWallTime(event)); @@ -177,6 +195,62 @@ mod tests { assert_eq!(reader.last_wall_time(), &Some(1234.5)); } + #[test] + fn test_no_checksum() { + let event = Event { + what: Some(pb::event::What::FileVersion("hello".to_string())), + ..Event::default() + }; + let records = vec![ + TfRecord::from_data(encode_event(&event)), + { + let mut record = TfRecord::from_data(encode_event(&event)); + record.data_crc.0 ^= 0x1; // invalidate checksum + record + }, + { + let mut record = TfRecord::from_data(b"failed proto, failed checksum".to_vec()); + record.data_crc.0 ^= 0x1; // invalidate checksum + record + }, + TfRecord::from_data(b"failed proto, okay checksum".to_vec()), + ]; + let mut file = Vec::new(); + for record in records { + record.write(&mut file).expect("writing record"); + } + let mut reader = EventFileReader::new(Cursor::new(file)); + reader.checksum(false); + + // First record is genuinely okay. + match reader.read_event() { + Ok(_) => (), + other => panic!("first record: {:?}", other), + }; + // Second record is a valid proto, but invalid record checksum; with `checksum(false)`, + // this should not be caught. + match reader.read_event() { + Ok(_) => (), + other => panic!("second record: {:?}", other), + }; + // Third record is an invalid proto with an invalid checksum, so the checksum error should + // be caught. + match reader.read_event() { + Err(ReadEventError::InvalidRecord(_)) => (), + other => panic!("third record: {:?}", other), + }; + // Fourth record is an invalid proto with valid checksum, which should still be caught. + match reader.read_event() { + Err(ReadEventError::InvalidProto(_)) => (), + other => panic!("fourth record: {:?}", other), + }; + // After four records, should be done. + match reader.read_event() { + Err(ReadEventError::ReadRecordError(ReadRecordError::Truncated)) => (), + other => panic!("eof: {:?}", other), + }; + } + #[test] fn test_resume() { let event = Event { diff --git a/tensorboard/data/server/logdir.rs b/tensorboard/data/server/logdir.rs index 7b6355678f..10cab8e93e 100644 --- a/tensorboard/data/server/logdir.rs +++ b/tensorboard/data/server/logdir.rs @@ -29,6 +29,7 @@ pub struct LogdirLoader<'a> { commit: &'a Commit, logdir: PathBuf, runs: HashMap, + checksum: bool, } struct RunState { @@ -80,9 +81,15 @@ impl<'a> LogdirLoader<'a> { commit, logdir, runs: HashMap::new(), + checksum: true, } } + /// Sets whether to compute checksums for records before parsing them as protos. + pub fn checksum(&mut self, yes: bool) { + self.checksum = yes; + } + /// Performs a complete load cycle: finds all event files and reads data from all runs, /// updating the shared commit. /// @@ -198,6 +205,7 @@ impl<'a> LogdirLoader<'a> { // Add new runs, and warn on any path collisions for existing runs. for (run_name, event_files) in discoveries { + let checksum = self.checksum; let run = self .runs .entry(run_name.clone()) @@ -205,7 +213,11 @@ impl<'a> LogdirLoader<'a> { // Values of `discoveries` are non-empty by construction, so it's safe to take the // first relpath. relpath: event_files[0].run_relpath.clone(), - loader: RunLoader::new(), + loader: { + let mut loader = RunLoader::new(); + loader.checksum(checksum); + loader + }, collided_relpaths: HashSet::new(), }); for ef in event_files { diff --git a/tensorboard/data/server/run.rs b/tensorboard/data/server/run.rs index e28a179f9d..d1153d3f77 100644 --- a/tensorboard/data/server/run.rs +++ b/tensorboard/data/server/run.rs @@ -52,6 +52,9 @@ pub struct RunLoader { /// Reservoir-sampled data and metadata for each time series. time_series: HashMap, + + /// Whether to compute CRCs for records before parsing as protos. + checksum: bool, } #[derive(Debug)] @@ -148,9 +151,15 @@ impl RunLoader { start_time: None, files: BTreeMap::new(), time_series: HashMap::new(), + checksum: true, } } + /// Sets whether to compute checksums for records before parsing them as protos. + pub fn checksum(&mut self, yes: bool) { + self.checksum = yes; + } + /// Loads new data given the current set of event files. /// /// The provided filenames should correspond to the entire set of event files currently part of @@ -188,7 +197,8 @@ impl RunLoader { Entry::Vacant(v) => { let event_file = match File::open(v.key()) { Ok(file) => { - let reader = EventFileReader::new(BufReader::new(file)); + let mut reader = EventFileReader::new(BufReader::new(file)); + reader.checksum(self.checksum); EventFile::Active(reader) } // TODO(@wchargin): Improve error handling?