Skip to content

Commit

Permalink
Add UT and IT for 2+ level aggregations PPL command (#603) (#614)
Browse files Browse the repository at this point in the history
* Add UT and IT for 2+ level aggregations PPL command



* doc



---------


(cherry picked from commit eb2bbb6)

Signed-off-by: Lantao Jin <ltjin@amazon.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent ad5697b commit 1f2472d
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, LessThan, Literal, Not, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, GreaterThan, LessThan, Literal, Not, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand Down Expand Up @@ -919,4 +919,207 @@ class FlintSparkPPLAggregationsITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("two-level stats") {
val frame = sql(s"""
| source = $testTable| stats avg(age) as avg_age by state, country | stats avg(avg_age) as avg_state_age by country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA"))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val avgAgeField = UnresolvedAttribute("avg_age")
val stateAlias = Alias(stateField, "state")()
val countryAlias = Alias(countryField, "country")()
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes1 = Seq(stateAlias, countryAlias)
val aggregateExpressions1 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")()
val aggregatePlan1 =
Aggregate(groupByAttributes1, Seq(aggregateExpressions1, stateAlias, countryAlias), table)

val groupByAttributes2 = Seq(countryAlias)
val aggregateExpressions2 =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(avgAgeField), isDistinct = false),
"avg_state_age")()

val aggregatePlan2 =
Aggregate(groupByAttributes2, Seq(aggregateExpressions2, countryAlias), aggregatePlan1)
val expectedPlan = Project(star, aggregatePlan2)

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("two-level stats with eval") {
val frame = sql(s"""
| source = $testTable| stats avg(age) as avg_age by state, country | eval new_avg_age = avg_age - 10 | stats avg(new_avg_age) as avg_state_age by country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(12.5, "Canada"), Row(40.0, "USA"))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val avgAgeField = UnresolvedAttribute("avg_age")
val stateAlias = Alias(stateField, "state")()
val countryAlias = Alias(countryField, "country")()
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes1 = Seq(stateAlias, countryAlias)
val aggregateExpressions1 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")()
val aggregatePlan1 =
Aggregate(groupByAttributes1, Seq(aggregateExpressions1, stateAlias, countryAlias), table)

val newAvgAgeAlias =
Alias(
UnresolvedFunction(Seq("-"), Seq(avgAgeField, Literal(10)), isDistinct = false),
"new_avg_age")()
val evalProject = Project(Seq(UnresolvedStar(None), newAvgAgeAlias), aggregatePlan1)

val groupByAttributes2 = Seq(countryAlias)
val aggregateExpressions2 =
Alias(
UnresolvedFunction(
Seq("AVG"),
Seq(UnresolvedAttribute("new_avg_age")),
isDistinct = false),
"avg_state_age")()

val aggregatePlan2 =
Aggregate(groupByAttributes2, Seq(aggregateExpressions2, countryAlias), evalProject)
val expectedPlan = Project(star, aggregatePlan2)

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("two-level stats with filter") {
val frame = sql(s"""
| source = $testTable| stats avg(age) as avg_age by country, state | where avg_age > 0 | stats count(avg_age) as count_state_age by country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA"))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val avgAgeField = UnresolvedAttribute("avg_age")
val stateAlias = Alias(stateField, "state")()
val countryAlias = Alias(countryField, "country")()
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes1 = Seq(countryAlias, stateAlias)
val aggregateExpressions1 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")()
val aggregatePlan1 =
Aggregate(groupByAttributes1, Seq(aggregateExpressions1, countryAlias, stateAlias), table)

val filterExpr = GreaterThan(avgAgeField, Literal(0))
val filterPlan = Filter(filterExpr, aggregatePlan1)

val groupByAttributes2 = Seq(countryAlias)
val aggregateExpressions2 =
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(avgAgeField), isDistinct = false),
"count_state_age")()

val aggregatePlan2 =
Aggregate(groupByAttributes2, Seq(aggregateExpressions2, countryAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan2)

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("three-level stats with eval and filter") {
val frame = sql(s"""
| source = $testTable| stats avg(age) as avg_age by country, state, name | eval avg_age_divide_20 = avg_age - 20 | stats avg(avg_age_divide_20)
| as avg_state_age by country, state | where avg_state_age > 0 | stats count(avg_state_age) as count_country_age_greater_20 by country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(1L, "Canada"), Row(2L, "USA"))

implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val star = Seq(UnresolvedStar(None))
val nameField = UnresolvedAttribute("name")
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val avgAgeField = UnresolvedAttribute("avg_age")
val nameAlias = Alias(nameField, "name")()
val stateAlias = Alias(stateField, "state")()
val countryAlias = Alias(countryField, "country")()
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes1 = Seq(countryAlias, stateAlias, nameAlias)
val aggregateExpressions1 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")()
val aggregatePlan1 =
Aggregate(
groupByAttributes1,
Seq(aggregateExpressions1, countryAlias, stateAlias, nameAlias),
table)

val avg_age_divide_20_Alias =
Alias(
UnresolvedFunction(Seq("-"), Seq(avgAgeField, Literal(20)), isDistinct = false),
"avg_age_divide_20")()
val evalProject = Project(Seq(UnresolvedStar(None), avg_age_divide_20_Alias), aggregatePlan1)
val groupByAttributes2 = Seq(countryAlias, stateAlias)
val aggregateExpressions2 =
Alias(
UnresolvedFunction(
Seq("AVG"),
Seq(UnresolvedAttribute("avg_age_divide_20")),
isDistinct = false),
"avg_state_age")()
val aggregatePlan2 =
Aggregate(
groupByAttributes2,
Seq(aggregateExpressions2, countryAlias, stateAlias),
evalProject)

val filterExpr = GreaterThan(UnresolvedAttribute("avg_state_age"), Literal(0))
val filterPlan = Filter(filterExpr, aggregatePlan2)

val groupByAttributes3 = Seq(countryAlias)
val aggregateExpressions3 =
Alias(
UnresolvedFunction(
Seq("COUNT"),
Seq(UnresolvedAttribute("avg_state_age")),
isDistinct = false),
"count_country_age_greater_20")()

val aggregatePlan3 =
Aggregate(groupByAttributes3, Seq(aggregateExpressions3, countryAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan3)

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
4 changes: 4 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source = table | stats sum(productsAmount) by span(transactionDate, 1d) as age_date | sort age_date`
- `source = table | stats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId`

**Aggregations Group by Multiple Levels**
- `source = table | stats avg(age) as avg_state_age by country, state | stats avg(avg_state_age) as avg_country_age by country`
- `source = table | stats avg(age) as avg_city_age by country, state, city | eval new_avg_city_age = avg_city_age - 1 | stats avg(new_avg_city_age) as avg_state_age by country, state | where avg_state_age > 18 | stats avg(avg_state_age) as avg_adult_country_age by country`

**Dedup**
- `source = table | dedup a | fields a,b,c`
- `source = table | dedup a,b | fields a,b,c`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -876,4 +876,128 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
comparePlans(expectedPlan, logPlan, false)
}

test("multiple stats - test average price and average age") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser, "source = table | stats avg(price) | stats avg(age)", false),
context)
val star = Seq(UnresolvedStar(None))

val priceField = UnresolvedAttribute("price")
val ageField = UnresolvedAttribute("age")
val tableRelation = UnresolvedRelation(Seq("table"))
val aggregateExpressions1 = Seq(
Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")())
val aggregatePlan1 = Aggregate(Seq(), aggregateExpressions1, tableRelation)
val aggregateExpressions2 =
Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")())
val aggregatePlan2 = Aggregate(Seq(), aggregateExpressions2, aggregatePlan1)
val expectedPlan = Project(star, aggregatePlan2)

comparePlans(expectedPlan, logPlan, false)
}

test("multiple stats - test average price and average age with Alias") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table | stats avg(price) as avg_price | stats avg(age) as avg_age",
false),
context)
val star = Seq(UnresolvedStar(None))

val priceField = UnresolvedAttribute("price")
val ageField = UnresolvedAttribute("age")
val tableRelation = UnresolvedRelation(Seq("table"))
val aggregateExpressions1 = Seq(
Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg_price")())
val aggregatePlan1 = Aggregate(Seq(), aggregateExpressions1, tableRelation)
val aggregateExpressions2 =
Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg_age")())
val aggregatePlan2 = Aggregate(Seq(), aggregateExpressions2, aggregatePlan1)
val expectedPlan = Project(star, aggregatePlan2)

comparePlans(expectedPlan, logPlan, false)
}

test(
"multiple stats - test average price group by product and average age by span of interval of 10 years") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table | stats avg(price) by product | stats avg(age) by span(age, 10) as age_span",
false),
context)
val star = Seq(UnresolvedStar(None))
val productField = UnresolvedAttribute("product")
val priceField = UnresolvedAttribute("price")
val ageField = UnresolvedAttribute("age")
val tableRelation = UnresolvedRelation(Seq("table"))

val groupByAttributes = Seq(Alias(productField, "product")())
val aggregateExpressions1 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")()
val productAlias = Alias(productField, "product")()

val aggregatePlan1 =
Aggregate(groupByAttributes, Seq(aggregateExpressions1, productAlias), tableRelation)

val aggregateExpressions2 =
Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val aggregatePlan2 = Aggregate(Seq(span), Seq(aggregateExpressions2, span), aggregatePlan1)

val expectedPlan = Project(star, aggregatePlan2)

comparePlans(expectedPlan, logPlan, false)
}

test("multiple levels stats") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
plan(
pplParser,
"source = table | stats avg(response_time) as avg_response_time by host, service | stats avg(avg_response_time) as avg_host_response_time by service",
false),
context)
val star = Seq(UnresolvedStar(None))
val hostField = UnresolvedAttribute("host")
val serviceField = UnresolvedAttribute("service")
val ageField = UnresolvedAttribute("age")
val responseTimeField = UnresolvedAttribute("response_time")
val tableRelation = UnresolvedRelation(Seq("table"))
val hostAlias = Alias(hostField, "host")()
val serviceAlias = Alias(serviceField, "service")()

val groupByAttributes1 = Seq(Alias(hostField, "host")(), Alias(serviceField, "service")())
val aggregateExpressions1 =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(responseTimeField), isDistinct = false),
"avg_response_time")()
val responseTimeAlias = Alias(responseTimeField, "response_time")()
val aggregatePlan1 =
Aggregate(
groupByAttributes1,
Seq(aggregateExpressions1, hostAlias, serviceAlias),
tableRelation)

val avgResponseTimeField = UnresolvedAttribute("avg_response_time")
val groupByAttributes2 = Seq(Alias(serviceField, "service")())
val aggregateExpressions2 =
Alias(
UnresolvedFunction(Seq("AVG"), Seq(avgResponseTimeField), isDistinct = false),
"avg_host_response_time")()

val aggregatePlan2 =
Aggregate(groupByAttributes2, Seq(aggregateExpressions2, serviceAlias), aggregatePlan1)

val expectedPlan = Project(star, aggregatePlan2)

comparePlans(expectedPlan, logPlan, false)
}
}

0 comments on commit 1f2472d

Please sign in to comment.