diff --git a/tensorboard/data/server/cli.rs b/tensorboard/data/server/cli.rs index 7c3986fc06..e592a20fd8 100644 --- a/tensorboard/data/server/cli.rs +++ b/tensorboard/data/server/cli.rs @@ -121,9 +121,10 @@ struct Opts { /// /// A comma separated list of `plugin_name=num_samples` pairs to explicitly specify how many /// samples to keep per tag for the specified plugin. For unspecified plugins, series are - /// randomly downsampled to reasonable values to prevent out-of-memory errors in long running - /// jobs. For instance, `--samples_per_plugin=scalars=500,images=0` keeps 500 events in each - /// scalar series and keeps none of the images. + /// randomly downsampled to reasonable values to prevent out-of-memory errors in long-running + /// jobs. Each `num_samples` may be the special token `all` to retain all data without + /// downsampling. For instance, `--samples_per_plugin=scalars=500,images=all,audio=0` keeps 500 + /// events in each scalar series, all of the images, and none of the audio. #[clap(long, default_value = "", setting(clap::ArgSettings::AllowEmptyValues))] samples_per_plugin: PluginSamplingHint, } diff --git a/tensorboard/data/server/reservoir.rs b/tensorboard/data/server/reservoir.rs index ba8131ad24..2935087506 100644 --- a/tensorboard/data/server/reservoir.rs +++ b/tensorboard/data/server/reservoir.rs @@ -81,8 +81,8 @@ pub struct StageReservoir { /// Total capacity of this reservoir. /// /// The combined physical capacities of `committed_steps` and `staged_items` may exceed this, - /// but their combined lengths will not. Behavior is undefined if `capacity == 0`. - capacity: usize, + /// but their combined lengths will not. + capacity: Capacity, /// Reservoir control, to determine whether and whither a given new record should be included. ctl: C, /// Estimate of the total number of non-preempted records passed in the stream so far, @@ -100,6 +100,24 @@ pub struct StageReservoir { seen: usize, } +/// Reservoir capacity, determining if and when items should start being evicted. +#[derive(PartialEq, Eq, Debug, Copy, Clone)] +pub enum Capacity { + /// The reservoir may have arbitrarily many records. + /// + /// An unbounded reservoir still supports preemption, but otherwise behaves like a normal + /// vector. + Unbounded, + /// The reservoir may have at most a fixed number of records. + Bounded(usize), +} + +impl From for Capacity { + fn from(n: usize) -> Self { + Capacity::Bounded(n) + } +} + /// A buffer of records that have been committed and not yet evicted from the reservoir. /// /// This is a snapshot of the reservoir contents at some point in time that is periodically updated @@ -154,8 +172,8 @@ impl StageReservoir { /// All reservoirs created by this function will use the same sequence of random numbers. /// /// This function does not allocate. Reservoir capacity is allocated as records are offered. - pub fn new(capacity: usize) -> Self { - Self::with_control(capacity, ChaCha20Rng::seed_from_u64(0)) + pub fn new(capacity: impl Into) -> Self { + Self::with_control(capacity.into(), ChaCha20Rng::seed_from_u64(0)) } } @@ -163,11 +181,11 @@ impl StageReservoir { /// Creates a new reservoir with the specified capacity and reservoir control. /// /// This function does not allocate. Reservoir capacity is allocated as records are offered. - pub fn with_control(capacity: usize, ctl: C) -> Self { + pub fn with_control(capacity: impl Into, ctl: C) -> Self { Self { committed_steps: Vec::new(), staged_items: Vec::new(), - capacity, + capacity: capacity.into(), ctl, seen: 0, } @@ -179,24 +197,26 @@ impl StageReservoir { /// records kept form a simple random sample of the stream (or at least approximately so in the /// case of preemptions). pub fn offer(&mut self, step: Step, v: T) { - if self.capacity == 0 { + if self.capacity == Capacity::Bounded(0) { return; } self.preempt(step); self.seen += 1; - // If we can hold every record that we've seen, we can add this record unconditionally. - // Otherwise, we need to roll a destination---even if there's available space, to avoid - // bias right after a preemption. - if self.seen > self.capacity { - let dst = self.ctl.destination(self.seen); - if dst >= self.capacity { - // Didn't make the cut? Keep-last only. - self.pop(); - } else if self.len() >= self.capacity { - // No room? Evict the destination. - // From `if`-guards, we know `dst < self.capacity <= self.len()`, so this is safe. - self.remove(dst); + if let Capacity::Bounded(capacity) = self.capacity { + // If we can hold every record that we've seen, we can add this record unconditionally. + // Otherwise, we need to roll a destination---even if there's available space, to avoid + // bias right after a preemption. + if self.seen > capacity { + let dst = self.ctl.destination(self.seen); + if dst >= capacity { + // Didn't make the cut? Keep-last only. + self.pop(); + } else if self.len() >= capacity { + // No room? Evict the destination. + // From `if`-guards, we know `dst < capacity <= self.len()`, so this is safe. + self.remove(dst); + } } } // In any case, add to end. @@ -542,6 +562,43 @@ mod tests { } } + #[test] + fn test_unbounded() { + let mut rsv = StageReservoir::new(Capacity::Unbounded); + let mut head = Basin::new(); + + rsv.commit(&mut head); + assert_eq!(head.as_slice(), &[]); + + rsv.offer(Step(0), "before"); + rsv.offer(Step(1), "before"); + rsv.offer(Step(2), "before"); + rsv.offer(Step(4), "before"); + rsv.commit(&mut head); + assert_eq!( + head.as_slice(), + &[ + (Step(0), "before"), + (Step(1), "before"), + (Step(2), "before"), + (Step(4), "before") + ] + ); + + rsv.offer(Step(2), "after"); + rsv.offer(Step(5), "after"); + rsv.commit(&mut head); + assert_eq!( + head.as_slice(), + &[ + (Step(0), "before"), + (Step(1), "before"), + (Step(2), "after"), + (Step(5), "after") + ] + ); + } + #[test] fn test_empty() { let mut rsv = StageReservoir::new(0); diff --git a/tensorboard/data/server/run.rs b/tensorboard/data/server/run.rs index 86398233d5..2acff3e4eb 100644 --- a/tensorboard/data/server/run.rs +++ b/tensorboard/data/server/run.rs @@ -26,7 +26,7 @@ use crate::data_compat::{EventValue, GraphDefValue, SummaryValue, TaggedRunMetad use crate::event_file::EventFileReader; use crate::logdir::{EventFileBuf, Logdir}; use crate::proto::tensorboard as pb; -use crate::reservoir::StageReservoir; +use crate::reservoir::{Capacity, StageReservoir}; use crate::types::{PluginSamplingHint, Run, Step, Tag, WallTime}; /// A loader to accumulate reservoir-sampled events in a single TensorBoard run. @@ -119,23 +119,21 @@ impl StageTimeSeries { fn capacity( metadata: &pb::SummaryMetadata, plugin_sampling_hint: Arc, - ) -> usize { + ) -> Capacity { let data_class = pb::DataClass::from_i32(metadata.data_class).unwrap_or(pb::DataClass::Unknown); - let mut capacity = match data_class { + let mut capacity = Capacity::Bounded(match data_class { pb::DataClass::Scalar => 1000, pb::DataClass::Tensor => 100, pb::DataClass::BlobSequence => 10, _ => 0, - }; + }); // Override the default capacity using the plugin-specific hint. if data_class != pb::DataClass::Unknown { if let Some(ref pd) = metadata.plugin_data { - if let Some(&num_samples) = plugin_sampling_hint.0.get(&pd.plugin_name) { - // TODO(psybuzz): if the hint prescribes 0 samples, the reservoir should ideally - // be unbounded. For now, it simply creates a reservoir with capacity 0. - capacity = num_samples; + if let Some(&hint) = plugin_sampling_hint.0.get(&pd.plugin_name) { + capacity = hint; } } } diff --git a/tensorboard/data/server/types.rs b/tensorboard/data/server/types.rs index 76529702a4..87027d75f1 100644 --- a/tensorboard/data/server/types.rs +++ b/tensorboard/data/server/types.rs @@ -15,11 +15,12 @@ limitations under the License. //! Core simple types. -use log::error; use std::borrow::Borrow; use std::collections::HashMap; use std::str::FromStr; +use crate::reservoir::Capacity; + /// A step associated with a record, strictly increasing over time within a record stream. #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Copy, Clone)] pub struct Step(pub i64); @@ -102,7 +103,7 @@ pub enum ParsePluginSamplingHintError { /// A map defining how many samples per plugin to keep. #[derive(Debug, Default)] -pub struct PluginSamplingHint(pub HashMap); +pub struct PluginSamplingHint(pub HashMap); impl FromStr for PluginSamplingHint { type Err = ParsePluginSamplingHintError; @@ -118,8 +119,12 @@ impl FromStr for PluginSamplingHint { part: pair_str.to_string(), }); } - let num_samples = pair[1].parse::()?; let plugin_name: String = pair[0].to_string(); + let num_samples = if pair[1] == "all" { + Capacity::Unbounded + } else { + Capacity::Bounded(pair[1].parse::()?) + }; result.insert(plugin_name, num_samples); } Ok(PluginSamplingHint(result)) @@ -176,17 +181,20 @@ mod tests { #[test] fn test_plugin_sampling_hint() { + use Capacity::{Bounded, Unbounded}; + // Parse from a valid hint with arbitrary plugin names. - let hint1 = "scalars=500,images=0,unknown=10".parse::(); - let mut expected1: HashMap = HashMap::new(); - expected1.insert("scalars".to_string(), 500); - expected1.insert("images".to_string(), 0); - expected1.insert("unknown".to_string(), 10); + let hint1 = "scalars=500,images=0,histograms=all,unknown=10".parse::(); + let mut expected1: HashMap = HashMap::new(); + expected1.insert("scalars".to_string(), Bounded(500)); + expected1.insert("images".to_string(), Bounded(0)); + expected1.insert("histograms".to_string(), Unbounded); + expected1.insert("unknown".to_string(), Bounded(10)); assert_eq!(hint1.unwrap().0, expected1); // Parse from an empty hint. let hint2 = "".parse::(); - let expected2: HashMap = HashMap::new(); + let expected2: HashMap = HashMap::new(); assert_eq!(hint2.unwrap().0, expected2); // Parse from an invalid hint. @@ -195,6 +203,11 @@ mod tests { other => panic!("expected ParseIntError, got {:?}", other), }; + match "x=wat".parse::().unwrap_err() { + ParsePluginSamplingHintError::ParseIntError(_) => (), + other => panic!("expected ParseIntError, got {:?}", other), + }; + match "=1".parse::().unwrap_err() { ParsePluginSamplingHintError::SyntaxError { part: _ } => (), other => panic!("expected SyntaxError, got {:?}", other), diff --git a/tensorboard/data/server_ingester.py b/tensorboard/data/server_ingester.py index a058888498..fda9132164 100644 --- a/tensorboard/data/server_ingester.py +++ b/tensorboard/data/server_ingester.py @@ -103,7 +103,8 @@ def start(self): reload = str(int(self._reload_interval)) sample_hint_pairs = [ - "%s=%s" % (k, v) for k, v in self._samples_per_plugin.items() + "%s=%s" % (k, "all" if v == 0 else v) + for k, v in self._samples_per_plugin.items() ] samples_per_plugin = ",".join(sample_hint_pairs) diff --git a/tensorboard/data/server_ingester_test.py b/tensorboard/data/server_ingester_test.py index a805ff129d..9a6538f3db 100644 --- a/tensorboard/data/server_ingester_test.py +++ b/tensorboard/data/server_ingester_test.py @@ -86,6 +86,10 @@ def target(): logdir=logdir, reload_interval=5, channel_creds_type=grpc_util.ChannelCredsType.LOCAL, + samples_per_plugin={ + "scalars": 500, + "images": 0, + }, ) ingester.start() self.assertIsInstance( @@ -99,6 +103,7 @@ def target(): "--port=0", "--port-file=%s" % port_file, "--die-after-stdin", + "--samples-per-plugin=scalars=500,images=all", "--verbose", # logging is enabled in tests ] popen.assert_called_once_with(expected_args, stdin=subprocess.PIPE)