Skip to content

Commit

Permalink
Add support for aggregate functions
Browse files Browse the repository at this point in the history
The aggregate functions SUM(), MAX(), MIN(), AVG(),
and COUNT() are now supported.
  • Loading branch information
timabdulla committed Sep 12, 2023
1 parent 3f5e840 commit 997289e
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 29 deletions.
28 changes: 22 additions & 6 deletions src/PostgREST/ApiRequest/QueryParams.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import Data.Tree (Tree (..))
import Text.Parsec.Error (errorMessages,
showErrorMessages)
import Text.ParserCombinators.Parsec (GenParser, ParseError, Parser,
anyChar, between, char, digit,
eof, errorPos, letter,
anyChar, between, char, choice,
digit, eof, errorPos, letter,
lookAhead, many1, noneOf,
notFollowedBy, oneOf,
optionMaybe, sepBy, sepBy1,
Expand All @@ -43,7 +43,8 @@ import PostgREST.RangeQuery (NonnegRange, allRange,
rangeOffset, restrictRange)
import PostgREST.SchemaCache.Identifiers (FieldName)

import PostgREST.ApiRequest.Types (EmbedParam (..), EmbedPath, Field,
import PostgREST.ApiRequest.Types (AggregateFunction(..),
EmbedParam (..), EmbedPath, Field,
Filter (..), FtsOperator (..),
Hint, JoinType (..),
JsonOperand (..),
Expand All @@ -58,7 +59,7 @@ import PostgREST.ApiRequest.Types (EmbedParam (..), EmbedPath, Field,
SimpleOperator (..), SingleVal,
TrileanVal (..))

import Protolude hiding (try)
import Protolude hiding (try, Sum)

data QueryParams =
QueryParams
Expand Down Expand Up @@ -452,10 +453,12 @@ pRelationSelect :: Parser SelectItem
pRelationSelect = lexeme $ do
alias <- optionMaybe ( try(pFieldName <* aliasSeparator) )
name <- pFieldName
guard (name /= "count")
(hint, jType) <- pEmbedParams
try (void $ lookAhead (string "("))
return $ SelectRelation name alias hint jType


-- |
-- Parse regular fields in select
--
Expand Down Expand Up @@ -495,18 +498,31 @@ pFieldSelect :: Parser SelectItem
pFieldSelect = lexeme $ try (do
s <- pStar
pEnd
return $ SelectField (s, []) Nothing Nothing)
return $ SelectField (s, []) Nothing Nothing Nothing)
<|> try (do
alias <- optionMaybe ( try(pFieldName <* aliasSeparator) )
_ <- string "count()"
pEnd
return $ SelectField ("*", []) (Just Count) Nothing alias)
<|> do
alias <- optionMaybe ( try(pFieldName <* aliasSeparator) )
fld <- pField
cast' <- optionMaybe (string "::" *> pIdentifier)
agg <- optionMaybe (try (char '.' *> pAggregation <* string "()"))
pEnd
return $ SelectField fld (toS <$> cast') alias
return $ SelectField fld agg (toS <$> cast') alias
where
pEnd = try (void $ lookAhead (string ")")) <|>
try (void $ lookAhead (string ",")) <|>
try eof
pStar = string "*" $> "*"
pAggregation = choice
[ string "sum" $> Sum
, string "avg" $> Avg
, string "max" $> Max
, string "min" $> Min
, string "count" $> Count
]


-- |
Expand Down
15 changes: 10 additions & 5 deletions src/PostgREST/ApiRequest/Types.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
module PostgREST.ApiRequest.Types
( Alias
( AggregateFunction(..)
, Alias
, Cast
, Depth
, EmbedParam(..)
Expand Down Expand Up @@ -42,12 +43,13 @@ import PostgREST.SchemaCache.Routine (Routine (..))

import Protolude

-- | The value in `/tbl?select=alias:field::cast`
-- | The value in `/tbl?select=alias:field.aggregateFunction()::cast`
data SelectItem
= SelectField
{ selField :: Field
, selCast :: Maybe Cast
, selAlias :: Maybe Alias
{ selField :: Field
, selAggregateFunction :: Maybe AggregateFunction
, selCast :: Maybe Cast
, selAlias :: Maybe Alias
}
-- | The value in `/tbl?select=alias:another_tbl(*)`
| SelectRelation
Expand Down Expand Up @@ -128,6 +130,9 @@ type Cast = Text
type Alias = Text
type Hint = Text

data AggregateFunction = Sum | Avg | Max | Min | Count
deriving (Show, Eq)

data EmbedParam
-- | Disambiguates an embedding operation when there's multiple relationships
-- between two tables. Can be the name of a foreign key constraint, column
Expand Down
18 changes: 9 additions & 9 deletions src/PostgREST/Plan.hs
Original file line number Diff line number Diff line change
Expand Up @@ -317,15 +317,15 @@ initReadRequest ctx@ResolverContext{qi=QualifiedIdentifier{..}} =
(Node defReadPlan{from=QualifiedIdentifier qiSchema selRelation, relName=selRelation, relHint=selHint, relJoinType=selJoinType, depth=nxtDepth, relIsSpread=True} [])
fldForest:rForest
SelectField{..} ->
Node q{select=(resolveOutputField ctx{qi=from q} selField, selCast, selAlias):select q} rForest
Node q{select=(resolveOutputField ctx{qi=from q} selField, selAggregateFunction, selCast, selAlias):select q} rForest

-- | Preserve the original field name if data representation is used to coerce the value.
addDataRepresentationAliases :: ReadPlanTree -> Either ApiRequestError ReadPlanTree
addDataRepresentationAliases rPlanTree = Right $ fmap (\rPlan@ReadPlan{select=sel} -> rPlan{select=map aliasSelectItem sel}) rPlanTree
where
aliasSelectItem :: (CoercibleField, Maybe Cast, Maybe Alias) -> (CoercibleField, Maybe Cast, Maybe Alias)
aliasSelectItem :: (CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias) -> (CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias)
-- If there already is an alias, don't overwrite it.
aliasSelectItem (fld@(CoercibleField{cfName=fieldName, cfTransform=(Just _)}), Nothing, Nothing) = (fld, Nothing, Just fieldName)
aliasSelectItem (fld@(CoercibleField{cfName=fieldName, cfTransform=(Just _)}), Nothing, Nothing, Nothing) = (fld, Nothing, Nothing, Just fieldName)
aliasSelectItem fld = fld

knownColumnsInContext :: ResolverContext -> [Column]
Expand All @@ -348,7 +348,7 @@ expandStarsForDataRepresentations ctx@ResolverContext{qi} rPlanTree = Right $ fm
expandStarsForTable :: ResolverContext -> ReadPlan -> ReadPlan
expandStarsForTable ctx@ResolverContext{representations, outputType} rplan@ReadPlan{select=selectItems} =
-- If we have a '*' select AND the target table has at least one data representation, expand.
if ("*" `elem` map (\(field, _, _) -> cfName field) selectItems) && any hasOutputRep knownColumns
if ("*" `elem` map (\(field, _, _, _) -> cfName field) selectItems) && any hasOutputRep knownColumns
then rplan{select=concatMap (expandStarSelectItem knownColumns) selectItems}
else rplan
where
Expand All @@ -357,8 +357,8 @@ expandStarsForTable ctx@ResolverContext{representations, outputType} rplan@ReadP
hasOutputRep :: Column -> Bool
hasOutputRep col = HM.member (colNominalType col, outputType) representations

expandStarSelectItem :: [Column] -> (CoercibleField, Maybe Cast, Maybe Alias) -> [(CoercibleField, Maybe Cast, Maybe Alias)]
expandStarSelectItem columns (CoercibleField{cfName="*", cfJsonPath=[]}, b, c) = map (\col -> (withOutputFormat ctx $ resolveColumnField col, b, c)) columns
expandStarSelectItem :: [Column] -> (CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias) -> [(CoercibleField, Maybe AggregateFunction,Maybe Cast, Maybe Alias)]
expandStarSelectItem columns (CoercibleField{cfName="*", cfJsonPath=[]}, b, c, d) = map (\col -> (withOutputFormat ctx $ resolveColumnField col, b, c, d)) columns
expandStarSelectItem _ selectItem = [selectItem]

-- | Enforces the `max-rows` config on the result
Expand Down Expand Up @@ -770,7 +770,7 @@ inferColsEmbedNeeds (Node ReadPlan{select} forest) pkCols
| "*" `elem` fldNames = ["*"]
| otherwise = returnings
where
fldNames = cfName . (\(f, _, _) -> f) <$> select
fldNames = cfName . (\(f, _, _, _) -> f) <$> select
-- Without fkCols, when a mutatePlan to
-- /projects?select=name,clients(name) occurs, the RETURNING SQL part would
-- be `RETURNING name`(see QueryBuilder). This would make the embedding
Expand Down Expand Up @@ -839,8 +839,8 @@ binaryField AppConfig{configRawMediaTypes} acceptMediaType proc rpTree
_ -> False

fstFieldName :: ReadPlanTree -> Maybe FieldName
fstFieldName (Node ReadPlan{select=(CoercibleField{cfName="*", cfJsonPath=[]}, _, _):_} []) = Nothing
fstFieldName (Node ReadPlan{select=[(CoercibleField{cfName=fld, cfJsonPath=[]}, _, _)]} []) = Just fld
fstFieldName (Node ReadPlan{select=(CoercibleField{cfName="*", cfJsonPath=[]}, _, _, _):_} []) = Nothing
fstFieldName (Node ReadPlan{select=[(CoercibleField{cfName=fld, cfJsonPath=[]}, _, _, _)]} []) = Just fld
fstFieldName _ = Nothing


Expand Down
6 changes: 3 additions & 3 deletions src/PostgREST/Plan/ReadPlan.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ module PostgREST.Plan.ReadPlan

import Data.Tree (Tree (..))

import PostgREST.ApiRequest.Types (Alias, Cast, Depth, Hint,
JoinType, NodeName)
import PostgREST.ApiRequest.Types (AggregateFunction, Alias, Cast, Depth,
Hint, JoinType, NodeName)
import PostgREST.Plan.Types (CoercibleField (..),
CoercibleLogicTree,
CoercibleOrderTerm)
Expand All @@ -28,7 +28,7 @@ data JoinCondition =
deriving (Eq, Show)

data ReadPlan = ReadPlan
{ select :: [(CoercibleField, Maybe Cast, Maybe Alias)]
{ select :: [(CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias)]
, from :: QualifiedIdentifier
, fromAlias :: Maybe Alias
, where_ :: [CoercibleLogicTree]
Expand Down
3 changes: 2 additions & 1 deletion src/PostgREST/Query/QueryBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ readPlanToQuery (Node ReadPlan{select,from=mainQi,fromAlias,where_=logicForest,o
(if null logicForest && null relJoinConds
then mempty
else "WHERE " <> intercalateSnippet " AND " (map (pgFmtLogicTree qi) logicForest ++ map pgFmtJoinCondition relJoinConds)) <> " " <>
groupF qi select <> " " <>
orderF qi order <> " " <>
limitOffsetF readRange
where
fromFrag = fromF relToParent mainQi fromAlias
qi = getQualifiedIdentifier relToParent mainQi fromAlias
defSelect = [(unknownField "*" [], Nothing, Nothing)] -- gets all the columns in case of an empty select, ignoring/obtaining these columns is done at the aggregation stage
defSelect = [(unknownField "*" [], Nothing, Nothing, Nothing)] -- gets all the columns in case of an empty select, ignoring/obtaining these columns is done at the aggregation stage
(selects, joins) = foldr getSelectsJoins ([],[]) forest

getSelectsJoins :: ReadPlanTree -> ([SQL.Snippet], [SQL.Snippet]) -> ([SQL.Snippet], [SQL.Snippet])
Expand Down
35 changes: 30 additions & 5 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ module PostgREST.Query.SqlFragment
( noLocationF
, aggF
, countF
, groupF
, fromQi
, limitOffsetF
, locationF
Expand Down Expand Up @@ -50,7 +51,8 @@ import Control.Arrow ((***))
import Data.Foldable (foldr1)
import Text.InterpolatedString.Perl6 (qc)

import PostgREST.ApiRequest.Types (Alias, Cast,
import PostgREST.ApiRequest.Types (AggregateFunction (..),
Alias, Cast,
FtsOperator (..),
JsonOperand (..),
JsonOperation (..),
Expand Down Expand Up @@ -82,7 +84,7 @@ import PostgREST.SchemaCache.Routine (ResultAggregate (..),
funcReturnsSetOfScalar,
funcReturnsSingleComposite)

import Protolude hiding (cast)
import Protolude hiding (cast, Sum)

sourceCTEName :: Text
sourceCTEName = "pgrst_source"
Expand Down Expand Up @@ -260,12 +262,21 @@ pgFmtCoerceNamed :: CoercibleField -> SQL.Snippet
pgFmtCoerceNamed CoercibleField{cfName=fn, cfTransform=(Just formatterProc)} = pgFmtCallUnary formatterProc (pgFmtIdent fn) <> " AS " <> pgFmtIdent fn
pgFmtCoerceNamed CoercibleField{cfName=fn} = pgFmtIdent fn

pgFmtSelectItem :: QualifiedIdentifier -> (CoercibleField, Maybe Cast, Maybe Alias) -> SQL.Snippet
pgFmtSelectItem table (fld, Nothing, alias) = pgFmtTableCoerce table fld <> pgFmtAs (cfName fld) (cfJsonPath fld) alias
pgFmtSelectItem :: QualifiedIdentifier -> (CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias) -> SQL.Snippet
pgFmtSelectItem table (fld, agg, Nothing, alias) = pgFmtApplyAggregate agg (pgFmtTableCoerce table fld) <> pgFmtAs (cfName fld) (cfJsonPath fld) alias
-- Ideally we'd quote the cast with "pgFmtIdent cast". However, that would invalidate common casts such as "int", "bigint", etc.
-- Try doing: `select 1::"bigint"` - it'll err, using "int8" will work though. There's some parser magic that pg does that's invalidated when quoting.
-- Not quoting should be fine, we validate the input on Parsers.
pgFmtSelectItem table (fld, Just cast, alias) = "CAST (" <> pgFmtTableCoerce table fld <> " AS " <> SQL.sql (encodeUtf8 cast) <> " )" <> pgFmtAs (cfName fld) (cfJsonPath fld) alias
pgFmtSelectItem table (fld, agg, Just cast, alias) = pgFmtApplyAggregate agg ("CAST (" <> pgFmtTableCoerce table fld <> " AS " <> SQL.sql (encodeUtf8 cast) <> " )") <> pgFmtAs (cfName fld) (cfJsonPath fld) alias

pgFmtApplyAggregate :: Maybe AggregateFunction -> SQL.Snippet -> SQL.Snippet
pgFmtApplyAggregate Nothing snippet = snippet
pgFmtApplyAggregate (Just agg) snippet = case agg of
Sum -> "SUM( " <> snippet <> " )"
Max -> "MAX( " <> snippet <> " )"
Min -> "MIN( " <> snippet <> " )"
Avg -> "AVG( " <> snippet <> " )"
Count -> "COUNT( " <> snippet <> " )"

-- TODO: At this stage there shouldn't be a Maybe since ApiRequest should ensure that an INSERT/UPDATE has a body
fromJsonBodyF :: Maybe LBS.ByteString -> [CoercibleField] -> Bool -> Bool -> Bool -> SQL.Snippet
Expand Down Expand Up @@ -409,6 +420,19 @@ pgFmtAs fName jp Nothing = case jOp <$> lastMay jp of
Nothing -> mempty
pgFmtAs _ _ (Just alias) = " AS " <> pgFmtIdent alias

groupF :: QualifiedIdentifier -> [(CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias)] -> SQL.Snippet
groupF _ [] = mempty
groupF qi fields =
if all (\(_, agg, _, _) -> isNothing agg) fields || all (\(_, agg, _, _) -> isJust agg) fields
then
mempty
else
" GROUP BY " <> intercalateSnippet ", " (pgFmtGroup qi <$> (filter (\(_, agg, _, _) -> isNothing agg) fields))

pgFmtGroup :: QualifiedIdentifier -> (CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias) -> SQL.Snippet
pgFmtGroup _ (_, Just _, _, _) = mempty
pgFmtGroup qi (fld, _, _, _) = pgFmtField qi fld

countF :: SQL.Snippet -> Bool -> (SQL.Snippet, SQL.Snippet)
countF countQuery shouldCount =
if shouldCount
Expand Down Expand Up @@ -496,6 +520,7 @@ setConfigLocalJson prefix keyVals = [setConfigLocal mempty (prefix, gucJsonVal k
arrayByteStringToText :: [(ByteString, ByteString)] -> [(Text,Text)]
arrayByteStringToText keyVal = (T.decodeUtf8 *** T.decodeUtf8) <$> keyVal

-- Investigate this
aggF :: Maybe Routine -> ResultAggregate -> SQL.Snippet
aggF rout = \case
BuiltinAggJson -> asJsonF rout False
Expand Down

0 comments on commit 997289e

Please sign in to comment.