Skip to content

Commit

Permalink
version 0.1.3.5: OIDC & role-based auth (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
oddsome authored Jun 8, 2021
1 parent 5545b9f commit e0665b6
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 107 deletions.
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- added OpenID Connect authorization support for servant;
- added role based authentication.

## [0.1.3.5] - 2021-06-08
### Changed
- CBDINFRA-318: added OpenID Connect authorization support for servant;
- CBDINFRA-318: added role based authentication on top of OIDC.

## [0.1.3.4] - 2021-04-18
### Changed
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# web-template

This is library that encapsulate settings and error-catching for REST-services.
This library encapsulates settings and error-catching for REST-services.

Convention, that are inside:

* every route has the following structure: `HOST:PORT/v{PATH VERSION}/PATH`;
* every path can be under authorization. Authorization means that server will look for the field `id` in Cookies.
* every path can be under authorization. Currently, there are cookie-based auth, OpenID-Connect auth and role-based auth on top of OpenID-Connect.

## Example

Expand Down
15 changes: 8 additions & 7 deletions app/ServantApp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,25 @@
{-# LANGUAGE TypeOperators #-}

import Data.Aeson (encode)
import Data.Maybe (fromJust)
import Data.OpenApi (OpenApi)
import Data.Proxy (Proxy (..))
import Data.Text (Text)
import Network.URI (parseURI)
import Servant (Description, Get, Handler, JSON, PlainText, Post, ReqBody,
Summary, (:<|>) (..), (:>))
import Servant.OpenApi (toOpenApi)
import Servant.Server.Internal.Context (Context (..))

import Web.Template.Servant (OIDCAuth, OIDCConfig (..), SwaggerSchemaUI, UserId (..), Version,
defaultOIDCCfg, runServantServerWithContext, swaggerSchemaUIServer)
import Web.Template.Servant (OIDCAuth, OIDCConfig (..), Permit, SwaggerSchemaUI, UserId (..),
Version, defaultOIDCCfg, runServantServerWithContext,
swaggerSchemaUIServer)
import Web.Template.Wai (defaultHandleLog, defaultHeaderCORS)


type API = Version "1" :>
( Summary "ping route" :> Description "Returns pong" :> "ping" :> Get '[PlainText] Text
:<|> OIDCAuth :>
( Summary "hello route" :> Description "Returns hello + user id" :> "hello" :> Get '[PlainText] Text
:<|> "post" :> ReqBody '[JSON] Int :> Post '[JSON] Text
:<|> Permit '["set role here", "or here"] :> "post" :> ReqBody '[JSON] Int :> Post '[JSON] Text
)
)

Expand All @@ -46,7 +46,8 @@ main = do
id
(defaultHeaderCORS . defaultHandleLog)
5000
(cfg {oidcWorkaroundUri = uri} :. EmptyContext )
(cfg {oidcIssuer = uri, oidcClientId = cId} :. EmptyContext )
$ swaggerSchemaUIServer swagger :<|> (pingH :<|> (\userId -> helloH userId :<|> postH userId))
where
uri = fromJust $ parseURI "https:// . "
uri = error "set uri here"
cId = error "set client id here"
5 changes: 2 additions & 3 deletions cabal.project
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
packages: *.cabal
packages:
*.cabal

package web-template
ghc-options: -Wall

allow-newer: openid-connect:aeson, http-client
200 changes: 113 additions & 87 deletions src/Web/Template/Servant/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,25 @@ module Web.Template.Servant.Auth
-- after https://www.stackage.org/haddock/lts-15.15/servant-server-0.16.2/src/Servant.Server.Experimental.Auth.html

import Control.Applicative ((<|>))
import Control.Lens (At (at), ix, (&), (.~), (<&>), (?~), (^?!), (^?), (^..))
import Control.Lens (At (at), ix, (&), (.~), (<&>), (?~), (^..), (^?))
import Control.Monad.Except (runExceptT, unless)
import Control.Monad.IO.Class (liftIO)
import Data.IORef (writeIORef, readIORef)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.IORef (readIORef, writeIORef)
import Data.Maybe (catMaybes)
import Data.Proxy (Proxy (..))
import Data.Text (Text, pack, intercalate)
import Data.Text (Text, intercalate, pack)
import Data.Text.Encoding (decodeUtf8)
import qualified Data.Vault.Lazy as V
import GHC.Generics (Generic)
import GHC.TypeLits (KnownSymbol, Symbol, symbolVal)
import GHC.TypeLits (KnownSymbol, Symbol, symbolVal)

import Crypto.JOSE.JWK (JWKSet)
import Crypto.JWT (JWTError, JWTValidationSettings, audiencePredicate,
decodeCompact, defaultJWTValidationSettings,
issuerPredicate, string, unregisteredClaims, uri,
verifyClaims)
import Crypto.JWT (ClaimsSet, JWTError, JWTValidationSettings,
SignedJWT, audiencePredicate, decodeCompact,
defaultJWTValidationSettings, issuerPredicate,
string, unregisteredClaims, uri, verifyClaims)
import Data.Aeson.Lens (AsPrimitive (_String), key, values)
import Data.ByteString (stripPrefix)
import Data.ByteString (ByteString, stripPrefix)
import qualified Data.ByteString.Lazy as LB
import Data.Cache (Cache)
import qualified Data.Cache as Cache
Expand All @@ -47,14 +47,13 @@ import Data.Time.Clock (UTCTime, diffUTCTime, getCurren
import Network.HTTP.Client (Manager, httpLbs)
import Network.HTTP.Client.TLS (newTlsManager)
import Network.HTTP.Types.Header (hContentType)
import Network.URI (URI)
import Network.URI (URI (..))
import Network.Wai (Request, requestHeaders, vault)
import OpenID.Connect.Client.Provider (Discovery (Discovery, jwksUri), keysFromDiscovery)
import qualified OpenID.Connect.Client.Provider as OIDC
import OpenID.Connect.Client.Provider (Discovery, discovery, keysFromDiscovery)
import Servant.API ((:>))
import Servant.OpenApi (HasOpenApi (..))
import Servant.Server (HasContextEntry (getContextEntry), HasServer (..),
ServerError (..), err401, err403)
ServerError (..), err401, err403, err500)
import Servant.Server.Internal (DelayedIO, addAuthCheck, delayedFailFatal,
withRequest)
import System.Clock (TimeSpec (..))
Expand Down Expand Up @@ -122,28 +121,30 @@ data OIDCAuth

-- | Info needed for OIDC authorization & key cache
data OIDCConfig = OIDCConfig
{ oidcManager :: Manager -- ^ https manager
, oidcClientId :: Text -- ^ audience
, oidcIssuer :: URI -- ^ discovery uri
, oidcWorkaroundUri :: URI -- ^ temporary solution to openid-connect issue
, oidcKeyCache :: Cache Text JWKSet -- ^ cache - storing validation keys
{ oidcManager :: Manager -- ^ https manager
, oidcClientId :: Text -- ^ audience
, oidcIssuer :: URI -- ^ discovery uri
, oidcDiscoCache :: Cache () Discovery -- ^ cache - storing discovery information
, oidcKeyCache :: Cache () JWKSet -- ^ cache - storing validation keys
}

defaultOIDCCfg :: IO OIDCConfig
defaultOIDCCfg :: MonadIO m => m OIDCConfig
defaultOIDCCfg = do
cache <- Cache.newCache (Just 0)
discoCache <- liftIO $ Cache.newCache $ Just 0
keyCache <- liftIO $ Cache.newCache $ Just 0
mgr <- newTlsManager
return $ OIDCConfig
{ oidcManager = mgr
, oidcKeyCache = cache
, oidcWorkaroundUri = error "workaround uri not set"
, oidcDiscoCache = discoCache
, oidcKeyCache = keyCache
, oidcIssuer = error "discovery uri not set"
, oidcClientId = error "client id not set"
}

instance ( HasServer api context
, HasContextEntry context OIDCConfig
) => HasServer (OIDCAuth :> api) context where

type ServerT (OIDCAuth :> api) m = UserId -> ServerT api m

hoistServerWithContext _ pc nt s = hoistServerWithContext @api Proxy pc nt . s
Expand All @@ -153,51 +154,22 @@ instance ( HasServer api context
$ addAuthCheck sub
$ withRequest $ \req -> do

token <- maybe unauth401 return (getToken req)
token <- maybe unauth401 return $ getToken req

jws <- case decodeToken token of
Left jwtErr -> do
logWarn $ show jwtErr
unauth401
Right jws -> return jws
jwt <- getJWT token

let OIDCConfig {..} = getContextEntry context
let cfg = getContextEntry context

jwkSet <- liftIO (Cache.lookup oidcKeyCache "jwkSet") >>= \case
Nothing -> liftIO
( keysFromDiscovery
(https oidcManager)
(Discovery {jwksUri = OIDC.URI $ oidcWorkaroundUri})
) >>= \case
Left jwtErr -> do
logWarn $ show jwtErr
unauth401
Right (jwkSet, mbKeysExp) -> liftIO $ do
now <- getCurrentTime
Cache.insert' oidcKeyCache
(diffTime <$> (mbKeysExp <|> pure now) <*> pure now)
"jwkSet"
jwkSet
return jwkSet
Just jwkSet -> return jwkSet

claims <- liftIO
( runExceptT $
verifyClaims @_ @_ @JWTError
(jwtValidation oidcIssuer oidcClientId)
jwkSet
jws
) >>= \case
Left jwtErr -> do
logWarn $ show jwtErr
unauth401
Right claims -> return claims

uid <- case claims ^? unregisteredClaims . ix "object_guid" . _String of
Nothing -> do
logErr ("No object_guid found" :: Text)
unauth401
Just uid -> return uid
disco <- getDisco cfg

jwkSet <- getJWKSet cfg disco

claims <- getClaims cfg jwt jwkSet

uid <- maybe
(die ERROR unauth401 ("No object_guid found" :: Text))
return
$ claims ^? unregisteredClaims . ix "object_guid" . _String

liftIO $ sequence_ $ catMaybes
[ userIdVaultKey <?> req <&> flip writeIORef (Just uid)
Expand All @@ -209,27 +181,77 @@ instance ( HasServer api context
where
https mgr = (`httpLbs` mgr)

die :: Show err => Level -> DelayedIO b -> err -> DelayedIO b
die lvl fin err = liftIO (log' lvl ("web-template" :: Text) $ show err) >> fin

getToken :: Request -> Maybe ByteString
getToken r = lookup "Authorization" (requestHeaders r) >>= stripPrefix "Bearer "

decodeToken = decodeCompact @_ @JWTError . LB.fromStrict
expiration :: UTCTime -> Maybe UTCTime -> Maybe TimeSpec
expiration now ex = diffTime
<$> (ex <|> pure now)
<*> pure now
where
tTreshold = 60 -- consider token expired 'tTreshold' seconds earlier

logWarn = liftIO . log' WARNING ("web-template" :: Text)
diffTime :: UTCTime -> UTCTime -> TimeSpec
diffTime from to = let
diff = diffUTCTime from to - tTreshold
in max
TimeSpec {sec = 0, nsec = 0}
TimeSpec {sec = floor $ nominalDiffTimeToSeconds diff, nsec = 0}

logErr = liftIO . log' ERROR ("web-template" :: Text)
getJWT :: ByteString -> DelayedIO SignedJWT
getJWT = either (die WARNING unauth401) return . decodeToken
where
decodeToken = decodeCompact @_ @JWTError . LB.fromStrict

diffTime :: UTCTime -> UTCTime -> TimeSpec
diffTime from to = let
diff = diffUTCTime from to - tTreshold
in max
TimeSpec {sec = 0, nsec = 0}
TimeSpec {sec = floor $ nominalDiffTimeToSeconds diff, nsec = 0}
getDisco :: OIDCConfig -> DelayedIO Discovery
getDisco OIDCConfig {..} = liftIO (Cache.lookup oidcDiscoCache ())
>>= maybe
fetchDisco
return
where
tTreshold = 60 -- consider token expired 'tTreshold' seconds earlier
fetchDisco = liftIO (discovery (https oidcManager) (appWellKnown oidcIssuer))
>>= either
(die ERROR unauth500)
(uncurry discoSuccess)
where
appWellKnown u@URI {..} = u {uriPath = uriPath <> "/.well-known/openid-configuration"}

discoSuccess disco mbDiscoExp = liftIO $ do
now <- getCurrentTime
Cache.insert' oidcDiscoCache (expiration now mbDiscoExp) () disco
return disco

getJWKSet :: OIDCConfig -> Discovery -> DelayedIO JWKSet
getJWKSet OIDCConfig {..} disco = liftIO (Cache.lookup oidcKeyCache ())
>>= maybe
fetchKeys
return
where
fetchKeys = liftIO (keysFromDiscovery (https oidcManager) disco)
>>= either
(die ERROR unauth500)
(uncurry keysSuccess)
where
keysSuccess jwkSet mbKeysExp = liftIO $ do
now <- getCurrentTime
Cache.insert' oidcKeyCache (expiration now mbKeysExp) () jwkSet
return jwkSet

jwtValidation :: URI -> Text -> JWTValidationSettings
jwtValidation issuer audience = defaultJWTValidationSettings (const True)
& issuerPredicate .~ (\iss -> iss ^? uri == Just issuer)
& audiencePredicate .~ (\aud -> aud ^? string == Just audience)
getClaims :: OIDCConfig -> SignedJWT -> JWKSet -> DelayedIO ClaimsSet
getClaims OIDCConfig {..} jwt jwkSet = liftIO
(runExceptT $
verifyClaims @_ @_ @JWTError (jwtValidation oidcIssuer oidcClientId) jwkSet jwt
) >>= either
(die ERROR unauth401)
return
where
jwtValidation :: URI -> Text -> JWTValidationSettings
jwtValidation issuer audience = defaultJWTValidationSettings (const True)
& issuerPredicate .~ (\iss -> iss ^? uri == Just issuer)
& audiencePredicate .~ (\aud -> aud ^? string == Just audience)

instance HasOpenApi api => HasOpenApi (OIDCAuth :> api) where
toOpenApi _ = toOpenApi @api Proxy
Expand All @@ -245,12 +267,14 @@ instance HasOpenApi api => HasOpenApi (OIDCAuth :> api) where
--
-- Usage:
--
-- > type API = Permit '["User", "Owner"] :> (....)
-- route access permitted if user has at least 1 role from specified list
-- > type API = Permit '["user", "owner"] :> (....)
-- Route access permitted if user has at least 1 role from specified list
-- Takes user roles from vault (originated from jwt token)
data Permit (rs :: [Symbol])

instance ( HasServer api context
, KnownSymbols roles
, HasContextEntry context OIDCConfig
) => HasServer (Permit roles :> api) context where

type ServerT (Permit roles :> api) m = ServerT api m
Expand All @@ -265,14 +289,12 @@ instance ( HasServer api context

claims <- liftIO (readIORef pTokenRef) >>= maybe unauth401 return

let azp = claims
^?! unregisteredClaims
. ix "azp" . _String
let OIDCConfig {..} = getContextEntry context

let haveRoles = claims
^.. unregisteredClaims
. ix "resource_access"
. key azp
. key oidcClientId
. key "roles"
. values . _String

Expand All @@ -290,8 +312,6 @@ instance ( HasOpenApi api
descr = "Action not permitted. Allowed for: "
<> intercalate ", " (symbolsVal (Proxy :: Proxy roles))

data Permit (rs :: [Symbol])

class KnownSymbols (rs :: [Symbol]) where
symbolsVal :: p rs -> [Text]

Expand All @@ -315,3 +335,9 @@ unauth403 = delayedFailFatal $ err403
{ errBody = "{\"error\": \"Action not permitted\"}"
, errHeaders = [(hContentType, "application/json")]
}

unauth500 :: DelayedIO a
unauth500 = delayedFailFatal $ err500
{ errBody = "{\"error\": \"Internal server error\"}"
, errHeaders = [(hContentType, "application/json")]
}
4 changes: 1 addition & 3 deletions stack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ packages:
- '.'

# necessary extra-deps that are not included in the BCD-LTS
extra-deps:
- openid-connect-0.1.0.0

extra-deps: []
allow-newer: true

flags: {}
Expand Down
Loading

0 comments on commit e0665b6

Please sign in to comment.