diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 9fb5b40575e..1784c11892a 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -556,7 +556,7 @@ Accelerator supports are described below.
S |
PS not allowed for grouping expressions |
NS |
-PS not allowed for grouping expressions; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
+PS not allowed for grouping expressions if containing Struct as child; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
PS not allowed for grouping expressions; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
PS not allowed for grouping expressions if containing Array, Map, or Binary as child; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
NS |
@@ -748,7 +748,7 @@ Accelerator supports are described below.
S |
S |
NS |
-PS Round-robin partitioning is not supported if spark.sql.execution.sortBeforeRepartition is true; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
PS Round-robin partitioning is not supported if spark.sql.execution.sortBeforeRepartition is true; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
PS Round-robin partitioning is not supported for nested structs if spark.sql.execution.sortBeforeRepartition is true; UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT |
NS |
@@ -8042,45 +8042,45 @@ are limited.
None |
project |
input |
- |
- |
- |
- |
- |
S |
S |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
+S |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for TIMESTAMP |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+S |
result |
- |
- |
- |
- |
- |
S |
S |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
+S |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for TIMESTAMP |
+S |
+S |
+S |
+S |
+S |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+PS UTC is only supported TZ for child TIMESTAMP |
+S |
KnownNotNull |
@@ -10046,9 +10046,9 @@ are limited.
S |
NS |
NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, MAP, UDT |
NS |
-NS |
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, MAP, UDT |
NS |
@@ -19270,9 +19270,9 @@ as `a` don't show up in the table. They are controlled by the rules for
S |
NS |
NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, MAP, UDT |
NS |
-NS |
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, MAP, UDT |
NS |
diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py
index b9413f18b7d..6e488486077 100644
--- a/integration_tests/src/main/python/hash_aggregate_test.py
+++ b/integration_tests/src/main/python/hash_aggregate_test.py
@@ -116,6 +116,22 @@
('b', FloatGen(nullable=(True, 10.0), special_cases=[(float('nan'), 10.0)])),
('c', LongGen())]
+# grouping single-level lists
+# StringGen for the value being aggregated will force CUDF to do a sort based aggregation internally.
+_grpkey_list_with_non_nested_children = [[('a', RepeatSeqGen(ArrayGen(data_gen), length=3)),
+ ('b', IntegerGen())] for data_gen in all_basic_gens + decimal_gens] + \
+ [[('a', RepeatSeqGen(ArrayGen(data_gen), length=3)),
+ ('b', StringGen())] for data_gen in all_basic_gens + decimal_gens]
+
+#grouping mutliple-level structs with arrays
+_grpkey_nested_structs_with_array_basic_child = [[
+ ('a', RepeatSeqGen(StructGen([
+ ['aa', IntegerGen()],
+ ['ab', ArrayGen(IntegerGen())]]),
+ length=20)),
+ ('b', IntegerGen()),
+ ('c', NullGen())]]
+
_nan_zero_float_special_cases = [
(float('nan'), 5.0),
(NEG_FLOAT_NAN_MIN_VALUE, 5.0),
@@ -324,6 +340,14 @@ def test_hash_grpby_sum_count_action(data_gen):
lambda spark: gen_df(spark, data_gen, length=100).groupby('a').agg(f.sum('b'))
)
+@allow_non_gpu("SortAggregateExec", "SortExec", "ShuffleExchangeExec")
+@ignore_order
+@pytest.mark.parametrize('data_gen', _grpkey_nested_structs_with_array_basic_child + _grpkey_list_with_non_nested_children, ids=idfn)
+def test_hash_grpby_list_min_max(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: gen_df(spark, data_gen, length=100).coalesce(1).groupby('a').agg(f.min('b'), f.max('b'))
+ )
+
@pytest.mark.parametrize('data_gen', [_longs_with_nulls], ids=idfn)
def test_hash_reduction_sum_count_action(data_gen):
assert_gpu_and_cpu_row_counts_equal(
@@ -1199,7 +1223,9 @@ def test_agg_count(data_gen, count_func):
@ignore_order(local=True)
@allow_non_gpu('HashAggregateExec', 'Alias', 'AggregateExpression', 'Cast',
'HashPartitioning', 'ShuffleExchangeExec', 'Count')
-@pytest.mark.parametrize('data_gen', array_gens_sample + [binary_gen], ids=idfn)
+@pytest.mark.parametrize('data_gen',
+ [ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))
+ , binary_gen], ids=idfn)
@pytest.mark.parametrize('count_func', [f.count, f.countDistinct])
def test_groupby_list_types_fallback(data_gen, count_func):
assert_gpu_fallback_collect(
@@ -1718,7 +1744,6 @@ def do_it(spark):
assert_gpu_and_cpu_are_equal_collect(do_it,
conf={'spark.sql.ansi.enabled': 'true'})
-
# Tests for standard deviation and variance aggregations.
@ignore_order(local=True)
@approximate_float
diff --git a/integration_tests/src/main/python/repart_test.py b/integration_tests/src/main/python/repart_test.py
index 7b77b7be426..8ae795e9348 100644
--- a/integration_tests/src/main/python/repart_test.py
+++ b/integration_tests/src/main/python/repart_test.py
@@ -214,10 +214,23 @@ def test_round_robin_sort_fallback(data_gen):
lambda spark : gen_df(spark, data_gen).withColumn('extra', lit(1)).repartition(13),
'ShuffleExchangeExec')
+@allow_non_gpu("ProjectExec", "ShuffleExchangeExec")
+@ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test.
+@pytest.mark.parametrize('num_parts', [2, 10, 17, 19, 32], ids=idfn)
+@pytest.mark.parametrize('gen', [([('ag', ArrayGen(StructGen([('b1', long_gen)])))], ['ag'])], ids=idfn)
+def test_hash_repartition_exact_fallback(gen, num_parts):
+ data_gen = gen[0]
+ part_on = gen[1]
+ assert_gpu_fallback_collect(
+ lambda spark : gen_df(spark, data_gen, length=1024) \
+ .repartition(num_parts, *part_on) \
+ .withColumn('id', f.spark_partition_id()) \
+ .selectExpr('*'), "ShuffleExchangeExec")
+
@ignore_order(local=True) # To avoid extra data shuffle by 'sort on Spark' for this repartition test.
@pytest.mark.parametrize('num_parts', [1, 2, 10, 17, 19, 32], ids=idfn)
@pytest.mark.parametrize('gen', [
- ([('a', boolean_gen)], ['a']),
+ ([('a', boolean_gen)], ['a']),
([('a', byte_gen)], ['a']),
([('a', short_gen)], ['a']),
([('a', int_gen)], ['a']),
@@ -235,7 +248,9 @@ def test_round_robin_sort_fallback(data_gen):
([('a', long_gen), ('b', StructGen([('b1', long_gen)]))], ['a']),
([('a', long_gen), ('b', ArrayGen(long_gen, max_length=2))], ['a']),
([('a', byte_gen)], [f.col('a') - 5]),
- ([('a', long_gen)], [f.col('a') + 15]),
+ ([('a', long_gen)], [f.col('a') + 15]),
+ ([('a', ArrayGen(long_gen, max_length=2)), ('b', long_gen)], ['a']),
+ ([('a', StructGen([('aa', ArrayGen(long_gen, max_length=2))])), ('b', long_gen)], ['a']),
([('a', byte_gen), ('b', boolean_gen)], ['a', 'b']),
([('a', short_gen), ('b', string_gen)], ['a', 'b']),
([('a', int_gen), ('b', byte_gen)], ['a', 'b']),
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 1c446af9c04..387f5f21645 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -1565,9 +1565,7 @@ object GpuOverrides extends Logging {
}),
expr[KnownFloatingPointNormalized](
"Tag to prevent redundant normalization",
- ExprChecks.unaryProjectInputMatchesOutput(
- TypeSig.DOUBLE + TypeSig.FLOAT,
- TypeSig.DOUBLE + TypeSig.FLOAT),
+ ExprChecks.unaryProjectInputMatchesOutput(TypeSig.all, TypeSig.all),
(a, conf, p, r) => new UnaryExprMeta[KnownFloatingPointNormalized](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
GpuKnownFloatingPointNormalized(child)
@@ -3070,7 +3068,7 @@ object GpuOverrides extends Logging {
ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT,
repeatingParamCheck = Some(RepeatingParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
- TypeSig.STRUCT).nested(), TypeSig.all))),
+ TypeSig.STRUCT + TypeSig.ARRAY).nested(), TypeSig.all))),
(a, conf, p, r) => new ExprMeta[Murmur3Hash](a, conf, p, r) {
override val childExprs: Seq[BaseExprMeta[_]] = a.children
.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
@@ -3592,11 +3590,26 @@ object GpuOverrides extends Logging {
// This needs to match what murmur3 supports.
PartChecks(RepeatingParamCheck("hash_key",
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
- TypeSig.STRUCT).nested(), TypeSig.all)),
+ TypeSig.STRUCT + TypeSig.ARRAY).nested(),
+ TypeSig.all)
+ ),
(hp, conf, p, r) => new PartMeta[HashPartitioning](hp, conf, p, r) {
override val childExprs: Seq[BaseExprMeta[_]] =
hp.expressions.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+ override def tagPartForGpu(): Unit = {
+ val arrayWithStructsHashing = hp.expressions.exists(e =>
+ TrampolineUtil.dataTypeExistsRecursively(e.dataType,
+ dt => dt match {
+ case ArrayType(_: StructType, _) => true
+ case _ => false
+ })
+ )
+ if (arrayWithStructsHashing) {
+ willNotWorkOnGpu("hashing arrays with structs is not supported")
+ }
+ }
+
override def convertToGpu(): GpuPartitioning =
GpuHashPartitioning(childExprs.map(_.convertToGpu()), hp.numPartitions)
}),
@@ -3844,7 +3857,7 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.STRUCT, "Round-robin partitioning is not supported for nested " +
s"structs if ${SQLConf.SORT_BEFORE_REPARTITION.key} is true")
.withPsNote(
- Seq(TypeEnum.ARRAY, TypeEnum.MAP),
+ Seq(TypeEnum.MAP),
"Round-robin partitioning is not supported if " +
s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true"),
TypeSig.all),
@@ -3909,8 +3922,10 @@ object GpuOverrides extends Logging {
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.BINARY +
TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT)
.nested()
- .withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP, TypeEnum.BINARY),
+ .withPsNote(Seq(TypeEnum.MAP, TypeEnum.BINARY),
"not allowed for grouping expressions")
+ .withPsNote(TypeEnum.ARRAY,
+ "not allowed for grouping expressions if containing Struct as child")
.withPsNote(TypeEnum.STRUCT,
"not allowed for grouping expressions if containing Array, Map, or Binary as child"),
TypeSig.all),
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
index df20db39c75..eb54007a798 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
@@ -1036,17 +1036,30 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan](
groupingExpressions ++ aggregateExpressions ++ aggregateAttributes ++ resultExpressions
override def tagPlanForGpu(): Unit = {
- // We don't support Arrays and Maps as GroupBy keys yet, even they are nested in Structs. So,
+ // We don't support Maps as GroupBy keys yet, even if they are nested in Structs. So,
// we need to run recursive type check on the structs.
- val listTypeGroupings = agg.groupingExpressions.exists(e =>
+ val mapOrBinaryGroupings = agg.groupingExpressions.exists(e =>
TrampolineUtil.dataTypeExistsRecursively(e.dataType,
- dt => dt.isInstanceOf[ArrayType] || dt.isInstanceOf[MapType]
- || dt.isInstanceOf[BinaryType]))
- if (listTypeGroupings) {
- willNotWorkOnGpu("ArrayType, MapType, or BinaryType " +
+ dt => dt.isInstanceOf[MapType] || dt.isInstanceOf[BinaryType]))
+ if (mapOrBinaryGroupings) {
+ willNotWorkOnGpu("MapType, or BinaryType " +
"in grouping expressions are not supported")
}
+ // We support Arrays as grouping expression but not if the child is a struct. So we need to
+ // run recursive type check on the lists of structs
+ val arrayWithStructsGroupings = agg.groupingExpressions.exists(e =>
+ TrampolineUtil.dataTypeExistsRecursively(e.dataType,
+ dt => dt match {
+ case ArrayType(_: StructType, _) => true
+ case _ => false
+ })
+ )
+ if (arrayWithStructsGroupings) {
+ willNotWorkOnGpu("ArrayTypes with Struct children in grouping expressions are not " +
+ "supported")
+ }
+
tagForReplaceMode()
if (agg.aggregateExpressions.exists(expr => expr.isDistinct)
diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv
index 5481157c85a..391e4c199bd 100644
--- a/tools/generated_files/supportedExprs.csv
+++ b/tools/generated_files/supportedExprs.csv
@@ -271,8 +271,8 @@ JsonToStructs,NS,`from_json`,This is disabled by default because parsing JSON fr
JsonTuple,S,`json_tuple`,None,project,json,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
JsonTuple,S,`json_tuple`,None,project,field,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA
JsonTuple,S,`json_tuple`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA
-KnownFloatingPointNormalized,S, ,None,project,input,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
-KnownFloatingPointNormalized,S, ,None,project,result,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
+KnownFloatingPointNormalized,S, ,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S
+KnownFloatingPointNormalized,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,S
KnownNotNull,S, ,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,NS,S,S,PS,PS,PS,NS
KnownNotNull,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,S,S,PS,PS,PS,NS
Lag,S,`lag`,None,window,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NS,PS,NS
@@ -352,7 +352,7 @@ Multiply,S,`*`,None,project,result,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,N
Multiply,S,`*`,None,AST,lhs,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NA,NA,NA,NA,NA
Multiply,S,`*`,None,AST,rhs,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NA,NA,NA,NA,NA
Multiply,S,`*`,None,AST,result,NA,NS,NS,S,S,S,S,NA,NA,NA,NS,NA,NA,NA,NA,NA,NA,NA
-Murmur3Hash,S,`hash`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,PS,NS
+Murmur3Hash,S,`hash`,None,project,input,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,NS,PS,NS
Murmur3Hash,S,`hash`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
NaNvl,S,`nanvl`,None,project,lhs,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
NaNvl,S,`nanvl`,None,project,rhs,NA,NA,NA,NA,NA,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA