diff --git a/postgrest.cabal b/postgrest.cabal index 290618a04d..dbd74ac0b6 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -80,8 +80,10 @@ library , auto-update >= 0.1.4 && < 0.2 , base64-bytestring >= 1 && < 1.3 , bytestring >= 0.10.8 && < 0.12 + , cache >= 0.1.3 && < 0.2.0 , case-insensitive >= 1.2 && < 1.3 , cassava >= 0.4.5 && < 0.6 + , clock >= 0.8.3 && < 0.9.0 , configurator-pg >= 0.2 && < 0.3 , containers >= 0.5.7 && < 0.7 , contravariant-extras >= 0.3.3 && < 0.4 diff --git a/src/PostgREST/AppState.hs b/src/PostgREST/AppState.hs index 89e3cb69d1..ed20481f60 100644 --- a/src/PostgREST/AppState.hs +++ b/src/PostgREST/AppState.hs @@ -4,6 +4,7 @@ module PostgREST.AppState ( AppState + , AuthResult(..) , destroy , getConfig , getSchemaCache @@ -12,6 +13,7 @@ module PostgREST.AppState , getPgVersion , getRetryNextIn , getTime + , getJwtCache , init , initWithPool , logWithZTime @@ -24,8 +26,11 @@ module PostgREST.AppState , runListener ) where +import qualified Data.Aeson as JSON +import qualified Data.Aeson.KeyMap as KM import qualified Data.ByteString.Char8 as BS import qualified Data.ByteString.Lazy as LBS +import qualified Data.Cache as C import Data.Either.Combinators (whenLeft) import qualified Data.Text.Encoding as T import Hasql.Connection (acquire) @@ -62,6 +67,11 @@ import PostgREST.SchemaCache.Identifiers (dumpQi) import Protolude +data AuthResult = AuthResult + { authClaims :: KM.KeyMap JSON.Value + , authRole :: BS.ByteString + } + data AppState = AppState -- | Database connection pool { statePool :: SQL.Pool @@ -87,6 +97,8 @@ data AppState = AppState , stateRetryNextIn :: IORef Int -- | Logs a pool error with a debounce , debounceLogAcquisitionTimeout :: IO () + -- | JWT Cache + , jwtCache :: C.Cache ByteString AuthResult } init :: AppConfig -> IO AppState @@ -108,6 +120,7 @@ initWithPool pool conf = do <*> myThreadId <*> newIORef 0 <*> pure (pure ()) + <*> C.newCache Nothing debLogTimeout <- @@ -188,6 +201,9 @@ putConfig = atomicWriteIORef . stateConf getTime :: AppState -> IO UTCTime getTime = stateGetTime +getJwtCache :: AppState -> C.Cache ByteString AuthResult +getJwtCache = jwtCache + -- | Log to stderr with local time logWithZTime :: AppState -> Text -> IO () logWithZTime appState txt = do diff --git a/src/PostgREST/Auth.hs b/src/PostgREST/Auth.hs index fb07daa600..fcd1bc78da 100644 --- a/src/PostgREST/Auth.hs +++ b/src/PostgREST/Auth.hs @@ -26,6 +26,8 @@ import qualified Data.Aeson.KeyMap as KM import qualified Data.Aeson.Types as JSON import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy.Char8 as LBS +import qualified Data.Cache as C +import qualified Data.Scientific as Sci import qualified Data.Vault.Lazy as Vault import qualified Data.Vector as V import qualified Network.HTTP.Types.Header as HTTP @@ -37,21 +39,18 @@ import Control.Monad.Except (liftEither) import Data.Either.Combinators (mapLeft) import Data.List (lookup) import Data.Time.Clock (UTCTime) +import System.Clock (TimeSpec (..)) import System.IO.Unsafe (unsafePerformIO) import System.TimeIt (timeItT) -import PostgREST.AppState (AppState, getConfig, getTime) +import PostgREST.AppState (AppState, AuthResult (..), getConfig, + getJwtCache, getTime) import PostgREST.Config (AppConfig (..), JSPath, JSPathExp (..)) import PostgREST.Error (Error (..)) import Protolude -data AuthResult = AuthResult - { authClaims :: KM.KeyMap JSON.Value - , authRole :: BS.ByteString - } - -- | Receives the JWT secret and audience (from config) and a JWT and returns a -- JSON object of JWT claims. parseToken :: Monad m => @@ -107,16 +106,49 @@ middleware appState app req respond = do let token = fromMaybe "" $ Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req) parseJwt = runExceptT $ parseToken conf (LBS.fromStrict token) time >>= parseClaims conf - if configDbPlanEnabled conf - then do - (dur,authResult) <- timeItT parseJwt - let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur } - app req' respond - else do - authResult <- parseJwt - let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult } - app req' respond - +-- If DbPlanEnabled -> calculate JWT validation time +-- If JwtCaching -> cache JWT validation result + case (configDbPlanEnabled conf, configJwtCaching conf) of + (True, True) -> do + (dur, authResult) <- timeItT $ getJWTFromCache appState token parseJwt + let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur } + app req' respond + + (True, False) -> do + (dur, authResult) <- timeItT parseJwt + let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur } + app req' respond + + (False, True) -> do + authResult <- getJWTFromCache appState token parseJwt + let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult } + app req' respond + + (False, False) -> do + authResult <- parseJwt + let req' = req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult } + app req' respond + +-- Used to extract JWT exp claim and add to JWT Cache +getTimeSpec :: AuthResult -> Maybe TimeSpec +getTimeSpec res = do + let sciToInt = fromMaybe 0 . Sci.toBoundedInteger + expireJSON <- KM.lookup "exp" (authClaims res) + case expireJSON of + JSON.Number seconds -> Just $ TimeSpec (sciToInt seconds) 0 + _ -> Just $ TimeSpec 0 0 -- set timeSpec to 0 so it expires immediately, hence not cached + +-- | Used to retrieve and insert JWT to JWT Cache +getJWTFromCache :: AppState -> ByteString -> IO (Either Error AuthResult) -> IO (Either Error AuthResult) +getJWTFromCache appState token parseJwt = do + checkCache <- C.lookup (getJwtCache appState) token + authResult <- maybe parseJwt (pure . Right) checkCache + + case (authResult,checkCache) of + (Right res, Nothing) -> C.insert' (getJwtCache appState) (getTimeSpec res) token res + _ -> pure () + + return authResult authResultKey :: Vault.Key (Either Error AuthResult) authResultKey = unsafePerformIO Vault.newKey diff --git a/test/io/test_io.py b/test/io/test_io.py index 55a0712a23..120c467d1a 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1095,3 +1095,55 @@ def test_fail_with_automatic_recovery_disabled_and_terminated_using_query(defaul exitCode = wait_until_exit(postgrest) assert exitCode == 1 + + +def test_server_timing_jwt_should_decrease_on_subsequent_requests(defaultenv): + "assert that server-timing duration for JWT should decrease on subsequent requests" + + env = { + **defaultenv, + "PGRST_DB_PLAN_ENABLED": "true", + "PGRST_JWT_CACHING": "true", + "PGRST_JWT_SECRET": "@/dev/stdin", + "PGRST_DB_CONFIG": "false", + } + + headers = jwtauthheader({"role": "postgrest_test_author"}, SECRET) + + with run(stdin=SECRET.encode(), env=env) as postgrest: + first_dur_text = postgrest.session.get( + "/authors_only", headers=headers + ).headers["Server-Timing"] + second_dur_text = postgrest.session.get( + "/authors_only", headers=headers + ).headers["Server-Timing"] + + first_dur = float(first_dur_text[8:]) # skip "jwt;dur=" + second_dur = float(second_dur_text[8:]) + + # their difference should be atleast 200, implying + # that JWT Caching is working as expected + assert (first_dur - second_dur) > 200.0 + + +# just added to complete code coverage +def test_jwt_caching_works_with_db_plan_disabled(defaultenv): + "assert that JWT caching words even when Server-Timing header is not returned" + + env = { + **defaultenv, + "PGRST_DB_PLAN_ENABLED": "false", + "PGRST_JWT_CACHING": "true", + "PGRST_JWT_SECRET": "@/dev/stdin", + "PGRST_DB_CONFIG": "false", + } + + headers = jwtauthheader({"role": "postgrest_test_author"}, SECRET) + + with run(stdin=SECRET.encode(), env=env) as postgrest: + first_request = postgrest.session.get("/authors_only", headers=headers) + second_request = postgrest.session.get("/authors_only", headers=headers) + + # in this case we don't get server-timing in response headers + # so we can't compare durations, we just check if request succeeds + assert first_request.status_code == 200 and second_request.status_code == 200 diff --git a/test/spec/SpecHelper.hs b/test/spec/SpecHelper.hs index 8b9e49edae..7cef61bead 100644 --- a/test/spec/SpecHelper.hs +++ b/test/spec/SpecHelper.hs @@ -206,9 +206,6 @@ testCfgAsymJWKSet = , configJWKS = parseSecret <$> secret } -testCfgJwtCaching :: AppConfig -testCfgJwtCaching = baseCfg { configJwtCaching = True , configDbPlanEnabled = True } - testNonexistentSchemaCfg :: AppConfig testNonexistentSchemaCfg = baseCfg { configDbSchemas = fromList ["nonexistent"] }