Skip to content

Commit

Permalink
Store pid of the backend when connecting to Postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
arybczak committed Feb 29, 2024
1 parent 8fb828c commit a7a136d
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 41 deletions.
4 changes: 4 additions & 0 deletions src/Database/PostgreSQL/PQTypes/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class (Applicative m, Monad m) => MonadDB m where
-- 'getLastQuery'.
withFrozenLastQuery :: m a -> m a

-- | Get ID of the server process attached to the current session.
getBackendPid :: m Int

-- | Get current connection statistics.
getConnectionStats :: HasCallStack => m ConnectionStats

Expand Down Expand Up @@ -89,6 +92,7 @@ instance
runPreparedQuery name = withFrozenCallStack $ lift . runPreparedQuery name
getLastQuery = lift getLastQuery
withFrozenLastQuery m = controlT $ \run -> withFrozenLastQuery (run m)
getBackendPid = lift getBackendPid
getConnectionStats = withFrozenCallStack $ lift getConnectionStats
getQueryResult = lift getQueryResult
clearQueryResult = lift clearQueryResult
Expand Down
43 changes: 35 additions & 8 deletions src/Database/PostgreSQL/PQTypes/Internal/Connection.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
{-# LANGUAGE TypeApplications #-}

module Database.PostgreSQL.PQTypes.Internal.Connection
( -- * Connection
Connection (..)
, getBackendPidIO
, ConnectionData (..)
, withConnectionData
, ConnectionStats (..)
Expand All @@ -26,10 +29,11 @@ import Control.Exception qualified as E
import Control.Monad
import Control.Monad.Base
import Control.Monad.Catch
import Data.Bifunctor
import Data.ByteString.Char8 qualified as BS
import Data.Foldable qualified as F
import Data.Functor.Identity
import Data.IORef
import Data.Int
import Data.Kind
import Data.Pool
import Data.Set qualified as S
Expand All @@ -48,6 +52,7 @@ import Database.PostgreSQL.PQTypes.Internal.Composite
import Database.PostgreSQL.PQTypes.Internal.Error
import Database.PostgreSQL.PQTypes.Internal.Error.Code
import Database.PostgreSQL.PQTypes.Internal.Exception
import Database.PostgreSQL.PQTypes.Internal.QueryResult
import Database.PostgreSQL.PQTypes.Internal.Utils
import Database.PostgreSQL.PQTypes.SQL.Class
import Database.PostgreSQL.PQTypes.SQL.Raw
Expand Down Expand Up @@ -114,6 +119,8 @@ initialStats =
data ConnectionData = ConnectionData
{ cdPtr :: !(Ptr PGconn)
-- ^ Pointer to connection object.
, cdBackendPid :: !Int
-- ^ Process ID of the server process attached to the current session.
, cdStats :: !ConnectionStats
-- ^ Statistics associated with the connection.
, cdPreparedQueries :: !(IORef (S.Set T.Text))
Expand All @@ -125,6 +132,11 @@ newtype Connection = Connection
{ unConnection :: MVar (Maybe ConnectionData)
}

getBackendPidIO :: Connection -> IO Int
getBackendPidIO conn = do
withConnectionData conn "connectionBackendPid" $ \cd -> do
pure (cd, cdBackendPid cd)

withConnectionData
:: Connection
-> String
Expand All @@ -133,7 +145,9 @@ withConnectionData
withConnectionData (Connection mvc) fname f =
modifyMVar mvc $ \mc -> case mc of
Nothing -> hpqTypesError $ fname ++ ": no connection"
Just cd -> first Just <$> f cd
Just cd -> do
(cd', r) <- f cd
cd' `seq` pure (Just cd', r)

-- | Database connection supplier.
newtype ConnectionSourceM m = ConnectionSourceM
Expand Down Expand Up @@ -215,10 +229,21 @@ connect ConnectionSettings {..} = mask $ \unmask -> do
Just
ConnectionData
{ cdPtr = connPtr
, cdBackendPid = 0
, cdStats = initialStats
, cdPreparedQueries = preparedQueries
}
F.forM_ csRole $ \role -> runQueryIO conn $ "SET ROLE " <> role

let selectPid = "SELECT pg_backend_pid()" :: RawSQL ()
(_, res) <- runQueryIO conn selectPid
case F.toList $ mkQueryResult @(Identity Int32) selectPid 0 res of
[pid] -> withConnectionData conn fname $ \cd -> do
pure (cd {cdBackendPid = fromIntegral pid}, ())
pids -> do
let err = HPQTypesError $ "unexpected backend pid: " ++ show pids
rethrowWithContext selectPid 0 $ toException err

pure conn
where
fname = "connect"
Expand Down Expand Up @@ -317,6 +342,7 @@ runPreparedQueryIO conn (QueryName queryName) sql = do
E.throwIO
DBException
{ dbeQueryContext = sql
, dbeBackendPid = cdBackendPid
, dbeError = HPQTypesError "runPreparedQueryIO: unnamed prepared query is not supported"
, dbeCallStack = callStack
}
Expand All @@ -329,7 +355,7 @@ runPreparedQueryIO conn (QueryName queryName) sql = do
-- succeeds, we need to reflect that fact in cdPreparedQueries since
-- you can't prepare a query with the same name more than once.
res <- c_PQparamPrepare cdPtr nullPtr param cname query
void . withForeignPtr res $ verifyResult sql cdPtr
void . withForeignPtr res $ verifyResult sql cdBackendPid cdPtr
modifyIORef' cdPreparedQueries $ S.insert queryName
(,)
<$> (fromIntegral <$> c_PQparamCount param)
Expand All @@ -353,7 +379,7 @@ runQueryImpl fname conn sql execSql = do
-- runtime system is used) and react appropriately.
queryRunner <- async . restore $ do
(paramCount, res) <- execSql cd
affected <- withForeignPtr res $ verifyResult sql cdPtr
affected <- withForeignPtr res $ verifyResult sql cdBackendPid cdPtr
stats' <- case affected of
Left _ ->
return
Expand All @@ -370,8 +396,7 @@ runQueryImpl fname conn sql execSql = do
, statsValues = statsValues cdStats + (rows * columns)
, statsParams = statsParams cdStats + paramCount
}
-- Force evaluation of modified stats to squash a space leak.
stats' `seq` return (cd {cdStats = stats'}, (either id id affected, res))
return (cd {cdStats = stats'}, (either id id affected, res))
-- If we receive an exception while waiting for the execution to complete,
-- we need to send a request to PostgreSQL for query cancellation and wait
-- for the query runner thread to terminate. It is paramount we make the
Expand Down Expand Up @@ -399,10 +424,11 @@ runQueryImpl fname conn sql execSql = do
verifyResult
:: (HasCallStack, IsSQL sql)
=> sql
-> Int
-> Ptr PGconn
-> Ptr PGresult
-> IO (Either Int Int)
verifyResult sql conn res = do
verifyResult sql pid conn res = do
-- works even if res is NULL
rst <- c_PQresultStatus res
case rst of
Expand All @@ -421,7 +447,7 @@ verifyResult sql conn res = do
_ | otherwise -> return . Left $ 0
where
throwSQLError =
rethrowWithContext sql
rethrowWithContext sql pid
=<< if res == nullPtr
then
return . E.toException . QueryError
Expand Down Expand Up @@ -451,6 +477,7 @@ verifyResult sql conn res = do
E.throwIO
DBException
{ dbeQueryContext = sql
, dbeBackendPid = pid
, dbeError = HPQTypesError ("verifyResult: string returned by PQcmdTuples is not a valid number: " ++ show sn)
, dbeCallStack = callStack
}
7 changes: 5 additions & 2 deletions src/Database/PostgreSQL/PQTypes/Internal/Exception.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import Database.PostgreSQL.PQTypes.SQL.Class
data DBException = forall e sql. (E.Exception e, Show sql) => DBException
{ dbeQueryContext :: !sql
-- ^ Last SQL query that was executed.
, dbeBackendPid :: !Int
-- ^ Process ID of the server process attached to the current session.
, dbeError :: !e
-- ^ Specific error.
, dbeCallStack :: CallStack
Expand All @@ -24,11 +26,12 @@ deriving instance Show DBException
instance E.Exception DBException

-- | Rethrow supplied exception enriched with given SQL.
rethrowWithContext :: (HasCallStack, IsSQL sql) => sql -> E.SomeException -> IO a
rethrowWithContext sql (E.SomeException e) =
rethrowWithContext :: (HasCallStack, IsSQL sql) => sql -> Int -> E.SomeException -> IO a
rethrowWithContext sql pid (E.SomeException e) =
E.throwIO
DBException
{ dbeQueryContext = sql
, dbeBackendPid = pid
, dbeError = e
, dbeCallStack = callStack
}
13 changes: 7 additions & 6 deletions src/Database/PostgreSQL/PQTypes/Internal/Monad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import Control.Monad.State.Strict
import Control.Monad.Trans.Control
import Control.Monad.Trans.State.Strict qualified as S
import Control.Monad.Writer.Class
import Data.Bifunctor
import GHC.Stack

import Database.PostgreSQL.PQTypes.Class
Expand Down Expand Up @@ -77,9 +76,9 @@ mapDBT f g m = DBT . StateT $ g . runStateT (unDBT m) . f

instance (m ~ n, MonadBase IO m, MonadMask m) => MonadDB (DBT_ m n) where
runQuery sql = withFrozenCallStack $ DBT . StateT $ \st -> liftBase $ do
second (updateStateWith st sql) <$> runQueryIO (dbConnection st) sql
updateStateWith st sql =<< runQueryIO (dbConnection st) sql
runPreparedQuery name sql = withFrozenCallStack $ DBT . StateT $ \st -> liftBase $ do
second (updateStateWith st sql) <$> runPreparedQueryIO (dbConnection st) name sql
updateStateWith st sql =<< runPreparedQueryIO (dbConnection st) name sql

getLastQuery = DBT . gets $ dbLastQuery

Expand All @@ -88,6 +87,9 @@ instance (m ~ n, MonadBase IO m, MonadMask m) => MonadDB (DBT_ m n) where
(x, st'') <- runStateT (unDBT callback) st'
pure (x, st'' {dbRecordLastQuery = dbRecordLastQuery st})

getBackendPid = DBT . StateT $ \st -> do
(,st) <$> liftBase (getBackendPidIO $ dbConnection st)

getConnectionStats = withFrozenCallStack $ do
mconn <- DBT $ liftBase . readMVar =<< gets (unConnection . dbConnection)
case mconn of
Expand All @@ -100,9 +102,8 @@ instance (m ~ n, MonadBase IO m, MonadMask m) => MonadDB (DBT_ m n) where
getTransactionSettings = DBT . gets $ dbTransactionSettings
setTransactionSettings ts = DBT . modify $ \st -> st {dbTransactionSettings = ts}

getNotification time = DBT . StateT $ \st ->
(,st)
<$> liftBase (getNotificationIO st time)
getNotification time = DBT . StateT $ \st -> do
(,st) <$> liftBase (getNotificationIO st time)

withNewConnection m = DBT . StateT $ \st -> do
let cs = dbConnectionSource st
Expand Down
23 changes: 20 additions & 3 deletions src/Database/PostgreSQL/PQTypes/Internal/QueryResult.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

module Database.PostgreSQL.PQTypes.Internal.QueryResult
( QueryResult (..)
, mkQueryResult
, ntuples
, nfields

Expand Down Expand Up @@ -35,12 +36,27 @@ import Database.PostgreSQL.PQTypes.SQL.Class
-- extraction appropriately.
data QueryResult t = forall row. FromRow row => QueryResult
{ qrSQL :: !SomeSQL
, qrBackendPid :: !Int
, qrResult :: !(ForeignPtr PGresult)
, qrFromRow :: !(row -> t)
}

mkQueryResult
:: (FromRow t, IsSQL sql)
=> sql
-> Int
-> ForeignPtr PGresult
-> QueryResult t
mkQueryResult sql pid res =
QueryResult
{ qrSQL = SomeSQL sql
, qrBackendPid = pid
, qrResult = res
, qrFromRow = id
}

instance Functor QueryResult where
f `fmap` QueryResult ctx fres g = QueryResult ctx fres (f . g)
f `fmap` QueryResult ctx pid fres g = QueryResult ctx pid fres (f . g)

instance Foldable QueryResult where
foldr f acc = runIdentity . foldrImpl False (coerce f) acc
Expand Down Expand Up @@ -77,7 +93,7 @@ foldImpl
-> acc
-> QueryResult t
-> m acc
foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) fres g) =
foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) pid fres g) =
unsafePerformIO $ withForeignPtr fres $ \res -> do
-- This bit is referentially transparent iff appropriate
-- FrowRow and FromSQL instances are (the ones provided
Expand All @@ -87,6 +103,7 @@ foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) fres g)
E.throwIO
DBException
{ dbeQueryContext = ctx
, dbeBackendPid = pid
, dbeError =
RowLengthMismatch
{ lengthExpected = pqVariablesP rowp
Expand All @@ -101,7 +118,7 @@ foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) fres g)
then return acc
else do
-- mask asynchronous exceptions so they won't be wrapped in DBException
obj <- E.mask_ (g <$> fromRow res err 0 i `E.catch` rethrowWithContext ctx)
obj <- E.mask_ (g <$> fromRow res err 0 i `E.catch` rethrowWithContext ctx pid)
worker `apply` (f obj =<< acc) $ advCtr i
worker (pure iacc) =<< initCtr res
where
Expand Down
30 changes: 18 additions & 12 deletions src/Database/PostgreSQL/PQTypes/Internal/State.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,21 @@ data DBState m = DBState
-- ^ Current query result.
}

updateStateWith :: IsSQL sql => DBState m -> sql -> ForeignPtr PGresult -> DBState m
updateStateWith st sql res =
st
{ dbLastQuery = if dbRecordLastQuery st then SomeSQL sql else dbLastQuery st
, dbQueryResult =
Just
QueryResult
{ qrSQL = SomeSQL sql
, qrResult = res
, qrFromRow = id
}
}
updateStateWith
:: IsSQL sql
=> DBState m
-> sql
-> (r, ForeignPtr PGresult)
-> IO (r, DBState m)
updateStateWith st sql (r, res) = do
pid <- getBackendPidIO $ dbConnection st
pure
( r
, st
{ dbLastQuery =
if dbRecordLastQuery st
then SomeSQL sql
else dbLastQuery st
, dbQueryResult = Just $ mkQueryResult sql pid res
}
)
23 changes: 13 additions & 10 deletions src/Database/PostgreSQL/PQTypes/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,19 @@ import Database.PostgreSQL.PQTypes.SQL.Raw
-- | When given 'DBException', throw it immediately. Otherwise
-- wrap it in 'DBException' with the current query context first.
throwDB :: (HasCallStack, Exception e, MonadDB m, MonadThrow m) => e -> m a
throwDB e = case fromException $ toException e of
Just (dbe :: DBException) -> throwM dbe
Nothing -> do
SomeSQL sql <- getLastQuery
throwM
DBException
{ dbeQueryContext = sql
, dbeError = e
, dbeCallStack = callStack
}
throwDB e = do
pid <- getBackendPid
case fromException $ toException e of
Just (dbe :: DBException) -> throwM dbe
Nothing -> do
SomeSQL sql <- getLastQuery
throwM
DBException
{ dbeQueryContext = sql
, dbeBackendPid = pid
, dbeError = e
, dbeCallStack = callStack
}

----------------------------------------

Expand Down

0 comments on commit a7a136d

Please sign in to comment.