Skip to content

Commit

Permalink
Fix expired JWTs starting an empty transaction
Browse files Browse the repository at this point in the history
Fixes #1094.

Expired JWTs were doing an empty BEGIN/COMMIT in the db.
  • Loading branch information
steve-chavez committed Jul 3, 2020
1 parent a5bc293 commit 55b4f4f
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 55 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
### Fixed

- #1530, Fix how the PostgREST version is shown in the help text when the `.git` directory is not available - @monacoremo
- #1094, Fix expired JWTs starting an empty transaction on the db - @steve-chavez

### Changed

Expand Down
38 changes: 21 additions & 17 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ import Network.Wai
import PostgREST.ApiRequest (Action (..), ApiRequest (..),
InvokeMethod (..), Target (..),
mutuallyAgreeable, userApiRequest)
import PostgREST.Auth (containsRole, jwtClaims,
parseSecret)
import PostgREST.Auth (attemptJwtClaims, containsRole,
jwtClaims, parseSecret)
import PostgREST.Config (AppConfig (..))
import PostgREST.DbRequestBuilder (mutateRequest, readRequest,
returningCols)
Expand Down Expand Up @@ -79,26 +79,30 @@ postgrest conf refDbStructure pool getTime worker =
Nothing -> respond . errorResponseFor $ ConnectionLostError
Just dbStructure -> do
response <- do
-- Need to parse ?columns early because findProc needs it to solve overloaded functions.
-- TODO: move this logic to the app function
let apiReq = userApiRequest (configSchemas conf) (configRootSpec conf) req body
-- Need to parse ?columns early because findProc needs it to solve overloaded functions.
apiReqCols = (,) <$> apiReq <*> (pRequestColumns . iColumns =<< apiReq)
case apiReqCols of
Left err -> return . errorResponseFor $ err
Right (apiRequest, maybeCols) -> do
eClaims <- jwtClaims jwtSecret (configJwtAudience conf) (toS $ iJWT apiRequest) time (rightToMaybe $ configRoleClaimKey conf)
let authed = containsRole eClaims
cols = case (iPayload apiRequest, maybeCols) of
(Just ProcessedJSON{pjKeys}, _) -> pjKeys
(Just RawJSON{}, Just cls) -> cls
_ -> S.empty
proc = case iTarget apiRequest of
TargetProc qi _ -> findProc qi cols (iPreferParameters apiRequest == Just SingleObject) $ dbProcs dbStructure
_ -> Nothing
handleReq = runWithClaims conf eClaims (app dbStructure proc cols conf) apiRequest
txMode = transactionMode proc (iAction apiRequest)
response <- P.use pool $ HT.transaction HT.ReadCommitted txMode handleReq
return $ either (errorResponseFor . PgError authed) identity response
-- The jwt must be checked before touching the db.
attempt <- attemptJwtClaims jwtSecret (configJwtAudience conf) (toS $ iJWT apiRequest) time (rightToMaybe $ configRoleClaimKey conf)
case jwtClaims attempt of
Left errJwt -> return . errorResponseFor $ errJwt
Right claims -> do
let
authed = containsRole claims
cols = case (iPayload apiRequest, maybeCols) of
(Just ProcessedJSON{pjKeys}, _) -> pjKeys
(Just RawJSON{}, Just cls) -> cls
_ -> S.empty
proc = case iTarget apiRequest of
TargetProc qi _ -> findProc qi cols (iPreferParameters apiRequest == Just SingleObject) $ dbProcs dbStructure
_ -> Nothing
handleReq = runPgLocals conf claims (app dbStructure proc cols conf) apiRequest
txMode = transactionMode proc (iAction apiRequest)
dbResp <- P.use pool $ HT.transaction HT.ReadCommitted txMode handleReq
return $ either (errorResponseFor . PgError authed) identity dbResp
when (responseStatus response == status503) worker
respond response

Expand Down
23 changes: 16 additions & 7 deletions src/PostgREST/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ very simple authentication system inside the PostgreSQL database.
module PostgREST.Auth (
containsRole
, jwtClaims
, JWTAttempt(..)
, attemptJwtClaims
, parseSecret
) where

Expand All @@ -30,6 +30,7 @@ import Data.Time.Clock (UTCTime)
import Control.Lens.Operators
import Crypto.JWT

import PostgREST.Error (SimpleError (..))
import PostgREST.Types
import Protolude hiding (toS)
import Protolude.Conv (toS)
Expand All @@ -42,13 +43,22 @@ data JWTAttempt = JWTInvalid JWTError
| JWTClaims (M.HashMap Text JSON.Value)
deriving (Eq, Show)


jwtClaims :: JWTAttempt -> Either SimpleError (M.HashMap Text JSON.Value)
jwtClaims attempt =
case attempt of
JWTMissingSecret -> Left JwtTokenMissing
JWTInvalid JWTExpired -> Left $ JwtTokenInvalid "JWT expired"
JWTInvalid e -> Left $ JwtTokenInvalid $ show e
JWTClaims claims -> Right claims

{-|
Receives the JWT secret and audience (from config) and a JWT and returns a map
of JWT claims.
-}
jwtClaims :: Maybe JWKSet -> Maybe StringOrURI -> LByteString -> UTCTime -> Maybe JSPath -> IO JWTAttempt
jwtClaims _ _ "" _ _ = return $ JWTClaims M.empty
jwtClaims secret audience payload time jspath =
attemptJwtClaims :: Maybe JWKSet -> Maybe StringOrURI -> LByteString -> UTCTime -> Maybe JSPath -> IO JWTAttempt
attemptJwtClaims _ _ "" _ _ = return $ JWTClaims M.empty
attemptJwtClaims secret audience payload time jspath =
case secret of
Nothing -> return JWTMissingSecret
Just s -> do
Expand Down Expand Up @@ -82,9 +92,8 @@ walkJSPath _ _ = Nothing
{-|
Whether a response from jwtClaims contains a role claim
-}
containsRole :: JWTAttempt -> Bool
containsRole (JWTClaims claims) = M.member "role" claims
containsRole _ = False
containsRole :: M.HashMap Text JSON.Value -> Bool
containsRole = M.member "role"

{-|
Parse `jwt-secret` configuration option and turn into a JWKSet.
Expand Down
53 changes: 22 additions & 31 deletions src/PostgREST/Middleware.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,34 @@ import Network.Wai.Middleware.Cors (cors)
import Network.Wai.Middleware.Gzip (def, gzip)
import Network.Wai.Middleware.Static (only, staticPolicy)

import Crypto.JWT

import PostgREST.ApiRequest (ApiRequest (..))
import PostgREST.Auth (JWTAttempt (..))
import PostgREST.Config (AppConfig (..), corsPolicy)
import PostgREST.Error (SimpleError (JwtTokenInvalid, JwtTokenMissing),
errorResponseFor)
import PostgREST.QueryBuilder (setLocalQuery, setLocalSearchPathQuery)
import Protolude hiding (head, toS)
import Protolude.Conv (toS)

runWithClaims :: AppConfig -> JWTAttempt ->
(ApiRequest -> H.Transaction Response) ->
ApiRequest -> H.Transaction Response
runWithClaims conf eClaims app req =
case eClaims of
JWTMissingSecret -> return . errorResponseFor $ JwtTokenMissing
JWTInvalid JWTExpired -> return . errorResponseFor . JwtTokenInvalid $ "JWT expired"
JWTInvalid e -> return . errorResponseFor . JwtTokenInvalid . show $ e
JWTClaims claims -> do
H.sql $ toS . mconcat $ setSearchPathSql : setRoleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql
mapM_ H.sql customReqCheck
app req
where
methodSql = setLocalQuery mempty ("request.method", toS $ iMethod req)
pathSql = setLocalQuery mempty ("request.path", toS $ iPath req)
headersSql = setLocalQuery "request.header." <$> iHeaders req
cookiesSql = setLocalQuery "request.cookie." <$> iCookies req
claimsSql = setLocalQuery "request.jwt.claim." <$> [(c,unquoted v) | (c,v) <- M.toList claimsWithRole]
appSettingsSql = setLocalQuery mempty <$> configSettings conf
setRoleSql = maybeToList $ (\x ->
setLocalQuery mempty ("role", unquoted x)) <$> M.lookup "role" claimsWithRole
setSearchPathSql = setLocalSearchPathQuery (iSchema req : configExtraSearchPath conf)
-- role claim defaults to anon if not specified in jwt
claimsWithRole = M.union claims (M.singleton "role" anon)
anon = JSON.String . toS $ configAnonRole conf
customReqCheck = (\f -> "select " <> toS f <> "();") <$> configReqCheck conf
-- | Runs local(transaction scoped) GUCs for every request, plus the pre-request function
runPgLocals :: AppConfig -> M.HashMap Text JSON.Value ->
(ApiRequest -> H.Transaction Response) ->
ApiRequest -> H.Transaction Response
runPgLocals conf claims app req = do
H.sql $ toS . mconcat $ setSearchPathSql : setRoleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql
traverse_ H.sql customReqCheck
app req
where
methodSql = setLocalQuery mempty ("request.method", toS $ iMethod req)
pathSql = setLocalQuery mempty ("request.path", toS $ iPath req)
headersSql = setLocalQuery "request.header." <$> iHeaders req
cookiesSql = setLocalQuery "request.cookie." <$> iCookies req
claimsSql = setLocalQuery "request.jwt.claim." <$> [(c,unquoted v) | (c,v) <- M.toList claimsWithRole]
appSettingsSql = setLocalQuery mempty <$> configSettings conf
setRoleSql = maybeToList $ (\x ->
setLocalQuery mempty ("role", unquoted x)) <$> M.lookup "role" claimsWithRole
setSearchPathSql = setLocalSearchPathQuery (iSchema req : configExtraSearchPath conf)
-- role claim defaults to anon if not specified in jwt
claimsWithRole = M.union claims (M.singleton "role" anon)
anon = JSON.String . toS $ configAnonRole conf
customReqCheck = (\f -> "select " <> toS f <> "();") <$> configReqCheck conf

defaultMiddle :: Application -> Application
defaultMiddle =
Expand Down

0 comments on commit 55b4f4f

Please sign in to comment.