diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml
index f769067557b8..d51e922f025a 100644
--- a/common/network-shuffle/pom.xml
+++ b/common/network-shuffle/pom.xml
@@ -42,11 +42,6 @@
${project.version}
-
- org.apache.commons
- commons-lang3
-
-
io.dropwizard.metrics
metrics-core
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 3d885ffdb02d..63484c23a920 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -212,6 +212,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val hideTraceback: Boolean = false
protected val simplifiedTraceback: Boolean = false
+ protected val runnerConf: Map[String, String] = Map.empty
+
// All the Python functions should have the same exec, version and envvars.
protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars
protected val pythonExec: String = funcs.head.funcs.head.pythonExec
@@ -403,6 +405,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
*/
protected def writeCommand(dataOut: DataOutputStream): Unit
+ /**
+ * Writes worker configuration to the stream connected to the Python worker.
+ */
+ protected def writeRunnerConf(dataOut: DataOutputStream): Unit = {
+ dataOut.writeInt(runnerConf.size)
+ for ((k, v) <- runnerConf) {
+ PythonWorkerUtils.writeUTF(k, dataOut)
+ PythonWorkerUtils.writeUTF(v, dataOut)
+ }
+ }
+
/**
* Writes input data to the stream connected to the Python worker.
* Returns true if any data was written to the stream, false if the input is exhausted.
@@ -532,6 +545,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
dataOut.writeInt(evalType)
+ writeRunnerConf(dataOut)
writeCommand(dataOut)
dataOut.flush()
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 53c791a3446c..aa17f954a7d9 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -407,10 +407,6 @@ private[spark] class Executor(
TaskState.FAILED,
env.closureSerializer.newInstance().serialize(new ExceptionFailure(t, Seq.empty)))
} catch {
- case oom: OutOfMemoryError =>
- logError(log"Executor update launching task ${MDC(TASK_NAME, taskDescription.name)} " +
- log"failed status failed, reason: ${MDC(REASON, oom.getMessage)}")
- System.exit(SparkExitCode.OOM)
case t: Throwable =>
logError(log"Executor update launching task ${MDC(TASK_NAME, taskDescription.name)} " +
log"failed status failed, reason: ${MDC(REASON, t.getMessage)}")
diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3
index 470082330509..270af9c205b4 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -30,7 +30,7 @@ azure-storage/7.0.1//azure-storage-7.0.1.jar
blas/3.0.4//blas-3.0.4.jar
breeze-macros_2.13/2.1.0//breeze-macros_2.13-2.1.0.jar
breeze_2.13/2.1.0//breeze_2.13-2.1.0.jar
-bundle/2.29.52//bundle-2.29.52.jar
+bundle/2.35.4//bundle-2.35.4.jar
cats-kernel_2.13/2.8.0//cats-kernel_2.13-2.8.0.jar
chill-java/0.10.0//chill-java-0.10.0.jar
chill_2.13/0.10.0//chill_2.13-0.10.0.jar
diff --git a/pom.xml b/pom.xml
index ea23bbf20d0b..95968580cc21 100644
--- a/pom.xml
+++ b/pom.xml
@@ -162,7 +162,7 @@
1.15.3
1.12.681
- 2.29.52
+ 2.35.4
1.0.5
@@ -233,7 +233,7 @@
./python/packaging/client/setup.py, and ./python/packaging/connect/setup.py too.
-->
18.3.0
- 3.0.4
+ 3.0.5
0.12.6
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 409c33d1727a..7495569ecfda 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -163,14 +163,26 @@ def handle_sigterm(*args):
# Initialization complete
try:
+ poller = None
+ if os.name == "posix":
+ # select.select has a known limit on the number of file descriptors
+ # it can handle. We use select.poll instead to avoid this limit.
+ poller = select.poll()
+ fd_reverse_map = {0: 0, listen_sock.fileno(): listen_sock}
+ poller.register(0, select.POLLIN)
+ poller.register(listen_sock, select.POLLIN)
+
while True:
- try:
- ready_fds = select.select([0, listen_sock], [], [], 1)[0]
- except select.error as ex:
- if ex[0] == EINTR:
- continue
- else:
- raise
+ if poller is not None:
+ ready_fds = [fd_reverse_map[fd] for fd, _ in poller.poll(1000)]
+ else:
+ try:
+ ready_fds = select.select([0, listen_sock], [], [], 1)[0]
+ except select.error as ex:
+ if ex[0] == EINTR:
+ continue
+ else:
+ raise
if 0 in ready_fds:
try:
@@ -208,6 +220,9 @@ def handle_sigterm(*args):
if pid == 0:
# in child process
+ if poller is not None:
+ poller.unregister(0)
+ poller.unregister(listen_sock)
listen_sock.close()
# It should close the standard input in the child process so that
@@ -256,6 +271,9 @@ def handle_sigterm(*args):
sock.close()
finally:
+ if poller is not None:
+ poller.unregister(0)
+ poller.unregister(listen_sock)
shutdown(1)
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index f5bfd59694af..abecbf90aa95 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -622,7 +622,7 @@ class _LinearSVCParams(
)
def __init__(self, *args: Any) -> None:
- super(_LinearSVCParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
maxIter=100,
regParam=0.0,
@@ -743,7 +743,7 @@ def __init__(
fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \
aggregationDepth=2, maxBlockSizeInMB=0.0):
"""
- super(LinearSVC, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.LinearSVC", self.uid
)
@@ -1019,7 +1019,7 @@ class _LogisticRegressionParams(
)
def __init__(self, *args: Any):
- super(_LogisticRegressionParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
maxIter=100, regParam=0.0, tol=1e-6, threshold=0.5, family="auto", maxBlockSizeInMB=0.0
)
@@ -1328,7 +1328,7 @@ def __init__(
maxBlockSizeInMB=0.0):
If the threshold and thresholds Params are both set, they must be equivalent.
"""
- super(LogisticRegression, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.LogisticRegression", self.uid
)
@@ -1676,7 +1676,7 @@ class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
"""
def __init__(self, *args: Any):
- super(_DecisionTreeClassifierParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
maxDepth=5,
maxBins=32,
@@ -1809,7 +1809,7 @@ def __init__(
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
"""
- super(DecisionTreeClassifier, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid
)
@@ -1970,7 +1970,7 @@ class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
"""
def __init__(self, *args: Any):
- super(_RandomForestClassifierParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
maxDepth=5,
maxBins=32,
@@ -2106,7 +2106,7 @@ def __init__(
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)
"""
- super(RandomForestClassifier, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.RandomForestClassifier", self.uid
)
@@ -2400,7 +2400,7 @@ class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
)
def __init__(self, *args: Any):
- super(_GBTClassifierParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
maxDepth=5,
maxBins=32,
@@ -2577,7 +2577,7 @@ def __init__(
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \
weightCol=None)
"""
- super(GBTClassifier, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.GBTClassifier", self.uid
)
@@ -2823,7 +2823,7 @@ class _NaiveBayesParams(_PredictorParams, HasWeightCol):
)
def __init__(self, *args: Any):
- super(_NaiveBayesParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(smoothing=1.0, modelType="multinomial")
@since("1.5.0")
@@ -2964,7 +2964,7 @@ def __init__(
probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
modelType="multinomial", thresholds=None, weightCol=None)
"""
- super(NaiveBayes, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.NaiveBayes", self.uid
)
@@ -3093,7 +3093,7 @@ class _MultilayerPerceptronParams(
)
def __init__(self, *args: Any):
- super(_MultilayerPerceptronParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(maxIter=100, tol=1e-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
@since("1.6.0")
@@ -3219,7 +3219,7 @@ def __init__(
solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
rawPredictionCol="rawPrediction")
"""
- super(MultilayerPerceptronClassifier, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid
)
@@ -3484,7 +3484,7 @@ def __init__(
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1):
"""
- super(OneVsRest, self).__init__()
+ super().__init__()
self._setDefault(parallelism=1)
kwargs = self._input_kwargs
self._set(**kwargs)
@@ -3749,7 +3749,7 @@ def validateParams(instance: Union[OneVsRest, "OneVsRestModel"]) -> None:
@inherit_doc
class OneVsRestReader(MLReader[OneVsRest]):
def __init__(self, cls: Type[OneVsRest]) -> None:
- super(OneVsRestReader, self).__init__()
+ super().__init__()
self.cls = cls
def load(self, path: str) -> OneVsRest:
@@ -3765,7 +3765,7 @@ def load(self, path: str) -> OneVsRest:
@inherit_doc
class OneVsRestWriter(MLWriter):
def __init__(self, instance: OneVsRest):
- super(OneVsRestWriter, self).__init__()
+ super().__init__()
self.instance = instance
def saveImpl(self, path: str) -> None:
@@ -3807,7 +3807,7 @@ def setRawPredictionCol(self, value: str) -> "OneVsRestModel":
return self._set(rawPredictionCol=value)
def __init__(self, models: List[ClassificationModel]):
- super(OneVsRestModel, self).__init__()
+ super().__init__()
self.models = models
if is_remote() or not isinstance(models[0], JavaMLWritable):
return
@@ -3980,7 +3980,7 @@ def write(self) -> MLWriter:
@inherit_doc
class OneVsRestModelReader(MLReader[OneVsRestModel]):
def __init__(self, cls: Type[OneVsRestModel]):
- super(OneVsRestModelReader, self).__init__()
+ super().__init__()
self.cls = cls
def load(self, path: str) -> OneVsRestModel:
@@ -4002,7 +4002,7 @@ def load(self, path: str) -> OneVsRestModel:
@inherit_doc
class OneVsRestModelWriter(MLWriter):
def __init__(self, instance: OneVsRestModel):
- super(OneVsRestModelWriter, self).__init__()
+ super().__init__()
self.instance = instance
def saveImpl(self, path: str) -> None:
@@ -4119,7 +4119,7 @@ def __init__(
miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
tol=1e-6, solver="adamW", thresholds=None, seed=None)
"""
- super(FMClassifier, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.FMClassifier", self.uid
)
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 0e26398de3c4..0fc2b34d1748 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -169,7 +169,7 @@ class _GaussianMixtureParams(
)
def __init__(self, *args: Any):
- super(_GaussianMixtureParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(k=2, tol=0.01, maxIter=100, aggregationDepth=2)
@since("2.0.0")
@@ -422,7 +422,7 @@ def __init__(
probabilityCol="probability", tol=0.01, maxIter=100, seed=None, \
aggregationDepth=2, weightCol=None)
"""
- super(GaussianMixture, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.clustering.GaussianMixture", self.uid
)
@@ -617,7 +617,7 @@ class _KMeansParams(
)
def __init__(self, *args: Any):
- super(_KMeansParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
k=2,
initMode="k-means||",
@@ -800,7 +800,7 @@ def __init__(
distanceMeasure="euclidean", weightCol=None, solver="auto", \
maxBlockSizeInMB=0.0)
"""
- super(KMeans, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -952,7 +952,7 @@ class _BisectingKMeansParams(
)
def __init__(self, *args: Any):
- super(_BisectingKMeansParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(maxIter=20, k=4, minDivisibleClusterSize=1.0)
@since("2.0.0")
@@ -1146,7 +1146,7 @@ def __init__(
seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean", \
weightCol=None)
"""
- super(BisectingKMeans, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.clustering.BisectingKMeans", self.uid
)
@@ -1337,7 +1337,7 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval):
)
def __init__(self, *args: Any):
- super(_LDAParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
maxIter=20,
checkpointInterval=10,
@@ -1710,7 +1710,7 @@ def __init__(
docConcentration=None, topicConcentration=None,\
topicDistributionCol="topicDistribution", keepLastCheckpoint=True)
"""
- super(LDA, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.LDA", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -1948,7 +1948,7 @@ class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol):
)
def __init__(self, *args: Any):
- super(_PowerIterationClusteringParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst")
@since("2.4.0")
@@ -2054,7 +2054,7 @@ def __init__(
__init__(self, \\*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\
weightCol=None)
"""
- super(PowerIterationClustering, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.clustering.PowerIterationClustering", self.uid
)
diff --git a/python/pyspark/ml/connect/classification.py b/python/pyspark/ml/connect/classification.py
index e9e7840b098b..3263f47e6135 100644
--- a/python/pyspark/ml/connect/classification.py
+++ b/python/pyspark/ml/connect/classification.py
@@ -64,7 +64,7 @@ class _LogisticRegressionParams(
"""
def __init__(self, *args: Any):
- super(_LogisticRegressionParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
maxIter=100,
tol=1e-6,
@@ -215,7 +215,7 @@ def __init__(
seed: int = 0,
)
"""
- super(LogisticRegression, self).__init__()
+ super().__init__()
kwargs = self._input_kwargs
self._set(**kwargs)
diff --git a/python/pyspark/ml/connect/pipeline.py b/python/pyspark/ml/connect/pipeline.py
index 9850f7a0fd0f..daba4e6fa218 100644
--- a/python/pyspark/ml/connect/pipeline.py
+++ b/python/pyspark/ml/connect/pipeline.py
@@ -144,7 +144,7 @@ def __init__(self, *, stages: Optional[List[Params]] = None):
"""
__init__(self, \\*, stages=None)
"""
- super(Pipeline, self).__init__()
+ super().__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -248,7 +248,7 @@ class PipelineModel(Model, _PipelineReadWrite):
"""
def __init__(self, stages: Optional[List[Params]] = None):
- super(PipelineModel, self).__init__()
+ super().__init__()
self.stages = stages # type: ignore[assignment]
def _transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
diff --git a/python/pyspark/ml/connect/tuning.py b/python/pyspark/ml/connect/tuning.py
index 1ef055d25007..39aae6f17cc5 100644
--- a/python/pyspark/ml/connect/tuning.py
+++ b/python/pyspark/ml/connect/tuning.py
@@ -118,7 +118,7 @@ class _CrossValidatorParams(_ValidatorParams):
)
def __init__(self, *args: Any):
- super(_CrossValidatorParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(numFolds=3, foldCol="")
@since("1.4.0")
@@ -326,7 +326,7 @@ def __init__(
__init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
seed=None, parallelism=1, foldCol="")
"""
- super(CrossValidator, self).__init__()
+ super().__init__()
self._setDefault(parallelism=1)
kwargs = self._input_kwargs
self._set(**kwargs)
@@ -539,7 +539,7 @@ def __init__(
avgMetrics: Optional[List[float]] = None,
stdMetrics: Optional[List[float]] = None,
) -> None:
- super(CrossValidatorModel, self).__init__()
+ super().__init__()
#: best model from cross validation
self.bestModel = bestModel
#: Average cross-validation metrics for each paramMap in
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 568583eb08ec..56747d07441b 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -235,7 +235,7 @@ def __init__(
__init__(self, \\*, rawPredictionCol="rawPrediction", labelCol="label", \
metricName="areaUnderROC", weightCol=None, numBins=1000)
"""
- super(BinaryClassificationEvaluator, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid
)
@@ -397,7 +397,7 @@ def __init__(
__init__(self, \\*, predictionCol="prediction", labelCol="label", \
metricName="rmse", weightCol=None, throughOrigin=False)
"""
- super(RegressionEvaluator, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid
)
@@ -593,7 +593,7 @@ def __init__(
metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0, \
probabilityCol="probability", eps=1e-15)
"""
- super(MulticlassClassificationEvaluator, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid
)
@@ -790,7 +790,7 @@ def __init__(
__init__(self, \\*, predictionCol="prediction", labelCol="label", \
metricName="f1Measure", metricLabel=0.0)
"""
- super(MultilabelClassificationEvaluator, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator", self.uid
)
@@ -947,7 +947,7 @@ def __init__(
__init__(self, \\*, predictionCol="prediction", featuresCol="features", \
metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None)
"""
- super(ClusteringEvaluator, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid
)
@@ -1095,7 +1095,7 @@ def __init__(
__init__(self, \\*, predictionCol="prediction", labelCol="label", \
metricName="meanAveragePrecision", k=10)
"""
- super(RankingEvaluator, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.RankingEvaluator", self.uid
)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 4d1551652028..c5a5033cd8f0 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -248,7 +248,7 @@ def __init__(
__init__(self, \\*, threshold=0.0, inputCol=None, outputCol=None, thresholds=None, \
inputCols=None, outputCols=None)
"""
- super(Binarizer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid)
self._setDefault(threshold=0.0)
kwargs = self._input_kwargs
@@ -350,7 +350,7 @@ class _LSHParams(HasInputCol, HasOutputCol):
)
def __init__(self, *args: Any):
- super(_LSHParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(numHashTables=1)
def getNumHashTables(self) -> int:
@@ -603,7 +603,7 @@ def __init__(
__init__(self, \\*, inputCol=None, outputCol=None, seed=None, numHashTables=1, \
bucketLength=None)
"""
- super(BucketedRandomProjectionLSH, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.feature.BucketedRandomProjectionLSH", self.uid
)
@@ -820,7 +820,7 @@ def __init__(
__init__(self, \\*, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
splitsArray=None, inputCols=None, outputCols=None)
"""
- super(Bucketizer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid)
self._setDefault(handleInvalid="error")
kwargs = self._input_kwargs
@@ -985,7 +985,7 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol):
)
def __init__(self, *args: Any):
- super(_CountVectorizerParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(minTF=1.0, minDF=1.0, maxDF=2**63 - 1, vocabSize=1 << 18, binary=False)
@since("1.6.0")
@@ -1105,7 +1105,7 @@ def __init__(
__init__(self, \\*, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18,\
binary=False, inputCol=None,outputCol=None)
"""
- super(CountVectorizer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -1339,7 +1339,7 @@ def __init__(
"""
__init__(self, \\*, inverse=False, inputCol=None, outputCol=None)
"""
- super(DCT, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid)
self._setDefault(inverse=False)
kwargs = self._input_kwargs
@@ -1447,7 +1447,7 @@ def __init__(
"""
__init__(self, \\*, scalingVec=None, inputCol=None, outputCol=None)
"""
- super(ElementwiseProduct, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.feature.ElementwiseProduct", self.uid
)
@@ -1585,7 +1585,7 @@ def __init__(
__init__(self, \\*, numFeatures=1 << 18, inputCols=None, outputCol=None, \
categoricalCols=None)
"""
- super(FeatureHasher, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.FeatureHasher", self.uid)
self._setDefault(numFeatures=1 << 18)
kwargs = self._input_kwargs
@@ -1708,7 +1708,7 @@ def __init__(
"""
__init__(self, \\*, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None)
"""
- super(HashingTF, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid)
self._setDefault(numFeatures=1 << 18, binary=False)
kwargs = self._input_kwargs
@@ -1795,7 +1795,7 @@ def getMinDocFreq(self) -> int:
return self.getOrDefault(self.minDocFreq)
def __init__(self, *args: Any):
- super(_IDFParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(minDocFreq=0)
@@ -1859,7 +1859,7 @@ def __init__(
"""
__init__(self, \\*, minDocFreq=0, inputCol=None, outputCol=None)
"""
- super(IDF, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -1975,7 +1975,7 @@ class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, Has
)
def __init__(self, *args: Any):
- super(_ImputerParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(strategy="mean", missingValue=float("nan"), relativeError=0.001)
@since("2.2.0")
@@ -2152,7 +2152,7 @@ def __init__(
__init__(self, \\*, strategy="mean", missingValue=float("nan"), inputCols=None, \
outputCols=None, inputCol=None, outputCol=None, relativeError=0.001):
"""
- super(Imputer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Imputer", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -2351,7 +2351,7 @@ def __init__(self, *, inputCols: Optional[List[str]] = None, outputCol: Optional
"""
__init__(self, \\*, inputCols=None, outputCol=None):
"""
- super(Interaction, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Interaction", self.uid)
self._setDefault()
kwargs = self._input_kwargs
@@ -2449,7 +2449,7 @@ def __init__(self, *, inputCol: Optional[str] = None, outputCol: Optional[str] =
"""
__init__(self, \\*, inputCol=None, outputCol=None)
"""
- super(MaxAbsScaler, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MaxAbsScaler", self.uid)
self._setDefault()
kwargs = self._input_kwargs
@@ -2602,7 +2602,7 @@ def __init__(
"""
__init__(self, \\*, inputCol=None, outputCol=None, seed=None, numHashTables=1)
"""
- super(MinHashLSH, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinHashLSH", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -2672,7 +2672,7 @@ class _MinMaxScalerParams(HasInputCol, HasOutputCol):
)
def __init__(self, *args: Any):
- super(_MinMaxScalerParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(min=0.0, max=1.0)
@since("1.6.0")
@@ -2767,7 +2767,7 @@ def __init__(
"""
__init__(self, \\*, min=0.0, max=1.0, inputCol=None, outputCol=None)
"""
- super(MinMaxScaler, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -2934,7 +2934,7 @@ def __init__(
"""
__init__(self, \\*, n=2, inputCol=None, outputCol=None)
"""
- super(NGram, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid)
self._setDefault(n=2)
kwargs = self._input_kwargs
@@ -3029,7 +3029,7 @@ def __init__(
"""
__init__(self, \\*, p=2.0, inputCol=None, outputCol=None)
"""
- super(Normalizer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid)
self._setDefault(p=2.0)
kwargs = self._input_kwargs
@@ -3102,7 +3102,7 @@ class _OneHotEncoderParams(
)
def __init__(self, *args: Any):
- super(_OneHotEncoderParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(handleInvalid="error", dropLast=True)
@since("2.3.0")
@@ -3221,7 +3221,7 @@ def __init__(
__init__(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \
inputCol=None, outputCol=None)
"""
- super(OneHotEncoder, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -3430,7 +3430,7 @@ def __init__(
"""
__init__(self, \\*, degree=2, inputCol=None, outputCol=None)
"""
- super(PolynomialExpansion, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.feature.PolynomialExpansion", self.uid
)
@@ -3658,7 +3658,7 @@ def __init__(
__init__(self, \\*, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None)
"""
- super(QuantileDiscretizer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.feature.QuantileDiscretizer", self.uid
)
@@ -3850,7 +3850,7 @@ class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError):
)
def __init__(self, *args: Any):
- super(_RobustScalerParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
lower=0.25, upper=0.75, withCentering=False, withScaling=True, relativeError=0.001
)
@@ -3958,7 +3958,7 @@ def __init__(
__init__(self, \\*, lower=0.25, upper=0.75, withCentering=False, withScaling=True, \
inputCol=None, outputCol=None, relativeError=0.001)
"""
- super(RobustScaler, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RobustScaler", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -4170,7 +4170,7 @@ def __init__(
__init__(self, \\*, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \
outputCol=None, toLowercase=True)
"""
- super(RegexTokenizer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid)
self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+", toLowercase=True)
kwargs = self._input_kwargs
@@ -4301,7 +4301,7 @@ def __init__(self, *, statement: Optional[str] = None):
"""
__init__(self, \\*, statement=None)
"""
- super(SQLTransformer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -4349,7 +4349,7 @@ class _StandardScalerParams(HasInputCol, HasOutputCol):
)
def __init__(self, *args: Any):
- super(_StandardScalerParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(withMean=False, withStd=True)
@since("1.4.0")
@@ -4436,7 +4436,7 @@ def __init__(
"""
__init__(self, \\*, withMean=False, withStd=True, inputCol=None, outputCol=None)
"""
- super(StandardScaler, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -4560,7 +4560,7 @@ class _StringIndexerParams(
)
def __init__(self, *args: Any):
- super(_StringIndexerParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc")
@since("2.3.0")
@@ -4699,7 +4699,7 @@ def __init__(
__init__(self, \\*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \
handleInvalid="error", stringOrderType="frequencyDesc")
"""
- super(StringIndexer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -4978,7 +4978,7 @@ def __init__(
"""
__init__(self, \\*, inputCol=None, outputCol=None, labels=None)
"""
- super(IndexToString, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -5140,7 +5140,7 @@ def __init__(
__init__(self, \\*, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
locale=None, inputCols=None, outputCols=None)
"""
- super(StopWordsRemover, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.feature.StopWordsRemover", self.uid
)
@@ -5317,7 +5317,7 @@ class _TargetEncoderParams(
)
def __init__(self, *args: Any):
- super(_TargetEncoderParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(handleInvalid="error", targetType="binary", smoothing=0.0)
@since("4.0.0")
@@ -5409,7 +5409,7 @@ def __init__(
__init__(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \
targetType="binary", smoothing=0.0, inputCol=None, outputCol=None)
"""
- super(TargetEncoder, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.TargetEncoder", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -5623,7 +5623,7 @@ def __init__(self, *, inputCol: Optional[str] = None, outputCol: Optional[str] =
"""
__init__(self, \\*, inputCol=None, outputCol=None)
"""
- super(Tokenizer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -5734,7 +5734,7 @@ def __init__(
"""
__init__(self, \\*, inputCols=None, outputCol=None, handleInvalid="error")
"""
- super(VectorAssembler, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid)
self._setDefault(handleInvalid="error")
kwargs = self._input_kwargs
@@ -5803,7 +5803,7 @@ class _VectorIndexerParams(HasInputCol, HasOutputCol, HasHandleInvalid):
)
def __init__(self, *args: Any):
- super(_VectorIndexerParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(maxCategories=20, handleInvalid="error")
@since("1.4.0")
@@ -5923,7 +5923,7 @@ def __init__(
"""
__init__(self, \\*, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error")
"""
- super(VectorIndexer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -6114,7 +6114,7 @@ def __init__(
"""
__init__(self, \\*, inputCol=None, outputCol=None, indices=None, names=None)
"""
- super(VectorSlicer, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid)
self._setDefault(indices=[], names=[])
kwargs = self._input_kwargs
@@ -6220,7 +6220,7 @@ class _Word2VecParams(HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCo
)
def __init__(self, *args: Any):
- super(_Word2VecParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
vectorSize=100,
minCount=5,
@@ -6359,7 +6359,7 @@ def __init__(
maxIter=1, seed=None, inputCol=None, outputCol=None, windowSize=5, \
maxSentenceLength=1000)
"""
- super(Word2Vec, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -6592,7 +6592,7 @@ def __init__(
"""
__init__(self, \\*, k=None, inputCol=None, outputCol=None)
"""
- super(PCA, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -6717,7 +6717,7 @@ class _RFormulaParams(HasFeaturesCol, HasLabelCol, HasHandleInvalid):
)
def __init__(self, *args: Any):
- super(_RFormulaParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", handleInvalid="error"
)
@@ -6841,7 +6841,7 @@ def __init__(
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
handleInvalid="error")
"""
- super(RFormula, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -6980,7 +6980,7 @@ class _SelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol):
)
def __init__(self, *args: Any):
- super(_SelectorParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
numTopFeatures=50,
selectorType="numTopFeatures",
@@ -7220,7 +7220,7 @@ def __init__(
labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \
fdr=0.05, fwe=0.05)
"""
- super(ChiSqSelector, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -7331,7 +7331,7 @@ def __init__(
"""
__init__(self, \\*, inputCol=None, size=None, handleInvalid="error")
"""
- super(VectorSizeHint, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSizeHint", self.uid)
self._setDefault(handleInvalid="error")
self.setParams(**self._input_kwargs)
@@ -7463,7 +7463,7 @@ def __init__(
"""
__init__(self, \\*, featuresCol="features", outputCol=None, varianceThreshold=0.0)
"""
- super(VarianceThresholdSelector, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.feature.VarianceThresholdSelector", self.uid
)
@@ -7586,7 +7586,7 @@ class _UnivariateFeatureSelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol
)
def __init__(self, *args: Any):
- super(_UnivariateFeatureSelectorParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(selectionMode="numTopFeatures")
@since("3.1.1")
@@ -7712,7 +7712,7 @@ def __init__(
__init__(self, \\*, featuresCol="features", outputCol=None, \
labelCol="label", selectionMode="numTopFeatures")
"""
- super(UnivariateFeatureSelector, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.feature.UnivariateFeatureSelector", self.uid
)
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 64b7e5eae556..857e9a97b154 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -71,7 +71,7 @@ class _FPGrowthParams(HasPredictionCol):
)
def __init__(self, *args: Any):
- super(_FPGrowthParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
minSupport=0.3, minConfidence=0.8, itemsCol="items", predictionCol="prediction"
)
@@ -256,7 +256,7 @@ def __init__(
__init__(self, \\*, minSupport=0.3, minConfidence=0.8, itemsCol="items", \
predictionCol="prediction", numPartitions=None)
"""
- super(FPGrowth, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.FPGrowth", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -409,7 +409,7 @@ def __init__(
__init__(self, \\*, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
sequenceCol="sequence")
"""
- super(PrefixSpan, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid)
self._setDefault(
minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, sequenceCol="sequence"
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 067741b0a7f6..6fce132561db 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -255,7 +255,7 @@ class Params(Identifiable, metaclass=ABCMeta):
"""
def __init__(self) -> None:
- super(Params, self).__init__()
+ super().__init__()
#: internal param map for user-supplied values param map
self._paramMap: "ParamMap" = {}
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index ade96da0a4f1..bbcaa208a39d 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -72,7 +72,7 @@ def _gen_param_header(
)
def __init__(self) -> None:
- super({Name}, self).__init__()'''
+ super().__init__()'''
if defaultValueStr is not None:
template += f"""
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index bc444bf9cbf9..e60f2a7432f7 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -35,7 +35,7 @@ class HasMaxIter(Params):
)
def __init__(self) -> None:
- super(HasMaxIter, self).__init__()
+ super().__init__()
def getMaxIter(self) -> int:
"""
@@ -57,7 +57,7 @@ class HasRegParam(Params):
)
def __init__(self) -> None:
- super(HasRegParam, self).__init__()
+ super().__init__()
def getRegParam(self) -> float:
"""
@@ -79,7 +79,7 @@ class HasFeaturesCol(Params):
)
def __init__(self) -> None:
- super(HasFeaturesCol, self).__init__()
+ super().__init__()
self._setDefault(featuresCol="features")
def getFeaturesCol(self) -> str:
@@ -102,7 +102,7 @@ class HasLabelCol(Params):
)
def __init__(self) -> None:
- super(HasLabelCol, self).__init__()
+ super().__init__()
self._setDefault(labelCol="label")
def getLabelCol(self) -> str:
@@ -125,7 +125,7 @@ class HasPredictionCol(Params):
)
def __init__(self) -> None:
- super(HasPredictionCol, self).__init__()
+ super().__init__()
self._setDefault(predictionCol="prediction")
def getPredictionCol(self) -> str:
@@ -148,7 +148,7 @@ class HasProbabilityCol(Params):
)
def __init__(self) -> None:
- super(HasProbabilityCol, self).__init__()
+ super().__init__()
self._setDefault(probabilityCol="probability")
def getProbabilityCol(self) -> str:
@@ -171,7 +171,7 @@ class HasRawPredictionCol(Params):
)
def __init__(self) -> None:
- super(HasRawPredictionCol, self).__init__()
+ super().__init__()
self._setDefault(rawPredictionCol="rawPrediction")
def getRawPredictionCol(self) -> str:
@@ -194,7 +194,7 @@ class HasInputCol(Params):
)
def __init__(self) -> None:
- super(HasInputCol, self).__init__()
+ super().__init__()
def getInputCol(self) -> str:
"""
@@ -216,7 +216,7 @@ class HasInputCols(Params):
)
def __init__(self) -> None:
- super(HasInputCols, self).__init__()
+ super().__init__()
def getInputCols(self) -> List[str]:
"""
@@ -238,7 +238,7 @@ class HasOutputCol(Params):
)
def __init__(self) -> None:
- super(HasOutputCol, self).__init__()
+ super().__init__()
self._setDefault(outputCol=self.uid + "__output")
def getOutputCol(self) -> str:
@@ -261,7 +261,7 @@ class HasOutputCols(Params):
)
def __init__(self) -> None:
- super(HasOutputCols, self).__init__()
+ super().__init__()
def getOutputCols(self) -> List[str]:
"""
@@ -283,7 +283,7 @@ class HasNumFeatures(Params):
)
def __init__(self) -> None:
- super(HasNumFeatures, self).__init__()
+ super().__init__()
self._setDefault(numFeatures=262144)
def getNumFeatures(self) -> int:
@@ -306,7 +306,7 @@ class HasCheckpointInterval(Params):
)
def __init__(self) -> None:
- super(HasCheckpointInterval, self).__init__()
+ super().__init__()
def getCheckpointInterval(self) -> int:
"""
@@ -328,7 +328,7 @@ class HasSeed(Params):
)
def __init__(self) -> None:
- super(HasSeed, self).__init__()
+ super().__init__()
self._setDefault(seed=hash(type(self).__name__))
def getSeed(self) -> int:
@@ -351,7 +351,7 @@ class HasTol(Params):
)
def __init__(self) -> None:
- super(HasTol, self).__init__()
+ super().__init__()
def getTol(self) -> float:
"""
@@ -373,7 +373,7 @@ class HasRelativeError(Params):
)
def __init__(self) -> None:
- super(HasRelativeError, self).__init__()
+ super().__init__()
self._setDefault(relativeError=0.001)
def getRelativeError(self) -> float:
@@ -396,7 +396,7 @@ class HasStepSize(Params):
)
def __init__(self) -> None:
- super(HasStepSize, self).__init__()
+ super().__init__()
def getStepSize(self) -> float:
"""
@@ -418,7 +418,7 @@ class HasHandleInvalid(Params):
)
def __init__(self) -> None:
- super(HasHandleInvalid, self).__init__()
+ super().__init__()
def getHandleInvalid(self) -> str:
"""
@@ -440,7 +440,7 @@ class HasElasticNetParam(Params):
)
def __init__(self) -> None:
- super(HasElasticNetParam, self).__init__()
+ super().__init__()
self._setDefault(elasticNetParam=0.0)
def getElasticNetParam(self) -> float:
@@ -463,7 +463,7 @@ class HasFitIntercept(Params):
)
def __init__(self) -> None:
- super(HasFitIntercept, self).__init__()
+ super().__init__()
self._setDefault(fitIntercept=True)
def getFitIntercept(self) -> bool:
@@ -486,7 +486,7 @@ class HasStandardization(Params):
)
def __init__(self) -> None:
- super(HasStandardization, self).__init__()
+ super().__init__()
self._setDefault(standardization=True)
def getStandardization(self) -> bool:
@@ -509,7 +509,7 @@ class HasThresholds(Params):
)
def __init__(self) -> None:
- super(HasThresholds, self).__init__()
+ super().__init__()
def getThresholds(self) -> List[float]:
"""
@@ -531,7 +531,7 @@ class HasThreshold(Params):
)
def __init__(self) -> None:
- super(HasThreshold, self).__init__()
+ super().__init__()
self._setDefault(threshold=0.5)
def getThreshold(self) -> float:
@@ -554,7 +554,7 @@ class HasWeightCol(Params):
)
def __init__(self) -> None:
- super(HasWeightCol, self).__init__()
+ super().__init__()
def getWeightCol(self) -> str:
"""
@@ -576,7 +576,7 @@ class HasSolver(Params):
)
def __init__(self) -> None:
- super(HasSolver, self).__init__()
+ super().__init__()
self._setDefault(solver="auto")
def getSolver(self) -> str:
@@ -599,7 +599,7 @@ class HasVarianceCol(Params):
)
def __init__(self) -> None:
- super(HasVarianceCol, self).__init__()
+ super().__init__()
def getVarianceCol(self) -> str:
"""
@@ -621,7 +621,7 @@ class HasAggregationDepth(Params):
)
def __init__(self) -> None:
- super(HasAggregationDepth, self).__init__()
+ super().__init__()
self._setDefault(aggregationDepth=2)
def getAggregationDepth(self) -> int:
@@ -644,7 +644,7 @@ class HasParallelism(Params):
)
def __init__(self) -> None:
- super(HasParallelism, self).__init__()
+ super().__init__()
self._setDefault(parallelism=1)
def getParallelism(self) -> int:
@@ -667,7 +667,7 @@ class HasCollectSubModels(Params):
)
def __init__(self) -> None:
- super(HasCollectSubModels, self).__init__()
+ super().__init__()
self._setDefault(collectSubModels=False)
def getCollectSubModels(self) -> bool:
@@ -690,7 +690,7 @@ class HasLoss(Params):
)
def __init__(self) -> None:
- super(HasLoss, self).__init__()
+ super().__init__()
def getLoss(self) -> str:
"""
@@ -712,7 +712,7 @@ class HasDistanceMeasure(Params):
)
def __init__(self) -> None:
- super(HasDistanceMeasure, self).__init__()
+ super().__init__()
self._setDefault(distanceMeasure="euclidean")
def getDistanceMeasure(self) -> str:
@@ -735,7 +735,7 @@ class HasValidationIndicatorCol(Params):
)
def __init__(self) -> None:
- super(HasValidationIndicatorCol, self).__init__()
+ super().__init__()
def getValidationIndicatorCol(self) -> str:
"""
@@ -757,7 +757,7 @@ class HasBlockSize(Params):
)
def __init__(self) -> None:
- super(HasBlockSize, self).__init__()
+ super().__init__()
def getBlockSize(self) -> int:
"""
@@ -779,7 +779,7 @@ class HasMaxBlockSizeInMB(Params):
)
def __init__(self) -> None:
- super(HasMaxBlockSizeInMB, self).__init__()
+ super().__init__()
self._setDefault(maxBlockSizeInMB=0.0)
def getMaxBlockSizeInMB(self) -> float:
@@ -802,7 +802,7 @@ class HasNumTrainWorkers(Params):
)
def __init__(self) -> None:
- super(HasNumTrainWorkers, self).__init__()
+ super().__init__()
self._setDefault(numTrainWorkers=1)
def getNumTrainWorkers(self) -> int:
@@ -825,7 +825,7 @@ class HasBatchSize(Params):
)
def __init__(self) -> None:
- super(HasBatchSize, self).__init__()
+ super().__init__()
def getBatchSize(self) -> int:
"""
@@ -847,7 +847,7 @@ class HasLearningRate(Params):
)
def __init__(self) -> None:
- super(HasLearningRate, self).__init__()
+ super().__init__()
def getLearningRate(self) -> float:
"""
@@ -869,7 +869,7 @@ class HasMomentum(Params):
)
def __init__(self) -> None:
- super(HasMomentum, self).__init__()
+ super().__init__()
def getMomentum(self) -> float:
"""
@@ -891,7 +891,7 @@ class HasFeatureSizes(Params):
)
def __init__(self) -> None:
- super(HasFeatureSizes, self).__init__()
+ super().__init__()
def getFeatureSizes(self) -> List[int]:
"""
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 5728b8c0c511..f5f0d12bc836 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -77,7 +77,7 @@ def __init__(self, *, stages: Optional[List["PipelineStage"]] = None):
"""
__init__(self, \\*, stages=None)
"""
- super(Pipeline, self).__init__()
+ super().__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -227,7 +227,7 @@ class PipelineWriter(MLWriter):
"""
def __init__(self, instance: Pipeline):
- super(PipelineWriter, self).__init__()
+ super().__init__()
self.instance = instance
def saveImpl(self, path: str) -> None:
@@ -243,7 +243,7 @@ class PipelineReader(MLReader[Pipeline]):
"""
def __init__(self, cls: Type[Pipeline]):
- super(PipelineReader, self).__init__()
+ super().__init__()
self.cls = cls
def load(self, path: str) -> Pipeline:
@@ -259,7 +259,7 @@ class PipelineModelWriter(MLWriter):
"""
def __init__(self, instance: "PipelineModel"):
- super(PipelineModelWriter, self).__init__()
+ super().__init__()
self.instance = instance
def saveImpl(self, path: str) -> None:
@@ -277,7 +277,7 @@ class PipelineModelReader(MLReader["PipelineModel"]):
"""
def __init__(self, cls: Type["PipelineModel"]):
- super(PipelineModelReader, self).__init__()
+ super().__init__()
self.cls = cls
def load(self, path: str) -> "PipelineModel":
@@ -295,7 +295,7 @@ class PipelineModel(Model, MLReadable["PipelineModel"], MLWritable):
"""
def __init__(self, stages: List[Transformer]):
- super(PipelineModel, self).__init__()
+ super().__init__()
self.stages = stages
def _transform(self, dataset: DataFrame) -> DataFrame:
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index d11990634593..d40b231224dc 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -72,7 +72,7 @@ class _ALSModelParams(HasPredictionCol, HasBlockSize):
)
def __init__(self, *args: Any):
- super(_ALSModelParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(blockSize=4096)
@since("1.4.0")
@@ -159,7 +159,7 @@ class _ALSParams(_ALSModelParams, HasMaxIter, HasRegParam, HasCheckpointInterval
)
def __init__(self, *args: Any):
- super(_ALSParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
rank=10,
maxIter=10,
@@ -395,7 +395,7 @@ def __init__(
intermediateStorageLevel="MEMORY_AND_DISK", \
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096)
"""
- super(ALS, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index ce97b98f6665..5e72fe89b586 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -190,7 +190,7 @@ class _LinearRegressionParams(
)
def __init__(self, *args: Any):
- super(_LinearRegressionParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
maxIter=100,
regParam=0.0,
@@ -325,7 +325,7 @@ def __init__(
standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \
loss="squaredError", epsilon=1.35, maxBlockSizeInMB=0.0)
"""
- super(LinearRegression, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.LinearRegression", self.uid
)
@@ -795,7 +795,7 @@ class _IsotonicRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, H
)
def __init__(self, *args: Any):
- super(_IsotonicRegressionParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(isotonic=True, featureIndex=0)
def getIsotonic(self) -> bool:
@@ -873,7 +873,7 @@ def __init__(
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
weightCol=None, isotonic=True, featureIndex=0):
"""
- super(IsotonicRegression, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.IsotonicRegression", self.uid
)
@@ -1016,7 +1016,7 @@ class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, Ha
"""
def __init__(self, *args: Any):
- super(_DecisionTreeRegressorParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
maxDepth=5,
maxBins=32,
@@ -1136,7 +1136,7 @@ def __init__(
impurity="variance", seed=None, varianceCol=None, weightCol=None, \
leafCol="", minWeightFractionPerNode=0.0)
"""
- super(DecisionTreeRegressor, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid
)
@@ -1317,7 +1317,7 @@ class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams):
"""
def __init__(self, *args: Any):
- super(_RandomForestRegressorParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
maxDepth=5,
maxBins=32,
@@ -1440,7 +1440,7 @@ def __init__(
featureSubsetStrategy="auto", leafCol=", minWeightFractionPerNode=0.0", \
weightCol=None, bootstrap=True)
"""
- super(RandomForestRegressor, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.RandomForestRegressor", self.uid
)
@@ -1649,7 +1649,7 @@ class _GBTRegressorParams(_GBTParams, _TreeRegressorParams):
)
def __init__(self, *args: Any):
- super(_GBTRegressorParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
maxDepth=5,
maxBins=32,
@@ -1794,7 +1794,7 @@ def __init__(
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0,
weightCol=None)
"""
- super(GBTRegressor, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -2058,7 +2058,7 @@ class _AFTSurvivalRegressionParams(
)
def __init__(self, *args: Any):
- super(_AFTSurvivalRegressionParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
censorCol="censor",
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99],
@@ -2191,7 +2191,7 @@ def __init__(
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
quantilesCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0)
"""
- super(AFTSurvivalRegression, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid
)
@@ -2422,7 +2422,7 @@ class _GeneralizedLinearRegressionParams(
)
def __init__(self, *args: Any):
- super(_GeneralizedLinearRegressionParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
family="gaussian",
maxIter=25,
@@ -2591,7 +2591,7 @@ def __init__(
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2)
"""
- super(GeneralizedLinearRegression, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid
)
@@ -3023,7 +3023,7 @@ class _FactorizationMachinesParams(
)
def __init__(self, *args: Any):
- super(_FactorizationMachinesParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(
factorSize=8,
fitIntercept=True,
@@ -3159,7 +3159,7 @@ def __init__(
miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \
tol=1e-6, solver="adamW", seed=None)
"""
- super(FMRegressor, self).__init__()
+ super().__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.FMRegressor", self.uid)
kwargs = self._input_kwargs
self.setParams(**kwargs)
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index fd32a1515101..99c4b80018cf 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -499,7 +499,7 @@ class SummaryBuilder(JavaWrapper):
def __init__(self, jSummaryBuilder: "JavaObject"):
if not is_remote():
- super(SummaryBuilder, self).__init__(jSummaryBuilder)
+ super().__init__(jSummaryBuilder)
def summary(self, featuresCol: Column, weightCol: Optional[Column] = None) -> Column:
"""
diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py
index e9ccd7b0369e..4406bda09e9d 100644
--- a/python/pyspark/ml/tests/connect/test_connect_classification.py
+++ b/python/pyspark/ml/tests/connect/test_connect_classification.py
@@ -35,7 +35,7 @@
class ClassificationTestsOnConnect(ClassificationTestsMixin, ReusedConnectTestCase):
@classmethod
def conf(cls):
- config = super(ClassificationTestsOnConnect, cls).conf()
+ config = super().conf()
config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
return config
diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
index 2b408911fbd2..d88110f71531 100644
--- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
@@ -36,7 +36,7 @@
class PipelineTestsOnConnect(PipelineTestsMixin, ReusedConnectTestCase):
@classmethod
def conf(cls):
- config = super(PipelineTestsOnConnect, cls).conf()
+ config = super().conf()
config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
return config
diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py b/python/pyspark/ml/tests/connect/test_connect_tuning.py
index 3b7f977b57ae..823a9b32e924 100644
--- a/python/pyspark/ml/tests/connect/test_connect_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py
@@ -36,7 +36,7 @@
class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, ReusedConnectTestCase):
@classmethod
def conf(cls):
- config = super(CrossValidatorTestsOnConnect, cls).conf()
+ config = super().conf()
config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
return config
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
index f29a2c77e4ea..7b53aa0b7c7d 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
@@ -48,7 +48,7 @@
class HasInducedError(Params):
def __init__(self):
- super(HasInducedError, self).__init__()
+ super().__init__()
self.inducedError = Param(
self, "inducedError", "Uniformly-distributed error added to feature"
)
@@ -58,7 +58,7 @@ def getInducedError(self):
class InducedErrorModel(Model, HasInducedError):
def __init__(self):
- super(InducedErrorModel, self).__init__()
+ super().__init__()
def _transform(self, dataset):
return dataset.withColumn(
@@ -67,7 +67,7 @@ def _transform(self, dataset):
class InducedErrorEstimator(Estimator, HasInducedError):
def __init__(self, inducedError=1.0):
- super(InducedErrorEstimator, self).__init__()
+ super().__init__()
self._set(inducedError=inducedError)
def _fit(self, dataset):
diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py
index 0f5deab4e093..875718bb7b21 100644
--- a/python/pyspark/ml/tests/test_algorithms.py
+++ b/python/pyspark/ml/tests/test_algorithms.py
@@ -254,7 +254,7 @@ def test_persistence(self):
class FPGrowthTests(SparkSessionTestCase):
def setUp(self):
- super(FPGrowthTests, self).setUp()
+ super().setUp()
self.data = self.spark.createDataFrame(
[([1, 2],), ([1, 2],), ([1, 2, 3],), ([1, 3],)], ["items"]
)
diff --git a/python/pyspark/ml/tests/test_functions.py b/python/pyspark/ml/tests/test_functions.py
index 02e6d0b7c424..4621ce63db49 100644
--- a/python/pyspark/ml/tests/test_functions.py
+++ b/python/pyspark/ml/tests/test_functions.py
@@ -70,7 +70,7 @@ class PredictBatchUDFTestsMixin:
def setUp(self):
import pandas as pd
- super(PredictBatchUDFTestsMixin, self).setUp()
+ super().setUp()
self.data = np.arange(0, 1000, dtype=np.float64).reshape(-1, 4)
# 4 scalar columns
diff --git a/python/pyspark/ml/tests/test_model_cache.py b/python/pyspark/ml/tests/test_model_cache.py
index 9ad8ac544274..6880ace2574f 100644
--- a/python/pyspark/ml/tests/test_model_cache.py
+++ b/python/pyspark/ml/tests/test_model_cache.py
@@ -23,7 +23,7 @@
class ModelCacheTests(SparkSessionTestCase):
def setUp(self):
- super(ModelCacheTests, self).setUp()
+ super().setUp()
def test_cache(self):
def predict_fn(inputs):
diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py
index 0aa982712495..03d514a38519 100644
--- a/python/pyspark/ml/tests/test_param.py
+++ b/python/pyspark/ml/tests/test_param.py
@@ -137,7 +137,7 @@ class TestParams(HasMaxIter, HasInputCol, HasSeed):
@keyword_only
def __init__(self, seed=None):
- super(TestParams, self).__init__()
+ super().__init__()
self._setDefault(maxIter=10)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -159,7 +159,7 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
@keyword_only
def __init__(self, seed=None):
- super(OtherTestParams, self).__init__()
+ super().__init__()
self._setDefault(maxIter=10)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -176,7 +176,7 @@ def setParams(self, seed=None):
class HasThrowableProperty(Params):
def __init__(self):
- super(HasThrowableProperty, self).__init__()
+ super().__init__()
self.p = Param(self, "none", "empty param")
@property
diff --git a/python/pyspark/ml/tests/tuning/test_tuning.py b/python/pyspark/ml/tests/tuning/test_tuning.py
index cbbe7b6cca82..abe84d3c5955 100644
--- a/python/pyspark/ml/tests/tuning/test_tuning.py
+++ b/python/pyspark/ml/tests/tuning/test_tuning.py
@@ -40,7 +40,7 @@
class HasInducedError(Params):
def __init__(self):
- super(HasInducedError, self).__init__()
+ super().__init__()
self.inducedError = Param(
self, "inducedError", "Uniformly-distributed error added to feature"
)
@@ -51,7 +51,7 @@ def getInducedError(self):
class InducedErrorModel(Model, HasInducedError):
def __init__(self):
- super(InducedErrorModel, self).__init__()
+ super().__init__()
def _transform(self, dataset):
return dataset.withColumn(
@@ -61,7 +61,7 @@ def _transform(self, dataset):
class InducedErrorEstimator(Estimator, HasInducedError):
def __init__(self, inducedError=1.0):
- super(InducedErrorEstimator, self).__init__()
+ super().__init__()
self._set(inducedError=inducedError)
def _fit(self, dataset):
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py b/python/pyspark/ml/torch/tests/test_distributor.py
index e9bf1d784000..4ef2f63153af 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -69,7 +69,7 @@ def create_training_function(mnist_dir_path: str) -> Callable:
class Net(nn.Module):
def __init__(self) -> None:
- super(Net, self).__init__()
+ super().__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
diff --git a/python/pyspark/ml/tree.py b/python/pyspark/ml/tree.py
index 449b67a5b089..86e8422702b1 100644
--- a/python/pyspark/ml/tree.py
+++ b/python/pyspark/ml/tree.py
@@ -154,7 +154,7 @@ class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol):
)
def __init__(self) -> None:
- super(_DecisionTreeParams, self).__init__()
+ super().__init__()
def setLeafCol(self: "P", value: str) -> "P":
"""
@@ -285,7 +285,7 @@ class _TreeEnsembleParams(_DecisionTreeParams):
)
def __init__(self) -> None:
- super(_TreeEnsembleParams, self).__init__()
+ super().__init__()
@since("1.4.0")
def getSubsamplingRate(self) -> float:
@@ -322,7 +322,7 @@ class _RandomForestParams(_TreeEnsembleParams):
)
def __init__(self) -> None:
- super(_RandomForestParams, self).__init__()
+ super().__init__()
@since("1.4.0")
def getNumTrees(self) -> int:
@@ -387,7 +387,7 @@ class _HasVarianceImpurity(Params):
)
def __init__(self) -> None:
- super(_HasVarianceImpurity, self).__init__()
+ super().__init__()
@since("1.4.0")
def getImpurity(self) -> str:
@@ -416,7 +416,7 @@ class _TreeClassifierParams(Params):
)
def __init__(self) -> None:
- super(_TreeClassifierParams, self).__init__()
+ super().__init__()
@since("1.6.0")
def getImpurity(self) -> str:
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index d922ea38f2d8..60bcf7ae0c2c 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -521,7 +521,7 @@ def getValidatorModelWriterPersistSubModelsParam(writer: MLWriter) -> bool:
@inherit_doc
class CrossValidatorReader(MLReader["CrossValidator"]):
def __init__(self, cls: Type["CrossValidator"]):
- super(CrossValidatorReader, self).__init__()
+ super().__init__()
self.cls = cls
def load(self, path: str) -> "CrossValidator":
@@ -540,7 +540,7 @@ def load(self, path: str) -> "CrossValidator":
@inherit_doc
class CrossValidatorWriter(MLWriter):
def __init__(self, instance: "CrossValidator"):
- super(CrossValidatorWriter, self).__init__()
+ super().__init__()
self.instance = instance
def saveImpl(self, path: str) -> None:
@@ -551,7 +551,7 @@ def saveImpl(self, path: str) -> None:
@inherit_doc
class CrossValidatorModelReader(MLReader["CrossValidatorModel"]):
def __init__(self, cls: Type["CrossValidatorModel"]):
- super(CrossValidatorModelReader, self).__init__()
+ super().__init__()
self.cls = cls
def load(self, path: str) -> "CrossValidatorModel":
@@ -599,7 +599,7 @@ def load(self, path: str) -> "CrossValidatorModel":
@inherit_doc
class CrossValidatorModelWriter(MLWriter):
def __init__(self, instance: "CrossValidatorModel"):
- super(CrossValidatorModelWriter, self).__init__()
+ super().__init__()
self.instance = instance
def saveImpl(self, path: str) -> None:
@@ -656,7 +656,7 @@ class _CrossValidatorParams(_ValidatorParams):
)
def __init__(self, *args: Any):
- super(_CrossValidatorParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(numFolds=3, foldCol="")
@since("1.4.0")
@@ -747,7 +747,7 @@ def __init__(
__init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
seed=None, parallelism=1, collectSubModels=False, foldCol="")
"""
- super(CrossValidator, self).__init__()
+ super().__init__()
self._setDefault(parallelism=1)
kwargs = self._input_kwargs
self._set(**kwargs)
@@ -973,7 +973,7 @@ def _from_java(cls, java_stage: "JavaObject") -> "CrossValidator":
Used for ML persistence.
"""
- estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage)
+ estimator, epms, evaluator = super()._from_java_impl(java_stage)
numFolds = java_stage.getNumFolds()
seed = java_stage.getSeed()
parallelism = java_stage.getParallelism()
@@ -1003,7 +1003,7 @@ def _to_java(self) -> "JavaObject":
Java object equivalent to this instance.
"""
- estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
+ estimator, epms, evaluator = super()._to_java_impl()
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
_java_obj.setEstimatorParamMaps(epms)
@@ -1042,7 +1042,7 @@ def __init__(
subModels: Optional[List[List[Model]]] = None,
stdMetrics: Optional[List[float]] = None,
):
- super(CrossValidatorModel, self).__init__()
+ super().__init__()
#: best model from cross validation
self.bestModel = bestModel
#: Average cross-validation metrics for each paramMap in
@@ -1119,7 +1119,7 @@ def _from_java(cls, java_stage: "JavaObject") -> "CrossValidatorModel":
bestModel: Model = JavaParams._from_java(java_stage.bestModel())
avgMetrics = _java2py(sc, java_stage.avgMetrics())
- estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
+ estimator, epms, evaluator = super()._from_java_impl(java_stage)
py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics)
params = {
@@ -1162,7 +1162,7 @@ def _to_java(self) -> "JavaObject":
cast(JavaParams, self.bestModel)._to_java(),
_py2java(sc, self.avgMetrics),
)
- estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
+ estimator, epms, evaluator = super()._to_java_impl()
params = {
"evaluator": evaluator,
@@ -1189,7 +1189,7 @@ def _to_java(self) -> "JavaObject":
@inherit_doc
class TrainValidationSplitReader(MLReader["TrainValidationSplit"]):
def __init__(self, cls: Type["TrainValidationSplit"]):
- super(TrainValidationSplitReader, self).__init__()
+ super().__init__()
self.cls = cls
def load(self, path: str) -> "TrainValidationSplit":
@@ -1208,7 +1208,7 @@ def load(self, path: str) -> "TrainValidationSplit":
@inherit_doc
class TrainValidationSplitWriter(MLWriter):
def __init__(self, instance: "TrainValidationSplit"):
- super(TrainValidationSplitWriter, self).__init__()
+ super().__init__()
self.instance = instance
def saveImpl(self, path: str) -> None:
@@ -1219,7 +1219,7 @@ def saveImpl(self, path: str) -> None:
@inherit_doc
class TrainValidationSplitModelReader(MLReader["TrainValidationSplitModel"]):
def __init__(self, cls: Type["TrainValidationSplitModel"]):
- super(TrainValidationSplitModelReader, self).__init__()
+ super().__init__()
self.cls = cls
def load(self, path: str) -> "TrainValidationSplitModel":
@@ -1258,7 +1258,7 @@ def load(self, path: str) -> "TrainValidationSplitModel":
@inherit_doc
class TrainValidationSplitModelWriter(MLWriter):
def __init__(self, instance: "TrainValidationSplitModel"):
- super(TrainValidationSplitModelWriter, self).__init__()
+ super().__init__()
self.instance = instance
def saveImpl(self, path: str) -> None:
@@ -1304,7 +1304,7 @@ class _TrainValidationSplitParams(_ValidatorParams):
)
def __init__(self, *args: Any):
- super(_TrainValidationSplitParams, self).__init__(*args)
+ super().__init__(*args)
self._setDefault(trainRatio=0.75)
@since("2.0.0")
@@ -1385,7 +1385,7 @@ def __init__(
__init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \
trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None)
"""
- super(TrainValidationSplit, self).__init__()
+ super().__init__()
self._setDefault(parallelism=1)
kwargs = self._input_kwargs
self._set(**kwargs)
@@ -1552,7 +1552,7 @@ def _from_java(cls, java_stage: "JavaObject") -> "TrainValidationSplit":
Used for ML persistence.
"""
- estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage)
+ estimator, epms, evaluator = super()._from_java_impl(java_stage)
trainRatio = java_stage.getTrainRatio()
seed = java_stage.getSeed()
parallelism = java_stage.getParallelism()
@@ -1580,7 +1580,7 @@ def _to_java(self) -> "JavaObject":
Java object equivalent to this instance.
"""
- estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
+ estimator, epms, evaluator = super()._to_java_impl()
_java_obj = JavaParams._new_java_obj(
"org.apache.spark.ml.tuning.TrainValidationSplit", self.uid
@@ -1610,7 +1610,7 @@ def __init__(
validationMetrics: Optional[List[float]] = None,
subModels: Optional[List[Model]] = None,
):
- super(TrainValidationSplitModel, self).__init__()
+ super().__init__()
#: best model from train validation split
self.bestModel = bestModel
#: evaluated validation metrics
@@ -1681,9 +1681,7 @@ def _from_java(cls, java_stage: "JavaObject") -> "TrainValidationSplitModel":
bestModel: Model = JavaParams._from_java(java_stage.bestModel())
validationMetrics = _java2py(sc, java_stage.validationMetrics())
- estimator, epms, evaluator = super(TrainValidationSplitModel, cls)._from_java_impl(
- java_stage
- )
+ estimator, epms, evaluator = super()._from_java_impl(java_stage)
# Create a new instance of this stage.
py_stage = cls(bestModel=bestModel, validationMetrics=validationMetrics)
params = {
@@ -1724,7 +1722,7 @@ def _to_java(self) -> "JavaObject":
cast(JavaParams, self.bestModel)._to_java(),
_py2java(sc, self.validationMetrics),
)
- estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
+ estimator, epms, evaluator = super()._to_java_impl()
params = {
"evaluator": evaluator,
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 06759bc25269..5849f4a6dad8 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -566,7 +566,7 @@ class MLWriter(BaseReadWrite):
"""
def __init__(self) -> None:
- super(MLWriter, self).__init__()
+ super().__init__()
self.shouldOverwrite: bool = False
self.optionMap: Dict[str, Any] = {}
@@ -630,7 +630,7 @@ class JavaMLWriter(MLWriter):
_jwrite: "JavaObject"
def __init__(self, instance: "JavaMLWritable"):
- super(JavaMLWriter, self).__init__()
+ super().__init__()
_java_obj = instance._to_java() # type: ignore[attr-defined]
self._jwrite = _java_obj.write()
@@ -662,7 +662,7 @@ class GeneralJavaMLWriter(JavaMLWriter):
"""
def __init__(self, instance: "JavaMLWritable"):
- super(GeneralJavaMLWriter, self).__init__(instance)
+ super().__init__(instance)
def format(self, source: str) -> "GeneralJavaMLWriter":
"""
@@ -723,7 +723,7 @@ class MLReader(BaseReadWrite, Generic[RL]):
"""
def __init__(self) -> None:
- super(MLReader, self).__init__()
+ super().__init__()
def load(self, path: str) -> RL:
"""Load the ML instance from the input path."""
@@ -737,7 +737,7 @@ class JavaMLReader(MLReader[RL]):
"""
def __init__(self, clazz: Type["JavaMLReadable[RL]"]) -> None:
- super(JavaMLReader, self).__init__()
+ super().__init__()
self._clazz = clazz
self._jread = self._load_java_obj(clazz).read()
@@ -850,7 +850,7 @@ class DefaultParamsWriter(MLWriter):
"""
def __init__(self, instance: "Params"):
- super(DefaultParamsWriter, self).__init__()
+ super().__init__()
self.instance = instance
def saveImpl(self, path: str) -> None:
@@ -976,7 +976,7 @@ class DefaultParamsReader(MLReader[RL]):
"""
def __init__(self, cls: Type[DefaultParamsReadable[RL]]):
- super(DefaultParamsReader, self).__init__()
+ super().__init__()
self.cls = cls
@staticmethod
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index b8d86e9eab3b..abb7c7625f1c 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -52,7 +52,7 @@ class JavaWrapper:
"""
def __init__(self, java_obj: Optional["JavaObject"] = None):
- super(JavaWrapper, self).__init__()
+ super().__init__()
self._java_obj = java_obj
@try_remote_del
@@ -355,7 +355,7 @@ def copy(self: "JP", extra: Optional["ParamMap"] = None) -> "JP":
"""
if extra is None:
extra = dict()
- that = super(JavaParams, self).copy(extra)
+ that = super().copy(extra)
if self._java_obj is not None:
from pyspark.ml.util import RemoteModelRef
@@ -374,7 +374,7 @@ def clear(self, param: Param) -> None:
"""
assert self._java_obj is not None
- super(JavaParams, self).clear(param)
+ super().clear(param)
java_param = self._java_obj.getParam(param.name)
self._java_obj.clear(java_param)
@@ -457,7 +457,7 @@ def __init__(self, java_model: Optional["JavaObject"] = None):
these wrappers depend on pyspark.ml.util (both directly and via
other ML classes).
"""
- super(JavaModel, self).__init__(java_model)
+ super().__init__(java_model)
if is_remote() and java_model is not None:
from pyspark.ml.util import RemoteModelRef
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 84b59eecbe9d..3a7c40d15fdf 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -58,7 +58,7 @@ class LinearClassificationModel(LinearModel):
"""
def __init__(self, weights: Vector, intercept: float) -> None:
- super(LinearClassificationModel, self).__init__(weights, intercept)
+ super().__init__(weights, intercept)
self._threshold: Optional[float] = None
@since("1.4.0")
@@ -198,7 +198,7 @@ class LogisticRegressionModel(LinearClassificationModel):
def __init__(
self, weights: Vector, intercept: float, numFeatures: int, numClasses: int
) -> None:
- super(LogisticRegressionModel, self).__init__(weights, intercept)
+ super().__init__(weights, intercept)
self._numFeatures = int(numFeatures)
self._numClasses = int(numClasses)
self._threshold = 0.5
@@ -584,7 +584,7 @@ class SVMModel(LinearClassificationModel):
"""
def __init__(self, weights: Vector, intercept: float) -> None:
- super(SVMModel, self).__init__(weights, intercept)
+ super().__init__(weights, intercept)
self._threshold = 0.0
@overload
@@ -931,7 +931,7 @@ def __init__(
self.miniBatchFraction = miniBatchFraction
self.convergenceTol = convergenceTol
self._model: Optional[LogisticRegressionModel] = None
- super(StreamingLogisticRegressionWithSGD, self).__init__(model=self._model)
+ super().__init__(model=self._model)
@since("1.5.0")
def setInitialWeights(
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index ff5d69b084e2..4fac735d7ed7 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -76,7 +76,7 @@ class BisectingKMeansModel(JavaModelWrapper):
"""
def __init__(self, java_model: "JavaObject"):
- super(BisectingKMeansModel, self).__init__(java_model)
+ super().__init__(java_model)
self.centers = [c.toArray() for c in self.call("clusterCenters")]
@property
@@ -943,7 +943,7 @@ class StreamingKMeansModel(KMeansModel):
"""
def __init__(self, clusterCenters: List["VectorLike"], clusterWeights: "VectorLike"):
- super(StreamingKMeansModel, self).__init__(centers=clusterCenters)
+ super().__init__(centers=clusterCenters)
self._clusterWeights = list(clusterWeights) # type: ignore[arg-type]
@property
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index dfcee167bea5..1c744e1ddb98 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -82,7 +82,7 @@ def __init__(self, scoreAndLabels: RDD[Tuple[float, float]]):
assert sc._jvm is not None
java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
java_model = java_class(df._jdf)
- super(BinaryClassificationMetrics, self).__init__(java_model)
+ super().__init__(java_model)
@property
@since("1.4.0")
@@ -158,7 +158,7 @@ def __init__(self, predictionAndObservations: RDD[Tuple[float, float]]):
assert sc._jvm is not None
java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics
java_model = java_class(df._jdf)
- super(RegressionMetrics, self).__init__(java_model)
+ super().__init__(java_model)
@property
@since("1.4.0")
@@ -299,7 +299,7 @@ def __init__(self, predictionAndLabels: RDD[Tuple[float, float]]):
assert sc._jvm is not None
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
java_model = java_class(df._jdf)
- super(MulticlassMetrics, self).__init__(java_model)
+ super().__init__(java_model)
@since("1.4.0")
def confusionMatrix(self) -> Matrix:
@@ -465,7 +465,7 @@ def __init__(
predictionAndLabels, schema=sql_ctx.sparkSession._inferSchema(predictionAndLabels)
)
java_model = callMLlibFunc("newRankingMetrics", df._jdf)
- super(RankingMetrics, self).__init__(java_model)
+ super().__init__(java_model)
@since("1.4.0")
def precisionAt(self, k: int) -> float:
@@ -581,7 +581,7 @@ def __init__(self, predictionAndLabels: RDD[Tuple[List[float], List[float]]]):
assert sc._jvm is not None
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics
java_model = java_class(df._jdf)
- super(MultilabelMetrics, self).__init__(java_model)
+ super().__init__(java_model)
@since("1.4.0")
def precision(self, label: Optional[float] = None) -> float:
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index b4bd41492706..b384e0fa608e 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -1014,7 +1014,7 @@ def __init__(
self.miniBatchFraction = miniBatchFraction
self.convergenceTol = convergenceTol
self._model: Optional[LinearModel] = None
- super(StreamingLinearRegressionWithSGD, self).__init__(model=self._model)
+ super().__init__(model=self._model)
@since("1.5.0")
def setInitialWeights(self, initialWeights: "VectorLike") -> "StreamingLinearRegressionWithSGD":
diff --git a/python/pyspark/pandas/data_type_ops/datetime_ops.py b/python/pyspark/pandas/data_type_ops/datetime_ops.py
index 22bd7a6d329d..6679cb5783ae 100644
--- a/python/pyspark/pandas/data_type_ops/datetime_ops.py
+++ b/python/pyspark/pandas/data_type_ops/datetime_ops.py
@@ -166,4 +166,4 @@ def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> Ind
)
return index_ops._with_new_scol(scol, field=InternalField(dtype=dtype))
else:
- return super(DatetimeNTZOps, self).astype(index_ops, dtype)
+ return super().astype(index_ops, dtype)
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 379d3698bc09..e5aaecbb64fd 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -2699,8 +2699,20 @@ def to_feather(
# Make sure locals() call is at the top of the function so we don't capture local variables.
args = locals()
+ pdf = self._to_internal_pandas()
+ # SPARK-54068: PyArrow >= 22.0.0 serializes DataFrame.attrs to JSON metadata,
+ # but PlanMetrics/PlanObservedMetrics objects from Spark Connect are not
+ # JSON serializable. We filter these internal attrs only for affected versions.
+ import pyarrow as pa
+ from pyspark.loose_version import LooseVersion
+
+ if LooseVersion(pa.__version__) >= LooseVersion("22.0.0"):
+ pdf.attrs = {
+ k: v for k, v in pdf.attrs.items() if k not in ("metrics", "observed_metrics")
+ }
+
return validate_arguments_and_invoke_function(
- self._to_internal_pandas(), self.to_feather, pd.DataFrame.to_feather, args
+ pdf, self.to_feather, pd.DataFrame.to_feather, args
)
def to_stata(
diff --git a/python/pyspark/pandas/sql_formatter.py b/python/pyspark/pandas/sql_formatter.py
index fe500c0bf207..98f085b39717 100644
--- a/python/pyspark/pandas/sql_formatter.py
+++ b/python/pyspark/pandas/sql_formatter.py
@@ -237,7 +237,7 @@ def __init__(self, session: SparkSession) -> None:
self._ref_sers: List[Tuple[Series, str]] = []
def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> str:
- ret = super(PandasSQLStringFormatter, self).vformat(format_string, args, kwargs)
+ ret = super().vformat(format_string, args, kwargs)
for ref, n in self._ref_sers:
if not any((ref is v for v in df._pssers.values()) for df, _ in self._temp_views):
@@ -246,7 +246,7 @@ def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str,
return ret
def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any:
- obj, first = super(PandasSQLStringFormatter, self).get_field(field_name, args, kwargs)
+ obj, first = super().get_field(field_name, args, kwargs)
return self._convert_value(obj, field_name), first
def _convert_value(self, val: Any, name: str) -> Optional[str]:
diff --git a/python/pyspark/pandas/tests/computation/test_describe.py b/python/pyspark/pandas/tests/computation/test_describe.py
index 8df07f1945d2..1a45c3e572e4 100644
--- a/python/pyspark/pandas/tests/computation/test_describe.py
+++ b/python/pyspark/pandas/tests/computation/test_describe.py
@@ -28,7 +28,7 @@
class FrameDescribeMixin:
@classmethod
def setUpClass(cls):
- super(FrameDescribeMixin, cls).setUpClass()
+ super().setUpClass()
# Some nanosecond->microsecond conversions throw loss of precision errors
cls.spark.conf.set("spark.sql.execution.pandas.convertToArrowArraySafely", "false")
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py
index f98f2011dde0..228061cac804 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py
@@ -248,7 +248,7 @@ class DatetimeOpsTests(
class DatetimeNTZOpsTest(DatetimeOpsTests):
@classmethod
def setUpClass(cls):
- super(DatetimeOpsTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.timestampType", "timestamp_ntz")
diff --git a/python/pyspark/pandas/tests/data_type_ops/testing_utils.py b/python/pyspark/pandas/tests/data_type_ops/testing_utils.py
index 17ac2bc5c474..04d03a05e02d 100644
--- a/python/pyspark/pandas/tests/data_type_ops/testing_utils.py
+++ b/python/pyspark/pandas/tests/data_type_ops/testing_utils.py
@@ -43,7 +43,7 @@ class OpsTestBase:
@classmethod
def setUpClass(cls):
- super(OpsTestBase, cls).setUpClass()
+ super().setUpClass()
# Some nanosecond->microsecond conversions throw loss of precision errors
cls.spark.conf.set("spark.sql.execution.pandas.convertToArrowArraySafely", "false")
diff --git a/python/pyspark/pandas/tests/io/test_feather.py b/python/pyspark/pandas/tests/io/test_feather.py
index 10638d915c0e..74fa6bc7d7b6 100644
--- a/python/pyspark/pandas/tests/io/test_feather.py
+++ b/python/pyspark/pandas/tests/io/test_feather.py
@@ -17,10 +17,8 @@
import unittest
import pandas as pd
-import sys
from pyspark import pandas as ps
-from pyspark.loose_version import LooseVersion
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
@@ -36,16 +34,6 @@ def pdf(self):
def psdf(self):
return ps.from_pandas(self.pdf)
- has_arrow_21_or_below = False
- try:
- import pyarrow as pa
-
- if LooseVersion(pa.__version__) < LooseVersion("22.0.0"):
- has_arrow_21_or_below = True
- except ImportError:
- pass
-
- @unittest.skipIf(not has_arrow_21_or_below, "SPARK-54068")
def test_to_feather(self):
with self.temp_dir() as dirpath:
path1 = f"{dirpath}/file1.feather"
diff --git a/python/pyspark/pandas/tests/resample/test_timezone.py b/python/pyspark/pandas/tests/resample/test_timezone.py
index 17c46dd26b35..ad01a6413709 100644
--- a/python/pyspark/pandas/tests/resample/test_timezone.py
+++ b/python/pyspark/pandas/tests/resample/test_timezone.py
@@ -33,11 +33,11 @@ class ResampleTimezoneMixin:
def setUpClass(cls):
cls.timezone = os.environ.get("TZ", None)
os.environ["TZ"] = "America/New_York"
- super(ResampleTimezoneMixin, cls).setUpClass()
+ super().setUpClass()
@classmethod
def tearDownClass(cls):
- super(ResampleTimezoneMixin, cls).tearDownClass()
+ super().tearDownClass()
if cls.timezone is not None:
os.environ["TZ"] = cls.timezone
diff --git a/python/pyspark/pandas/tests/test_numpy_compat.py b/python/pyspark/pandas/tests/test_numpy_compat.py
index d961a433e181..97df2bc6eb8d 100644
--- a/python/pyspark/pandas/tests/test_numpy_compat.py
+++ b/python/pyspark/pandas/tests/test_numpy_compat.py
@@ -28,7 +28,7 @@
class NumPyCompatTestsMixin:
@classmethod
def setUpClass(cls):
- super(NumPyCompatTestsMixin, cls).setUpClass()
+ super().setUpClass()
# Some nanosecond->microsecond conversions throw loss of precision errors
cls.spark.conf.set("spark.sql.execution.pandas.convertToArrowArraySafely", "false")
diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py
index 0c9a6b45bd2f..a3f188302296 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -502,6 +502,20 @@ def pairs(self) -> dict[str, Any]:
def keys(self) -> List[str]:
return self._keys
+ def to_dict(self) -> dict[str, Any]:
+ """Return a JSON-serializable dictionary representation of this observed metrics.
+
+ Returns
+ -------
+ dict
+ A dictionary with keys 'name', 'keys', and 'pairs'.
+ """
+ return {
+ "name": self._name,
+ "keys": self._keys,
+ "pairs": self.pairs,
+ }
+
class AnalyzeResult:
def __init__(
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 6630d96f21de..80b6b562369e 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1930,7 +1930,7 @@ def command(self, session: "SparkConnectClient") -> proto.Command:
class WriteOperation(LogicalPlan):
def __init__(self, child: "LogicalPlan") -> None:
- super(WriteOperation, self).__init__(child)
+ super().__init__(child)
self.source: Optional[str] = None
self.path: Optional[str] = None
self.table_name: Optional[str] = None
@@ -2037,7 +2037,7 @@ def _repr_html_(self) -> str:
class WriteOperationV2(LogicalPlan):
def __init__(self, child: "LogicalPlan", table_name: str) -> None:
- super(WriteOperationV2, self).__init__(child)
+ super().__init__(child)
self.table_name: Optional[str] = table_name
self.provider: Optional[str] = None
self.partitioning_columns: List[Column] = []
@@ -2101,7 +2101,7 @@ def command(self, session: "SparkConnectClient") -> proto.Command:
class WriteStreamOperation(LogicalPlan):
def __init__(self, child: "LogicalPlan") -> None:
- super(WriteStreamOperation, self).__init__(child)
+ super().__init__(child)
self.write_op = proto.WriteStreamOperationStart()
def command(self, session: "SparkConnectClient") -> proto.Command:
diff --git a/python/pyspark/sql/connect/sql_formatter.py b/python/pyspark/sql/connect/sql_formatter.py
index 38b94bbaf205..8fced80081ad 100644
--- a/python/pyspark/sql/connect/sql_formatter.py
+++ b/python/pyspark/sql/connect/sql_formatter.py
@@ -39,7 +39,7 @@ def __init__(self, session: "SparkSession") -> None:
self._temp_views: List[Tuple[DataFrame, str]] = []
def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any:
- obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs)
+ obj, first = super().get_field(field_name, args, kwargs)
return self._convert_value(obj, field_name), first
def _convert_value(self, val: Any, field_name: str) -> Optional[str]:
diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py
index 4ab9b041e313..50d8ae444c47 100644
--- a/python/pyspark/sql/metrics.py
+++ b/python/pyspark/sql/metrics.py
@@ -68,6 +68,20 @@ def value(self) -> Union[int, float]:
def metric_type(self) -> str:
return self._type
+ def to_dict(self) -> Dict[str, Any]:
+ """Return a JSON-serializable dictionary representation of this metric value.
+
+ Returns
+ -------
+ dict
+ A dictionary with keys 'name', 'value', and 'type'.
+ """
+ return {
+ "name": self._name,
+ "value": self._value,
+ "type": self._type,
+ }
+
class PlanMetrics:
"""Represents a particular plan node and the associated metrics of this node."""
@@ -97,6 +111,21 @@ def parent_plan_id(self) -> int:
def metrics(self) -> List[MetricValue]:
return self._metrics
+ def to_dict(self) -> Dict[str, Any]:
+ """Return a JSON-serializable dictionary representation of this plan metrics.
+
+ Returns
+ -------
+ dict
+ A dictionary with keys 'name', 'plan_id', 'parent_plan_id', and 'metrics'.
+ """
+ return {
+ "name": self._name,
+ "plan_id": self._id,
+ "parent_plan_id": self._parent_id,
+ "metrics": [m.to_dict() for m in self._metrics],
+ }
+
class CollectedMetrics:
@dataclasses.dataclass
diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 768160087032..dc854cb1985d 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -149,7 +149,7 @@ def load_stream(self, stream):
"""
import pyarrow as pa
- batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
+ batches = super().load_stream(stream)
for batch in batches:
struct = batch.column(0)
yield [pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))]
@@ -184,7 +184,7 @@ def wrap_and_init_stream():
should_write_start_length = False
yield batch
- return super(ArrowStreamUDFSerializer, self).dump_stream(wrap_and_init_stream(), stream)
+ return super().dump_stream(wrap_and_init_stream(), stream)
class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer):
@@ -304,7 +304,7 @@ class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
"""
def __init__(self, assign_cols_by_name):
- super(ArrowStreamGroupUDFSerializer, self).__init__()
+ super().__init__()
self._assign_cols_by_name = assign_cols_by_name
def dump_stream(self, iterator, stream):
@@ -330,7 +330,7 @@ def dump_stream(self, iterator, stream):
for batch, arrow_type in batch_iter
)
- super(ArrowStreamGroupUDFSerializer, self).dump_stream(batch_iter, stream)
+ super().dump_stream(batch_iter, stream)
class ArrowStreamPandasSerializer(ArrowStreamSerializer):
@@ -351,7 +351,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
"""
def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled):
- super(ArrowStreamPandasSerializer, self).__init__()
+ super().__init__()
self._timezone = timezone
self._safecheck = safecheck
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
@@ -528,13 +528,13 @@ def dump_stream(self, iterator, stream):
a list of series accompanied by an optional pyarrow type to coerce the data to.
"""
batches = (self._create_batch(series) for series in iterator)
- super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)
+ super().dump_stream(batches, stream)
def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
- batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
+ batches = super().load_stream(stream)
import pyarrow as pa
import pandas as pd
@@ -569,9 +569,7 @@ def __init__(
input_types=None,
int_to_decimal_coercion_enabled=False,
):
- super(ArrowStreamPandasUDFSerializer, self).__init__(
- timezone, safecheck, int_to_decimal_coercion_enabled
- )
+ super().__init__(timezone, safecheck, int_to_decimal_coercion_enabled)
self._assign_cols_by_name = assign_cols_by_name
self._df_for_struct = df_for_struct
self._struct_in_pandas = struct_in_pandas
@@ -593,6 +591,7 @@ def arrow_to_pandas(self, arrow_column, idx):
import pandas as pd
series = [
+ # Need to be explicit here because it's in a comprehension
super(ArrowStreamPandasUDFSerializer, self)
.arrow_to_pandas(
column,
@@ -610,7 +609,7 @@ def arrow_to_pandas(self, arrow_column, idx):
]
s = pd.concat(series, axis=1)
else:
- s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(
+ s = super().arrow_to_pandas(
arrow_column,
idx,
self._struct_in_pandas,
@@ -768,7 +767,7 @@ def __init__(
assign_cols_by_name,
arrow_cast,
):
- super(ArrowStreamArrowUDFSerializer, self).__init__()
+ super().__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name
@@ -941,7 +940,7 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
"""
def __init__(self, timezone, safecheck, input_types, int_to_decimal_coercion_enabled):
- super(ArrowStreamPandasUDTFSerializer, self).__init__(
+ super().__init__(
timezone=timezone,
safecheck=safecheck,
# The output pandas DataFrame's columns are unnamed.
@@ -1089,7 +1088,7 @@ def __repr__(self):
class GroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
def __init__(self, assign_cols_by_name):
- super(GroupArrowUDFSerializer, self).__init__(
+ super().__init__(
assign_cols_by_name=assign_cols_by_name,
)
@@ -1138,13 +1137,9 @@ def __init__(
super().__init__(
timezone=timezone,
safecheck=safecheck,
- assign_cols_by_name=False,
- arrow_cast=True,
+ assign_cols_by_name=assign_cols_by_name,
+ arrow_cast=arrow_cast,
)
- self._timezone = timezone
- self._safecheck = safecheck
- self._assign_cols_by_name = assign_cols_by_name
- self._arrow_cast = arrow_cast
def load_stream(self, stream):
"""
@@ -1189,13 +1184,9 @@ def __init__(
super().__init__(
timezone=timezone,
safecheck=safecheck,
- assign_cols_by_name=False,
- arrow_cast=True,
+ assign_cols_by_name=assign_cols_by_name,
+ arrow_cast=arrow_cast,
)
- self._timezone = timezone
- self._safecheck = safecheck
- self._assign_cols_by_name = assign_cols_by_name
- self._arrow_cast = arrow_cast
def load_stream(self, stream):
"""
@@ -1238,10 +1229,10 @@ def __init__(
assign_cols_by_name,
int_to_decimal_coercion_enabled,
):
- super(ArrowStreamAggPandasUDFSerializer, self).__init__(
+ super().__init__(
timezone=timezone,
safecheck=safecheck,
- assign_cols_by_name=False,
+ assign_cols_by_name=assign_cols_by_name,
df_for_struct=False,
struct_in_pandas="dict",
ndarray_as_list=False,
@@ -1249,9 +1240,6 @@ def __init__(
input_types=None,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
- self._timezone = timezone
- self._safecheck = safecheck
- self._assign_cols_by_name = assign_cols_by_name
def load_stream(self, stream):
"""
@@ -1295,7 +1283,7 @@ def __init__(
assign_cols_by_name,
int_to_decimal_coercion_enabled,
):
- super(GroupPandasUDFSerializer, self).__init__(
+ super().__init__(
timezone=timezone,
safecheck=safecheck,
assign_cols_by_name=assign_cols_by_name,
@@ -1353,7 +1341,7 @@ def dump_stream(self, iterator, stream):
"""
# Flatten: Iterator[Iterator[[(df, arrow_type)]]] -> Iterator[[(df, arrow_type)]]
flattened_iter = (batch for generator in iterator for batch in generator)
- super(GroupPandasUDFSerializer, self).dump_stream(flattened_iter, stream)
+ super().dump_stream(flattened_iter, stream)
def __repr__(self):
return "GroupPandasUDFSerializer"
@@ -1373,7 +1361,7 @@ class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
"""
def __init__(self, assign_cols_by_name):
- super(CogroupArrowUDFSerializer, self).__init__(assign_cols_by_name)
+ super().__init__(assign_cols_by_name)
def load_stream(self, stream):
"""
@@ -1458,7 +1446,7 @@ def __init__(
prefers_large_var_types,
int_to_decimal_coercion_enabled,
):
- super(ApplyInPandasWithStateSerializer, self).__init__(
+ super().__init__(
timezone,
safecheck,
assign_cols_by_name,
@@ -1842,7 +1830,7 @@ def __init__(
arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled,
):
- super(TransformWithStateInPandasSerializer, self).__init__(
+ super().__init__(
timezone,
safecheck,
assign_cols_by_name,
@@ -1967,7 +1955,7 @@ def __init__(
arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled,
):
- super(TransformWithStateInPandasInitStateSerializer, self).__init__(
+ super().__init__(
timezone,
safecheck,
assign_cols_by_name,
@@ -2108,7 +2096,7 @@ class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
"""
def __init__(self, arrow_max_records_per_batch):
- super(TransformWithStateInPySparkRowSerializer, self).__init__()
+ super().__init__()
self.arrow_max_records_per_batch = (
arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1
)
@@ -2197,9 +2185,7 @@ class TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp
"""
def __init__(self, arrow_max_records_per_batch):
- super(TransformWithStateInPySparkRowInitStateSerializer, self).__init__(
- arrow_max_records_per_batch
- )
+ super().__init__(arrow_max_records_per_batch)
self.init_key_offsets = None
def load_stream(self, stream):
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 4e45972d79b3..0d5353f0fb32 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -1624,12 +1624,12 @@ def createDataFrame( # type: ignore[misc]
if has_pandas and isinstance(data, pd.DataFrame):
# Create a DataFrame from pandas DataFrame.
- return super(SparkSession, self).createDataFrame( # type: ignore[call-overload]
+ return super().createDataFrame( # type: ignore[call-overload]
data, schema, samplingRatio, verifySchema
)
if has_pyarrow and isinstance(data, pa.Table):
# Create a DataFrame from PyArrow Table.
- return super(SparkSession, self).createDataFrame( # type: ignore[call-overload]
+ return super().createDataFrame( # type: ignore[call-overload]
data, schema, samplingRatio, verifySchema
)
return self._create_dataframe(
diff --git a/python/pyspark/sql/sql_formatter.py b/python/pyspark/sql/sql_formatter.py
index 011563d7006e..1366ef277c47 100644
--- a/python/pyspark/sql/sql_formatter.py
+++ b/python/pyspark/sql/sql_formatter.py
@@ -38,7 +38,7 @@ def __init__(self, session: "SparkSession") -> None:
self._temp_views: List[Tuple[DataFrame, str]] = []
def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any:
- obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs)
+ obj, first = super().get_field(field_name, args, kwargs)
return self._convert_value(obj, field_name), first
def _convert_value(self, val: Any, field_name: str) -> Optional[str]:
diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py b/python/pyspark/sql/tests/arrow/test_arrow.py
index 19e579cb6778..e410f7df711b 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow.py
@@ -1925,14 +1925,14 @@ def test_exception_by_max_results(self):
class EncryptionArrowTests(ArrowTests):
@classmethod
def conf(cls):
- return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true")
+ return super().conf().set("spark.io.encryption.enabled", "true")
class RDDBasedArrowTests(ArrowTests):
@classmethod
def conf(cls):
return (
- super(RDDBasedArrowTests, cls)
+ super()
.conf()
.set("spark.sql.execution.arrow.localRelationThreshold", "0")
# to test multiple partitions
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
index c315151d4d75..be46939b351f 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
@@ -50,15 +50,15 @@
class ArrowPythonUDFTestsMixin(BaseUDFTestsMixin):
@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
def test_broadcast_in_udf(self):
- super(ArrowPythonUDFTests, self).test_broadcast_in_udf()
+ super().test_broadcast_in_udf()
@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
def test_register_java_function(self):
- super(ArrowPythonUDFTests, self).test_register_java_function()
+ super().test_register_java_function()
@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
def test_register_java_udaf(self):
- super(ArrowPythonUDFTests, self).test_register_java_udaf()
+ super().test_register_java_udaf()
def test_complex_input_types(self):
row = (
@@ -485,7 +485,7 @@ def tearDownClass(cls):
class ArrowPythonUDFTests(ArrowPythonUDFTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
- super(ArrowPythonUDFTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
@classmethod
@@ -493,21 +493,21 @@ def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
- super(ArrowPythonUDFTests, cls).tearDownClass()
+ super().tearDownClass()
@unittest.skip("Duplicate test; it is tested separately in legacy and non-legacy tests")
def test_udf_binary_type(self):
- super(ArrowPythonUDFTests, self).test_udf_binary_type()
+ super().test_udf_binary_type()
@unittest.skip("Duplicate test; it is tested separately in legacy and non-legacy tests")
def test_udf_binary_type_in_nested_structures(self):
- super(ArrowPythonUDFTests, self).test_udf_binary_type_in_nested_structures()
+ super().test_udf_binary_type_in_nested_structures()
class ArrowPythonUDFLegacyTests(ArrowPythonUDFLegacyTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
- super(ArrowPythonUDFLegacyTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.concurrency.level", "4")
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
@@ -517,13 +517,13 @@ def tearDownClass(cls):
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.concurrency.level")
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
- super(ArrowPythonUDFLegacyTests, cls).tearDownClass()
+ super().tearDownClass()
class ArrowPythonUDFNonLegacyTests(ArrowPythonUDFNonLegacyTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
- super(ArrowPythonUDFNonLegacyTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
@classmethod
@@ -531,7 +531,7 @@ def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
- super(ArrowPythonUDFNonLegacyTests, cls).tearDownClass()
+ super().tearDownClass()
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index 74a81be37f80..844c7f111db4 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -1153,13 +1153,14 @@ def test_iterator_grouped_agg_partial_consumption(self):
# Create a dataset with multiple batches per group
# Use small batch size to ensure multiple batches per group
+ # Use same value for all data points to avoid ordering issues
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 2}):
df = self.spark.createDataFrame(
- [(1, 1.0), (1, 2.0), (1, 3.0), (1, 4.0), (2, 5.0), (2, 6.0)], ("id", "v")
+ [(1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0), (2, 1.0), (2, 1.0)], ("id", "v")
)
- @arrow_udf("double")
- def arrow_sum_partial(it: Iterator[pa.Array]) -> float:
+ @arrow_udf("struct")
+ def arrow_count_sum_partial(it: Iterator[pa.Array]) -> dict:
# Only consume first two batches, then return
# This tests that partial consumption works correctly
total = 0.0
@@ -1171,32 +1172,44 @@ def arrow_sum_partial(it: Iterator[pa.Array]) -> float:
else:
# Stop early - partial consumption
break
- return total / count if count > 0 else 0.0
+ return {"count": count, "sum": total}
- result = df.groupby("id").agg(arrow_sum_partial(df["v"]).alias("mean")).sort("id")
+ result = (
+ df.groupby("id").agg(arrow_count_sum_partial(df["v"]).alias("result")).sort("id")
+ )
# Verify results are correct for partial consumption
# With batch size = 2:
# Group 1 (id=1): 4 values in 2 batches -> processes both batches
- # Batch 1: [1.0, 2.0], Batch 2: [3.0, 4.0]
- # Result: (1.0+2.0+3.0+4.0)/4 = 2.5
+ # Batch 1: [1.0, 1.0], Batch 2: [1.0, 1.0]
+ # Result: count=4, sum=4.0
# Group 2 (id=2): 2 values in 1 batch -> processes 1 batch (only 1 batch available)
- # Batch 1: [5.0, 6.0]
- # Result: (5.0+6.0)/2 = 5.5
+ # Batch 1: [1.0, 1.0]
+ # Result: count=2, sum=2.0
actual = result.collect()
self.assertEqual(len(actual), 2, "Should have results for both groups")
# Verify both groups were processed correctly
# Group 1: processes 2 batches (all available)
group1_result = next(row for row in actual if row["id"] == 1)
+ self.assertEqual(
+ group1_result["result"]["count"],
+ 4,
+ msg="Group 1 should process 4 values (2 batches)",
+ )
self.assertAlmostEqual(
- group1_result["mean"], 2.5, places=5, msg="Group 1 should process 2 batches"
+ group1_result["result"]["sum"], 4.0, places=5, msg="Group 1 should sum to 4.0"
)
# Group 2: processes 1 batch (only batch available)
group2_result = next(row for row in actual if row["id"] == 2)
+ self.assertEqual(
+ group2_result["result"]["count"],
+ 2,
+ msg="Group 2 should process 2 values (1 batch)",
+ )
self.assertAlmostEqual(
- group2_result["mean"], 5.5, places=5, msg="Group 2 should process 1 batch"
+ group2_result["result"]["sum"], 2.0, places=5, msg="Group 2 should sum to 2.0"
)
diff --git a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
index 0f879087e7bb..46ba9e0e2e51 100644
--- a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
@@ -24,7 +24,7 @@
class ArrowPythonUDFParityTests(UDFParityTests, ArrowPythonUDFTestsMixin):
@classmethod
def setUpClass(cls):
- super(ArrowPythonUDFParityTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
@classmethod
@@ -32,7 +32,7 @@ def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
- super(ArrowPythonUDFParityTests, cls).tearDownClass()
+ super().tearDownClass()
class ArrowPythonUDFParityLegacyTestsMixin(ArrowPythonUDFTestsMixin):
@@ -84,7 +84,7 @@ def test_udf_binary_type_in_nested_structures(self):
class ArrowPythonUDFParityLegacyTests(UDFParityTests, ArrowPythonUDFParityLegacyTestsMixin):
@classmethod
def setUpClass(cls):
- super(ArrowPythonUDFParityLegacyTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
@classmethod
@@ -92,13 +92,13 @@ def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
- super(ArrowPythonUDFParityLegacyTests, cls).tearDownClass()
+ super().tearDownClass()
class ArrowPythonUDFParityNonLegacyTests(UDFParityTests, ArrowPythonUDFParityNonLegacyTestsMixin):
@classmethod
def setUpClass(cls):
- super(ArrowPythonUDFParityNonLegacyTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
@classmethod
@@ -106,7 +106,7 @@ def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
- super(ArrowPythonUDFParityNonLegacyTests, cls).tearDownClass()
+ super().tearDownClass()
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py
index 22865be4f42a..528fc10bbb28 100644
--- a/python/pyspark/sql/tests/connect/client/test_artifact.py
+++ b/python/pyspark/sql/tests/connect/client/test_artifact.py
@@ -203,7 +203,7 @@ def root(cls):
@classmethod
def setUpClass(cls):
- super(ArtifactTests, cls).setUpClass()
+ super().setUpClass()
cls.artifact_manager: ArtifactManager = cls.spark._client._artifact_manager
cls.base_resource_dir = os.path.join(SPARK_HOME, "data")
cls.artifact_file_path = os.path.join(
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index b789d7919c94..08e912a446e3 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -65,7 +65,7 @@ class SparkConnectSQLTestCase(ReusedMixedTestCase, PandasOnSparkTestUtils):
@classmethod
def setUpClass(cls):
- super(SparkConnectSQLTestCase, cls).setUpClass()
+ super().setUpClass()
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
cls.testDataStr = [Row(key=str(i)) for i in range(100)]
@@ -88,7 +88,7 @@ def tearDownClass(cls):
try:
cls.spark_connect_clean_up_test_data()
finally:
- super(SparkConnectSQLTestCase, cls).tearDownClass()
+ super().tearDownClass()
@classmethod
def spark_connect_load_test_data(cls):
@@ -1480,11 +1480,11 @@ class SparkConnectGCTests(SparkConnectSQLTestCase):
def setUpClass(cls):
cls.origin = os.getenv("USER", None)
os.environ["USER"] = "SparkConnectGCTests"
- super(SparkConnectGCTests, cls).setUpClass()
+ super().setUpClass()
@classmethod
def tearDownClass(cls):
- super(SparkConnectGCTests, cls).tearDownClass()
+ super().tearDownClass()
if cls.origin is not None:
os.environ["USER"] = cls.origin
else:
diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py
index 208f9ae53898..d3dfbc221041 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udtf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py
@@ -41,7 +41,7 @@
class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase):
@classmethod
def setUpClass(cls):
- super(UDTFParityTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "false")
@classmethod
@@ -49,7 +49,7 @@ def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
finally:
- super(UDTFParityTests, cls).tearDownClass()
+ super().tearDownClass()
def test_struct_output_type_casting_row(self):
self.check_struct_output_type_casting_row(PickleException)
@@ -94,7 +94,7 @@ def _add_file(self, path):
class LegacyArrowUDTFParityTests(LegacyUDTFArrowTestsMixin, UDTFParityTests):
@classmethod
def setUpClass(cls):
- super(LegacyArrowUDTFParityTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true")
cls.spark.conf.set(
"spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "true"
@@ -106,7 +106,7 @@ def tearDownClass(cls):
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled")
finally:
- super(LegacyArrowUDTFParityTests, cls).tearDownClass()
+ super().tearDownClass()
def test_udtf_access_spark_session_connect(self):
df = self.spark.range(10)
@@ -128,7 +128,7 @@ def eval(self):
class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests):
@classmethod
def setUpClass(cls):
- super(ArrowUDTFParityTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true")
cls.spark.conf.set(
"spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "false"
@@ -140,7 +140,7 @@ def tearDownClass(cls):
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled")
finally:
- super(ArrowUDTFParityTests, cls).tearDownClass()
+ super().tearDownClass()
def test_udtf_access_spark_session_connect(self):
df = self.spark.range(10)
diff --git a/python/pyspark/sql/tests/test_resources.py b/python/pyspark/sql/tests/test_resources.py
index 4ce61e9f763d..9b99672077fe 100644
--- a/python/pyspark/sql/tests/test_resources.py
+++ b/python/pyspark/sql/tests/test_resources.py
@@ -90,7 +90,7 @@ def setUpClass(cls):
@classmethod
def tearDownClass(cls):
- super(ResourceProfileTests, cls).tearDownClass()
+ super().tearDownClass()
cls.spark.stop()
diff --git a/python/pyspark/sql/tests/test_udf_profiler.py b/python/pyspark/sql/tests/test_udf_profiler.py
index e6a7bf40b945..4d565ecfd939 100644
--- a/python/pyspark/sql/tests/test_udf_profiler.py
+++ b/python/pyspark/sql/tests/test_udf_profiler.py
@@ -138,6 +138,23 @@ def iter_to_iter(iter: Iterator[pa.Array]) -> Iterator[pa.Array]:
self.spark.range(10).select(iter_to_iter("id")).collect()
+ def exec_arrow_udf_grouped_agg_iter(self):
+ import pyarrow as pa
+
+ @arrow_udf("double")
+ def arrow_mean_iter(it: Iterator[pa.Array]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for v in it:
+ sum_val += pa.compute.sum(v).as_py()
+ cnt += len(v)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+ df.groupby("id").agg(arrow_mean_iter(df["v"])).collect()
+
# Unsupported
def exec_map(self):
import pandas as pd
@@ -169,6 +186,15 @@ def test_unsupported(self):
"Profiling UDFs with iterators input/output is not supported" in str(user_warns[0])
)
+ with warnings.catch_warnings(record=True) as warns:
+ warnings.simplefilter("always")
+ self.exec_arrow_udf_grouped_agg_iter()
+ user_warns = [warn.message for warn in warns if isinstance(warn.message, UserWarning)]
+ self.assertTrue(len(user_warns) > 0)
+ self.assertTrue(
+ "Profiling UDFs with iterators input/output is not supported" in str(user_warns[0])
+ )
+
with warnings.catch_warnings(record=True) as warns:
warnings.simplefilter("always")
self.exec_map()
@@ -486,6 +512,31 @@ def min_udf(v: pa.Array) -> float:
for id in self.profile_results:
self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
+ @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
+ def test_perf_profiler_arrow_udf_grouped_agg_iter(self):
+ import pyarrow as pa
+ from typing import Iterator
+
+ @arrow_udf("double")
+ def arrow_mean_iter(it: Iterator[pa.Array]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for v in it:
+ sum_val += pa.compute.sum(v).as_py()
+ cnt += len(v)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+ df.groupBy(df.id).agg(arrow_mean_iter(df["v"])).show()
+
+ self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
+
+ for id in self.profile_results:
+ self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
+
@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py
index 389df5b5a6cf..eeb07400b060 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -812,6 +812,111 @@ def _check_result_or_exception(
with self.assertRaisesRegex(err_type, expected):
func().collect()
+ def test_udtf_nullable_check(self):
+ for ret_type, value, expected in [
+ (
+ StructType([StructField("value", ArrayType(IntegerType(), False))]),
+ ([None],),
+ "PySparkRuntimeError",
+ ),
+ (
+ StructType([StructField("value", ArrayType(IntegerType(), True))]),
+ ([None],),
+ [Row(value=[None])],
+ ),
+ (
+ StructType([StructField("value", MapType(StringType(), IntegerType(), False))]),
+ ({"a": None},),
+ "PySparkRuntimeError",
+ ),
+ (
+ StructType([StructField("value", MapType(StringType(), IntegerType(), True))]),
+ ({"a": None},),
+ [Row(value={"a": None})],
+ ),
+ (
+ StructType([StructField("value", MapType(StringType(), IntegerType(), True))]),
+ ({None: 1},),
+ "PySparkRuntimeError",
+ ),
+ (
+ StructType([StructField("value", MapType(StringType(), IntegerType(), False))]),
+ ({None: 1},),
+ "PySparkRuntimeError",
+ ),
+ (
+ StructType(
+ [
+ StructField(
+ "value", MapType(StringType(), ArrayType(IntegerType(), False), False)
+ )
+ ]
+ ),
+ ({"s": [None]},),
+ "PySparkRuntimeError",
+ ),
+ (
+ StructType(
+ [
+ StructField(
+ "value",
+ MapType(
+ StructType([StructField("value", StringType(), False)]),
+ IntegerType(),
+ False,
+ ),
+ )
+ ]
+ ),
+ ({(None,): 1},),
+ "PySparkRuntimeError",
+ ),
+ (
+ StructType(
+ [
+ StructField(
+ "value",
+ MapType(
+ StructType([StructField("value", StringType(), False)]),
+ IntegerType(),
+ True,
+ ),
+ )
+ ]
+ ),
+ ({(None,): 1},),
+ "PySparkRuntimeError",
+ ),
+ (
+ StructType(
+ [StructField("value", StructType([StructField("value", StringType(), False)]))]
+ ),
+ ((None,),),
+ "PySparkRuntimeError",
+ ),
+ (
+ StructType(
+ [
+ StructField(
+ "value",
+ StructType(
+ [StructField("value", ArrayType(StringType(), False), False)]
+ ),
+ )
+ ]
+ ),
+ (([None],),),
+ "PySparkRuntimeError",
+ ),
+ ]:
+
+ class TestUDTF:
+ def eval(self):
+ yield value
+
+ with self.subTest(ret_type=ret_type, value=value):
+ self._check_result_or_exception(TestUDTF, ret_type, expected)
+
def test_numeric_output_type_casting(self):
class TestUDTF:
def eval(self):
@@ -3183,7 +3288,7 @@ def eval(self, x: int):
class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
- super(UDTFTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "false")
@classmethod
@@ -3191,7 +3296,7 @@ def tearDownClass(cls):
try:
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
finally:
- super(UDTFTests, cls).tearDownClass()
+ super().tearDownClass()
@unittest.skipIf(
@@ -3521,7 +3626,7 @@ def eval(self):
class LegacyUDTFArrowTests(LegacyUDTFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
- super(LegacyUDTFArrowTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true")
cls.spark.conf.set(
"spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "true"
@@ -3533,7 +3638,7 @@ def tearDownClass(cls):
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled")
finally:
- super(LegacyUDTFArrowTests, cls).tearDownClass()
+ super().tearDownClass()
class UDTFArrowTestsMixin(LegacyUDTFArrowTestsMixin):
@@ -3780,7 +3885,7 @@ def eval(self, v: float):
class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
- super(UDTFArrowTests, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true")
cls.spark.conf.set(
"spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "false"
@@ -3792,7 +3897,7 @@ def tearDownClass(cls):
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled")
finally:
- super(UDTFArrowTests, cls).tearDownClass()
+ super().tearDownClass()
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py b/python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py
index c25f35814b43..99127c8bbe90 100644
--- a/python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py
+++ b/python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py
@@ -70,10 +70,10 @@
class UDFInputTypeTests(ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
- super(UDFInputTypeTests, cls).setUpClass()
+ super().setUpClass()
def setUp(self):
- super(UDFInputTypeTests, self).setUp()
+ super().setUp()
def test_udf_input_types_arrow_disabled(self):
golden_file = os.path.join(
diff --git a/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py b/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py
index 5e307110559b..9b69d6b20ba2 100644
--- a/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py
+++ b/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py
@@ -73,10 +73,10 @@
class UDFReturnTypeTests(ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
- super(UDFReturnTypeTests, cls).setUpClass()
+ super().setUpClass()
def setUp(self):
- super(UDFReturnTypeTests, self).setUp()
+ super().setUp()
self.test_data = [
None,
True,
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 675d25f23aea..4c36a2db5323 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -242,9 +242,7 @@ class DataTypeSingleton(type):
def __call__(cls: Type[T]) -> T:
if cls not in cls._instances: # type: ignore[attr-defined]
- cls._instances[cls] = super( # type: ignore[misc, attr-defined]
- DataTypeSingleton, cls
- ).__call__()
+ cls._instances[cls] = super().__call__() # type: ignore[misc, attr-defined]
return cls._instances[cls] # type: ignore[attr-defined]
@@ -3624,7 +3622,7 @@ def __contains__(self, item: Any) -> bool:
if hasattr(self, "__fields__"):
return item in self.__fields__
else:
- return super(Row, self).__contains__(item)
+ return super().__contains__(item)
# let object acts like class
def __call__(self, *args: Any) -> "Row":
@@ -3642,12 +3640,12 @@ def __call__(self, *args: Any) -> "Row":
def __getitem__(self, item: Any) -> Any:
if isinstance(item, (int, slice)):
- return super(Row, self).__getitem__(item)
+ return super().__getitem__(item)
try:
# it will be slow when it has many fields,
# but this will not be used in normal cases
idx = self.__fields__.index(item)
- return super(Row, self).__getitem__(idx)
+ return super().__getitem__(idx)
except IndexError:
raise PySparkKeyError(errorClass="KEY_NOT_EXISTS", messageParameters={"key": str(item)})
except ValueError:
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index c7471d19f7d6..75bcb66efdb8 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -438,6 +438,7 @@ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column:
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
]:
warnings.warn(
"Profiling UDFs with iterators input/output is not supported.",
diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py
index ee86f5c03974..6e3c282a41ee 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -234,7 +234,7 @@ def quiet(self):
class ReusedMixedTestCase(ReusedConnectTestCase, SQLTestUtils):
@classmethod
def setUpClass(cls):
- super(ReusedMixedTestCase, cls).setUpClass()
+ super().setUpClass()
# Disable the shared namespace so pyspark.sql.functions, etc point the regular
# PySpark libraries.
os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
@@ -250,7 +250,7 @@ def tearDownClass(cls):
cls.spark = cls.connect
del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
finally:
- super(ReusedMixedTestCase, cls).tearDownClass()
+ super().tearDownClass()
def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
from pyspark.sql.classic.dataframe import DataFrame as SDF
diff --git a/python/pyspark/testing/mlutils.py b/python/pyspark/testing/mlutils.py
index aa3e23bccb19..491c539acc0e 100644
--- a/python/pyspark/testing/mlutils.py
+++ b/python/pyspark/testing/mlutils.py
@@ -110,7 +110,7 @@ def __init__(self):
class HasFake(Params):
def __init__(self):
- super(HasFake, self).__init__()
+ super().__init__()
self.fake = Param(self, "fake", "fake param")
def getFake(self):
@@ -119,7 +119,7 @@ def getFake(self):
class MockTransformer(Transformer, HasFake):
def __init__(self):
- super(MockTransformer, self).__init__()
+ super().__init__()
self.dataset_index = None
def _transform(self, dataset):
@@ -137,7 +137,7 @@ class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParam
)
def __init__(self, shiftVal=1):
- super(MockUnaryTransformer, self).__init__()
+ super().__init__()
self._setDefault(shift=1)
self._set(shift=shiftVal)
@@ -161,7 +161,7 @@ def validateInputType(self, inputType):
class MockEstimator(Estimator, HasFake):
def __init__(self):
- super(MockEstimator, self).__init__()
+ super().__init__()
self.dataset_index = None
def _fit(self, dataset):
@@ -198,7 +198,7 @@ def __init__(
regParam=0.0,
rawPredictionCol="rawPrediction",
):
- super(DummyLogisticRegression, self).__init__()
+ super().__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -229,7 +229,7 @@ class DummyLogisticRegressionModel(
DefaultParamsWritable,
):
def __init__(self):
- super(DummyLogisticRegressionModel, self).__init__()
+ super().__init__()
def _transform(self, dataset):
# A dummy transform impl which always predict label 1
diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py
index e53240586d59..11b741947604 100644
--- a/python/pyspark/testing/pandasutils.py
+++ b/python/pyspark/testing/pandasutils.py
@@ -473,7 +473,7 @@ def _to_pandas(obj: Any):
class PandasOnSparkTestCase(ReusedSQLTestCase, PandasOnSparkTestUtils):
@classmethod
def setUpClass(cls):
- super(PandasOnSparkTestCase, cls).setUpClass()
+ super().setUpClass()
cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True)
def setUp(self):
diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py
index 22f75bb931b1..b63c98f96f4e 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -208,7 +208,7 @@ def assert_close(a, b):
class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils, PySparkErrorTestUtils):
@classmethod
def setUpClass(cls):
- super(ReusedSQLTestCase, cls).setUpClass()
+ super().setUpClass()
cls._legacy_sc = cls.sc
cls.spark = SparkSession(cls.sc)
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
@@ -218,7 +218,7 @@ def setUpClass(cls):
@classmethod
def tearDownClass(cls):
- super(ReusedSQLTestCase, cls).tearDownClass()
+ super().tearDownClass()
cls.spark.stop()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py
index 3961997120ed..eb50a105cc77 100644
--- a/python/pyspark/tests/test_worker.py
+++ b/python/pyspark/tests/test_worker.py
@@ -227,7 +227,7 @@ def tearDown(self):
class WorkerSegfaultTest(ReusedPySparkTestCase):
@classmethod
def conf(cls):
- _conf = super(WorkerSegfaultTest, cls).conf()
+ _conf = super().conf()
_conf.set("spark.python.worker.faulthandler.enabled", "true")
return _conf
@@ -252,7 +252,7 @@ def f():
class WorkerSegfaultNonDaemonTest(WorkerSegfaultTest):
@classmethod
def conf(cls):
- _conf = super(WorkerSegfaultNonDaemonTest, cls).conf()
+ _conf = super().conf()
_conf.set("spark.python.use.daemon", "false")
return _conf
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 1cf913287d56..f8750fbbec2e 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -564,9 +564,7 @@ def copy_local_properties(*a: Any, **k: Any) -> Any:
thread_local.tags = self._tags # type: ignore[has-type]
return target(*a, **k)
- super(InheritableThread, self).__init__(
- target=copy_local_properties, *args, **kwargs # type: ignore[misc]
- )
+ super().__init__(target=copy_local_properties, *args, **kwargs) # type: ignore[misc]
else:
# Non Spark Connect
from pyspark import SparkContext
@@ -585,13 +583,11 @@ def copy_local_properties(*a: Any, **k: Any) -> Any:
SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props)
return target(*a, **k)
- super(InheritableThread, self).__init__(
+ super().__init__(
target=copy_local_properties, *args, **kwargs # type: ignore[misc]
)
else:
- super(InheritableThread, self).__init__(
- target=target, *args, **kwargs # type: ignore[misc]
- )
+ super().__init__(target=target, *args, **kwargs) # type: ignore[misc]
def start(self) -> None:
from pyspark.sql import is_remote
@@ -619,7 +615,7 @@ def start(self) -> None:
if self._session is not None:
self._tags = self._session.getTags()
- return super(InheritableThread, self).start()
+ return super().start()
class PythonEvalType:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 109157e2c339..65dcbbbf23e6 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1514,10 +1514,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil
# It expects the UDTF to be in a specific format and performs various checks to
# ensure the UDTF is valid. This function also prepares a mapper function for applying
# the UDTF logic to input rows.
-def read_udtf(pickleSer, infile, eval_type):
+def read_udtf(pickleSer, infile, eval_type, runner_conf):
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
- # Load conf used for arrow evaluation.
- runner_conf = RunnerConf(infile)
input_types = [
field.dataType for field in _parse_datatype_json_string(utf8_deserializer.loads(infile))
]
@@ -1532,7 +1530,6 @@ def read_udtf(pickleSer, infile, eval_type):
else:
ser = ArrowStreamUDTFSerializer()
elif eval_type == PythonEvalType.SQL_ARROW_UDTF:
- runner_conf = RunnerConf(infile)
# Read the table argument offsets
num_table_arg_offsets = read_int(infile)
table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)]
@@ -1540,7 +1537,6 @@ def read_udtf(pickleSer, infile, eval_type):
ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets)
else:
# Each row is a group so do not batch but send one by one.
- runner_conf = RunnerConf()
ser = BatchedSerializer(CPickleSerializer(), 1)
# See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
@@ -2688,7 +2684,7 @@ def mapper(_, it):
return mapper, None, ser, ser
-def read_udfs(pickleSer, infile, eval_type):
+def read_udfs(pickleSer, infile, eval_type, runner_conf):
state_server_port = None
key_schema = None
if eval_type in (
@@ -2716,9 +2712,6 @@ def read_udfs(pickleSer, infile, eval_type):
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
):
- # Load conf used for pandas_udf evaluation
- runner_conf = RunnerConf(infile)
-
state_object_schema = None
if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
@@ -2870,7 +2863,6 @@ def read_udfs(pickleSer, infile, eval_type):
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
else:
- runner_conf = RunnerConf()
batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100"))
ser = BatchedSerializer(CPickleSerializer(), batch_size)
@@ -3353,6 +3345,7 @@ def main(infile, outfile):
_accumulatorRegistry.clear()
eval_type = read_int(infile)
+ runner_conf = RunnerConf(infile)
if eval_type == PythonEvalType.NON_UDF:
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
elif eval_type in (
@@ -3360,9 +3353,13 @@ def main(infile, outfile):
PythonEvalType.SQL_ARROW_TABLE_UDF,
PythonEvalType.SQL_ARROW_UDTF,
):
- func, profiler, deserializer, serializer = read_udtf(pickleSer, infile, eval_type)
+ func, profiler, deserializer, serializer = read_udtf(
+ pickleSer, infile, eval_type, runner_conf
+ )
else:
- func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)
+ func, profiler, deserializer, serializer = read_udfs(
+ pickleSer, infile, eval_type, runner_conf
+ )
init_time = time.time()
diff --git a/python/test_support/test_pytorch_training_file.py b/python/test_support/test_pytorch_training_file.py
index 4107197acfd8..150246563f09 100644
--- a/python/test_support/test_pytorch_training_file.py
+++ b/python/test_support/test_pytorch_training_file.py
@@ -32,7 +32,7 @@
class Net(nn.Module):
def __init__(self):
- super(Net, self).__init__()
+ super().__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/ClusteredDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/ClusteredDistribution.java
index dcc3d191461c..0fa77d259e9d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/ClusteredDistribution.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/ClusteredDistribution.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.connector.distributions;
-import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
/**
@@ -26,7 +26,7 @@
*
* @since 3.2.0
*/
-@Experimental
+@Evolving
public interface ClusteredDistribution extends Distribution {
/**
* Returns clustering expressions.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distribution.java
index 95d68ea2d1ab..0a2f982fce07 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distribution.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distribution.java
@@ -17,12 +17,12 @@
package org.apache.spark.sql.connector.distributions;
-import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.Evolving;
/**
* An interface that defines how data is distributed across partitions.
*
* @since 3.2.0
*/
-@Experimental
+@Evolving
public interface Distribution {}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distributions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distributions.java
index da5d6f8c81a3..6a346a25424f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distributions.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distributions.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.connector.distributions;
-import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.connector.expressions.SortOrder;
@@ -26,7 +26,7 @@
*
* @since 3.2.0
*/
-@Experimental
+@Evolving
public class Distributions {
private Distributions() {
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/OrderedDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/OrderedDistribution.java
index 3456178d8e64..f959cc2e00ce 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/OrderedDistribution.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/OrderedDistribution.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.connector.distributions;
-import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.SortOrder;
/**
@@ -26,7 +26,7 @@
*
* @since 3.2.0
*/
-@Experimental
+@Evolving
public interface OrderedDistribution extends Distribution {
/**
* Returns ordering expressions.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/UnspecifiedDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/UnspecifiedDistribution.java
index ea18d8906cfd..4749701e348e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/UnspecifiedDistribution.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/UnspecifiedDistribution.java
@@ -17,12 +17,12 @@
package org.apache.spark.sql.connector.distributions;
-import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.Evolving;
/**
* A distribution where no promises are made about co-location of data.
*
* @since 3.2.0
*/
-@Experimental
+@Evolving
public interface UnspecifiedDistribution extends Distribution {}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NullOrdering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NullOrdering.java
index 4aca199c11c0..7d457a62b313 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NullOrdering.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NullOrdering.java
@@ -17,14 +17,14 @@
package org.apache.spark.sql.connector.expressions;
-import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.Evolving;
/**
* A null order used in sorting expressions.
*
* @since 3.2.0
*/
-@Experimental
+@Evolving
public enum NullOrdering {
NULLS_FIRST, NULLS_LAST;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortDirection.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortDirection.java
index 7e3a29945cc9..385154e0fd83 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortDirection.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortDirection.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.connector.expressions;
-import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.Evolving;
import static org.apache.spark.sql.connector.expressions.NullOrdering.NULLS_FIRST;
import static org.apache.spark.sql.connector.expressions.NullOrdering.NULLS_LAST;
@@ -30,7 +30,7 @@
*
* @since 3.2.0
*/
-@Experimental
+@Evolving
public enum SortDirection {
ASCENDING(NULLS_FIRST), DESCENDING(NULLS_LAST);
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java
index 51401786ca5d..45f06e17de6d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java
@@ -17,14 +17,14 @@
package org.apache.spark.sql.connector.expressions;
-import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.Evolving;
/**
* Represents a sort order in the public expression API.
*
* @since 3.2.0
*/
-@Experimental
+@Evolving
public interface SortOrder extends Expression {
/**
* Returns the sort expression.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java
index 2adfe75f7d80..dbef9dd6146a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.connector.write;
-import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.distributions.Distribution;
import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution;
import org.apache.spark.sql.connector.expressions.SortOrder;
@@ -27,7 +27,7 @@
*
* @since 3.2.0
*/
-@Experimental
+@Evolving
public interface RequiresDistributionAndOrdering extends Write {
/**
* Returns the distribution required by this write.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 67d25296a1e2..311c1c946fbe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1736,9 +1736,8 @@ class Analyzer(
Assignment(key, sourceAttr)
}
} else {
- sourceTable.output.flatMap { sourceAttr =>
- findAttrInTarget(sourceAttr.name).map(
- targetAttr => Assignment(targetAttr, sourceAttr))
+ targetTable.output.map { attr =>
+ Assignment(attr, UnresolvedAttribute(Seq(attr.name)))
}
}
UpdateAction(
@@ -1775,9 +1774,8 @@ class Analyzer(
Assignment(key, sourceAttr)
}
} else {
- sourceTable.output.flatMap { sourceAttr =>
- findAttrInTarget(sourceAttr.name).map(
- targetAttr => Assignment(targetAttr, sourceAttr))
+ targetTable.output.map { attr =>
+ Assignment(attr, UnresolvedAttribute(Seq(attr.name)))
}
}
InsertAction(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
index d1b8eab13191..bf1016ba8268 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
@@ -53,7 +53,7 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && m.rewritable && !m.aligned &&
!m.needSchemaEvolution =>
validateStoreAssignmentPolicy()
- val coerceNestedTypes = SQLConf.get.coerceMergeNestedTypes
+ val coerceNestedTypes = SQLConf.get.coerceMergeNestedTypes && m.withSchemaEvolution
m.copy(
targetTable = cleanAttrMetadata(m.targetTable),
matchedActions = alignActions(m.targetTable.output, m.matchedActions,
diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala
index 9a1db36514da..fdce056e664d 100644
--- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala
+++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.connect.client.jdbc
-import java.sql.{ResultSet, SQLException, Types}
+import java.math.{BigDecimal => JBigDecimal}
+import java.nio.charset.StandardCharsets
+import java.sql.{Date, ResultSet, SQLException, Time, Timestamp, Types}
+import java.util.{Calendar, TimeZone}
import scala.util.Using
@@ -223,344 +226,359 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess
}
test("get decimal type") {
- Seq(
- ("123.45", 37, 2, 39),
- ("-0.12345", 5, 5, 8),
- ("-0.12345", 6, 5, 8),
- ("-123.45", 5, 2, 7),
- ("12345", 5, 0, 6),
- ("-12345", 5, 0, 6)
- ).foreach {
- case (value, precision, scale, expectedColumnDisplaySize) =>
- val decimalType = s"DECIMAL($precision,$scale)"
- withExecuteQuery(s"SELECT cast('$value' as $decimalType)") { rs =>
- assert(rs.next())
- assert(rs.getBigDecimal(1) === new java.math.BigDecimal(value))
- assert(!rs.wasNull)
- assert(!rs.next())
-
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnName(1) === s"CAST($value AS $decimalType)")
- assert(metaData.getColumnLabel(1) === s"CAST($value AS $decimalType)")
- assert(metaData.getColumnType(1) === Types.DECIMAL)
- assert(metaData.getColumnTypeName(1) === decimalType)
- assert(metaData.getColumnClassName(1) === "java.math.BigDecimal")
- assert(metaData.isSigned(1) === true)
- assert(metaData.getPrecision(1) === precision)
- assert(metaData.getScale(1) === scale)
- assert(metaData.getColumnDisplaySize(1) === expectedColumnDisplaySize)
- assert(metaData.getColumnDisplaySize(1) >= value.size)
- }
+ withStatement { stmt =>
+ Seq(
+ ("123.45", 37, 2, 39),
+ ("-0.12345", 5, 5, 8),
+ ("-0.12345", 6, 5, 8),
+ ("-123.45", 5, 2, 7),
+ ("12345", 5, 0, 6),
+ ("-12345", 5, 0, 6)
+ ).foreach {
+ case (value, precision, scale, expectedColumnDisplaySize) =>
+ val decimalType = s"DECIMAL($precision,$scale)"
+ withExecuteQuery(stmt, s"SELECT cast('$value' as $decimalType)") { rs =>
+ assert(rs.next())
+ assert(rs.getBigDecimal(1) === new JBigDecimal(value))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === s"CAST($value AS $decimalType)")
+ assert(metaData.getColumnLabel(1) === s"CAST($value AS $decimalType)")
+ assert(metaData.getColumnType(1) === Types.DECIMAL)
+ assert(metaData.getColumnTypeName(1) === decimalType)
+ assert(metaData.getColumnClassName(1) === "java.math.BigDecimal")
+ assert(metaData.isSigned(1) === true)
+ assert(metaData.getPrecision(1) === precision)
+ assert(metaData.getScale(1) === scale)
+ assert(metaData.getColumnDisplaySize(1) === expectedColumnDisplaySize)
+ assert(metaData.getColumnDisplaySize(1) >= value.size)
+ }
+ }
}
}
test("getter functions column index out of bound") {
- Seq(
- ("'foo'", (rs: ResultSet) => rs.getString(999)),
- ("true", (rs: ResultSet) => rs.getBoolean(999)),
- ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(999)),
- ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(999)),
- ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(999)),
- ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(999)),
- ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(999)),
- ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(999)),
- ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(999)),
- ("CAST(X'0A0B0C' AS BINARY)", (rs: ResultSet) => rs.getBytes(999)),
- ("date '2025-11-15'", (rs: ResultSet) => rs.getBytes(999)),
- ("time '12:34:56.123456'", (rs: ResultSet) => rs.getBytes(999)),
- ("timestamp '2025-11-15 10:30:45.123456'", (rs: ResultSet) => rs.getTimestamp(999)),
- ("timestamp_ntz '2025-11-15 10:30:45.789012'", (rs: ResultSet) => rs.getTimestamp(999))
- ).foreach {
- case (query, getter) =>
- withExecuteQuery(s"SELECT $query") { rs =>
- assert(rs.next())
- val exception = intercept[SQLException] {
- getter(rs)
+ withStatement { stmt =>
+ Seq(
+ ("'foo'", (rs: ResultSet) => rs.getString(999)),
+ ("true", (rs: ResultSet) => rs.getBoolean(999)),
+ ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(999)),
+ ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(999)),
+ ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(999)),
+ ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(999)),
+ ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(999)),
+ ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(999)),
+ ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(999)),
+ ("CAST(X'0A0B0C' AS BINARY)", (rs: ResultSet) => rs.getBytes(999)),
+ ("date '2025-11-15'", (rs: ResultSet) => rs.getBytes(999)),
+ ("time '12:34:56.123456'", (rs: ResultSet) => rs.getBytes(999)),
+ ("timestamp '2025-11-15 10:30:45.123456'", (rs: ResultSet) => rs.getTimestamp(999)),
+ ("timestamp_ntz '2025-11-15 10:30:45.789012'", (rs: ResultSet) => rs.getTimestamp(999))
+ ).foreach {
+ case (query, getter) =>
+ withExecuteQuery(stmt, s"SELECT $query") { rs =>
+ assert(rs.next())
+ val exception = intercept[SQLException] {
+ getter(rs)
+ }
+ assert(exception.getMessage() ===
+ "The column index is out of range: 999, number of columns: 1.")
}
- assert(exception.getMessage() ===
- "The column index is out of range: 999, number of columns: 1.")
- }
+ }
}
}
test("getter functions called after statement closed") {
- Seq(
- ("'foo'", (rs: ResultSet) => rs.getString(1), "foo"),
- ("true", (rs: ResultSet) => rs.getBoolean(1), true),
- ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(1), 1.toByte),
- ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(1), 1.toShort),
- ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(1), 1.toInt),
- ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(1), 1.toLong),
- ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(1), 1.toFloat),
- ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(1), 1.toDouble),
- ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(1),
- new java.math.BigDecimal("1.00000")),
- ("CAST(X'0A0B0C' AS BINARY)", (rs: ResultSet) => rs.getBytes(1),
- Array[Byte](0x0A, 0x0B, 0x0C)),
- ("date '2023-11-15'", (rs: ResultSet) => rs.getDate(1),
- java.sql.Date.valueOf("2023-11-15")),
- ("time '12:34:56.123456'", (rs: ResultSet) => rs.getTime(1), {
- val millis = timeToMillis(12, 34, 56, 123)
- new java.sql.Time(millis)
- })
- ).foreach {
- case (query, getter, expectedValue) =>
- var resultSet: Option[ResultSet] = None
- withExecuteQuery(s"SELECT $query") { rs =>
- assert(rs.next())
- expectedValue match {
- case arr: Array[Byte] => assert(getter(rs).asInstanceOf[Array[Byte]].sameElements(arr))
- case other => assert(getter(rs) === other)
+ withStatement { stmt =>
+ Seq(
+ ("'foo'", (rs: ResultSet) => rs.getString(1), "foo"),
+ ("true", (rs: ResultSet) => rs.getBoolean(1), true),
+ ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(1), 1.toByte),
+ ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(1), 1.toShort),
+ ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(1), 1.toInt),
+ ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(1), 1.toLong),
+ ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(1), 1.toFloat),
+ ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(1), 1.toDouble),
+ ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(1),
+ new JBigDecimal("1.00000")),
+ ("CAST(X'0A0B0C' AS BINARY)", (rs: ResultSet) => rs.getBytes(1),
+ Array[Byte](0x0A, 0x0B, 0x0C)),
+ ("date '2023-11-15'", (rs: ResultSet) => rs.getDate(1),
+ Date.valueOf("2023-11-15")),
+ ("time '12:34:56.123456'", (rs: ResultSet) => rs.getTime(1), {
+ val millis = timeToMillis(12, 34, 56, 123)
+ new Time(millis)
+ })
+ ).foreach {
+ case (query, getter, expectedValue) =>
+ var resultSet: Option[ResultSet] = None
+ withExecuteQuery(stmt, s"SELECT $query") { rs =>
+ assert(rs.next())
+ expectedValue match {
+ case arr: Array[Byte] =>
+ assert(getter(rs).asInstanceOf[Array[Byte]].sameElements(arr))
+ case other => assert(getter(rs) === other)
+ }
+ assert(!rs.wasNull)
+ resultSet = Some(rs)
}
- assert(!rs.wasNull)
- resultSet = Some(rs)
- }
- assert(resultSet.isDefined)
- val exception = intercept[SQLException] {
- getter(resultSet.get)
- }
- assert(exception.getMessage() === "JDBC Statement is closed.")
+ assert(resultSet.isDefined)
+ val exception = intercept[SQLException] {
+ getter(resultSet.get)
+ }
+ assert(exception.getMessage() === "JDBC Statement is closed.")
+ }
}
}
test("get date type") {
- withExecuteQuery("SELECT date '2023-11-15'") { rs =>
- assert(rs.next())
- assert(rs.getDate(1) === java.sql.Date.valueOf("2023-11-15"))
- assert(!rs.wasNull)
- assert(!rs.next())
-
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnName(1) === "DATE '2023-11-15'")
- assert(metaData.getColumnLabel(1) === "DATE '2023-11-15'")
- assert(metaData.getColumnType(1) === Types.DATE)
- assert(metaData.getColumnTypeName(1) === "DATE")
- assert(metaData.getColumnClassName(1) === "java.sql.Date")
- assert(metaData.isSigned(1) === false)
- assert(metaData.getPrecision(1) === 10)
- assert(metaData.getScale(1) === 0)
- assert(metaData.getColumnDisplaySize(1) === 10)
- }
- }
-
- test("get date type with null") {
- withExecuteQuery("SELECT cast(null as date)") { rs =>
- assert(rs.next())
- assert(rs.getDate(1) === null)
- assert(rs.wasNull)
- assert(!rs.next())
+ withStatement { stmt =>
+ // Test basic date type
+ withExecuteQuery(stmt, "SELECT date '2023-11-15'") { rs =>
+ assert(rs.next())
+ assert(rs.getDate(1) === Date.valueOf("2023-11-15"))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "DATE '2023-11-15'")
+ assert(metaData.getColumnLabel(1) === "DATE '2023-11-15'")
+ assert(metaData.getColumnType(1) === Types.DATE)
+ assert(metaData.getColumnTypeName(1) === "DATE")
+ assert(metaData.getColumnClassName(1) === "java.sql.Date")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 10)
+ assert(metaData.getScale(1) === 0)
+ assert(metaData.getColumnDisplaySize(1) === 10)
+ }
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnName(1) === "CAST(NULL AS DATE)")
- assert(metaData.getColumnLabel(1) === "CAST(NULL AS DATE)")
- assert(metaData.getColumnType(1) === Types.DATE)
- assert(metaData.getColumnTypeName(1) === "DATE")
- assert(metaData.getColumnClassName(1) === "java.sql.Date")
- assert(metaData.isSigned(1) === false)
- assert(metaData.getPrecision(1) === 10)
- assert(metaData.getScale(1) === 0)
- assert(metaData.getColumnDisplaySize(1) === 10)
- }
- }
+ // Test date type with null
+ withExecuteQuery(stmt, "SELECT cast(null as date)") { rs =>
+ assert(rs.next())
+ assert(rs.getDate(1) === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "CAST(NULL AS DATE)")
+ assert(metaData.getColumnLabel(1) === "CAST(NULL AS DATE)")
+ assert(metaData.getColumnType(1) === Types.DATE)
+ assert(metaData.getColumnTypeName(1) === "DATE")
+ assert(metaData.getColumnClassName(1) === "java.sql.Date")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 10)
+ assert(metaData.getScale(1) === 0)
+ assert(metaData.getColumnDisplaySize(1) === 10)
+ }
- test("get date type by column label") {
- withExecuteQuery("SELECT date '2025-11-15' as test_date") { rs =>
- assert(rs.next())
- assert(rs.getDate("test_date") === java.sql.Date.valueOf("2025-11-15"))
- assert(!rs.wasNull)
- assert(!rs.next())
+ // Test date type by column label
+ withExecuteQuery(stmt, "SELECT date '2025-11-15' as test_date") { rs =>
+ assert(rs.next())
+ assert(rs.getDate("test_date") === Date.valueOf("2025-11-15"))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+ }
}
}
test("get binary type") {
- val testBytes = Array[Byte](0x01, 0x02, 0x03, 0x04, 0x05)
- val hexString = testBytes.map(b => "%02X".format(b)).mkString
- withExecuteQuery(s"SELECT CAST(X'$hexString' AS BINARY)") { rs =>
- assert(rs.next())
- val bytes = rs.getBytes(1)
- assert(bytes !== null)
- assert(bytes.length === testBytes.length)
- assert(bytes.sameElements(testBytes))
- assert(!rs.wasNull)
+ withStatement { stmt =>
+ // Test basic binary type
+ val testBytes = Array[Byte](0x01, 0x02, 0x03, 0x04, 0x05)
+ val hexString = testBytes.map(b => "%02X".format(b)).mkString
+ withExecuteQuery(stmt, s"SELECT CAST(X'$hexString' AS BINARY)") { rs =>
+ assert(rs.next())
+ val bytes = rs.getBytes(1)
+ assert(bytes !== null)
+ assert(bytes.length === testBytes.length)
+ assert(bytes.sameElements(testBytes))
+ assert(!rs.wasNull)
val stringValue = rs.getString(1)
- val expectedString = new String(testBytes, java.nio.charset.StandardCharsets.UTF_8)
+ val expectedString = new String(testBytes, StandardCharsets.UTF_8)
assert(stringValue === expectedString)
- assert(!rs.next())
-
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnType(1) === Types.VARBINARY)
- assert(metaData.getColumnTypeName(1) === "BINARY")
- assert(metaData.getColumnClassName(1) === "[B")
- assert(metaData.isSigned(1) === false)
- }
- }
+ assert(!rs.next())
- test("get binary type with UTF-8 text") {
- val textBytes = "\\xDeAdBeEf".getBytes(java.nio.charset.StandardCharsets.UTF_8)
- val hexString = textBytes.map(b => "%02X".format(b)).mkString
- withExecuteQuery(s"SELECT CAST(X'$hexString' AS BINARY)") { rs =>
- assert(rs.next())
- val bytes = rs.getBytes(1)
- assert(bytes !== null)
- assert(bytes.sameElements(textBytes))
-
- val stringValue = rs.getString(1)
- assert(stringValue === "\\xDeAdBeEf")
-
- assert(!rs.next())
- }
- }
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnType(1) === Types.VARBINARY)
+ assert(metaData.getColumnTypeName(1) === "BINARY")
+ assert(metaData.getColumnClassName(1) === "[B")
+ assert(metaData.isSigned(1) === false)
+ }
- test("get binary type with null") {
- withExecuteQuery("SELECT cast(null as binary)") { rs =>
- assert(rs.next())
- assert(rs.getBytes(1) === null)
- assert(rs.wasNull)
- assert(!rs.next())
+ // Test binary type with UTF-8 text
+ val textBytes = "\\xDeAdBeEf".getBytes(StandardCharsets.UTF_8)
+ val hexString2 = textBytes.map(b => "%02X".format(b)).mkString
+ withExecuteQuery(stmt, s"SELECT CAST(X'$hexString2' AS BINARY)") { rs =>
+ assert(rs.next())
+ val bytes = rs.getBytes(1)
+ assert(bytes !== null)
+ assert(bytes.sameElements(textBytes))
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnType(1) === Types.VARBINARY)
- assert(metaData.getColumnTypeName(1) === "BINARY")
- assert(metaData.getColumnClassName(1) === "[B")
- }
- }
+ val stringValue = rs.getString(1)
+ assert(stringValue === "\\xDeAdBeEf")
- test("get binary type by column label") {
- val testBytes = Array[Byte](0x0A, 0x0B, 0x0C)
- val hexString = testBytes.map(b => "%02X".format(b)).mkString
- withExecuteQuery(s"SELECT CAST(X'$hexString' AS BINARY) as test_binary") { rs =>
- assert(rs.next())
- val bytes = rs.getBytes("test_binary")
- assert(bytes !== null)
- assert(bytes.length === testBytes.length)
- assert(bytes.sameElements(testBytes))
- assert(!rs.wasNull)
- assert(!rs.next())
+ assert(!rs.next())
+ }
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnName(1) === "test_binary")
- assert(metaData.getColumnLabel(1) === "test_binary")
- }
- }
+ // Test binary type with null
+ withExecuteQuery(stmt, "SELECT cast(null as binary)") { rs =>
+ assert(rs.next())
+ assert(rs.getBytes(1) === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnType(1) === Types.VARBINARY)
+ assert(metaData.getColumnTypeName(1) === "BINARY")
+ assert(metaData.getColumnClassName(1) === "[B")
+ }
- test("get empty binary") {
- withExecuteQuery("SELECT CAST(X'' AS BINARY)") { rs =>
- assert(rs.next())
- val bytes = rs.getBytes(1)
- assert(bytes !== null)
- assert(bytes.length === 0)
- assert(!rs.wasNull)
+ // Test binary type by column label
+ val testBytes2 = Array[Byte](0x0A, 0x0B, 0x0C)
+ val hexString3 = testBytes2.map(b => "%02X".format(b)).mkString
+ withExecuteQuery(stmt, s"SELECT CAST(X'$hexString3' AS BINARY) as test_binary") { rs =>
+ assert(rs.next())
+ val bytes = rs.getBytes("test_binary")
+ assert(bytes !== null)
+ assert(bytes.length === testBytes2.length)
+ assert(bytes.sameElements(testBytes2))
+ assert(!rs.wasNull)
+
+ val stringValue = rs.getString("test_binary")
+ val expectedString = new String(testBytes2, StandardCharsets.UTF_8)
+ assert(stringValue === expectedString)
+
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "test_binary")
+ assert(metaData.getColumnLabel(1) === "test_binary")
+ }
- val stringValue = rs.getString(1)
- assert(stringValue === "")
- assert(!rs.next())
+ // Test empty binary
+ withExecuteQuery(stmt, "SELECT CAST(X'' AS BINARY)") { rs =>
+ assert(rs.next())
+ val bytes = rs.getBytes(1)
+ assert(bytes !== null)
+ assert(bytes.length === 0)
+ assert(!rs.wasNull)
+
+ val stringValue = rs.getString(1)
+ assert(stringValue === "")
+ assert(!rs.next())
+ }
}
}
test("get time type") {
- withExecuteQuery("SELECT time '12:34:56.123456'") { rs =>
- assert(rs.next())
- val time = rs.getTime(1)
- // Verify milliseconds are preserved (123 from 123456 microseconds)
- val expectedMillis = timeToMillis(12, 34, 56, 123)
- assert(time.getTime === expectedMillis)
- assert(!rs.wasNull)
- assert(!rs.next())
-
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnName(1) === "TIME '12:34:56.123456'")
- assert(metaData.getColumnLabel(1) === "TIME '12:34:56.123456'")
- assert(metaData.getColumnType(1) === Types.TIME)
- assert(metaData.getColumnTypeName(1) === "TIME(6)")
- assert(metaData.getColumnClassName(1) === "java.sql.Time")
- assert(metaData.isSigned(1) === false)
- assert(metaData.getPrecision(1) === 6)
- assert(metaData.getScale(1) === 0)
- assert(metaData.getColumnDisplaySize(1) === 15)
- }
- }
-
- test("get time type with null") {
- withExecuteQuery("SELECT cast(null as time)") { rs =>
- assert(rs.next())
- assert(rs.getTime(1) === null)
- assert(rs.wasNull)
- assert(!rs.next())
+ withStatement { stmt =>
+ // Test basic time type
+ withExecuteQuery(stmt, "SELECT time '12:34:56.123456'") { rs =>
+ assert(rs.next())
+ val time = rs.getTime(1)
+ // Verify milliseconds are preserved (123 from 123456 microseconds)
+ val expectedMillis = timeToMillis(12, 34, 56, 123)
+ assert(time.getTime === expectedMillis)
+ assert(!rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "TIME '12:34:56.123456'")
+ assert(metaData.getColumnLabel(1) === "TIME '12:34:56.123456'")
+ assert(metaData.getColumnType(1) === Types.TIME)
+ assert(metaData.getColumnTypeName(1) === "TIME(6)")
+ assert(metaData.getColumnClassName(1) === "java.sql.Time")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 6)
+ assert(metaData.getScale(1) === 0)
+ assert(metaData.getColumnDisplaySize(1) === 15)
+ }
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnName(1) === "CAST(NULL AS TIME(6))")
- assert(metaData.getColumnLabel(1) === "CAST(NULL AS TIME(6))")
- assert(metaData.getColumnType(1) === Types.TIME)
- assert(metaData.getColumnTypeName(1) === "TIME(6)")
- assert(metaData.getColumnClassName(1) === "java.sql.Time")
- assert(metaData.isSigned(1) === false)
- assert(metaData.getPrecision(1) === 6)
- assert(metaData.getScale(1) === 0)
- assert(metaData.getColumnDisplaySize(1) === 15)
- }
- }
+ // Test time type with null
+ withExecuteQuery(stmt, "SELECT cast(null as time)") { rs =>
+ assert(rs.next())
+ assert(rs.getTime(1) === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "CAST(NULL AS TIME(6))")
+ assert(metaData.getColumnLabel(1) === "CAST(NULL AS TIME(6))")
+ assert(metaData.getColumnType(1) === Types.TIME)
+ assert(metaData.getColumnTypeName(1) === "TIME(6)")
+ assert(metaData.getColumnClassName(1) === "java.sql.Time")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 6)
+ assert(metaData.getScale(1) === 0)
+ assert(metaData.getColumnDisplaySize(1) === 15)
+ }
- test("get time type by column label") {
- withExecuteQuery("SELECT time '09:15:30.456789' as test_time") { rs =>
- assert(rs.next())
- val time = rs.getTime("test_time")
- // Verify milliseconds are preserved (456 from 456789 microseconds)
- val expectedMillis = timeToMillis(9, 15, 30, 456)
- assert(time.getTime === expectedMillis)
- assert(!rs.wasNull)
- assert(!rs.next())
+ // Test time type by column label
+ withExecuteQuery(stmt, "SELECT time '09:15:30.456789' as test_time") { rs =>
+ assert(rs.next())
+ val time = rs.getTime("test_time")
+ // Verify milliseconds are preserved (456 from 456789 microseconds)
+ val expectedMillis = timeToMillis(9, 15, 30, 456)
+ assert(time.getTime === expectedMillis)
+ assert(!rs.wasNull)
+ assert(!rs.next())
+ }
}
}
test("get time type with different precisions") {
- Seq(
- // (timeValue, precision, expectedDisplaySize, expectedMillis)
- // HH:MM:SS (no fractional)
- ("15:45:30.123456", 0, 8, timeToMillis(15, 45, 30, 0)),
- // HH:MM:SS.f (100ms from .1)
- ("10:20:30.123456", 1, 10, timeToMillis(10, 20, 30, 100)),
- // HH:MM:SS.fff (123ms)
- ("08:15:45.123456", 3, 12, timeToMillis(8, 15, 45, 123)),
- // HH:MM:SS.fff (999ms) . Spark TIME values can have microsecond precision,
- // but java.sql.Time can only store up to millisecond precision
- ("23:59:59.999999", 6, 15, timeToMillis(23, 59, 59, 999))
- ).foreach {
- case (timeValue, precision, expectedDisplaySize, expectedMillis) =>
- withExecuteQuery(s"SELECT cast(time '$timeValue' as time($precision))") { rs =>
- assert(rs.next(), s"Failed to get next row for precision $precision")
- val time = rs.getTime(1)
- assert(time.getTime === expectedMillis,
- s"Time millis mismatch for precision" +
- s" $precision: expected $expectedMillis, got ${time.getTime}")
- assert(!rs.wasNull, s"wasNull should be false for precision $precision")
- assert(!rs.next(), s"Should have no more rows for precision $precision")
-
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnType(1) === Types.TIME,
- s"Column type mismatch for precision $precision")
- assert(metaData.getColumnTypeName(1) === s"TIME($precision)",
- s"Column type name mismatch for precision $precision")
- assert(metaData.getColumnClassName(1) === "java.sql.Time",
- s"Column class name mismatch for precision $precision")
- assert(metaData.getPrecision(1) === precision,
- s"Precision mismatch for precision $precision")
- assert(metaData.getScale(1) === 0,
- s"Scale should be 0 for precision $precision")
- assert(metaData.getColumnDisplaySize(1) === expectedDisplaySize,
- s"Display size mismatch for precision $precision: " +
- s"expected $expectedDisplaySize, got ${metaData.getColumnDisplaySize(1)}")
- }
+ withStatement { stmt =>
+ Seq(
+ // (timeValue, precision, expectedDisplaySize, expectedMillis)
+ // HH:MM:SS (no fractional)
+ ("15:45:30.123456", 0, 8, timeToMillis(15, 45, 30, 0)),
+ // HH:MM:SS.f (100ms from .1)
+ ("10:20:30.123456", 1, 10, timeToMillis(10, 20, 30, 100)),
+ // HH:MM:SS.fff (123ms)
+ ("08:15:45.123456", 3, 12, timeToMillis(8, 15, 45, 123)),
+ // HH:MM:SS.fff (999ms) . Spark TIME values can have microsecond precision,
+ // but java.sql.Time can only store up to millisecond precision
+ ("23:59:59.999999", 6, 15, timeToMillis(23, 59, 59, 999))
+ ).foreach {
+ case (timeValue, precision, expectedDisplaySize, expectedMillis) =>
+ withExecuteQuery(stmt, s"SELECT cast(time '$timeValue' as time($precision))") { rs =>
+ assert(rs.next(), s"Failed to get next row for precision $precision")
+ val time = rs.getTime(1)
+ assert(time.getTime === expectedMillis,
+ s"Time millis mismatch for precision" +
+ s" $precision: expected $expectedMillis, got ${time.getTime}")
+ assert(!rs.wasNull, s"wasNull should be false for precision $precision")
+ assert(!rs.next(), s"Should have no more rows for precision $precision")
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnType(1) === Types.TIME,
+ s"Column type mismatch for precision $precision")
+ assert(metaData.getColumnTypeName(1) === s"TIME($precision)",
+ s"Column type name mismatch for precision $precision")
+ assert(metaData.getColumnClassName(1) === "java.sql.Time",
+ s"Column class name mismatch for precision $precision")
+ assert(metaData.getPrecision(1) === precision,
+ s"Precision mismatch for precision $precision")
+ assert(metaData.getScale(1) === 0,
+ s"Scale should be 0 for precision $precision")
+ assert(metaData.getColumnDisplaySize(1) === expectedDisplaySize,
+ s"Display size mismatch for precision $precision: " +
+ s"expected $expectedDisplaySize, got ${metaData.getColumnDisplaySize(1)}")
+ }
+ }
}
}
@@ -570,7 +588,7 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess
stmt.execute(s"set spark.sql.datetime.java8API.enabled=$java8APIEnabled")
Using.resource(stmt.executeQuery("SELECT date '2025-11-15'")) { rs =>
assert(rs.next())
- assert(rs.getDate(1) === java.sql.Date.valueOf("2025-11-15"))
+ assert(rs.getDate(1) === Date.valueOf("2025-11-15"))
assert(!rs.wasNull)
assert(!rs.next())
}
@@ -595,150 +613,153 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess
}
test("get timestamp type") {
- withExecuteQuery("SELECT timestamp '2025-11-15 10:30:45.123456'") { rs =>
- assert(rs.next())
- val timestamp = rs.getTimestamp(1)
- assert(timestamp !== null)
- assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 10:30:45.123456"))
- assert(!rs.wasNull)
- assert(!rs.next())
-
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnName(1) === "TIMESTAMP '2025-11-15 10:30:45.123456'")
- assert(metaData.getColumnLabel(1) === "TIMESTAMP '2025-11-15 10:30:45.123456'")
- assert(metaData.getColumnType(1) === Types.TIMESTAMP)
- assert(metaData.getColumnTypeName(1) === "TIMESTAMP")
- assert(metaData.getColumnClassName(1) === "java.sql.Timestamp")
- assert(metaData.isSigned(1) === false)
- assert(metaData.getPrecision(1) === 29)
- assert(metaData.getScale(1) === 6)
- assert(metaData.getColumnDisplaySize(1) === 29)
- }
- }
-
- test("get timestamp type with null") {
- withExecuteQuery("SELECT cast(null as timestamp)") { rs =>
- assert(rs.next())
- assert(rs.getTimestamp(1) === null)
- assert(rs.wasNull)
- assert(!rs.next())
-
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnName(1) === "CAST(NULL AS TIMESTAMP)")
- assert(metaData.getColumnLabel(1) === "CAST(NULL AS TIMESTAMP)")
- assert(metaData.getColumnType(1) === Types.TIMESTAMP)
- assert(metaData.getColumnTypeName(1) === "TIMESTAMP")
- assert(metaData.getColumnClassName(1) === "java.sql.Timestamp")
- assert(metaData.isSigned(1) === false)
- assert(metaData.getPrecision(1) === 29)
- assert(metaData.getScale(1) === 6)
- assert(metaData.getColumnDisplaySize(1) === 29)
- }
- }
-
- test("get timestamp type by column label and with calendar") {
- withExecuteQuery("SELECT timestamp '2025-11-15 10:30:45.987654' as test_timestamp") { rs =>
- assert(rs.next())
-
- // Test by column label
- val timestamp = rs.getTimestamp("test_timestamp")
- assert(timestamp !== null)
- assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 10:30:45.987654"))
- assert(!rs.wasNull)
+ withStatement { stmt =>
+ // Test basic timestamp type
+ withExecuteQuery(stmt, "SELECT timestamp '2025-11-15 10:30:45.123456'") { rs =>
+ assert(rs.next())
+ val timestamp = rs.getTimestamp(1)
+ assert(timestamp !== null)
+ assert(timestamp === Timestamp.valueOf("2025-11-15 10:30:45.123456"))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "TIMESTAMP '2025-11-15 10:30:45.123456'")
+ assert(metaData.getColumnLabel(1) === "TIMESTAMP '2025-11-15 10:30:45.123456'")
+ assert(metaData.getColumnType(1) === Types.TIMESTAMP)
+ assert(metaData.getColumnTypeName(1) === "TIMESTAMP")
+ assert(metaData.getColumnClassName(1) === "java.sql.Timestamp")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 29)
+ assert(metaData.getScale(1) === 6)
+ assert(metaData.getColumnDisplaySize(1) === 29)
+ }
- // Test with calendar - should return same value (Calendar is ignored)
- // Note: Spark Connect handles timezone at server, Calendar param is for API compliance
- val calUTC = java.util.Calendar.getInstance(java.util.TimeZone.getTimeZone("UTC"))
- val timestampUTC = rs.getTimestamp(1, calUTC)
- assert(timestampUTC !== null)
- assert(timestampUTC.getTime === timestamp.getTime)
-
- val calPST = java.util.Calendar.getInstance(
- java.util.TimeZone.getTimeZone("America/Los_Angeles"))
- val timestampPST = rs.getTimestamp(1, calPST)
- assert(timestampPST !== null)
- // Same value regardless of calendar
- assert(timestampPST.getTime === timestamp.getTime)
- assert(timestampUTC.getTime === timestampPST.getTime)
-
- // Test with calendar by label
- val timestampLabel = rs.getTimestamp("test_timestamp", calUTC)
- assert(timestampLabel !== null)
- assert(timestampLabel.getTime === timestamp.getTime)
-
- // Test with null calendar - returns same value
- val timestampNullCal = rs.getTimestamp(1, null)
- assert(timestampNullCal !== null)
- assert(timestampNullCal.getTime === timestamp.getTime)
+ // Test timestamp type with null
+ withExecuteQuery(stmt, "SELECT cast(null as timestamp)") { rs =>
+ assert(rs.next())
+ assert(rs.getTimestamp(1) === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "CAST(NULL AS TIMESTAMP)")
+ assert(metaData.getColumnLabel(1) === "CAST(NULL AS TIMESTAMP)")
+ assert(metaData.getColumnType(1) === Types.TIMESTAMP)
+ assert(metaData.getColumnTypeName(1) === "TIMESTAMP")
+ assert(metaData.getColumnClassName(1) === "java.sql.Timestamp")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 29)
+ assert(metaData.getScale(1) === 6)
+ assert(metaData.getColumnDisplaySize(1) === 29)
+ }
- assert(!rs.next())
- }
- }
+ // Test timestamp type by column label and with calendar
+ val tsString = "2025-11-15 10:30:45.987654"
+ withExecuteQuery(stmt, s"SELECT timestamp '$tsString' as test_timestamp") { rs =>
+ assert(rs.next())
+
+ // Test by column label
+ val timestamp = rs.getTimestamp("test_timestamp")
+ assert(timestamp !== null)
+ assert(timestamp === Timestamp.valueOf(tsString))
+ assert(!rs.wasNull)
+
+ // Test with calendar - should return same value (Calendar is ignored)
+ // Note: Spark Connect handles timezone at server, Calendar param is for API compliance
+ val calUTC = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
+ val timestampUTC = rs.getTimestamp(1, calUTC)
+ assert(timestampUTC !== null)
+ assert(timestampUTC.getTime === timestamp.getTime)
+
+ val calPST = Calendar.getInstance(
+ TimeZone.getTimeZone("America/Los_Angeles"))
+ val timestampPST = rs.getTimestamp(1, calPST)
+ assert(timestampPST !== null)
+ // Same value regardless of calendar
+ assert(timestampPST.getTime === timestamp.getTime)
+ assert(timestampUTC.getTime === timestampPST.getTime)
+
+ // Test with calendar by label
+ val timestampLabel = rs.getTimestamp("test_timestamp", calUTC)
+ assert(timestampLabel !== null)
+ assert(timestampLabel.getTime === timestamp.getTime)
+
+ // Test with null calendar - returns same value
+ val timestampNullCal = rs.getTimestamp(1, null)
+ assert(timestampNullCal !== null)
+ assert(timestampNullCal.getTime === timestamp.getTime)
+
+ assert(!rs.next())
+ }
- test("get timestamp type with calendar for null value") {
- withExecuteQuery("SELECT cast(null as timestamp)") { rs =>
- assert(rs.next())
+ // Test timestamp type with calendar for null value
+ withExecuteQuery(stmt, "SELECT cast(null as timestamp)") { rs =>
+ assert(rs.next())
- // Calendar parameter should not affect null handling
- val cal = java.util.Calendar.getInstance(java.util.TimeZone.getTimeZone("UTC"))
- val timestamp = rs.getTimestamp(1, cal)
- assert(timestamp === null)
- assert(rs.wasNull)
- assert(!rs.next())
+ // Calendar parameter should not affect null handling
+ val cal = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
+ val timestamp = rs.getTimestamp(1, cal)
+ assert(timestamp === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+ }
}
}
test("get timestamp_ntz type") {
- withExecuteQuery("SELECT timestamp_ntz '2025-11-15 10:30:45.123456'") { rs =>
- assert(rs.next())
- val timestamp = rs.getTimestamp(1)
- assert(timestamp !== null)
- assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 10:30:45.123456"))
- assert(!rs.wasNull)
- assert(!rs.next())
-
- val metaData = rs.getMetaData
- assert(metaData.getColumnCount === 1)
- assert(metaData.getColumnName(1) === "TIMESTAMP_NTZ '2025-11-15 10:30:45.123456'")
- assert(metaData.getColumnLabel(1) === "TIMESTAMP_NTZ '2025-11-15 10:30:45.123456'")
- assert(metaData.getColumnType(1) === Types.TIMESTAMP)
- assert(metaData.getColumnTypeName(1) === "TIMESTAMP_NTZ")
- assert(metaData.getColumnClassName(1) === "java.sql.Timestamp")
- assert(metaData.isSigned(1) === false)
- assert(metaData.getPrecision(1) === 29)
- assert(metaData.getScale(1) === 6)
- assert(metaData.getColumnDisplaySize(1) === 29)
- }
- }
+ withStatement { stmt =>
+ // Test basic timestamp_ntz type
+ withExecuteQuery(stmt, "SELECT timestamp_ntz '2025-11-15 10:30:45.123456'") { rs =>
+ assert(rs.next())
+ val timestamp = rs.getTimestamp(1)
+ assert(timestamp !== null)
+ assert(timestamp === Timestamp.valueOf("2025-11-15 10:30:45.123456"))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "TIMESTAMP_NTZ '2025-11-15 10:30:45.123456'")
+ assert(metaData.getColumnLabel(1) === "TIMESTAMP_NTZ '2025-11-15 10:30:45.123456'")
+ assert(metaData.getColumnType(1) === Types.TIMESTAMP)
+ assert(metaData.getColumnTypeName(1) === "TIMESTAMP_NTZ")
+ assert(metaData.getColumnClassName(1) === "java.sql.Timestamp")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 29)
+ assert(metaData.getScale(1) === 6)
+ assert(metaData.getColumnDisplaySize(1) === 29)
+ }
- test("get timestamp_ntz type by label, null, and with calendar") {
- // Test with non-null value
- withExecuteQuery("SELECT timestamp_ntz '2025-11-15 14:22:33.789456' as test_ts_ntz") { rs =>
- assert(rs.next())
+ // Test timestamp_ntz by label, null, and with calendar - non-null value
+ val tsString = "2025-11-15 14:22:33.789456"
+ withExecuteQuery(stmt, s"SELECT timestamp_ntz '$tsString' as test_ts_ntz") { rs =>
+ assert(rs.next())
- // Test by column label
- val timestamp = rs.getTimestamp("test_ts_ntz")
- assert(timestamp !== null)
- assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 14:22:33.789456"))
- assert(!rs.wasNull)
+ // Test by column label
+ val timestamp = rs.getTimestamp("test_ts_ntz")
+ assert(timestamp !== null)
+ assert(timestamp === Timestamp.valueOf(tsString))
+ assert(!rs.wasNull)
- // Test with calendar - should return same value (Calendar is ignored)
- val calUTC = java.util.Calendar.getInstance(java.util.TimeZone.getTimeZone("UTC"))
- val timestampCal = rs.getTimestamp(1, calUTC)
- assert(timestampCal !== null)
- assert(timestampCal.getTime === timestamp.getTime)
+ // Test with calendar - should return same value (Calendar is ignored)
+ val calUTC = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
+ val timestampCal = rs.getTimestamp(1, calUTC)
+ assert(timestampCal !== null)
+ assert(timestampCal.getTime === timestamp.getTime)
- assert(!rs.next())
- }
+ assert(!rs.next())
+ }
- // Test with null value
- withExecuteQuery("SELECT cast(null as timestamp_ntz)") { rs =>
- assert(rs.next())
- assert(rs.getTimestamp(1) === null)
- assert(rs.wasNull)
- assert(!rs.next())
+ // Test timestamp_ntz with null value
+ withExecuteQuery(stmt, "SELECT cast(null as timestamp_ntz)") { rs =>
+ assert(rs.next())
+ assert(rs.getTimestamp(1) === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+ }
}
}
@@ -757,13 +778,13 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess
// Test TIMESTAMP type
val timestamp = rs.getTimestamp(1)
assert(timestamp !== null)
- assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 10:30:45.123456"))
+ assert(timestamp === Timestamp.valueOf("2025-11-15 10:30:45.123456"))
assert(!rs.wasNull)
// Test TIMESTAMP_NTZ type
val timestampNtz = rs.getTimestamp(2)
assert(timestampNtz !== null)
- assert(timestampNtz === java.sql.Timestamp.valueOf("2025-11-15 14:22:33.789012"))
+ assert(timestampNtz === Timestamp.valueOf("2025-11-15 14:22:33.789012"))
assert(!rs.wasNull)
assert(!rs.next())
diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala
index 9b3aa373e93c..a512a44cac3b 100644
--- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala
+++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala
@@ -39,8 +39,10 @@ trait JdbcHelper {
}
def withExecuteQuery(query: String)(f: ResultSet => Unit): Unit = {
- withStatement { stmt =>
- Using.resource { stmt.executeQuery(query) } { rs => f(rs) }
- }
+ withStatement { stmt => withExecuteQuery(stmt, query)(f) }
+ }
+
+ def withExecuteQuery(stmt: Statement, query: String)(f: ResultSet => Unit): Unit = {
+ Using.resource { stmt.executeQuery(query) } { rs => f(rs) }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 47f5f180789e..8534a24d0110 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -305,8 +305,7 @@ class JDBCRDD(
val inputMetrics = context.taskMetrics().inputMetrics
val part = thePart.asInstanceOf[JDBCPartition]
conn = getConnection(part.idx)
- import scala.jdk.CollectionConverters._
- dialect.beforeFetch(conn, options.asProperties.asScala.toMap)
+ dialect.beforeFetch(conn, options)
// This executes a generic SQL statement (or PL/SQL block) before reading
// the table/query via JDBC. Use this feature to initialize the database
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index f5f968ee9522..499fa99a2444 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -35,7 +35,6 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef](
_schema: StructType,
_timeZoneId: String,
protected override val largeVarTypes: Boolean,
- protected override val workerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
@@ -86,12 +85,11 @@ abstract class RowInputArrowPythonRunner(
_schema: StructType,
_timeZoneId: String,
largeVarTypes: Boolean,
- workerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
extends BaseArrowPythonRunner[Iterator[InternalRow], ColumnarBatch](
- funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf,
+ funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID)
with BasicPythonArrowInput
with BasicPythonArrowOutput
@@ -106,13 +104,13 @@ class ArrowPythonRunner(
_schema: StructType,
_timeZoneId: String,
largeVarTypes: Boolean,
- workerConf: Map[String, String],
+ protected override val runnerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String],
profiler: Option[String])
extends RowInputArrowPythonRunner(
- funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf,
+ funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID) {
override protected def writeUDF(dataOut: DataOutputStream): Unit =
@@ -130,13 +128,13 @@ class ArrowPythonWithNamedArgumentRunner(
_schema: StructType,
_timeZoneId: String,
largeVarTypes: Boolean,
- workerConf: Map[String, String],
+ protected override val runnerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String],
profiler: Option[String])
extends RowInputArrowPythonRunner(
- funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, workerConf,
+ funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID) {
override protected def writeUDF(dataOut: DataOutputStream): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index 1d5df9bad924..979d91205d5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner(
protected override val schema: StructType,
protected override val timeZoneId: String,
protected override val largeVarTypes: Boolean,
- protected override val workerConf: Map[String, String],
+ protected override val runnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 7f6efbae8881..b5986be9214a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -25,7 +25,7 @@ import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec}
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
-import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker}
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonWorker}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriterWrapper
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -45,7 +45,7 @@ class CoGroupedArrowPythonRunner(
rightSchema: StructType,
timeZoneId: String,
largeVarTypes: Boolean,
- conf: Map[String, String],
+ protected override val runnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String],
@@ -119,14 +119,6 @@ class CoGroupedArrowPythonRunner(
private var rightGroupArrowWriter: ArrowWriterWrapper = null
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-
- // Write config for the worker as a number of key -> value pairs of strings
- dataOut.writeInt(conf.size)
- for ((k, v) <- conf) {
- PythonRDD.writeUTF(k, dataOut)
- PythonRDD.writeUTF(v, dataOut)
- }
-
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index f77b0a9342b0..d2d16b0c9623 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -27,7 +27,7 @@ import org.apache.arrow.vector.ipc.WriteChannel
import org.apache.arrow.vector.ipc.message.MessageSerializer
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
-import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker}
+import org.apache.spark.api.python.{BasePythonRunner, PythonWorker}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow
import org.apache.spark.sql.execution.arrow.{ArrowWriter, ArrowWriterWrapper}
@@ -42,8 +42,6 @@ import org.apache.spark.util.Utils
* JVM (an iterator of internal rows + additional data if required) to Python (Arrow).
*/
private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
- protected val workerConf: Map[String, String]
-
protected val schema: StructType
protected val timeZoneId: String
@@ -62,14 +60,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
protected def writeUDF(dataOut: DataOutputStream): Unit
- protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
- // Write config for the worker as a number of key -> value pairs of strings
- stream.writeInt(workerConf.size)
- for ((k, v) <- workerConf) {
- PythonRDD.writeUTF(k, stream)
- PythonRDD.writeUTF(v, stream)
- }
- }
+ protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {}
+
private val arrowSchema = ArrowUtils.toArrowSchema(
schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
protected val allocator =
@@ -301,7 +293,6 @@ private[python] trait GroupedPythonArrowInput { self: RowInputArrowPythonRunner
context: TaskContext): Writer = {
new Writer(env, worker, inputIterator, partitionIndex, context) {
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
- handleMetadataBeforeExec(dataOut)
writeUDF(dataOut)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
index 14054ba89a94..ae89ff1637ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
@@ -58,7 +58,7 @@ class ApplyInPandasWithStatePythonRunner(
argOffsets: Array[Array[Int]],
inputSchema: StructType,
_timeZoneId: String,
- initialWorkerConf: Map[String, String],
+ initialRunnerConf: Map[String, String],
stateEncoder: ExpressionEncoder[Row],
keySchema: StructType,
outputSchema: StructType,
@@ -113,7 +113,7 @@ class ApplyInPandasWithStatePythonRunner(
// applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance.
// Configurations are both applied to executor and Python worker, set them to the worker conf
// to let Python worker read the config properly.
- override protected val workerConf: Map[String, String] = initialWorkerConf +
+ override protected val runnerConf: Map[String, String] = initialRunnerConf +
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 3eb7c7e64d64..bbf7b9387526 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -52,7 +52,7 @@ class TransformWithStateInPySparkPythonRunner(
_schema: StructType,
processorHandle: StatefulProcessorHandleImpl,
_timeZoneId: String,
- initialWorkerConf: Map[String, String],
+ initialRunnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
@@ -60,7 +60,7 @@ class TransformWithStateInPySparkPythonRunner(
eventTimeWatermarkForEviction: Option[Long])
extends TransformWithStateInPySparkPythonBaseRunner[InType](
funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
- initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
batchTimestampMs, eventTimeWatermarkForEviction)
with PythonArrowInput[InType] {
@@ -126,7 +126,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
initStateSchema: StructType,
processorHandle: StatefulProcessorHandleImpl,
_timeZoneId: String,
- initialWorkerConf: Map[String, String],
+ initialRunnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
@@ -134,7 +134,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
eventTimeWatermarkForEviction: Option[Long])
extends TransformWithStateInPySparkPythonBaseRunner[GroupedInType](
funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
- initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
batchTimestampMs, eventTimeWatermarkForEviction)
with PythonArrowInput[GroupedInType] {
@@ -221,7 +221,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](
_schema: StructType,
processorHandle: StatefulProcessorHandleImpl,
_timeZoneId: String,
- initialWorkerConf: Map[String, String],
+ initialRunnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
@@ -238,7 +238,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](
protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch
- override protected val workerConf: Map[String, String] = initialWorkerConf +
+ override protected val runnerConf: Map[String, String] = initialRunnerConf +
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)
@@ -251,7 +251,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](
override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
super.handleMetadataBeforeExec(stream)
- // Also write the port/path number for state server
+ // Write the port/path number for state server
if (isUnixDomainSock) {
stream.writeInt(-1)
PythonWorkerUtils.writeUTF(stateServerSocketPath, stream)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index ce4c347cad34..875bfeb011bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -342,9 +342,20 @@ abstract class JdbcDialect extends Serializable with Logging {
* @param connection The connection object
* @param properties The connection properties. This is passed through from the relation.
*/
+ @deprecated("Use beforeFetch(Connection, JDBCOptions) instead", "4.2.0")
def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
}
+ /**
+ * Override connection specific properties to run before a select is made. This is in place to
+ * allow dialects that need special treatment to optimize behavior.
+ * @param connection The connection object
+ * @param options The JDBC options for the connection.
+ */
+ def beforeFetch(connection: Connection, options: JDBCOptions): Unit = {
+ beforeFetch(connection, options.parameters)
+ }
+
/**
* Escape special characters in SQL string literals.
* @param value The string to be escaped.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index e59b4435c408..5d1173b5a1a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -2404,7 +2404,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
sourceDF.createOrReplaceTempView("source")
val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
- sql(s"""MERGE $schemaEvolutionClause
+ val mergeStmt = s"""MERGE $schemaEvolutionClause
|INTO $tableNameAsString t
|USING source s
|ON t.pk = s.pk
@@ -2412,8 +2412,9 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
| UPDATE SET *
|WHEN NOT MATCHED THEN
| INSERT *
- |""".stripMargin)
+ |""".stripMargin
if (withSchemaEvolution && schemaEvolutionEnabled) {
+ sql(mergeStmt)
checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(
@@ -2424,15 +2425,12 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
Row(5, 250, "executive", true),
Row(6, 350, null, false)))
} else {
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(1, 100, "hr"),
- Row(2, 200, "software"),
- Row(3, 300, "hr"),
- Row(4, 150, "marketing"),
- Row(5, 250, "executive"),
- Row(6, 350, null)))
+ val e = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION")
+ assert(e.message.contains("A column, variable, or function parameter with name " +
+ "`dep` cannot be resolved"))
}
sql(s"DROP TABLE $tableNameAsString")
}
@@ -2463,7 +2461,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
sourceDF.createOrReplaceTempView("source")
val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
- sql(s"""MERGE $schemaEvolutionClause
+ val mergeStmt = s"""MERGE $schemaEvolutionClause
|INTO $tableNameAsString t
|USING source s
|ON t.pk = s.pk
@@ -2471,8 +2469,9 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
| UPDATE SET *
|WHEN NOT MATCHED THEN
| INSERT *
- |""".stripMargin)
+ |""".stripMargin
if (withSchemaEvolution && schemaEvolutionEnabled) {
+ sql(mergeStmt)
checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(
@@ -2483,15 +2482,12 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
Row(5, 250, "executive", true),
Row(6, 350, "unknown", false)))
} else {
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(1, 100, "hr"),
- Row(2, 200, "software"),
- Row(3, 300, "hr"),
- Row(4, 150, "marketing"),
- Row(5, 250, "executive"),
- Row(6, 350, "unknown")))
+ val e = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION")
+ assert(e.getMessage.contains("A column, variable, or function parameter with name " +
+ "`dep` cannot be resolved"))
}
sql(s"DROP TABLE $tableNameAsString")
}
@@ -3239,23 +3235,14 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
|WHEN NOT MATCHED THEN
| INSERT *
|""".stripMargin
- if (coerceNestedTypes) {
- if (withSchemaEvolution) {
- sql(mergeStmt)
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"),
- Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- sql(mergeStmt)
- }
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
- assert(exception.getMessage.contains(
- "Cannot write extra fields `c3` to the struct `s`.`c2`"))
- }
+ if (coerceNestedTypes && withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"),
+ Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering")))
+
} else {
val exception = intercept[org.apache.spark.sql.AnalysisException] {
sql(mergeStmt)
@@ -3335,30 +3322,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
|WHEN NOT MATCHED THEN
| INSERT *
|""".stripMargin
- if (coerceNestedTypes) {
- if (withSchemaEvolution) {
- sql(mergeStmt)
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"),
- Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- sql(mergeStmt)
- }
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
- assert(exception.getMessage.contains(
- "Cannot write extra fields `c3` to the struct `s`.`c2`"))
- }
+ if (coerceNestedTypes && withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"),
+ Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering")))
} else {
val exception = intercept[org.apache.spark.sql.AnalysisException] {
sql(mergeStmt)
}
assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot find data for the output column `s`.`c2`.`a`"))
}
}
sql(s"DROP TABLE IF EXISTS $tableNameAsString")
@@ -3534,30 +3509,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
| INSERT *
|""".stripMargin
- if (coerceNestedTypes) {
- if (withSchemaEvolution) {
- sql(mergeStmt)
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"),
- Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "sales"),
- Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- sql(mergeStmt)
- }
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
- assert(exception.getMessage.contains(
- "Cannot write extra fields `c3` to the struct `m`.`key`"))
- }
+ if (coerceNestedTypes && withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"),
+ Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "sales"),
+ Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "engineering")))
} else {
val exception = intercept[org.apache.spark.sql.AnalysisException] {
sql(mergeStmt)
}
assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot find data for the output column `m`.`key`.`c2`"))
}
}
sql(s"DROP TABLE IF EXISTS $tableNameAsString")
@@ -3612,30 +3575,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
| INSERT (pk, m, dep) VALUES (src.pk, src.m, 'my_new_dep')
|""".stripMargin
- if (coerceNestedTypes) {
- if (withSchemaEvolution) {
- sql(mergeStmt)
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"),
- Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "my_old_dep"),
- Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "my_new_dep")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- sql(mergeStmt)
- }
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
- assert(exception.getMessage.contains(
- "Cannot write extra fields `c3` to the struct `m`.`key`"))
- }
+ if (coerceNestedTypes && withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"),
+ Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "my_old_dep"),
+ Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "my_new_dep")))
} else {
val exception = intercept[org.apache.spark.sql.AnalysisException] {
sql(mergeStmt)
}
assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot find data for the output column `m`.`key`.`c2`"))
}
}
sql(s"DROP TABLE IF EXISTS $tableNameAsString")
@@ -3688,30 +3639,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
| INSERT *
|""".stripMargin
- if (coerceNestedTypes) {
- if (withSchemaEvolution) {
- sql(mergeStmt)
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(Row(0, Array(Row(10, 10, null)), "hr"),
- Row(1, Array(Row(10, null, true)), "sales"),
- Row(2, Array(Row(20, null, false)), "engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- sql(mergeStmt)
- }
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
- assert(exception.getMessage.contains(
- "Cannot write extra fields `c3` to the struct `a`.`element`"))
- }
+ if (coerceNestedTypes && withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(0, Array(Row(10, 10, null)), "hr"),
+ Row(1, Array(Row(10, null, true)), "sales"),
+ Row(2, Array(Row(20, null, false)), "engineering")))
} else {
val exception = intercept[org.apache.spark.sql.AnalysisException] {
sql(mergeStmt)
}
assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot find data for the output column `a`.`element`.`c2`"))
}
}
sql(s"DROP TABLE IF EXISTS $tableNameAsString")
@@ -3764,30 +3703,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
| INSERT (pk, a, dep) VALUES (src.pk, src.a, 'my_new_dep')
|""".stripMargin
- if (coerceNestedTypes) {
- if (withSchemaEvolution) {
- sql(mergeStmt)
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(Row(0, Array(Row(10, 10, null)), "hr"),
- Row(1, Array(Row(10, null, true)), "my_old_dep"),
- Row(2, Array(Row(20, null, false)), "my_new_dep")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- sql(mergeStmt)
- }
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
- assert(exception.getMessage.contains(
- "Cannot write extra fields `c3` to the struct `a`.`element`"))
- }
+ if (coerceNestedTypes && withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(0, Array(Row(10, 10, null)), "hr"),
+ Row(1, Array(Row(10, null, true)), "my_old_dep"),
+ Row(2, Array(Row(20, null, false)), "my_new_dep")))
} else {
val exception = intercept[org.apache.spark.sql.AnalysisException] {
sql(mergeStmt)
}
assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot find data for the output column `a`.`element`.`c2`"))
}
}
sql(s"DROP TABLE IF EXISTS $tableNameAsString")
@@ -4051,6 +3978,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
sql(mergeStmt)
}
assert(exception.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION")
+ assert(exception.message.contains(" A column, variable, or function parameter with name "
+ + "`bonus` cannot be resolved"))
}
sql(s"DROP TABLE $tableNameAsString")
@@ -4484,270 +4413,324 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
}
test("merge into with source missing fields in struct nested in array") {
- Seq(true, false).foreach { coerceNestedTypes =>
- withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
- coerceNestedTypes.toString) {
- withTempView("source") {
- // Target table has struct with 3 fields (c1, c2, c3) in array
- createAndInitTable(
- s"""pk INT NOT NULL,
- |a ARRAY>,
- |dep STRING""".stripMargin,
- """{ "pk": 0, "a": [ { "c1": 1, "c2": "a", "c3": true } ], "dep": "sales" }
- |{ "pk": 1, "a": [ { "c1": 2, "c2": "b", "c3": false } ], "dep": "sales" }"""
- .stripMargin)
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ withTempView("source") {
+ // Target table has struct with 3 fields (c1, c2, c3) in array
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |a ARRAY>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "a": [ { "c1": 1, "c2": "a", "c3": true } ], "dep": "sales" }
+ |{ "pk": 1, "a": [ { "c1": 2, "c2": "b", "c3": false } ], "dep": "sales" }"""
+ .stripMargin)
- // Source table has struct with only 2 fields (c1, c2) - missing c3
- val sourceTableSchema = StructType(Seq(
- StructField("pk", IntegerType, nullable = false),
- StructField("a", ArrayType(
- StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StringType))))), // missing c3 field
- StructField("dep", StringType)))
- val data = Seq(
- Row(1, Array(Row(10, "c")), "hr"),
- Row(2, Array(Row(30, "e")), "engineering")
- )
- spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
- .createOrReplaceTempView("source")
+ // Source table has struct with only 2 fields (c1, c2) - missing c3
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("a", ArrayType(
+ StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType))))), // missing c3 field
+ StructField("dep", StringType)))
+ val data = Seq(
+ Row(1, Array(Row(10, "c")), "hr"),
+ Row(2, Array(Row(30, "e")), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
+ .createOrReplaceTempView("source")
- val mergeStmt =
- s"""MERGE INTO $tableNameAsString t
- |USING source src
- |ON t.pk = src.pk
- |WHEN MATCHED THEN
- | UPDATE SET *
- |WHEN NOT MATCHED THEN
- | INSERT *
- |""".stripMargin
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
- if (coerceNestedTypes) {
- sql(mergeStmt)
- // Missing field c3 should be filled with NULL
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(0, Array(Row(1, "a", true)), "sales"),
- Row(1, Array(Row(10, "c", null)), "hr"),
- Row(2, Array(Row(30, "e", null)), "engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ if (coerceNestedTypes && withSchemaEvolution) {
sql(mergeStmt)
+ // Missing field c3 should be filled with NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Array(Row(1, "a", true)), "sales"),
+ Row(1, Array(Row(10, "c", null)), "hr"),
+ Row(2, Array(Row(30, "e", null)), "engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
}
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot write incompatible data for the table ``: " +
- "Cannot find data for the output column `a`.`element`.`c3`."))
}
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
- sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
}
}
test("merge into with source missing fields in struct nested in map key") {
- Seq(true, false).foreach { coerceNestedTypes =>
- withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
- coerceNestedTypes.toString) {
- withTempView("source") {
- // Target table has struct with 2 fields in map key
- val targetSchema =
- StructType(Seq(
- StructField("pk", IntegerType, nullable = false),
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ withTempView("source") {
+ // Target table has struct with 2 fields in map key
+ val targetSchema =
+ StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("m", MapType(
+ StructType(Seq(StructField("c1", IntegerType), StructField("c2", BooleanType))),
+ StructType(Seq(StructField("c3", StringType))))),
+ StructField("dep", StringType)))
+ createTable(CatalogV2Util.structTypeToV2Columns(targetSchema))
+
+ val targetData = Seq(
+ Row(0, Map(Row(10, true) -> Row("x")), "hr"),
+ Row(1, Map(Row(20, false) -> Row("y")), "sales"))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ // Source table has struct with only 1 field (c1) in map key - missing c2
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
StructField("m", MapType(
- StructType(Seq(StructField("c1", IntegerType), StructField("c2", BooleanType))),
+ StructType(Seq(StructField("c1", IntegerType))), // missing c2
StructType(Seq(StructField("c3", StringType))))),
StructField("dep", StringType)))
- createTable(CatalogV2Util.structTypeToV2Columns(targetSchema))
-
- val targetData = Seq(
- Row(0, Map(Row(10, true) -> Row("x")), "hr"),
- Row(1, Map(Row(20, false) -> Row("y")), "sales"))
- spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema)
- .writeTo(tableNameAsString).append()
-
- // Source table has struct with only 1 field (c1) in map key - missing c2
- val sourceTableSchema = StructType(Seq(
- StructField("pk", IntegerType),
- StructField("m", MapType(
- StructType(Seq(StructField("c1", IntegerType))), // missing c2
- StructType(Seq(StructField("c3", StringType))))),
- StructField("dep", StringType)))
- val sourceData = Seq(
- Row(1, Map(Row(10) -> Row("z")), "sales"),
- Row(2, Map(Row(20) -> Row("w")), "engineering")
- )
- spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema)
- .createOrReplaceTempView("source")
+ val sourceData = Seq(
+ Row(1, Map(Row(10) -> Row("z")), "sales"),
+ Row(2, Map(Row(20) -> Row("w")), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema)
+ .createOrReplaceTempView("source")
- val mergeStmt =
- s"""MERGE INTO $tableNameAsString t
- |USING source src
- |ON t.pk = src.pk
- |WHEN MATCHED THEN
- | UPDATE SET *
- |WHEN NOT MATCHED THEN
- | INSERT *
- |""".stripMargin
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
- if (coerceNestedTypes) {
- sql(mergeStmt)
- // Missing field c2 should be filled with NULL
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(0, Map(Row(10, true) -> Row("x")), "hr"),
- Row(1, Map(Row(10, null) -> Row("z")), "sales"),
- Row(2, Map(Row(20, null) -> Row("w")), "engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ if (coerceNestedTypes && withSchemaEvolution) {
sql(mergeStmt)
+ // Missing field c2 should be filled with NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Map(Row(10, true) -> Row("x")), "hr"),
+ Row(1, Map(Row(10, null) -> Row("z")), "sales"),
+ Row(2, Map(Row(20, null) -> Row("w")), "engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
}
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot write incompatible data for the table ``: " +
- "Cannot find data for the output column `m`.`key`.`c2`."))
}
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
- sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
}
}
test("merge into with source missing fields in struct nested in map value") {
- Seq(true, false).foreach { coerceNestedTypes =>
- withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
- coerceNestedTypes.toString) {
- withTempView("source") {
- // Target table has struct with 2 fields in map value
- val targetSchema =
- StructType(Seq(
- StructField("pk", IntegerType, nullable = false),
- StructField("m", MapType(
- StructType(Seq(StructField("c1", IntegerType))),
- StructType(Seq(StructField("c1", StringType), StructField("c2", BooleanType))))),
- StructField("dep", StringType)))
- createTable(CatalogV2Util.structTypeToV2Columns(targetSchema))
-
- val targetData = Seq(
- Row(0, Map(Row(10) -> Row("x", true)), "hr"),
- Row(1, Map(Row(20) -> Row("y", false)), "sales"))
- spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema)
- .writeTo(tableNameAsString).append()
-
- // Source table has struct with only 1 field (c1) in map value - missing c2
- val sourceTableSchema = StructType(Seq(
- StructField("pk", IntegerType),
- StructField("m", MapType(
- StructType(Seq(StructField("c1", IntegerType))),
- StructType(Seq(StructField("c1", StringType))))), // missing c2
- StructField("dep", StringType)))
- val sourceData = Seq(
- Row(1, Map(Row(10) -> Row("z")), "sales"),
- Row(2, Map(Row(20) -> Row("w")), "engineering")
- )
- spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema)
- .createOrReplaceTempView("source")
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ withTempView("source") {
+ // Target table has struct with 2 fields in map value
+ val targetSchema =
+ StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("m", MapType(
+ StructType(Seq(StructField("c1", IntegerType))),
+ StructType(Seq(StructField("c1", StringType), StructField("c2", BooleanType))))),
+ StructField("dep", StringType)))
+ createTable(CatalogV2Util.structTypeToV2Columns(targetSchema))
- val mergeStmt =
- s"""MERGE INTO $tableNameAsString t
- |USING source src
- |ON t.pk = src.pk
- |WHEN MATCHED THEN
- | UPDATE SET *
- |WHEN NOT MATCHED THEN
- | INSERT *
- |""".stripMargin
+ val targetData = Seq(
+ Row(0, Map(Row(10) -> Row("x", true)), "hr"),
+ Row(1, Map(Row(20) -> Row("y", false)), "sales"))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema)
+ .writeTo(tableNameAsString).append()
- if (coerceNestedTypes) {
- sql(mergeStmt)
- // Missing field c2 should be filled with NULL
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(0, Map(Row(10) -> Row("x", true)), "hr"),
- Row(1, Map(Row(10) -> Row("z", null)), "sales"),
- Row(2, Map(Row(20) -> Row("w", null)), "engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ // Source table has struct with only 1 field (c1) in map value - missing c2
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("m", MapType(
+ StructType(Seq(StructField("c1", IntegerType))),
+ StructType(Seq(StructField("c1", StringType))))), // missing c2
+ StructField("dep", StringType)))
+ val sourceData = Seq(
+ Row(1, Map(Row(10) -> Row("z")), "sales"),
+ Row(2, Map(Row(20) -> Row("w")), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+
+ if (coerceNestedTypes && withSchemaEvolution) {
sql(mergeStmt)
+ // Missing field c2 should be filled with NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Map(Row(10) -> Row("x", true)), "hr"),
+ Row(1, Map(Row(10) -> Row("z", null)), "sales"),
+ Row(2, Map(Row(20) -> Row("w", null)), "engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
}
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot write incompatible data for the table ``: " +
- "Cannot find data for the output column `m`.`value`.`c2`."))
}
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
- sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
}
}
test("merge into with source missing fields in top-level struct") {
- Seq(true, false).foreach { coerceNestedTypes =>
- withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
- coerceNestedTypes.toString) {
- withTempView("source") {
- // Target table has struct with 3 fields at top level
- createAndInitTable(
- s"""pk INT NOT NULL,
- |s STRUCT,
- |dep STRING""".stripMargin,
- """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""")
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ withTempView("source") {
+ // Target table has struct with 3 fields at top level
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""")
- // Source table has struct with only 2 fields (c1, c2) - missing c3
- val sourceTableSchema = StructType(Seq(
- StructField("pk", IntegerType, nullable = false),
- StructField("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StringType)))), // missing c3 field
- StructField("dep", StringType)))
- val data = Seq(
- Row(1, Row(10, "b"), "hr"),
- Row(2, Row(20, "c"), "engineering")
- )
- spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
- .createOrReplaceTempView("source")
+ // Source table has struct with only 2 fields (c1, c2) - missing c3
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType)))), // missing c3 field
+ StructField("dep", StringType)))
+ val data = Seq(
+ Row(1, Row(10, "b"), "hr"),
+ Row(2, Row(20, "c"), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
+ .createOrReplaceTempView("source")
- val mergeStmt =
- s"""MERGE INTO $tableNameAsString t
- |USING source src
- |ON t.pk = src.pk
- |WHEN MATCHED THEN
- | UPDATE SET *
- |WHEN NOT MATCHED THEN
- | INSERT *
- |""".stripMargin
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
- if (coerceNestedTypes) {
- sql(mergeStmt)
- // Missing field c3 should be filled with NULL
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(0, Row(1, "a", true), "sales"),
- Row(1, Row(10, "b", null), "hr"),
- Row(2, Row(20, "c", null), "engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ if (coerceNestedTypes && withSchemaEvolution) {
sql(mergeStmt)
+ // Missing field c3 should be filled with NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, "a", true), "sales"),
+ Row(1, Row(10, "b", null), "hr"),
+ Row(2, Row(20, "c", null), "engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
}
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot write incompatible data for the table ``: " +
- "Cannot find data for the output column `s`.`c3`."))
}
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
- sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
}
}
+ test("merge into with source missing top-level column") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ withTempView("source") {
+ // Target table has 3 columns: pk, salary, dep
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |salary INT,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "salary": 100, "dep": "sales" }
+ |{ "pk": 1, "salary": 200, "dep": "hr" }"""
+ .stripMargin)
+
+ // Source table has only 2 columns: pk, dep (missing salary)
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("dep", StringType)))
+ val data = Seq(
+ Row(1, "engineering"),
+ Row(2, "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+
+ if (withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, 100, "sales"),
+ Row(1, 200, "engineering"),
+ Row(2, null, "finance")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
+ "UNRESOLVED_COLUMN.WITH_SUGGESTION")
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+
test("merge with null struct") {
withTempView("source") {
createAndInitTable(
@@ -4884,70 +4867,69 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
}
test("merge with with null struct with missing nested field") {
- Seq(true, false).foreach { coerceNestedTypes =>
- withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
- coerceNestedTypes.toString) {
- withTempView("source") {
- // Target table has nested struct with fields c1 and c2
- createAndInitTable(
- s"""pk INT NOT NULL,
- |s STRUCT>,
- |dep STRING""".stripMargin,
- """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } }, "dep": "sales" }
- |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } }, "dep": "hr" }"""
- .stripMargin)
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ withTempView("source") {
+ // Target table has nested struct with fields c1 and c2
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } }, "dep": "sales" }
+ |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } }, "dep": "hr" }"""
+ .stripMargin)
- // Source table has null for the nested struct
- val sourceTableSchema = StructType(Seq(
- StructField("pk", IntegerType),
- StructField("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StructType(Seq(
- StructField("a", IntegerType)
- // missing field 'b'
- )))
- ))),
- StructField("dep", StringType)
- ))
+ // Source table has null for the nested struct
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType)
+ // missing field 'b'
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
- val data = Seq(
- Row(1, null, "engineering"),
- Row(2, null, "finance")
- )
- spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
- .createOrReplaceTempView("source")
+ val data = Seq(
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
+ .createOrReplaceTempView("source")
- val mergeStmt =
- s"""MERGE INTO $tableNameAsString t USING source
- |ON t.pk = source.pk
- |WHEN MATCHED THEN
- | UPDATE SET *
- |WHEN NOT MATCHED THEN
- | INSERT *
- |""".stripMargin
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
- if (coerceNestedTypes) {
- sql(mergeStmt)
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(0, Row(1, Row(10, "x")), "sales"),
- Row(1, null, "engineering"),
- Row(2, null, "finance")))
- } else {
- // Without coercion, the merge should fail due to missing field
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ if (coerceNestedTypes && withSchemaEvolution) {
sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
}
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot write incompatible data for the table ``: " +
- "Cannot find data for the output column `s`.`c2`.`b`."))
}
}
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
- sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
}
@@ -4998,37 +4980,21 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
| INSERT *
|""".stripMargin
- if (coerceNestedTypes) {
- if (withSchemaEvolution) {
- // extra nested field is added
- sql(mergeStmt)
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(0, Row(1, Row(10, "x", null)), "sales"),
- Row(1, null, "engineering"),
- Row(2, null, "finance")))
- } else {
- // extra nested field is not added
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- sql(mergeStmt)
- }
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
- assert(exception.getMessage.contains(
- "Cannot write incompatible data for the table ``: " +
- "Cannot write extra fields `c` to the struct `s`.`c2`"))
- }
+ if (coerceNestedTypes && withSchemaEvolution) {
+ // extra nested field is added
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, Row(10, "x", null)), "sales"),
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")))
} else {
- // Without source struct coercion, the merge should fail
val exception = intercept[org.apache.spark.sql.AnalysisException] {
sql(mergeStmt)
}
assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot write incompatible data for the table ``: " +
- "Cannot find data for the output column `s`.`c2`.`b`."))
}
}
}
@@ -5097,82 +5063,83 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
}
test("merge with null struct using default value") {
- Seq(true, false).foreach { coerceNestedTypes =>
- withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
- coerceNestedTypes.toString) {
- withTempView("source") {
- sql(
- s"""CREATE TABLE $tableNameAsString (
- | pk INT NOT NULL,
- | s STRUCT> DEFAULT
- | named_struct('c1', 999, 'c2', named_struct('a', 999, 'b', 'default')),
- | dep STRING)
- |PARTITIONED BY (dep)
- |""".stripMargin)
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ withTempView("source") {
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ | pk INT NOT NULL,
+ | s STRUCT> DEFAULT
+ | named_struct('c1', 999, 'c2', named_struct('a', 999, 'b', 'default')),
+ | dep STRING)
+ |PARTITIONED BY (dep)
+ |""".stripMargin)
- val initialSchema = StructType(Seq(
- StructField("pk", IntegerType, nullable = false),
- StructField("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StructType(Seq(
- StructField("a", IntegerType),
- StructField("b", StringType)
- )))
- ))),
- StructField("dep", StringType)
- ))
- val initialData = Seq(
- Row(0, Row(1, Row(10, "x")), "sales"),
- Row(1, Row(2, Row(20, "y")), "hr")
- )
- spark.createDataFrame(spark.sparkContext.parallelize(initialData), initialSchema)
- .writeTo(tableNameAsString).append()
+ val initialSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ val initialData = Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, Row(2, Row(20, "y")), "hr")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(initialData), initialSchema)
+ .writeTo(tableNameAsString).append()
- val sourceTableSchema = StructType(Seq(
- StructField("pk", IntegerType),
- StructField("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StructType(Seq(
- StructField("a", IntegerType)
- )))
- ))),
- StructField("dep", StringType)
- ))
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
- val data = Seq(
- Row(1, null, "engineering"),
- Row(2, null, "finance")
- )
- spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
- .createOrReplaceTempView("source")
+ val data = Seq(
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
+ .createOrReplaceTempView("source")
- val mergeStmt =
- s"""MERGE INTO $tableNameAsString t USING source
- |ON t.pk = source.pk
- |WHEN MATCHED THEN
- | UPDATE SET *
- |WHEN NOT MATCHED THEN
- | INSERT *
- |""".stripMargin
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
- if (coerceNestedTypes) {
- sql(mergeStmt)
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(0, Row(1, Row(10, "x")), "sales"),
- Row(1, null, "engineering"),
- Row(2, null, "finance")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ if (coerceNestedTypes && withSchemaEvolution) {
sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
}
- assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot find data for the output column `s`.`c2`.`b`"))
}
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
- sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
}
}
@@ -5243,7 +5210,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
}
test("merge into with source missing fields in nested struct") {
- Seq(true, false).foreach { nestedTypeCoercion =>
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { nestedTypeCoercion =>
withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key
-> nestedTypeCoercion.toString) {
withTempView("source") {
@@ -5275,7 +5243,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
.createOrReplaceTempView("source")
// Missing field b should be filled with NULL
- val mergeStmt = s"""MERGE INTO $tableNameAsString t
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
+ val mergeStmt = s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t
|USING source src
|ON t.pk = src.pk
|WHEN MATCHED THEN
@@ -5284,7 +5253,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
| INSERT *
|""".stripMargin
- if (nestedTypeCoercion) {
+ if (nestedTypeCoercion && withSchemaEvolution) {
sql(mergeStmt)
checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
@@ -5292,16 +5261,17 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
Row(1, Row(10, Row(20, null)), "sales"),
Row(2, Row(20, Row(30, null)), "engineering")))
} else {
- val exception = intercept[Exception] {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
sql(mergeStmt)
}
- assert(exception.getMessage.contains(
- """Cannot write incompatible data for the table ``""".stripMargin))
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
}
}
sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
}
+ }
}
test("merge with named_struct missing non-nullable field") {
@@ -5355,51 +5325,68 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
}
test("merge with struct missing nested field with check constraint") {
- withTempView("source") {
- // Target table has struct with nested field c2
- createAndInitTable(
- s"""pk INT NOT NULL,
- |s STRUCT,
- |dep STRING""".stripMargin,
- """{ "pk": 0, "s": { "c1": 1, "c2": 10 }, "dep": "sales" }
- |{ "pk": 1, "s": { "c1": 2, "c2": 20 }, "dep": "hr" }"""
- .stripMargin)
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coercionEnabled =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coercionEnabled.toString) {
+ withTempView("source") {
+ // Target table has struct with nested field c2
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "s": { "c1": 1, "c2": 10 }, "dep": "sales" }
+ |{ "pk": 1, "s": { "c1": 2, "c2": 20 }, "dep": "hr" }"""
+ .stripMargin)
- // Add CHECK constraint on nested field c2 using ALTER TABLE
- sql(s"ALTER TABLE $tableNameAsString ADD CONSTRAINT check_c2 CHECK " +
- s"(s.c2 IS NOT NULL AND s.c2 > 1)")
+ // Add CHECK constraint on nested field c2 using ALTER TABLE
+ sql(s"ALTER TABLE $tableNameAsString ADD CONSTRAINT check_c2 CHECK " +
+ s"(s.c2 IS NOT NULL AND s.c2 > 1)")
- // Source table schema with struct missing the c2 field
- val sourceTableSchema = StructType(Seq(
- StructField("pk", IntegerType),
- StructField("s", StructType(Seq(
- StructField("c1", IntegerType)
- // missing field 'c2' which has CHECK constraint IS NOT NULL AND > 1
- ))),
- StructField("dep", StringType)
- ))
+ // Source table schema with struct missing the c2 field
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType)
+ // missing field 'c2' which has CHECK constraint IS NOT NULL AND > 1
+ ))),
+ StructField("dep", StringType)
+ ))
- val data = Seq(
- Row(1, Row(100), "engineering"),
- Row(2, Row(200), "finance")
- )
- spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
- .createOrReplaceTempView("source")
+ val data = Seq(
+ Row(1, Row(100), "engineering"),
+ Row(2, Row(200), "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
+ .createOrReplaceTempView("source")
- withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> "true") {
- val error = intercept[SparkRuntimeException] {
- sql(
- s"""MERGE INTO $tableNameAsString t USING source
- |ON t.pk = source.pk
- |WHEN MATCHED THEN
- | UPDATE SET s = source.s, dep = source.dep
- |""".stripMargin)}
- assert(error.getCondition == "CHECK_CONSTRAINT_VIOLATION")
- assert(error.getMessage.contains("CHECK constraint check_c2 s.c2 IS NOT NULL AND " +
- "s.c2 > 1 violated by row with values:\n - s.c2 : null"))
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET s = source.s, dep = source.dep
+ |""".stripMargin
+
+ if (withSchemaEvolution && coercionEnabled) {
+ val error = intercept[SparkRuntimeException] {
+ sql(mergeStmt)
+ }
+ assert(error.getCondition == "CHECK_CONSTRAINT_VIOLATION")
+ assert(error.getMessage.contains("CHECK constraint check_c2 s.c2 IS NOT NULL AND " +
+ "s.c2 > 1 violated by row with values:\n - s.c2 : null"))
+ } else {
+ // Without schema evolution or coercion, the schema mismatch is rejected
+ val error = intercept[AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(error.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
}
}
- sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
test("merge with schema evolution using dataframe API: add new column and set all") {
@@ -5926,30 +5913,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
.whenNotMatched()
.insertAll()
- if (coerceNestedTypes) {
- if (withSchemaEvolution) {
- mergeBuilder.withSchemaEvolution().merge()
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"),
- Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- mergeBuilder.merge()
- }
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
- assert(exception.getMessage.contains(
- "Cannot write extra fields `c3` to the struct `s`.`c2`"))
- }
+ if (coerceNestedTypes && withSchemaEvolution) {
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"),
+ Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering")))
} else {
val exception = intercept[org.apache.spark.sql.AnalysisException] {
mergeBuilder.merge()
}
assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot find data for the output column `s`.`c2`.`a`"))
}
sql(s"DROP TABLE $tableNameAsString")
@@ -6038,30 +6013,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
.whenNotMatched()
.insertAll()
- if (coerceNestedTypes) {
- if (withSchemaEvolution) {
- mergeBuilder.withSchemaEvolution().merge()
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"),
- Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- mergeBuilder.merge()
- }
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
- assert(exception.getMessage.contains(
- "Cannot write extra fields `c3` to the struct `s`.`c2`"))
- }
+ if (coerceNestedTypes && withSchemaEvolution) {
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"),
+ Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering")))
} else {
val exception = intercept[org.apache.spark.sql.AnalysisException] {
mergeBuilder.merge()
}
assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot find data for the output column `s`.`c2`.`a`"))
}
sql(s"DROP TABLE $tableNameAsString")
@@ -6132,198 +6095,190 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
}
test("merge into with source missing fields in top-level struct using dataframe API") {
- Seq(true, false).foreach { coerceNestedTypes =>
- withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
- coerceNestedTypes.toString) {
- val sourceTable = "cat.ns1.source_table"
- withTable(sourceTable) {
- // Target table has struct with 3 fields at top level
- sql(
- s"""CREATE TABLE $tableNameAsString (
- |pk INT NOT NULL,
- |s STRUCT,
- |dep STRING)""".stripMargin)
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ // Target table has struct with 3 fields at top level
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT,
+ |dep STRING)""".stripMargin)
- val targetData = Seq(
- Row(0, Row(1, "a", true), "sales")
- )
- val targetSchema = StructType(Seq(
- StructField("pk", IntegerType, nullable = false),
- StructField("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StringType),
- StructField("c3", BooleanType)
- ))),
- StructField("dep", StringType)
- ))
- spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema)
- .writeTo(tableNameAsString).append()
+ val targetData = Seq(
+ Row(0, Row(1, "a", true), "sales")
+ )
+ val targetSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType),
+ StructField("c3", BooleanType)
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema)
+ .writeTo(tableNameAsString).append()
- // Create source table with struct having only 2 fields (c1, c2) - missing c3
- val sourceIdent = Identifier.of(Array("ns1"), "source_table")
- val columns = Array(
- Column.create("pk", IntegerType, false),
- Column.create("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StringType)))), // missing c3 field
- Column.create("dep", StringType))
- val tableInfo = new TableInfo.Builder()
- .withColumns(columns)
- .withProperties(extraTableProps)
- .build()
- catalog.createTable(sourceIdent, tableInfo)
+ // Create source table with struct having only 2 fields (c1, c2) - missing c3
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType)))), // missing c3 field
+ Column.create("dep", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
- val data = Seq(
- Row(1, Row(10, "b"), "hr"),
- Row(2, Row(20, "c"), "engineering")
- )
- val sourceTableSchema = StructType(Seq(
- StructField("pk", IntegerType, nullable = false),
- StructField("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StringType)))),
- StructField("dep", StringType)))
- spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
- .createOrReplaceTempView("source_temp")
+ val data = Seq(
+ Row(1, Row(10, "b"), "hr"),
+ Row(2, Row(20, "c"), "engineering")
+ )
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType)))),
+ StructField("dep", StringType)))
+ spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
+ .createOrReplaceTempView("source_temp")
- sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
+ sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
- if (coerceNestedTypes) {
- spark.table(sourceTable)
+ val mergeBuilder = spark.table(sourceTable)
.mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk"))
.whenMatched()
.updateAll()
.whenNotMatched()
.insertAll()
- .merge()
- // Missing field c3 should be filled with NULL
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(0, Row(1, "a", true), "sales"),
- Row(1, Row(10, "b", null), "hr"),
- Row(2, Row(20, "c", null), "engineering")))
- } else {
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- spark.table(sourceTable)
- .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk"))
- .whenMatched()
- .updateAll()
- .whenNotMatched()
- .insertAll()
- .merge()
+ if (coerceNestedTypes && withSchemaEvolution) {
+ mergeBuilder.withSchemaEvolution().merge()
+
+ // Missing field c3 should be filled with NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, "a", true), "sales"),
+ Row(1, Row(10, "b", null), "hr"),
+ Row(2, Row(20, "c", null), "engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ mergeBuilder.merge()
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
}
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot write incompatible data for the table ``: " +
- "Cannot find data for the output column `s`.`c3`."))
- }
- sql(s"DROP TABLE $tableNameAsString")
+ sql(s"DROP TABLE $tableNameAsString")
+ }
}
}
}
}
test("merge with null struct with missing nested field using dataframe API") {
- Seq(true, false).foreach { coerceNestedTypes =>
- withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
- coerceNestedTypes.toString) {
- val sourceTable = "cat.ns1.source_table"
- withTable(sourceTable) {
- // Target table has nested struct with fields c1 and c2
- sql(
- s"""CREATE TABLE $tableNameAsString (
- |pk INT NOT NULL,
- |s STRUCT>,
- |dep STRING)""".stripMargin)
-
- val targetData = Seq(
- Row(0, Row(1, Row(10, "x")), "sales"),
- Row(1, Row(2, Row(20, "y")), "hr")
- )
- val targetSchema = StructType(Seq(
- StructField("pk", IntegerType, nullable = false),
- StructField("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StructType(Seq(
- StructField("a", IntegerType),
- StructField("b", StringType)
- )))
- ))),
- StructField("dep", StringType)
- ))
- spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema)
- .writeTo(tableNameAsString).append()
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ val sourceTable = "cat.ns1.source_table"
+ withTable(sourceTable) {
+ // Target table has nested struct with fields c1 and c2
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ |pk INT NOT NULL,
+ |s STRUCT>,
+ |dep STRING)""".stripMargin)
- // Create source table with missing nested field 'b'
- val sourceIdent = Identifier.of(Array("ns1"), "source_table")
- val columns = Array(
- Column.create("pk", IntegerType, false),
- Column.create("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StructType(Seq(
- StructField("a", IntegerType)
- // missing field 'b'
- )))
- ))),
- Column.create("dep", StringType))
- val tableInfo = new TableInfo.Builder()
- .withColumns(columns)
- .withProperties(extraTableProps)
- .build()
- catalog.createTable(sourceIdent, tableInfo)
+ val targetData = Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, Row(2, Row(20, "y")), "hr")
+ )
+ val targetSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema)
+ .writeTo(tableNameAsString).append()
- // Source table has null for the nested struct
- val data = Seq(
- Row(1, null, "engineering"),
- Row(2, null, "finance")
- )
- val sourceTableSchema = StructType(Seq(
- StructField("pk", IntegerType),
- StructField("s", StructType(Seq(
- StructField("c1", IntegerType),
- StructField("c2", StructType(Seq(
- StructField("a", IntegerType)
- )))
- ))),
- StructField("dep", StringType)
- ))
- spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
- .createOrReplaceTempView("source_temp")
+ // Create source table with missing nested field 'b'
+ val sourceIdent = Identifier.of(Array("ns1"), "source_table")
+ val columns = Array(
+ Column.create("pk", IntegerType, false),
+ Column.create("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType)
+ // missing field 'b'
+ )))
+ ))),
+ Column.create("dep", StringType))
+ val tableInfo = new TableInfo.Builder()
+ .withColumns(columns)
+ .withProperties(extraTableProps)
+ .build()
+ catalog.createTable(sourceIdent, tableInfo)
- sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
- val mergeBuilder = spark.table(sourceTable)
- .mergeInto(tableNameAsString,
- $"source_table.pk" === col(tableNameAsString + ".pk"))
- .whenMatched()
- .updateAll()
- .whenNotMatched()
- .insertAll()
+ // Source table has null for the nested struct
+ val data = Seq(
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")
+ )
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType)
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema)
+ .createOrReplaceTempView("source_temp")
- if (coerceNestedTypes) {
- mergeBuilder.merge()
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(0, Row(1, Row(10, "x")), "sales"),
- Row(1, null, "engineering"),
- Row(2, null, "finance")))
- } else {
- // Without coercion, the merge should fail due to missing field
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- mergeBuilder.merge()
+ sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp")
+ val mergeBuilder = spark.table(sourceTable)
+ .mergeInto(tableNameAsString,
+ $"source_table.pk" === col(tableNameAsString + ".pk"))
+ .whenMatched()
+ .updateAll()
+ .whenNotMatched()
+ .insertAll()
+
+ if (coerceNestedTypes && withSchemaEvolution) {
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, Row(10, "x")), "sales"),
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ mergeBuilder.merge()
+ }
+ assert(exception.errorClass.get ==
+ "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
}
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot write incompatible data for the table ``: " +
- "Cannot find data for the output column `s`.`c2`.`b`."))
- }
- sql(s"DROP TABLE $tableNameAsString")
+ sql(s"DROP TABLE $tableNameAsString")
+ }
}
}
}
@@ -6409,37 +6364,21 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
.whenNotMatched()
.insertAll()
- if (coerceNestedTypes) {
- if (withSchemaEvolution) {
- // extra nested field is added
- mergeBuilder.withSchemaEvolution().merge()
- checkAnswer(
- sql(s"SELECT * FROM $tableNameAsString"),
- Seq(
- Row(0, Row(1, Row(10, "x", null)), "sales"),
- Row(1, null, "engineering"),
- Row(2, null, "finance")))
- } else {
- // extra nested field is not added
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- mergeBuilder.merge()
- }
- assert(exception.errorClass.get ==
- "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
- assert(exception.getMessage.contains(
- "Cannot write incompatible data for the table ``: " +
- "Cannot write extra fields `c` to the struct `s`.`c2`"))
- }
+ if (coerceNestedTypes && withSchemaEvolution) {
+ // extra nested field is added
+ mergeBuilder.withSchemaEvolution().merge()
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, Row(10, "x", null)), "sales"),
+ Row(1, null, "engineering"),
+ Row(2, null, "finance")))
} else {
- // Without source struct coercion, the merge should fail
val exception = intercept[org.apache.spark.sql.AnalysisException] {
mergeBuilder.merge()
}
assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains(
- "Cannot write incompatible data for the table ``: " +
- "Cannot find data for the output column `s`.`c2`.`b`."))
}
sql(s"DROP TABLE $tableNameAsString")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
index f635131dc3f7..ddc864447141 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
@@ -690,28 +690,36 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase {
|""".stripMargin)
assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i"))
- Seq(true, false).foreach { coerceNestedTypes =>
- withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
- coerceNestedTypes.toString) {
- val mergeStmt =
- s"""MERGE INTO nested_struct_table t USING nested_struct_table src
- |ON t.i = src.i
- |$clause THEN
- | UPDATE SET s.n_s = named_struct('dn_i', 2L)
- |""".stripMargin
- if (coerceNestedTypes) {
- val plan5 = parseAndResolve(mergeStmt)
- // No null check for dn_i as it is explicitly set
- assertNoNullCheckExists(plan5)
- } else {
- val e = intercept[AnalysisException] {
- parseAndResolve(mergeStmt)
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ val schemaEvolutionString = if (withSchemaEvolution) {
+ "WITH SCHEMA EVOLUTION"
+ } else {
+ ""
+ }
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionString INTO nested_struct_table t
+ |USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s = named_struct('dn_i', 2L)
+ |""".stripMargin
+ if (coerceNestedTypes && withSchemaEvolution) {
+ val plan5 = parseAndResolve(mergeStmt)
+ // No null check for dn_i as it is explicitly set
+ assertNoNullCheckExists(plan5)
+ } else {
+ val e = intercept[AnalysisException] {
+ parseAndResolve(mergeStmt)
+ }
+ checkError(
+ exception = e,
+ condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
+ parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`")
+ )
}
- checkError(
- exception = e,
- condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
- parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`")
- )
}
}
}
@@ -857,27 +865,35 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase {
|""".stripMargin)
assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i"))
- Seq(true, false).foreach { coerceNestedTypes =>
- withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
- coerceNestedTypes.toString) {
- val mergeStmt =
- s"""MERGE INTO nested_struct_table t USING nested_struct_table src
- |ON t.i = src.i
- |$clause THEN
- | UPDATE SET s.n_s = named_struct('dn_i', 1)
- |""".stripMargin
- if (coerceNestedTypes) {
- val plan5 = parseAndResolve(mergeStmt)
- // No null check for dn_i as it is explicitly set
- assertNoNullCheckExists(plan5)
- } else {
- val e = intercept[AnalysisException] {
- parseAndResolve(mergeStmt)
+ Seq(true, false).foreach { withSchemaEvolution =>
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ val schemaEvolutionString = if (withSchemaEvolution) {
+ "WITH SCHEMA EVOLUTION"
+ } else {
+ ""
+ }
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionString INTO nested_struct_table t
+ |USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s = named_struct('dn_i', 1)
+ |""".stripMargin
+ if (coerceNestedTypes && withSchemaEvolution) {
+ val plan5 = parseAndResolve(mergeStmt)
+ // No null check for dn_i as it is explicitly set
+ assertNoNullCheckExists(plan5)
+ } else {
+ val e = intercept[AnalysisException] {
+ parseAndResolve(mergeStmt)
+ }
+ checkError(
+ exception = e,
+ condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
+ parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`"))
}
- checkError(
- exception = e,
- condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
- parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`"))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
index 18042bf73adf..8e5ee1644f9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
@@ -1633,10 +1633,10 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest {
if (starInUpdate) {
assert(updateAssigns.size == 2)
- assert(updateAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ts))
- assert(updateAssigns(0).value.asInstanceOf[AttributeReference].sameRef(ss))
- assert(updateAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ti))
- assert(updateAssigns(1).value.asInstanceOf[AttributeReference].sameRef(si))
+ assert(updateAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti))
+ assert(updateAssigns(0).value.asInstanceOf[AttributeReference].sameRef(si))
+ assert(updateAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ts))
+ assert(updateAssigns(1).value.asInstanceOf[AttributeReference].sameRef(ss))
} else {
assert(updateAssigns.size == 1)
assert(updateAssigns.head.key.asInstanceOf[AttributeReference].sameRef(ts))
@@ -1656,10 +1656,10 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest {
if (starInInsert) {
assert(insertAssigns.size == 2)
- assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ts))
- assert(insertAssigns(0).value.asInstanceOf[AttributeReference].sameRef(ss))
- assert(insertAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ti))
- assert(insertAssigns(1).value.asInstanceOf[AttributeReference].sameRef(si))
+ assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti))
+ assert(insertAssigns(0).value.asInstanceOf[AttributeReference].sameRef(si))
+ assert(insertAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ts))
+ assert(insertAssigns(1).value.asInstanceOf[AttributeReference].sameRef(ss))
} else {
assert(insertAssigns.size == 2)
assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti))
@@ -1720,8 +1720,40 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest {
case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString)
}
+ // star with schema evolution
+ val sqlStarSchemaEvolution =
+ s"""
+ |MERGE WITH SCHEMA EVOLUTION INTO $target AS target
+ |USING $source AS source
+ |ON target.i = source.i
+ |WHEN MATCHED AND (target.s='delete') THEN DELETE
+ |WHEN MATCHED AND (target.s='update') THEN UPDATE SET *
+ |WHEN NOT MATCHED AND (source.s='insert') THEN INSERT *
+ """.stripMargin
+ parseAndResolve(sqlStarSchemaEvolution) match {
+ case MergeIntoTable(
+ SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)),
+ SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)),
+ mergeCondition,
+ Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))),
+ UpdateAction(Some(EqualTo(ul: AttributeReference,
+ StringLiteral("update"))), updateAssigns, _)),
+ Seq(InsertAction(Some(EqualTo(il: AttributeReference, StringLiteral("insert"))),
+ insertAssigns)),
+ Seq(),
+ withSchemaEvolution) =>
+ checkMergeConditionResolution(target, source, mergeCondition)
+ checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns,
+ starInUpdate = true)
+ checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns,
+ starInInsert = true)
+ assert(withSchemaEvolution === true)
+
+ case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString)
+ }
+
// star
- val sql2 =
+ val sqlStarWithoutSchemaEvolution =
s"""
|MERGE INTO $target AS target
|USING $source AS source
@@ -1730,7 +1762,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest {
|WHEN MATCHED AND (target.s='update') THEN UPDATE SET *
|WHEN NOT MATCHED AND (source.s='insert') THEN INSERT *
""".stripMargin
- parseAndResolve(sql2) match {
+ parseAndResolve(sqlStarWithoutSchemaEvolution) match {
case MergeIntoTable(
SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)),
SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)),
@@ -2336,24 +2368,11 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest {
|USING testcat.tab2
|ON 1 = 1
|WHEN MATCHED THEN UPDATE SET *""".stripMargin
- val parsed2 = parseAndResolve(sql2)
- parsed2 match {
- case MergeIntoTable(
- AsDataSourceV2Relation(target),
- AsDataSourceV2Relation(source),
- EqualTo(IntegerLiteral(1), IntegerLiteral(1)),
- Seq(UpdateAction(None, updateAssigns, _)), // Matched actions
- Seq(), // Not matched actions
- Seq(), // Not matched by source actions
- withSchemaEvolution) =>
- val ti = target.output.find(_.name == "i").get
- val si = source.output.find(_.name == "i").get
- assert(updateAssigns.size == 1)
- assert(updateAssigns.head.key.asInstanceOf[AttributeReference].sameRef(ti))
- assert(updateAssigns.head.value.asInstanceOf[AttributeReference].sameRef(si))
- assert(withSchemaEvolution === false)
- case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString)
- }
+ checkError(
+ exception = intercept[AnalysisException](parseAndResolve(sql2)),
+ condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ parameters = Map("objectName" -> "`s`", "proposal" -> "`i`, `x`"),
+ context = ExpectedContext(fragment = sql2, start = 0, stop = 80))
// INSERT * with incompatible schema between source and target tables.
val sql3 =
@@ -2361,24 +2380,11 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest {
|USING testcat.tab2
|ON 1 = 1
|WHEN NOT MATCHED THEN INSERT *""".stripMargin
- val parsed3 = parseAndResolve(sql3)
- parsed3 match {
- case MergeIntoTable(
- AsDataSourceV2Relation(target),
- AsDataSourceV2Relation(source),
- EqualTo(IntegerLiteral(1), IntegerLiteral(1)),
- Seq(), // Matched action
- Seq(InsertAction(None, insertAssigns)), // Not matched actions
- Seq(), // Not matched by source actions
- withSchemaEvolution) =>
- val ti = target.output.find(_.name == "i").get
- val si = source.output.find(_.name == "i").get
- assert(insertAssigns.size == 1)
- assert(insertAssigns.head.key.asInstanceOf[AttributeReference].sameRef(ti))
- assert(insertAssigns.head.value.asInstanceOf[AttributeReference].sameRef(si))
- assert(withSchemaEvolution === false)
- case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString)
- }
+ checkError(
+ exception = intercept[AnalysisException](parseAndResolve(sql3)),
+ condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ parameters = Map("objectName" -> "`s`", "proposal" -> "`i`, `x`"),
+ context = ExpectedContext(fragment = sql3, start = 0, stop = 80))
val sql4 =
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala
new file mode 100644
index 000000000000..15682bcf68f1
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.Connection
+
+import org.mockito.Mockito._
+import org.scalatestplus.mockito.MockitoSugar
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+
+class PostgresDialectSuite extends SparkFunSuite with MockitoSugar {
+
+ private def createJDBCOptions(extraOptions: Map[String, String]): JDBCOptions = {
+ new JDBCOptions(Map(
+ "url" -> "jdbc:postgresql://localhost:5432/test",
+ "dbtable" -> "test_table"
+ ) ++ extraOptions)
+ }
+
+ test("beforeFetch sets autoCommit=false with lowercase fetchsize") {
+ val conn = mock[Connection]
+ val dialect = PostgresDialect()
+ dialect.beforeFetch(conn, createJDBCOptions(Map("fetchsize" -> "100")))
+ verify(conn).setAutoCommit(false)
+ }
+
+ test("beforeFetch sets autoCommit=false with camelCase fetchSize") {
+ val conn = mock[Connection]
+ val dialect = PostgresDialect()
+ dialect.beforeFetch(conn, createJDBCOptions(Map("fetchSize" -> "100")))
+ verify(conn).setAutoCommit(false)
+ }
+
+ test("beforeFetch sets autoCommit=false with uppercase FETCHSIZE") {
+ val conn = mock[Connection]
+ val dialect = PostgresDialect()
+ dialect.beforeFetch(conn, createJDBCOptions(Map("FETCHSIZE" -> "100")))
+ verify(conn).setAutoCommit(false)
+ }
+
+ test("beforeFetch does not set autoCommit when fetchSize is 0") {
+ val conn = mock[Connection]
+ val dialect = PostgresDialect()
+ dialect.beforeFetch(conn, createJDBCOptions(Map("fetchsize" -> "0")))
+ verify(conn, never()).setAutoCommit(false)
+ }
+}