diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index c17b2f7856..21b395982b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -46,6 +46,7 @@ import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects import org.apache.comet.CometRuntimeException +import org.apache.comet.shims.ShimCometBroadcastExchangeExec /** * A [[CometBroadcastExchangeExec]] collects, transforms and finally broadcasts the result of a @@ -64,8 +65,8 @@ case class CometBroadcastExchangeExec( mode: BroadcastMode, override val child: SparkPlan) extends BroadcastExchangeLike + with ShimCometBroadcastExchangeExec with CometPlan { - import CometBroadcastExchangeExec._ override val runId: UUID = UUID.randomUUID @@ -117,11 +118,7 @@ case class CometBroadcastExchangeExec( session, CometBroadcastExchangeExec.executionContext) { try { - // Setup a job group here so later it may get cancelled by groupId if necessary. - sparkContext.setJobGroup( - runId.toString, - s"broadcast exchange (runId $runId)", - interruptOnCancel = true) + setJobGroupOrTag(sparkContext, this) val beforeCollect = System.nanoTime() val countsAndBytes = child match { @@ -167,9 +164,10 @@ case class CometBroadcastExchangeExec( val dataSize = batches.map(_.size).sum longMetric("dataSize") += dataSize - if (dataSize >= MAX_BROADCAST_TABLE_BYTES) { + val maxBytes = maxBroadcastTableBytes(conf) + if (dataSize >= maxBytes) { throw QueryExecutionErrors.cannotBroadcastTableOverMaxTableBytesError( - MAX_BROADCAST_TABLE_BYTES, + maxBytes, dataSize) } @@ -233,7 +231,7 @@ case class CometBroadcastExchangeExec( case ex: TimeoutException => logError(s"Could not execute broadcast in $timeout secs.", ex) if (!relationFuture.isDone) { - sparkContext.cancelJobGroup(runId.toString) + cancelJobGroup(sparkContext, this) relationFuture.cancel(true) } throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex)) @@ -259,8 +257,6 @@ case class CometBroadcastExchangeExec( } object CometBroadcastExchangeExec { - val MAX_BROADCAST_TABLE_BYTES: Long = 8L << 30 - private[comet] val executionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool( "comet-broadcast-exchange", diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala new file mode 100644 index 0000000000..98ac1e7d2d --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.comet.shims.ShimCometBroadcastExchangeExec.SPARK_MAX_BROADCAST_TABLE_SIZE +import org.apache.spark.SparkContext +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike +import org.apache.spark.sql.internal.SQLConf + +trait ShimCometBroadcastExchangeExec { + + def setJobGroupOrTag(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = { + // Setup a job group here so later it may get cancelled by groupId if necessary. + sc.setJobGroup( + broadcastExchange.runId.toString, + s"broadcast exchange (runId ${broadcastExchange.runId})", + interruptOnCancel = true) + } + + def cancelJobGroup(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = { + sc.cancelJobGroup(broadcastExchange.runId.toString) + } + + def maxBroadcastTableBytes(conf: SQLConf): Long = { + JavaUtils.byteStringAsBytes(conf.getConfString(SPARK_MAX_BROADCAST_TABLE_SIZE, "8GB")) + } + +} + +object ShimCometBroadcastExchangeExec { + val SPARK_MAX_BROADCAST_TABLE_SIZE = "spark.sql.maxBroadcastTableSize" +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala new file mode 100644 index 0000000000..81053ac410 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.comet.shims.ShimCometBroadcastExchangeExec.SPARK_MAX_BROADCAST_TABLE_SIZE +import org.apache.spark.SparkContext +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike +import org.apache.spark.sql.internal.SQLConf + +trait ShimCometBroadcastExchangeExec { + + def setJobGroupOrTag(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = { + // Setup a job tag here so later it may get cancelled by tag if necessary. + sc.addJobTag(broadcastExchange.jobTag) + sc.setInterruptOnCancel(true) + } + + def cancelJobGroup(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = { + sc.cancelJobsWithTag(broadcastExchange.jobTag) + } + + def maxBroadcastTableBytes(conf: SQLConf): Long = { + JavaUtils.byteStringAsBytes(conf.getConfString(SPARK_MAX_BROADCAST_TABLE_SIZE, "8GB")) + } + +} + +object ShimCometBroadcastExchangeExec { + val SPARK_MAX_BROADCAST_TABLE_SIZE = "spark.sql.maxBroadcastTableSize" +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala new file mode 100644 index 0000000000..81053ac410 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.comet.shims.ShimCometBroadcastExchangeExec.SPARK_MAX_BROADCAST_TABLE_SIZE +import org.apache.spark.SparkContext +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike +import org.apache.spark.sql.internal.SQLConf + +trait ShimCometBroadcastExchangeExec { + + def setJobGroupOrTag(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = { + // Setup a job tag here so later it may get cancelled by tag if necessary. + sc.addJobTag(broadcastExchange.jobTag) + sc.setInterruptOnCancel(true) + } + + def cancelJobGroup(sc: SparkContext, broadcastExchange: BroadcastExchangeLike): Unit = { + sc.cancelJobsWithTag(broadcastExchange.jobTag) + } + + def maxBroadcastTableBytes(conf: SQLConf): Long = { + JavaUtils.byteStringAsBytes(conf.getConfString(SPARK_MAX_BROADCAST_TABLE_SIZE, "8GB")) + } + +} + +object ShimCometBroadcastExchangeExec { + val SPARK_MAX_BROADCAST_TABLE_SIZE = "spark.sql.maxBroadcastTableSize" +} diff --git a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala index 325ef51f68..fd6d3ef535 100644 --- a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala @@ -22,7 +22,7 @@ package org.apache.comet import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.expressions.PrettyAttribute -import org.apache.spark.sql.comet.{CometExec, CometExecUtils} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometExec, CometExecUtils} import org.apache.spark.sql.types.LongType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -97,4 +97,21 @@ class CometNativeSuite extends CometTestBase { } } } + + test("test maxBroadcastTableSize") { + withSQLConf("spark.sql.maxBroadcastTableSize" -> "10B") { + spark.range(0, 1000).createOrReplaceTempView("t1") + spark.range(0, 100).createOrReplaceTempView("t2") + val df = spark.sql("select /*+ BROADCAST(t2) */ * from t1 join t2 on t1.id = t2.id") + val exception = intercept[SparkException] { + df.collect() + } + assert( + exception.getMessage.contains("Cannot broadcast the table that is larger than 10.0 B")) + val broadcasts = collect(df.queryExecution.executedPlan) { + case p: CometBroadcastExchangeExec => p + } + assert(broadcasts.size == 1) + } + } }