diff --git a/tensorboard/data/server/BUILD b/tensorboard/data/server/BUILD index 073b741aad..41d3cdf994 100644 --- a/tensorboard/data/server/BUILD +++ b/tensorboard/data/server/BUILD @@ -55,6 +55,7 @@ rust_library( "//third_party/rust:prost", "//third_party/rust:rand", "//third_party/rust:rand_chacha", + "//third_party/rust:rayon", "//third_party/rust:serde", "//third_party/rust:serde_json", "//third_party/rust:thiserror", @@ -82,6 +83,7 @@ rust_binary( "//third_party/rust:clap", "//third_party/rust:env_logger", "//third_party/rust:log", + "//third_party/rust:rayon", ], ) diff --git a/tensorboard/data/server/bench.rs b/tensorboard/data/server/bench.rs index 0ea320e84f..dd409b2cb8 100644 --- a/tensorboard/data/server/bench.rs +++ b/tensorboard/data/server/bench.rs @@ -29,6 +29,8 @@ struct Opts { logdir: PathBuf, #[clap(long, default_value = "info")] log_level: String, + #[clap(long)] + reload_threads: Option, // Pair of `--no-checksum` and `--checksum` flags, defaulting to "no checksum". #[clap(long, multiple_occurrences = true, overrides_with = "checksum")] #[allow(unused)] @@ -42,7 +44,7 @@ fn main() { init_logging(&opts); let commit = Commit::new(); - let mut loader = LogdirLoader::new(&commit, opts.logdir); + let mut loader = LogdirLoader::new(&commit, opts.logdir, opts.reload_threads.unwrap_or(0)); loader.checksum(opts.checksum); // if neither `--[no-]checksum` given, defaults to false info!("Starting load cycle"); diff --git a/tensorboard/data/server/cli.rs b/tensorboard/data/server/cli.rs index fc6c4be55e..8c23e95852 100644 --- a/tensorboard/data/server/cli.rs +++ b/tensorboard/data/server/cli.rs @@ -177,7 +177,7 @@ pub async fn main() -> Result<(), Box> { .spawn({ let logdir = opts.logdir; let reload_strategy = opts.reload; - let mut loader = LogdirLoader::new(commit, logdir); + let mut loader = LogdirLoader::new(commit, logdir, 0); // Checksum only if `--checksum` given (i.e., off by default). loader.checksum(opts.checksum); move || loop { diff --git a/tensorboard/data/server/logdir.rs b/tensorboard/data/server/logdir.rs index dfac6fb760..23ff658933 100644 --- a/tensorboard/data/server/logdir.rs +++ b/tensorboard/data/server/logdir.rs @@ -16,6 +16,7 @@ limitations under the License. //! Loader for many runs under a directory. use log::{error, warn}; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; use std::collections::{HashMap, HashSet}; use std::path::{Path, PathBuf}; use walkdir::WalkDir; @@ -26,6 +27,7 @@ use crate::types::Run; /// A loader for a log directory, connecting a filesystem to a [`Commit`] via [`RunLoader`]s. pub struct LogdirLoader<'a> { + thread_pool: rayon::ThreadPool, commit: &'a Commit, logdir: PathBuf, runs: HashMap, @@ -76,8 +78,23 @@ const EVENT_FILE_BASENAME_INFIX: &str = "tfevents"; impl<'a> LogdirLoader<'a> { /// Creates a new, empty logdir loader. Does not load any data. - pub fn new(commit: &'a Commit, logdir: PathBuf) -> Self { + /// + /// This constructor is heavyweight: it builds a new thread-pool. The thread pool will be + /// reused for all calls to [`Self::reload`]. If `reload_threads` is `0`, the number of threads + /// will be determined automatically, per [`rayon`] semantics. + /// + /// # Panics + /// + /// If [`rayon::ThreadPoolBuilder::build`] returns an error; should only happen if there is a + /// failure to create threads at the OS level. + pub fn new(commit: &'a Commit, logdir: PathBuf, reload_threads: usize) -> Self { + let thread_pool = rayon::ThreadPoolBuilder::new() + .num_threads(reload_threads) + .thread_name(|i| format!("Reloader-{:03}", i)) + .build() + .expect("failed to construct Rayon thread pool"); LogdirLoader { + thread_pool, commit, logdir, runs: HashMap::new(), @@ -246,22 +263,27 @@ impl<'a> LogdirLoader<'a> { .runs .read() .expect("could not acquire runs.data"); + + let mut work_items = Vec::new(); for (run, run_state) in self.runs.iter_mut() { let event_files = discoveries .remove(run) .unwrap_or_else(|| panic!("run in self.runs but not discovered: {:?}", run)); let filenames: Vec = event_files.into_iter().map(|d| d.event_file).collect(); - run_state.loader.reload( - filenames, - commit_runs.get(run).unwrap_or_else(|| { - panic!( - "run in self.runs but not in commit.runs \ + let run_data = commit_runs.get(run).unwrap_or_else(|| { + panic!( + "run in self.runs but not in commit.runs \ (is another client mutating this commit?): {:?}", - run - ) - }), - ); + run + ) + }); + work_items.push((&mut run_state.loader, filenames, run_data)); } + self.thread_pool.install(|| { + work_items + .into_par_iter() + .for_each(|(loader, filenames, run_data)| loader.reload(filenames, run_data)); + }); } } @@ -319,7 +341,7 @@ mod tests { let test_relpath: PathBuf = ["mnist", "test"].iter().collect(); let commit = Commit::new(); - let mut loader = LogdirLoader::new(&commit, logdir.path().to_path_buf()); + let mut loader = LogdirLoader::new(&commit, logdir.path().to_path_buf(), 1); // Check that we persist the right run states in the loader. loader.reload(); @@ -379,7 +401,7 @@ mod tests { )?; let commit = Commit::new(); - let mut loader = LogdirLoader::new(&commit, logdir.path().to_path_buf()); + let mut loader = LogdirLoader::new(&commit, logdir.path().to_path_buf(), 1); let get_run_names = || { let runs_store = commit.runs.read().unwrap(); @@ -442,7 +464,7 @@ mod tests { } let commit = Commit::new(); - let mut loader = LogdirLoader::new(&commit, logdir.path().to_path_buf()); + let mut loader = LogdirLoader::new(&commit, logdir.path().to_path_buf(), 1); loader.reload(); assert_eq!(loader.runs.keys().collect::>(), vec![&run]); @@ -467,7 +489,7 @@ mod tests { File::create(train_dir.join(EVENT_FILE_BASENAME_INFIX))?; let commit = Commit::new(); - let mut loader = LogdirLoader::new(&commit, logdir.path().to_path_buf()); + let mut loader = LogdirLoader::new(&commit, logdir.path().to_path_buf(), 1); loader.reload(); assert_eq!( @@ -489,7 +511,7 @@ mod tests { std::os::unix::fs::symlink(&dir2, &dir1)?; let commit = Commit::new(); - let mut loader = LogdirLoader::new(&commit, logdir.path().to_path_buf()); + let mut loader = LogdirLoader::new(&commit, logdir.path().to_path_buf(), 1); loader.reload(); // should not hang Ok(()) }