Skip to content

Commit

Permalink
feat(frontend):change query_handle to return data stream (#5556)
Browse files Browse the repository at this point in the history
* change query_hanlde to return data_stream

* tranfer Vec<Row> instead of Row

* use Option<i32> in row_cnt

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
ZENOTME and mergify[bot] authored Sep 28, 2022
1 parent bc3974a commit 960e11b
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 91 deletions.
24 changes: 13 additions & 11 deletions src/frontend/src/handler/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ pub fn handle_describe(context: OptimizerContext, table_name: ObjectName) -> Res
// TODO: recover the original user statement
Ok(PgResponse::new(
StatementType::DESCRIBE_TABLE,
rows.len() as i32,
Some(rows.len() as i32),
rows,
vec![
PgFieldDescriptor::new("Name".to_owned(), TypeOid::Varchar),
Expand Down Expand Up @@ -150,16 +150,18 @@ mod tests {

let mut columns = HashMap::new();
#[for_await]
for row in pg_response.values_stream() {
let row = row.unwrap();
columns.insert(
std::str::from_utf8(row.index(0).as_ref().unwrap())
.unwrap()
.to_string(),
std::str::from_utf8(row.index(1).as_ref().unwrap())
.unwrap()
.to_string(),
);
for row_set in pg_response.values_stream() {
let row_set = row_set.unwrap();
for row in row_set {
columns.insert(
std::str::from_utf8(row.index(0).as_ref().unwrap())
.unwrap()
.to_string(),
std::str::from_utf8(row.index(1).as_ref().unwrap())
.unwrap()
.to_string(),
);
}
}

let expected_columns: HashMap<String, String> = maplit::hashmap! {
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/handler/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ pub(super) fn handle_explain(

Ok(PgResponse::new(
StatementType::EXPLAIN,
rows.len() as i32,
Some(rows.len() as i32),
rows,
vec![PgFieldDescriptor::new(
"QUERY PLAN".to_owned(),
Expand Down
37 changes: 23 additions & 14 deletions src/frontend/src/handler/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use futures::StreamExt;
use pgwire::pg_field_descriptor::PgFieldDescriptor;
use pgwire::pg_response::{PgResponse, StatementType};
use pgwire::types::Row;
use risingwave_common::error::Result;
use pgwire::pg_response::{PgResponse, PgResultSet, StatementType};
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::session_config::QueryMode;
use risingwave_sqlparser::ast::Statement;
use tracing::debug;
Expand All @@ -29,7 +29,7 @@ use crate::scheduler::{
};
use crate::session::{OptimizerContext, SessionImpl};

pub type QueryResultSet = Vec<Row>;
pub type QueryResultSet = PgResultSet;

pub async fn handle_query(
context: OptimizerContext,
Expand All @@ -55,7 +55,7 @@ pub async fn handle_query(
};
debug!("query_mode:{:?}", query_mode);

let (rows, pg_descs) = match query_mode {
let (mut row_stream, pg_descs) = match query_mode {
QueryMode::Local => {
if stmt_type.is_dml() {
// DML do not support local mode yet.
Expand All @@ -69,16 +69,23 @@ pub async fn handle_query(
};

let rows_count = match stmt_type {
StatementType::SELECT => rows.len() as i32,
StatementType::SELECT => None,
StatementType::INSERT | StatementType::DELETE | StatementType::UPDATE => {
let first_row = rows[0].values();
let affected_rows_str = first_row[0]
// Get the row from the row_stream.
let first_row_set = row_stream
.next()
.await
.expect("compute node should return affected rows in output")
.map_err(|err| RwError::from(ErrorCode::InternalError(format!("{}", err))))?;
let affected_rows_str = first_row_set[0].values()[0]
.as_ref()
.expect("compute node should return affected rows in output");
String::from_utf8(affected_rows_str.to_vec())
.unwrap()
.parse()
.unwrap_or_default()
Some(
String::from_utf8(affected_rows_str.to_vec())
.unwrap()
.parse()
.unwrap_or_default(),
)
}
_ => unreachable!(),
};
Expand All @@ -88,7 +95,9 @@ pub async fn handle_query(
flush_for_write(&session, stmt_type).await?;
}

Ok(PgResponse::new(stmt_type, rows_count, rows, pg_descs))
Ok(PgResponse::new_for_stream(
stmt_type, rows_count, row_stream, pg_descs,
))
}

fn to_statement_type(stmt: &Statement) -> StatementType {
Expand Down Expand Up @@ -195,7 +204,7 @@ async fn local_execute(
// TODO: Passing sql here
let execution =
LocalQueryExecution::new(query, front_env.clone(), "", epoch, session.auth_context());
let rsp = Ok((execution.collect_rows(format).await?, pg_descs));
let rsp = Ok((execution.stream_rows(format), pg_descs));

// Release hummock snapshot for local execution.
hummock_snapshot_manager.release(epoch, &query_id).await;
Expand Down
26 changes: 14 additions & 12 deletions src/frontend/src/handler/show.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pub fn handle_show_object(context: OptimizerContext, command: ShowObject) -> Res

return Ok(PgResponse::new(
StatementType::SHOW_COMMAND,
rows.len() as i32,
Some(rows.len() as i32),
rows,
vec![
PgFieldDescriptor::new("Name".to_owned(), TypeOid::Varchar),
Expand All @@ -112,7 +112,7 @@ pub fn handle_show_object(context: OptimizerContext, command: ShowObject) -> Res

Ok(PgResponse::new(
StatementType::SHOW_COMMAND,
rows.len() as i32,
Some(rows.len() as i32),
rows,
vec![PgFieldDescriptor::new("Name".to_owned(), TypeOid::Varchar)],
))
Expand Down Expand Up @@ -174,16 +174,18 @@ mod tests {

let mut columns = HashMap::new();
#[for_await]
for row in pg_response.values_stream() {
let row = row.unwrap();
columns.insert(
std::str::from_utf8(row.index(0).as_ref().unwrap())
.unwrap()
.to_string(),
std::str::from_utf8(row.index(1).as_ref().unwrap())
.unwrap()
.to_string(),
);
for row_set in pg_response.values_stream() {
let row_set = row_set.unwrap();
for row in row_set {
columns.insert(
std::str::from_utf8(row.index(0).as_ref().unwrap())
.unwrap()
.to_string(),
std::str::from_utf8(row.index(1).as_ref().unwrap())
.unwrap()
.to_string(),
);
}
}

let expected_columns: HashMap<String, String> = maplit::hashmap! {
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/handler/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub(super) fn handle_show(context: OptimizerContext, variable: Vec<Ident>) -> Re

Ok(PgResponse::new(
StatementType::SHOW_COMMAND,
1,
Some(1),
vec![row],
vec![PgFieldDescriptor::new(
name.to_ascii_lowercase(),
Expand All @@ -79,7 +79,7 @@ pub(super) fn handle_show_all(context: &OptimizerContext) -> Result<PgResponse>

Ok(PgResponse::new(
StatementType::SHOW_COMMAND,
all_variables.len() as i32,
Some(all_variables.len() as i32),
rows,
vec![
PgFieldDescriptor::new("Name".to_string(), TypeOid::Varchar),
Expand Down
17 changes: 11 additions & 6 deletions src/frontend/src/scheduler/distributed/query_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ use std::sync::Arc;

use futures::StreamExt;
use futures_async_stream::try_stream;
use pgwire::pg_server::{Session, SessionId};
use pgwire::pg_server::{BoxedError, Session, SessionId};
use pgwire::types::Row;
use risingwave_batch::executor::BoxedDataChunkStream;
use risingwave_common::array::DataChunk;
use risingwave_common::error::RwError;
Expand Down Expand Up @@ -122,7 +123,7 @@ impl QueryManager {

// TODO: Clean up queries status when ends. This should be done lazily.

query_result_fetcher.collect_rows_from_channel(format).await
Ok(query_result_fetcher.stream_from_channel(format))
}

pub fn cancel_queries_in_session(&self, session_id: SessionId) {
Expand Down Expand Up @@ -190,13 +191,17 @@ impl QueryResultFetcher {
Box::pin(self.run_inner())
}

async fn collect_rows_from_channel(mut self, format: bool) -> SchedulerResult<QueryResultSet> {
let mut result_sets = vec![];
#[try_stream(ok = Vec<Row>, error = BoxedError)]
async fn stream_from_channel_inner(mut self, format: bool) {
while let Some(chunk_inner) = self.chunk_rx.recv().await {
let chunk = chunk_inner?;
result_sets.extend(to_pg_rows(chunk, format));
let rows = to_pg_rows(chunk, format);
yield rows;
}
Ok(result_sets)
}

fn stream_from_channel(self, format: bool) -> QueryResultSet {
Box::pin(self.stream_from_channel_inner(format))
}
}

Expand Down
16 changes: 10 additions & 6 deletions src/frontend/src/scheduler/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
use std::collections::HashMap;
use std::sync::Arc;

use futures_async_stream::{for_await, try_stream};
use futures_async_stream::try_stream;
use itertools::Itertools;
use pgwire::pg_server::BoxedError;
use pgwire::types::Row;
use risingwave_batch::executor::{BoxedDataChunkStream, ExecutorBuilder};
use risingwave_batch::task::TaskId;
use risingwave_common::array::DataChunk;
Expand Down Expand Up @@ -99,15 +101,17 @@ impl LocalQueryExecution {
Box::pin(self.run_inner())
}

pub async fn collect_rows(self, format: bool) -> SchedulerResult<QueryResultSet> {
let data_stream = self.run();
let mut rows = vec![];
#[try_stream(ok = Vec<Row>, error = BoxedError)]
async fn stream_row_inner(data_stream: BoxedDataChunkStream, format: bool) {
#[for_await]
for chunk in data_stream {
rows.extend(to_pg_rows(chunk?, format));
let rows = to_pg_rows(chunk?, format);
yield rows;
}
}

Ok(rows)
pub fn stream_rows(self, format: bool) -> QueryResultSet {
Box::pin(Self::stream_row_inner(self.run(), format))
}

/// Convert query to plan fragment.
Expand Down
6 changes: 4 additions & 2 deletions src/frontend/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ impl LocalFrontend {
let mut rsp = self.run_sql(sql).await.unwrap();
let mut res = vec![];
#[for_await]
for row in rsp.values_stream() {
res.push(format!("{:?}", row.unwrap()));
for row_set in rsp.values_stream() {
for row in row_set.unwrap() {
res.push(format!("{:?}", row))
}
}
res
}
Expand Down
41 changes: 28 additions & 13 deletions src/utils/pgwire/src/pg_extended.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use std::vec::IntoIter;

use bytes::Bytes;
use futures::stream::FusedStream;
Expand All @@ -30,6 +31,7 @@ use crate::pg_message::{BeCommandCompleteMessage, BeMessage};
use crate::pg_protocol::{cstr_to_str, PgStream};
use crate::pg_response::PgResponse;
use crate::pg_server::{Session, SessionManager};
use crate::types::Row;

#[derive(Default)]
pub struct PgStatement {
Expand Down Expand Up @@ -82,6 +84,7 @@ impl PgStatement {
is_query: self.is_query,
row_description: self.row_description.clone(),
result: None,
row_cache: vec![].into_iter(),
})
}

Expand All @@ -92,14 +95,14 @@ impl PgStatement {
}
}

#[derive(Default)]
pub struct PgPortal {
name: String,
query_string: String,
result_format: bool,
is_query: bool,
row_description: Vec<PgFieldDescriptor>,
result: Option<PgResponse>,
row_cache: IntoIter<Row>,
}

impl PgPortal {
Expand Down Expand Up @@ -135,6 +138,7 @@ impl PgPortal {
self.result.as_mut().unwrap()
};

// Indicate all data from stream have been completely consumed.
let mut query_end = false;
let mut query_row_count = 0;
if result.is_empty() {
Expand All @@ -143,23 +147,32 @@ impl PgPortal {
// fetch row data
// if row_limit is 0, fetch all rows
// if row_limit > 0, fetch row_limit rows
let stream = result.values_stream();
while row_limit == 0 || query_row_count < row_limit {
if let Some(row) = stream
.try_next()
.await
.map_err(|err| PsqlError::ExecuteError(err))?
{
msg_stream.write_no_flush(&BeMessage::DataRow(&row))?;
query_row_count += 1;
if self.row_cache.len() > 0 {
for row in self.row_cache.by_ref() {
msg_stream.write_no_flush(&BeMessage::DataRow(&row))?;
query_row_count += 1;
if row_limit > 0 && query_row_count >= row_limit {
break;
}
}
} else {
query_end = true;
break;
self.row_cache = if let Some(rows) = result
.values_stream()
.try_next()
.await
.map_err(|err| PsqlError::ExecuteError(err))?
{
rows.into_iter()
} else {
query_end = true;
break;
};
}
}
// Check if the result is consumed completely.
// If not, cache the result.
if stream.peekable().is_terminated() {
if self.row_cache.len() == 0 && result.values_stream().peekable().is_terminated() {
query_end = true;
}
if query_end {
Expand All @@ -177,7 +190,9 @@ impl PgPortal {
msg_stream.write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage {
stmt_type: result.get_stmt_type(),
notice: result.get_notice(),
rows_cnt: result.get_effected_rows_cnt(),
rows_cnt: result
.get_effected_rows_cnt()
.expect("row count should be set"),
}))?;
}

Expand Down
Loading

0 comments on commit 960e11b

Please sign in to comment.