1- use anyhow:: { anyhow, Result } ;
1+ use anyhow:: { anyhow, Context , Result } ;
22use native_tls:: TlsConnector ;
33use postgres_native_tls:: MakeTlsConnector ;
44use spin_world:: async_trait;
55use spin_world:: spin:: postgres:: postgres:: {
66 self as v3, Column , DbDataType , DbValue , ParameterValue , RowSet ,
77} ;
88use tokio_postgres:: types:: Type ;
9- use tokio_postgres:: { config:: SslMode , types:: ToSql , Row } ;
10- use tokio_postgres:: { Client as TokioClient , NoTls , Socket } ;
9+ use tokio_postgres:: { config:: SslMode , types:: ToSql , NoTls , Row } ;
10+
11+ const CONNECTION_POOL_SIZE : usize = 64 ;
1112
1213#[ async_trait]
13- pub trait Client {
14- async fn build_client ( address : & str ) -> Result < Self >
15- where
16- Self : Sized ;
14+ pub trait ClientFactory : Send + Sync {
15+ type Client : Client + Send + Sync + ' static ;
16+ fn new ( ) -> Self ;
17+ async fn build_client ( & mut self , address : & str ) -> Result < Self :: Client > ;
18+ }
19+
20+ pub struct PooledTokioClientFactory {
21+ pools : std:: collections:: HashMap < String , deadpool_postgres:: Pool > ,
22+ }
23+
24+ #[ async_trait]
25+ impl ClientFactory for PooledTokioClientFactory {
26+ type Client = deadpool_postgres:: Object ;
27+ fn new ( ) -> Self {
28+ Self {
29+ pools : Default :: default ( ) ,
30+ }
31+ }
32+ async fn build_client ( & mut self , address : & str ) -> Result < Self :: Client > {
33+ let pool_entry = self . pools . entry ( address. to_owned ( ) ) ;
34+ let pool = match pool_entry {
35+ std:: collections:: hash_map:: Entry :: Occupied ( entry) => entry. into_mut ( ) ,
36+ std:: collections:: hash_map:: Entry :: Vacant ( entry) => {
37+ let pool = create_connection_pool ( address)
38+ . context ( "establishing PostgreSQL connection pool" ) ?;
39+ entry. insert ( pool)
40+ }
41+ } ;
42+
43+ Ok ( pool. get ( ) . await ?)
44+ }
45+ }
46+
47+ fn create_connection_pool ( address : & str ) -> Result < deadpool_postgres:: Pool > {
48+ let config = address
49+ . parse :: < tokio_postgres:: Config > ( )
50+ . context ( "parsing Postgres connection string" ) ?;
51+
52+ tracing:: debug!( "Build new connection: {}" , address) ;
53+
54+ // TODO: This is slower but safer. Is it the right tradeoff?
55+ // https://docs.rs/deadpool-postgres/latest/deadpool_postgres/enum.RecyclingMethod.html
56+ let mgr_config = deadpool_postgres:: ManagerConfig {
57+ recycling_method : deadpool_postgres:: RecyclingMethod :: Clean ,
58+ } ;
59+
60+ let mgr = if config. get_ssl_mode ( ) == SslMode :: Disable {
61+ deadpool_postgres:: Manager :: from_config ( config, NoTls , mgr_config)
62+ } else {
63+ let builder = TlsConnector :: builder ( ) ;
64+ let connector = MakeTlsConnector :: new ( builder. build ( ) ?) ;
65+ deadpool_postgres:: Manager :: from_config ( config, connector, mgr_config)
66+ } ;
1767
68+ // TODO: what is our max size heuristic? Should this be passed in soe that different
69+ // hosts can manage it according to their needs? Will a plain number suffice for
70+ // sophisticated hosts anyway?
71+ let pool = deadpool_postgres:: Pool :: builder ( mgr)
72+ . max_size ( CONNECTION_POOL_SIZE )
73+ . build ( )
74+ . context ( "building Postgres connection pool" ) ?;
75+
76+ Ok ( pool)
77+ }
78+
79+ #[ async_trait]
80+ pub trait Client {
1881 async fn execute (
1982 & self ,
2083 statement : String ,
@@ -29,28 +92,7 @@ pub trait Client {
2992}
3093
3194#[ async_trait]
32- impl Client for TokioClient {
33- async fn build_client ( address : & str ) -> Result < Self >
34- where
35- Self : Sized ,
36- {
37- let config = address. parse :: < tokio_postgres:: Config > ( ) ?;
38-
39- tracing:: debug!( "Build new connection: {}" , address) ;
40-
41- if config. get_ssl_mode ( ) == SslMode :: Disable {
42- let ( client, connection) = config. connect ( NoTls ) . await ?;
43- spawn_connection ( connection) ;
44- Ok ( client)
45- } else {
46- let builder = TlsConnector :: builder ( ) ;
47- let connector = MakeTlsConnector :: new ( builder. build ( ) ?) ;
48- let ( client, connection) = config. connect ( connector) . await ?;
49- spawn_connection ( connection) ;
50- Ok ( client)
51- }
52- }
53-
95+ impl Client for deadpool_postgres:: Object {
5496 async fn execute (
5597 & self ,
5698 statement : String ,
@@ -67,7 +109,8 @@ impl Client for TokioClient {
67109 . map ( |b| b. as_ref ( ) as & ( dyn ToSql + Sync ) )
68110 . collect ( ) ;
69111
70- self . execute ( & statement, params_refs. as_slice ( ) )
112+ self . as_ref ( )
113+ . execute ( & statement, params_refs. as_slice ( ) )
71114 . await
72115 . map_err ( |e| v3:: Error :: QueryFailed ( format ! ( "{e:?}" ) ) )
73116 }
@@ -89,6 +132,7 @@ impl Client for TokioClient {
89132 . collect ( ) ;
90133
91134 let results = self
135+ . as_ref ( )
92136 . query ( & statement, params_refs. as_slice ( ) )
93137 . await
94138 . map_err ( |e| v3:: Error :: QueryFailed ( format ! ( "{e:?}" ) ) ) ?;
@@ -111,17 +155,6 @@ impl Client for TokioClient {
111155 }
112156}
113157
114- fn spawn_connection < T > ( connection : tokio_postgres:: Connection < Socket , T > )
115- where
116- T : tokio_postgres:: tls:: TlsStream + std:: marker:: Unpin + std:: marker:: Send + ' static ,
117- {
118- tokio:: spawn ( async move {
119- if let Err ( e) = connection. await {
120- tracing:: error!( "Postgres connection error: {}" , e) ;
121- }
122- } ) ;
123- }
124-
125158fn to_sql_parameter ( value : & ParameterValue ) -> Result < Box < dyn ToSql + Send + Sync > > {
126159 match value {
127160 ParameterValue :: Boolean ( v) => Ok ( Box :: new ( * v) ) ,
0 commit comments