diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 77c5f187f..1aaf82020 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -197,6 +197,17 @@ impl Client { /// /// Panics if the number of parameters provided does not match the number expected. pub fn execute(&mut self, statement: &Statement, params: &[&dyn ToSql]) -> impls::Execute { + self.execute_iter(statement, params.iter().cloned()) + } + + /// Like [`execute`], but takes an iterator of parameters rather than a slice. + /// + /// [`execute`]: #method.execute + pub fn execute_iter<'a, I>(&mut self, statement: &Statement, params: I) -> impls::Execute + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { impls::Execute(self.0.execute(&statement.0, params)) } @@ -206,6 +217,17 @@ impl Client { /// /// Panics if the number of parameters provided does not match the number expected. pub fn query(&mut self, statement: &Statement, params: &[&dyn ToSql]) -> impls::Query { + self.query_iter(statement, params.iter().cloned()) + } + + /// Like [`query`], but takes an iterator of parameters rather than a slice. + /// + /// [`query`]: #method.query + pub fn query_iter<'a, I>(&mut self, statement: &Statement, params: I) -> impls::Query + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { impls::Query(self.0.query(&statement.0, params)) } @@ -214,10 +236,22 @@ impl Client { /// Portals only last for the duration of the transaction in which they are created - in particular, a portal /// created outside of a transaction is immediately destroyed. Portals can only be used on the connection that /// created them. + /// /// # Panics /// /// Panics if the number of parameters provided does not match the number expected. pub fn bind(&mut self, statement: &Statement, params: &[&dyn ToSql]) -> impls::Bind { + self.bind_iter(statement, params.iter().cloned()) + } + + /// Like [`bind`], but takes an iterator of parameters rather than a slice. + /// + /// [`bind`]: #method.bind + pub fn bind_iter<'a, I>(&mut self, statement: &Statement, params: I) -> impls::Bind + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { impls::Bind(self.0.bind(&statement.0, next_portal(), params)) } @@ -233,6 +267,10 @@ impl Client { /// /// The data in the provided stream is passed along to the server verbatim; it is the caller's responsibility to /// ensure it uses the proper format. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. pub fn copy_in( &mut self, statement: &Statement, @@ -245,12 +283,48 @@ impl Client { ::Buf: 'static + Send, // FIXME error type? S::Error: Into>, + { + self.copy_in_iter(statement, params.iter().cloned(), stream) + } + + /// Like [`copy_in`], except that it takes an iterator of parameters rather than a slice. + /// + /// [`copy_in`]: #method.copy_in + pub fn copy_in_iter<'a, I, S>( + &mut self, + statement: &Statement, + params: I, + stream: S, + ) -> impls::CopyIn + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + S: Stream, + S::Item: IntoBuf, + ::Buf: 'static + Send, + // FIXME error type? + S::Error: Into>, { impls::CopyIn(self.0.copy_in(&statement.0, params, stream)) } /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. pub fn copy_out(&mut self, statement: &Statement, params: &[&dyn ToSql]) -> impls::CopyOut { + self.copy_out_iter(statement, params.iter().cloned()) + } + + /// Like [`copy_out`], except that it takes an iterator of parameters rather than a slice. + /// + /// [`copy_out`]: #method.copy_out + pub fn copy_out_iter<'a, I>(&mut self, statement: &Statement, params: I) -> impls::CopyOut + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { impls::CopyOut(self.0.copy_out(&statement.0, params)) } diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs index 7fb070e90..46184bb9d 100644 --- a/tokio-postgres/src/proto/client.rs +++ b/tokio-postgres/src/proto/client.rs @@ -165,7 +165,11 @@ impl Client { PrepareFuture::new(self.clone(), pending, name) } - pub fn execute(&self, statement: &Statement, params: &[&dyn ToSql]) -> ExecuteFuture { + pub fn execute<'a, I>(&self, statement: &Statement, params: I) -> ExecuteFuture + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { let pending = PendingRequest( self.excecute_message(statement, params) .map(|m| (RequestMessages::Single(m), self.0.idle.guard())), @@ -173,7 +177,11 @@ impl Client { ExecuteFuture::new(self.clone(), pending, statement.clone()) } - pub fn query(&self, statement: &Statement, params: &[&dyn ToSql]) -> QueryStream { + pub fn query<'a, I>(&self, statement: &Statement, params: I) -> QueryStream + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { let pending = PendingRequest( self.excecute_message(statement, params) .map(|m| (RequestMessages::Single(m), self.0.idle.guard())), @@ -181,7 +189,11 @@ impl Client { QueryStream::new(self.clone(), pending, statement.clone()) } - pub fn bind(&self, statement: &Statement, name: String, params: &[&dyn ToSql]) -> BindFuture { + pub fn bind<'a, I>(&self, statement: &Statement, name: String, params: I) -> BindFuture + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { let mut buf = self.bind_message(statement, &name, params); if let Ok(ref mut buf) = buf { frontend::sync(buf); @@ -204,17 +216,14 @@ impl Client { QueryStream::new(self.clone(), pending, portal.clone()) } - pub fn copy_in( - &self, - statement: &Statement, - params: &[&dyn ToSql], - stream: S, - ) -> CopyInFuture + pub fn copy_in<'a, S, I>(&self, statement: &Statement, params: I, stream: S) -> CopyInFuture where S: Stream, S::Item: IntoBuf, ::Buf: 'static + Send, S::Error: Into>, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, { let (mut sender, receiver) = mpsc::channel(1); let pending = PendingRequest(self.excecute_message(statement, params).map(|data| { @@ -233,7 +242,11 @@ impl Client { CopyInFuture::new(self.clone(), pending, statement.clone(), stream, sender) } - pub fn copy_out(&self, statement: &Statement, params: &[&dyn ToSql]) -> CopyOutStream { + pub fn copy_out<'a, I>(&self, statement: &Statement, params: I) -> CopyOutStream + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { let pending = PendingRequest( self.excecute_message(statement, params) .map(|m| (RequestMessages::Single(m), self.0.idle.guard())), @@ -289,12 +302,18 @@ impl Client { }); } - fn bind_message( + fn bind_message<'a, I>( &self, statement: &Statement, name: &str, - params: &[&dyn ToSql], - ) -> Result, Error> { + params: I, + ) -> Result, Error> + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let params = params.into_iter(); + assert!( statement.params().len() == params.len(), "expected {} parameters but got {}", @@ -308,7 +327,7 @@ impl Client { name, statement.name(), Some(1), - params.iter().zip(statement.params()).enumerate(), + params.zip(statement.params()).enumerate(), |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) { Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), @@ -327,11 +346,15 @@ impl Client { } } - fn excecute_message( + fn excecute_message<'a, I>( &self, statement: &Statement, - params: &[&dyn ToSql], - ) -> Result { + params: I, + ) -> Result + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { let mut buf = self.bind_message(statement, "", params)?; frontend::execute("", 0, &mut buf).map_err(Error::parse)?; frontend::sync(&mut buf); diff --git a/tokio-postgres/src/proto/typeinfo.rs b/tokio-postgres/src/proto/typeinfo.rs index 02e35eaeb..15a657b6c 100644 --- a/tokio-postgres/src/proto/typeinfo.rs +++ b/tokio-postgres/src/proto/typeinfo.rs @@ -10,7 +10,7 @@ use crate::proto::query::QueryStream; use crate::proto::statement::Statement; use crate::proto::typeinfo_composite::TypeinfoCompositeFuture; use crate::proto::typeinfo_enum::TypeinfoEnumFuture; -use crate::types::{Kind, Oid, Type}; +use crate::types::{Kind, Oid, ToSql, Type}; const TYPEINFO_QUERY: &str = " SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid @@ -114,7 +114,10 @@ impl PollTypeinfo for Typeinfo { match state.client.typeinfo_query() { Some(statement) => transition!(QueryingTypeinfo { - future: state.client.query(&statement, &[&state.oid]).collect(), + future: state + .client + .query(&statement, [&state.oid as &dyn ToSql].iter().cloned()) + .collect(), oid: state.oid, client: state.client, }), @@ -149,7 +152,10 @@ impl PollTypeinfo for Typeinfo { }; let state = state.take(); - let future = state.client.query(&statement, &[&state.oid]).collect(); + let future = state + .client + .query(&statement, [&state.oid as &dyn ToSql].iter().cloned()) + .collect(); state.client.set_typeinfo_query(&statement); transition!(QueryingTypeinfo { future, @@ -164,7 +170,10 @@ impl PollTypeinfo for Typeinfo { let statement = try_ready!(state.future.poll()); let state = state.take(); - let future = state.client.query(&statement, &[&state.oid]).collect(); + let future = state + .client + .query(&statement, [&state.oid as &dyn ToSql].iter().cloned()) + .collect(); state.client.set_typeinfo_query(&statement); transition!(QueryingTypeinfo { future, diff --git a/tokio-postgres/src/proto/typeinfo_composite.rs b/tokio-postgres/src/proto/typeinfo_composite.rs index f424fabcb..31398389d 100644 --- a/tokio-postgres/src/proto/typeinfo_composite.rs +++ b/tokio-postgres/src/proto/typeinfo_composite.rs @@ -11,7 +11,7 @@ use crate::proto::prepare::PrepareFuture; use crate::proto::query::QueryStream; use crate::proto::statement::Statement; use crate::proto::typeinfo::TypeinfoFuture; -use crate::types::{Field, Oid}; +use crate::types::{Field, Oid, ToSql}; const TYPEINFO_COMPOSITE_QUERY: &str = " SELECT attname, atttypid @@ -59,7 +59,10 @@ impl PollTypeinfoComposite for TypeinfoComposite { match state.client.typeinfo_composite_query() { Some(statement) => transition!(QueryingCompositeFields { - future: state.client.query(&statement, &[&state.oid]).collect(), + future: state + .client + .query(&statement, [&state.oid as &dyn ToSql].iter().cloned()) + .collect(), client: state.client, }), None => transition!(PreparingTypeinfoComposite { @@ -82,7 +85,10 @@ impl PollTypeinfoComposite for TypeinfoComposite { state.client.set_typeinfo_composite_query(&statement); transition!(QueryingCompositeFields { - future: state.client.query(&statement, &[&state.oid]).collect(), + future: state + .client + .query(&statement, [&state.oid as &dyn ToSql].iter().cloned()) + .collect(), client: state.client, }) } diff --git a/tokio-postgres/src/proto/typeinfo_enum.rs b/tokio-postgres/src/proto/typeinfo_enum.rs index d264d3ab7..dbc391070 100644 --- a/tokio-postgres/src/proto/typeinfo_enum.rs +++ b/tokio-postgres/src/proto/typeinfo_enum.rs @@ -8,7 +8,7 @@ use crate::proto::client::Client; use crate::proto::prepare::PrepareFuture; use crate::proto::query::QueryStream; use crate::proto::statement::Statement; -use crate::types::Oid; +use crate::types::{Oid, ToSql}; const TYPEINFO_ENUM_QUERY: &str = " SELECT enumlabel @@ -58,7 +58,10 @@ impl PollTypeinfoEnum for TypeinfoEnum { match state.client.typeinfo_enum_query() { Some(statement) => transition!(QueryingEnumVariants { - future: state.client.query(&statement, &[&state.oid]).collect(), + future: state + .client + .query(&statement, [&state.oid as &dyn ToSql].iter().cloned()) + .collect(), client: state.client, }), None => transition!(PreparingTypeinfoEnum { @@ -98,7 +101,10 @@ impl PollTypeinfoEnum for TypeinfoEnum { state.client.set_typeinfo_enum_query(&statement); transition!(QueryingEnumVariants { - future: state.client.query(&statement, &[&state.oid]).collect(), + future: state + .client + .query(&statement, [&state.oid as &dyn ToSql].iter().cloned()) + .collect(), client: state.client, }) } @@ -111,7 +117,10 @@ impl PollTypeinfoEnum for TypeinfoEnum { state.client.set_typeinfo_enum_query(&statement); transition!(QueryingEnumVariants { - future: state.client.query(&statement, &[&state.oid]).collect(), + future: state + .client + .query(&statement, [&state.oid as &dyn ToSql].iter().cloned()) + .collect(), client: state.client, }) }