Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Avoid casting to table type when select= and media type handler are used #3224

Merged
merged 2 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- #3149, Misleading "Starting PostgREST.." logs on schema cache reloading - @steve-chavez
- #2815, Build static executable with GSSAPI support - @wolfgangwalther
- #3205, Fix wrong subquery error returning a status of 400 Bad Request - @steve-chavez
- #3224, Return status code 406 for non-accepted media type instead of code 415 - @wolfgangwalther
- #3160, Fix using select= query parameter for custom media type handlers - @wolfgangwalther

### Deprecated

Expand Down
2 changes: 1 addition & 1 deletion src/PostgREST/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ instance PgrstError ApiRequestError where
status AggregatesNotAllowed{} = HTTP.status400
status AmbiguousRelBetween{} = HTTP.status300
status AmbiguousRpc{} = HTTP.status300
status MediaTypeError{} = HTTP.status415
status MediaTypeError{} = HTTP.status406
status InvalidBody{} = HTTP.status400
status InvalidFilters = HTTP.status405
status InvalidPreferences{} = HTTP.status400
Expand Down
32 changes: 18 additions & 14 deletions src/PostgREST/Plan.hs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ data WrappedReadPlan = WrappedReadPlan {
, wrTxMode :: SQL.Mode
, wrHandler :: MediaHandler
, wrMedia :: MediaType
, wrIdent :: QualifiedIdentifier
}

data MutateReadPlan = MutateReadPlan {
Expand All @@ -106,7 +105,6 @@ data MutateReadPlan = MutateReadPlan {
, mrTxMode :: SQL.Mode
, mrHandler :: MediaHandler
, mrMedia :: MediaType
, mrIdent :: QualifiedIdentifier
}

data CallReadPlan = CallReadPlan {
Expand All @@ -116,7 +114,6 @@ data CallReadPlan = CallReadPlan {
, crProc :: Routine
, crHandler :: MediaHandler
, crMedia :: MediaType
, crIdent :: QualifiedIdentifier
}

data InspectPlan = InspectPlan {
Expand All @@ -127,17 +124,17 @@ data InspectPlan = InspectPlan {
wrappedReadPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> Either Error WrappedReadPlan
wrappedReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferences{..},..} = do
rPlan <- readPlan identifier conf sCache apiRequest
(hdler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache)
(handler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache) (hasDefaultSelect rPlan)
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
return $ WrappedReadPlan rPlan SQL.Read hdler mediaType identifier
return $ WrappedReadPlan rPlan SQL.Read handler mediaType

mutateReadPlan :: Mutation -> ApiRequest -> QualifiedIdentifier -> AppConfig -> SchemaCache -> Either Error MutateReadPlan
mutateReadPlan mutation apiRequest@ApiRequest{iPreferences=Preferences{..},..} identifier conf sCache = do
rPlan <- readPlan identifier conf sCache apiRequest
mPlan <- mutatePlan mutation identifier apiRequest sCache rPlan
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
(hdler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache)
return $ MutateReadPlan rPlan mPlan SQL.Write hdler mediaType identifier
(handler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest identifier iAcceptMediaType (dbMediaHandlers sCache) (hasDefaultSelect rPlan)
return $ MutateReadPlan rPlan mPlan SQL.Write handler mediaType

callReadPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> InvokeMethod -> Either Error CallReadPlan
callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferences{..},..} invMethod = do
Expand All @@ -161,12 +158,16 @@ callReadPlan identifier conf sCache apiRequest@ApiRequest{iPreferences=Preferenc
(InvPost, Routine.Immutable) -> SQL.Read
(InvPost, Routine.Volatile) -> SQL.Write
cPlan = callPlan proc apiRequest paramKeys args rPlan
(hdler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest relIdentifier iAcceptMediaType (dbMediaHandlers sCache)
(handler, mediaType) <- mapLeft ApiRequestError $ negotiateContent conf apiRequest relIdentifier iAcceptMediaType (dbMediaHandlers sCache) (hasDefaultSelect rPlan)
if not (null invalidPrefs) && preferHandling == Just Strict then Left $ ApiRequestError $ InvalidPreferences invalidPrefs else Right ()
return $ CallReadPlan rPlan cPlan txMode proc hdler mediaType relIdentifier
return $ CallReadPlan rPlan cPlan txMode proc handler mediaType
where
qsParams' = QueryParams.qsParams iQueryParams

hasDefaultSelect :: ReadPlanTree -> Bool
hasDefaultSelect (Node ReadPlan{select=[CoercibleSelectField{csField=CoercibleField{cfName}}]} []) = cfName == "*"
hasDefaultSelect _ = False

inspectPlan :: ApiRequest -> Either Error InspectPlan
inspectPlan apiRequest = do
let producedMTs = [MTOpenAPI, MTApplicationJSON, MTAny]
Expand Down Expand Up @@ -993,8 +994,8 @@ addFilterToLogicForest :: CoercibleFilter -> [CoercibleLogicTree] -> [CoercibleL
addFilterToLogicForest flt lf = CoercibleStmnt flt : lf

-- | Do content negotiation. i.e. choose a media type based on the intersection of accepted/produced media types.
negotiateContent :: AppConfig -> ApiRequest -> QualifiedIdentifier -> [MediaType] -> MediaHandlerMap -> Either ApiRequestError ResolvedHandler
negotiateContent conf ApiRequest{iAction=act, iPreferences=Preferences{preferRepresentation=rep}} identifier accepts produces =
negotiateContent :: AppConfig -> ApiRequest -> QualifiedIdentifier -> [MediaType] -> MediaHandlerMap -> Bool -> Either ApiRequestError ResolvedHandler
negotiateContent conf ApiRequest{iAction=act, iPreferences=Preferences{preferRepresentation=rep}} identifier accepts produces defaultSelect =
case (act, firstAcceptedPick) of
(_, Nothing) -> Left . MediaTypeError $ map MediaType.toMime accepts
(ActionMutate _, Just (x, mt)) -> Right (if rep == Just Full then x else NoAgg, mt)
Expand All @@ -1017,6 +1018,9 @@ negotiateContent conf ApiRequest{iAction=act, iPreferences=Preferences{preferRep
x -> lookupHandler x
mtPlanToNothing x = if configDbPlanEnabled conf then x else Nothing -- don't find anything if the plan media type is not allowed
lookupHandler mt =
HM.lookup (RelId identifier, MTAny) produces <|> -- lookup for identifier and `*/*`
HM.lookup (RelId identifier, mt) produces <|> -- lookup for identifier and a particular media type
HM.lookup (RelAnyElement, mt) produces -- lookup for anyelement and a particular media type
when' defaultSelect (HM.lookup (RelId identifier, MTAny) produces) <|> -- lookup for identifier and `*/*`
when' defaultSelect (HM.lookup (RelId identifier, mt) produces) <|> -- lookup for identifier and a particular media type
HM.lookup (RelAnyElement, mt) produces -- lookup for anyelement and a particular media type
when' :: Bool -> Maybe a -> Maybe a
when' True (Just a) = Just a
when' _ _ = Nothing
3 changes: 0 additions & 3 deletions src/PostgREST/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ readQuery WrappedReadPlan{..} conf@AppConfig{..} apiReq@ApiRequest{iPreferences=
resultSet <-
lift . SQL.statement mempty $
Statements.prepareRead
wrIdent
(QueryBuilder.readPlanToQuery wrReadPlan)
(if preferCount == Just EstimatedCount then
-- LIMIT maxRows + 1 so we can determine below that maxRows was surpassed
Expand Down Expand Up @@ -157,7 +156,6 @@ invokeQuery rout CallReadPlan{..} apiReq@ApiRequest{iPreferences=Preferences{..}
resultSet <-
lift . SQL.statement mempty $
Statements.prepareCall
crIdent
rout
(QueryBuilder.callPlanToQuery crCallPlan pgVer)
(QueryBuilder.readPlanToQuery crReadPlan)
Expand Down Expand Up @@ -196,7 +194,6 @@ writeQuery MutateReadPlan{..} ApiRequest{iPreferences=Preferences{..}} conf =
in
lift . SQL.statement mempty $
Statements.prepareWrite
mrIdent
(QueryBuilder.readPlanToQuery mrReadPlan)
(QueryBuilder.mutatePlanToQuery mrMutatePlan)
isInsert
Expand Down
16 changes: 9 additions & 7 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ import PostgREST.Plan.Types (CoercibleField (..),
import PostgREST.RangeQuery (NonnegRange, allRange,
rangeLimit, rangeOffset)
import PostgREST.SchemaCache.Identifiers (FieldName,
QualifiedIdentifier (..))
QualifiedIdentifier (..),
RelIdentifier (..))
import PostgREST.SchemaCache.Routine (MediaHandler (..),
Routine (..),
funcReturnsScalar,
Expand Down Expand Up @@ -221,10 +222,11 @@ asJsonF rout strip
asGeoJsonF :: SQL.Snippet
asGeoJsonF = "json_build_object('type', 'FeatureCollection', 'features', coalesce(json_agg(ST_AsGeoJSON(_postgrest_t)::json), '[]'))"

customFuncF :: Maybe Routine -> QualifiedIdentifier -> QualifiedIdentifier -> SQL.Snippet
customFuncF rout funcQi target
customFuncF :: Maybe Routine -> QualifiedIdentifier -> RelIdentifier -> SQL.Snippet
customFuncF rout funcQi _
| (funcReturnsScalar <$> rout) == Just True = fromQi funcQi <> "(_postgrest_t.pgrst_scalar)"
| otherwise = fromQi funcQi <> "(_postgrest_t::" <> fromQi target <> ")"
customFuncF _ funcQi RelAnyElement = fromQi funcQi <> "(_postgrest_t)"
customFuncF _ funcQi (RelId target) = fromQi funcQi <> "(_postgrest_t::" <> fromQi target <> ")"

locationF :: [Text] -> SQL.Snippet
locationF pKeys = [qc|(
Expand Down Expand Up @@ -559,12 +561,12 @@ setConfigWithConstantNameJSON prefix keyVals = [setConfigWithConstantName (prefi
arrayByteStringToText :: [(ByteString, ByteString)] -> [(Text,Text)]
arrayByteStringToText keyVal = (T.decodeUtf8 *** T.decodeUtf8) <$> keyVal

handlerF :: Maybe Routine -> QualifiedIdentifier -> MediaHandler -> SQL.Snippet
handlerF rout target = \case
handlerF :: Maybe Routine -> MediaHandler -> SQL.Snippet
handlerF rout = \case
BuiltinAggArrayJsonStrip -> asJsonF rout True
BuiltinAggSingleJson strip -> asJsonSingleF rout strip
BuiltinOvAggJson -> asJsonF rout False
BuiltinOvAggGeoJson -> asGeoJsonF
BuiltinOvAggCsv -> asCsvF
CustomFunc funcQi -> customFuncF rout funcQi target
CustomFunc funcQi target -> customFuncF rout funcQi target
NoAgg -> "''::text"
27 changes: 13 additions & 14 deletions src/PostgREST/Query/Statements.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ import qualified Hasql.Statement as SQL
import Control.Lens ((^?))

import PostgREST.ApiRequest.Preferences
import PostgREST.MediaType (MTVndPlanFormat (..),
MediaType (..))
import PostgREST.MediaType (MTVndPlanFormat (..),
MediaType (..))
import PostgREST.Query.SqlFragment
import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier)
import PostgREST.SchemaCache.Routine (MediaHandler (..), Routine,
funcReturnsSingle)
import PostgREST.SchemaCache.Routine (MediaHandler (..), Routine,
funcReturnsSingle)

import Protolude

Expand All @@ -56,9 +55,9 @@ data ResultSet
| RSPlan BS.ByteString -- ^ the plan of the query


prepareWrite :: QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> Bool -> MediaType -> MediaHandler ->
prepareWrite :: SQL.Snippet -> SQL.Snippet -> Bool -> Bool -> MediaType -> MediaHandler ->
Maybe PreferRepresentation -> Maybe PreferResolution -> [Text] -> Bool -> SQL.Statement () ResultSet
prepareWrite qi selectQuery mutateQuery isInsert isPut mt handler rep resolution pKeys =
prepareWrite selectQuery mutateQuery isInsert isPut mt handler rep resolution pKeys =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
checkUpsert snip = if isInsert && (isPut || resolution == Just MergeDuplicates) then snip else "''"
Expand All @@ -69,7 +68,7 @@ prepareWrite qi selectQuery mutateQuery isInsert isPut mt handler rep resolution
"'' AS total_result_set, " <>
"pg_catalog.count(_postgrest_t) AS page_total, " <>
locF <> " AS header, " <>
handlerF Nothing qi handler <> " AS body, " <>
handlerF Nothing handler <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status, " <>
pgrstInsertedF <> " AS response_inserted " <>
Expand All @@ -94,8 +93,8 @@ prepareWrite qi selectQuery mutateQuery isInsert isPut mt handler rep resolution
MTVndPlan{} -> planRow
_ -> fromMaybe (RSStandard Nothing 0 mempty mempty Nothing Nothing Nothing) <$> HD.rowMaybe (standardRow False)

prepareRead :: QualifiedIdentifier -> SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> MediaHandler -> Bool -> SQL.Statement () ResultSet
prepareRead qi selectQuery countQuery countTotal mt handler =
prepareRead :: SQL.Snippet -> SQL.Snippet -> Bool -> MediaType -> MediaHandler -> Bool -> SQL.Statement () ResultSet
prepareRead selectQuery countQuery countTotal mt handler =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
snippet =
Expand All @@ -104,7 +103,7 @@ prepareRead qi selectQuery countQuery countTotal mt handler =
"SELECT " <>
countResultF <> " AS total_result_set, " <>
"pg_catalog.count(_postgrest_t) AS page_total, " <>
handlerF Nothing qi handler <> " AS body, " <>
handlerF Nothing handler <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status, " <>
"''" <> " AS response_inserted " <>
Expand All @@ -117,10 +116,10 @@ prepareRead qi selectQuery countQuery countTotal mt handler =
MTVndPlan{} -> planRow
_ -> HD.singleRow $ standardRow True

prepareCall :: QualifiedIdentifier -> Routine -> SQL.Snippet -> SQL.Snippet -> SQL.Snippet -> Bool ->
prepareCall :: Routine -> SQL.Snippet -> SQL.Snippet -> SQL.Snippet -> Bool ->
MediaType -> MediaHandler -> Bool ->
SQL.Statement () ResultSet
prepareCall qi rout callProcQuery selectQuery countQuery countTotal mt handler =
prepareCall rout callProcQuery selectQuery countQuery countTotal mt handler =
SQL.dynamicallyParameterized (mtSnippet mt snippet) decodeIt
where
snippet =
Expand All @@ -131,7 +130,7 @@ prepareCall qi rout callProcQuery selectQuery countQuery countTotal mt handler =
(if funcReturnsSingle rout
then "1"
else "pg_catalog.count(_postgrest_t)") <> " AS page_total, " <>
handlerF (Just rout) qi handler <> " AS body, " <>
handlerF (Just rout) handler <> " AS body, " <>
responseHeadersF <> " AS response_headers, " <>
responseStatusF <> " AS response_status, " <>
"''" <> " AS response_inserted " <>
Expand Down
4 changes: 3 additions & 1 deletion src/PostgREST/SchemaCache.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,9 @@ mediaHandlers pgVer =

decodeMediaHandlers :: HD.Result MediaHandlerMap
decodeMediaHandlers =
HM.fromList . fmap (\(x, y, z, w) -> ((if isAnyElement y then RelAnyElement else RelId y, z), (CustomFunc x, w)) ) <$> HD.rowList caggRow
HM.fromList . fmap (\(x, y, z, w) ->
let rel = if isAnyElement y then RelAnyElement else RelId y
in ((rel, z), (CustomFunc x rel, w)) ) <$> HD.rowList caggRow
where
caggRow = (,,,)
<$> (QualifiedIdentifier <$> column HD.text <*> column HD.text)
Expand Down
2 changes: 1 addition & 1 deletion src/PostgREST/SchemaCache/Identifiers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import qualified Data.Text as T
import Protolude

data RelIdentifier = RelId QualifiedIdentifier | RelAnyElement
deriving (Eq, Ord, Generic, JSON.ToJSON, JSON.ToJSONKey)
deriving (Eq, Ord, Generic, JSON.ToJSON, JSON.ToJSONKey, Show)
instance Hashable RelIdentifier

-- | Represents a pg identifier with a prepended schema name "schema.table".
Expand Down
2 changes: 1 addition & 1 deletion src/PostgREST/SchemaCache/Routine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ data MediaHandler
| BuiltinOvAggGeoJson
| BuiltinOvAggCsv
-- custom
| CustomFunc QualifiedIdentifier
| CustomFunc QualifiedIdentifier RelIdentifier
| NoAgg
deriving (Eq, Show)

Expand Down
8 changes: 4 additions & 4 deletions test/spec/Feature/OpenApi/OpenApiSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ spec actualPgVersion = describe "OpenAPI" $ do
, matchHeaders = ["Content-Type" <:> "application/openapi+json; charset=utf-8"]
}

it "should respond to openapi request on none root path with 415" $
it "should respond to openapi request on none root path with 406" $
request methodGet "/items"
(acceptHdrs "application/openapi+json") ""
`shouldRespondWith` 415
`shouldRespondWith` 406

it "should respond to openapi request with unsupported media type with 415" $
it "should respond to openapi request with unsupported media type with 406" $
request methodGet "/"
(acceptHdrs "text/csv") ""
`shouldRespondWith` 415
`shouldRespondWith` 406

it "includes postgrest.org current version api docs" $ do
r <- simpleBody <$> get "/"
Expand Down
Loading
Loading