diff --git a/src/client.rs b/src/client.rs index 64dfa8ac..e72dbf79 100644 --- a/src/client.rs +++ b/src/client.rs @@ -601,7 +601,7 @@ where // in case the client is sending some custom protocol messages, e.g. // SET SHARDING KEY TO 'bigint'; - let mut message = tokio::select! { + let message = tokio::select! { _ = self.shutdown.recv() => { if !self.admin { error_response_terminal( @@ -792,6 +792,8 @@ where // Set application_name. server.set_name(&self.application_name).await?; + let mut initial_message = Some(message); + // Transaction loop. Multiple queries can be issued by the client here. // The connection belongs to the client until the transaction is over, // or until the client disconnects if we are in session mode. @@ -799,40 +801,42 @@ where // If the client is in session mode, no more custom protocol // commands will be accepted. loop { - let mut message = if message.len() == 0 { - trace!("Waiting for message inside transaction or in session mode"); + let message = match initial_message { + None => { + trace!("Waiting for message inside transaction or in session mode"); - match read_message(&mut self.read).await { - Ok(message) => message, - Err(err) => { - // Client disconnected inside a transaction. - // Clean up the server and re-use it. - server.checkin_cleanup().await?; + match read_message(&mut self.read).await { + Ok(message) => message, + Err(err) => { + // Client disconnected inside a transaction. + // Clean up the server and re-use it. + server.checkin_cleanup().await?; - return Err(err); + return Err(err); + } } } - } else { - let msg = message.clone(); - message.clear(); - msg + Some(message) => { + initial_message = None; + message + } }; // The message will be forwarded to the server intact. We still would like to // parse it below to figure out what to do with it. - let original = message.clone(); - let code = message.get_u8() as char; - let _len = message.get_i32() as usize; + // Safe to unwrap because we know this message has a certain length and has the code + // This reads the first byte without advancing the internal pointer and mutating the bytes + let code = *message.get(0).unwrap() as char; trace!("Message: {}", code); match code { - // ReadyForQuery + // Query 'Q' => { debug!("Sending query to server"); - self.send_and_receive_loop(code, original, server, &address, &pool) + self.send_and_receive_loop(code, message, server, &address, &pool) .await?; if !server.in_transaction() { @@ -858,25 +862,25 @@ where // Parse // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. 'P' => { - self.buffer.put(&original[..]); + self.buffer.put(&message[..]); } // Bind // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' 'B' => { - self.buffer.put(&original[..]); + self.buffer.put(&message[..]); } // Describe // Command a client can issue to describe a previously prepared named statement. 'D' => { - self.buffer.put(&original[..]); + self.buffer.put(&message[..]); } // Execute // Execute a prepared statement prepared in `P` and bound in `B`. 'E' => { - self.buffer.put(&original[..]); + self.buffer.put(&message[..]); } // Sync @@ -884,9 +888,8 @@ where 'S' => { debug!("Sending query to server"); - self.buffer.put(&original[..]); + self.buffer.put(&message[..]); - // Clone after freeze does not allocate let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; // Almost certainly true @@ -929,14 +932,14 @@ where 'd' => { // Forward the data to the server, // don't buffer it since it can be rather large. - self.send_server_message(server, original, &address, &pool) + self.send_server_message(server, message, &address, &pool) .await?; } // CopyDone or CopyFail // Copy is done, successfully or not. 'c' | 'f' => { - self.send_server_message(server, original, &address, &pool) + self.send_server_message(server, message, &address, &pool) .await?; let response = self.receive_server_message(server, &address, &pool).await?; diff --git a/src/server.rs b/src/server.rs index dbac9bc0..d191eb74 100644 --- a/src/server.rs +++ b/src/server.rs @@ -457,7 +457,17 @@ impl Server { // which can leak between clients. This is a best effort to block bad clients // from poisoning a transaction-mode pool by setting inappropriate session variables match command_tag.as_str() { - "SET\0" | "PREPARE\0" => { + "SET\0" => { + // We don't detect set statements in transactions + // No great way to differentiate between set and set local + // As a result, we will miss cases when set statements are used in transactions + // This will reduce amount of discard statements sent + if !self.in_transaction { + debug!("Server connection marked for clean up"); + self.needs_cleanup = true; + } + } + "PREPARE\0" => { debug!("Server connection marked for clean up"); self.needs_cleanup = true; } @@ -595,7 +605,7 @@ impl Server { self.query("ROLLBACK").await?; } - // Client disconnected but it perfromed session-altering operations such as + // Client disconnected but it performed session-altering operations such as // SET statement_timeout to 1 or create a prepared statement. We clear that // to avoid leaking state between clients. For performance reasons we only // send `DISCARD ALL` if we think the session is altered instead of just sending diff --git a/tests/ruby/misc_spec.rb b/tests/ruby/misc_spec.rb index 6e79e1a4..1f5bf421 100644 --- a/tests/ruby/misc_spec.rb +++ b/tests/ruby/misc_spec.rb @@ -189,5 +189,30 @@ expect(processes.primary.count_query("DISCARD ALL")).to eq(10) end end + + context "transaction mode with transactions" do + let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 5, "transaction") } + it "Does not clear set statement state when declared in a transaction" do + 10.times do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.async_exec("SET SERVER ROLE to 'primary'") + conn.async_exec("BEGIN") + conn.async_exec("SET statement_timeout to 1000") + conn.async_exec("COMMIT") + conn.close + end + expect(processes.primary.count_query("DISCARD ALL")).to eq(0) + + 10.times do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.async_exec("SET SERVER ROLE to 'primary'") + conn.async_exec("BEGIN") + conn.async_exec("SET LOCAL statement_timeout to 1000") + conn.async_exec("COMMIT") + conn.close + end + expect(processes.primary.count_query("DISCARD ALL")).to eq(0) + end + end end end