diff --git a/src/app/app_execution.rs b/src/app/app_execution.rs index b4b7c8f..ea1c1c8 100644 --- a/src/app/app_execution.rs +++ b/src/app/app_execution.rs @@ -17,7 +17,6 @@ //! [`AppExecution`]: Handles executing queries for the TUI application. -use crate::app::state::tabs::sql::Query; use crate::app::{AppEvent, ExecutionError, ExecutionResultsBatch}; use crate::execution::ExecutionContext; use color_eyre::eyre::Result; @@ -27,7 +26,6 @@ use datafusion::physical_plan::execute_stream; use futures::StreamExt; use log::{error, info}; use std::sync::Arc; -use std::time::Duration; use tokio::sync::mpsc::UnboundedSender; use tokio::sync::Mutex; @@ -62,23 +60,25 @@ impl AppExecution { /// /// Error handling: If an error occurs while executing a query, the error is /// logged and execution continues - pub async fn run_sqls(&self, sqls: Vec<&str>, sender: UnboundedSender) -> Result<()> { + pub async fn run_sqls( + self: Arc, + sqls: Vec, + sender: UnboundedSender, + ) -> Result<()> { // We need to filter out empty strings to correctly determine the last query for displaying // results. info!("Running sqls: {:?}", sqls); - let non_empty_sqls: Vec<&str> = sqls.into_iter().filter(|s| !s.is_empty()).collect(); + let non_empty_sqls: Vec = sqls.into_iter().filter(|s| !s.is_empty()).collect(); info!("Non empty SQLs: {:?}", non_empty_sqls); let statement_count = non_empty_sqls.len(); for (i, sql) in non_empty_sqls.into_iter().enumerate() { info!("Running query {}", i); let _sender = sender.clone(); - let mut query = - Query::new(sql.to_string(), None, None, None, Duration::default(), None); let start = std::time::Instant::now(); if i == statement_count - 1 { info!("Executing last query and display results"); sender.send(AppEvent::NewExecution)?; - match self.inner.create_physical_plan(sql).await { + match self.inner.create_physical_plan(&sql).await { Ok(plan) => match execute_stream(plan, self.inner.session_ctx().task_ctx()) { Ok(stream) => { self.set_result_stream(stream).await; @@ -105,7 +105,7 @@ impl AppExecution { } } Err(stream_err) => { - error!("Error creating physical plan: {:?}", stream_err); + error!("Error executing stream: {:?}", stream_err); let elapsed = start.elapsed(); let e = ExecutionError { query: sql.to_string(), @@ -127,11 +127,8 @@ impl AppExecution { } } } else { - match self.inner.execute_sql_and_discard_results(sql).await { - Ok(_) => { - let elapsed = start.elapsed(); - query.set_execution_time(elapsed); - } + match self.inner.execute_sql_and_discard_results(&sql).await { + Ok(_) => {} Err(e) => { // We only log failed queries, we don't want to stop the execution of the // remaining queries. Perhaps there should be a configuration option for @@ -140,7 +137,6 @@ impl AppExecution { } } } - _sender.send(AppEvent::QueryResult(query))?; // Send the query result to the UI } Ok(()) } diff --git a/src/app/handlers/mod.rs b/src/app/handlers/mod.rs index c77bba6..b91f37d 100644 --- a/src/app/handlers/mod.rs +++ b/src/app/handlers/mod.rs @@ -163,7 +163,7 @@ pub fn app_event_handler(app: &mut App, event: AppEvent) -> Result<()> { }) .collect(); let ctx = app.execution.session_ctx().clone(); - tokio::spawn(async move { + let handle = tokio::spawn(async move { for q in queries { info!("Executing DDL: {:?}", q); match ctx.sql(&q).await { @@ -178,6 +178,7 @@ pub fn app_event_handler(app: &mut App, event: AppEvent) -> Result<()> { } } }); + app.ddl_task = Some(handle); } AppEvent::NewExecution => { app.state.sql_tab.reset_execution_results(); diff --git a/src/app/handlers/sql.rs b/src/app/handlers/sql.rs index 4812dfd..00c7fbe 100644 --- a/src/app/handlers/sql.rs +++ b/src/app/handlers/sql.rs @@ -62,12 +62,9 @@ pub fn normal_mode_handler(app: &mut App, key: KeyEvent) { info!("Running query: {}", sql); let _event_tx = app.event_tx().clone(); let execution = Arc::clone(&app.execution); - // TODO: Extract this into function to be used in both normal and editable handler. - // Only useful if we get Ctrl / Cmd + Enter to work in editable mode though. - tokio::spawn(async move { - let sqls: Vec<&str> = sql.split(';').collect(); - let _ = execution.run_sqls(sqls, _event_tx).await; - }); + let sqls: Vec = sql.split(';').map(|s| s.to_string()).collect(); + let handle = tokio::spawn(execution.run_sqls(sqls, _event_tx)); + app.state.sql_tab.set_execution_task(handle); } KeyCode::Right => { let _event_tx = app.event_tx().clone(); diff --git a/src/app/mod.rs b/src/app/mod.rs index a95ae85..6cf177c 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -42,7 +42,6 @@ use tokio_util::sync::CancellationToken; use self::app_execution::AppExecution; use self::handlers::{app_event_handler, crossterm_event_handler}; -use self::state::tabs::sql::Query; use crate::execution::ExecutionContext; #[cfg(feature = "flightsql")] @@ -121,7 +120,6 @@ pub enum AppEvent { Resize(u16, u16), ExecuteDDL(String), NewExecution, - QueryResult(Query), ExecutionResultsNextPage(ExecutionResultsBatch), ExecutionResultsPreviousPage, ExecutionResultsError(ExecutionError), @@ -138,6 +136,7 @@ pub struct App<'app> { event_rx: UnboundedReceiver, cancellation_token: CancellationToken, task: JoinHandle<()>, + ddl_task: Option>, } impl<'app> App<'app> { @@ -145,6 +144,7 @@ impl<'app> App<'app> { let (event_tx, event_rx) = mpsc::unbounded_channel(); let cancellation_token = CancellationToken::new(); let task = tokio::spawn(async {}); + // let ddl_task = tokio::spawn(async {}); let app_execution = Arc::new(AppExecution::new(Arc::new(execution))); Self { @@ -154,6 +154,7 @@ impl<'app> App<'app> { event_tx, cancellation_token, execution: app_execution, + ddl_task: None, } } @@ -161,6 +162,10 @@ impl<'app> App<'app> { self.event_tx.clone() } + pub fn ddl_task(&mut self) -> &mut Option> { + &mut self.ddl_task + } + pub fn event_rx(&mut self) -> &mut UnboundedReceiver { &mut self.event_rx } @@ -181,6 +186,10 @@ impl<'app> App<'app> { &self.state } + pub fn state_mut(&mut self) -> &mut state::AppState<'app> { + &mut self.state + } + /// Enter app, optionally setup `crossterm` with UI settings such as alternative screen and /// mouse capture, then start event loop. pub fn enter(&mut self, ui: bool) -> Result<()> { @@ -194,7 +203,7 @@ impl<'app> App<'app> { ratatui::crossterm::execute!(std::io::stdout(), event::EnableBracketedPaste)?; } } - self.start_event_loop(); + self.start_app_event_loop(); Ok(()) } @@ -293,8 +302,10 @@ impl<'app> App<'app> { }); } + /// Execute DDL from users DDL file pub fn execute_ddl(&mut self) { if let Some(user_dirs) = directories::UserDirs::new() { + // TODO: Move to ~/.config/ddl let datafusion_rc_path = user_dirs .home_dir() .join(".datafusion") @@ -321,11 +332,6 @@ impl<'app> App<'app> { let _ = self.event_tx().send(AppEvent::EstablishFlightSQLConnection); } - /// Dispatch to the appropriate event loop based on the command - pub fn start_event_loop(&mut self) { - self.start_app_event_loop() - } - /// Get the next event from event loop pub async fn next(&mut self) -> Result { self.event_rx() @@ -349,6 +355,20 @@ impl<'app> App<'app> { .divider(" ") .render(area, buf); } + + pub async fn loop_without_render(&mut self) -> Result<()> { + self.enter(true)?; + // Main loop for handling events + loop { + let event = self.next().await?; + + self.handle_app_event(event)?; + + if self.state.should_quit { + break Ok(()); + } + } + } } impl Widget for &App<'_> { diff --git a/src/app/state/tabs/sql.rs b/src/app/state/tabs/sql.rs index 8155cc7..f897f8b 100644 --- a/src/app/state/tabs/sql.rs +++ b/src/app/state/tabs/sql.rs @@ -16,93 +16,19 @@ // under the License. use core::cell::RefCell; -use std::time::Duration; +use color_eyre::Result; use datafusion::arrow::array::RecordBatch; use datafusion::sql::sqlparser::keywords; use ratatui::crossterm::event::KeyEvent; use ratatui::style::palette::tailwind; use ratatui::style::{Modifier, Style}; use ratatui::widgets::TableState; +use tokio::task::JoinHandle; use tui_textarea::TextArea; use crate::app::ExecutionError; use crate::config::AppConfig; -use crate::execution::ExecutionStats; - -#[derive(Clone, Debug)] -pub struct Query { - sql: String, - results: Option>, - num_rows: Option, - error: Option, - execution_time: Duration, - execution_stats: Option, -} - -impl Query { - pub fn new( - sql: String, - results: Option>, - num_rows: Option, - error: Option, - execution_time: Duration, - execution_stats: Option, - ) -> Self { - Self { - sql, - results, - num_rows, - error, - execution_time, - execution_stats, - } - } - - pub fn sql(&self) -> &String { - &self.sql - } - - pub fn execution_time(&self) -> &Duration { - &self.execution_time - } - - pub fn set_results(&mut self, results: Option>) { - self.results = results; - } - - pub fn results(&self) -> &Option> { - &self.results - } - - pub fn set_num_rows(&mut self, num_rows: Option) { - self.num_rows = num_rows; - } - - pub fn num_rows(&self) -> &Option { - &self.num_rows - } - - pub fn set_error(&mut self, error: Option) { - self.error = error; - } - - pub fn error(&self) -> &Option { - &self.error - } - - pub fn set_execution_time(&mut self, elapsed_time: Duration) { - self.execution_time = elapsed_time; - } - - pub fn execution_stats(&self) -> &Option { - &self.execution_stats - } - - pub fn set_execution_stats(&mut self, stats: Option) { - self.execution_stats = stats; - } -} pub fn get_keywords() -> Vec { keywords::ALL_KEYWORDS @@ -129,11 +55,11 @@ pub fn keyword_style() -> Style { pub struct SQLTabState<'app> { editor: TextArea<'app>, editor_editable: bool, - query: Option, query_results_state: Option>, result_batches: Option>, results_page: Option, execution_error: Option, + execution_task: Option>>, } impl<'app> SQLTabState<'app> { @@ -149,11 +75,11 @@ impl<'app> SQLTabState<'app> { Self { editor: textarea, editor_editable: false, - query: None, query_results_state: None, result_batches: None, results_page: None, execution_error: None, + execution_task: None, } } @@ -215,14 +141,6 @@ impl<'app> SQLTabState<'app> { self.editor_editable } - pub fn set_query(&mut self, query: Query) { - self.query = Some(query); - } - - pub fn query(&self) -> &Option { - &self.query - } - // TODO: Create Editor struct and move this there pub fn next_word(&mut self) { self.editor @@ -288,4 +206,12 @@ impl<'app> SQLTabState<'app> { } } } + + pub fn execution_task(&mut self) -> &mut Option>> { + &mut self.execution_task + } + + pub fn set_execution_task(&mut self, task: JoinHandle>) { + self.execution_task = Some(task); + } } diff --git a/tests/tui.rs b/tests/tui.rs index b66ecb1..5f6f89b 100644 --- a/tests/tui.rs +++ b/tests/tui.rs @@ -16,6 +16,8 @@ // under the License. // +use datafusion::arrow::array::RecordBatch; +use datafusion::common::Result; use dft::{ app::{state::initialize, App, AppEvent}, execution::ExecutionContext, @@ -42,7 +44,8 @@ impl<'app> TestApp<'app> { let config_path = tempdir().unwrap(); let state = initialize(config_path.path().to_path_buf()); let execution = ExecutionContext::try_new(&state.config.execution).unwrap(); - let app = App::new(state, execution); + let mut app = App::new(state, execution); + app.enter(false).unwrap(); Self { config_path, app } } @@ -55,4 +58,15 @@ impl<'app> TestApp<'app> { pub fn state(&self) -> &dft::app::state::AppState { self.app.state() } + + pub async fn wait_for_ddl(&mut self) { + if let Some(handle) = self.app.ddl_task().take() { + handle.await.unwrap(); + } + } + + pub async fn execute_sql(&self, sql: &str) -> Result> { + let ctx = self.app.execution().session_ctx().clone(); + ctx.sql(sql).await.unwrap().collect().await + } } diff --git a/tests/tui_cases/ddl.rs b/tests/tui_cases/ddl.rs new file mode 100644 index 0000000..da9c223 --- /dev/null +++ b/tests/tui_cases/ddl.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::assert_batches_eq; +use dft::app::AppEvent; + +use crate::TestApp; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_create_table_ddl() { + let mut test_app = TestApp::new(); + + let ddl = "CREATE TABLE foo AS VALUES (1);"; + test_app + .handle_app_event(AppEvent::ExecuteDDL(ddl.to_string())) + .unwrap(); + test_app.wait_for_ddl().await; + + let sql = "SELECT * FROM foo;"; + let batches = test_app.execute_sql(sql).await.unwrap(); + + let expected = [ + "+---------+", + "| column1 |", + "+---------+", + "| 1 |", + "+---------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_create_table_in_new_schema() { + let mut test_app = TestApp::new(); + + let create_schema = "CREATE SCHEMA foo;"; + let create_table = "CREATE TABLE foo.bar AS VALUES (1);"; + let combined = [create_schema, create_table].join(";"); + test_app + .handle_app_event(AppEvent::ExecuteDDL(combined)) + .unwrap(); + test_app.wait_for_ddl().await; + + let sql = "SELECT * FROM foo.bar;"; + let batches = test_app.execute_sql(sql).await.unwrap(); + + let expected = [ + "+---------+", + "| column1 |", + "+---------+", + "| 1 |", + "+---------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_create_table_in_new_catalog() { + let mut test_app = TestApp::new(); + + let create_catalog = "CREATE DATABASE foo;"; + let create_schema = "CREATE SCHEMA foo.bar;"; + let create_table = "CREATE TABLE foo.bar.baz AS VALUES (1);"; + let combined = [create_catalog, create_schema, create_table].join(";"); + test_app + .handle_app_event(AppEvent::ExecuteDDL(combined)) + .unwrap(); + test_app.wait_for_ddl().await; + + let sql = "SELECT * FROM foo.bar.baz;"; + let batches = test_app.execute_sql(sql).await.unwrap(); + + let expected = [ + "+---------+", + "| column1 |", + "+---------+", + "| 1 |", + "+---------+", + ]; + assert_batches_eq!(expected, &batches); +} diff --git a/tests/tui_cases/mod.rs b/tests/tui_cases/mod.rs index f38f867..e9a8a34 100644 --- a/tests/tui_cases/mod.rs +++ b/tests/tui_cases/mod.rs @@ -14,7 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -// +mod ddl; mod pagination; mod quit;