Skip to content

Commit

Permalink
feat(test): optimize set stmts in simulation to avoid duplicate replay (
Browse files Browse the repository at this point in the history
  • Loading branch information
yezizp2012 authored Mar 10, 2023
1 parent b6244d7 commit 64d80d2
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 25 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/tests/simulation/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ etcd-client = { version = "0.2.17", package = "madsim-etcd-client" }
futures = { version = "0.3", default-features = false, features = ["alloc"] }
glob = "0.3"
itertools = "0.10"
lru = { git = "https://github.com/risingwavelabs/lru-rs.git", branch = "evict_by_timestamp" }
madsim = "0.2.17"
paste = "1"
pretty_assertions = "1"
Expand All @@ -32,6 +33,7 @@ risingwave_ctl = { path = "../../ctl" }
risingwave_frontend = { path = "../../frontend" }
risingwave_meta = { path = "../../meta" }
risingwave_pb = { path = "../../prost" }
risingwave_sqlparser = { path = "../../sqlparser" }
risingwave_sqlsmith = { path = "../sqlsmith" }
serde = "1.0.152"
serde_derive = "1.0.152"
Expand Down
123 changes: 98 additions & 25 deletions src/tests/simulation/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

use std::time::Duration;

use itertools::Itertools;
use lru::{Iter, LruCache};
use risingwave_sqlparser::ast::Statement;
use risingwave_sqlparser::parser::Parser;

/// A RisingWave client.
pub struct RisingWave {
client: tokio_postgres::Client,
Expand All @@ -22,26 +27,98 @@ pub struct RisingWave {
dbname: String,
/// The `SET` statements that have been executed on this client.
/// We need to replay them when reconnecting.
set_stmts: Vec<String>,
set_stmts: SetStmts,
}

/// `SetStmts` stores and compacts all `SET` statements that have been executed in the client
/// history.
pub struct SetStmts {
stmts_cache: LruCache<String, String>,
}

impl Default for SetStmts {
fn default() -> Self {
Self {
stmts_cache: LruCache::unbounded(),
}
}
}

struct SetStmtsIterator<'a, 'b>
where
'a: 'b,
{
_stmts: &'a SetStmts,
stmts_iter: core::iter::Rev<Iter<'b, String, String>>,
}

impl<'a, 'b> SetStmtsIterator<'a, 'b> {
fn new(stmts: &'a SetStmts) -> Self {
Self {
_stmts: stmts,
stmts_iter: stmts.stmts_cache.iter().rev(),
}
}
}

impl SetStmts {
fn push(&mut self, sql: &str) {
let ast = Parser::parse_sql(&sql).expect("a set statement should be parsed successfully");
match ast
.into_iter()
.exactly_one()
.expect("should contain only one statement")
{
// record `local` for variable and `SetTransaction` if supported in the future.
Statement::SetVariable {
local: _,
variable,
value: _,
} => {
let key = variable.real_value().to_lowercase();
// store complete sql as value.
self.stmts_cache.put(key, sql.to_string());
}
_ => unreachable!(),
}
}
}

impl Iterator for SetStmtsIterator<'_, '_> {
type Item = String;

fn next(&mut self) -> Option<Self::Item> {
let (_, stmt) = self.stmts_iter.next()?;
Some(stmt.clone())
}
}

impl RisingWave {
pub async fn connect(
host: String,
dbname: String,
) -> Result<Self, tokio_postgres::error::Error> {
Self::reconnect(host, dbname, vec![]).await
let set_stmts = SetStmts::default();
let (client, task) = Self::connect_inner(&host, &dbname, &set_stmts).await?;
Ok(Self {
client,
task,
host,
dbname,
set_stmts,
})
}

pub async fn reconnect(
host: String,
dbname: String,
set_stmts: Vec<String>,
) -> Result<Self, tokio_postgres::error::Error> {
pub async fn connect_inner(
host: &str,
dbname: &str,
set_stmts: &SetStmts,
) -> Result<(tokio_postgres::Client, tokio::task::JoinHandle<()>), tokio_postgres::error::Error>
{
let (client, connection) = tokio_postgres::Config::new()
.host(&host)
.host(host)
.port(4566)
.dbname(&dbname)
.dbname(dbname)
.user("root")
.connect_timeout(Duration::from_secs(5))
.connect(tokio_postgres::NoTls)
Expand All @@ -64,16 +141,17 @@ impl RisingWave {
.simple_query("SET VISIBILITY_MODE TO checkpoint;")
.await?;
// replay all SET statements
for stmt in &set_stmts {
client.simple_query(stmt).await?;
for stmt in SetStmtsIterator::new(&set_stmts) {
client.simple_query(&stmt).await?;
}
Ok(RisingWave {
client,
task,
host,
dbname,
set_stmts,
})
Ok((client, task))
}

pub async fn reconnect(&mut self) -> Result<(), tokio_postgres::error::Error> {
let (client, task) = Self::connect_inner(&self.host, &self.dbname, &self.set_stmts).await?;
self.client = client;
self.task = task;
Ok(())
}

/// Returns a reference of the inner Postgres client.
Expand All @@ -97,16 +175,11 @@ impl sqllogictest::AsyncDB for RisingWave {

if self.client.is_closed() {
// connection error, reset the client
*self = Self::reconnect(
self.host.clone(),
self.dbname.clone(),
self.set_stmts.clone(),
)
.await?;
self.reconnect().await?;
}

if sql.trim_start().to_lowercase().starts_with("set") {
self.set_stmts.push(sql.to_string());
self.set_stmts.push(sql);
}

let mut output = vec![];
Expand Down

0 comments on commit 64d80d2

Please sign in to comment.