Skip to content

Commit

Permalink
use LATERAL for RPC
Browse files Browse the repository at this point in the history
  • Loading branch information
steve-chavez committed Feb 25, 2023
1 parent 2470e2e commit 0c7a9ca
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 48 deletions.
56 changes: 16 additions & 40 deletions src/PostgREST/Query/QueryBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ getSelectsJoins rr@(Node ReadPlan{select, relName, relToParent=Just rel, relAggA
mutatePlanToQuery :: MutatePlan -> SQL.Snippet
mutatePlanToQuery (Insert mainQi iCols body onConflct putConditions returnings _) =
"INSERT INTO " <> SQL.sql (fromQi mainQi) <> SQL.sql (if null iCols then " " else "(" <> cols <> ") ") <>
fromJsonBodyF body iCols True <>
fromJsonBodyF body iCols True False <>
-- Only used for PUT
(if null putConditions then mempty else "WHERE " <> intercalateSnippet " AND " (pgFmtLogicTree (QualifiedIdentifier mempty "pgrst_body") <$> putConditions)) <>
SQL.sql (BS.unwords [
Expand Down Expand Up @@ -114,13 +114,13 @@ mutatePlanToQuery (Update mainQi uCols body logicForest range ordts returnings)

| range == allRange =
"UPDATE " <> mainTbl <> " SET " <> SQL.sql nonRangeCols <> " " <>
fromJsonBodyF body uCols False <>
fromJsonBodyF body uCols False False <>
whereLogic <> " " <>
SQL.sql (returningF mainQi returnings)

| otherwise =
"WITH " <>
"pgrst_update_body AS (" <> fromJsonBodyF body uCols True <> " LIMIT 1), " <>
"pgrst_update_body AS (" <> fromJsonBodyF body uCols True True <> "), " <>
"pgrst_affected_rows AS (" <>
"SELECT " <> SQL.sql rangeIdF <> " FROM " <> mainTbl <>
whereLogic <> " " <>
Expand Down Expand Up @@ -165,50 +165,26 @@ mutatePlanToQuery (Delete mainQi logicForest range ordts returnings)

callPlanToQuery :: CallPlan -> SQL.Snippet
callPlanToQuery (FunctionCall qi params args returnsScalar multipleCall returnings) =
prmsCTE <> argsBody
"SELECT " <> (if returnsScalar then "pgrst_call AS pgrst_scalar " else returnedColumns) <> " " <>
fromCall
where
(prmsCTE, argFrag) = case params of
OnePosParam prm -> ("WITH pgrst_args AS (SELECT NULL)", singleParameter args (encodeUtf8 $ ppType prm))
KeyParams [] -> (mempty, mempty)
KeyParams prms -> (
"WITH " <> normalizedBody args <> ", " <>
SQL.sql (
BS.unwords [
"pgrst_args AS (",
"SELECT * FROM json_to_recordset(" <> selectBody <> ") AS _(" <> fmtParams prms (const mempty) (\a -> " " <> encodeUtf8 (ppType a)) <> ")",
")"])
, SQL.sql $ if multipleCall
then fmtParams prms varadicPrefix (\a -> " := pgrst_args." <> pgFmtIdent (ppName a))
else fmtParams prms varadicPrefix (\a -> " := (SELECT " <> pgFmtIdent (ppName a) <> " FROM pgrst_args LIMIT 1)")
)
fromCall = case params of
OnePosParam prm -> "FROM " <> callIt (singleParameter args $ encodeUtf8 $ ppType prm)
KeyParams [] -> "FROM " <> callIt mempty
KeyParams prms -> fromJsonBodyF args ((\p -> TypedField (ppName p) (ppType p)) <$> prms) False (not multipleCall) <> ", " <>
"LATERAL " <> callIt (fmtParams prms)

fmtParams :: [ProcParam] -> (ProcParam -> SqlFragment) -> (ProcParam -> SqlFragment) -> SqlFragment
fmtParams prms prmFragPre prmFragSuf = BS.intercalate ", "
((\a -> prmFragPre a <> pgFmtIdent (ppName a) <> prmFragSuf a) <$> prms)
callIt :: SQL.Snippet -> SQL.Snippet
callIt argument = SQL.sql (fromQi qi) <> "(" <> argument <> ") pgrst_call"

varadicPrefix :: ProcParam -> SqlFragment
varadicPrefix a = if ppVar a then "VARIADIC " else mempty

argsBody :: SQL.Snippet
argsBody
| multipleCall =
if returnsScalar
then "SELECT " <> callIt <> " AS pgrst_scalar FROM pgrst_args"
else "SELECT pgrst_lat_args.* FROM pgrst_args, " <>
"LATERAL ( SELECT " <> returnedColumns <> " FROM " <> callIt <> " ) pgrst_lat_args"
| otherwise =
if returnsScalar
then "SELECT " <> callIt <> " AS pgrst_scalar"
else "SELECT " <> returnedColumns <> " FROM " <> callIt

callIt :: SQL.Snippet
callIt = SQL.sql (fromQi qi) <> "(" <> argFrag <> ")"
fmtParams :: [ProcParam] -> SQL.Snippet
fmtParams prms = SQL.sql $ BS.intercalate ", "
((\a -> (if ppVar a then "VARIADIC " else mempty) <> pgFmtIdent (ppName a) <> " := pgrst_body." <> pgFmtIdent (ppName a)) <$> prms)

returnedColumns :: SQL.Snippet
returnedColumns
| null returnings = "*"
| otherwise = SQL.sql $ BS.intercalate ", " (pgFmtColumn (QualifiedIdentifier mempty $ qiName qi) <$> returnings)

| otherwise = SQL.sql $ BS.intercalate ", " (pgFmtColumn (QualifiedIdentifier mempty "pgrst_call") <$> returnings)

-- | SQL query meant for COUNTing the root node of the Tree.
-- It only takes WHERE into account and doesn't include LIMIT/OFFSET because it would reduce the COUNT.
Expand Down
6 changes: 3 additions & 3 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ pgFmtSelectItem table (f@(fName, jp), Nothing, alias) = pgFmtField table f <> SQ
pgFmtSelectItem table (f@(fName, jp), Just cast, alias) = "CAST (" <> pgFmtField table f <> " AS " <> SQL.sql (encodeUtf8 cast) <> " )" <> SQL.sql (pgFmtAs fName jp alias)

-- TODO: At this stage there shouldn't be a Maybe since ApiRequest should ensure that an INSERT/UPDATE has a body
fromJsonBodyF :: Maybe LBS.ByteString -> [TypedField] -> Bool -> SQL.Snippet
fromJsonBodyF body fields includeSelect =
fromJsonBodyF :: Maybe LBS.ByteString -> [TypedField] -> Bool -> Bool -> SQL.Snippet
fromJsonBodyF body fields includeSelect includeLimitOne =
SQL.sql
(if includeSelect then "SELECT " <> parsedCols <> " " else mempty) <>
"FROM (SELECT " <> jsonPlaceHolder <> " AS json_data) pgrst_payload, " <>
Expand All @@ -260,7 +260,7 @@ fromJsonBodyF body fields includeSelect =
-- because it can't extract records with no columns (there's no valid syntax for the `AS (colName colType,...)`
-- part). But we still need to ensure as many rows are created as there are array elements.
then SQL.sql "json_array_elements(pgrst_uniform_json.val) _ "
else SQL.sql ("json_to_recordset(pgrst_uniform_json.val) AS _(" <> typedCols <> ") ")
else SQL.sql ("json_to_recordset(pgrst_uniform_json.val) AS _(" <> typedCols <> ") " <> if includeLimitOne then "LIMIT 1" else mempty)
) <>
") pgrst_body "
where
Expand Down
8 changes: 4 additions & 4 deletions test/pgbench/1652/new.sql
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
WITH pgrst_source AS (
SELECT "get_projects_below".*
SELECT pgrst_call.*
FROM (
SELECT '[{"id": 4}]'::json as json_data
SELECT '{"id": 4}'::json as json_data
) pgrst_payload,
LATERAL (
SELECT CASE WHEN json_typeof(pgrst_payload.json_data) = 'array' THEN pgrst_payload.json_data ELSE json_build_array(pgrst_payload.json_data) END AS val
) pgrst_uniform_json,
LATERAL (
SELECT * FROM json_to_recordset(pgrst_uniform_json.val) AS _("id" integer) LIMIT 1
) pgrst_args,
LATERAL "test"."get_projects_below"("id" := pgrst_args.id)
) pgrst_body,
LATERAL "test"."get_projects_below"("id" := pgrst_body.id) pgrst_call
)
SELECT
null::bigint AS total_result_set,
Expand Down
2 changes: 1 addition & 1 deletion test/spec/Feature/Query/PlanSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ spec actualPgVersion = do
r <- request methodGet "/rpc/get_projects_below?id=3"
[planHdr] ""

liftIO $ planCost r `shouldSatisfy` (< 36.4)
liftIO $ planCost r `shouldSatisfy` (< 45.4)

it "should not exceed cost when calling setof composite proc with empty params" $ do
r <- request methodGet "/rpc/getallprojects"
Expand Down

0 comments on commit 0c7a9ca

Please sign in to comment.