Skip to content

Commit

Permalink
Add tests and support for limited filters
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenceisla committed Aug 6, 2022
1 parent c4ec461 commit 76b5a1e
Show file tree
Hide file tree
Showing 12 changed files with 351 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/PostgREST/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ instance JSON.ToJSON ApiRequestError where
"message" .= ("Cannot apply filter because '" <> resource <> "' is not an embedded resource in this request" :: Text),
"details" .= JSON.Null,
"hint" .= ("Verify that '" <> resource <> "' is included in the 'select' query parameter." :: Text)]
toJSON (BodyFilterNotAllowed message) = JSON.object [
toJSON (BodyFilterNotAllowed method isRpc) = JSON.object [
"code" .= ApiRequestErrorCode18,
"message" .= message,
"message" .= ("Body filter _eq is not allowed for " <> if isRpc then "RPC" else T.decodeUtf8 method),
"details" .= JSON.Null,
"hint" .= JSON.Null]

Expand Down
24 changes: 12 additions & 12 deletions src/PostgREST/Query/QueryBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ readRequestToQuery (Node (Select colSelects mainQi tblAlias implJoins logicFores
intercalateSnippet " " joins <> " " <>
(if null logicForest && null joinConditions_
then mempty
else "WHERE " <> intercalateSnippet " AND " (map (pgFmtLogicTree qi) logicForest ++ map pgFmtJoinCondition joinConditions_)) <> " " <>
else "WHERE " <> intercalateSnippet " AND " (map (pgFmtLogicTree qi Nothing) logicForest ++ map pgFmtJoinCondition joinConditions_)) <> " " <>
orderF qi ordts <> " " <>
limitOffsetF range
where
Expand Down Expand Up @@ -90,7 +90,7 @@ mutateRequestToQuery (Insert mainQi iCols body onConflct putConditions returning
"SELECT " <> SQL.sql cols <> " " <>
SQL.sql ("FROM json_populate_recordset (null::" <> fromQi mainQi <> ", " <> selectBody <> ") _ ") <>
-- Only used for PUT
(if null putConditions then mempty else "WHERE " <> intercalateSnippet " AND " (pgFmtLogicTree (QualifiedIdentifier mempty "_") <$> putConditions)) <>
(if null putConditions then mempty else "WHERE " <> intercalateSnippet " AND " (pgFmtLogicTree (QualifiedIdentifier mempty "_") Nothing <$> putConditions)) <>
SQL.sql (BS.unwords [
maybe "" (\(oncDo, oncCols) ->
if null oncCols then
Expand Down Expand Up @@ -120,16 +120,16 @@ mutateRequestToQuery (Update mainQi uCols body logicForest range ordts returning
| range == allRange =
"WITH " <> normalizedBody body <> " " <>
"UPDATE " <> mainTbl <> " SET " <> SQL.sql nonRangeCols <> " " <>
"FROM (SELECT * FROM json_populate_recordset (null::" <> mainTbl <> " , " <> SQL.sql selectBody <> " )) pgrst_recordset_body " <>
whereLogic <> " " <>
"FROM (SELECT * FROM json_populate_recordset (null::" <> mainTbl <> " , " <> SQL.sql selectBody <> " )) pgrst_update_body " <>
whereLogic (Just (BodyRecordset "pgrst_update_body" True)) <> " " <>
SQL.sql (returningF mainQi returnings)

| otherwise =
"WITH " <> normalizedBody body <> ", " <>
"pgrst_recordset_body AS (SELECT * FROM json_populate_recordset (null::" <> mainTbl <> " , " <> SQL.sql selectBody <> " ) LIMIT 1), " <>
"pgrst_update_body AS (SELECT * FROM json_populate_recordset (null::" <> mainTbl <> " , " <> SQL.sql selectBody <> " ) LIMIT 1), " <>
"pgrst_affected_rows AS (" <>
"SELECT " <> SQL.sql rangeIdF <> " FROM " <> mainTbl <> " " <>
whereLogic <> " " <>
whereLogic (Just (BodyRecordset "pgrst_update_body" False)) <> " " <>
orderF mainQi ordts <> " " <>
limitOffsetF range <>
") " <>
Expand All @@ -141,11 +141,11 @@ mutateRequestToQuery (Update mainQi uCols body logicForest range ordts returning
where

mainTbl = SQL.sql (fromQi mainQi)
logicForestF = intercalateSnippet " AND " (pgFmtLogicTree mainQi <$> logicForest)
whereLogic = if null logicForest then mempty else " WHERE " <> logicForestF
logicForestF recordset = intercalateSnippet " AND " (pgFmtLogicTree mainQi recordset <$> logicForest)
whereLogic recordset = if null logicForest then mempty else " WHERE " <> logicForestF recordset
emptyBodyReturnedColumns = if null returnings then "NULL" else BS.intercalate ", " (pgFmtColumn (QualifiedIdentifier mempty $ qiName mainQi) <$> returnings)
nonRangeCols = BS.intercalate ", " (pgFmtIdent <> const " = pgrst_recordset_body." <> pgFmtIdent <$> S.toList uCols)
rangeCols = BS.intercalate ", " ((\col -> pgFmtIdent col <> " = (SELECT " <> pgFmtIdent col <> " FROM pgrst_recordset_body) ") <$> S.toList uCols)
nonRangeCols = BS.intercalate ", " (pgFmtIdent <> const " = pgrst_update_body." <> pgFmtIdent <$> S.toList uCols)
rangeCols = BS.intercalate ", " ((\col -> pgFmtIdent col <> " = (SELECT " <> pgFmtIdent col <> " FROM pgrst_update_body) ") <$> S.toList uCols)
(whereRangeIdF, rangeIdF) = mutRangeF mainQi (fst . otTerm <$> ordts)

mutateRequestToQuery (Delete mainQi logicForest range ordts returnings)
Expand All @@ -168,7 +168,7 @@ mutateRequestToQuery (Delete mainQi logicForest range ordts returnings)
SQL.sql (returningF mainQi returnings)

where
whereLogic = if null logicForest then mempty else " WHERE " <> intercalateSnippet " AND " (pgFmtLogicTree mainQi <$> logicForest)
whereLogic = if null logicForest then mempty else " WHERE " <> intercalateSnippet " AND " (pgFmtLogicTree mainQi Nothing <$> logicForest)
(whereRangeIdF, rangeIdF) = mutRangeF mainQi (fst . otTerm <$> ordts)

requestToCallProcQuery :: CallRequest -> SQL.Snippet
Expand Down Expand Up @@ -233,7 +233,7 @@ readRequestToCountQuery (Node (Select{from=mainQi, fromAlias=tblAlias, implicitJ
then mempty
else " WHERE " ) <>
intercalateSnippet " AND " (
map (pgFmtLogicTree treeQi) logicForest ++
map (pgFmtLogicTree treeQi Nothing) logicForest ++
map pgFmtJoinCondition joinConditions_ ++
subQueries
)
Expand Down
20 changes: 13 additions & 7 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ import PostgREST.RangeQuery (NonnegRange, allRange,
rangeLimit, rangeOffset)
import PostgREST.Request.ReadQuery (SelectItem)
import PostgREST.Request.Types (Alias, BodyOperator (..),
Field, Filter (..),
BodyRecordset (..), Field,
Filter (..),
FtsOperator (..),
JoinCondition (..),
JsonOperand (..),
Expand Down Expand Up @@ -253,8 +254,8 @@ pgFmtOrderTerm qi ot =
nullOrder OrderNullsLast = "NULLS LAST"


pgFmtFilter :: QualifiedIdentifier -> Filter -> SQL.Snippet
pgFmtFilter table (Filter fld (OpExpr hasNot oper)) = notOp <> " " <> case oper of
pgFmtFilter :: QualifiedIdentifier -> Maybe BodyRecordset -> Filter -> SQL.Snippet
pgFmtFilter table bodyRec (Filter fld (OpExpr hasNot oper)) = notOp <> " " <> case oper of
Op op val -> pgFmtFieldOp op <> " " <> case op of
OpLike -> unknownLiteral (T.map star val)
OpILike -> unknownLiteral (T.map star val)
Expand All @@ -280,27 +281,32 @@ pgFmtFilter table (Filter fld (OpExpr hasNot oper)) = notOp <> " " <> case oper
Fts op lang val ->
pgFmtFieldFts op <> "(" <> ftsLang lang <> unknownLiteral val <> ") "

BodOp op val -> pgFmtFieldBodOp op <> " " <> SQL.sql (pgFmtColumn (QualifiedIdentifier mempty "pgrst_recordset_body") val)
BodOp op val -> pgFmtFieldBodOp op <> " " <> fmtBodOpFilter val
where
ftsLang = maybe mempty (\l -> unknownLiteral l <> ", ")
pgFmtFieldOp op = pgFmtField table fld <> " " <> SQL.sql (singleValOperator op)
pgFmtFieldFts op = pgFmtField table fld <> " " <> SQL.sql (ftsOperator op)
pgFmtFieldBodOp op = pgFmtField table fld <> " " <> SQL.sql (bodySingleOperator op)
fmtBodOpFilter val = case bodyRec of
Just (BodyRecordset bodName direct)
| direct -> SQL.sql (pgFmtColumn (QualifiedIdentifier mempty (decodeUtf8 bodName)) val)
| otherwise -> SQL.sql ("(SELECT " <> pgFmtIdent val <> " FROM " <> bodName <> ")")
Nothing -> mempty
notOp = if hasNot then "NOT" else mempty
star c = if c == '*' then '%' else c

pgFmtJoinCondition :: JoinCondition -> SQL.Snippet
pgFmtJoinCondition (JoinCondition (qi1, col1) (qi2, col2)) =
SQL.sql $ pgFmtColumn qi1 col1 <> " = " <> pgFmtColumn qi2 col2

pgFmtLogicTree :: QualifiedIdentifier -> LogicTree -> SQL.Snippet
pgFmtLogicTree qi (Expr hasNot op forest) = SQL.sql notOp <> " (" <> intercalateSnippet (opSql op) (pgFmtLogicTree qi <$> forest) <> ")"
pgFmtLogicTree :: QualifiedIdentifier -> Maybe BodyRecordset -> LogicTree -> SQL.Snippet
pgFmtLogicTree qi bodyRec (Expr hasNot op forest) = SQL.sql notOp <> " (" <> intercalateSnippet (opSql op) (pgFmtLogicTree qi bodyRec <$> forest) <> ")"
where
notOp = if hasNot then "NOT" else mempty

opSql And = " AND "
opSql Or = " OR "
pgFmtLogicTree qi (Stmnt flt) = pgFmtFilter qi flt
pgFmtLogicTree qi bodyRec (Stmnt flt) = pgFmtFilter qi bodyRec flt

pgFmtJsonPath :: JsonPath -> SQL.Snippet
pgFmtJsonPath = \case
Expand Down
2 changes: 1 addition & 1 deletion src/PostgREST/Request/ApiRequest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ apiRequest conf@AppConfig{..} dbStructure req reqBody queryparams@QueryParams{..
| isInvalidRange = Left InvalidRange
| shouldParsePayload && isLeft payload = either (Left . InvalidBody) witness payload
| not expectParams && not (L.null qsParams) = Left $ ParseRequestError "Unexpected param or filter missing operator" ("Failed to parse " <> show qsParams)
| bodyFilterNotAllowed = Left $ BodyFilterNotAllowed "Body filter _eq is not allowed for this method"
| bodyFilterNotAllowed = Left $ BodyFilterNotAllowed method pathIsProc
| method `elem` ["PATCH", "DELETE"] && not (null qsRanges) && null qsOrder = Left LimitNoOrderError
| method == "PUT" && topLevelRange /= allRange = Left PutRangeNotAllowedError
| otherwise = do
Expand Down
10 changes: 9 additions & 1 deletion src/PostgREST/Request/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ module PostgREST.Request.Types
, SimpleOperator(..)
, FtsOperator(..)
, BodyOperator(..)
, BodyRecordset(..)
) where

import qualified Data.ByteString.Lazy as LBS
Expand All @@ -50,7 +51,7 @@ import Protolude
data ApiRequestError
= AmbiguousRelBetween Text Text [Relationship]
| AmbiguousRpc [ProcDescription]
| BodyFilterNotAllowed Text
| BodyFilterNotAllowed ByteString Bool
| MediaTypeError [ByteString]
| InvalidBody ByteString
| InvalidFilters
Expand Down Expand Up @@ -233,3 +234,10 @@ data FtsOperator
data BodyOperator
= BodyOpEqual
deriving Eq

-- | Information of the transformed body using json_populate_recordset
-- to be used for filtering using body operators
data BodyRecordset = BodyRecordset
{ brName :: ByteString
, isDirectRef :: Bool
}
6 changes: 6 additions & 0 deletions test/spec/Feature/Query/DeleteSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ spec =
, matchHeaders = ["Content-Range" <:> "*/*"]
}

it "fails if a body filter operator is given" $
request methodDelete "/tasks?id=_eq.id" [] mempty
`shouldRespondWith`
[json|{"details":null,"message":"Body filter _eq is not allowed for DELETE","code":"PGRST118","hint":null} |]
{ matchStatus = 400 }

context "known route, no records matched" $
it "includes [] body if return=rep" $
request methodDelete "/items?id=eq.101"
Expand Down
6 changes: 6 additions & 0 deletions test/spec/Feature/Query/QuerySpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,12 @@ spec actualPgVersion = do
, matchHeaders = [matchContentTypeJson]
}

it "fails if a body filter operator is given" $
get "/ghostBusters?id=_eq.id" `shouldRespondWith` [json| {"details":null,"message":"Body filter _eq is not allowed for GET","code":"PGRST118","hint":null} |]
{ matchStatus = 400
, matchHeaders = [matchContentTypeJson]
}

it "will embed a collection" $
get "/Escap3e;?select=ghostBusters(*)" `shouldRespondWith`
[json| [{"ghostBusters":[{"escapeId":1}]},{"ghostBusters":[]},{"ghostBusters":[{"escapeId":3}]},{"ghostBusters":[]},{"ghostBusters":[{"escapeId":5}]}] |]
Expand Down
8 changes: 8 additions & 0 deletions test/spec/Feature/Query/RpcSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,14 @@ spec actualPgVersion =
`shouldRespondWith` "3"
{ matchHeaders = [matchContentTypeJson] }

it "fails if a body filter operator is given" $ do
get "/rpc/sayhello?name=_eq.John"
`shouldRespondWith` [json| {"details":null,"message":"Body filter _eq is not allowed for RPC","code":"PGRST118","hint":null} |]
{ matchStatus = 400 }
post "/rpc/sayhello?name=_eq.name" [json|{name: "John"}|]
`shouldRespondWith` [json| {"details":null,"message":"Body filter _eq is not allowed for RPC","code":"PGRST118","hint":null} |]
{ matchStatus = 400 }

context "bulk RPC with params=multiple-objects" $ do
it "works with a scalar function an returns a json array" $
request methodPost "/rpc/add_them" [("Prefer", "params=multiple-objects")]
Expand Down
Loading

0 comments on commit 76b5a1e

Please sign in to comment.