From 40b2942edab70c8eabdb3c7d96b7d9789097e760 Mon Sep 17 00:00:00 2001 From: Andrzej Rybczak Date: Thu, 29 Feb 2024 18:56:48 +0100 Subject: [PATCH] Store pid of the backend when connecting to Postgres --- src/Database/PostgreSQL/PQTypes/Class.hs | 4 ++ .../PostgreSQL/PQTypes/Internal/Connection.hs | 43 +++++++++++++++---- .../PostgreSQL/PQTypes/Internal/Exception.hs | 7 ++- .../PostgreSQL/PQTypes/Internal/Monad.hs | 13 +++--- .../PQTypes/Internal/QueryResult.hs | 23 ++++++++-- .../PostgreSQL/PQTypes/Internal/State.hs | 30 +++++++------ src/Database/PostgreSQL/PQTypes/Utils.hs | 23 +++++----- 7 files changed, 102 insertions(+), 41 deletions(-) diff --git a/src/Database/PostgreSQL/PQTypes/Class.hs b/src/Database/PostgreSQL/PQTypes/Class.hs index 008170d..47c9249 100644 --- a/src/Database/PostgreSQL/PQTypes/Class.hs +++ b/src/Database/PostgreSQL/PQTypes/Class.hs @@ -32,6 +32,9 @@ class (Applicative m, Monad m) => MonadDB m where -- 'getLastQuery'. withFrozenLastQuery :: m a -> m a + -- | Get ID of the server process attached to the current session. + getBackendPid :: m Int + -- | Get current connection statistics. getConnectionStats :: HasCallStack => m ConnectionStats @@ -89,6 +92,7 @@ instance runPreparedQuery name = withFrozenCallStack $ lift . runPreparedQuery name getLastQuery = lift getLastQuery withFrozenLastQuery m = controlT $ \run -> withFrozenLastQuery (run m) + getBackendPid = lift getBackendPid getConnectionStats = withFrozenCallStack $ lift getConnectionStats getQueryResult = lift getQueryResult clearQueryResult = lift clearQueryResult diff --git a/src/Database/PostgreSQL/PQTypes/Internal/Connection.hs b/src/Database/PostgreSQL/PQTypes/Internal/Connection.hs index de26d92..5706c38 100644 --- a/src/Database/PostgreSQL/PQTypes/Internal/Connection.hs +++ b/src/Database/PostgreSQL/PQTypes/Internal/Connection.hs @@ -1,6 +1,9 @@ +{-# LANGUAGE TypeApplications #-} + module Database.PostgreSQL.PQTypes.Internal.Connection ( -- * Connection Connection (..) + , getBackendPidIO , ConnectionData (..) , withConnectionData , ConnectionStats (..) @@ -26,10 +29,11 @@ import Control.Exception qualified as E import Control.Monad import Control.Monad.Base import Control.Monad.Catch -import Data.Bifunctor import Data.ByteString.Char8 qualified as BS import Data.Foldable qualified as F +import Data.Functor.Identity import Data.IORef +import Data.Int import Data.Kind import Data.Pool import Data.Set qualified as S @@ -48,6 +52,7 @@ import Database.PostgreSQL.PQTypes.Internal.Composite import Database.PostgreSQL.PQTypes.Internal.Error import Database.PostgreSQL.PQTypes.Internal.Error.Code import Database.PostgreSQL.PQTypes.Internal.Exception +import Database.PostgreSQL.PQTypes.Internal.QueryResult import Database.PostgreSQL.PQTypes.Internal.Utils import Database.PostgreSQL.PQTypes.SQL.Class import Database.PostgreSQL.PQTypes.SQL.Raw @@ -114,6 +119,8 @@ initialStats = data ConnectionData = ConnectionData { cdPtr :: !(Ptr PGconn) -- ^ Pointer to connection object. + , cdBackendPid :: !Int + -- ^ Process ID of the server process attached to the current session. , cdStats :: !ConnectionStats -- ^ Statistics associated with the connection. , cdPreparedQueries :: !(IORef (S.Set T.Text)) @@ -125,6 +132,11 @@ newtype Connection = Connection { unConnection :: MVar (Maybe ConnectionData) } +getBackendPidIO :: Connection -> IO Int +getBackendPidIO conn = do + withConnectionData conn "connectionBackendPid" $ \cd -> do + pure (cd, cdBackendPid cd) + withConnectionData :: Connection -> String @@ -133,7 +145,9 @@ withConnectionData withConnectionData (Connection mvc) fname f = modifyMVar mvc $ \mc -> case mc of Nothing -> hpqTypesError $ fname ++ ": no connection" - Just cd -> first Just <$> f cd + Just cd0 -> do + (cd, r) <- f cd0 + cd `seq` pure (Just cd, r) -- | Database connection supplier. newtype ConnectionSourceM m = ConnectionSourceM @@ -215,10 +229,21 @@ connect ConnectionSettings {..} = mask $ \unmask -> do Just ConnectionData { cdPtr = connPtr + , cdBackendPid = 0 , cdStats = initialStats , cdPreparedQueries = preparedQueries } F.forM_ csRole $ \role -> runQueryIO conn $ "SET ROLE " <> role + + let selectPid = "SELECT pg_backend_pid()" :: RawSQL () + (_, res) <- runQueryIO conn selectPid + case F.toList $ mkQueryResult @(Identity Int32) selectPid 0 res of + [pid] -> withConnectionData conn fname $ \cd -> do + pure (cd {cdBackendPid = fromIntegral pid}, ()) + pids -> do + let err = HPQTypesError $ "unexpected backend pid: " ++ show pids + rethrowWithContext selectPid 0 $ toException err + pure conn where fname = "connect" @@ -317,6 +342,7 @@ runPreparedQueryIO conn (QueryName queryName) sql = do E.throwIO DBException { dbeQueryContext = sql + , dbeBackendPid = cdBackendPid , dbeError = HPQTypesError "runPreparedQueryIO: unnamed prepared query is not supported" , dbeCallStack = callStack } @@ -329,7 +355,7 @@ runPreparedQueryIO conn (QueryName queryName) sql = do -- succeeds, we need to reflect that fact in cdPreparedQueries since -- you can't prepare a query with the same name more than once. res <- c_PQparamPrepare cdPtr nullPtr param cname query - void . withForeignPtr res $ verifyResult sql cdPtr + void . withForeignPtr res $ verifyResult sql cdBackendPid cdPtr modifyIORef' cdPreparedQueries $ S.insert queryName (,) <$> (fromIntegral <$> c_PQparamCount param) @@ -353,7 +379,7 @@ runQueryImpl fname conn sql execSql = do -- runtime system is used) and react appropriately. queryRunner <- async . restore $ do (paramCount, res) <- execSql cd - affected <- withForeignPtr res $ verifyResult sql cdPtr + affected <- withForeignPtr res $ verifyResult sql cdBackendPid cdPtr stats' <- case affected of Left _ -> return @@ -370,8 +396,7 @@ runQueryImpl fname conn sql execSql = do , statsValues = statsValues cdStats + (rows * columns) , statsParams = statsParams cdStats + paramCount } - -- Force evaluation of modified stats to squash a space leak. - stats' `seq` return (cd {cdStats = stats'}, (either id id affected, res)) + return (cd {cdStats = stats'}, (either id id affected, res)) -- If we receive an exception while waiting for the execution to complete, -- we need to send a request to PostgreSQL for query cancellation and wait -- for the query runner thread to terminate. It is paramount we make the @@ -399,10 +424,11 @@ runQueryImpl fname conn sql execSql = do verifyResult :: (HasCallStack, IsSQL sql) => sql + -> Int -> Ptr PGconn -> Ptr PGresult -> IO (Either Int Int) -verifyResult sql conn res = do +verifyResult sql pid conn res = do -- works even if res is NULL rst <- c_PQresultStatus res case rst of @@ -421,7 +447,7 @@ verifyResult sql conn res = do _ | otherwise -> return . Left $ 0 where throwSQLError = - rethrowWithContext sql + rethrowWithContext sql pid =<< if res == nullPtr then return . E.toException . QueryError @@ -451,6 +477,7 @@ verifyResult sql conn res = do E.throwIO DBException { dbeQueryContext = sql + , dbeBackendPid = pid , dbeError = HPQTypesError ("verifyResult: string returned by PQcmdTuples is not a valid number: " ++ show sn) , dbeCallStack = callStack } diff --git a/src/Database/PostgreSQL/PQTypes/Internal/Exception.hs b/src/Database/PostgreSQL/PQTypes/Internal/Exception.hs index d0433c0..256ca10 100644 --- a/src/Database/PostgreSQL/PQTypes/Internal/Exception.hs +++ b/src/Database/PostgreSQL/PQTypes/Internal/Exception.hs @@ -14,6 +14,8 @@ import Database.PostgreSQL.PQTypes.SQL.Class data DBException = forall e sql. (E.Exception e, Show sql) => DBException { dbeQueryContext :: !sql -- ^ Last SQL query that was executed. + , dbeBackendPid :: !Int + -- ^ Process ID of the server process attached to the current session. , dbeError :: !e -- ^ Specific error. , dbeCallStack :: CallStack @@ -24,11 +26,12 @@ deriving instance Show DBException instance E.Exception DBException -- | Rethrow supplied exception enriched with given SQL. -rethrowWithContext :: (HasCallStack, IsSQL sql) => sql -> E.SomeException -> IO a -rethrowWithContext sql (E.SomeException e) = +rethrowWithContext :: (HasCallStack, IsSQL sql) => sql -> Int -> E.SomeException -> IO a +rethrowWithContext sql pid (E.SomeException e) = E.throwIO DBException { dbeQueryContext = sql + , dbeBackendPid = pid , dbeError = e , dbeCallStack = callStack } diff --git a/src/Database/PostgreSQL/PQTypes/Internal/Monad.hs b/src/Database/PostgreSQL/PQTypes/Internal/Monad.hs index 43fedda..01b5c97 100644 --- a/src/Database/PostgreSQL/PQTypes/Internal/Monad.hs +++ b/src/Database/PostgreSQL/PQTypes/Internal/Monad.hs @@ -17,7 +17,6 @@ import Control.Monad.State.Strict import Control.Monad.Trans.Control import Control.Monad.Trans.State.Strict qualified as S import Control.Monad.Writer.Class -import Data.Bifunctor import GHC.Stack import Database.PostgreSQL.PQTypes.Class @@ -77,9 +76,9 @@ mapDBT f g m = DBT . StateT $ g . runStateT (unDBT m) . f instance (m ~ n, MonadBase IO m, MonadMask m) => MonadDB (DBT_ m n) where runQuery sql = withFrozenCallStack $ DBT . StateT $ \st -> liftBase $ do - second (updateStateWith st sql) <$> runQueryIO (dbConnection st) sql + updateStateWith st sql =<< runQueryIO (dbConnection st) sql runPreparedQuery name sql = withFrozenCallStack $ DBT . StateT $ \st -> liftBase $ do - second (updateStateWith st sql) <$> runPreparedQueryIO (dbConnection st) name sql + updateStateWith st sql =<< runPreparedQueryIO (dbConnection st) name sql getLastQuery = DBT . gets $ dbLastQuery @@ -88,6 +87,9 @@ instance (m ~ n, MonadBase IO m, MonadMask m) => MonadDB (DBT_ m n) where (x, st'') <- runStateT (unDBT callback) st' pure (x, st'' {dbRecordLastQuery = dbRecordLastQuery st}) + getBackendPid = DBT . StateT $ \st -> do + (,st) <$> liftBase (getBackendPidIO $ dbConnection st) + getConnectionStats = withFrozenCallStack $ do mconn <- DBT $ liftBase . readMVar =<< gets (unConnection . dbConnection) case mconn of @@ -100,9 +102,8 @@ instance (m ~ n, MonadBase IO m, MonadMask m) => MonadDB (DBT_ m n) where getTransactionSettings = DBT . gets $ dbTransactionSettings setTransactionSettings ts = DBT . modify $ \st -> st {dbTransactionSettings = ts} - getNotification time = DBT . StateT $ \st -> - (,st) - <$> liftBase (getNotificationIO st time) + getNotification time = DBT . StateT $ \st -> do + (,st) <$> liftBase (getNotificationIO st time) withNewConnection m = DBT . StateT $ \st -> do let cs = dbConnectionSource st diff --git a/src/Database/PostgreSQL/PQTypes/Internal/QueryResult.hs b/src/Database/PostgreSQL/PQTypes/Internal/QueryResult.hs index f6906b9..719ead6 100644 --- a/src/Database/PostgreSQL/PQTypes/Internal/QueryResult.hs +++ b/src/Database/PostgreSQL/PQTypes/Internal/QueryResult.hs @@ -2,6 +2,7 @@ module Database.PostgreSQL.PQTypes.Internal.QueryResult ( QueryResult (..) + , mkQueryResult , ntuples , nfields @@ -35,12 +36,27 @@ import Database.PostgreSQL.PQTypes.SQL.Class -- extraction appropriately. data QueryResult t = forall row. FromRow row => QueryResult { qrSQL :: !SomeSQL + , qrBackendPid :: !Int , qrResult :: !(ForeignPtr PGresult) , qrFromRow :: !(row -> t) } +mkQueryResult + :: (FromRow t, IsSQL sql) + => sql + -> Int + -> ForeignPtr PGresult + -> QueryResult t +mkQueryResult sql pid res = + QueryResult + { qrSQL = SomeSQL sql + , qrBackendPid = pid + , qrResult = res + , qrFromRow = id + } + instance Functor QueryResult where - f `fmap` QueryResult ctx fres g = QueryResult ctx fres (f . g) + f `fmap` QueryResult ctx pid fres g = QueryResult ctx pid fres (f . g) instance Foldable QueryResult where foldr f acc = runIdentity . foldrImpl False (coerce f) acc @@ -77,7 +93,7 @@ foldImpl -> acc -> QueryResult t -> m acc -foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) fres g) = +foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) pid fres g) = unsafePerformIO $ withForeignPtr fres $ \res -> do -- This bit is referentially transparent iff appropriate -- FrowRow and FromSQL instances are (the ones provided @@ -87,6 +103,7 @@ foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) fres g) E.throwIO DBException { dbeQueryContext = ctx + , dbeBackendPid = pid , dbeError = RowLengthMismatch { lengthExpected = pqVariablesP rowp @@ -101,7 +118,7 @@ foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) fres g) then return acc else do -- mask asynchronous exceptions so they won't be wrapped in DBException - obj <- E.mask_ (g <$> fromRow res err 0 i `E.catch` rethrowWithContext ctx) + obj <- E.mask_ (g <$> fromRow res err 0 i `E.catch` rethrowWithContext ctx pid) worker `apply` (f obj =<< acc) $ advCtr i worker (pure iacc) =<< initCtr res where diff --git a/src/Database/PostgreSQL/PQTypes/Internal/State.hs b/src/Database/PostgreSQL/PQTypes/Internal/State.hs index 1d23b3b..9f8712f 100644 --- a/src/Database/PostgreSQL/PQTypes/Internal/State.hs +++ b/src/Database/PostgreSQL/PQTypes/Internal/State.hs @@ -29,15 +29,21 @@ data DBState m = DBState -- ^ Current query result. } -updateStateWith :: IsSQL sql => DBState m -> sql -> ForeignPtr PGresult -> DBState m -updateStateWith st sql res = - st - { dbLastQuery = if dbRecordLastQuery st then SomeSQL sql else dbLastQuery st - , dbQueryResult = - Just - QueryResult - { qrSQL = SomeSQL sql - , qrResult = res - , qrFromRow = id - } - } +updateStateWith + :: IsSQL sql + => DBState m + -> sql + -> (r, ForeignPtr PGresult) + -> IO (r, DBState m) +updateStateWith st sql (r, res) = do + pid <- getBackendPidIO $ dbConnection st + pure + ( r + , st + { dbLastQuery = + if dbRecordLastQuery st + then SomeSQL sql + else dbLastQuery st + , dbQueryResult = Just $ mkQueryResult sql pid res + } + ) diff --git a/src/Database/PostgreSQL/PQTypes/Utils.hs b/src/Database/PostgreSQL/PQTypes/Utils.hs index bfa291c..99840cf 100644 --- a/src/Database/PostgreSQL/PQTypes/Utils.hs +++ b/src/Database/PostgreSQL/PQTypes/Utils.hs @@ -34,16 +34,19 @@ import Database.PostgreSQL.PQTypes.SQL.Raw -- | When given 'DBException', throw it immediately. Otherwise -- wrap it in 'DBException' with the current query context first. throwDB :: (HasCallStack, Exception e, MonadDB m, MonadThrow m) => e -> m a -throwDB e = case fromException $ toException e of - Just (dbe :: DBException) -> throwM dbe - Nothing -> do - SomeSQL sql <- getLastQuery - throwM - DBException - { dbeQueryContext = sql - , dbeError = e - , dbeCallStack = callStack - } +throwDB e = do + pid <- getBackendPid + case fromException $ toException e of + Just (dbe :: DBException) -> throwM dbe + Nothing -> do + SomeSQL sql <- getLastQuery + throwM + DBException + { dbeQueryContext = sql + , dbeBackendPid = pid + , dbeError = e + , dbeCallStack = callStack + } ----------------------------------------