@@ -3,7 +3,7 @@ use std::{
33 time:: SystemTime ,
44} ;
55
6- use super :: extended:: PreparedStatement ;
6+ use super :: { extended:: PreparedStatement , pg_auth_service :: AuthenticationStatus } ;
77use crate :: {
88 compile:: {
99 convert_statement_to_cube_query,
@@ -24,8 +24,11 @@ use crate::{
2424use futures:: { pin_mut, FutureExt , StreamExt } ;
2525use log:: { debug, error, trace} ;
2626use pg_srv:: {
27- buffer, protocol,
28- protocol:: { ErrorCode , ErrorResponse , Format , InitialMessage , PortalCompletion } ,
27+ buffer,
28+ protocol:: {
29+ self , AuthenticationRequest , ErrorCode , ErrorResponse , Format , InitialMessage ,
30+ PortalCompletion ,
31+ } ,
2932 PgType , PgTypeId , ProtocolError ,
3033} ;
3134use sqlparser:: ast:: { self , CloseCursor , FetchDirection , Query , SetExpr , Statement , Value } ;
@@ -46,10 +49,9 @@ pub struct AsyncPostgresShim {
4649 logger : Arc < dyn ContextLogger > ,
4750}
4851
49- #[ derive( PartialEq , Eq ) ]
5052pub enum StartupState {
5153 // Initial parameters which client sends in the first message, we use it later in auth method
52- Success ( HashMap < String , String > ) ,
54+ Success ( HashMap < String , String > , AuthenticationRequest ) ,
5355 SslRequested ,
5456 Denied ,
5557 CancelRequest ,
@@ -313,25 +315,23 @@ impl AsyncPostgresShim {
313315 }
314316
315317 pub async fn run ( & mut self ) -> Result < ( ) , ConnectionError > {
316- let initial_parameters = match self . process_initial_message ( ) . await ? {
317- StartupState :: Success ( parameters) => parameters,
318+ let ( initial_parameters, auth_method ) = match self . process_initial_message ( ) . await ? {
319+ StartupState :: Success ( parameters, auth_method ) => ( parameters, auth_method ) ,
318320 StartupState :: SslRequested => match self . process_initial_message ( ) . await ? {
319- StartupState :: Success ( parameters) => parameters,
321+ StartupState :: Success ( parameters, auth_method ) => ( parameters, auth_method ) ,
320322 _ => return Ok ( ( ) ) ,
321323 } ,
322324 StartupState :: Denied | StartupState :: CancelRequest => return Ok ( ( ) ) ,
323325 } ;
324326
325- match buffer:: read_message ( & mut self . socket ) . await ? {
326- protocol:: FrontendMessage :: PasswordMessage ( password_message) => {
327- if !self
328- . authenticate ( password_message, initial_parameters)
329- . await ?
330- {
331- return Ok ( ( ) ) ;
332- }
333- }
334- _ => return Ok ( ( ) ) ,
327+ let message_tag_parser = self . session . server . pg_auth . get_pg_message_tag_parser ( ) ;
328+ let auth_secret =
329+ buffer:: read_message ( & mut self . socket , Arc :: clone ( & message_tag_parser) ) . await ?;
330+ if !self
331+ . authenticate ( auth_method, auth_secret, initial_parameters)
332+ . await ?
333+ {
334+ return Ok ( ( ) ) ;
335335 }
336336
337337 self . ready ( ) . await ?;
@@ -351,7 +351,7 @@ impl AsyncPostgresShim {
351351 true = async { semifast_shutdownable && { semifast_shutdown_interruptor. cancelled( ) . await ; true } } => {
352352 return Self :: flush_and_write_admin_shutdown_fatal_message( self ) . await ;
353353 }
354- message_result = buffer:: read_message( & mut self . socket) => message_result?
354+ message_result = buffer:: read_message( & mut self . socket, Arc :: clone ( & message_tag_parser ) ) => message_result?
355355 } ;
356356
357357 let result = match message {
@@ -716,73 +716,62 @@ impl AsyncPostgresShim {
716716 return Ok ( StartupState :: Denied ) ;
717717 }
718718
719- self . write ( protocol:: Authentication :: new (
720- protocol:: AuthenticationRequest :: CleartextPassword ,
721- ) )
722- . await ?;
719+ let auth_method = self . session . server . pg_auth . get_auth_method ( & parameters) ;
720+ self . write ( protocol:: Authentication :: new ( auth_method. clone ( ) ) )
721+ . await ?;
723722
724- Ok ( StartupState :: Success ( parameters) )
723+ Ok ( StartupState :: Success ( parameters, auth_method ) )
725724 }
726725
727726 pub async fn authenticate (
728727 & mut self ,
729- password_message : protocol:: PasswordMessage ,
728+ auth_request : AuthenticationRequest ,
729+ auth_secret : protocol:: FrontendMessage ,
730730 parameters : HashMap < String , String > ,
731731 ) -> Result < bool , ConnectionError > {
732- let user = parameters . get ( "user" ) . unwrap ( ) . clone ( ) ;
733- let authenticate_response = self
732+ let auth_service = self . session . server . auth . clone ( ) ;
733+ let auth_status = self
734734 . session
735735 . server
736- . auth
737- . authenticate ( Some ( user . clone ( ) ) , Some ( password_message . password . clone ( ) ) )
736+ . pg_auth
737+ . authenticate ( auth_service , auth_request , auth_secret , & parameters )
738738 . await ;
739+ let result = match auth_status {
740+ AuthenticationStatus :: UnexpectedFrontendMessage => Err ( (
741+ "invalid authorization specification" . to_string ( ) ,
742+ protocol:: ErrorCode :: InvalidAuthorizationSpecification ,
743+ ) ) ,
744+ AuthenticationStatus :: Failed ( err) => Err ( ( err, protocol:: ErrorCode :: InvalidPassword ) ) ,
745+ AuthenticationStatus :: Success ( user, auth_context) => Ok ( ( user, auth_context) ) ,
746+ } ;
739747
740- let mut auth_context: Option < AuthContextRef > = None ;
748+ match result {
749+ Err ( ( message, code) ) => {
750+ let error_response = protocol:: ErrorResponse :: fatal ( code, message) ;
751+ buffer:: write_message (
752+ & mut self . partial_write_buf ,
753+ & mut self . socket ,
754+ error_response,
755+ )
756+ . await ?;
741757
742- let auth_success = match authenticate_response {
743- Ok ( authenticate_response) => {
744- auth_context = Some ( authenticate_response. context ) ;
745- if !authenticate_response. skip_password_check {
746- match authenticate_response. password {
747- None => false ,
748- Some ( password) => password == password_message. password ,
749- }
750- } else {
751- true
752- }
758+ Ok ( false )
753759 }
754- _ => false ,
755- } ;
756-
757- if !auth_success {
758- let error_response = protocol:: ErrorResponse :: fatal (
759- protocol:: ErrorCode :: InvalidPassword ,
760- format ! ( "password authentication failed for user \" {}\" " , & user) ,
761- ) ;
762- buffer:: write_message (
763- & mut self . partial_write_buf ,
764- & mut self . socket ,
765- error_response,
766- )
767- . await ?;
760+ Ok ( ( user, auth_context) ) => {
761+ let database = parameters
762+ . get ( "database" )
763+ . map ( |v| v. clone ( ) )
764+ . unwrap_or ( "db" . to_string ( ) ) ;
765+ self . session . state . set_database ( Some ( database) ) ;
766+ self . session . state . set_user ( Some ( user) ) ;
767+ self . session . state . set_auth_context ( Some ( auth_context) ) ;
768+
769+ self . write ( protocol:: Authentication :: new ( AuthenticationRequest :: Ok ) )
770+ . await ?;
768771
769- return Ok ( false ) ;
772+ Ok ( true )
773+ }
770774 }
771-
772- let database = parameters
773- . get ( "database" )
774- . map ( |v| v. clone ( ) )
775- . unwrap_or ( "db" . to_string ( ) ) ;
776- self . session . state . set_database ( Some ( database) ) ;
777- self . session . state . set_user ( Some ( user) ) ;
778- self . session . state . set_auth_context ( auth_context) ;
779-
780- self . write ( protocol:: Authentication :: new (
781- protocol:: AuthenticationRequest :: Ok ,
782- ) )
783- . await ?;
784-
785- Ok ( true )
786775 }
787776
788777 pub async fn ready ( & mut self ) -> Result < ( ) , ConnectionError > {
0 commit comments