From 168467fc222cd3fe3d2bb3091e6d9712ceec00e3 Mon Sep 17 00:00:00 2001 From: Vincent Masse Date: Wed, 13 Nov 2024 20:22:23 -0500 Subject: [PATCH 1/5] Add option to request manual quit on tui --- crates/burn-train/src/learner/builder.rs | 13 ++++++++++++- crates/burn-train/src/renderer/base.rs | 8 ++++++++ .../burn-train/src/renderer/tui/renderer.rs | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/crates/burn-train/src/learner/builder.rs b/crates/burn-train/src/learner/builder.rs index 2298a41ee7..e88ae0712b 100644 --- a/crates/burn-train/src/learner/builder.rs +++ b/crates/burn-train/src/learner/builder.rs @@ -59,6 +59,7 @@ where early_stopping: Option>, summary_metrics: HashSet, summary: bool, + manual_quit: bool, } impl LearnerBuilder @@ -106,6 +107,7 @@ where early_stopping: None, summary_metrics: HashSet::new(), summary: false, + manual_quit: false, } } @@ -275,6 +277,12 @@ where self } + /// Enable manual quit mode for renderer. + pub fn with_manual_quit(mut self) -> Self { + self.manual_quit = true; + self + } + /// Enable the training summary report. /// /// The summary will be displayed at the end of `.fit()`. @@ -316,9 +324,12 @@ where log::warn!("Failed to install the experiment logger: {}", e); } } - let renderer = self + let mut renderer = self .renderer .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint)); + if self.manual_quit { + renderer.enable_manual_quit(); + } if self.num_loggers == 0 { self.event_store diff --git a/crates/burn-train/src/renderer/base.rs b/crates/burn-train/src/renderer/base.rs index 6cfc2a5eb0..f6b29fc1b6 100644 --- a/crates/burn-train/src/renderer/base.rs +++ b/crates/burn-train/src/renderer/base.rs @@ -31,6 +31,14 @@ pub trait MetricsRenderer: Send + Sync { /// /// * `item` - The validation progress. fn render_valid(&mut self, item: TrainingProgress); + + /// Enable manual quit. Default implementation warn that this feature is not implemented. + /// + fn enable_manual_quit(&mut self) { + log::warn!( + "Manual quit option will be ignored since it's not implemented for this renderer." + ) + } } /// The state of a metric. diff --git a/crates/burn-train/src/renderer/tui/renderer.rs b/crates/burn-train/src/renderer/tui/renderer.rs index 549dfdd1ba..1886f49cbd 100644 --- a/crates/burn-train/src/renderer/tui/renderer.rs +++ b/crates/burn-train/src/renderer/tui/renderer.rs @@ -44,6 +44,7 @@ pub struct TuiMetricsRenderer { interuptor: TrainingInterrupter, popup: PopupState, previous_panic_hook: Option>, + manual_quit: bool, } impl MetricsRenderer for TuiMetricsRenderer { @@ -84,6 +85,10 @@ impl MetricsRenderer for TuiMetricsRenderer { self.status.update_valid(item); self.render().unwrap(); } + + fn enable_manual_quit(&mut self) { + self.manual_quit = true; + } } impl TuiMetricsRenderer { @@ -116,6 +121,7 @@ impl TuiMetricsRenderer { interuptor, popup: PopupState::Empty, previous_panic_hook: Some(previous_panic_hook), + manual_quit: false, } } @@ -230,6 +236,19 @@ impl Drop for TuiMetricsRenderer { // Reset the terminal back to raw mode. This can be skipped during // panicking because the panic hook has already reset the terminal if !std::thread::panicking() { + if self.manual_quit { + // Wait for 'q' key press before closing + loop { + if let Ok(true) = event::poll(Duration::from_millis(100)) { + if let Ok(Event::Key(key)) = event::read() { + if let KeyCode::Char('q') = key.code { + break; + } + } + } + } + } + disable_raw_mode().ok(); execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap(); self.terminal.show_cursor().ok(); From 185de0ef6d7a18151b26c02aa395cc74e0f41b93 Mon Sep 17 00:00:00 2001 From: Vincent Masse Date: Thu, 14 Nov 2024 07:47:35 -0500 Subject: [PATCH 2/5] Update book --- burn-book/src/building-blocks/learner.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/burn-book/src/building-blocks/learner.md b/burn-book/src/building-blocks/learner.md index e9ceb38a7f..16b984c71a 100644 --- a/burn-book/src/building-blocks/learner.md +++ b/burn-book/src/building-blocks/learner.md @@ -28,10 +28,11 @@ The learner builder provides numerous options when it comes to configurations. | Renderer | Configure how to render metrics (default is CLI) | | Grad Accumulation | Configure the number of steps before applying gradients | | File Checkpointer | Configure how the model, optimizer and scheduler states are saved | -| Num Epochs | Set the number of epochs. | +| Num Epochs | Set the number of epochs | | Devices | Set the devices to be used | | Checkpoint | Restart training from a checkpoint | | Application logging | Configure the application logging installer (default is writing to `experiment.log`) | +| Manual quit | Configure the renderer to wait for user to quit after the training | When the builder is configured at your liking, you can then move forward to build the learner. The build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note From 3d9653f644cac5c4d21d2168804753f2a682682c Mon Sep 17 00:00:00 2001 From: Vincent Masse Date: Mon, 18 Nov 2024 17:32:08 -0500 Subject: [PATCH 3/5] Update post training implementation for TUI --- burn-book/src/building-blocks/learner.md | 1 - crates/burn-train/src/learner/builder.rs | 13 +--- crates/burn-train/src/renderer/base.rs | 8 --- .../burn-train/src/renderer/tui/renderer.rs | 63 +++++++++++++++---- 4 files changed, 51 insertions(+), 34 deletions(-) diff --git a/burn-book/src/building-blocks/learner.md b/burn-book/src/building-blocks/learner.md index 16b984c71a..c46140e0f6 100644 --- a/burn-book/src/building-blocks/learner.md +++ b/burn-book/src/building-blocks/learner.md @@ -32,7 +32,6 @@ The learner builder provides numerous options when it comes to configurations. | Devices | Set the devices to be used | | Checkpoint | Restart training from a checkpoint | | Application logging | Configure the application logging installer (default is writing to `experiment.log`) | -| Manual quit | Configure the renderer to wait for user to quit after the training | When the builder is configured at your liking, you can then move forward to build the learner. The build method requires three inputs: the model, the optimizer and the learning rate scheduler. Note diff --git a/crates/burn-train/src/learner/builder.rs b/crates/burn-train/src/learner/builder.rs index e88ae0712b..2298a41ee7 100644 --- a/crates/burn-train/src/learner/builder.rs +++ b/crates/burn-train/src/learner/builder.rs @@ -59,7 +59,6 @@ where early_stopping: Option>, summary_metrics: HashSet, summary: bool, - manual_quit: bool, } impl LearnerBuilder @@ -107,7 +106,6 @@ where early_stopping: None, summary_metrics: HashSet::new(), summary: false, - manual_quit: false, } } @@ -277,12 +275,6 @@ where self } - /// Enable manual quit mode for renderer. - pub fn with_manual_quit(mut self) -> Self { - self.manual_quit = true; - self - } - /// Enable the training summary report. /// /// The summary will be displayed at the end of `.fit()`. @@ -324,12 +316,9 @@ where log::warn!("Failed to install the experiment logger: {}", e); } } - let mut renderer = self + let renderer = self .renderer .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint)); - if self.manual_quit { - renderer.enable_manual_quit(); - } if self.num_loggers == 0 { self.event_store diff --git a/crates/burn-train/src/renderer/base.rs b/crates/burn-train/src/renderer/base.rs index f6b29fc1b6..6cfc2a5eb0 100644 --- a/crates/burn-train/src/renderer/base.rs +++ b/crates/burn-train/src/renderer/base.rs @@ -31,14 +31,6 @@ pub trait MetricsRenderer: Send + Sync { /// /// * `item` - The validation progress. fn render_valid(&mut self, item: TrainingProgress); - - /// Enable manual quit. Default implementation warn that this feature is not implemented. - /// - fn enable_manual_quit(&mut self) { - log::warn!( - "Manual quit option will be ignored since it's not implemented for this renderer." - ) - } } /// The state of a metric. diff --git a/crates/burn-train/src/renderer/tui/renderer.rs b/crates/burn-train/src/renderer/tui/renderer.rs index 1886f49cbd..d3cd1ec32f 100644 --- a/crates/burn-train/src/renderer/tui/renderer.rs +++ b/crates/burn-train/src/renderer/tui/renderer.rs @@ -85,10 +85,6 @@ impl MetricsRenderer for TuiMetricsRenderer { self.status.update_valid(item); self.render().unwrap(); } - - fn enable_manual_quit(&mut self) { - self.manual_quit = true; - } } impl TuiMetricsRenderer { @@ -125,6 +121,11 @@ impl TuiMetricsRenderer { } } + /// Enable manual quit after training. + pub fn enable_manual_quit(&mut self) { + self.manual_quit = true; + } + fn render(&mut self) -> Result<(), Box> { let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS); if self.last_update.elapsed() < tick_rate { @@ -206,6 +207,49 @@ impl TuiMetricsRenderer { Ok(()) } + + fn handle_post_training(&mut self) -> Result<(), Box> { + self.popup = PopupState::Full( + "Training is done".to_string(), + vec![Callback::new( + "Training Done", + "Press 'x' to close this popup. Press 'q' to exit the application after the \ + popup is closed.", + 'x', + PopupCancel, + )], + ); + + self.draw().ok(); + + loop { + if let Ok(true) = event::poll(Duration::from_millis(MAX_REFRESH_RATE_MILLIS)) { + match event::read() { + Ok(event @ Event::Key(key)) => { + if self.popup.is_empty() { + self.metrics_numeric.on_event(&event); + if let KeyCode::Char('q') = key.code { + break; + } + } else { + self.popup.on_event(&event); + } + self.draw().ok(); + } + + Ok(Event::Resize(..)) => { + self.draw().ok(); + } + Err(err) => { + eprintln!("Error reading event: {}", err); + break; + } + _ => continue, + } + } + } + Ok(()) + } } struct QuitPopupAccept(TrainingInterrupter); @@ -237,15 +281,8 @@ impl Drop for TuiMetricsRenderer { // panicking because the panic hook has already reset the terminal if !std::thread::panicking() { if self.manual_quit { - // Wait for 'q' key press before closing - loop { - if let Ok(true) = event::poll(Duration::from_millis(100)) { - if let Ok(Event::Key(key)) = event::read() { - if let KeyCode::Char('q') = key.code { - break; - } - } - } + if let Err(err) = self.handle_post_training() { + eprintln!("Error in post-training handling: {}", err); } } From fa1d078db98d7f6cc13591f5daddfdeb1a884b47 Mon Sep 17 00:00:00 2001 From: Vincent Masse Date: Tue, 19 Nov 2024 19:27:06 -0500 Subject: [PATCH 4/5] Make tui public --- crates/burn-train/src/renderer/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/burn-train/src/renderer/mod.rs b/crates/burn-train/src/renderer/mod.rs index a123bf6944..427d8e5603 100644 --- a/crates/burn-train/src/renderer/mod.rs +++ b/crates/burn-train/src/renderer/mod.rs @@ -6,8 +6,9 @@ pub use base::*; mod cli; +/// The tui renderer #[cfg(feature = "tui")] -mod tui; +pub mod tui; use crate::TrainingInterrupter; /// Return the default metrics renderer. From b417f3a6f33d1baaf8a5f1d87a3832e3914eff94 Mon Sep 17 00:00:00 2001 From: Vincent Masse Date: Thu, 21 Nov 2024 18:43:56 -0500 Subject: [PATCH 5/5] Rename manual_quit to persistent --- crates/burn-train/src/renderer/tui/renderer.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/crates/burn-train/src/renderer/tui/renderer.rs b/crates/burn-train/src/renderer/tui/renderer.rs index d3cd1ec32f..7498f7980f 100644 --- a/crates/burn-train/src/renderer/tui/renderer.rs +++ b/crates/burn-train/src/renderer/tui/renderer.rs @@ -44,7 +44,7 @@ pub struct TuiMetricsRenderer { interuptor: TrainingInterrupter, popup: PopupState, previous_panic_hook: Option>, - manual_quit: bool, + persistent: bool, } impl MetricsRenderer for TuiMetricsRenderer { @@ -117,13 +117,14 @@ impl TuiMetricsRenderer { interuptor, popup: PopupState::Empty, previous_panic_hook: Some(previous_panic_hook), - manual_quit: false, + persistent: false, } } - /// Enable manual quit after training. - pub fn enable_manual_quit(&mut self) { - self.manual_quit = true; + /// Set the renderer to persistent mode. + pub fn persistent(mut self) -> Self { + self.persistent = true; + self } fn render(&mut self) -> Result<(), Box> { @@ -280,7 +281,7 @@ impl Drop for TuiMetricsRenderer { // Reset the terminal back to raw mode. This can be skipped during // panicking because the panic hook has already reset the terminal if !std::thread::panicking() { - if self.manual_quit { + if self.persistent { if let Err(err) = self.handle_post_training() { eprintln!("Error in post-training handling: {}", err); }