-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-47545][CONNECT] Dataset observe support for the Scala client
#45701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fb38900
e534f82
3441c8a
f14cca3
baa5ca3
8efa252
13eadb3
2d0af64
103e57a
7ce7a57
8ddd42f
cf3437b
459f2a6
f9cc2a5
0437c45
cc37898
4c1dda7
ef8a6d6
b7d8cef
3c0cf22
4e1772f
1ddbe92
294ec85
f949903
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql | ||
|
|
||
| import java.util.UUID | ||
|
|
||
| class Observation(name: String) extends ObservationBase(name) { | ||
xupefei marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| /** | ||
| * Create an Observation instance without providing a name. This generates a random name. | ||
| */ | ||
| def this() = this(UUID.randomUUID().toString) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Create instances of Observation via Scala `apply`. | ||
| * @since 4.0.0 | ||
| */ | ||
| object Observation { | ||
|
|
||
| /** | ||
| * Observation constructor for creating an anonymous observation. | ||
| */ | ||
| def apply(): Observation = new Observation() | ||
|
|
||
| /** | ||
| * Observation constructor for creating a named observation. | ||
| */ | ||
| def apply(name: String): Observation = new Observation(name) | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,8 @@ import java.time.DateTimeException | |
| import java.util.Properties | ||
|
|
||
| import scala.collection.mutable | ||
| import scala.concurrent.{ExecutionContext, Future} | ||
| import scala.concurrent.duration.DurationInt | ||
| import scala.jdk.CollectionConverters._ | ||
|
|
||
| import org.apache.commons.io.FileUtils | ||
|
|
@@ -41,6 +43,7 @@ import org.apache.spark.sql.internal.SqlApiConf | |
| import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession, SQLHelper} | ||
| import org.apache.spark.sql.test.SparkConnectServerUtils.port | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.util.SparkThreadUtils | ||
|
|
||
| class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { | ||
|
|
||
|
|
@@ -1511,6 +1514,46 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM | |
| (0 until 5).foreach(i => assert(row.get(i * 2) === row.get(i * 2 + 1))) | ||
| } | ||
| } | ||
|
|
||
| test("Observable metrics") { | ||
| val df = spark.range(99).withColumn("extra", col("id") - 1) | ||
| val ob1 = new Observation("ob1") | ||
| val observedDf = df.observe(ob1, min("id"), avg("id"), max("id")) | ||
| val observedObservedDf = observedDf.observe("ob2", min("extra"), avg("extra"), max("extra")) | ||
|
|
||
| val ob1Schema = new StructType() | ||
| .add("min(id)", LongType) | ||
| .add("avg(id)", DoubleType) | ||
| .add("max(id)", LongType) | ||
| val ob2Schema = new StructType() | ||
| .add("min(extra)", LongType) | ||
| .add("avg(extra)", DoubleType) | ||
| .add("max(extra)", LongType) | ||
| val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), ob1Schema)) | ||
| val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), ob2Schema)) | ||
|
|
||
| assert(df.collectResult().getObservedMetrics === Map.empty) | ||
| assert(observedDf.collectResult().getObservedMetrics === ob1Metrics) | ||
| assert(observedObservedDf.collectResult().getObservedMetrics === ob1Metrics ++ ob2Metrics) | ||
| } | ||
|
|
||
| test("Observation.get is blocked until the query is finished") { | ||
| val df = spark.range(99).withColumn("extra", col("id") - 1) | ||
| val observation = new Observation("ob1") | ||
| val observedDf = df.observe(observation, min("id"), avg("id"), max("id")) | ||
|
|
||
| // Start a new thread to get the observation | ||
| val future = Future(observation.get)(ExecutionContext.global) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the record. IMO the observation class should have been using a future from the get go. |
||
| // make sure the thread is blocked right now | ||
| val e = intercept[java.util.concurrent.TimeoutException] { | ||
| SparkThreadUtils.awaitResult(future, 2.seconds) | ||
| } | ||
| assert(e.getMessage.contains("Future timed out")) | ||
| observedDf.collect() | ||
| // make sure the thread is unblocked after the query is finished | ||
| val metrics = SparkThreadUtils.awaitResult(future, 2.seconds) | ||
| assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98)) | ||
| } | ||
| } | ||
|
|
||
| private[sql] case class ClassData(a: String, b: Int) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,18 +27,22 @@ import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch} | |
| import org.apache.arrow.vector.types.pojo | ||
|
|
||
| import org.apache.spark.connect.proto | ||
| import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics | ||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} | ||
| import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder} | ||
| import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema | ||
| import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator} | ||
| import org.apache.spark.sql.connect.common.DataTypeProtoConverter | ||
| import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter} | ||
| import org.apache.spark.sql.types.{DataType, StructType} | ||
| import org.apache.spark.sql.util.ArrowUtils | ||
|
|
||
| private[sql] class SparkResult[T]( | ||
| responses: CloseableIterator[proto.ExecutePlanResponse], | ||
| allocator: BufferAllocator, | ||
| encoder: AgnosticEncoder[T], | ||
| timeZoneId: String) | ||
| timeZoneId: String, | ||
| setObservationMetricsOpt: Option[(Long, Map[String, Any]) => Unit] = None) | ||
| extends AutoCloseable { self => | ||
|
|
||
| case class StageInfo( | ||
|
|
@@ -79,6 +83,7 @@ private[sql] class SparkResult[T]( | |
| private[this] var arrowSchema: pojo.Schema = _ | ||
| private[this] var nextResultIndex: Int = 0 | ||
| private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])] | ||
| private val observedMetrics = mutable.Map.empty[String, Row] | ||
| private val cleanable = | ||
| SparkResult.cleaner.register(this, new SparkResultCloseable(resultMap, responses)) | ||
|
|
||
|
|
@@ -117,6 +122,9 @@ private[sql] class SparkResult[T]( | |
| while (!stop && responses.hasNext) { | ||
| val response = responses.next() | ||
|
|
||
| // Collect metrics for this response | ||
| observedMetrics ++= processObservedMetrics(response.getObservedMetricsList) | ||
|
|
||
| // Save and validate operationId | ||
| if (opId == null) { | ||
| opId = response.getOperationId | ||
|
|
@@ -198,6 +206,29 @@ private[sql] class SparkResult[T]( | |
| nonEmpty | ||
| } | ||
|
|
||
| private def processObservedMetrics( | ||
| metrics: java.util.List[ObservedMetrics]): Iterable[(String, Row)] = { | ||
| metrics.asScala.map { metric => | ||
| assert(metric.getKeysCount == metric.getValuesCount) | ||
| var schema = new StructType() | ||
| val keys = mutable.ListBuffer.empty[String] | ||
| val values = mutable.ListBuffer.empty[Any] | ||
| (0 until metric.getKeysCount).map { i => | ||
| val key = metric.getKeys(i) | ||
| val value = LiteralValueProtoConverter.toCatalystValue(metric.getValues(i)) | ||
| schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass)) | ||
xupefei marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a bit of a twist here. So, LiteralValueProtoConverter, returns a Tuple for a nested struct. This is not really expected in a Row. We can address this in a follow-up. |
||
| keys += key | ||
| values += value | ||
| } | ||
| // If the metrics is registered by an Observation object, attach them and unblock any | ||
| // blocked thread. | ||
| setObservationMetricsOpt.foreach { setObservationMetrics => | ||
| setObservationMetrics(metric.getPlanId, keys.zip(values).toMap) | ||
| } | ||
| metric.getName -> new GenericRowWithSchema(values.toArray, schema) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns the number of elements in the result. | ||
| */ | ||
|
|
@@ -248,6 +279,15 @@ private[sql] class SparkResult[T]( | |
| result | ||
| } | ||
|
|
||
| /** | ||
| * Returns all observed metrics in the result. | ||
| */ | ||
| def getObservedMetrics: Map[String, Row] = { | ||
| // We need to process all responses to get all metrics. | ||
| processResponses() | ||
| observedMetrics.toMap | ||
| } | ||
|
|
||
| /** | ||
| * Returns an iterator over the contents of the result. | ||
| */ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.