1
- use std:: collections:: HashMap ;
2
- use std:: sync:: Arc ;
1
+ use crate :: ConnectionData :: { ConnectionPool , ConnectionString } ;
3
2
use actix_session:: storage:: { LoadError , SaveError , SessionKey , SessionStore , UpdateError } ;
4
3
use chrono:: Utc ;
5
- use sqlx:: { Pool , Postgres , Row } ;
6
- use sqlx:: postgres:: PgPoolOptions ;
7
4
use rand:: { distributions:: Alphanumeric , rngs:: OsRng , Rng as _} ;
5
+ use serde_json:: { self , Value } ;
6
+ use sqlx:: postgres:: PgPoolOptions ;
7
+ use sqlx:: { Pool , Postgres , Row } ;
8
+ use std:: collections:: HashMap ;
9
+ use std:: sync:: Arc ;
8
10
use time:: Duration ;
9
- use serde_json;
10
- use crate :: ConnectionData :: { ConnectionPool , ConnectionString } ;
11
11
12
12
/// Use Postgres via Sqlx as session storage backend.
13
13
///
@@ -107,25 +107,27 @@ impl SqlxPostgresqlSessionStore {
107
107
pub fn builder < S : Into < String > > ( connection_string : S ) -> SqlxPostgresqlSessionStoreBuilder {
108
108
SqlxPostgresqlSessionStoreBuilder {
109
109
connection_data : ConnectionString ( connection_string. into ( ) ) ,
110
- configuration : CacheConfiguration :: default ( )
110
+ configuration : CacheConfiguration :: default ( ) ,
111
111
}
112
112
}
113
113
114
- pub async fn new < S : Into < String > > ( connection_string : S ) -> Result < SqlxPostgresqlSessionStore , anyhow:: Error > {
114
+ pub async fn new < S : Into < String > > (
115
+ connection_string : S ,
116
+ ) -> Result < SqlxPostgresqlSessionStore , anyhow:: Error > {
115
117
Self :: builder ( connection_string) . build ( ) . await
116
118
}
117
119
118
120
pub async fn from_pool ( pool : Pool < Postgres > ) -> SqlxPostgresqlSessionStoreBuilder {
119
121
SqlxPostgresqlSessionStoreBuilder {
120
122
connection_data : ConnectionPool ( pool. clone ( ) ) ,
121
- configuration : CacheConfiguration :: default ( )
123
+ configuration : CacheConfiguration :: default ( ) ,
122
124
}
123
125
}
124
126
}
125
127
126
128
pub enum ConnectionData {
127
129
ConnectionString ( String ) ,
128
- ConnectionPool ( Pool < Postgres > )
130
+ ConnectionPool ( Pool < Postgres > ) ,
129
131
}
130
132
131
133
#[ must_use]
@@ -137,22 +139,19 @@ pub struct SqlxPostgresqlSessionStoreBuilder {
137
139
impl SqlxPostgresqlSessionStoreBuilder {
138
140
pub async fn build ( self ) -> Result < SqlxPostgresqlSessionStore , anyhow:: Error > {
139
141
match self . connection_data {
140
- ConnectionString ( conn_string) => {
141
- PgPoolOptions :: new ( )
142
- . max_connections ( 1 )
143
- . connect ( conn_string. as_str ( ) )
144
- . await
145
- . map_err ( Into :: into)
146
- . map ( |pool| {
147
- SqlxPostgresqlSessionStore {
148
- client_pool : pool,
149
- configuration : self . configuration
150
- }
151
- } )
152
- } ,
142
+ ConnectionString ( conn_string) => PgPoolOptions :: new ( )
143
+ . max_connections ( 1 )
144
+ . connect ( conn_string. as_str ( ) )
145
+ . await
146
+ . map_err ( Into :: into)
147
+ . map ( |pool| SqlxPostgresqlSessionStore {
148
+ client_pool : pool,
149
+ configuration : self . configuration ,
150
+ } ) ,
153
151
ConnectionPool ( pool) => Ok ( SqlxPostgresqlSessionStore {
154
- client_pool : pool, configuration : self . configuration
155
- } )
152
+ client_pool : pool,
153
+ configuration : self . configuration ,
154
+ } ) ,
156
155
}
157
156
}
158
157
}
@@ -162,61 +161,79 @@ pub(crate) type SessionState = HashMap<String, String>;
162
161
impl SessionStore for SqlxPostgresqlSessionStore {
163
162
async fn load ( & self , session_key : & SessionKey ) -> Result < Option < SessionState > , LoadError > {
164
163
let key = ( self . configuration . cache_keygen ) ( session_key. as_ref ( ) ) ;
165
- let row = sqlx:: query ( "SELECT session_state FROM sessions WHERE key = $1 AND expires > NOW()" )
166
- . bind ( key)
167
- . fetch_optional ( & self . client_pool )
168
- . await
169
- . map_err ( Into :: into)
170
- . map_err ( LoadError :: Other ) ?;
164
+ let row =
165
+ sqlx:: query ( "SELECT session_state FROM sessions WHERE key = $1 AND expires > NOW()" )
166
+ . bind ( key)
167
+ . fetch_optional ( & self . client_pool )
168
+ . await
169
+ . map_err ( Into :: into)
170
+ . map_err ( LoadError :: Other ) ?;
171
171
match row {
172
172
None => Ok ( None ) ,
173
173
Some ( r) => {
174
- let data: String = r. get ( "session_state" ) ;
175
- let state: SessionState = serde_json:: from_str ( & data) . map_err ( Into :: into) . map_err ( LoadError :: Deserialization ) ?;
174
+ let data: Value = r. get ( "session_state" ) ;
175
+ let state: SessionState = serde_json:: from_value ( data)
176
+ . map_err ( Into :: into)
177
+ . map_err ( LoadError :: Deserialization ) ?;
176
178
Ok ( Some ( state) )
177
179
}
178
180
}
179
181
}
180
182
181
- async fn save ( & self , session_state : SessionState , ttl : & Duration ) -> Result < SessionKey , SaveError > {
182
- let body = serde_json:: to_string ( & session_state)
183
+ async fn save (
184
+ & self ,
185
+ session_state : SessionState ,
186
+ ttl : & Duration ,
187
+ ) -> Result < SessionKey , SaveError > {
188
+ let body = serde_json:: to_value ( & session_state)
183
189
. map_err ( Into :: into)
184
190
. map_err ( SaveError :: Serialization ) ?;
185
191
let key = generate_session_key ( ) ;
186
192
let cache_key = ( self . configuration . cache_keygen ) ( key. as_ref ( ) ) ;
187
193
let expires = Utc :: now ( ) + chrono:: Duration :: seconds ( ttl. whole_seconds ( ) as i64 ) ;
188
194
sqlx:: query ( "INSERT INTO sessions(key, session_state, expires) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING" )
189
195
. bind ( cache_key)
190
- . bind ( body)
191
- . bind ( expires)
196
+ . bind ( body)
197
+ . bind ( expires)
192
198
. execute ( & self . client_pool )
193
199
. await
194
200
. map_err ( Into :: into)
195
201
. map_err ( SaveError :: Other ) ?;
196
202
Ok ( key)
197
203
}
198
204
199
- async fn update ( & self , session_key : SessionKey , session_state : SessionState , ttl : & Duration ) -> Result < SessionKey , UpdateError > {
200
- let body = serde_json:: to_string ( & session_state) . map_err ( Into :: into) . map_err ( UpdateError :: Serialization ) ?;
205
+ async fn update (
206
+ & self ,
207
+ session_key : SessionKey ,
208
+ session_state : SessionState ,
209
+ ttl : & Duration ,
210
+ ) -> Result < SessionKey , UpdateError > {
211
+ let body = serde_json:: to_value ( & session_state)
212
+ . map_err ( Into :: into)
213
+ . map_err ( UpdateError :: Serialization ) ?;
201
214
let cache_key = ( self . configuration . cache_keygen ) ( session_key. as_ref ( ) ) ;
202
215
let new_expires = Utc :: now ( ) + chrono:: Duration :: seconds ( ttl. whole_seconds ( ) ) ;
203
216
sqlx:: query ( "UPDATE sessions SET session_state = $1, expires = $2 WHERE key = $3" )
204
- . bind ( body)
205
- . bind ( new_expires)
206
- . bind ( cache_key)
217
+ . bind ( body)
218
+ . bind ( new_expires)
219
+ . bind ( cache_key)
207
220
. execute ( & self . client_pool )
208
221
. await
209
222
. map_err ( Into :: into)
210
223
. map_err ( UpdateError :: Other ) ?;
211
224
Ok ( session_key)
212
225
}
213
226
214
- async fn update_ttl ( & self , session_key : & SessionKey , ttl : & Duration ) -> Result < ( ) , anyhow:: Error > {
227
+ async fn update_ttl (
228
+ & self ,
229
+ session_key : & SessionKey ,
230
+ ttl : & Duration ,
231
+ ) -> Result < ( ) , anyhow:: Error > {
215
232
let new_expires = Utc :: now ( ) + chrono:: Duration :: seconds ( ttl. whole_seconds ( ) as i64 ) ;
216
233
let key = ( self . configuration . cache_keygen ) ( session_key. as_ref ( ) ) ;
217
234
sqlx:: query ( "UPDATE sessions SET expires = $1 WHERE key = $2" )
218
235
. bind ( new_expires)
219
- . bind ( key)
236
+ . bind ( key)
220
237
. execute ( & self . client_pool )
221
238
. await
222
239
. map_err ( Into :: into)
0 commit comments