Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow empty contexts. #1

Open
wants to merge 21 commits into
base: maksbotan/configurable-combinator-errors
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Make error messages from combinators configurable
Currently there is no way for Servant users to customize formatting of
error messages that arise when combinators can't parse URL or request
body, apart from reimplementing those combinators for themselves or
using middlewares.

This commit adds a possibility to specify custom error formatters
through Context.

Fixes haskell-servant#685
  • Loading branch information
maksbotan committed Jun 25, 2020
commit 2a39394a70543ca3cab4e5f0fc8abb23e5780f97
2 changes: 1 addition & 1 deletion servant-client/test/Servant/WrappedApiSpec.hs
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@ spec = describe "Servant.WrappedApiSpec" $ do
wrappedApiSpec

data WrappedApi where
WrappedApi :: (HasServer (api :: *) '[], Server api ~ Handler a,
WrappedApi :: (HasServer (api :: *) DefaultErrorFormatters, Server api ~ Handler a,
HasClient ClientM api, Client ClientM api ~ ClientM ()) =>
Proxy api -> WrappedApi

2 changes: 1 addition & 1 deletion servant-http-streams/test/Servant/ClientSpec.hs
Original file line number Diff line number Diff line change
@@ -400,7 +400,7 @@ failSpec = beforeAll (startWaiApp failServer) $ afterAll endWaiApp $ do
_ -> fail $ "expected InvalidContentTypeHeader, but got " <> show res

data WrappedApi where
WrappedApi :: (HasServer (api :: *) '[], Server api ~ Handler a,
WrappedApi :: (HasServer (api :: *) DefaultErrorFormatters, Server api ~ Handler a,
HasClient ClientM api, Client ClientM api ~ ClientM ()) =>
Proxy api -> WrappedApi

3 changes: 2 additions & 1 deletion servant-server/servant-server.cabal
Original file line number Diff line number Diff line change
@@ -50,9 +50,10 @@ library
Servant.Server.Internal.Context
Servant.Server.Internal.Delayed
Servant.Server.Internal.DelayedIO
Servant.Server.Internal.ErrorFormatter
Servant.Server.Internal.Handler
Servant.Server.Internal.Router
Servant.Server.Internal.RouteResult
Servant.Server.Internal.Router
Servant.Server.Internal.RoutingApplication
Servant.Server.Internal.ServerError
Servant.Server.StaticFiles
24 changes: 18 additions & 6 deletions servant-server/src/Servant/Server.hs
Original file line number Diff line number Diff line change
@@ -86,6 +86,18 @@ module Servant.Server
, err504
, err505

-- * Formatting of errors from combinators
, ErrorFormatter
, BodyParseErrorFormatter (..)
, defaulyBodyParseErrorFormatter
, URLParseErrorFormatter (..)
, defaultURLParseErrorFormatter
, HeaderParseErrorFormatter (..)
, defaultHeaderParseErrorFormatter

, DefaultErrorFormatters
, defaultErrorFormatters

-- * Re-exports
, Application
, Tagged (..)
@@ -126,8 +138,8 @@ import Servant.Server.Internal
-- > main :: IO ()
-- > main = Network.Wai.Handler.Warp.run 8080 app
--
serve :: (HasServer api '[]) => Proxy api -> Server api -> Application
serve p = serveWithContext p EmptyContext
serve :: (HasServer api DefaultErrorFormatters) => Proxy api -> Server api -> Application
serve p = serveWithContext p defaultErrorFormatters

serveWithContext :: (HasServer api context)
=> Proxy api -> Context context -> Server api -> Application
@@ -154,9 +166,9 @@ serveWithContext p context server =
-- >>> let nt x = return (runReader x "hi")
-- >>> let mainServer = hoistServer readerApi nt readerServer :: Server ReaderAPI
--
hoistServer :: (HasServer api '[]) => Proxy api
hoistServer :: (HasServer api DefaultErrorFormatters) => Proxy api
-> (forall x. m x -> n x) -> ServerT api m -> ServerT api n
hoistServer p = hoistServerWithContext p (Proxy :: Proxy '[])
hoistServer p = hoistServerWithContext p (Proxy :: Proxy DefaultErrorFormatters)

-- | The function 'layout' produces a textual description of the internal
-- router layout for debugging purposes. Note that the router layout is
@@ -209,8 +221,8 @@ hoistServer p = hoistServerWithContext p (Proxy :: Proxy '[])
-- that one takes precedence. If both parts fail, the \"better\" error
-- code will be returned.
--
layout :: (HasServer api '[]) => Proxy api -> Text
layout p = layoutWithContext p EmptyContext
layout :: (HasServer api DefaultErrorFormatters) => Proxy api -> Text
layout p = layoutWithContext p defaultErrorFormatters

-- | Variant of 'layout' that takes an additional 'Context'.
layoutWithContext :: (HasServer api context)
4 changes: 2 additions & 2 deletions servant-server/src/Servant/Server/Generic.hs
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ type AsServer = AsServerT Handler
-- | Transform a record of routes into a WAI 'Application'.
genericServe
:: forall routes.
( HasServer (ToServantApi routes) '[]
( HasServer (ToServantApi routes) DefaultErrorFormatters
, GenericServant routes AsServer
, Server (ToServantApi routes) ~ ToServant routes AsServer
)
@@ -48,7 +48,7 @@ genericServeT
:: forall (routes :: * -> *) (m :: * -> *).
( GenericServant routes (AsServerT m)
, GenericServant routes AsApi
, HasServer (ToServantApi routes) '[]
, HasServer (ToServantApi routes) DefaultErrorFormatters
, ServerT (ToServantApi routes) m ~ ToServant routes (AsServerT m)
)
=> (forall a. m a -> Handler a) -- ^ 'hoistServer' argument to come back to 'Handler'
86 changes: 55 additions & 31 deletions servant-server/src/Servant/Server/Internal.hs
Original file line number Diff line number Diff line change
@@ -24,9 +24,10 @@ module Servant.Server.Internal
, module Servant.Server.Internal.Context
, module Servant.Server.Internal.Delayed
, module Servant.Server.Internal.DelayedIO
, module Servant.Server.Internal.ErrorFormatter
, module Servant.Server.Internal.Handler
, module Servant.Server.Internal.Router
, module Servant.Server.Internal.RouteResult
, module Servant.Server.Internal.Router
, module Servant.Server.Internal.RoutingApplication
, module Servant.Server.Internal.ServerError
) where
@@ -95,6 +96,7 @@ import Servant.Server.Internal.BasicAuth
import Servant.Server.Internal.Context
import Servant.Server.Internal.Delayed
import Servant.Server.Internal.DelayedIO
import Servant.Server.Internal.ErrorFormatter
import Servant.Server.Internal.Handler
import Servant.Server.Internal.Router
import Servant.Server.Internal.RouteResult
@@ -168,7 +170,10 @@ instance (HasServer a context, HasServer b context) => HasServer (a :<|> b) cont
-- > server = getBook
-- > where getBook :: Text -> Handler Book
-- > getBook isbn = ...
instance (KnownSymbol capture, FromHttpApiData a, HasServer api context, SBoolI (FoldLenient mods))
instance (KnownSymbol capture, FromHttpApiData a
, HasServer api context, SBoolI (FoldLenient mods)
, HasContextEntry context URLParseErrorFormatter
)
=> HasServer (Capture' mods capture a :> api) context where

type ServerT (Capture' mods capture a :> api) m =
@@ -182,10 +187,13 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer api context, SBoolI
context
(addCapture d $ \ txt -> case ( sbool :: SBool (FoldLenient mods)
, parseUrlPiece txt :: Either T.Text a) of
(SFalse, Left e) -> delayedFail err400 { errBody = cs e }
(SFalse, Left e) -> delayedFail $ urlParseErrorFormatter rep $ cs e
(SFalse, Right v) -> return v
(STrue, piece) -> return $ (either (Left . cs) Right) piece
)
where
rep = typeRep (Proxy :: Proxy Capture')
urlParseErrorFormatter = getUrlParseErrorFormatter $ getContextEntry context

-- | If you use 'CaptureAll' in one of the endpoints for your API,
-- this automatically requires your server-side handler to be a
@@ -204,7 +212,10 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer api context, SBoolI
-- > server = getSourceFile
-- > where getSourceFile :: [Text] -> Handler Book
-- > getSourceFile pathSegments = ...
instance (KnownSymbol capture, FromHttpApiData a, HasServer api context)
instance (KnownSymbol capture, FromHttpApiData a
, HasServer api context
, HasContextEntry context URLParseErrorFormatter
)
=> HasServer (CaptureAll capture a :> api) context where

type ServerT (CaptureAll capture a :> api) m =
@@ -217,10 +228,12 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer api context)
route (Proxy :: Proxy api)
context
(addCapture d $ \ txts -> case parseUrlPieces txts of
Left _ -> delayedFail err400
Left e -> delayedFail $ urlParseErrorFormatter rep $ cs e
Right v -> return v
)

where
rep = typeRep (Proxy :: Proxy CaptureAll)
urlParseErrorFormatter = getUrlParseErrorFormatter $ getContextEntry context

allowedMethodHead :: Method -> Request -> Bool
allowedMethodHead method request = method == methodGet && requestMethod request == methodHead
@@ -388,6 +401,7 @@ streamRouter splitHeaders method status framingproxy ctypeproxy action = leafRou
instance
(KnownSymbol sym, FromHttpApiData a, HasServer api context
, SBoolI (FoldRequired mods), SBoolI (FoldLenient mods)
, HasContextEntry context HeaderParseErrorFormatter
)
=> HasServer (Header' mods sym a :> api) context where
------
@@ -399,6 +413,9 @@ instance
route Proxy context subserver = route (Proxy :: Proxy api) context $
subserver `addHeaderCheck` withRequest headerCheck
where
rep = typeRep (Proxy :: Proxy Header')
headerParseErrorFormatter = getHeaderParseErrorFormatter $ getContextEntry context

headerName :: IsString n => n
headerName = fromString $ symbolVal (Proxy :: Proxy sym)

@@ -409,15 +426,13 @@ instance
mev :: Maybe (Either T.Text a)
mev = fmap parseHeader $ lookup headerName (requestHeaders req)

errReq = delayedFailFatal err400
{ errBody = "Header " <> headerName <> " is required"
}
errReq = delayedFailFatal $ headerParseErrorFormatter rep
$ "Header " <> headerName <> " is required"

errSt e = delayedFailFatal err400
{ errBody = cs $ "Error parsing header "
<> headerName
<> " failed: " <> e
}
errSt e = delayedFailFatal $ headerParseErrorFormatter rep
$ cs $ "Error parsing header "
<> headerName
<> " failed: " <> e

-- | If you use @'QueryParam' "author" Text@ in one of the endpoints for your API,
-- this automatically requires your server-side handler to be a function
@@ -443,6 +458,7 @@ instance
instance
( KnownSymbol sym, FromHttpApiData a, HasServer api context
, SBoolI (FoldRequired mods), SBoolI (FoldLenient mods)
, HasContextEntry context URLParseErrorFormatter
)
=> HasServer (QueryParam' mods sym a :> api) context where
------
@@ -455,21 +471,22 @@ instance
let querytext = queryToQueryText . queryString
paramname = cs $ symbolVal (Proxy :: Proxy sym)

rep = typeRep (Proxy :: Proxy QueryParam')
urlParseErrorFormatter = getUrlParseErrorFormatter $ getContextEntry context

parseParam :: Request -> DelayedIO (RequestArgument mods a)
parseParam req =
unfoldRequestArgument (Proxy :: Proxy mods) errReq errSt mev
where
mev :: Maybe (Either T.Text a)
mev = fmap parseQueryParam $ join $ lookup paramname $ querytext req

errReq = delayedFailFatal err400
{ errBody = cs $ "Query parameter " <> paramname <> " is required"
}
errReq = delayedFailFatal $ urlParseErrorFormatter rep
$ cs $ "Query parameter " <> paramname <> " is required"

errSt e = delayedFailFatal err400
{ errBody = cs $ "Error parsing query parameter "
<> paramname <> " failed: " <> e
}
errSt e = delayedFailFatal $ urlParseErrorFormatter rep
$ cs $ "Error parsing query parameter "
<> paramname <> " failed: " <> e

delayed = addParameterCheck subserver . withRequest $ \req ->
parseParam req
@@ -495,7 +512,8 @@ instance
-- > server = getBooksBy
-- > where getBooksBy :: [Text] -> Handler [Book]
-- > getBooksBy authors = ...return all books by these authors...
instance (KnownSymbol sym, FromHttpApiData a, HasServer api context)
instance (KnownSymbol sym, FromHttpApiData a, HasServer api context
, HasContextEntry context URLParseErrorFormatter)
=> HasServer (QueryParams sym a :> api) context where

type ServerT (QueryParams sym a :> api) m =
@@ -506,21 +524,23 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer api context)
route Proxy context subserver = route (Proxy :: Proxy api) context $
subserver `addParameterCheck` withRequest paramsCheck
where
rep = typeRep (Proxy :: Proxy QueryParams)
urlParseErrorFormatter = getUrlParseErrorFormatter $ getContextEntry context

paramname = cs $ symbolVal (Proxy :: Proxy sym)
paramsCheck req =
case partitionEithers $ fmap parseQueryParam params of
([], parsed) -> return parsed
(errs, _) -> delayedFailFatal err400
{ errBody = cs $ "Error parsing query parameter(s) "
<> paramname <> " failed: "
<> T.intercalate ", " errs
}
(errs, _) -> delayedFailFatal $ urlParseErrorFormatter rep
$ cs $ "Error parsing query parameter(s) "
<> paramname <> " failed: "
<> T.intercalate ", " errs
where
params :: [T.Text]
params = mapMaybe snd
. filter (looksLikeParam . fst)
. queryToQueryText
. queryString
. queryToQueryText
. queryString
$ req

looksLikeParam name = name == paramname || name == (paramname <> "[]")
@@ -588,7 +608,7 @@ instance HasServer Raw context where
-- The @Content-Type@ header is inspected, and the list provided is used to
-- attempt deserialization. If the request does not have a @Content-Type@
-- header, it is treated as @application/octet-stream@ (as specified in
-- <http://tools.ietf.org/html/rfc7231#section-3.1.1.5 RFC7231>.
-- [RFC 7231 section 3.1.1.5](http://tools.ietf.org/html/rfc7231#section-3.1.1.5 RFC7231).
-- This lets servant worry about extracting it from the request and turning
-- it into a value of the type you specify.
--
@@ -604,6 +624,7 @@ instance HasServer Raw context where
-- > where postBook :: Book -> Handler Book
-- > postBook book = ...insert into your db...
instance ( AllCTUnrender list a, HasServer api context, SBoolI (FoldLenient mods)
, HasContextEntry context BodyParseErrorFormatter
) => HasServer (ReqBody' mods list a :> api) context where

type ServerT (ReqBody' mods list a :> api) m =
@@ -615,6 +636,9 @@ instance ( AllCTUnrender list a, HasServer api context, SBoolI (FoldLenient mods
= route (Proxy :: Proxy api) context $
addBodyCheck subserver ctCheck bodyCheck
where
rep = typeRep (Proxy :: Proxy ReqBody')
bodyParserErrorFormatter = getBodyParseErrorFormatter $ getContextEntry context

-- Content-Type check, we only lookup we can try to parse the request body
ctCheck = withRequest $ \ request -> do
-- See HTTP RFC 2616, section 7.2.1
@@ -633,7 +657,7 @@ instance ( AllCTUnrender list a, HasServer api context, SBoolI (FoldLenient mods
case sbool :: SBool (FoldLenient mods) of
STrue -> return mrqbody
SFalse -> case mrqbody of
Left e -> delayedFailFatal err400 { errBody = cs e }
Left e -> delayedFailFatal $ bodyParserErrorFormatter rep e
Right v -> return v

instance
63 changes: 63 additions & 0 deletions servant-server/src/Servant/Server/Internal/ErrorFormatter.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{-# LANGUAGE DataKinds #-}

-- | TODO: more documentation for this module
module Servant.Server.Internal.ErrorFormatter
where

import Data.String.Conversions
(cs)
import Data.Typeable

import Servant.Server.Internal.Context
import Servant.Server.Internal.ServerError

-- | 'Context' that contains default formatters for all error types.
--
-- Default formatters will just return HTTP 400 status code with error
-- message as response body.
type DefaultErrorFormatters = '[BodyParseErrorFormatter, URLParseErrorFormatter, HeaderParseErrorFormatter]

defaultErrorFormatters :: Context DefaultErrorFormatters
defaultErrorFormatters =
defaulyBodyParseErrorFormatter
:. defaultURLParseErrorFormatter
:. defaultHeaderParseErrorFormatter
:. EmptyContext

-- | A custom formatter for errors produced by parsing combinators like
-- 'Servant.API.ReqBody' or 'Servant.API.Capture'.
--
-- A 'TypeRep' argument described the concrete combinator that raised
-- the error, allowing formatter to customize the message for different
-- combinators.
type ErrorFormatter = TypeRep -> String -> ServerError

-- | Formatter for errors that occur while parsing request body.
newtype BodyParseErrorFormatter = BodyParseErrorFormatter
{ getBodyParseErrorFormatter :: ErrorFormatter
}

defaulyBodyParseErrorFormatter :: BodyParseErrorFormatter
defaulyBodyParseErrorFormatter = BodyParseErrorFormatter defaultErrorFormatter

-- | Formatter for errors that occur while parsing URL parts, like 'Servant.API.Capture' or
-- 'Servant.API.QueryParam'.
newtype URLParseErrorFormatter = URLParseErrorFormatter
{ getUrlParseErrorFormatter :: ErrorFormatter
}

defaultURLParseErrorFormatter :: URLParseErrorFormatter
defaultURLParseErrorFormatter = URLParseErrorFormatter defaultErrorFormatter

-- | Formatter for errors that occur while parsing HTTP headers.
newtype HeaderParseErrorFormatter = HeaderParseErrorFormatter
{ getHeaderParseErrorFormatter :: ErrorFormatter
}

defaultHeaderParseErrorFormatter :: HeaderParseErrorFormatter
defaultHeaderParseErrorFormatter = HeaderParseErrorFormatter defaultErrorFormatter

-- Internal

defaultErrorFormatter :: ErrorFormatter
defaultErrorFormatter _ e = err400 { errBody = cs e }
Loading