Skip to content

Commit

Permalink
Make error messages from combinators configurable
Browse files Browse the repository at this point in the history
Currently there is no way for Servant users to customize formatting of
error messages that arise when combinators can't parse URL or request
body, apart from reimplementing those combinators for themselves or
using middlewares.

This commit adds a possibility to specify custom error formatters
through Context.

Fixes haskell-servant#685
  • Loading branch information
maksbotan committed Jul 17, 2020
1 parent 7f4ae61 commit df2f164
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 75 deletions.
3 changes: 2 additions & 1 deletion servant-server/servant-server.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ library
Servant.Server.Internal.Context
Servant.Server.Internal.Delayed
Servant.Server.Internal.DelayedIO
Servant.Server.Internal.ErrorFormatter
Servant.Server.Internal.Handler
Servant.Server.Internal.Router
Servant.Server.Internal.RouteResult
Servant.Server.Internal.Router
Servant.Server.Internal.RoutingApplication
Servant.Server.Internal.ServerError
Servant.Server.StaticFiles
Expand Down
32 changes: 30 additions & 2 deletions servant-server/src/Servant/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- | This module lets you implement 'Server's for defined APIs. You'll
-- most likely just need 'serve'.
Expand Down Expand Up @@ -35,6 +36,8 @@ module Servant.Server
-- * Context
, Context(..)
, HasContextEntry(getContextEntry)
, type (.++)
, (.++)
-- ** NamedContext
, NamedContext(..)
, descendIntoNamedContext
Expand Down Expand Up @@ -86,6 +89,24 @@ module Servant.Server
, err504
, err505

-- * Formatting of errors from combinators
--
-- | You can configure how Servant will render errors that occur while parsing the request.

, ErrorFormatter
, NotFoundErrorFormatter
, ErrorFormatters

, bodyParserErrorFormatter
, urlParseErrorFormatter
, headerParseErrorFormatter
, notFoundErrorFormatter

, DefaultErrorFormatters
, defaultErrorFormatters

, getAcceptHeader

-- * Re-exports
, Application
, Tagged (..)
Expand Down Expand Up @@ -129,10 +150,17 @@ import Servant.Server.Internal
serve :: (HasServer api '[]) => Proxy api -> Server api -> Application
serve p = serveWithContext p EmptyContext

serveWithContext :: (HasServer api context)
-- | Like 'serve', but allows you to pass custom context.
--
-- 'defaultErrorFormatters' will always be appended to the end of the passed context,
-- but if you pass your own formatter, it will override the default one.
serveWithContext :: ( HasServer api context
, HasContextEntry (context .++ DefaultErrorFormatters) ErrorFormatters )
=> Proxy api -> Context context -> Server api -> Application
serveWithContext p context server =
toApplication (runRouter (route p context (emptyDelayed (Route server))))
toApplication (runRouter format404 (route p context (emptyDelayed (Route server))))
where
format404 = notFoundErrorFormatter . getContextEntry . mkContextWithErrorFormatter $ context

-- | Hoist server implementation.
--
Expand Down
1 change: 1 addition & 0 deletions servant-server/src/Servant/Server/Generic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ genericServeTWithContext
( GenericServant routes (AsServerT m)
, GenericServant routes AsApi
, HasServer (ToServantApi routes) ctx
, HasContextEntry (ctx .++ DefaultErrorFormatters) ErrorFormatters
, ServerT (ToServantApi routes) m ~ ToServant routes (AsServerT m)
)
=> (forall a. m a -> Handler a) -- ^ 'hoistServer' argument to come back to 'Handler'
Expand Down
114 changes: 71 additions & 43 deletions servant-server/src/Servant/Server/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ module Servant.Server.Internal
, module Servant.Server.Internal.Context
, module Servant.Server.Internal.Delayed
, module Servant.Server.Internal.DelayedIO
, module Servant.Server.Internal.ErrorFormatter
, module Servant.Server.Internal.Handler
, module Servant.Server.Internal.Router
, module Servant.Server.Internal.RouteResult
Expand Down Expand Up @@ -95,6 +96,7 @@ import Servant.Server.Internal.BasicAuth
import Servant.Server.Internal.Context
import Servant.Server.Internal.Delayed
import Servant.Server.Internal.DelayedIO
import Servant.Server.Internal.ErrorFormatter
import Servant.Server.Internal.Handler
import Servant.Server.Internal.Router
import Servant.Server.Internal.RouteResult
Expand Down Expand Up @@ -168,7 +170,10 @@ instance (HasServer a context, HasServer b context) => HasServer (a :<|> b) cont
-- > server = getBook
-- > where getBook :: Text -> Handler Book
-- > getBook isbn = ...
instance (KnownSymbol capture, FromHttpApiData a, HasServer api context, SBoolI (FoldLenient mods))
instance (KnownSymbol capture, FromHttpApiData a
, HasServer api context, SBoolI (FoldLenient mods)
, HasContextEntry (MkContextWithErrorFormatter context) ErrorFormatters
)
=> HasServer (Capture' mods capture a :> api) context where

type ServerT (Capture' mods capture a :> api) m =
Expand All @@ -180,12 +185,15 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer api context, SBoolI
CaptureRouter $
route (Proxy :: Proxy api)
context
(addCapture d $ \ txt -> case ( sbool :: SBool (FoldLenient mods)
, parseUrlPiece txt :: Either T.Text a) of
(SFalse, Left e) -> delayedFail err400 { errBody = cs e }
(SFalse, Right v) -> return v
(STrue, piece) -> return $ (either (Left . cs) Right) piece
)
(addCapture d $ \ txt -> withRequest $ \ request ->
case ( sbool :: SBool (FoldLenient mods)
, parseUrlPiece txt :: Either T.Text a) of
(SFalse, Left e) -> delayedFail $ formatError rep request $ cs e
(SFalse, Right v) -> return v
(STrue, piece) -> return $ (either (Left . cs) Right) piece)
where
rep = typeRep (Proxy :: Proxy Capture')
formatError = urlParseErrorFormatter $ getContextEntry (mkContextWithErrorFormatter context)

-- | If you use 'CaptureAll' in one of the endpoints for your API,
-- this automatically requires your server-side handler to be a
Expand All @@ -204,7 +212,10 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer api context, SBoolI
-- > server = getSourceFile
-- > where getSourceFile :: [Text] -> Handler Book
-- > getSourceFile pathSegments = ...
instance (KnownSymbol capture, FromHttpApiData a, HasServer api context)
instance (KnownSymbol capture, FromHttpApiData a
, HasServer api context
, HasContextEntry (MkContextWithErrorFormatter context) ErrorFormatters
)
=> HasServer (CaptureAll capture a :> api) context where

type ServerT (CaptureAll capture a :> api) m =
Expand All @@ -216,11 +227,14 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer api context)
CaptureAllRouter $
route (Proxy :: Proxy api)
context
(addCapture d $ \ txts -> case parseUrlPieces txts of
Left _ -> delayedFail err400
Right v -> return v
(addCapture d $ \ txts -> withRequest $ \ request ->
case parseUrlPieces txts of
Left e -> delayedFail $ formatError rep request $ cs e
Right v -> return v
)

where
rep = typeRep (Proxy :: Proxy CaptureAll)
formatError = urlParseErrorFormatter $ getContextEntry (mkContextWithErrorFormatter context)

allowedMethodHead :: Method -> Request -> Bool
allowedMethodHead method request = method == methodGet && requestMethod request == methodHead
Expand All @@ -240,10 +254,10 @@ methodCheck method request
-- body check is no longer an option. However, we now run the accept
-- check before the body check and can therefore afford to make it
-- recoverable.
acceptCheck :: (AllMime list) => Proxy list -> B.ByteString -> DelayedIO ()
acceptCheck :: (AllMime list) => Proxy list -> AcceptHeader -> DelayedIO ()
acceptCheck proxy accH
| canHandleAcceptH proxy (AcceptHeader accH) = return ()
| otherwise = delayedFail err406
| canHandleAcceptH proxy accH = return ()
| otherwise = delayedFail err406

methodRouter :: (AllCTRender ctypes a)
=> (b -> ([(HeaderName, B.ByteString)], a))
Expand All @@ -253,12 +267,12 @@ methodRouter :: (AllCTRender ctypes a)
methodRouter splitHeaders method proxy status action = leafRouter route'
where
route' env request respond =
let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request
let accH = getAcceptHeader request
in runAction (action `addMethodCheck` methodCheck method request
`addAcceptCheck` acceptCheck proxy accH
) env request respond $ \ output -> do
let (headers, b) = splitHeaders output
case handleAcceptH proxy (AcceptHeader accH) b of
case handleAcceptH proxy accH b of
Nothing -> FailFatal err406 -- this should not happen (checked before), so we make it fatal if it does
Just (contentT, body) ->
let bdy = if allowedMethodHead method request then "" else body
Expand Down Expand Up @@ -343,7 +357,7 @@ streamRouter :: forall ctype a c chunk env framing. (MimeRender ctype chunk, Fra
-> Delayed env (Handler c)
-> Router env
streamRouter splitHeaders method status framingproxy ctypeproxy action = leafRouter $ \env request respond ->
let accH = fromMaybe ct_wildcard $ lookup hAccept $ requestHeaders request
let AcceptHeader accH = getAcceptHeader request
cmediatype = NHM.matchAccept [contentType ctypeproxy] accH
accCheck = when (isNothing cmediatype) $ delayedFail err406
contentHeader = (hContentType, NHM.renderHeader . maybeToList $ cmediatype)
Expand Down Expand Up @@ -388,6 +402,7 @@ streamRouter splitHeaders method status framingproxy ctypeproxy action = leafRou
instance
(KnownSymbol sym, FromHttpApiData a, HasServer api context
, SBoolI (FoldRequired mods), SBoolI (FoldLenient mods)
, HasContextEntry (MkContextWithErrorFormatter context) ErrorFormatters
)
=> HasServer (Header' mods sym a :> api) context where
------
Expand All @@ -399,6 +414,9 @@ instance
route Proxy context subserver = route (Proxy :: Proxy api) context $
subserver `addHeaderCheck` withRequest headerCheck
where
rep = typeRep (Proxy :: Proxy Header')
formatError = headerParseErrorFormatter $ getContextEntry (mkContextWithErrorFormatter context)

headerName :: IsString n => n
headerName = fromString $ symbolVal (Proxy :: Proxy sym)

Expand All @@ -409,15 +427,13 @@ instance
mev :: Maybe (Either T.Text a)
mev = fmap parseHeader $ lookup headerName (requestHeaders req)

errReq = delayedFailFatal err400
{ errBody = "Header " <> headerName <> " is required"
}
errReq = delayedFailFatal $ formatError rep req
$ "Header " <> headerName <> " is required"

errSt e = delayedFailFatal err400
{ errBody = cs $ "Error parsing header "
<> headerName
<> " failed: " <> e
}
errSt e = delayedFailFatal $ formatError rep req
$ cs $ "Error parsing header "
<> headerName
<> " failed: " <> e

-- | If you use @'QueryParam' "author" Text@ in one of the endpoints for your API,
-- this automatically requires your server-side handler to be a function
Expand All @@ -443,6 +459,7 @@ instance
instance
( KnownSymbol sym, FromHttpApiData a, HasServer api context
, SBoolI (FoldRequired mods), SBoolI (FoldLenient mods)
, HasContextEntry (MkContextWithErrorFormatter context) ErrorFormatters
)
=> HasServer (QueryParam' mods sym a :> api) context where
------
Expand All @@ -455,21 +472,22 @@ instance
let querytext = queryToQueryText . queryString
paramname = cs $ symbolVal (Proxy :: Proxy sym)

rep = typeRep (Proxy :: Proxy QueryParam')
formatError = urlParseErrorFormatter $ getContextEntry (mkContextWithErrorFormatter context)

parseParam :: Request -> DelayedIO (RequestArgument mods a)
parseParam req =
unfoldRequestArgument (Proxy :: Proxy mods) errReq errSt mev
where
mev :: Maybe (Either T.Text a)
mev = fmap parseQueryParam $ join $ lookup paramname $ querytext req

errReq = delayedFailFatal err400
{ errBody = cs $ "Query parameter " <> paramname <> " is required"
}
errReq = delayedFailFatal $ formatError rep req
$ cs $ "Query parameter " <> paramname <> " is required"

errSt e = delayedFailFatal err400
{ errBody = cs $ "Error parsing query parameter "
<> paramname <> " failed: " <> e
}
errSt e = delayedFailFatal $ formatError rep req
$ cs $ "Error parsing query parameter "
<> paramname <> " failed: " <> e

delayed = addParameterCheck subserver . withRequest $ \req ->
parseParam req
Expand All @@ -495,7 +513,8 @@ instance
-- > server = getBooksBy
-- > where getBooksBy :: [Text] -> Handler [Book]
-- > getBooksBy authors = ...return all books by these authors...
instance (KnownSymbol sym, FromHttpApiData a, HasServer api context)
instance (KnownSymbol sym, FromHttpApiData a, HasServer api context
, HasContextEntry (MkContextWithErrorFormatter context) ErrorFormatters)
=> HasServer (QueryParams sym a :> api) context where

type ServerT (QueryParams sym a :> api) m =
Expand All @@ -506,21 +525,23 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context)
route Proxy context subserver = route (Proxy :: Proxy api) context $
subserver `addParameterCheck` withRequest paramsCheck
where
rep = typeRep (Proxy :: Proxy QueryParams)
formatError = urlParseErrorFormatter $ getContextEntry (mkContextWithErrorFormatter context)

paramname = cs $ symbolVal (Proxy :: Proxy sym)
paramsCheck req =
case partitionEithers $ fmap parseQueryParam params of
([], parsed) -> return parsed
(errs, _) -> delayedFailFatal err400
{ errBody = cs $ "Error parsing query parameter(s) "
<> paramname <> " failed: "
<> T.intercalate ", " errs
}
(errs, _) -> delayedFailFatal $ formatError rep req
$ cs $ "Error parsing query parameter(s) "
<> paramname <> " failed: "
<> T.intercalate ", " errs
where
params :: [T.Text]
params = mapMaybe snd
. filter (looksLikeParam . fst)
. queryToQueryText
. queryString
. queryToQueryText
. queryString
$ req

looksLikeParam name = name == paramname || name == (paramname <> "[]")
Expand Down Expand Up @@ -588,7 +609,7 @@ instance HasServer Raw context where
-- The @Content-Type@ header is inspected, and the list provided is used to
-- attempt deserialization. If the request does not have a @Content-Type@
-- header, it is treated as @application/octet-stream@ (as specified in
-- <http://tools.ietf.org/html/rfc7231#section-3.1.1.5 RFC7231>.
-- [RFC 7231 section 3.1.1.5](http://tools.ietf.org/html/rfc7231#section-3.1.1.5)).
-- This lets servant worry about extracting it from the request and turning
-- it into a value of the type you specify.
--
Expand All @@ -604,6 +625,7 @@ instance HasServer Raw context where
-- > where postBook :: Book -> Handler Book
-- > postBook book = ...insert into your db...
instance ( AllCTUnrender list a, HasServer api context, SBoolI (FoldLenient mods)
, HasContextEntry (MkContextWithErrorFormatter context) ErrorFormatters
) => HasServer (ReqBody' mods list a :> api) context where

type ServerT (ReqBody' mods list a :> api) m =
Expand All @@ -615,6 +637,9 @@ instance ( AllCTUnrender list a, HasServer api context, SBoolI (FoldLenient mods
= route (Proxy :: Proxy api) context $
addBodyCheck subserver ctCheck bodyCheck
where
rep = typeRep (Proxy :: Proxy ReqBody')
formatError = bodyParserErrorFormatter $ getContextEntry (mkContextWithErrorFormatter context)

-- Content-Type check, we only lookup we can try to parse the request body
ctCheck = withRequest $ \ request -> do
-- See HTTP RFC 2616, section 7.2.1
Expand All @@ -633,7 +658,7 @@ instance ( AllCTUnrender list a, HasServer api context, SBoolI (FoldLenient mods
case sbool :: SBool (FoldLenient mods) of
STrue -> return mrqbody
SFalse -> case mrqbody of
Left e -> delayedFailFatal err400 { errBody = cs e }
Left e -> delayedFailFatal $ formatError rep request e
Right v -> return v

instance
Expand Down Expand Up @@ -761,6 +786,9 @@ instance ( KnownSymbol realm
ct_wildcard :: B.ByteString
ct_wildcard = "*" <> "/" <> "*" -- Because CPP

getAcceptHeader :: Request -> AcceptHeader
getAcceptHeader = AcceptHeader . fromMaybe ct_wildcard . lookup hAccept . requestHeaders

-- * General Authentication


Expand Down
Loading

0 comments on commit df2f164

Please sign in to comment.