diff --git a/Cargo.lock b/Cargo.lock index 9b2634e34..0721d6016 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2181,6 +2181,40 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "deadpool" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ed5957ff93768adf7a65ab167a17835c3d2c3c50d084fe305174c112f468e2f" +dependencies = [ + "deadpool-runtime", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-postgres" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d697d376cbfa018c23eb4caab1fd1883dd9c906a8c034e8d9a3cb06a7e0bef9" +dependencies = [ + "async-trait", + "deadpool", + "getrandom 0.2.15", + "tokio", + "tokio-postgres", + "tracing", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +dependencies = [ + "tokio", +] + [[package]] name = "debugid" version = "0.8.0" @@ -7969,6 +8003,7 @@ version = "3.2.0-pre0" dependencies = [ "anyhow", "chrono", + "deadpool-postgres", "native-tls", "postgres-native-tls", "spin-core", diff --git a/crates/factor-outbound-pg/Cargo.toml b/crates/factor-outbound-pg/Cargo.toml index 47899aee9..3340ad80b 100644 --- a/crates/factor-outbound-pg/Cargo.toml +++ b/crates/factor-outbound-pg/Cargo.toml @@ -7,6 +7,7 @@ edition = { workspace = true } [dependencies] anyhow = { workspace = true } chrono = "0.4" +deadpool-postgres = { version = "0.14", features = ["rt_tokio_1"] } native-tls = "0.2" postgres-native-tls = "0.5" spin-core = { path = "../core" } diff --git a/crates/factor-outbound-pg/src/client.rs b/crates/factor-outbound-pg/src/client.rs index 3f0a890a9..53bdd8c55 100644 --- a/crates/factor-outbound-pg/src/client.rs +++ b/crates/factor-outbound-pg/src/client.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use native_tls::TlsConnector; use postgres_native_tls::MakeTlsConnector; use spin_world::async_trait; @@ -6,15 +6,78 @@ use spin_world::spin::postgres::postgres::{ self as v3, Column, DbDataType, DbValue, ParameterValue, RowSet, }; use tokio_postgres::types::Type; -use tokio_postgres::{config::SslMode, types::ToSql, Row}; -use tokio_postgres::{Client as TokioClient, NoTls, Socket}; +use tokio_postgres::{config::SslMode, types::ToSql, NoTls, Row}; + +const CONNECTION_POOL_SIZE: usize = 64; #[async_trait] -pub trait Client { - async fn build_client(address: &str) -> Result - where - Self: Sized; +pub trait ClientFactory: Send + Sync { + type Client: Client + Send + Sync + 'static; + fn new() -> Self; + async fn build_client(&mut self, address: &str) -> Result; +} + +pub struct PooledTokioClientFactory { + pools: std::collections::HashMap, +} + +#[async_trait] +impl ClientFactory for PooledTokioClientFactory { + type Client = deadpool_postgres::Object; + fn new() -> Self { + Self { + pools: Default::default(), + } + } + async fn build_client(&mut self, address: &str) -> Result { + let pool_entry = self.pools.entry(address.to_owned()); + let pool = match pool_entry { + std::collections::hash_map::Entry::Occupied(entry) => entry.into_mut(), + std::collections::hash_map::Entry::Vacant(entry) => { + let pool = create_connection_pool(address) + .context("establishing PostgreSQL connection pool")?; + entry.insert(pool) + } + }; + + Ok(pool.get().await?) + } +} + +fn create_connection_pool(address: &str) -> Result { + let config = address + .parse::() + .context("parsing Postgres connection string")?; + + tracing::debug!("Build new connection: {}", address); + + // TODO: This is slower but safer. Is it the right tradeoff? + // https://docs.rs/deadpool-postgres/latest/deadpool_postgres/enum.RecyclingMethod.html + let mgr_config = deadpool_postgres::ManagerConfig { + recycling_method: deadpool_postgres::RecyclingMethod::Clean, + }; + + let mgr = if config.get_ssl_mode() == SslMode::Disable { + deadpool_postgres::Manager::from_config(config, NoTls, mgr_config) + } else { + let builder = TlsConnector::builder(); + let connector = MakeTlsConnector::new(builder.build()?); + deadpool_postgres::Manager::from_config(config, connector, mgr_config) + }; + // TODO: what is our max size heuristic? Should this be passed in soe that different + // hosts can manage it according to their needs? Will a plain number suffice for + // sophisticated hosts anyway? + let pool = deadpool_postgres::Pool::builder(mgr) + .max_size(CONNECTION_POOL_SIZE) + .build() + .context("building Postgres connection pool")?; + + Ok(pool) +} + +#[async_trait] +pub trait Client { async fn execute( &self, statement: String, @@ -29,28 +92,7 @@ pub trait Client { } #[async_trait] -impl Client for TokioClient { - async fn build_client(address: &str) -> Result - where - Self: Sized, - { - let config = address.parse::()?; - - tracing::debug!("Build new connection: {}", address); - - if config.get_ssl_mode() == SslMode::Disable { - let (client, connection) = config.connect(NoTls).await?; - spawn_connection(connection); - Ok(client) - } else { - let builder = TlsConnector::builder(); - let connector = MakeTlsConnector::new(builder.build()?); - let (client, connection) = config.connect(connector).await?; - spawn_connection(connection); - Ok(client) - } - } - +impl Client for deadpool_postgres::Object { async fn execute( &self, statement: String, @@ -67,7 +109,8 @@ impl Client for TokioClient { .map(|b| b.as_ref() as &(dyn ToSql + Sync)) .collect(); - self.execute(&statement, params_refs.as_slice()) + self.as_ref() + .execute(&statement, params_refs.as_slice()) .await .map_err(|e| v3::Error::QueryFailed(format!("{:?}", e))) } @@ -89,6 +132,7 @@ impl Client for TokioClient { .collect(); let results = self + .as_ref() .query(&statement, params_refs.as_slice()) .await .map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))?; @@ -111,17 +155,6 @@ impl Client for TokioClient { } } -fn spawn_connection(connection: tokio_postgres::Connection) -where - T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static, -{ - tokio::spawn(async move { - if let Err(e) = connection.await { - tracing::error!("Postgres connection error: {}", e); - } - }); -} - fn to_sql_parameter(value: &ParameterValue) -> Result> { match value { ParameterValue::Boolean(v) => Ok(Box::new(*v)), diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index 289934446..92047f9e8 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -9,17 +9,20 @@ use tracing::field::Empty; use tracing::instrument; use tracing::Level; -use crate::client::Client; +use crate::client::{Client, ClientFactory}; use crate::InstanceState; -impl InstanceState { +impl InstanceState { async fn open_connection( &mut self, address: &str, ) -> Result, v3::Error> { self.connections .push( - C::build_client(address) + self.client_factory + .write() + .await + .build_client(address) .await .map_err(|e| v3::Error::ConnectionFailed(format!("{e:?}")))?, ) @@ -30,7 +33,7 @@ impl InstanceState { async fn get_client( &mut self, connection: Resource, - ) -> Result<&C, v3::Error> { + ) -> Result<&CF::Client, v3::Error> { self.connections .get(connection.rep()) .ok_or_else(|| v3::Error::ConnectionFailed("no connection found".into())) @@ -71,8 +74,8 @@ fn v2_params_to_v3( params.into_iter().map(|p| p.try_into()).collect() } -impl spin_world::spin::postgres::postgres::HostConnection - for InstanceState +impl spin_world::spin::postgres::postgres::HostConnection + for InstanceState { #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))] async fn open(&mut self, address: String) -> Result, v3::Error> { @@ -122,13 +125,13 @@ impl spin_world::spin::postgres::postgres::HostConnecti } } -impl v2_types::Host for InstanceState { +impl v2_types::Host for InstanceState { fn convert_error(&mut self, error: v2::Error) -> Result { Ok(error) } } -impl v3::Host for InstanceState { +impl v3::Host for InstanceState { fn convert_error(&mut self, error: v3::Error) -> Result { Ok(error) } @@ -152,9 +155,9 @@ macro_rules! delegate { }}; } -impl v2::Host for InstanceState {} +impl v2::Host for InstanceState {} -impl v2::HostConnection for InstanceState { +impl v2::HostConnection for InstanceState { #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))] async fn open(&mut self, address: String) -> Result, v2::Error> { spin_factor_outbound_networking::record_address_fields(&address); @@ -206,7 +209,7 @@ impl v2::HostConnection for InstanceState { } } -impl v1::Host for InstanceState { +impl v1::Host for InstanceState { async fn execute( &mut self, address: String, diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index 4ca366353..2becb330e 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -1,21 +1,23 @@ pub mod client; mod host; -use client::Client; +use std::sync::Arc; + +use client::ClientFactory; use spin_factor_outbound_networking::{OutboundAllowedHosts, OutboundNetworkingFactor}; use spin_factors::{ anyhow, ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder, }; -use tokio_postgres::Client as PgClient; +use tokio::sync::RwLock; -pub struct OutboundPgFactor { - _phantom: std::marker::PhantomData, +pub struct OutboundPgFactor { + _phantom: std::marker::PhantomData, } -impl Factor for OutboundPgFactor { +impl Factor for OutboundPgFactor { type RuntimeConfig = (); - type AppState = (); - type InstanceBuilder = InstanceState; + type AppState = Arc>; + type InstanceBuilder = InstanceState; fn init( &mut self, @@ -31,7 +33,7 @@ impl Factor for OutboundPgFactor { &self, _ctx: ConfigureAppContext, ) -> anyhow::Result { - Ok(()) + Ok(Arc::new(RwLock::new(CF::new()))) } fn prepare( @@ -43,6 +45,7 @@ impl Factor for OutboundPgFactor { .allowed_hosts(); Ok(InstanceState { allowed_hosts, + client_factory: ctx.app_state().clone(), connections: Default::default(), }) } @@ -62,9 +65,10 @@ impl OutboundPgFactor { } } -pub struct InstanceState { +pub struct InstanceState { allowed_hosts: OutboundAllowedHosts, - connections: spin_resource_table::Table, + client_factory: Arc>, + connections: spin_resource_table::Table, } -impl SelfInstanceBuilder for InstanceState {} +impl SelfInstanceBuilder for InstanceState {} diff --git a/crates/factor-outbound-pg/tests/factor_test.rs b/crates/factor-outbound-pg/tests/factor_test.rs index ae0ab2876..7ba1f6e76 100644 --- a/crates/factor-outbound-pg/tests/factor_test.rs +++ b/crates/factor-outbound-pg/tests/factor_test.rs @@ -1,6 +1,7 @@ use anyhow::{bail, Result}; use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factor_outbound_pg::client::Client; +use spin_factor_outbound_pg::client::ClientFactory; use spin_factor_outbound_pg::OutboundPgFactor; use spin_factor_variables::VariablesFactor; use spin_factors::{anyhow, RuntimeFactors}; @@ -15,14 +16,14 @@ use spin_world::spin::postgres::postgres::{ParameterValue, RowSet}; struct TestFactors { variables: VariablesFactor, networking: OutboundNetworkingFactor, - pg: OutboundPgFactor, + pg: OutboundPgFactor, } fn factors() -> TestFactors { TestFactors { variables: VariablesFactor::default(), networking: OutboundNetworkingFactor::new(), - pg: OutboundPgFactor::::new(), + pg: OutboundPgFactor::::new(), } } @@ -104,17 +105,22 @@ async fn exercise_query() -> anyhow::Result<()> { } // TODO: We can expand this mock to track calls and simulate return values +pub struct MockClientFactory {} pub struct MockClient {} #[async_trait] -impl Client for MockClient { - async fn build_client(_address: &str) -> anyhow::Result - where - Self: Sized, - { +impl ClientFactory for MockClientFactory { + type Client = MockClient; + fn new() -> Self { + Self {} + } + async fn build_client(&mut self, _address: &str) -> Result { Ok(MockClient {}) } +} +#[async_trait] +impl Client for MockClient { async fn execute( &self, _statement: String,