diff --git a/integration_tests/src/main/python/spark_session.py b/integration_tests/src/main/python/spark_session.py index 17ba470f168..78e0b08a651 100644 --- a/integration_tests/src/main/python/spark_session.py +++ b/integration_tests/src/main/python/spark_session.py @@ -260,6 +260,9 @@ def is_databricks113_or_later(): def is_databricks122_or_later(): return is_databricks_version_or_later(12, 2) +def is_databricks133_or_later(): + return is_databricks_version_or_later(13, 3) + def supports_delta_lake_deletion_vectors(): if is_databricks_runtime(): return is_databricks122_or_later() diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index b788a9b13c9..af8bbbb55b3 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -21,7 +21,7 @@ from pyspark.sql.types import DateType, TimestampType, NumericType from pyspark.sql.window import Window import pyspark.sql.functions as f -from spark_session import is_before_spark_320, is_before_spark_350, is_databricks113_or_later, spark_version, with_cpu_session +from spark_session import is_before_spark_320, is_databricks113_or_later, is_databricks133_or_later, is_spark_350_or_later, spark_version, with_cpu_session import warnings _grpkey_longs_with_no_nulls = [ @@ -2042,8 +2042,9 @@ def assert_query_runs_on(exec, conf): assert_query_runs_on(exec='GpuBatchedBoundedWindowExec', conf=conf_200) -@pytest.mark.skipif(condition=is_before_spark_350(), - reason="WindowGroupLimit not available for spark.version < 3.5") +@pytest.mark.skipif(condition=not (is_spark_350_or_later() or is_databricks133_or_later()), + reason="WindowGroupLimit not available for spark.version < 3.5 " + "and Databricks version < 13.3") @ignore_order(local=True) @approximate_float @pytest.mark.parametrize('batch_size', ['1k', '1g'], ids=idfn) @@ -2087,12 +2088,13 @@ def test_window_group_limits_for_ranking_functions(data_gen, batch_size, rank_cl lambda spark: gen_df(spark, data_gen, length=4096), "window_agg_table", query, - conf = conf) + conf=conf) @allow_non_gpu('WindowGroupLimitExec') -@pytest.mark.skipif(condition=is_before_spark_350(), - reason="WindowGroupLimit not available for spark.version < 3.5") +@pytest.mark.skipif(condition=not (is_spark_350_or_later() or is_databricks133_or_later()), + reason="WindowGroupLimit not available for spark.version < 3.5 " + " and Databricks version < 13.3") @ignore_order(local=True) @approximate_float def test_window_group_limits_fallback_for_row_number(): diff --git a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/GpuWindowGroupLimitExec.scala b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/GpuWindowGroupLimitExec.scala similarity index 99% rename from sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/GpuWindowGroupLimitExec.scala rename to sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/GpuWindowGroupLimitExec.scala index 5d879283a38..3406186a9d0 100644 --- a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/GpuWindowGroupLimitExec.scala +++ b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/GpuWindowGroupLimitExec.scala @@ -15,6 +15,7 @@ */ /*** spark-rapids-shim-json-lines +{"spark": "341db"} {"spark": "350"} {"spark": "351"} spark-rapids-shim-json-lines ***/ diff --git a/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala index d5f554adcee..667a6912abc 100644 --- a/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala +++ b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/Spark341PlusDBShims.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} +import org.apache.spark.sql.execution.window.WindowGroupLimitExec import org.apache.spark.sql.rapids.GpuV1WriteUtils.GpuEmpty2Null import org.apache.spark.sql.rapids.execution.python.GpuPythonUDAF import org.apache.spark.sql.types.StringType @@ -167,7 +168,15 @@ trait Spark341PlusDBShims extends Spark332PlusDBShims { } ).disabledByDefault("Collect Limit replacement can be slower on the GPU, if huge number " + "of rows in a batch it could help by limiting the number of rows transferred from " + - "GPU to CPU") + "GPU to CPU"), + GpuOverrides.exec[WindowGroupLimitExec]( + "Apply group-limits for row groups destined for rank-based window functions like " + + "row_number(), rank(), and dense_rank()", + ExecChecks( // Similar to WindowExec. + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), + TypeSig.all), + (limit, conf, p, r) => new GpuWindowGroupLimitExecMeta(limit, conf, p, r)) ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =