Skip to content

Commit

Permalink
Broadcast progress
Browse files Browse the repository at this point in the history
This adds a `Client::subscribe_progress` method that allows subscribing to query progress events sent by the server.
This is very useful to track the progress of larger queries.
The implementation uses a `tokio::sync::broadcast`, which only keeps the most recent events and sends them to subscribers.

This was tested through the "basic" example, adapted to display progress, and against a real database.

Note that the `Client::query*` interfaces currently do not provide the ID of the query, so it is not yet possible to match progress events against exact queries.
Nevertheless, the commit already populates the ID of queries, and returns it as part of the progress stream.

Later, The `Client::query*` should likely be adapted to return the query ID in addition to the Block stream. Then either:
- We keep the same `subscribe_progress` interface
- We return a progress stream for this exact query along with the query ID and the block stream.
- We change the stream be a stream of enum representing either `Block` or `Progress`.

An alternative is to add a method to `Client` to return the ID of the query currently executing; this would be almost trivial.
  • Loading branch information
cpg314 committed Dec 28, 2023
1 parent f91c3af commit a33c479
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 21 deletions.
66 changes: 53 additions & 13 deletions klickhouse/examples/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@ use chrono::Utc;
use futures::StreamExt;
use klickhouse::*;

/*
create table my_user_data (
id UUID,
user_data String,
created_at DateTime('UTC')
) Engine=Memory;
*/
#[derive(Row, Debug, Default)]
pub struct MyUserData {
id: Uuid,
Expand All @@ -25,23 +18,70 @@ async fn main() {
.await
.unwrap();

let mut row = MyUserData::default();
row.id = Uuid::new_v4();
row.user_data = "some important stuff!".to_string();
row.created_at = Utc::now().try_into().unwrap();
// Retrieve and display query progress events
let mut progress = client.subscribe_progress();
let progress_task = tokio::task::spawn(async move {
let mut current_query = Uuid::nil();
let mut progress_total = Progress::default();
while let Ok((query, progress)) = progress.recv().await {
if query != current_query {
progress_total = Progress::default();
current_query = query;
}
progress_total += progress;
println!(
"Progress on query {}: {}/{} {:.2}%",
query,
progress_total.read_rows,
progress_total.new_total_rows_to_read,
100.0 * progress_total.read_rows as f64
/ progress_total.new_total_rows_to_read as f64
);
}
});

// Prepare table
client
.insert_native_block("INSERT INTO my_user_data FORMAT native", vec![row])
.execute("DROP TABLE IF EXISTS klickhouse_example")
.await
.unwrap();
client
.execute(
"
CREATE TABLE klickhouse_example (
id UUID,
user_data String,
created_at DateTime('UTC'))
Engine=MergeTree() ORDER BY created_at;",
)
.await
.unwrap();

// Insert rows
let rows = (0..5)
.map(|_| MyUserData {
id: Uuid::new_v4(),
user_data: "some important stuff!".to_string(),
created_at: Utc::now().try_into().unwrap(),
})
.collect();
client
.insert_native_block("INSERT INTO klickhouse_example FORMAT native", rows)
.await
.unwrap();

// Read back rows
let mut all_rows = client
.query::<MyUserData>("select * from my_user_data;")
.query::<MyUserData>("SELECT * FROM klickhouse_example;")
.await
.unwrap();

while let Some(row) = all_rows.next().await {
let row = row.unwrap();
println!("row received '{}': {:?}", row.id, row);
}

// Drop the client so that the progress task finishes.
drop(client);
progress_task.await.unwrap();
}
40 changes: 33 additions & 7 deletions klickhouse/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ use tokio::{
net::{TcpStream, ToSocketAddrs},
select,
sync::{
broadcast,
mpsc::{self, Receiver},
oneshot,
},
};
use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid;

use crate::{
block::{Block, BlockInfo},
Expand All @@ -22,17 +24,22 @@ use crate::{
ClientHello, ClientInfo, InternalClientOut, Query, QueryKind, QueryProcessingStage,
},
io::{ClickhouseRead, ClickhouseWrite},
progress::Progress,
protocol::{self, ServerPacket},
KlickhouseError, ParsedQuery, RawRow, Result,
};
use log::*;

// Maximum number of progress statuses to keep in memory. New statuses evict old ones.
const PROGRESS_CAPACITY: usize = 100;

struct InnerClient<R: ClickhouseRead, W: ClickhouseWrite> {
input: InternalClientIn<R>,
output: InternalClientOut<W>,
options: ClientOptions,
pending_queries: VecDeque<PendingQuery>,
executing_query: Option<mpsc::Sender<Result<Block>>>,
executing_query: Option<(Uuid, mpsc::Sender<Result<Block>>)>,
progress: broadcast::Sender<(Uuid, Progress)>,
}

struct PendingQuery {
Expand All @@ -48,13 +55,15 @@ impl<R: ClickhouseRead + 'static, W: ClickhouseWrite> InnerClient<R, W> {
options,
pending_queries: VecDeque::new(),
executing_query: None,
progress: broadcast::channel(PROGRESS_CAPACITY).0,
}
}

async fn dispatch_query(&mut self, query: PendingQuery) -> Result<()> {
let id = Uuid::new_v4();
self.output
.send_query(Query {
id: "",
id: &id.to_string(),
info: ClientInfo {
kind: QueryKind::InitialQuery,
initial_user: "",
Expand All @@ -79,7 +88,7 @@ impl<R: ClickhouseRead + 'static, W: ClickhouseWrite> InnerClient<R, W> {

let (sender, receiver) = mpsc::channel(32);
query.response.send(receiver).ok();
self.executing_query = Some(sender);
self.executing_query = Some((id, sender));
self.output
.send_data(
Block {
Expand Down Expand Up @@ -124,7 +133,7 @@ impl<R: ClickhouseRead + 'static, W: ClickhouseWrite> InnerClient<R, W> {
))
}
ServerPacket::Data(block) => {
if let Some(current) = self.executing_query.as_ref() {
if let Some((_, current)) = self.executing_query.as_ref() {
current.send(Ok(block.block)).await.ok();
} else {
return Err(KlickhouseError::ProtocolError(
Expand All @@ -133,7 +142,7 @@ impl<R: ClickhouseRead + 'static, W: ClickhouseWrite> InnerClient<R, W> {
}
}
ServerPacket::Exception(e) => {
if let Some(current) = self.executing_query.take() {
if let Some((_, current)) = self.executing_query.take() {
current.send(Err(e.emit())).await.ok();
if let Some(query) = self.pending_queries.pop_front() {
self.dispatch_query(query).await?;
Expand All @@ -142,7 +151,11 @@ impl<R: ClickhouseRead + 'static, W: ClickhouseWrite> InnerClient<R, W> {
return Err(e.emit());
}
}
ServerPacket::Progress(_) => {}
ServerPacket::Progress(progress) => {
if let Some((id, _)) = &self.executing_query {
let _ = self.progress.send((*id, progress));
}
}
ServerPacket::Pong => {}
ServerPacket::EndOfStream => {
if self.executing_query.take().is_none() {
Expand Down Expand Up @@ -220,6 +233,7 @@ struct ClientRequest {
#[derive(Clone)]
pub struct Client {
sender: mpsc::Sender<ClientRequest>,
progress: broadcast::Sender<(Uuid, Progress)>,
}

/// Options set for a Clickhouse connection.
Expand Down Expand Up @@ -278,9 +292,11 @@ impl Client {
async fn start<R: ClickhouseRead + 'static, W: ClickhouseWrite>(
inner: InnerClient<R, W>,
) -> Result<Self> {
let progress = inner.progress.clone();
let (sender, receiver) = mpsc::channel(1024);

tokio::spawn(inner.run(receiver));
let client = Client { sender };
let client = Client { sender, progress };
client
.execute("SET date_time_input_format='best_effort'")
.await?;
Expand Down Expand Up @@ -526,4 +542,14 @@ impl Client {
pub fn is_closed(&self) -> bool {
self.sender.is_closed()
}

/// Receive progress on the queries as they execute.
///
/// TODO: There is currently no way to retrieve the ID of a query launched
/// with `query` or `execute.`
/// The signature of these functions should be modified to also return
/// an ID (and possibly directly the streaming broadcast).
pub fn subscribe_progress(&self) -> broadcast::Receiver<(Uuid, Progress)> {
self.progress.subscribe()
}
}
1 change: 1 addition & 0 deletions klickhouse/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod migrate;
#[cfg(feature = "refinery")]
pub use migrate::*;
mod progress;
pub use progress::*;
mod protocol;
mod query;
pub mod query_parser;
Expand Down
31 changes: 30 additions & 1 deletion klickhouse/src/progress.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
#[derive(Debug, Clone, Copy)]
/// Query execution progress.
/// Values are delta and must be summed.
///
/// See https://clickhouse.com/codebrowser/ClickHouse/src/IO/Progress.h.html
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct Progress {
pub read_rows: u64,
pub read_bytes: u64,
pub new_total_rows_to_read: u64,
pub new_written_rows: Option<u64>,
pub new_written_bytes: Option<u64>,
}
impl std::ops::Add for Progress {
type Output = Progress;

fn add(self, rhs: Self) -> Self::Output {
let sum_opt = |opt1, opt2| match (opt1, opt2) {
(Some(a), Some(b)) => Some(a + b),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
Self::Output {
read_rows: self.read_rows + rhs.read_rows,
read_bytes: self.read_bytes + rhs.read_bytes,
new_total_rows_to_read: self.new_total_rows_to_read + rhs.new_total_rows_to_read,
new_written_rows: sum_opt(self.new_written_rows, rhs.new_written_rows),
new_written_bytes: sum_opt(self.new_written_bytes, rhs.new_written_bytes),
}
}
}

impl std::ops::AddAssign for Progress {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}

0 comments on commit a33c479

Please sign in to comment.