diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index f769067557b8..d51e922f025a 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -42,11 +42,6 @@ ${project.version} - - org.apache.commons - commons-lang3 - - io.dropwizard.metrics metrics-core diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 3d885ffdb02d..63484c23a920 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -212,6 +212,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( protected val hideTraceback: Boolean = false protected val simplifiedTraceback: Boolean = false + protected val runnerConf: Map[String, String] = Map.empty + // All the Python functions should have the same exec, version and envvars. protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars protected val pythonExec: String = funcs.head.funcs.head.pythonExec @@ -403,6 +405,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( */ protected def writeCommand(dataOut: DataOutputStream): Unit + /** + * Writes worker configuration to the stream connected to the Python worker. + */ + protected def writeRunnerConf(dataOut: DataOutputStream): Unit = { + dataOut.writeInt(runnerConf.size) + for ((k, v) <- runnerConf) { + PythonWorkerUtils.writeUTF(k, dataOut) + PythonWorkerUtils.writeUTF(v, dataOut) + } + } + /** * Writes input data to the stream connected to the Python worker. * Returns true if any data was written to the stream, false if the input is exhausted. @@ -532,6 +545,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut) dataOut.writeInt(evalType) + writeRunnerConf(dataOut) writeCommand(dataOut) dataOut.flush() diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 53c791a3446c..aa17f954a7d9 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -407,10 +407,6 @@ private[spark] class Executor( TaskState.FAILED, env.closureSerializer.newInstance().serialize(new ExceptionFailure(t, Seq.empty))) } catch { - case oom: OutOfMemoryError => - logError(log"Executor update launching task ${MDC(TASK_NAME, taskDescription.name)} " + - log"failed status failed, reason: ${MDC(REASON, oom.getMessage)}") - System.exit(SparkExitCode.OOM) case t: Throwable => logError(log"Executor update launching task ${MDC(TASK_NAME, taskDescription.name)} " + log"failed status failed, reason: ${MDC(REASON, t.getMessage)}") diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 470082330509..270af9c205b4 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -30,7 +30,7 @@ azure-storage/7.0.1//azure-storage-7.0.1.jar blas/3.0.4//blas-3.0.4.jar breeze-macros_2.13/2.1.0//breeze-macros_2.13-2.1.0.jar breeze_2.13/2.1.0//breeze_2.13-2.1.0.jar -bundle/2.29.52//bundle-2.29.52.jar +bundle/2.35.4//bundle-2.35.4.jar cats-kernel_2.13/2.8.0//cats-kernel_2.13-2.8.0.jar chill-java/0.10.0//chill-java-0.10.0.jar chill_2.13/0.10.0//chill_2.13-0.10.0.jar diff --git a/pom.xml b/pom.xml index ea23bbf20d0b..95968580cc21 100644 --- a/pom.xml +++ b/pom.xml @@ -162,7 +162,7 @@ 1.15.3 1.12.681 - 2.29.52 + 2.35.4 1.0.5 @@ -233,7 +233,7 @@ ./python/packaging/client/setup.py, and ./python/packaging/connect/setup.py too. --> 18.3.0 - 3.0.4 + 3.0.5 0.12.6 diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 409c33d1727a..7495569ecfda 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -163,14 +163,26 @@ def handle_sigterm(*args): # Initialization complete try: + poller = None + if os.name == "posix": + # select.select has a known limit on the number of file descriptors + # it can handle. We use select.poll instead to avoid this limit. + poller = select.poll() + fd_reverse_map = {0: 0, listen_sock.fileno(): listen_sock} + poller.register(0, select.POLLIN) + poller.register(listen_sock, select.POLLIN) + while True: - try: - ready_fds = select.select([0, listen_sock], [], [], 1)[0] - except select.error as ex: - if ex[0] == EINTR: - continue - else: - raise + if poller is not None: + ready_fds = [fd_reverse_map[fd] for fd, _ in poller.poll(1000)] + else: + try: + ready_fds = select.select([0, listen_sock], [], [], 1)[0] + except select.error as ex: + if ex[0] == EINTR: + continue + else: + raise if 0 in ready_fds: try: @@ -208,6 +220,9 @@ def handle_sigterm(*args): if pid == 0: # in child process + if poller is not None: + poller.unregister(0) + poller.unregister(listen_sock) listen_sock.close() # It should close the standard input in the child process so that @@ -256,6 +271,9 @@ def handle_sigterm(*args): sock.close() finally: + if poller is not None: + poller.unregister(0) + poller.unregister(listen_sock) shutdown(1) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index f5bfd59694af..abecbf90aa95 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -622,7 +622,7 @@ class _LinearSVCParams( ) def __init__(self, *args: Any) -> None: - super(_LinearSVCParams, self).__init__(*args) + super().__init__(*args) self._setDefault( maxIter=100, regParam=0.0, @@ -743,7 +743,7 @@ def __init__( fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \ aggregationDepth=2, maxBlockSizeInMB=0.0): """ - super(LinearSVC, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LinearSVC", self.uid ) @@ -1019,7 +1019,7 @@ class _LogisticRegressionParams( ) def __init__(self, *args: Any): - super(_LogisticRegressionParams, self).__init__(*args) + super().__init__(*args) self._setDefault( maxIter=100, regParam=0.0, tol=1e-6, threshold=0.5, family="auto", maxBlockSizeInMB=0.0 ) @@ -1328,7 +1328,7 @@ def __init__( maxBlockSizeInMB=0.0): If the threshold and thresholds Params are both set, they must be equivalent. """ - super(LogisticRegression, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid ) @@ -1676,7 +1676,7 @@ class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams): """ def __init__(self, *args: Any): - super(_DecisionTreeClassifierParams, self).__init__(*args) + super().__init__(*args) self._setDefault( maxDepth=5, maxBins=32, @@ -1809,7 +1809,7 @@ def __init__( maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0) """ - super(DecisionTreeClassifier, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid ) @@ -1970,7 +1970,7 @@ class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams): """ def __init__(self, *args: Any): - super(_RandomForestClassifierParams, self).__init__(*args) + super().__init__(*args) self._setDefault( maxDepth=5, maxBins=32, @@ -2106,7 +2106,7 @@ def __init__( numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \ leafCol="", minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True) """ - super(RandomForestClassifier, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.RandomForestClassifier", self.uid ) @@ -2400,7 +2400,7 @@ class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity): ) def __init__(self, *args: Any): - super(_GBTClassifierParams, self).__init__(*args) + super().__init__(*args) self._setDefault( maxDepth=5, maxBins=32, @@ -2577,7 +2577,7 @@ def __init__( validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, \ weightCol=None) """ - super(GBTClassifier, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.GBTClassifier", self.uid ) @@ -2823,7 +2823,7 @@ class _NaiveBayesParams(_PredictorParams, HasWeightCol): ) def __init__(self, *args: Any): - super(_NaiveBayesParams, self).__init__(*args) + super().__init__(*args) self._setDefault(smoothing=1.0, modelType="multinomial") @since("1.5.0") @@ -2964,7 +2964,7 @@ def __init__( probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ modelType="multinomial", thresholds=None, weightCol=None) """ - super(NaiveBayes, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.NaiveBayes", self.uid ) @@ -3093,7 +3093,7 @@ class _MultilayerPerceptronParams( ) def __init__(self, *args: Any): - super(_MultilayerPerceptronParams, self).__init__(*args) + super().__init__(*args) self._setDefault(maxIter=100, tol=1e-6, blockSize=128, stepSize=0.03, solver="l-bfgs") @since("1.6.0") @@ -3219,7 +3219,7 @@ def __init__( solver="l-bfgs", initialWeights=None, probabilityCol="probability", \ rawPredictionCol="rawPrediction") """ - super(MultilayerPerceptronClassifier, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid ) @@ -3484,7 +3484,7 @@ def __init__( __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ rawPredictionCol="rawPrediction", classifier=None, weightCol=None, parallelism=1): """ - super(OneVsRest, self).__init__() + super().__init__() self._setDefault(parallelism=1) kwargs = self._input_kwargs self._set(**kwargs) @@ -3749,7 +3749,7 @@ def validateParams(instance: Union[OneVsRest, "OneVsRestModel"]) -> None: @inherit_doc class OneVsRestReader(MLReader[OneVsRest]): def __init__(self, cls: Type[OneVsRest]) -> None: - super(OneVsRestReader, self).__init__() + super().__init__() self.cls = cls def load(self, path: str) -> OneVsRest: @@ -3765,7 +3765,7 @@ def load(self, path: str) -> OneVsRest: @inherit_doc class OneVsRestWriter(MLWriter): def __init__(self, instance: OneVsRest): - super(OneVsRestWriter, self).__init__() + super().__init__() self.instance = instance def saveImpl(self, path: str) -> None: @@ -3807,7 +3807,7 @@ def setRawPredictionCol(self, value: str) -> "OneVsRestModel": return self._set(rawPredictionCol=value) def __init__(self, models: List[ClassificationModel]): - super(OneVsRestModel, self).__init__() + super().__init__() self.models = models if is_remote() or not isinstance(models[0], JavaMLWritable): return @@ -3980,7 +3980,7 @@ def write(self) -> MLWriter: @inherit_doc class OneVsRestModelReader(MLReader[OneVsRestModel]): def __init__(self, cls: Type[OneVsRestModel]): - super(OneVsRestModelReader, self).__init__() + super().__init__() self.cls = cls def load(self, path: str) -> OneVsRestModel: @@ -4002,7 +4002,7 @@ def load(self, path: str) -> OneVsRestModel: @inherit_doc class OneVsRestModelWriter(MLWriter): def __init__(self, instance: OneVsRestModel): - super(OneVsRestModelWriter, self).__init__() + super().__init__() self.instance = instance def saveImpl(self, path: str) -> None: @@ -4119,7 +4119,7 @@ def __init__( miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \ tol=1e-6, solver="adamW", thresholds=None, seed=None) """ - super(FMClassifier, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.FMClassifier", self.uid ) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 0e26398de3c4..0fc2b34d1748 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -169,7 +169,7 @@ class _GaussianMixtureParams( ) def __init__(self, *args: Any): - super(_GaussianMixtureParams, self).__init__(*args) + super().__init__(*args) self._setDefault(k=2, tol=0.01, maxIter=100, aggregationDepth=2) @since("2.0.0") @@ -422,7 +422,7 @@ def __init__( probabilityCol="probability", tol=0.01, maxIter=100, seed=None, \ aggregationDepth=2, weightCol=None) """ - super(GaussianMixture, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.clustering.GaussianMixture", self.uid ) @@ -617,7 +617,7 @@ class _KMeansParams( ) def __init__(self, *args: Any): - super(_KMeansParams, self).__init__(*args) + super().__init__(*args) self._setDefault( k=2, initMode="k-means||", @@ -800,7 +800,7 @@ def __init__( distanceMeasure="euclidean", weightCol=None, solver="auto", \ maxBlockSizeInMB=0.0) """ - super(KMeans, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -952,7 +952,7 @@ class _BisectingKMeansParams( ) def __init__(self, *args: Any): - super(_BisectingKMeansParams, self).__init__(*args) + super().__init__(*args) self._setDefault(maxIter=20, k=4, minDivisibleClusterSize=1.0) @since("2.0.0") @@ -1146,7 +1146,7 @@ def __init__( seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean", \ weightCol=None) """ - super(BisectingKMeans, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.clustering.BisectingKMeans", self.uid ) @@ -1337,7 +1337,7 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval): ) def __init__(self, *args: Any): - super(_LDAParams, self).__init__(*args) + super().__init__(*args) self._setDefault( maxIter=20, checkpointInterval=10, @@ -1710,7 +1710,7 @@ def __init__( docConcentration=None, topicConcentration=None,\ topicDistributionCol="topicDistribution", keepLastCheckpoint=True) """ - super(LDA, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.LDA", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1948,7 +1948,7 @@ class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol): ) def __init__(self, *args: Any): - super(_PowerIterationClusteringParams, self).__init__(*args) + super().__init__(*args) self._setDefault(k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst") @since("2.4.0") @@ -2054,7 +2054,7 @@ def __init__( __init__(self, \\*, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\ weightCol=None) """ - super(PowerIterationClustering, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid ) diff --git a/python/pyspark/ml/connect/classification.py b/python/pyspark/ml/connect/classification.py index e9e7840b098b..3263f47e6135 100644 --- a/python/pyspark/ml/connect/classification.py +++ b/python/pyspark/ml/connect/classification.py @@ -64,7 +64,7 @@ class _LogisticRegressionParams( """ def __init__(self, *args: Any): - super(_LogisticRegressionParams, self).__init__(*args) + super().__init__(*args) self._setDefault( maxIter=100, tol=1e-6, @@ -215,7 +215,7 @@ def __init__( seed: int = 0, ) """ - super(LogisticRegression, self).__init__() + super().__init__() kwargs = self._input_kwargs self._set(**kwargs) diff --git a/python/pyspark/ml/connect/pipeline.py b/python/pyspark/ml/connect/pipeline.py index 9850f7a0fd0f..daba4e6fa218 100644 --- a/python/pyspark/ml/connect/pipeline.py +++ b/python/pyspark/ml/connect/pipeline.py @@ -144,7 +144,7 @@ def __init__(self, *, stages: Optional[List[Params]] = None): """ __init__(self, \\*, stages=None) """ - super(Pipeline, self).__init__() + super().__init__() kwargs = self._input_kwargs self.setParams(**kwargs) @@ -248,7 +248,7 @@ class PipelineModel(Model, _PipelineReadWrite): """ def __init__(self, stages: Optional[List[Params]] = None): - super(PipelineModel, self).__init__() + super().__init__() self.stages = stages # type: ignore[assignment] def _transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]: diff --git a/python/pyspark/ml/connect/tuning.py b/python/pyspark/ml/connect/tuning.py index 1ef055d25007..39aae6f17cc5 100644 --- a/python/pyspark/ml/connect/tuning.py +++ b/python/pyspark/ml/connect/tuning.py @@ -118,7 +118,7 @@ class _CrossValidatorParams(_ValidatorParams): ) def __init__(self, *args: Any): - super(_CrossValidatorParams, self).__init__(*args) + super().__init__(*args) self._setDefault(numFolds=3, foldCol="") @since("1.4.0") @@ -326,7 +326,7 @@ def __init__( __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ seed=None, parallelism=1, foldCol="") """ - super(CrossValidator, self).__init__() + super().__init__() self._setDefault(parallelism=1) kwargs = self._input_kwargs self._set(**kwargs) @@ -539,7 +539,7 @@ def __init__( avgMetrics: Optional[List[float]] = None, stdMetrics: Optional[List[float]] = None, ) -> None: - super(CrossValidatorModel, self).__init__() + super().__init__() #: best model from cross validation self.bestModel = bestModel #: Average cross-validation metrics for each paramMap in diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 568583eb08ec..56747d07441b 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -235,7 +235,7 @@ def __init__( __init__(self, \\*, rawPredictionCol="rawPrediction", labelCol="label", \ metricName="areaUnderROC", weightCol=None, numBins=1000) """ - super(BinaryClassificationEvaluator, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid ) @@ -397,7 +397,7 @@ def __init__( __init__(self, \\*, predictionCol="prediction", labelCol="label", \ metricName="rmse", weightCol=None, throughOrigin=False) """ - super(RegressionEvaluator, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid ) @@ -593,7 +593,7 @@ def __init__( metricName="f1", weightCol=None, metricLabel=0.0, beta=1.0, \ probabilityCol="probability", eps=1e-15) """ - super(MulticlassClassificationEvaluator, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid ) @@ -790,7 +790,7 @@ def __init__( __init__(self, \\*, predictionCol="prediction", labelCol="label", \ metricName="f1Measure", metricLabel=0.0) """ - super(MultilabelClassificationEvaluator, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator", self.uid ) @@ -947,7 +947,7 @@ def __init__( __init__(self, \\*, predictionCol="prediction", featuresCol="features", \ metricName="silhouette", distanceMeasure="squaredEuclidean", weightCol=None) """ - super(ClusteringEvaluator, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid ) @@ -1095,7 +1095,7 @@ def __init__( __init__(self, \\*, predictionCol="prediction", labelCol="label", \ metricName="meanAveragePrecision", k=10) """ - super(RankingEvaluator, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.RankingEvaluator", self.uid ) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 4d1551652028..c5a5033cd8f0 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -248,7 +248,7 @@ def __init__( __init__(self, \\*, threshold=0.0, inputCol=None, outputCol=None, thresholds=None, \ inputCols=None, outputCols=None) """ - super(Binarizer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid) self._setDefault(threshold=0.0) kwargs = self._input_kwargs @@ -350,7 +350,7 @@ class _LSHParams(HasInputCol, HasOutputCol): ) def __init__(self, *args: Any): - super(_LSHParams, self).__init__(*args) + super().__init__(*args) self._setDefault(numHashTables=1) def getNumHashTables(self) -> int: @@ -603,7 +603,7 @@ def __init__( __init__(self, \\*, inputCol=None, outputCol=None, seed=None, numHashTables=1, \ bucketLength=None) """ - super(BucketedRandomProjectionLSH, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.BucketedRandomProjectionLSH", self.uid ) @@ -820,7 +820,7 @@ def __init__( __init__(self, \\*, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ splitsArray=None, inputCols=None, outputCols=None) """ - super(Bucketizer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) self._setDefault(handleInvalid="error") kwargs = self._input_kwargs @@ -985,7 +985,7 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): ) def __init__(self, *args: Any): - super(_CountVectorizerParams, self).__init__(*args) + super().__init__(*args) self._setDefault(minTF=1.0, minDF=1.0, maxDF=2**63 - 1, vocabSize=1 << 18, binary=False) @since("1.6.0") @@ -1105,7 +1105,7 @@ def __init__( __init__(self, \\*, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18,\ binary=False, inputCol=None,outputCol=None) """ - super(CountVectorizer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1339,7 +1339,7 @@ def __init__( """ __init__(self, \\*, inverse=False, inputCol=None, outputCol=None) """ - super(DCT, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid) self._setDefault(inverse=False) kwargs = self._input_kwargs @@ -1447,7 +1447,7 @@ def __init__( """ __init__(self, \\*, scalingVec=None, inputCol=None, outputCol=None) """ - super(ElementwiseProduct, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.ElementwiseProduct", self.uid ) @@ -1585,7 +1585,7 @@ def __init__( __init__(self, \\*, numFeatures=1 << 18, inputCols=None, outputCol=None, \ categoricalCols=None) """ - super(FeatureHasher, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.FeatureHasher", self.uid) self._setDefault(numFeatures=1 << 18) kwargs = self._input_kwargs @@ -1708,7 +1708,7 @@ def __init__( """ __init__(self, \\*, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None) """ - super(HashingTF, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) self._setDefault(numFeatures=1 << 18, binary=False) kwargs = self._input_kwargs @@ -1795,7 +1795,7 @@ def getMinDocFreq(self) -> int: return self.getOrDefault(self.minDocFreq) def __init__(self, *args: Any): - super(_IDFParams, self).__init__(*args) + super().__init__(*args) self._setDefault(minDocFreq=0) @@ -1859,7 +1859,7 @@ def __init__( """ __init__(self, \\*, minDocFreq=0, inputCol=None, outputCol=None) """ - super(IDF, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1975,7 +1975,7 @@ class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, Has ) def __init__(self, *args: Any): - super(_ImputerParams, self).__init__(*args) + super().__init__(*args) self._setDefault(strategy="mean", missingValue=float("nan"), relativeError=0.001) @since("2.2.0") @@ -2152,7 +2152,7 @@ def __init__( __init__(self, \\*, strategy="mean", missingValue=float("nan"), inputCols=None, \ outputCols=None, inputCol=None, outputCol=None, relativeError=0.001): """ - super(Imputer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Imputer", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -2351,7 +2351,7 @@ def __init__(self, *, inputCols: Optional[List[str]] = None, outputCol: Optional """ __init__(self, \\*, inputCols=None, outputCol=None): """ - super(Interaction, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Interaction", self.uid) self._setDefault() kwargs = self._input_kwargs @@ -2449,7 +2449,7 @@ def __init__(self, *, inputCol: Optional[str] = None, outputCol: Optional[str] = """ __init__(self, \\*, inputCol=None, outputCol=None) """ - super(MaxAbsScaler, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MaxAbsScaler", self.uid) self._setDefault() kwargs = self._input_kwargs @@ -2602,7 +2602,7 @@ def __init__( """ __init__(self, \\*, inputCol=None, outputCol=None, seed=None, numHashTables=1) """ - super(MinHashLSH, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinHashLSH", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -2672,7 +2672,7 @@ class _MinMaxScalerParams(HasInputCol, HasOutputCol): ) def __init__(self, *args: Any): - super(_MinMaxScalerParams, self).__init__(*args) + super().__init__(*args) self._setDefault(min=0.0, max=1.0) @since("1.6.0") @@ -2767,7 +2767,7 @@ def __init__( """ __init__(self, \\*, min=0.0, max=1.0, inputCol=None, outputCol=None) """ - super(MinMaxScaler, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -2934,7 +2934,7 @@ def __init__( """ __init__(self, \\*, n=2, inputCol=None, outputCol=None) """ - super(NGram, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) self._setDefault(n=2) kwargs = self._input_kwargs @@ -3029,7 +3029,7 @@ def __init__( """ __init__(self, \\*, p=2.0, inputCol=None, outputCol=None) """ - super(Normalizer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid) self._setDefault(p=2.0) kwargs = self._input_kwargs @@ -3102,7 +3102,7 @@ class _OneHotEncoderParams( ) def __init__(self, *args: Any): - super(_OneHotEncoderParams, self).__init__(*args) + super().__init__(*args) self._setDefault(handleInvalid="error", dropLast=True) @since("2.3.0") @@ -3221,7 +3221,7 @@ def __init__( __init__(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \ inputCol=None, outputCol=None) """ - super(OneHotEncoder, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -3430,7 +3430,7 @@ def __init__( """ __init__(self, \\*, degree=2, inputCol=None, outputCol=None) """ - super(PolynomialExpansion, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.PolynomialExpansion", self.uid ) @@ -3658,7 +3658,7 @@ def __init__( __init__(self, \\*, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \ handleInvalid="error", numBucketsArray=None, inputCols=None, outputCols=None) """ - super(QuantileDiscretizer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.QuantileDiscretizer", self.uid ) @@ -3850,7 +3850,7 @@ class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError): ) def __init__(self, *args: Any): - super(_RobustScalerParams, self).__init__(*args) + super().__init__(*args) self._setDefault( lower=0.25, upper=0.75, withCentering=False, withScaling=True, relativeError=0.001 ) @@ -3958,7 +3958,7 @@ def __init__( __init__(self, \\*, lower=0.25, upper=0.75, withCentering=False, withScaling=True, \ inputCol=None, outputCol=None, relativeError=0.001) """ - super(RobustScaler, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RobustScaler", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -4170,7 +4170,7 @@ def __init__( __init__(self, \\*, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, \ outputCol=None, toLowercase=True) """ - super(RegexTokenizer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+", toLowercase=True) kwargs = self._input_kwargs @@ -4301,7 +4301,7 @@ def __init__(self, *, statement: Optional[str] = None): """ __init__(self, \\*, statement=None) """ - super(SQLTransformer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -4349,7 +4349,7 @@ class _StandardScalerParams(HasInputCol, HasOutputCol): ) def __init__(self, *args: Any): - super(_StandardScalerParams, self).__init__(*args) + super().__init__(*args) self._setDefault(withMean=False, withStd=True) @since("1.4.0") @@ -4436,7 +4436,7 @@ def __init__( """ __init__(self, \\*, withMean=False, withStd=True, inputCol=None, outputCol=None) """ - super(StandardScaler, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -4560,7 +4560,7 @@ class _StringIndexerParams( ) def __init__(self, *args: Any): - super(_StringIndexerParams, self).__init__(*args) + super().__init__(*args) self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") @since("2.3.0") @@ -4699,7 +4699,7 @@ def __init__( __init__(self, \\*, inputCol=None, outputCol=None, inputCols=None, outputCols=None, \ handleInvalid="error", stringOrderType="frequencyDesc") """ - super(StringIndexer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -4978,7 +4978,7 @@ def __init__( """ __init__(self, \\*, inputCol=None, outputCol=None, labels=None) """ - super(IndexToString, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -5140,7 +5140,7 @@ def __init__( __init__(self, \\*, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ locale=None, inputCols=None, outputCols=None) """ - super(StopWordsRemover, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.StopWordsRemover", self.uid ) @@ -5317,7 +5317,7 @@ class _TargetEncoderParams( ) def __init__(self, *args: Any): - super(_TargetEncoderParams, self).__init__(*args) + super().__init__(*args) self._setDefault(handleInvalid="error", targetType="binary", smoothing=0.0) @since("4.0.0") @@ -5409,7 +5409,7 @@ def __init__( __init__(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \ targetType="binary", smoothing=0.0, inputCol=None, outputCol=None) """ - super(TargetEncoder, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.TargetEncoder", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -5623,7 +5623,7 @@ def __init__(self, *, inputCol: Optional[str] = None, outputCol: Optional[str] = """ __init__(self, \\*, inputCol=None, outputCol=None) """ - super(Tokenizer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -5734,7 +5734,7 @@ def __init__( """ __init__(self, \\*, inputCols=None, outputCol=None, handleInvalid="error") """ - super(VectorAssembler, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) self._setDefault(handleInvalid="error") kwargs = self._input_kwargs @@ -5803,7 +5803,7 @@ class _VectorIndexerParams(HasInputCol, HasOutputCol, HasHandleInvalid): ) def __init__(self, *args: Any): - super(_VectorIndexerParams, self).__init__(*args) + super().__init__(*args) self._setDefault(maxCategories=20, handleInvalid="error") @since("1.4.0") @@ -5923,7 +5923,7 @@ def __init__( """ __init__(self, \\*, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error") """ - super(VectorIndexer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -6114,7 +6114,7 @@ def __init__( """ __init__(self, \\*, inputCol=None, outputCol=None, indices=None, names=None) """ - super(VectorSlicer, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid) self._setDefault(indices=[], names=[]) kwargs = self._input_kwargs @@ -6220,7 +6220,7 @@ class _Word2VecParams(HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCo ) def __init__(self, *args: Any): - super(_Word2VecParams, self).__init__(*args) + super().__init__(*args) self._setDefault( vectorSize=100, minCount=5, @@ -6359,7 +6359,7 @@ def __init__( maxIter=1, seed=None, inputCol=None, outputCol=None, windowSize=5, \ maxSentenceLength=1000) """ - super(Word2Vec, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -6592,7 +6592,7 @@ def __init__( """ __init__(self, \\*, k=None, inputCol=None, outputCol=None) """ - super(PCA, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -6717,7 +6717,7 @@ class _RFormulaParams(HasFeaturesCol, HasLabelCol, HasHandleInvalid): ) def __init__(self, *args: Any): - super(_RFormulaParams, self).__init__(*args) + super().__init__(*args) self._setDefault( forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", handleInvalid="error" ) @@ -6841,7 +6841,7 @@ def __init__( forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \ handleInvalid="error") """ - super(RFormula, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -6980,7 +6980,7 @@ class _SelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol): ) def __init__(self, *args: Any): - super(_SelectorParams, self).__init__(*args) + super().__init__(*args) self._setDefault( numTopFeatures=50, selectorType="numTopFeatures", @@ -7220,7 +7220,7 @@ def __init__( labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05, \ fdr=0.05, fwe=0.05) """ - super(ChiSqSelector, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -7331,7 +7331,7 @@ def __init__( """ __init__(self, \\*, inputCol=None, size=None, handleInvalid="error") """ - super(VectorSizeHint, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSizeHint", self.uid) self._setDefault(handleInvalid="error") self.setParams(**self._input_kwargs) @@ -7463,7 +7463,7 @@ def __init__( """ __init__(self, \\*, featuresCol="features", outputCol=None, varianceThreshold=0.0) """ - super(VarianceThresholdSelector, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.VarianceThresholdSelector", self.uid ) @@ -7586,7 +7586,7 @@ class _UnivariateFeatureSelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol ) def __init__(self, *args: Any): - super(_UnivariateFeatureSelectorParams, self).__init__(*args) + super().__init__(*args) self._setDefault(selectionMode="numTopFeatures") @since("3.1.1") @@ -7712,7 +7712,7 @@ def __init__( __init__(self, \\*, featuresCol="features", outputCol=None, \ labelCol="label", selectionMode="numTopFeatures") """ - super(UnivariateFeatureSelector, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.UnivariateFeatureSelector", self.uid ) diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index 64b7e5eae556..857e9a97b154 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -71,7 +71,7 @@ class _FPGrowthParams(HasPredictionCol): ) def __init__(self, *args: Any): - super(_FPGrowthParams, self).__init__(*args) + super().__init__(*args) self._setDefault( minSupport=0.3, minConfidence=0.8, itemsCol="items", predictionCol="prediction" ) @@ -256,7 +256,7 @@ def __init__( __init__(self, \\*, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ predictionCol="prediction", numPartitions=None) """ - super(FPGrowth, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.FPGrowth", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -409,7 +409,7 @@ def __init__( __init__(self, \\*, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ sequenceCol="sequence") """ - super(PrefixSpan, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid) self._setDefault( minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, sequenceCol="sequence" diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 067741b0a7f6..6fce132561db 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -255,7 +255,7 @@ class Params(Identifiable, metaclass=ABCMeta): """ def __init__(self) -> None: - super(Params, self).__init__() + super().__init__() #: internal param map for user-supplied values param map self._paramMap: "ParamMap" = {} diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index ade96da0a4f1..bbcaa208a39d 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -72,7 +72,7 @@ def _gen_param_header( ) def __init__(self) -> None: - super({Name}, self).__init__()''' + super().__init__()''' if defaultValueStr is not None: template += f""" diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index bc444bf9cbf9..e60f2a7432f7 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -35,7 +35,7 @@ class HasMaxIter(Params): ) def __init__(self) -> None: - super(HasMaxIter, self).__init__() + super().__init__() def getMaxIter(self) -> int: """ @@ -57,7 +57,7 @@ class HasRegParam(Params): ) def __init__(self) -> None: - super(HasRegParam, self).__init__() + super().__init__() def getRegParam(self) -> float: """ @@ -79,7 +79,7 @@ class HasFeaturesCol(Params): ) def __init__(self) -> None: - super(HasFeaturesCol, self).__init__() + super().__init__() self._setDefault(featuresCol="features") def getFeaturesCol(self) -> str: @@ -102,7 +102,7 @@ class HasLabelCol(Params): ) def __init__(self) -> None: - super(HasLabelCol, self).__init__() + super().__init__() self._setDefault(labelCol="label") def getLabelCol(self) -> str: @@ -125,7 +125,7 @@ class HasPredictionCol(Params): ) def __init__(self) -> None: - super(HasPredictionCol, self).__init__() + super().__init__() self._setDefault(predictionCol="prediction") def getPredictionCol(self) -> str: @@ -148,7 +148,7 @@ class HasProbabilityCol(Params): ) def __init__(self) -> None: - super(HasProbabilityCol, self).__init__() + super().__init__() self._setDefault(probabilityCol="probability") def getProbabilityCol(self) -> str: @@ -171,7 +171,7 @@ class HasRawPredictionCol(Params): ) def __init__(self) -> None: - super(HasRawPredictionCol, self).__init__() + super().__init__() self._setDefault(rawPredictionCol="rawPrediction") def getRawPredictionCol(self) -> str: @@ -194,7 +194,7 @@ class HasInputCol(Params): ) def __init__(self) -> None: - super(HasInputCol, self).__init__() + super().__init__() def getInputCol(self) -> str: """ @@ -216,7 +216,7 @@ class HasInputCols(Params): ) def __init__(self) -> None: - super(HasInputCols, self).__init__() + super().__init__() def getInputCols(self) -> List[str]: """ @@ -238,7 +238,7 @@ class HasOutputCol(Params): ) def __init__(self) -> None: - super(HasOutputCol, self).__init__() + super().__init__() self._setDefault(outputCol=self.uid + "__output") def getOutputCol(self) -> str: @@ -261,7 +261,7 @@ class HasOutputCols(Params): ) def __init__(self) -> None: - super(HasOutputCols, self).__init__() + super().__init__() def getOutputCols(self) -> List[str]: """ @@ -283,7 +283,7 @@ class HasNumFeatures(Params): ) def __init__(self) -> None: - super(HasNumFeatures, self).__init__() + super().__init__() self._setDefault(numFeatures=262144) def getNumFeatures(self) -> int: @@ -306,7 +306,7 @@ class HasCheckpointInterval(Params): ) def __init__(self) -> None: - super(HasCheckpointInterval, self).__init__() + super().__init__() def getCheckpointInterval(self) -> int: """ @@ -328,7 +328,7 @@ class HasSeed(Params): ) def __init__(self) -> None: - super(HasSeed, self).__init__() + super().__init__() self._setDefault(seed=hash(type(self).__name__)) def getSeed(self) -> int: @@ -351,7 +351,7 @@ class HasTol(Params): ) def __init__(self) -> None: - super(HasTol, self).__init__() + super().__init__() def getTol(self) -> float: """ @@ -373,7 +373,7 @@ class HasRelativeError(Params): ) def __init__(self) -> None: - super(HasRelativeError, self).__init__() + super().__init__() self._setDefault(relativeError=0.001) def getRelativeError(self) -> float: @@ -396,7 +396,7 @@ class HasStepSize(Params): ) def __init__(self) -> None: - super(HasStepSize, self).__init__() + super().__init__() def getStepSize(self) -> float: """ @@ -418,7 +418,7 @@ class HasHandleInvalid(Params): ) def __init__(self) -> None: - super(HasHandleInvalid, self).__init__() + super().__init__() def getHandleInvalid(self) -> str: """ @@ -440,7 +440,7 @@ class HasElasticNetParam(Params): ) def __init__(self) -> None: - super(HasElasticNetParam, self).__init__() + super().__init__() self._setDefault(elasticNetParam=0.0) def getElasticNetParam(self) -> float: @@ -463,7 +463,7 @@ class HasFitIntercept(Params): ) def __init__(self) -> None: - super(HasFitIntercept, self).__init__() + super().__init__() self._setDefault(fitIntercept=True) def getFitIntercept(self) -> bool: @@ -486,7 +486,7 @@ class HasStandardization(Params): ) def __init__(self) -> None: - super(HasStandardization, self).__init__() + super().__init__() self._setDefault(standardization=True) def getStandardization(self) -> bool: @@ -509,7 +509,7 @@ class HasThresholds(Params): ) def __init__(self) -> None: - super(HasThresholds, self).__init__() + super().__init__() def getThresholds(self) -> List[float]: """ @@ -531,7 +531,7 @@ class HasThreshold(Params): ) def __init__(self) -> None: - super(HasThreshold, self).__init__() + super().__init__() self._setDefault(threshold=0.5) def getThreshold(self) -> float: @@ -554,7 +554,7 @@ class HasWeightCol(Params): ) def __init__(self) -> None: - super(HasWeightCol, self).__init__() + super().__init__() def getWeightCol(self) -> str: """ @@ -576,7 +576,7 @@ class HasSolver(Params): ) def __init__(self) -> None: - super(HasSolver, self).__init__() + super().__init__() self._setDefault(solver="auto") def getSolver(self) -> str: @@ -599,7 +599,7 @@ class HasVarianceCol(Params): ) def __init__(self) -> None: - super(HasVarianceCol, self).__init__() + super().__init__() def getVarianceCol(self) -> str: """ @@ -621,7 +621,7 @@ class HasAggregationDepth(Params): ) def __init__(self) -> None: - super(HasAggregationDepth, self).__init__() + super().__init__() self._setDefault(aggregationDepth=2) def getAggregationDepth(self) -> int: @@ -644,7 +644,7 @@ class HasParallelism(Params): ) def __init__(self) -> None: - super(HasParallelism, self).__init__() + super().__init__() self._setDefault(parallelism=1) def getParallelism(self) -> int: @@ -667,7 +667,7 @@ class HasCollectSubModels(Params): ) def __init__(self) -> None: - super(HasCollectSubModels, self).__init__() + super().__init__() self._setDefault(collectSubModels=False) def getCollectSubModels(self) -> bool: @@ -690,7 +690,7 @@ class HasLoss(Params): ) def __init__(self) -> None: - super(HasLoss, self).__init__() + super().__init__() def getLoss(self) -> str: """ @@ -712,7 +712,7 @@ class HasDistanceMeasure(Params): ) def __init__(self) -> None: - super(HasDistanceMeasure, self).__init__() + super().__init__() self._setDefault(distanceMeasure="euclidean") def getDistanceMeasure(self) -> str: @@ -735,7 +735,7 @@ class HasValidationIndicatorCol(Params): ) def __init__(self) -> None: - super(HasValidationIndicatorCol, self).__init__() + super().__init__() def getValidationIndicatorCol(self) -> str: """ @@ -757,7 +757,7 @@ class HasBlockSize(Params): ) def __init__(self) -> None: - super(HasBlockSize, self).__init__() + super().__init__() def getBlockSize(self) -> int: """ @@ -779,7 +779,7 @@ class HasMaxBlockSizeInMB(Params): ) def __init__(self) -> None: - super(HasMaxBlockSizeInMB, self).__init__() + super().__init__() self._setDefault(maxBlockSizeInMB=0.0) def getMaxBlockSizeInMB(self) -> float: @@ -802,7 +802,7 @@ class HasNumTrainWorkers(Params): ) def __init__(self) -> None: - super(HasNumTrainWorkers, self).__init__() + super().__init__() self._setDefault(numTrainWorkers=1) def getNumTrainWorkers(self) -> int: @@ -825,7 +825,7 @@ class HasBatchSize(Params): ) def __init__(self) -> None: - super(HasBatchSize, self).__init__() + super().__init__() def getBatchSize(self) -> int: """ @@ -847,7 +847,7 @@ class HasLearningRate(Params): ) def __init__(self) -> None: - super(HasLearningRate, self).__init__() + super().__init__() def getLearningRate(self) -> float: """ @@ -869,7 +869,7 @@ class HasMomentum(Params): ) def __init__(self) -> None: - super(HasMomentum, self).__init__() + super().__init__() def getMomentum(self) -> float: """ @@ -891,7 +891,7 @@ class HasFeatureSizes(Params): ) def __init__(self) -> None: - super(HasFeatureSizes, self).__init__() + super().__init__() def getFeatureSizes(self) -> List[int]: """ diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 5728b8c0c511..f5f0d12bc836 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -77,7 +77,7 @@ def __init__(self, *, stages: Optional[List["PipelineStage"]] = None): """ __init__(self, \\*, stages=None) """ - super(Pipeline, self).__init__() + super().__init__() kwargs = self._input_kwargs self.setParams(**kwargs) @@ -227,7 +227,7 @@ class PipelineWriter(MLWriter): """ def __init__(self, instance: Pipeline): - super(PipelineWriter, self).__init__() + super().__init__() self.instance = instance def saveImpl(self, path: str) -> None: @@ -243,7 +243,7 @@ class PipelineReader(MLReader[Pipeline]): """ def __init__(self, cls: Type[Pipeline]): - super(PipelineReader, self).__init__() + super().__init__() self.cls = cls def load(self, path: str) -> Pipeline: @@ -259,7 +259,7 @@ class PipelineModelWriter(MLWriter): """ def __init__(self, instance: "PipelineModel"): - super(PipelineModelWriter, self).__init__() + super().__init__() self.instance = instance def saveImpl(self, path: str) -> None: @@ -277,7 +277,7 @@ class PipelineModelReader(MLReader["PipelineModel"]): """ def __init__(self, cls: Type["PipelineModel"]): - super(PipelineModelReader, self).__init__() + super().__init__() self.cls = cls def load(self, path: str) -> "PipelineModel": @@ -295,7 +295,7 @@ class PipelineModel(Model, MLReadable["PipelineModel"], MLWritable): """ def __init__(self, stages: List[Transformer]): - super(PipelineModel, self).__init__() + super().__init__() self.stages = stages def _transform(self, dataset: DataFrame) -> DataFrame: diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index d11990634593..d40b231224dc 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -72,7 +72,7 @@ class _ALSModelParams(HasPredictionCol, HasBlockSize): ) def __init__(self, *args: Any): - super(_ALSModelParams, self).__init__(*args) + super().__init__(*args) self._setDefault(blockSize=4096) @since("1.4.0") @@ -159,7 +159,7 @@ class _ALSParams(_ALSModelParams, HasMaxIter, HasRegParam, HasCheckpointInterval ) def __init__(self, *args: Any): - super(_ALSParams, self).__init__(*args) + super().__init__(*args) self._setDefault( rank=10, maxIter=10, @@ -395,7 +395,7 @@ def __init__( intermediateStorageLevel="MEMORY_AND_DISK", \ finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan", blockSize=4096) """ - super(ALS, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index ce97b98f6665..5e72fe89b586 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -190,7 +190,7 @@ class _LinearRegressionParams( ) def __init__(self, *args: Any): - super(_LinearRegressionParams, self).__init__(*args) + super().__init__(*args) self._setDefault( maxIter=100, regParam=0.0, @@ -325,7 +325,7 @@ def __init__( standardization=True, solver="auto", weightCol=None, aggregationDepth=2, \ loss="squaredError", epsilon=1.35, maxBlockSizeInMB=0.0) """ - super(LinearRegression, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.LinearRegression", self.uid ) @@ -795,7 +795,7 @@ class _IsotonicRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, H ) def __init__(self, *args: Any): - super(_IsotonicRegressionParams, self).__init__(*args) + super().__init__(*args) self._setDefault(isotonic=True, featureIndex=0) def getIsotonic(self) -> bool: @@ -873,7 +873,7 @@ def __init__( __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \ weightCol=None, isotonic=True, featureIndex=0): """ - super(IsotonicRegression, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.IsotonicRegression", self.uid ) @@ -1016,7 +1016,7 @@ class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, Ha """ def __init__(self, *args: Any): - super(_DecisionTreeRegressorParams, self).__init__(*args) + super().__init__(*args) self._setDefault( maxDepth=5, maxBins=32, @@ -1136,7 +1136,7 @@ def __init__( impurity="variance", seed=None, varianceCol=None, weightCol=None, \ leafCol="", minWeightFractionPerNode=0.0) """ - super(DecisionTreeRegressor, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid ) @@ -1317,7 +1317,7 @@ class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams): """ def __init__(self, *args: Any): - super(_RandomForestRegressorParams, self).__init__(*args) + super().__init__(*args) self._setDefault( maxDepth=5, maxBins=32, @@ -1440,7 +1440,7 @@ def __init__( featureSubsetStrategy="auto", leafCol=", minWeightFractionPerNode=0.0", \ weightCol=None, bootstrap=True) """ - super(RandomForestRegressor, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.RandomForestRegressor", self.uid ) @@ -1649,7 +1649,7 @@ class _GBTRegressorParams(_GBTParams, _TreeRegressorParams): ) def __init__(self, *args: Any): - super(_GBTRegressorParams, self).__init__(*args) + super().__init__(*args) self._setDefault( maxDepth=5, maxBins=32, @@ -1794,7 +1794,7 @@ def __init__( validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0, weightCol=None) """ - super(GBTRegressor, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -2058,7 +2058,7 @@ class _AFTSurvivalRegressionParams( ) def __init__(self, *args: Any): - super(_AFTSurvivalRegressionParams, self).__init__(*args) + super().__init__(*args) self._setDefault( censorCol="censor", quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], @@ -2191,7 +2191,7 @@ def __init__( quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ quantilesCol=None, aggregationDepth=2, maxBlockSizeInMB=0.0) """ - super(AFTSurvivalRegression, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid ) @@ -2422,7 +2422,7 @@ class _GeneralizedLinearRegressionParams( ) def __init__(self, *args: Any): - super(_GeneralizedLinearRegressionParams, self).__init__(*args) + super().__init__(*args) self._setDefault( family="gaussian", maxIter=25, @@ -2591,7 +2591,7 @@ def __init__( regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ variancePower=0.0, linkPower=None, offsetCol=None, aggregationDepth=2) """ - super(GeneralizedLinearRegression, self).__init__() + super().__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid ) @@ -3023,7 +3023,7 @@ class _FactorizationMachinesParams( ) def __init__(self, *args: Any): - super(_FactorizationMachinesParams, self).__init__(*args) + super().__init__(*args) self._setDefault( factorSize=8, fitIntercept=True, @@ -3159,7 +3159,7 @@ def __init__( miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \ tol=1e-6, solver="adamW", seed=None) """ - super(FMRegressor, self).__init__() + super().__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.FMRegressor", self.uid) kwargs = self._input_kwargs self.setParams(**kwargs) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index fd32a1515101..99c4b80018cf 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -499,7 +499,7 @@ class SummaryBuilder(JavaWrapper): def __init__(self, jSummaryBuilder: "JavaObject"): if not is_remote(): - super(SummaryBuilder, self).__init__(jSummaryBuilder) + super().__init__(jSummaryBuilder) def summary(self, featuresCol: Column, weightCol: Optional[Column] = None) -> Column: """ diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py index e9ccd7b0369e..4406bda09e9d 100644 --- a/python/pyspark/ml/tests/connect/test_connect_classification.py +++ b/python/pyspark/ml/tests/connect/test_connect_classification.py @@ -35,7 +35,7 @@ class ClassificationTestsOnConnect(ClassificationTestsMixin, ReusedConnectTestCase): @classmethod def conf(cls): - config = super(ClassificationTestsOnConnect, cls).conf() + config = super().conf() config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") return config diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py b/python/pyspark/ml/tests/connect/test_connect_pipeline.py index 2b408911fbd2..d88110f71531 100644 --- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py +++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py @@ -36,7 +36,7 @@ class PipelineTestsOnConnect(PipelineTestsMixin, ReusedConnectTestCase): @classmethod def conf(cls): - config = super(PipelineTestsOnConnect, cls).conf() + config = super().conf() config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") return config diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py b/python/pyspark/ml/tests/connect/test_connect_tuning.py index 3b7f977b57ae..823a9b32e924 100644 --- a/python/pyspark/ml/tests/connect/test_connect_tuning.py +++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py @@ -36,7 +36,7 @@ class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, ReusedConnectTestCase): @classmethod def conf(cls): - config = super(CrossValidatorTestsOnConnect, cls).conf() + config = super().conf() config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") return config diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py index f29a2c77e4ea..7b53aa0b7c7d 100644 --- a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py @@ -48,7 +48,7 @@ class HasInducedError(Params): def __init__(self): - super(HasInducedError, self).__init__() + super().__init__() self.inducedError = Param( self, "inducedError", "Uniformly-distributed error added to feature" ) @@ -58,7 +58,7 @@ def getInducedError(self): class InducedErrorModel(Model, HasInducedError): def __init__(self): - super(InducedErrorModel, self).__init__() + super().__init__() def _transform(self, dataset): return dataset.withColumn( @@ -67,7 +67,7 @@ def _transform(self, dataset): class InducedErrorEstimator(Estimator, HasInducedError): def __init__(self, inducedError=1.0): - super(InducedErrorEstimator, self).__init__() + super().__init__() self._set(inducedError=inducedError) def _fit(self, dataset): diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py index 0f5deab4e093..875718bb7b21 100644 --- a/python/pyspark/ml/tests/test_algorithms.py +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -254,7 +254,7 @@ def test_persistence(self): class FPGrowthTests(SparkSessionTestCase): def setUp(self): - super(FPGrowthTests, self).setUp() + super().setUp() self.data = self.spark.createDataFrame( [([1, 2],), ([1, 2],), ([1, 2, 3],), ([1, 3],)], ["items"] ) diff --git a/python/pyspark/ml/tests/test_functions.py b/python/pyspark/ml/tests/test_functions.py index 02e6d0b7c424..4621ce63db49 100644 --- a/python/pyspark/ml/tests/test_functions.py +++ b/python/pyspark/ml/tests/test_functions.py @@ -70,7 +70,7 @@ class PredictBatchUDFTestsMixin: def setUp(self): import pandas as pd - super(PredictBatchUDFTestsMixin, self).setUp() + super().setUp() self.data = np.arange(0, 1000, dtype=np.float64).reshape(-1, 4) # 4 scalar columns diff --git a/python/pyspark/ml/tests/test_model_cache.py b/python/pyspark/ml/tests/test_model_cache.py index 9ad8ac544274..6880ace2574f 100644 --- a/python/pyspark/ml/tests/test_model_cache.py +++ b/python/pyspark/ml/tests/test_model_cache.py @@ -23,7 +23,7 @@ class ModelCacheTests(SparkSessionTestCase): def setUp(self): - super(ModelCacheTests, self).setUp() + super().setUp() def test_cache(self): def predict_fn(inputs): diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py index 0aa982712495..03d514a38519 100644 --- a/python/pyspark/ml/tests/test_param.py +++ b/python/pyspark/ml/tests/test_param.py @@ -137,7 +137,7 @@ class TestParams(HasMaxIter, HasInputCol, HasSeed): @keyword_only def __init__(self, seed=None): - super(TestParams, self).__init__() + super().__init__() self._setDefault(maxIter=10) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -159,7 +159,7 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): @keyword_only def __init__(self, seed=None): - super(OtherTestParams, self).__init__() + super().__init__() self._setDefault(maxIter=10) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -176,7 +176,7 @@ def setParams(self, seed=None): class HasThrowableProperty(Params): def __init__(self): - super(HasThrowableProperty, self).__init__() + super().__init__() self.p = Param(self, "none", "empty param") @property diff --git a/python/pyspark/ml/tests/tuning/test_tuning.py b/python/pyspark/ml/tests/tuning/test_tuning.py index cbbe7b6cca82..abe84d3c5955 100644 --- a/python/pyspark/ml/tests/tuning/test_tuning.py +++ b/python/pyspark/ml/tests/tuning/test_tuning.py @@ -40,7 +40,7 @@ class HasInducedError(Params): def __init__(self): - super(HasInducedError, self).__init__() + super().__init__() self.inducedError = Param( self, "inducedError", "Uniformly-distributed error added to feature" ) @@ -51,7 +51,7 @@ def getInducedError(self): class InducedErrorModel(Model, HasInducedError): def __init__(self): - super(InducedErrorModel, self).__init__() + super().__init__() def _transform(self, dataset): return dataset.withColumn( @@ -61,7 +61,7 @@ def _transform(self, dataset): class InducedErrorEstimator(Estimator, HasInducedError): def __init__(self, inducedError=1.0): - super(InducedErrorEstimator, self).__init__() + super().__init__() self._set(inducedError=inducedError) def _fit(self, dataset): diff --git a/python/pyspark/ml/torch/tests/test_distributor.py b/python/pyspark/ml/torch/tests/test_distributor.py index e9bf1d784000..4ef2f63153af 100644 --- a/python/pyspark/ml/torch/tests/test_distributor.py +++ b/python/pyspark/ml/torch/tests/test_distributor.py @@ -69,7 +69,7 @@ def create_training_function(mnist_dir_path: str) -> Callable: class Net(nn.Module): def __init__(self) -> None: - super(Net, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() diff --git a/python/pyspark/ml/tree.py b/python/pyspark/ml/tree.py index 449b67a5b089..86e8422702b1 100644 --- a/python/pyspark/ml/tree.py +++ b/python/pyspark/ml/tree.py @@ -154,7 +154,7 @@ class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol): ) def __init__(self) -> None: - super(_DecisionTreeParams, self).__init__() + super().__init__() def setLeafCol(self: "P", value: str) -> "P": """ @@ -285,7 +285,7 @@ class _TreeEnsembleParams(_DecisionTreeParams): ) def __init__(self) -> None: - super(_TreeEnsembleParams, self).__init__() + super().__init__() @since("1.4.0") def getSubsamplingRate(self) -> float: @@ -322,7 +322,7 @@ class _RandomForestParams(_TreeEnsembleParams): ) def __init__(self) -> None: - super(_RandomForestParams, self).__init__() + super().__init__() @since("1.4.0") def getNumTrees(self) -> int: @@ -387,7 +387,7 @@ class _HasVarianceImpurity(Params): ) def __init__(self) -> None: - super(_HasVarianceImpurity, self).__init__() + super().__init__() @since("1.4.0") def getImpurity(self) -> str: @@ -416,7 +416,7 @@ class _TreeClassifierParams(Params): ) def __init__(self) -> None: - super(_TreeClassifierParams, self).__init__() + super().__init__() @since("1.6.0") def getImpurity(self) -> str: diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index d922ea38f2d8..60bcf7ae0c2c 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -521,7 +521,7 @@ def getValidatorModelWriterPersistSubModelsParam(writer: MLWriter) -> bool: @inherit_doc class CrossValidatorReader(MLReader["CrossValidator"]): def __init__(self, cls: Type["CrossValidator"]): - super(CrossValidatorReader, self).__init__() + super().__init__() self.cls = cls def load(self, path: str) -> "CrossValidator": @@ -540,7 +540,7 @@ def load(self, path: str) -> "CrossValidator": @inherit_doc class CrossValidatorWriter(MLWriter): def __init__(self, instance: "CrossValidator"): - super(CrossValidatorWriter, self).__init__() + super().__init__() self.instance = instance def saveImpl(self, path: str) -> None: @@ -551,7 +551,7 @@ def saveImpl(self, path: str) -> None: @inherit_doc class CrossValidatorModelReader(MLReader["CrossValidatorModel"]): def __init__(self, cls: Type["CrossValidatorModel"]): - super(CrossValidatorModelReader, self).__init__() + super().__init__() self.cls = cls def load(self, path: str) -> "CrossValidatorModel": @@ -599,7 +599,7 @@ def load(self, path: str) -> "CrossValidatorModel": @inherit_doc class CrossValidatorModelWriter(MLWriter): def __init__(self, instance: "CrossValidatorModel"): - super(CrossValidatorModelWriter, self).__init__() + super().__init__() self.instance = instance def saveImpl(self, path: str) -> None: @@ -656,7 +656,7 @@ class _CrossValidatorParams(_ValidatorParams): ) def __init__(self, *args: Any): - super(_CrossValidatorParams, self).__init__(*args) + super().__init__(*args) self._setDefault(numFolds=3, foldCol="") @since("1.4.0") @@ -747,7 +747,7 @@ def __init__( __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ seed=None, parallelism=1, collectSubModels=False, foldCol="") """ - super(CrossValidator, self).__init__() + super().__init__() self._setDefault(parallelism=1) kwargs = self._input_kwargs self._set(**kwargs) @@ -973,7 +973,7 @@ def _from_java(cls, java_stage: "JavaObject") -> "CrossValidator": Used for ML persistence. """ - estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage) + estimator, epms, evaluator = super()._from_java_impl(java_stage) numFolds = java_stage.getNumFolds() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() @@ -1003,7 +1003,7 @@ def _to_java(self) -> "JavaObject": Java object equivalent to this instance. """ - estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() + estimator, epms, evaluator = super()._to_java_impl() _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) _java_obj.setEstimatorParamMaps(epms) @@ -1042,7 +1042,7 @@ def __init__( subModels: Optional[List[List[Model]]] = None, stdMetrics: Optional[List[float]] = None, ): - super(CrossValidatorModel, self).__init__() + super().__init__() #: best model from cross validation self.bestModel = bestModel #: Average cross-validation metrics for each paramMap in @@ -1119,7 +1119,7 @@ def _from_java(cls, java_stage: "JavaObject") -> "CrossValidatorModel": bestModel: Model = JavaParams._from_java(java_stage.bestModel()) avgMetrics = _java2py(sc, java_stage.avgMetrics()) - estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) + estimator, epms, evaluator = super()._from_java_impl(java_stage) py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics) params = { @@ -1162,7 +1162,7 @@ def _to_java(self) -> "JavaObject": cast(JavaParams, self.bestModel)._to_java(), _py2java(sc, self.avgMetrics), ) - estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() + estimator, epms, evaluator = super()._to_java_impl() params = { "evaluator": evaluator, @@ -1189,7 +1189,7 @@ def _to_java(self) -> "JavaObject": @inherit_doc class TrainValidationSplitReader(MLReader["TrainValidationSplit"]): def __init__(self, cls: Type["TrainValidationSplit"]): - super(TrainValidationSplitReader, self).__init__() + super().__init__() self.cls = cls def load(self, path: str) -> "TrainValidationSplit": @@ -1208,7 +1208,7 @@ def load(self, path: str) -> "TrainValidationSplit": @inherit_doc class TrainValidationSplitWriter(MLWriter): def __init__(self, instance: "TrainValidationSplit"): - super(TrainValidationSplitWriter, self).__init__() + super().__init__() self.instance = instance def saveImpl(self, path: str) -> None: @@ -1219,7 +1219,7 @@ def saveImpl(self, path: str) -> None: @inherit_doc class TrainValidationSplitModelReader(MLReader["TrainValidationSplitModel"]): def __init__(self, cls: Type["TrainValidationSplitModel"]): - super(TrainValidationSplitModelReader, self).__init__() + super().__init__() self.cls = cls def load(self, path: str) -> "TrainValidationSplitModel": @@ -1258,7 +1258,7 @@ def load(self, path: str) -> "TrainValidationSplitModel": @inherit_doc class TrainValidationSplitModelWriter(MLWriter): def __init__(self, instance: "TrainValidationSplitModel"): - super(TrainValidationSplitModelWriter, self).__init__() + super().__init__() self.instance = instance def saveImpl(self, path: str) -> None: @@ -1304,7 +1304,7 @@ class _TrainValidationSplitParams(_ValidatorParams): ) def __init__(self, *args: Any): - super(_TrainValidationSplitParams, self).__init__(*args) + super().__init__(*args) self._setDefault(trainRatio=0.75) @since("2.0.0") @@ -1385,7 +1385,7 @@ def __init__( __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \ trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None) """ - super(TrainValidationSplit, self).__init__() + super().__init__() self._setDefault(parallelism=1) kwargs = self._input_kwargs self._set(**kwargs) @@ -1552,7 +1552,7 @@ def _from_java(cls, java_stage: "JavaObject") -> "TrainValidationSplit": Used for ML persistence. """ - estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage) + estimator, epms, evaluator = super()._from_java_impl(java_stage) trainRatio = java_stage.getTrainRatio() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() @@ -1580,7 +1580,7 @@ def _to_java(self) -> "JavaObject": Java object equivalent to this instance. """ - estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() + estimator, epms, evaluator = super()._to_java_impl() _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplit", self.uid @@ -1610,7 +1610,7 @@ def __init__( validationMetrics: Optional[List[float]] = None, subModels: Optional[List[Model]] = None, ): - super(TrainValidationSplitModel, self).__init__() + super().__init__() #: best model from train validation split self.bestModel = bestModel #: evaluated validation metrics @@ -1681,9 +1681,7 @@ def _from_java(cls, java_stage: "JavaObject") -> "TrainValidationSplitModel": bestModel: Model = JavaParams._from_java(java_stage.bestModel()) validationMetrics = _java2py(sc, java_stage.validationMetrics()) - estimator, epms, evaluator = super(TrainValidationSplitModel, cls)._from_java_impl( - java_stage - ) + estimator, epms, evaluator = super()._from_java_impl(java_stage) # Create a new instance of this stage. py_stage = cls(bestModel=bestModel, validationMetrics=validationMetrics) params = { @@ -1724,7 +1722,7 @@ def _to_java(self) -> "JavaObject": cast(JavaParams, self.bestModel)._to_java(), _py2java(sc, self.validationMetrics), ) - estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl() + estimator, epms, evaluator = super()._to_java_impl() params = { "evaluator": evaluator, diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 06759bc25269..5849f4a6dad8 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -566,7 +566,7 @@ class MLWriter(BaseReadWrite): """ def __init__(self) -> None: - super(MLWriter, self).__init__() + super().__init__() self.shouldOverwrite: bool = False self.optionMap: Dict[str, Any] = {} @@ -630,7 +630,7 @@ class JavaMLWriter(MLWriter): _jwrite: "JavaObject" def __init__(self, instance: "JavaMLWritable"): - super(JavaMLWriter, self).__init__() + super().__init__() _java_obj = instance._to_java() # type: ignore[attr-defined] self._jwrite = _java_obj.write() @@ -662,7 +662,7 @@ class GeneralJavaMLWriter(JavaMLWriter): """ def __init__(self, instance: "JavaMLWritable"): - super(GeneralJavaMLWriter, self).__init__(instance) + super().__init__(instance) def format(self, source: str) -> "GeneralJavaMLWriter": """ @@ -723,7 +723,7 @@ class MLReader(BaseReadWrite, Generic[RL]): """ def __init__(self) -> None: - super(MLReader, self).__init__() + super().__init__() def load(self, path: str) -> RL: """Load the ML instance from the input path.""" @@ -737,7 +737,7 @@ class JavaMLReader(MLReader[RL]): """ def __init__(self, clazz: Type["JavaMLReadable[RL]"]) -> None: - super(JavaMLReader, self).__init__() + super().__init__() self._clazz = clazz self._jread = self._load_java_obj(clazz).read() @@ -850,7 +850,7 @@ class DefaultParamsWriter(MLWriter): """ def __init__(self, instance: "Params"): - super(DefaultParamsWriter, self).__init__() + super().__init__() self.instance = instance def saveImpl(self, path: str) -> None: @@ -976,7 +976,7 @@ class DefaultParamsReader(MLReader[RL]): """ def __init__(self, cls: Type[DefaultParamsReadable[RL]]): - super(DefaultParamsReader, self).__init__() + super().__init__() self.cls = cls @staticmethod diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index b8d86e9eab3b..abb7c7625f1c 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -52,7 +52,7 @@ class JavaWrapper: """ def __init__(self, java_obj: Optional["JavaObject"] = None): - super(JavaWrapper, self).__init__() + super().__init__() self._java_obj = java_obj @try_remote_del @@ -355,7 +355,7 @@ def copy(self: "JP", extra: Optional["ParamMap"] = None) -> "JP": """ if extra is None: extra = dict() - that = super(JavaParams, self).copy(extra) + that = super().copy(extra) if self._java_obj is not None: from pyspark.ml.util import RemoteModelRef @@ -374,7 +374,7 @@ def clear(self, param: Param) -> None: """ assert self._java_obj is not None - super(JavaParams, self).clear(param) + super().clear(param) java_param = self._java_obj.getParam(param.name) self._java_obj.clear(java_param) @@ -457,7 +457,7 @@ def __init__(self, java_model: Optional["JavaObject"] = None): these wrappers depend on pyspark.ml.util (both directly and via other ML classes). """ - super(JavaModel, self).__init__(java_model) + super().__init__(java_model) if is_remote() and java_model is not None: from pyspark.ml.util import RemoteModelRef diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 84b59eecbe9d..3a7c40d15fdf 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -58,7 +58,7 @@ class LinearClassificationModel(LinearModel): """ def __init__(self, weights: Vector, intercept: float) -> None: - super(LinearClassificationModel, self).__init__(weights, intercept) + super().__init__(weights, intercept) self._threshold: Optional[float] = None @since("1.4.0") @@ -198,7 +198,7 @@ class LogisticRegressionModel(LinearClassificationModel): def __init__( self, weights: Vector, intercept: float, numFeatures: int, numClasses: int ) -> None: - super(LogisticRegressionModel, self).__init__(weights, intercept) + super().__init__(weights, intercept) self._numFeatures = int(numFeatures) self._numClasses = int(numClasses) self._threshold = 0.5 @@ -584,7 +584,7 @@ class SVMModel(LinearClassificationModel): """ def __init__(self, weights: Vector, intercept: float) -> None: - super(SVMModel, self).__init__(weights, intercept) + super().__init__(weights, intercept) self._threshold = 0.0 @overload @@ -931,7 +931,7 @@ def __init__( self.miniBatchFraction = miniBatchFraction self.convergenceTol = convergenceTol self._model: Optional[LogisticRegressionModel] = None - super(StreamingLogisticRegressionWithSGD, self).__init__(model=self._model) + super().__init__(model=self._model) @since("1.5.0") def setInitialWeights( diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index ff5d69b084e2..4fac735d7ed7 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -76,7 +76,7 @@ class BisectingKMeansModel(JavaModelWrapper): """ def __init__(self, java_model: "JavaObject"): - super(BisectingKMeansModel, self).__init__(java_model) + super().__init__(java_model) self.centers = [c.toArray() for c in self.call("clusterCenters")] @property @@ -943,7 +943,7 @@ class StreamingKMeansModel(KMeansModel): """ def __init__(self, clusterCenters: List["VectorLike"], clusterWeights: "VectorLike"): - super(StreamingKMeansModel, self).__init__(centers=clusterCenters) + super().__init__(centers=clusterCenters) self._clusterWeights = list(clusterWeights) # type: ignore[arg-type] @property diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index dfcee167bea5..1c744e1ddb98 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -82,7 +82,7 @@ def __init__(self, scoreAndLabels: RDD[Tuple[float, float]]): assert sc._jvm is not None java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics java_model = java_class(df._jdf) - super(BinaryClassificationMetrics, self).__init__(java_model) + super().__init__(java_model) @property @since("1.4.0") @@ -158,7 +158,7 @@ def __init__(self, predictionAndObservations: RDD[Tuple[float, float]]): assert sc._jvm is not None java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics java_model = java_class(df._jdf) - super(RegressionMetrics, self).__init__(java_model) + super().__init__(java_model) @property @since("1.4.0") @@ -299,7 +299,7 @@ def __init__(self, predictionAndLabels: RDD[Tuple[float, float]]): assert sc._jvm is not None java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics java_model = java_class(df._jdf) - super(MulticlassMetrics, self).__init__(java_model) + super().__init__(java_model) @since("1.4.0") def confusionMatrix(self) -> Matrix: @@ -465,7 +465,7 @@ def __init__( predictionAndLabels, schema=sql_ctx.sparkSession._inferSchema(predictionAndLabels) ) java_model = callMLlibFunc("newRankingMetrics", df._jdf) - super(RankingMetrics, self).__init__(java_model) + super().__init__(java_model) @since("1.4.0") def precisionAt(self, k: int) -> float: @@ -581,7 +581,7 @@ def __init__(self, predictionAndLabels: RDD[Tuple[List[float], List[float]]]): assert sc._jvm is not None java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics java_model = java_class(df._jdf) - super(MultilabelMetrics, self).__init__(java_model) + super().__init__(java_model) @since("1.4.0") def precision(self, label: Optional[float] = None) -> float: diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index b4bd41492706..b384e0fa608e 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -1014,7 +1014,7 @@ def __init__( self.miniBatchFraction = miniBatchFraction self.convergenceTol = convergenceTol self._model: Optional[LinearModel] = None - super(StreamingLinearRegressionWithSGD, self).__init__(model=self._model) + super().__init__(model=self._model) @since("1.5.0") def setInitialWeights(self, initialWeights: "VectorLike") -> "StreamingLinearRegressionWithSGD": diff --git a/python/pyspark/pandas/data_type_ops/datetime_ops.py b/python/pyspark/pandas/data_type_ops/datetime_ops.py index 22bd7a6d329d..6679cb5783ae 100644 --- a/python/pyspark/pandas/data_type_ops/datetime_ops.py +++ b/python/pyspark/pandas/data_type_ops/datetime_ops.py @@ -166,4 +166,4 @@ def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> Ind ) return index_ops._with_new_scol(scol, field=InternalField(dtype=dtype)) else: - return super(DatetimeNTZOps, self).astype(index_ops, dtype) + return super().astype(index_ops, dtype) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 379d3698bc09..e5aaecbb64fd 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -2699,8 +2699,20 @@ def to_feather( # Make sure locals() call is at the top of the function so we don't capture local variables. args = locals() + pdf = self._to_internal_pandas() + # SPARK-54068: PyArrow >= 22.0.0 serializes DataFrame.attrs to JSON metadata, + # but PlanMetrics/PlanObservedMetrics objects from Spark Connect are not + # JSON serializable. We filter these internal attrs only for affected versions. + import pyarrow as pa + from pyspark.loose_version import LooseVersion + + if LooseVersion(pa.__version__) >= LooseVersion("22.0.0"): + pdf.attrs = { + k: v for k, v in pdf.attrs.items() if k not in ("metrics", "observed_metrics") + } + return validate_arguments_and_invoke_function( - self._to_internal_pandas(), self.to_feather, pd.DataFrame.to_feather, args + pdf, self.to_feather, pd.DataFrame.to_feather, args ) def to_stata( diff --git a/python/pyspark/pandas/sql_formatter.py b/python/pyspark/pandas/sql_formatter.py index fe500c0bf207..98f085b39717 100644 --- a/python/pyspark/pandas/sql_formatter.py +++ b/python/pyspark/pandas/sql_formatter.py @@ -237,7 +237,7 @@ def __init__(self, session: SparkSession) -> None: self._ref_sers: List[Tuple[Series, str]] = [] def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> str: - ret = super(PandasSQLStringFormatter, self).vformat(format_string, args, kwargs) + ret = super().vformat(format_string, args, kwargs) for ref, n in self._ref_sers: if not any((ref is v for v in df._pssers.values()) for df, _ in self._temp_views): @@ -246,7 +246,7 @@ def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, return ret def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any: - obj, first = super(PandasSQLStringFormatter, self).get_field(field_name, args, kwargs) + obj, first = super().get_field(field_name, args, kwargs) return self._convert_value(obj, field_name), first def _convert_value(self, val: Any, name: str) -> Optional[str]: diff --git a/python/pyspark/pandas/tests/computation/test_describe.py b/python/pyspark/pandas/tests/computation/test_describe.py index 8df07f1945d2..1a45c3e572e4 100644 --- a/python/pyspark/pandas/tests/computation/test_describe.py +++ b/python/pyspark/pandas/tests/computation/test_describe.py @@ -28,7 +28,7 @@ class FrameDescribeMixin: @classmethod def setUpClass(cls): - super(FrameDescribeMixin, cls).setUpClass() + super().setUpClass() # Some nanosecond->microsecond conversions throw loss of precision errors cls.spark.conf.set("spark.sql.execution.pandas.convertToArrowArraySafely", "false") diff --git a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py index f98f2011dde0..228061cac804 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py @@ -248,7 +248,7 @@ class DatetimeOpsTests( class DatetimeNTZOpsTest(DatetimeOpsTests): @classmethod def setUpClass(cls): - super(DatetimeOpsTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.timestampType", "timestamp_ntz") diff --git a/python/pyspark/pandas/tests/data_type_ops/testing_utils.py b/python/pyspark/pandas/tests/data_type_ops/testing_utils.py index 17ac2bc5c474..04d03a05e02d 100644 --- a/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +++ b/python/pyspark/pandas/tests/data_type_ops/testing_utils.py @@ -43,7 +43,7 @@ class OpsTestBase: @classmethod def setUpClass(cls): - super(OpsTestBase, cls).setUpClass() + super().setUpClass() # Some nanosecond->microsecond conversions throw loss of precision errors cls.spark.conf.set("spark.sql.execution.pandas.convertToArrowArraySafely", "false") diff --git a/python/pyspark/pandas/tests/io/test_feather.py b/python/pyspark/pandas/tests/io/test_feather.py index 10638d915c0e..74fa6bc7d7b6 100644 --- a/python/pyspark/pandas/tests/io/test_feather.py +++ b/python/pyspark/pandas/tests/io/test_feather.py @@ -17,10 +17,8 @@ import unittest import pandas as pd -import sys from pyspark import pandas as ps -from pyspark.loose_version import LooseVersion from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils @@ -36,16 +34,6 @@ def pdf(self): def psdf(self): return ps.from_pandas(self.pdf) - has_arrow_21_or_below = False - try: - import pyarrow as pa - - if LooseVersion(pa.__version__) < LooseVersion("22.0.0"): - has_arrow_21_or_below = True - except ImportError: - pass - - @unittest.skipIf(not has_arrow_21_or_below, "SPARK-54068") def test_to_feather(self): with self.temp_dir() as dirpath: path1 = f"{dirpath}/file1.feather" diff --git a/python/pyspark/pandas/tests/resample/test_timezone.py b/python/pyspark/pandas/tests/resample/test_timezone.py index 17c46dd26b35..ad01a6413709 100644 --- a/python/pyspark/pandas/tests/resample/test_timezone.py +++ b/python/pyspark/pandas/tests/resample/test_timezone.py @@ -33,11 +33,11 @@ class ResampleTimezoneMixin: def setUpClass(cls): cls.timezone = os.environ.get("TZ", None) os.environ["TZ"] = "America/New_York" - super(ResampleTimezoneMixin, cls).setUpClass() + super().setUpClass() @classmethod def tearDownClass(cls): - super(ResampleTimezoneMixin, cls).tearDownClass() + super().tearDownClass() if cls.timezone is not None: os.environ["TZ"] = cls.timezone diff --git a/python/pyspark/pandas/tests/test_numpy_compat.py b/python/pyspark/pandas/tests/test_numpy_compat.py index d961a433e181..97df2bc6eb8d 100644 --- a/python/pyspark/pandas/tests/test_numpy_compat.py +++ b/python/pyspark/pandas/tests/test_numpy_compat.py @@ -28,7 +28,7 @@ class NumPyCompatTestsMixin: @classmethod def setUpClass(cls): - super(NumPyCompatTestsMixin, cls).setUpClass() + super().setUpClass() # Some nanosecond->microsecond conversions throw loss of precision errors cls.spark.conf.set("spark.sql.execution.pandas.convertToArrowArraySafely", "false") diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 0c9a6b45bd2f..a3f188302296 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -502,6 +502,20 @@ def pairs(self) -> dict[str, Any]: def keys(self) -> List[str]: return self._keys + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dictionary representation of this observed metrics. + + Returns + ------- + dict + A dictionary with keys 'name', 'keys', and 'pairs'. + """ + return { + "name": self._name, + "keys": self._keys, + "pairs": self.pairs, + } + class AnalyzeResult: def __init__( diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 6630d96f21de..80b6b562369e 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1930,7 +1930,7 @@ def command(self, session: "SparkConnectClient") -> proto.Command: class WriteOperation(LogicalPlan): def __init__(self, child: "LogicalPlan") -> None: - super(WriteOperation, self).__init__(child) + super().__init__(child) self.source: Optional[str] = None self.path: Optional[str] = None self.table_name: Optional[str] = None @@ -2037,7 +2037,7 @@ def _repr_html_(self) -> str: class WriteOperationV2(LogicalPlan): def __init__(self, child: "LogicalPlan", table_name: str) -> None: - super(WriteOperationV2, self).__init__(child) + super().__init__(child) self.table_name: Optional[str] = table_name self.provider: Optional[str] = None self.partitioning_columns: List[Column] = [] @@ -2101,7 +2101,7 @@ def command(self, session: "SparkConnectClient") -> proto.Command: class WriteStreamOperation(LogicalPlan): def __init__(self, child: "LogicalPlan") -> None: - super(WriteStreamOperation, self).__init__(child) + super().__init__(child) self.write_op = proto.WriteStreamOperationStart() def command(self, session: "SparkConnectClient") -> proto.Command: diff --git a/python/pyspark/sql/connect/sql_formatter.py b/python/pyspark/sql/connect/sql_formatter.py index 38b94bbaf205..8fced80081ad 100644 --- a/python/pyspark/sql/connect/sql_formatter.py +++ b/python/pyspark/sql/connect/sql_formatter.py @@ -39,7 +39,7 @@ def __init__(self, session: "SparkSession") -> None: self._temp_views: List[Tuple[DataFrame, str]] = [] def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any: - obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs) + obj, first = super().get_field(field_name, args, kwargs) return self._convert_value(obj, field_name), first def _convert_value(self, val: Any, field_name: str) -> Optional[str]: diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py index 4ab9b041e313..50d8ae444c47 100644 --- a/python/pyspark/sql/metrics.py +++ b/python/pyspark/sql/metrics.py @@ -68,6 +68,20 @@ def value(self) -> Union[int, float]: def metric_type(self) -> str: return self._type + def to_dict(self) -> Dict[str, Any]: + """Return a JSON-serializable dictionary representation of this metric value. + + Returns + ------- + dict + A dictionary with keys 'name', 'value', and 'type'. + """ + return { + "name": self._name, + "value": self._value, + "type": self._type, + } + class PlanMetrics: """Represents a particular plan node and the associated metrics of this node.""" @@ -97,6 +111,21 @@ def parent_plan_id(self) -> int: def metrics(self) -> List[MetricValue]: return self._metrics + def to_dict(self) -> Dict[str, Any]: + """Return a JSON-serializable dictionary representation of this plan metrics. + + Returns + ------- + dict + A dictionary with keys 'name', 'plan_id', 'parent_plan_id', and 'metrics'. + """ + return { + "name": self._name, + "plan_id": self._id, + "parent_plan_id": self._parent_id, + "metrics": [m.to_dict() for m in self._metrics], + } + class CollectedMetrics: @dataclasses.dataclass diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 768160087032..dc854cb1985d 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -149,7 +149,7 @@ def load_stream(self, stream): """ import pyarrow as pa - batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) + batches = super().load_stream(stream) for batch in batches: struct = batch.column(0) yield [pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))] @@ -184,7 +184,7 @@ def wrap_and_init_stream(): should_write_start_length = False yield batch - return super(ArrowStreamUDFSerializer, self).dump_stream(wrap_and_init_stream(), stream) + return super().dump_stream(wrap_and_init_stream(), stream) class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer): @@ -304,7 +304,7 @@ class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer): """ def __init__(self, assign_cols_by_name): - super(ArrowStreamGroupUDFSerializer, self).__init__() + super().__init__() self._assign_cols_by_name = assign_cols_by_name def dump_stream(self, iterator, stream): @@ -330,7 +330,7 @@ def dump_stream(self, iterator, stream): for batch, arrow_type in batch_iter ) - super(ArrowStreamGroupUDFSerializer, self).dump_stream(batch_iter, stream) + super().dump_stream(batch_iter, stream) class ArrowStreamPandasSerializer(ArrowStreamSerializer): @@ -351,7 +351,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): """ def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled): - super(ArrowStreamPandasSerializer, self).__init__() + super().__init__() self._timezone = timezone self._safecheck = safecheck self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled @@ -528,13 +528,13 @@ def dump_stream(self, iterator, stream): a list of series accompanied by an optional pyarrow type to coerce the data to. """ batches = (self._create_batch(series) for series in iterator) - super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream) + super().dump_stream(batches, stream) def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + batches = super().load_stream(stream) import pyarrow as pa import pandas as pd @@ -569,9 +569,7 @@ def __init__( input_types=None, int_to_decimal_coercion_enabled=False, ): - super(ArrowStreamPandasUDFSerializer, self).__init__( - timezone, safecheck, int_to_decimal_coercion_enabled - ) + super().__init__(timezone, safecheck, int_to_decimal_coercion_enabled) self._assign_cols_by_name = assign_cols_by_name self._df_for_struct = df_for_struct self._struct_in_pandas = struct_in_pandas @@ -593,6 +591,7 @@ def arrow_to_pandas(self, arrow_column, idx): import pandas as pd series = [ + # Need to be explicit here because it's in a comprehension super(ArrowStreamPandasUDFSerializer, self) .arrow_to_pandas( column, @@ -610,7 +609,7 @@ def arrow_to_pandas(self, arrow_column, idx): ] s = pd.concat(series, axis=1) else: - s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas( + s = super().arrow_to_pandas( arrow_column, idx, self._struct_in_pandas, @@ -768,7 +767,7 @@ def __init__( assign_cols_by_name, arrow_cast, ): - super(ArrowStreamArrowUDFSerializer, self).__init__() + super().__init__() self._timezone = timezone self._safecheck = safecheck self._assign_cols_by_name = assign_cols_by_name @@ -941,7 +940,7 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer): """ def __init__(self, timezone, safecheck, input_types, int_to_decimal_coercion_enabled): - super(ArrowStreamPandasUDTFSerializer, self).__init__( + super().__init__( timezone=timezone, safecheck=safecheck, # The output pandas DataFrame's columns are unnamed. @@ -1089,7 +1088,7 @@ def __repr__(self): class GroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer): def __init__(self, assign_cols_by_name): - super(GroupArrowUDFSerializer, self).__init__( + super().__init__( assign_cols_by_name=assign_cols_by_name, ) @@ -1138,13 +1137,9 @@ def __init__( super().__init__( timezone=timezone, safecheck=safecheck, - assign_cols_by_name=False, - arrow_cast=True, + assign_cols_by_name=assign_cols_by_name, + arrow_cast=arrow_cast, ) - self._timezone = timezone - self._safecheck = safecheck - self._assign_cols_by_name = assign_cols_by_name - self._arrow_cast = arrow_cast def load_stream(self, stream): """ @@ -1189,13 +1184,9 @@ def __init__( super().__init__( timezone=timezone, safecheck=safecheck, - assign_cols_by_name=False, - arrow_cast=True, + assign_cols_by_name=assign_cols_by_name, + arrow_cast=arrow_cast, ) - self._timezone = timezone - self._safecheck = safecheck - self._assign_cols_by_name = assign_cols_by_name - self._arrow_cast = arrow_cast def load_stream(self, stream): """ @@ -1238,10 +1229,10 @@ def __init__( assign_cols_by_name, int_to_decimal_coercion_enabled, ): - super(ArrowStreamAggPandasUDFSerializer, self).__init__( + super().__init__( timezone=timezone, safecheck=safecheck, - assign_cols_by_name=False, + assign_cols_by_name=assign_cols_by_name, df_for_struct=False, struct_in_pandas="dict", ndarray_as_list=False, @@ -1249,9 +1240,6 @@ def __init__( input_types=None, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, ) - self._timezone = timezone - self._safecheck = safecheck - self._assign_cols_by_name = assign_cols_by_name def load_stream(self, stream): """ @@ -1295,7 +1283,7 @@ def __init__( assign_cols_by_name, int_to_decimal_coercion_enabled, ): - super(GroupPandasUDFSerializer, self).__init__( + super().__init__( timezone=timezone, safecheck=safecheck, assign_cols_by_name=assign_cols_by_name, @@ -1353,7 +1341,7 @@ def dump_stream(self, iterator, stream): """ # Flatten: Iterator[Iterator[[(df, arrow_type)]]] -> Iterator[[(df, arrow_type)]] flattened_iter = (batch for generator in iterator for batch in generator) - super(GroupPandasUDFSerializer, self).dump_stream(flattened_iter, stream) + super().dump_stream(flattened_iter, stream) def __repr__(self): return "GroupPandasUDFSerializer" @@ -1373,7 +1361,7 @@ class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer): """ def __init__(self, assign_cols_by_name): - super(CogroupArrowUDFSerializer, self).__init__(assign_cols_by_name) + super().__init__(assign_cols_by_name) def load_stream(self, stream): """ @@ -1458,7 +1446,7 @@ def __init__( prefers_large_var_types, int_to_decimal_coercion_enabled, ): - super(ApplyInPandasWithStateSerializer, self).__init__( + super().__init__( timezone, safecheck, assign_cols_by_name, @@ -1842,7 +1830,7 @@ def __init__( arrow_max_bytes_per_batch, int_to_decimal_coercion_enabled, ): - super(TransformWithStateInPandasSerializer, self).__init__( + super().__init__( timezone, safecheck, assign_cols_by_name, @@ -1967,7 +1955,7 @@ def __init__( arrow_max_bytes_per_batch, int_to_decimal_coercion_enabled, ): - super(TransformWithStateInPandasInitStateSerializer, self).__init__( + super().__init__( timezone, safecheck, assign_cols_by_name, @@ -2108,7 +2096,7 @@ class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer): """ def __init__(self, arrow_max_records_per_batch): - super(TransformWithStateInPySparkRowSerializer, self).__init__() + super().__init__() self.arrow_max_records_per_batch = ( arrow_max_records_per_batch if arrow_max_records_per_batch > 0 else 2**31 - 1 ) @@ -2197,9 +2185,7 @@ class TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp """ def __init__(self, arrow_max_records_per_batch): - super(TransformWithStateInPySparkRowInitStateSerializer, self).__init__( - arrow_max_records_per_batch - ) + super().__init__(arrow_max_records_per_batch) self.init_key_offsets = None def load_stream(self, stream): diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 4e45972d79b3..0d5353f0fb32 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -1624,12 +1624,12 @@ def createDataFrame( # type: ignore[misc] if has_pandas and isinstance(data, pd.DataFrame): # Create a DataFrame from pandas DataFrame. - return super(SparkSession, self).createDataFrame( # type: ignore[call-overload] + return super().createDataFrame( # type: ignore[call-overload] data, schema, samplingRatio, verifySchema ) if has_pyarrow and isinstance(data, pa.Table): # Create a DataFrame from PyArrow Table. - return super(SparkSession, self).createDataFrame( # type: ignore[call-overload] + return super().createDataFrame( # type: ignore[call-overload] data, schema, samplingRatio, verifySchema ) return self._create_dataframe( diff --git a/python/pyspark/sql/sql_formatter.py b/python/pyspark/sql/sql_formatter.py index 011563d7006e..1366ef277c47 100644 --- a/python/pyspark/sql/sql_formatter.py +++ b/python/pyspark/sql/sql_formatter.py @@ -38,7 +38,7 @@ def __init__(self, session: "SparkSession") -> None: self._temp_views: List[Tuple[DataFrame, str]] = [] def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any: - obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs) + obj, first = super().get_field(field_name, args, kwargs) return self._convert_value(obj, field_name), first def _convert_value(self, val: Any, field_name: str) -> Optional[str]: diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py b/python/pyspark/sql/tests/arrow/test_arrow.py index 19e579cb6778..e410f7df711b 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow.py +++ b/python/pyspark/sql/tests/arrow/test_arrow.py @@ -1925,14 +1925,14 @@ def test_exception_by_max_results(self): class EncryptionArrowTests(ArrowTests): @classmethod def conf(cls): - return super(EncryptionArrowTests, cls).conf().set("spark.io.encryption.enabled", "true") + return super().conf().set("spark.io.encryption.enabled", "true") class RDDBasedArrowTests(ArrowTests): @classmethod def conf(cls): return ( - super(RDDBasedArrowTests, cls) + super() .conf() .set("spark.sql.execution.arrow.localRelationThreshold", "0") # to test multiple partitions diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py index c315151d4d75..be46939b351f 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py @@ -50,15 +50,15 @@ class ArrowPythonUDFTestsMixin(BaseUDFTestsMixin): @unittest.skip("Unrelated test, and it fails when it runs duplicatedly.") def test_broadcast_in_udf(self): - super(ArrowPythonUDFTests, self).test_broadcast_in_udf() + super().test_broadcast_in_udf() @unittest.skip("Unrelated test, and it fails when it runs duplicatedly.") def test_register_java_function(self): - super(ArrowPythonUDFTests, self).test_register_java_function() + super().test_register_java_function() @unittest.skip("Unrelated test, and it fails when it runs duplicatedly.") def test_register_java_udaf(self): - super(ArrowPythonUDFTests, self).test_register_java_udaf() + super().test_register_java_udaf() def test_complex_input_types(self): row = ( @@ -485,7 +485,7 @@ def tearDownClass(cls): class ArrowPythonUDFTests(ArrowPythonUDFTestsMixin, ReusedSQLTestCase): @classmethod def setUpClass(cls): - super(ArrowPythonUDFTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") @classmethod @@ -493,21 +493,21 @@ def tearDownClass(cls): try: cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") finally: - super(ArrowPythonUDFTests, cls).tearDownClass() + super().tearDownClass() @unittest.skip("Duplicate test; it is tested separately in legacy and non-legacy tests") def test_udf_binary_type(self): - super(ArrowPythonUDFTests, self).test_udf_binary_type() + super().test_udf_binary_type() @unittest.skip("Duplicate test; it is tested separately in legacy and non-legacy tests") def test_udf_binary_type_in_nested_structures(self): - super(ArrowPythonUDFTests, self).test_udf_binary_type_in_nested_structures() + super().test_udf_binary_type_in_nested_structures() class ArrowPythonUDFLegacyTests(ArrowPythonUDFLegacyTestsMixin, ReusedSQLTestCase): @classmethod def setUpClass(cls): - super(ArrowPythonUDFLegacyTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.concurrency.level", "4") cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") @@ -517,13 +517,13 @@ def tearDownClass(cls): cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.concurrency.level") cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") finally: - super(ArrowPythonUDFLegacyTests, cls).tearDownClass() + super().tearDownClass() class ArrowPythonUDFNonLegacyTests(ArrowPythonUDFNonLegacyTestsMixin, ReusedSQLTestCase): @classmethod def setUpClass(cls): - super(ArrowPythonUDFNonLegacyTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") @classmethod @@ -531,7 +531,7 @@ def tearDownClass(cls): try: cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") finally: - super(ArrowPythonUDFNonLegacyTests, cls).tearDownClass() + super().tearDownClass() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py index 74a81be37f80..844c7f111db4 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py @@ -1153,13 +1153,14 @@ def test_iterator_grouped_agg_partial_consumption(self): # Create a dataset with multiple batches per group # Use small batch size to ensure multiple batches per group + # Use same value for all data points to avoid ordering issues with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 2}): df = self.spark.createDataFrame( - [(1, 1.0), (1, 2.0), (1, 3.0), (1, 4.0), (2, 5.0), (2, 6.0)], ("id", "v") + [(1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0), (2, 1.0), (2, 1.0)], ("id", "v") ) - @arrow_udf("double") - def arrow_sum_partial(it: Iterator[pa.Array]) -> float: + @arrow_udf("struct") + def arrow_count_sum_partial(it: Iterator[pa.Array]) -> dict: # Only consume first two batches, then return # This tests that partial consumption works correctly total = 0.0 @@ -1171,32 +1172,44 @@ def arrow_sum_partial(it: Iterator[pa.Array]) -> float: else: # Stop early - partial consumption break - return total / count if count > 0 else 0.0 + return {"count": count, "sum": total} - result = df.groupby("id").agg(arrow_sum_partial(df["v"]).alias("mean")).sort("id") + result = ( + df.groupby("id").agg(arrow_count_sum_partial(df["v"]).alias("result")).sort("id") + ) # Verify results are correct for partial consumption # With batch size = 2: # Group 1 (id=1): 4 values in 2 batches -> processes both batches - # Batch 1: [1.0, 2.0], Batch 2: [3.0, 4.0] - # Result: (1.0+2.0+3.0+4.0)/4 = 2.5 + # Batch 1: [1.0, 1.0], Batch 2: [1.0, 1.0] + # Result: count=4, sum=4.0 # Group 2 (id=2): 2 values in 1 batch -> processes 1 batch (only 1 batch available) - # Batch 1: [5.0, 6.0] - # Result: (5.0+6.0)/2 = 5.5 + # Batch 1: [1.0, 1.0] + # Result: count=2, sum=2.0 actual = result.collect() self.assertEqual(len(actual), 2, "Should have results for both groups") # Verify both groups were processed correctly # Group 1: processes 2 batches (all available) group1_result = next(row for row in actual if row["id"] == 1) + self.assertEqual( + group1_result["result"]["count"], + 4, + msg="Group 1 should process 4 values (2 batches)", + ) self.assertAlmostEqual( - group1_result["mean"], 2.5, places=5, msg="Group 1 should process 2 batches" + group1_result["result"]["sum"], 4.0, places=5, msg="Group 1 should sum to 4.0" ) # Group 2: processes 1 batch (only batch available) group2_result = next(row for row in actual if row["id"] == 2) + self.assertEqual( + group2_result["result"]["count"], + 2, + msg="Group 2 should process 2 values (1 batch)", + ) self.assertAlmostEqual( - group2_result["mean"], 5.5, places=5, msg="Group 2 should process 1 batch" + group2_result["result"]["sum"], 2.0, places=5, msg="Group 2 should sum to 2.0" ) diff --git a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py index 0f879087e7bb..46ba9e0e2e51 100644 --- a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py +++ b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py @@ -24,7 +24,7 @@ class ArrowPythonUDFParityTests(UDFParityTests, ArrowPythonUDFTestsMixin): @classmethod def setUpClass(cls): - super(ArrowPythonUDFParityTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") @classmethod @@ -32,7 +32,7 @@ def tearDownClass(cls): try: cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") finally: - super(ArrowPythonUDFParityTests, cls).tearDownClass() + super().tearDownClass() class ArrowPythonUDFParityLegacyTestsMixin(ArrowPythonUDFTestsMixin): @@ -84,7 +84,7 @@ def test_udf_binary_type_in_nested_structures(self): class ArrowPythonUDFParityLegacyTests(UDFParityTests, ArrowPythonUDFParityLegacyTestsMixin): @classmethod def setUpClass(cls): - super(ArrowPythonUDFParityLegacyTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") @classmethod @@ -92,13 +92,13 @@ def tearDownClass(cls): try: cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") finally: - super(ArrowPythonUDFParityLegacyTests, cls).tearDownClass() + super().tearDownClass() class ArrowPythonUDFParityNonLegacyTests(UDFParityTests, ArrowPythonUDFParityNonLegacyTestsMixin): @classmethod def setUpClass(cls): - super(ArrowPythonUDFParityNonLegacyTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") @classmethod @@ -106,7 +106,7 @@ def tearDownClass(cls): try: cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") finally: - super(ArrowPythonUDFParityNonLegacyTests, cls).tearDownClass() + super().tearDownClass() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index 22865be4f42a..528fc10bbb28 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -203,7 +203,7 @@ def root(cls): @classmethod def setUpClass(cls): - super(ArtifactTests, cls).setUpClass() + super().setUpClass() cls.artifact_manager: ArtifactManager = cls.spark._client._artifact_manager cls.base_resource_dir = os.path.join(SPARK_HOME, "data") cls.artifact_file_path = os.path.join( diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index b789d7919c94..08e912a446e3 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -65,7 +65,7 @@ class SparkConnectSQLTestCase(ReusedMixedTestCase, PandasOnSparkTestUtils): @classmethod def setUpClass(cls): - super(SparkConnectSQLTestCase, cls).setUpClass() + super().setUpClass() cls.testData = [Row(key=i, value=str(i)) for i in range(100)] cls.testDataStr = [Row(key=str(i)) for i in range(100)] @@ -88,7 +88,7 @@ def tearDownClass(cls): try: cls.spark_connect_clean_up_test_data() finally: - super(SparkConnectSQLTestCase, cls).tearDownClass() + super().tearDownClass() @classmethod def spark_connect_load_test_data(cls): @@ -1480,11 +1480,11 @@ class SparkConnectGCTests(SparkConnectSQLTestCase): def setUpClass(cls): cls.origin = os.getenv("USER", None) os.environ["USER"] = "SparkConnectGCTests" - super(SparkConnectGCTests, cls).setUpClass() + super().setUpClass() @classmethod def tearDownClass(cls): - super(SparkConnectGCTests, cls).tearDownClass() + super().tearDownClass() if cls.origin is not None: os.environ["USER"] = cls.origin else: diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index 208f9ae53898..d3dfbc221041 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -41,7 +41,7 @@ class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): @classmethod def setUpClass(cls): - super(UDTFParityTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "false") @classmethod @@ -49,7 +49,7 @@ def tearDownClass(cls): try: cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled") finally: - super(UDTFParityTests, cls).tearDownClass() + super().tearDownClass() def test_struct_output_type_casting_row(self): self.check_struct_output_type_casting_row(PickleException) @@ -94,7 +94,7 @@ def _add_file(self, path): class LegacyArrowUDTFParityTests(LegacyUDTFArrowTestsMixin, UDTFParityTests): @classmethod def setUpClass(cls): - super(LegacyArrowUDTFParityTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true") cls.spark.conf.set( "spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "true" @@ -106,7 +106,7 @@ def tearDownClass(cls): cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled") cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled") finally: - super(LegacyArrowUDTFParityTests, cls).tearDownClass() + super().tearDownClass() def test_udtf_access_spark_session_connect(self): df = self.spark.range(10) @@ -128,7 +128,7 @@ def eval(self): class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests): @classmethod def setUpClass(cls): - super(ArrowUDTFParityTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true") cls.spark.conf.set( "spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "false" @@ -140,7 +140,7 @@ def tearDownClass(cls): cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled") cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled") finally: - super(ArrowUDTFParityTests, cls).tearDownClass() + super().tearDownClass() def test_udtf_access_spark_session_connect(self): df = self.spark.range(10) diff --git a/python/pyspark/sql/tests/test_resources.py b/python/pyspark/sql/tests/test_resources.py index 4ce61e9f763d..9b99672077fe 100644 --- a/python/pyspark/sql/tests/test_resources.py +++ b/python/pyspark/sql/tests/test_resources.py @@ -90,7 +90,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - super(ResourceProfileTests, cls).tearDownClass() + super().tearDownClass() cls.spark.stop() diff --git a/python/pyspark/sql/tests/test_udf_profiler.py b/python/pyspark/sql/tests/test_udf_profiler.py index e6a7bf40b945..4d565ecfd939 100644 --- a/python/pyspark/sql/tests/test_udf_profiler.py +++ b/python/pyspark/sql/tests/test_udf_profiler.py @@ -138,6 +138,23 @@ def iter_to_iter(iter: Iterator[pa.Array]) -> Iterator[pa.Array]: self.spark.range(10).select(iter_to_iter("id")).collect() + def exec_arrow_udf_grouped_agg_iter(self): + import pyarrow as pa + + @arrow_udf("double") + def arrow_mean_iter(it: Iterator[pa.Array]) -> float: + sum_val = 0.0 + cnt = 0 + for v in it: + sum_val += pa.compute.sum(v).as_py() + cnt += len(v) + return sum_val / cnt if cnt > 0 else 0.0 + + df = self.spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v") + ) + df.groupby("id").agg(arrow_mean_iter(df["v"])).collect() + # Unsupported def exec_map(self): import pandas as pd @@ -169,6 +186,15 @@ def test_unsupported(self): "Profiling UDFs with iterators input/output is not supported" in str(user_warns[0]) ) + with warnings.catch_warnings(record=True) as warns: + warnings.simplefilter("always") + self.exec_arrow_udf_grouped_agg_iter() + user_warns = [warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Profiling UDFs with iterators input/output is not supported" in str(user_warns[0]) + ) + with warnings.catch_warnings(record=True) as warns: warnings.simplefilter("always") self.exec_map() @@ -486,6 +512,31 @@ def min_udf(v: pa.Array) -> float: for id in self.profile_results: self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2) + @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) + def test_perf_profiler_arrow_udf_grouped_agg_iter(self): + import pyarrow as pa + from typing import Iterator + + @arrow_udf("double") + def arrow_mean_iter(it: Iterator[pa.Array]) -> float: + sum_val = 0.0 + cnt = 0 + for v in it: + sum_val += pa.compute.sum(v).as_py() + cnt += len(v) + return sum_val / cnt if cnt > 0 else 0.0 + + with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): + df = self.spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v") + ) + df.groupBy(df.id).agg(arrow_mean_iter(df["v"])).show() + + self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys())) + + for id in self.profile_results: + self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 389df5b5a6cf..eeb07400b060 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -812,6 +812,111 @@ def _check_result_or_exception( with self.assertRaisesRegex(err_type, expected): func().collect() + def test_udtf_nullable_check(self): + for ret_type, value, expected in [ + ( + StructType([StructField("value", ArrayType(IntegerType(), False))]), + ([None],), + "PySparkRuntimeError", + ), + ( + StructType([StructField("value", ArrayType(IntegerType(), True))]), + ([None],), + [Row(value=[None])], + ), + ( + StructType([StructField("value", MapType(StringType(), IntegerType(), False))]), + ({"a": None},), + "PySparkRuntimeError", + ), + ( + StructType([StructField("value", MapType(StringType(), IntegerType(), True))]), + ({"a": None},), + [Row(value={"a": None})], + ), + ( + StructType([StructField("value", MapType(StringType(), IntegerType(), True))]), + ({None: 1},), + "PySparkRuntimeError", + ), + ( + StructType([StructField("value", MapType(StringType(), IntegerType(), False))]), + ({None: 1},), + "PySparkRuntimeError", + ), + ( + StructType( + [ + StructField( + "value", MapType(StringType(), ArrayType(IntegerType(), False), False) + ) + ] + ), + ({"s": [None]},), + "PySparkRuntimeError", + ), + ( + StructType( + [ + StructField( + "value", + MapType( + StructType([StructField("value", StringType(), False)]), + IntegerType(), + False, + ), + ) + ] + ), + ({(None,): 1},), + "PySparkRuntimeError", + ), + ( + StructType( + [ + StructField( + "value", + MapType( + StructType([StructField("value", StringType(), False)]), + IntegerType(), + True, + ), + ) + ] + ), + ({(None,): 1},), + "PySparkRuntimeError", + ), + ( + StructType( + [StructField("value", StructType([StructField("value", StringType(), False)]))] + ), + ((None,),), + "PySparkRuntimeError", + ), + ( + StructType( + [ + StructField( + "value", + StructType( + [StructField("value", ArrayType(StringType(), False), False)] + ), + ) + ] + ), + (([None],),), + "PySparkRuntimeError", + ), + ]: + + class TestUDTF: + def eval(self): + yield value + + with self.subTest(ret_type=ret_type, value=value): + self._check_result_or_exception(TestUDTF, ret_type, expected) + def test_numeric_output_type_casting(self): class TestUDTF: def eval(self): @@ -3183,7 +3288,7 @@ def eval(self, x: int): class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase): @classmethod def setUpClass(cls): - super(UDTFTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "false") @classmethod @@ -3191,7 +3296,7 @@ def tearDownClass(cls): try: cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled") finally: - super(UDTFTests, cls).tearDownClass() + super().tearDownClass() @unittest.skipIf( @@ -3521,7 +3626,7 @@ def eval(self): class LegacyUDTFArrowTests(LegacyUDTFArrowTestsMixin, ReusedSQLTestCase): @classmethod def setUpClass(cls): - super(LegacyUDTFArrowTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true") cls.spark.conf.set( "spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "true" @@ -3533,7 +3638,7 @@ def tearDownClass(cls): cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled") cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled") finally: - super(LegacyUDTFArrowTests, cls).tearDownClass() + super().tearDownClass() class UDTFArrowTestsMixin(LegacyUDTFArrowTestsMixin): @@ -3780,7 +3885,7 @@ def eval(self, v: float): class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase): @classmethod def setUpClass(cls): - super(UDTFArrowTests, cls).setUpClass() + super().setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true") cls.spark.conf.set( "spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "false" @@ -3792,7 +3897,7 @@ def tearDownClass(cls): cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled") cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled") finally: - super(UDTFArrowTests, cls).tearDownClass() + super().tearDownClass() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py b/python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py index c25f35814b43..99127c8bbe90 100644 --- a/python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py +++ b/python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py @@ -70,10 +70,10 @@ class UDFInputTypeTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - super(UDFInputTypeTests, cls).setUpClass() + super().setUpClass() def setUp(self): - super(UDFInputTypeTests, self).setUp() + super().setUp() def test_udf_input_types_arrow_disabled(self): golden_file = os.path.join( diff --git a/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py b/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py index 5e307110559b..9b69d6b20ba2 100644 --- a/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py +++ b/python/pyspark/sql/tests/udf_type_tests/test_udf_return_types.py @@ -73,10 +73,10 @@ class UDFReturnTypeTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - super(UDFReturnTypeTests, cls).setUpClass() + super().setUpClass() def setUp(self): - super(UDFReturnTypeTests, self).setUp() + super().setUp() self.test_data = [ None, True, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 675d25f23aea..4c36a2db5323 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -242,9 +242,7 @@ class DataTypeSingleton(type): def __call__(cls: Type[T]) -> T: if cls not in cls._instances: # type: ignore[attr-defined] - cls._instances[cls] = super( # type: ignore[misc, attr-defined] - DataTypeSingleton, cls - ).__call__() + cls._instances[cls] = super().__call__() # type: ignore[misc, attr-defined] return cls._instances[cls] # type: ignore[attr-defined] @@ -3624,7 +3622,7 @@ def __contains__(self, item: Any) -> bool: if hasattr(self, "__fields__"): return item in self.__fields__ else: - return super(Row, self).__contains__(item) + return super().__contains__(item) # let object acts like class def __call__(self, *args: Any) -> "Row": @@ -3642,12 +3640,12 @@ def __call__(self, *args: Any) -> "Row": def __getitem__(self, item: Any) -> Any: if isinstance(item, (int, slice)): - return super(Row, self).__getitem__(item) + return super().__getitem__(item) try: # it will be slow when it has many fields, # but this will not be used in normal cases idx = self.__fields__.index(item) - return super(Row, self).__getitem__(idx) + return super().__getitem__(idx) except IndexError: raise PySparkKeyError(errorClass="KEY_NOT_EXISTS", messageParameters={"key": str(item)}) except ValueError: diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c7471d19f7d6..75bcb66efdb8 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -438,6 +438,7 @@ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column: PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, + PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF, ]: warnings.warn( "Profiling UDFs with iterators input/output is not supported.", diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index ee86f5c03974..6e3c282a41ee 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -234,7 +234,7 @@ def quiet(self): class ReusedMixedTestCase(ReusedConnectTestCase, SQLTestUtils): @classmethod def setUpClass(cls): - super(ReusedMixedTestCase, cls).setUpClass() + super().setUpClass() # Disable the shared namespace so pyspark.sql.functions, etc point the regular # PySpark libraries. os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1" @@ -250,7 +250,7 @@ def tearDownClass(cls): cls.spark = cls.connect del os.environ["PYSPARK_NO_NAMESPACE_SHARE"] finally: - super(ReusedMixedTestCase, cls).tearDownClass() + super().tearDownClass() def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20): from pyspark.sql.classic.dataframe import DataFrame as SDF diff --git a/python/pyspark/testing/mlutils.py b/python/pyspark/testing/mlutils.py index aa3e23bccb19..491c539acc0e 100644 --- a/python/pyspark/testing/mlutils.py +++ b/python/pyspark/testing/mlutils.py @@ -110,7 +110,7 @@ def __init__(self): class HasFake(Params): def __init__(self): - super(HasFake, self).__init__() + super().__init__() self.fake = Param(self, "fake", "fake param") def getFake(self): @@ -119,7 +119,7 @@ def getFake(self): class MockTransformer(Transformer, HasFake): def __init__(self): - super(MockTransformer, self).__init__() + super().__init__() self.dataset_index = None def _transform(self, dataset): @@ -137,7 +137,7 @@ class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParam ) def __init__(self, shiftVal=1): - super(MockUnaryTransformer, self).__init__() + super().__init__() self._setDefault(shift=1) self._set(shift=shiftVal) @@ -161,7 +161,7 @@ def validateInputType(self, inputType): class MockEstimator(Estimator, HasFake): def __init__(self): - super(MockEstimator, self).__init__() + super().__init__() self.dataset_index = None def _fit(self, dataset): @@ -198,7 +198,7 @@ def __init__( regParam=0.0, rawPredictionCol="rawPrediction", ): - super(DummyLogisticRegression, self).__init__() + super().__init__() kwargs = self._input_kwargs self.setParams(**kwargs) @@ -229,7 +229,7 @@ class DummyLogisticRegressionModel( DefaultParamsWritable, ): def __init__(self): - super(DummyLogisticRegressionModel, self).__init__() + super().__init__() def _transform(self, dataset): # A dummy transform impl which always predict label 1 diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index e53240586d59..11b741947604 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -473,7 +473,7 @@ def _to_pandas(obj: Any): class PandasOnSparkTestCase(ReusedSQLTestCase, PandasOnSparkTestUtils): @classmethod def setUpClass(cls): - super(PandasOnSparkTestCase, cls).setUpClass() + super().setUpClass() cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True) def setUp(self): diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index 22f75bb931b1..b63c98f96f4e 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -208,7 +208,7 @@ def assert_close(a, b): class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils, PySparkErrorTestUtils): @classmethod def setUpClass(cls): - super(ReusedSQLTestCase, cls).setUpClass() + super().setUpClass() cls._legacy_sc = cls.sc cls.spark = SparkSession(cls.sc) cls.tempdir = tempfile.NamedTemporaryFile(delete=False) @@ -218,7 +218,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - super(ReusedSQLTestCase, cls).tearDownClass() + super().tearDownClass() cls.spark.stop() shutil.rmtree(cls.tempdir.name, ignore_errors=True) diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 3961997120ed..eb50a105cc77 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -227,7 +227,7 @@ def tearDown(self): class WorkerSegfaultTest(ReusedPySparkTestCase): @classmethod def conf(cls): - _conf = super(WorkerSegfaultTest, cls).conf() + _conf = super().conf() _conf.set("spark.python.worker.faulthandler.enabled", "true") return _conf @@ -252,7 +252,7 @@ def f(): class WorkerSegfaultNonDaemonTest(WorkerSegfaultTest): @classmethod def conf(cls): - _conf = super(WorkerSegfaultNonDaemonTest, cls).conf() + _conf = super().conf() _conf.set("spark.python.use.daemon", "false") return _conf diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 1cf913287d56..f8750fbbec2e 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -564,9 +564,7 @@ def copy_local_properties(*a: Any, **k: Any) -> Any: thread_local.tags = self._tags # type: ignore[has-type] return target(*a, **k) - super(InheritableThread, self).__init__( - target=copy_local_properties, *args, **kwargs # type: ignore[misc] - ) + super().__init__(target=copy_local_properties, *args, **kwargs) # type: ignore[misc] else: # Non Spark Connect from pyspark import SparkContext @@ -585,13 +583,11 @@ def copy_local_properties(*a: Any, **k: Any) -> Any: SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props) return target(*a, **k) - super(InheritableThread, self).__init__( + super().__init__( target=copy_local_properties, *args, **kwargs # type: ignore[misc] ) else: - super(InheritableThread, self).__init__( - target=target, *args, **kwargs # type: ignore[misc] - ) + super().__init__(target=target, *args, **kwargs) # type: ignore[misc] def start(self) -> None: from pyspark.sql import is_remote @@ -619,7 +615,7 @@ def start(self) -> None: if self._session is not None: self._tags = self._session.getTags() - return super(InheritableThread, self).start() + return super().start() class PythonEvalType: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 109157e2c339..65dcbbbf23e6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1514,10 +1514,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil # It expects the UDTF to be in a specific format and performs various checks to # ensure the UDTF is valid. This function also prepares a mapper function for applying # the UDTF logic to input rows. -def read_udtf(pickleSer, infile, eval_type): +def read_udtf(pickleSer, infile, eval_type, runner_conf): if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: - # Load conf used for arrow evaluation. - runner_conf = RunnerConf(infile) input_types = [ field.dataType for field in _parse_datatype_json_string(utf8_deserializer.loads(infile)) ] @@ -1532,7 +1530,6 @@ def read_udtf(pickleSer, infile, eval_type): else: ser = ArrowStreamUDTFSerializer() elif eval_type == PythonEvalType.SQL_ARROW_UDTF: - runner_conf = RunnerConf(infile) # Read the table argument offsets num_table_arg_offsets = read_int(infile) table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)] @@ -1540,7 +1537,6 @@ def read_udtf(pickleSer, infile, eval_type): ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets) else: # Each row is a group so do not batch but send one by one. - runner_conf = RunnerConf() ser = BatchedSerializer(CPickleSerializer(), 1) # See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand' @@ -2688,7 +2684,7 @@ def mapper(_, it): return mapper, None, ser, ser -def read_udfs(pickleSer, infile, eval_type): +def read_udfs(pickleSer, infile, eval_type, runner_conf): state_server_port = None key_schema = None if eval_type in ( @@ -2716,9 +2712,6 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, ): - # Load conf used for pandas_udf evaluation - runner_conf = RunnerConf(infile) - state_object_schema = None if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) @@ -2870,7 +2863,6 @@ def read_udfs(pickleSer, infile, eval_type): int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) else: - runner_conf = RunnerConf() batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100")) ser = BatchedSerializer(CPickleSerializer(), batch_size) @@ -3353,6 +3345,7 @@ def main(infile, outfile): _accumulatorRegistry.clear() eval_type = read_int(infile) + runner_conf = RunnerConf(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) elif eval_type in ( @@ -3360,9 +3353,13 @@ def main(infile, outfile): PythonEvalType.SQL_ARROW_TABLE_UDF, PythonEvalType.SQL_ARROW_UDTF, ): - func, profiler, deserializer, serializer = read_udtf(pickleSer, infile, eval_type) + func, profiler, deserializer, serializer = read_udtf( + pickleSer, infile, eval_type, runner_conf + ) else: - func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) + func, profiler, deserializer, serializer = read_udfs( + pickleSer, infile, eval_type, runner_conf + ) init_time = time.time() diff --git a/python/test_support/test_pytorch_training_file.py b/python/test_support/test_pytorch_training_file.py index 4107197acfd8..150246563f09 100644 --- a/python/test_support/test_pytorch_training_file.py +++ b/python/test_support/test_pytorch_training_file.py @@ -32,7 +32,7 @@ class Net(nn.Module): def __init__(self): - super(Net, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/ClusteredDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/ClusteredDistribution.java index dcc3d191461c..0fa77d259e9d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/ClusteredDistribution.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/ClusteredDistribution.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.connector.distributions; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; /** @@ -26,7 +26,7 @@ * * @since 3.2.0 */ -@Experimental +@Evolving public interface ClusteredDistribution extends Distribution { /** * Returns clustering expressions. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distribution.java index 95d68ea2d1ab..0a2f982fce07 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distribution.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distribution.java @@ -17,12 +17,12 @@ package org.apache.spark.sql.connector.distributions; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Evolving; /** * An interface that defines how data is distributed across partitions. * * @since 3.2.0 */ -@Experimental +@Evolving public interface Distribution {} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distributions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distributions.java index da5d6f8c81a3..6a346a25424f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distributions.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distributions.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.connector.distributions; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.connector.expressions.SortOrder; @@ -26,7 +26,7 @@ * * @since 3.2.0 */ -@Experimental +@Evolving public class Distributions { private Distributions() { } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/OrderedDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/OrderedDistribution.java index 3456178d8e64..f959cc2e00ce 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/OrderedDistribution.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/OrderedDistribution.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.connector.distributions; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.SortOrder; /** @@ -26,7 +26,7 @@ * * @since 3.2.0 */ -@Experimental +@Evolving public interface OrderedDistribution extends Distribution { /** * Returns ordering expressions. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/UnspecifiedDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/UnspecifiedDistribution.java index ea18d8906cfd..4749701e348e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/UnspecifiedDistribution.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/UnspecifiedDistribution.java @@ -17,12 +17,12 @@ package org.apache.spark.sql.connector.distributions; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Evolving; /** * A distribution where no promises are made about co-location of data. * * @since 3.2.0 */ -@Experimental +@Evolving public interface UnspecifiedDistribution extends Distribution {} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NullOrdering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NullOrdering.java index 4aca199c11c0..7d457a62b313 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NullOrdering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NullOrdering.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.connector.expressions; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Evolving; /** * A null order used in sorting expressions. * * @since 3.2.0 */ -@Experimental +@Evolving public enum NullOrdering { NULLS_FIRST, NULLS_LAST; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortDirection.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortDirection.java index 7e3a29945cc9..385154e0fd83 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortDirection.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortDirection.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.connector.expressions; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Evolving; import static org.apache.spark.sql.connector.expressions.NullOrdering.NULLS_FIRST; import static org.apache.spark.sql.connector.expressions.NullOrdering.NULLS_LAST; @@ -30,7 +30,7 @@ * * @since 3.2.0 */ -@Experimental +@Evolving public enum SortDirection { ASCENDING(NULLS_FIRST), DESCENDING(NULLS_LAST); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java index 51401786ca5d..45f06e17de6d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java @@ -17,14 +17,14 @@ package org.apache.spark.sql.connector.expressions; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Evolving; /** * Represents a sort order in the public expression API. * * @since 3.2.0 */ -@Experimental +@Evolving public interface SortOrder extends Expression { /** * Returns the sort expression. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java index 2adfe75f7d80..dbef9dd6146a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.connector.write; -import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.distributions.Distribution; import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution; import org.apache.spark.sql.connector.expressions.SortOrder; @@ -27,7 +27,7 @@ * * @since 3.2.0 */ -@Experimental +@Evolving public interface RequiresDistributionAndOrdering extends Write { /** * Returns the distribution required by this write. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 67d25296a1e2..311c1c946fbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1736,9 +1736,8 @@ class Analyzer( Assignment(key, sourceAttr) } } else { - sourceTable.output.flatMap { sourceAttr => - findAttrInTarget(sourceAttr.name).map( - targetAttr => Assignment(targetAttr, sourceAttr)) + targetTable.output.map { attr => + Assignment(attr, UnresolvedAttribute(Seq(attr.name))) } } UpdateAction( @@ -1775,9 +1774,8 @@ class Analyzer( Assignment(key, sourceAttr) } } else { - sourceTable.output.flatMap { sourceAttr => - findAttrInTarget(sourceAttr.name).map( - targetAttr => Assignment(targetAttr, sourceAttr)) + targetTable.output.map { attr => + Assignment(attr, UnresolvedAttribute(Seq(attr.name))) } } InsertAction( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala index d1b8eab13191..bf1016ba8268 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala @@ -53,7 +53,7 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] { case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && m.rewritable && !m.aligned && !m.needSchemaEvolution => validateStoreAssignmentPolicy() - val coerceNestedTypes = SQLConf.get.coerceMergeNestedTypes + val coerceNestedTypes = SQLConf.get.coerceMergeNestedTypes && m.withSchemaEvolution m.copy( targetTable = cleanAttrMetadata(m.targetTable), matchedActions = alignActions(m.targetTable.output, m.matchedActions, diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala index 9a1db36514da..fdce056e664d 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.connect.client.jdbc -import java.sql.{ResultSet, SQLException, Types} +import java.math.{BigDecimal => JBigDecimal} +import java.nio.charset.StandardCharsets +import java.sql.{Date, ResultSet, SQLException, Time, Timestamp, Types} +import java.util.{Calendar, TimeZone} import scala.util.Using @@ -223,344 +226,359 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess } test("get decimal type") { - Seq( - ("123.45", 37, 2, 39), - ("-0.12345", 5, 5, 8), - ("-0.12345", 6, 5, 8), - ("-123.45", 5, 2, 7), - ("12345", 5, 0, 6), - ("-12345", 5, 0, 6) - ).foreach { - case (value, precision, scale, expectedColumnDisplaySize) => - val decimalType = s"DECIMAL($precision,$scale)" - withExecuteQuery(s"SELECT cast('$value' as $decimalType)") { rs => - assert(rs.next()) - assert(rs.getBigDecimal(1) === new java.math.BigDecimal(value)) - assert(!rs.wasNull) - assert(!rs.next()) - - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnName(1) === s"CAST($value AS $decimalType)") - assert(metaData.getColumnLabel(1) === s"CAST($value AS $decimalType)") - assert(metaData.getColumnType(1) === Types.DECIMAL) - assert(metaData.getColumnTypeName(1) === decimalType) - assert(metaData.getColumnClassName(1) === "java.math.BigDecimal") - assert(metaData.isSigned(1) === true) - assert(metaData.getPrecision(1) === precision) - assert(metaData.getScale(1) === scale) - assert(metaData.getColumnDisplaySize(1) === expectedColumnDisplaySize) - assert(metaData.getColumnDisplaySize(1) >= value.size) - } + withStatement { stmt => + Seq( + ("123.45", 37, 2, 39), + ("-0.12345", 5, 5, 8), + ("-0.12345", 6, 5, 8), + ("-123.45", 5, 2, 7), + ("12345", 5, 0, 6), + ("-12345", 5, 0, 6) + ).foreach { + case (value, precision, scale, expectedColumnDisplaySize) => + val decimalType = s"DECIMAL($precision,$scale)" + withExecuteQuery(stmt, s"SELECT cast('$value' as $decimalType)") { rs => + assert(rs.next()) + assert(rs.getBigDecimal(1) === new JBigDecimal(value)) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === s"CAST($value AS $decimalType)") + assert(metaData.getColumnLabel(1) === s"CAST($value AS $decimalType)") + assert(metaData.getColumnType(1) === Types.DECIMAL) + assert(metaData.getColumnTypeName(1) === decimalType) + assert(metaData.getColumnClassName(1) === "java.math.BigDecimal") + assert(metaData.isSigned(1) === true) + assert(metaData.getPrecision(1) === precision) + assert(metaData.getScale(1) === scale) + assert(metaData.getColumnDisplaySize(1) === expectedColumnDisplaySize) + assert(metaData.getColumnDisplaySize(1) >= value.size) + } + } } } test("getter functions column index out of bound") { - Seq( - ("'foo'", (rs: ResultSet) => rs.getString(999)), - ("true", (rs: ResultSet) => rs.getBoolean(999)), - ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(999)), - ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(999)), - ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(999)), - ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(999)), - ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(999)), - ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(999)), - ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(999)), - ("CAST(X'0A0B0C' AS BINARY)", (rs: ResultSet) => rs.getBytes(999)), - ("date '2025-11-15'", (rs: ResultSet) => rs.getBytes(999)), - ("time '12:34:56.123456'", (rs: ResultSet) => rs.getBytes(999)), - ("timestamp '2025-11-15 10:30:45.123456'", (rs: ResultSet) => rs.getTimestamp(999)), - ("timestamp_ntz '2025-11-15 10:30:45.789012'", (rs: ResultSet) => rs.getTimestamp(999)) - ).foreach { - case (query, getter) => - withExecuteQuery(s"SELECT $query") { rs => - assert(rs.next()) - val exception = intercept[SQLException] { - getter(rs) + withStatement { stmt => + Seq( + ("'foo'", (rs: ResultSet) => rs.getString(999)), + ("true", (rs: ResultSet) => rs.getBoolean(999)), + ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(999)), + ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(999)), + ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(999)), + ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(999)), + ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(999)), + ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(999)), + ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(999)), + ("CAST(X'0A0B0C' AS BINARY)", (rs: ResultSet) => rs.getBytes(999)), + ("date '2025-11-15'", (rs: ResultSet) => rs.getBytes(999)), + ("time '12:34:56.123456'", (rs: ResultSet) => rs.getBytes(999)), + ("timestamp '2025-11-15 10:30:45.123456'", (rs: ResultSet) => rs.getTimestamp(999)), + ("timestamp_ntz '2025-11-15 10:30:45.789012'", (rs: ResultSet) => rs.getTimestamp(999)) + ).foreach { + case (query, getter) => + withExecuteQuery(stmt, s"SELECT $query") { rs => + assert(rs.next()) + val exception = intercept[SQLException] { + getter(rs) + } + assert(exception.getMessage() === + "The column index is out of range: 999, number of columns: 1.") } - assert(exception.getMessage() === - "The column index is out of range: 999, number of columns: 1.") - } + } } } test("getter functions called after statement closed") { - Seq( - ("'foo'", (rs: ResultSet) => rs.getString(1), "foo"), - ("true", (rs: ResultSet) => rs.getBoolean(1), true), - ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(1), 1.toByte), - ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(1), 1.toShort), - ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(1), 1.toInt), - ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(1), 1.toLong), - ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(1), 1.toFloat), - ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(1), 1.toDouble), - ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(1), - new java.math.BigDecimal("1.00000")), - ("CAST(X'0A0B0C' AS BINARY)", (rs: ResultSet) => rs.getBytes(1), - Array[Byte](0x0A, 0x0B, 0x0C)), - ("date '2023-11-15'", (rs: ResultSet) => rs.getDate(1), - java.sql.Date.valueOf("2023-11-15")), - ("time '12:34:56.123456'", (rs: ResultSet) => rs.getTime(1), { - val millis = timeToMillis(12, 34, 56, 123) - new java.sql.Time(millis) - }) - ).foreach { - case (query, getter, expectedValue) => - var resultSet: Option[ResultSet] = None - withExecuteQuery(s"SELECT $query") { rs => - assert(rs.next()) - expectedValue match { - case arr: Array[Byte] => assert(getter(rs).asInstanceOf[Array[Byte]].sameElements(arr)) - case other => assert(getter(rs) === other) + withStatement { stmt => + Seq( + ("'foo'", (rs: ResultSet) => rs.getString(1), "foo"), + ("true", (rs: ResultSet) => rs.getBoolean(1), true), + ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(1), 1.toByte), + ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(1), 1.toShort), + ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(1), 1.toInt), + ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(1), 1.toLong), + ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(1), 1.toFloat), + ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(1), 1.toDouble), + ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(1), + new JBigDecimal("1.00000")), + ("CAST(X'0A0B0C' AS BINARY)", (rs: ResultSet) => rs.getBytes(1), + Array[Byte](0x0A, 0x0B, 0x0C)), + ("date '2023-11-15'", (rs: ResultSet) => rs.getDate(1), + Date.valueOf("2023-11-15")), + ("time '12:34:56.123456'", (rs: ResultSet) => rs.getTime(1), { + val millis = timeToMillis(12, 34, 56, 123) + new Time(millis) + }) + ).foreach { + case (query, getter, expectedValue) => + var resultSet: Option[ResultSet] = None + withExecuteQuery(stmt, s"SELECT $query") { rs => + assert(rs.next()) + expectedValue match { + case arr: Array[Byte] => + assert(getter(rs).asInstanceOf[Array[Byte]].sameElements(arr)) + case other => assert(getter(rs) === other) + } + assert(!rs.wasNull) + resultSet = Some(rs) } - assert(!rs.wasNull) - resultSet = Some(rs) - } - assert(resultSet.isDefined) - val exception = intercept[SQLException] { - getter(resultSet.get) - } - assert(exception.getMessage() === "JDBC Statement is closed.") + assert(resultSet.isDefined) + val exception = intercept[SQLException] { + getter(resultSet.get) + } + assert(exception.getMessage() === "JDBC Statement is closed.") + } } } test("get date type") { - withExecuteQuery("SELECT date '2023-11-15'") { rs => - assert(rs.next()) - assert(rs.getDate(1) === java.sql.Date.valueOf("2023-11-15")) - assert(!rs.wasNull) - assert(!rs.next()) - - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnName(1) === "DATE '2023-11-15'") - assert(metaData.getColumnLabel(1) === "DATE '2023-11-15'") - assert(metaData.getColumnType(1) === Types.DATE) - assert(metaData.getColumnTypeName(1) === "DATE") - assert(metaData.getColumnClassName(1) === "java.sql.Date") - assert(metaData.isSigned(1) === false) - assert(metaData.getPrecision(1) === 10) - assert(metaData.getScale(1) === 0) - assert(metaData.getColumnDisplaySize(1) === 10) - } - } - - test("get date type with null") { - withExecuteQuery("SELECT cast(null as date)") { rs => - assert(rs.next()) - assert(rs.getDate(1) === null) - assert(rs.wasNull) - assert(!rs.next()) + withStatement { stmt => + // Test basic date type + withExecuteQuery(stmt, "SELECT date '2023-11-15'") { rs => + assert(rs.next()) + assert(rs.getDate(1) === Date.valueOf("2023-11-15")) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "DATE '2023-11-15'") + assert(metaData.getColumnLabel(1) === "DATE '2023-11-15'") + assert(metaData.getColumnType(1) === Types.DATE) + assert(metaData.getColumnTypeName(1) === "DATE") + assert(metaData.getColumnClassName(1) === "java.sql.Date") + assert(metaData.isSigned(1) === false) + assert(metaData.getPrecision(1) === 10) + assert(metaData.getScale(1) === 0) + assert(metaData.getColumnDisplaySize(1) === 10) + } - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnName(1) === "CAST(NULL AS DATE)") - assert(metaData.getColumnLabel(1) === "CAST(NULL AS DATE)") - assert(metaData.getColumnType(1) === Types.DATE) - assert(metaData.getColumnTypeName(1) === "DATE") - assert(metaData.getColumnClassName(1) === "java.sql.Date") - assert(metaData.isSigned(1) === false) - assert(metaData.getPrecision(1) === 10) - assert(metaData.getScale(1) === 0) - assert(metaData.getColumnDisplaySize(1) === 10) - } - } + // Test date type with null + withExecuteQuery(stmt, "SELECT cast(null as date)") { rs => + assert(rs.next()) + assert(rs.getDate(1) === null) + assert(rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "CAST(NULL AS DATE)") + assert(metaData.getColumnLabel(1) === "CAST(NULL AS DATE)") + assert(metaData.getColumnType(1) === Types.DATE) + assert(metaData.getColumnTypeName(1) === "DATE") + assert(metaData.getColumnClassName(1) === "java.sql.Date") + assert(metaData.isSigned(1) === false) + assert(metaData.getPrecision(1) === 10) + assert(metaData.getScale(1) === 0) + assert(metaData.getColumnDisplaySize(1) === 10) + } - test("get date type by column label") { - withExecuteQuery("SELECT date '2025-11-15' as test_date") { rs => - assert(rs.next()) - assert(rs.getDate("test_date") === java.sql.Date.valueOf("2025-11-15")) - assert(!rs.wasNull) - assert(!rs.next()) + // Test date type by column label + withExecuteQuery(stmt, "SELECT date '2025-11-15' as test_date") { rs => + assert(rs.next()) + assert(rs.getDate("test_date") === Date.valueOf("2025-11-15")) + assert(!rs.wasNull) + assert(!rs.next()) + } } } test("get binary type") { - val testBytes = Array[Byte](0x01, 0x02, 0x03, 0x04, 0x05) - val hexString = testBytes.map(b => "%02X".format(b)).mkString - withExecuteQuery(s"SELECT CAST(X'$hexString' AS BINARY)") { rs => - assert(rs.next()) - val bytes = rs.getBytes(1) - assert(bytes !== null) - assert(bytes.length === testBytes.length) - assert(bytes.sameElements(testBytes)) - assert(!rs.wasNull) + withStatement { stmt => + // Test basic binary type + val testBytes = Array[Byte](0x01, 0x02, 0x03, 0x04, 0x05) + val hexString = testBytes.map(b => "%02X".format(b)).mkString + withExecuteQuery(stmt, s"SELECT CAST(X'$hexString' AS BINARY)") { rs => + assert(rs.next()) + val bytes = rs.getBytes(1) + assert(bytes !== null) + assert(bytes.length === testBytes.length) + assert(bytes.sameElements(testBytes)) + assert(!rs.wasNull) val stringValue = rs.getString(1) - val expectedString = new String(testBytes, java.nio.charset.StandardCharsets.UTF_8) + val expectedString = new String(testBytes, StandardCharsets.UTF_8) assert(stringValue === expectedString) - assert(!rs.next()) - - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnType(1) === Types.VARBINARY) - assert(metaData.getColumnTypeName(1) === "BINARY") - assert(metaData.getColumnClassName(1) === "[B") - assert(metaData.isSigned(1) === false) - } - } + assert(!rs.next()) - test("get binary type with UTF-8 text") { - val textBytes = "\\xDeAdBeEf".getBytes(java.nio.charset.StandardCharsets.UTF_8) - val hexString = textBytes.map(b => "%02X".format(b)).mkString - withExecuteQuery(s"SELECT CAST(X'$hexString' AS BINARY)") { rs => - assert(rs.next()) - val bytes = rs.getBytes(1) - assert(bytes !== null) - assert(bytes.sameElements(textBytes)) - - val stringValue = rs.getString(1) - assert(stringValue === "\\xDeAdBeEf") - - assert(!rs.next()) - } - } + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnType(1) === Types.VARBINARY) + assert(metaData.getColumnTypeName(1) === "BINARY") + assert(metaData.getColumnClassName(1) === "[B") + assert(metaData.isSigned(1) === false) + } - test("get binary type with null") { - withExecuteQuery("SELECT cast(null as binary)") { rs => - assert(rs.next()) - assert(rs.getBytes(1) === null) - assert(rs.wasNull) - assert(!rs.next()) + // Test binary type with UTF-8 text + val textBytes = "\\xDeAdBeEf".getBytes(StandardCharsets.UTF_8) + val hexString2 = textBytes.map(b => "%02X".format(b)).mkString + withExecuteQuery(stmt, s"SELECT CAST(X'$hexString2' AS BINARY)") { rs => + assert(rs.next()) + val bytes = rs.getBytes(1) + assert(bytes !== null) + assert(bytes.sameElements(textBytes)) - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnType(1) === Types.VARBINARY) - assert(metaData.getColumnTypeName(1) === "BINARY") - assert(metaData.getColumnClassName(1) === "[B") - } - } + val stringValue = rs.getString(1) + assert(stringValue === "\\xDeAdBeEf") - test("get binary type by column label") { - val testBytes = Array[Byte](0x0A, 0x0B, 0x0C) - val hexString = testBytes.map(b => "%02X".format(b)).mkString - withExecuteQuery(s"SELECT CAST(X'$hexString' AS BINARY) as test_binary") { rs => - assert(rs.next()) - val bytes = rs.getBytes("test_binary") - assert(bytes !== null) - assert(bytes.length === testBytes.length) - assert(bytes.sameElements(testBytes)) - assert(!rs.wasNull) - assert(!rs.next()) + assert(!rs.next()) + } - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnName(1) === "test_binary") - assert(metaData.getColumnLabel(1) === "test_binary") - } - } + // Test binary type with null + withExecuteQuery(stmt, "SELECT cast(null as binary)") { rs => + assert(rs.next()) + assert(rs.getBytes(1) === null) + assert(rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnType(1) === Types.VARBINARY) + assert(metaData.getColumnTypeName(1) === "BINARY") + assert(metaData.getColumnClassName(1) === "[B") + } - test("get empty binary") { - withExecuteQuery("SELECT CAST(X'' AS BINARY)") { rs => - assert(rs.next()) - val bytes = rs.getBytes(1) - assert(bytes !== null) - assert(bytes.length === 0) - assert(!rs.wasNull) + // Test binary type by column label + val testBytes2 = Array[Byte](0x0A, 0x0B, 0x0C) + val hexString3 = testBytes2.map(b => "%02X".format(b)).mkString + withExecuteQuery(stmt, s"SELECT CAST(X'$hexString3' AS BINARY) as test_binary") { rs => + assert(rs.next()) + val bytes = rs.getBytes("test_binary") + assert(bytes !== null) + assert(bytes.length === testBytes2.length) + assert(bytes.sameElements(testBytes2)) + assert(!rs.wasNull) + + val stringValue = rs.getString("test_binary") + val expectedString = new String(testBytes2, StandardCharsets.UTF_8) + assert(stringValue === expectedString) + + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "test_binary") + assert(metaData.getColumnLabel(1) === "test_binary") + } - val stringValue = rs.getString(1) - assert(stringValue === "") - assert(!rs.next()) + // Test empty binary + withExecuteQuery(stmt, "SELECT CAST(X'' AS BINARY)") { rs => + assert(rs.next()) + val bytes = rs.getBytes(1) + assert(bytes !== null) + assert(bytes.length === 0) + assert(!rs.wasNull) + + val stringValue = rs.getString(1) + assert(stringValue === "") + assert(!rs.next()) + } } } test("get time type") { - withExecuteQuery("SELECT time '12:34:56.123456'") { rs => - assert(rs.next()) - val time = rs.getTime(1) - // Verify milliseconds are preserved (123 from 123456 microseconds) - val expectedMillis = timeToMillis(12, 34, 56, 123) - assert(time.getTime === expectedMillis) - assert(!rs.wasNull) - assert(!rs.next()) - - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnName(1) === "TIME '12:34:56.123456'") - assert(metaData.getColumnLabel(1) === "TIME '12:34:56.123456'") - assert(metaData.getColumnType(1) === Types.TIME) - assert(metaData.getColumnTypeName(1) === "TIME(6)") - assert(metaData.getColumnClassName(1) === "java.sql.Time") - assert(metaData.isSigned(1) === false) - assert(metaData.getPrecision(1) === 6) - assert(metaData.getScale(1) === 0) - assert(metaData.getColumnDisplaySize(1) === 15) - } - } - - test("get time type with null") { - withExecuteQuery("SELECT cast(null as time)") { rs => - assert(rs.next()) - assert(rs.getTime(1) === null) - assert(rs.wasNull) - assert(!rs.next()) + withStatement { stmt => + // Test basic time type + withExecuteQuery(stmt, "SELECT time '12:34:56.123456'") { rs => + assert(rs.next()) + val time = rs.getTime(1) + // Verify milliseconds are preserved (123 from 123456 microseconds) + val expectedMillis = timeToMillis(12, 34, 56, 123) + assert(time.getTime === expectedMillis) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "TIME '12:34:56.123456'") + assert(metaData.getColumnLabel(1) === "TIME '12:34:56.123456'") + assert(metaData.getColumnType(1) === Types.TIME) + assert(metaData.getColumnTypeName(1) === "TIME(6)") + assert(metaData.getColumnClassName(1) === "java.sql.Time") + assert(metaData.isSigned(1) === false) + assert(metaData.getPrecision(1) === 6) + assert(metaData.getScale(1) === 0) + assert(metaData.getColumnDisplaySize(1) === 15) + } - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnName(1) === "CAST(NULL AS TIME(6))") - assert(metaData.getColumnLabel(1) === "CAST(NULL AS TIME(6))") - assert(metaData.getColumnType(1) === Types.TIME) - assert(metaData.getColumnTypeName(1) === "TIME(6)") - assert(metaData.getColumnClassName(1) === "java.sql.Time") - assert(metaData.isSigned(1) === false) - assert(metaData.getPrecision(1) === 6) - assert(metaData.getScale(1) === 0) - assert(metaData.getColumnDisplaySize(1) === 15) - } - } + // Test time type with null + withExecuteQuery(stmt, "SELECT cast(null as time)") { rs => + assert(rs.next()) + assert(rs.getTime(1) === null) + assert(rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "CAST(NULL AS TIME(6))") + assert(metaData.getColumnLabel(1) === "CAST(NULL AS TIME(6))") + assert(metaData.getColumnType(1) === Types.TIME) + assert(metaData.getColumnTypeName(1) === "TIME(6)") + assert(metaData.getColumnClassName(1) === "java.sql.Time") + assert(metaData.isSigned(1) === false) + assert(metaData.getPrecision(1) === 6) + assert(metaData.getScale(1) === 0) + assert(metaData.getColumnDisplaySize(1) === 15) + } - test("get time type by column label") { - withExecuteQuery("SELECT time '09:15:30.456789' as test_time") { rs => - assert(rs.next()) - val time = rs.getTime("test_time") - // Verify milliseconds are preserved (456 from 456789 microseconds) - val expectedMillis = timeToMillis(9, 15, 30, 456) - assert(time.getTime === expectedMillis) - assert(!rs.wasNull) - assert(!rs.next()) + // Test time type by column label + withExecuteQuery(stmt, "SELECT time '09:15:30.456789' as test_time") { rs => + assert(rs.next()) + val time = rs.getTime("test_time") + // Verify milliseconds are preserved (456 from 456789 microseconds) + val expectedMillis = timeToMillis(9, 15, 30, 456) + assert(time.getTime === expectedMillis) + assert(!rs.wasNull) + assert(!rs.next()) + } } } test("get time type with different precisions") { - Seq( - // (timeValue, precision, expectedDisplaySize, expectedMillis) - // HH:MM:SS (no fractional) - ("15:45:30.123456", 0, 8, timeToMillis(15, 45, 30, 0)), - // HH:MM:SS.f (100ms from .1) - ("10:20:30.123456", 1, 10, timeToMillis(10, 20, 30, 100)), - // HH:MM:SS.fff (123ms) - ("08:15:45.123456", 3, 12, timeToMillis(8, 15, 45, 123)), - // HH:MM:SS.fff (999ms) . Spark TIME values can have microsecond precision, - // but java.sql.Time can only store up to millisecond precision - ("23:59:59.999999", 6, 15, timeToMillis(23, 59, 59, 999)) - ).foreach { - case (timeValue, precision, expectedDisplaySize, expectedMillis) => - withExecuteQuery(s"SELECT cast(time '$timeValue' as time($precision))") { rs => - assert(rs.next(), s"Failed to get next row for precision $precision") - val time = rs.getTime(1) - assert(time.getTime === expectedMillis, - s"Time millis mismatch for precision" + - s" $precision: expected $expectedMillis, got ${time.getTime}") - assert(!rs.wasNull, s"wasNull should be false for precision $precision") - assert(!rs.next(), s"Should have no more rows for precision $precision") - - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnType(1) === Types.TIME, - s"Column type mismatch for precision $precision") - assert(metaData.getColumnTypeName(1) === s"TIME($precision)", - s"Column type name mismatch for precision $precision") - assert(metaData.getColumnClassName(1) === "java.sql.Time", - s"Column class name mismatch for precision $precision") - assert(metaData.getPrecision(1) === precision, - s"Precision mismatch for precision $precision") - assert(metaData.getScale(1) === 0, - s"Scale should be 0 for precision $precision") - assert(metaData.getColumnDisplaySize(1) === expectedDisplaySize, - s"Display size mismatch for precision $precision: " + - s"expected $expectedDisplaySize, got ${metaData.getColumnDisplaySize(1)}") - } + withStatement { stmt => + Seq( + // (timeValue, precision, expectedDisplaySize, expectedMillis) + // HH:MM:SS (no fractional) + ("15:45:30.123456", 0, 8, timeToMillis(15, 45, 30, 0)), + // HH:MM:SS.f (100ms from .1) + ("10:20:30.123456", 1, 10, timeToMillis(10, 20, 30, 100)), + // HH:MM:SS.fff (123ms) + ("08:15:45.123456", 3, 12, timeToMillis(8, 15, 45, 123)), + // HH:MM:SS.fff (999ms) . Spark TIME values can have microsecond precision, + // but java.sql.Time can only store up to millisecond precision + ("23:59:59.999999", 6, 15, timeToMillis(23, 59, 59, 999)) + ).foreach { + case (timeValue, precision, expectedDisplaySize, expectedMillis) => + withExecuteQuery(stmt, s"SELECT cast(time '$timeValue' as time($precision))") { rs => + assert(rs.next(), s"Failed to get next row for precision $precision") + val time = rs.getTime(1) + assert(time.getTime === expectedMillis, + s"Time millis mismatch for precision" + + s" $precision: expected $expectedMillis, got ${time.getTime}") + assert(!rs.wasNull, s"wasNull should be false for precision $precision") + assert(!rs.next(), s"Should have no more rows for precision $precision") + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnType(1) === Types.TIME, + s"Column type mismatch for precision $precision") + assert(metaData.getColumnTypeName(1) === s"TIME($precision)", + s"Column type name mismatch for precision $precision") + assert(metaData.getColumnClassName(1) === "java.sql.Time", + s"Column class name mismatch for precision $precision") + assert(metaData.getPrecision(1) === precision, + s"Precision mismatch for precision $precision") + assert(metaData.getScale(1) === 0, + s"Scale should be 0 for precision $precision") + assert(metaData.getColumnDisplaySize(1) === expectedDisplaySize, + s"Display size mismatch for precision $precision: " + + s"expected $expectedDisplaySize, got ${metaData.getColumnDisplaySize(1)}") + } + } } } @@ -570,7 +588,7 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess stmt.execute(s"set spark.sql.datetime.java8API.enabled=$java8APIEnabled") Using.resource(stmt.executeQuery("SELECT date '2025-11-15'")) { rs => assert(rs.next()) - assert(rs.getDate(1) === java.sql.Date.valueOf("2025-11-15")) + assert(rs.getDate(1) === Date.valueOf("2025-11-15")) assert(!rs.wasNull) assert(!rs.next()) } @@ -595,150 +613,153 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess } test("get timestamp type") { - withExecuteQuery("SELECT timestamp '2025-11-15 10:30:45.123456'") { rs => - assert(rs.next()) - val timestamp = rs.getTimestamp(1) - assert(timestamp !== null) - assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 10:30:45.123456")) - assert(!rs.wasNull) - assert(!rs.next()) - - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnName(1) === "TIMESTAMP '2025-11-15 10:30:45.123456'") - assert(metaData.getColumnLabel(1) === "TIMESTAMP '2025-11-15 10:30:45.123456'") - assert(metaData.getColumnType(1) === Types.TIMESTAMP) - assert(metaData.getColumnTypeName(1) === "TIMESTAMP") - assert(metaData.getColumnClassName(1) === "java.sql.Timestamp") - assert(metaData.isSigned(1) === false) - assert(metaData.getPrecision(1) === 29) - assert(metaData.getScale(1) === 6) - assert(metaData.getColumnDisplaySize(1) === 29) - } - } - - test("get timestamp type with null") { - withExecuteQuery("SELECT cast(null as timestamp)") { rs => - assert(rs.next()) - assert(rs.getTimestamp(1) === null) - assert(rs.wasNull) - assert(!rs.next()) - - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnName(1) === "CAST(NULL AS TIMESTAMP)") - assert(metaData.getColumnLabel(1) === "CAST(NULL AS TIMESTAMP)") - assert(metaData.getColumnType(1) === Types.TIMESTAMP) - assert(metaData.getColumnTypeName(1) === "TIMESTAMP") - assert(metaData.getColumnClassName(1) === "java.sql.Timestamp") - assert(metaData.isSigned(1) === false) - assert(metaData.getPrecision(1) === 29) - assert(metaData.getScale(1) === 6) - assert(metaData.getColumnDisplaySize(1) === 29) - } - } - - test("get timestamp type by column label and with calendar") { - withExecuteQuery("SELECT timestamp '2025-11-15 10:30:45.987654' as test_timestamp") { rs => - assert(rs.next()) - - // Test by column label - val timestamp = rs.getTimestamp("test_timestamp") - assert(timestamp !== null) - assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 10:30:45.987654")) - assert(!rs.wasNull) + withStatement { stmt => + // Test basic timestamp type + withExecuteQuery(stmt, "SELECT timestamp '2025-11-15 10:30:45.123456'") { rs => + assert(rs.next()) + val timestamp = rs.getTimestamp(1) + assert(timestamp !== null) + assert(timestamp === Timestamp.valueOf("2025-11-15 10:30:45.123456")) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "TIMESTAMP '2025-11-15 10:30:45.123456'") + assert(metaData.getColumnLabel(1) === "TIMESTAMP '2025-11-15 10:30:45.123456'") + assert(metaData.getColumnType(1) === Types.TIMESTAMP) + assert(metaData.getColumnTypeName(1) === "TIMESTAMP") + assert(metaData.getColumnClassName(1) === "java.sql.Timestamp") + assert(metaData.isSigned(1) === false) + assert(metaData.getPrecision(1) === 29) + assert(metaData.getScale(1) === 6) + assert(metaData.getColumnDisplaySize(1) === 29) + } - // Test with calendar - should return same value (Calendar is ignored) - // Note: Spark Connect handles timezone at server, Calendar param is for API compliance - val calUTC = java.util.Calendar.getInstance(java.util.TimeZone.getTimeZone("UTC")) - val timestampUTC = rs.getTimestamp(1, calUTC) - assert(timestampUTC !== null) - assert(timestampUTC.getTime === timestamp.getTime) - - val calPST = java.util.Calendar.getInstance( - java.util.TimeZone.getTimeZone("America/Los_Angeles")) - val timestampPST = rs.getTimestamp(1, calPST) - assert(timestampPST !== null) - // Same value regardless of calendar - assert(timestampPST.getTime === timestamp.getTime) - assert(timestampUTC.getTime === timestampPST.getTime) - - // Test with calendar by label - val timestampLabel = rs.getTimestamp("test_timestamp", calUTC) - assert(timestampLabel !== null) - assert(timestampLabel.getTime === timestamp.getTime) - - // Test with null calendar - returns same value - val timestampNullCal = rs.getTimestamp(1, null) - assert(timestampNullCal !== null) - assert(timestampNullCal.getTime === timestamp.getTime) + // Test timestamp type with null + withExecuteQuery(stmt, "SELECT cast(null as timestamp)") { rs => + assert(rs.next()) + assert(rs.getTimestamp(1) === null) + assert(rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "CAST(NULL AS TIMESTAMP)") + assert(metaData.getColumnLabel(1) === "CAST(NULL AS TIMESTAMP)") + assert(metaData.getColumnType(1) === Types.TIMESTAMP) + assert(metaData.getColumnTypeName(1) === "TIMESTAMP") + assert(metaData.getColumnClassName(1) === "java.sql.Timestamp") + assert(metaData.isSigned(1) === false) + assert(metaData.getPrecision(1) === 29) + assert(metaData.getScale(1) === 6) + assert(metaData.getColumnDisplaySize(1) === 29) + } - assert(!rs.next()) - } - } + // Test timestamp type by column label and with calendar + val tsString = "2025-11-15 10:30:45.987654" + withExecuteQuery(stmt, s"SELECT timestamp '$tsString' as test_timestamp") { rs => + assert(rs.next()) + + // Test by column label + val timestamp = rs.getTimestamp("test_timestamp") + assert(timestamp !== null) + assert(timestamp === Timestamp.valueOf(tsString)) + assert(!rs.wasNull) + + // Test with calendar - should return same value (Calendar is ignored) + // Note: Spark Connect handles timezone at server, Calendar param is for API compliance + val calUTC = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + val timestampUTC = rs.getTimestamp(1, calUTC) + assert(timestampUTC !== null) + assert(timestampUTC.getTime === timestamp.getTime) + + val calPST = Calendar.getInstance( + TimeZone.getTimeZone("America/Los_Angeles")) + val timestampPST = rs.getTimestamp(1, calPST) + assert(timestampPST !== null) + // Same value regardless of calendar + assert(timestampPST.getTime === timestamp.getTime) + assert(timestampUTC.getTime === timestampPST.getTime) + + // Test with calendar by label + val timestampLabel = rs.getTimestamp("test_timestamp", calUTC) + assert(timestampLabel !== null) + assert(timestampLabel.getTime === timestamp.getTime) + + // Test with null calendar - returns same value + val timestampNullCal = rs.getTimestamp(1, null) + assert(timestampNullCal !== null) + assert(timestampNullCal.getTime === timestamp.getTime) + + assert(!rs.next()) + } - test("get timestamp type with calendar for null value") { - withExecuteQuery("SELECT cast(null as timestamp)") { rs => - assert(rs.next()) + // Test timestamp type with calendar for null value + withExecuteQuery(stmt, "SELECT cast(null as timestamp)") { rs => + assert(rs.next()) - // Calendar parameter should not affect null handling - val cal = java.util.Calendar.getInstance(java.util.TimeZone.getTimeZone("UTC")) - val timestamp = rs.getTimestamp(1, cal) - assert(timestamp === null) - assert(rs.wasNull) - assert(!rs.next()) + // Calendar parameter should not affect null handling + val cal = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + val timestamp = rs.getTimestamp(1, cal) + assert(timestamp === null) + assert(rs.wasNull) + assert(!rs.next()) + } } } test("get timestamp_ntz type") { - withExecuteQuery("SELECT timestamp_ntz '2025-11-15 10:30:45.123456'") { rs => - assert(rs.next()) - val timestamp = rs.getTimestamp(1) - assert(timestamp !== null) - assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 10:30:45.123456")) - assert(!rs.wasNull) - assert(!rs.next()) - - val metaData = rs.getMetaData - assert(metaData.getColumnCount === 1) - assert(metaData.getColumnName(1) === "TIMESTAMP_NTZ '2025-11-15 10:30:45.123456'") - assert(metaData.getColumnLabel(1) === "TIMESTAMP_NTZ '2025-11-15 10:30:45.123456'") - assert(metaData.getColumnType(1) === Types.TIMESTAMP) - assert(metaData.getColumnTypeName(1) === "TIMESTAMP_NTZ") - assert(metaData.getColumnClassName(1) === "java.sql.Timestamp") - assert(metaData.isSigned(1) === false) - assert(metaData.getPrecision(1) === 29) - assert(metaData.getScale(1) === 6) - assert(metaData.getColumnDisplaySize(1) === 29) - } - } + withStatement { stmt => + // Test basic timestamp_ntz type + withExecuteQuery(stmt, "SELECT timestamp_ntz '2025-11-15 10:30:45.123456'") { rs => + assert(rs.next()) + val timestamp = rs.getTimestamp(1) + assert(timestamp !== null) + assert(timestamp === Timestamp.valueOf("2025-11-15 10:30:45.123456")) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "TIMESTAMP_NTZ '2025-11-15 10:30:45.123456'") + assert(metaData.getColumnLabel(1) === "TIMESTAMP_NTZ '2025-11-15 10:30:45.123456'") + assert(metaData.getColumnType(1) === Types.TIMESTAMP) + assert(metaData.getColumnTypeName(1) === "TIMESTAMP_NTZ") + assert(metaData.getColumnClassName(1) === "java.sql.Timestamp") + assert(metaData.isSigned(1) === false) + assert(metaData.getPrecision(1) === 29) + assert(metaData.getScale(1) === 6) + assert(metaData.getColumnDisplaySize(1) === 29) + } - test("get timestamp_ntz type by label, null, and with calendar") { - // Test with non-null value - withExecuteQuery("SELECT timestamp_ntz '2025-11-15 14:22:33.789456' as test_ts_ntz") { rs => - assert(rs.next()) + // Test timestamp_ntz by label, null, and with calendar - non-null value + val tsString = "2025-11-15 14:22:33.789456" + withExecuteQuery(stmt, s"SELECT timestamp_ntz '$tsString' as test_ts_ntz") { rs => + assert(rs.next()) - // Test by column label - val timestamp = rs.getTimestamp("test_ts_ntz") - assert(timestamp !== null) - assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 14:22:33.789456")) - assert(!rs.wasNull) + // Test by column label + val timestamp = rs.getTimestamp("test_ts_ntz") + assert(timestamp !== null) + assert(timestamp === Timestamp.valueOf(tsString)) + assert(!rs.wasNull) - // Test with calendar - should return same value (Calendar is ignored) - val calUTC = java.util.Calendar.getInstance(java.util.TimeZone.getTimeZone("UTC")) - val timestampCal = rs.getTimestamp(1, calUTC) - assert(timestampCal !== null) - assert(timestampCal.getTime === timestamp.getTime) + // Test with calendar - should return same value (Calendar is ignored) + val calUTC = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + val timestampCal = rs.getTimestamp(1, calUTC) + assert(timestampCal !== null) + assert(timestampCal.getTime === timestamp.getTime) - assert(!rs.next()) - } + assert(!rs.next()) + } - // Test with null value - withExecuteQuery("SELECT cast(null as timestamp_ntz)") { rs => - assert(rs.next()) - assert(rs.getTimestamp(1) === null) - assert(rs.wasNull) - assert(!rs.next()) + // Test timestamp_ntz with null value + withExecuteQuery(stmt, "SELECT cast(null as timestamp_ntz)") { rs => + assert(rs.next()) + assert(rs.getTimestamp(1) === null) + assert(rs.wasNull) + assert(!rs.next()) + } } } @@ -757,13 +778,13 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess // Test TIMESTAMP type val timestamp = rs.getTimestamp(1) assert(timestamp !== null) - assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 10:30:45.123456")) + assert(timestamp === Timestamp.valueOf("2025-11-15 10:30:45.123456")) assert(!rs.wasNull) // Test TIMESTAMP_NTZ type val timestampNtz = rs.getTimestamp(2) assert(timestampNtz !== null) - assert(timestampNtz === java.sql.Timestamp.valueOf("2025-11-15 14:22:33.789012")) + assert(timestampNtz === Timestamp.valueOf("2025-11-15 14:22:33.789012")) assert(!rs.wasNull) assert(!rs.next()) diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala index 9b3aa373e93c..a512a44cac3b 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala @@ -39,8 +39,10 @@ trait JdbcHelper { } def withExecuteQuery(query: String)(f: ResultSet => Unit): Unit = { - withStatement { stmt => - Using.resource { stmt.executeQuery(query) } { rs => f(rs) } - } + withStatement { stmt => withExecuteQuery(stmt, query)(f) } + } + + def withExecuteQuery(stmt: Statement, query: String)(f: ResultSet => Unit): Unit = { + Using.resource { stmt.executeQuery(query) } { rs => f(rs) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 47f5f180789e..8534a24d0110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -305,8 +305,7 @@ class JDBCRDD( val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition] conn = getConnection(part.idx) - import scala.jdk.CollectionConverters._ - dialect.beforeFetch(conn, options.asProperties.asScala.toMap) + dialect.beforeFetch(conn, options) // This executes a generic SQL statement (or PL/SQL block) before reading // the table/query via JDBC. Use this feature to initialize the database diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index f5f968ee9522..499fa99a2444 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -35,7 +35,6 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef]( _schema: StructType, _timeZoneId: String, protected override val largeVarTypes: Boolean, - protected override val workerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String]) @@ -86,12 +85,11 @@ abstract class RowInputArrowPythonRunner( _schema: StructType, _timeZoneId: String, largeVarTypes: Boolean, - workerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String]) extends BaseArrowPythonRunner[Iterator[InternalRow], ColumnarBatch]( - funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf, + funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, pythonMetrics, jobArtifactUUID, sessionUUID) with BasicPythonArrowInput with BasicPythonArrowOutput @@ -106,13 +104,13 @@ class ArrowPythonRunner( _schema: StructType, _timeZoneId: String, largeVarTypes: Boolean, - workerConf: Map[String, String], + protected override val runnerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String], profiler: Option[String]) extends RowInputArrowPythonRunner( - funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf, + funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, pythonMetrics, jobArtifactUUID, sessionUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = @@ -130,13 +128,13 @@ class ArrowPythonWithNamedArgumentRunner( _schema: StructType, _timeZoneId: String, largeVarTypes: Boolean, - workerConf: Map[String, String], + protected override val runnerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String], profiler: Option[String]) extends RowInputArrowPythonRunner( - funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, workerConf, + funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, pythonMetrics, jobArtifactUUID, sessionUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index 1d5df9bad924..979d91205d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner( protected override val schema: StructType, protected override val timeZoneId: String, protected override val largeVarTypes: Boolean, - protected override val workerConf: Map[String, String], + protected override val runnerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 7f6efbae8881..b5986be9214a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -25,7 +25,7 @@ import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader} import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} import org.apache.spark.{SparkEnv, SparkException, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker} +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonWorker} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriterWrapper import org.apache.spark.sql.execution.metric.SQLMetric @@ -45,7 +45,7 @@ class CoGroupedArrowPythonRunner( rightSchema: StructType, timeZoneId: String, largeVarTypes: Boolean, - conf: Map[String, String], + protected override val runnerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String], @@ -119,14 +119,6 @@ class CoGroupedArrowPythonRunner( private var rightGroupArrowWriter: ArrowWriterWrapper = null protected override def writeCommand(dataOut: DataOutputStream): Unit = { - - // Write config for the worker as a number of key -> value pairs of strings - dataOut.writeInt(conf.size) - for ((k, v) <- conf) { - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v, dataOut) - } - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index f77b0a9342b0..d2d16b0c9623 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -27,7 +27,7 @@ import org.apache.arrow.vector.ipc.WriteChannel import org.apache.arrow.vector.ipc.message.MessageSerializer import org.apache.spark.{SparkEnv, SparkException, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker} +import org.apache.spark.api.python.{BasePythonRunner, PythonWorker} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow import org.apache.spark.sql.execution.arrow.{ArrowWriter, ArrowWriterWrapper} @@ -42,8 +42,6 @@ import org.apache.spark.util.Utils * JVM (an iterator of internal rows + additional data if required) to Python (Arrow). */ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => - protected val workerConf: Map[String, String] - protected val schema: StructType protected val timeZoneId: String @@ -62,14 +60,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected def writeUDF(dataOut: DataOutputStream): Unit - protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { - // Write config for the worker as a number of key -> value pairs of strings - stream.writeInt(workerConf.size) - for ((k, v) <- workerConf) { - PythonRDD.writeUTF(k, stream) - PythonRDD.writeUTF(v, stream) - } - } + protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {} + private val arrowSchema = ArrowUtils.toArrowSchema( schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) protected val allocator = @@ -301,7 +293,6 @@ private[python] trait GroupedPythonArrowInput { self: RowInputArrowPythonRunner context: TaskContext): Writer = { new Writer(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { - handleMetadataBeforeExec(dataOut) writeUDF(dataOut) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala index 14054ba89a94..ae89ff1637ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala @@ -58,7 +58,7 @@ class ApplyInPandasWithStatePythonRunner( argOffsets: Array[Array[Int]], inputSchema: StructType, _timeZoneId: String, - initialWorkerConf: Map[String, String], + initialRunnerConf: Map[String, String], stateEncoder: ExpressionEncoder[Row], keySchema: StructType, outputSchema: StructType, @@ -113,7 +113,7 @@ class ApplyInPandasWithStatePythonRunner( // applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance. // Configurations are both applied to executor and Python worker, set them to the worker conf // to let Python worker read the config properly. - override protected val workerConf: Map[String, String] = initialWorkerConf + + override protected val runnerConf: Map[String, String] = initialRunnerConf + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + (SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala index 3eb7c7e64d64..bbf7b9387526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala @@ -52,7 +52,7 @@ class TransformWithStateInPySparkPythonRunner( _schema: StructType, processorHandle: StatefulProcessorHandleImpl, _timeZoneId: String, - initialWorkerConf: Map[String, String], + initialRunnerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], groupingKeySchema: StructType, @@ -60,7 +60,7 @@ class TransformWithStateInPySparkPythonRunner( eventTimeWatermarkForEviction: Option[Long]) extends TransformWithStateInPySparkPythonBaseRunner[InType]( funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId, - initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema, + initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema, batchTimestampMs, eventTimeWatermarkForEviction) with PythonArrowInput[InType] { @@ -126,7 +126,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner( initStateSchema: StructType, processorHandle: StatefulProcessorHandleImpl, _timeZoneId: String, - initialWorkerConf: Map[String, String], + initialRunnerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], groupingKeySchema: StructType, @@ -134,7 +134,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner( eventTimeWatermarkForEviction: Option[Long]) extends TransformWithStateInPySparkPythonBaseRunner[GroupedInType]( funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId, - initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema, + initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema, batchTimestampMs, eventTimeWatermarkForEviction) with PythonArrowInput[GroupedInType] { @@ -221,7 +221,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I]( _schema: StructType, processorHandle: StatefulProcessorHandleImpl, _timeZoneId: String, - initialWorkerConf: Map[String, String], + initialRunnerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], groupingKeySchema: StructType, @@ -238,7 +238,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I]( protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch - override protected val workerConf: Map[String, String] = initialWorkerConf + + override protected val runnerConf: Map[String, String] = initialRunnerConf + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + (SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString) @@ -251,7 +251,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I]( override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { super.handleMetadataBeforeExec(stream) - // Also write the port/path number for state server + // Write the port/path number for state server if (isUnixDomainSock) { stream.writeInt(-1) PythonWorkerUtils.writeUTF(stateServerSocketPath, stream) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index ce4c347cad34..875bfeb011bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -342,9 +342,20 @@ abstract class JdbcDialect extends Serializable with Logging { * @param connection The connection object * @param properties The connection properties. This is passed through from the relation. */ + @deprecated("Use beforeFetch(Connection, JDBCOptions) instead", "4.2.0") def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } + /** + * Override connection specific properties to run before a select is made. This is in place to + * allow dialects that need special treatment to optimize behavior. + * @param connection The connection object + * @param options The JDBC options for the connection. + */ + def beforeFetch(connection: Connection, options: JDBCOptions): Unit = { + beforeFetch(connection, options.parameters) + } + /** * Escape special characters in SQL string literals. * @param value The string to be escaped. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index e59b4435c408..5d1173b5a1a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -2404,7 +2404,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sourceDF.createOrReplaceTempView("source") val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" - sql(s"""MERGE $schemaEvolutionClause + val mergeStmt = s"""MERGE $schemaEvolutionClause |INTO $tableNameAsString t |USING source s |ON t.pk = s.pk @@ -2412,8 +2412,9 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase | UPDATE SET * |WHEN NOT MATCHED THEN | INSERT * - |""".stripMargin) + |""".stripMargin if (withSchemaEvolution && schemaEvolutionEnabled) { + sql(mergeStmt) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq( @@ -2424,15 +2425,12 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase Row(5, 250, "executive", true), Row(6, 350, null, false))) } else { - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, 100, "hr"), - Row(2, 200, "software"), - Row(3, 300, "hr"), - Row(4, 150, "marketing"), - Row(5, 250, "executive"), - Row(6, 350, null))) + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.message.contains("A column, variable, or function parameter with name " + + "`dep` cannot be resolved")) } sql(s"DROP TABLE $tableNameAsString") } @@ -2463,7 +2461,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sourceDF.createOrReplaceTempView("source") val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" - sql(s"""MERGE $schemaEvolutionClause + val mergeStmt = s"""MERGE $schemaEvolutionClause |INTO $tableNameAsString t |USING source s |ON t.pk = s.pk @@ -2471,8 +2469,9 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase | UPDATE SET * |WHEN NOT MATCHED THEN | INSERT * - |""".stripMargin) + |""".stripMargin if (withSchemaEvolution && schemaEvolutionEnabled) { + sql(mergeStmt) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), Seq( @@ -2483,15 +2482,12 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase Row(5, 250, "executive", true), Row(6, 350, "unknown", false))) } else { - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, 100, "hr"), - Row(2, 200, "software"), - Row(3, 300, "hr"), - Row(4, 150, "marketing"), - Row(5, 250, "executive"), - Row(6, 350, "unknown"))) + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(e.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(e.getMessage.contains("A column, variable, or function parameter with name " + + "`dep` cannot be resolved")) } sql(s"DROP TABLE $tableNameAsString") } @@ -3239,23 +3235,14 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase |WHEN NOT MATCHED THEN | INSERT * |""".stripMargin - if (coerceNestedTypes) { - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) - } + if (coerceNestedTypes && withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) + } else { val exception = intercept[org.apache.spark.sql.AnalysisException] { sql(mergeStmt) @@ -3335,30 +3322,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase |WHEN NOT MATCHED THEN | INSERT * |""".stripMargin - if (coerceNestedTypes) { - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) - } + if (coerceNestedTypes && withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) } else { val exception = intercept[org.apache.spark.sql.AnalysisException] { sql(mergeStmt) } assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot find data for the output column `s`.`c2`.`a`")) } } sql(s"DROP TABLE IF EXISTS $tableNameAsString") @@ -3534,30 +3509,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase | INSERT * |""".stripMargin - if (coerceNestedTypes) { - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"), - Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "sales"), - Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `m`.`key`")) - } + if (coerceNestedTypes && withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"), + Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "sales"), + Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "engineering"))) } else { val exception = intercept[org.apache.spark.sql.AnalysisException] { sql(mergeStmt) } assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot find data for the output column `m`.`key`.`c2`")) } } sql(s"DROP TABLE IF EXISTS $tableNameAsString") @@ -3612,30 +3575,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase | INSERT (pk, m, dep) VALUES (src.pk, src.m, 'my_new_dep') |""".stripMargin - if (coerceNestedTypes) { - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"), - Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "my_old_dep"), - Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "my_new_dep"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `m`.`key`")) - } + if (coerceNestedTypes && withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"), + Row(1, Map(Row(10, null, true) -> Row("y", null, false)), "my_old_dep"), + Row(2, Map(Row(20, null, false) -> Row("z", null, true)), "my_new_dep"))) } else { val exception = intercept[org.apache.spark.sql.AnalysisException] { sql(mergeStmt) } assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot find data for the output column `m`.`key`.`c2`")) } } sql(s"DROP TABLE IF EXISTS $tableNameAsString") @@ -3688,30 +3639,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase | INSERT * |""".stripMargin - if (coerceNestedTypes) { - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(0, Array(Row(10, 10, null)), "hr"), - Row(1, Array(Row(10, null, true)), "sales"), - Row(2, Array(Row(20, null, false)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `a`.`element`")) - } + if (coerceNestedTypes && withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(0, Array(Row(10, 10, null)), "hr"), + Row(1, Array(Row(10, null, true)), "sales"), + Row(2, Array(Row(20, null, false)), "engineering"))) } else { val exception = intercept[org.apache.spark.sql.AnalysisException] { sql(mergeStmt) } assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot find data for the output column `a`.`element`.`c2`")) } } sql(s"DROP TABLE IF EXISTS $tableNameAsString") @@ -3764,30 +3703,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase | INSERT (pk, a, dep) VALUES (src.pk, src.a, 'my_new_dep') |""".stripMargin - if (coerceNestedTypes) { - if (withSchemaEvolution) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq(Row(0, Array(Row(10, 10, null)), "hr"), - Row(1, Array(Row(10, null, true)), "my_old_dep"), - Row(2, Array(Row(20, null, false)), "my_new_dep"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `a`.`element`")) - } + if (coerceNestedTypes && withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(0, Array(Row(10, 10, null)), "hr"), + Row(1, Array(Row(10, null, true)), "my_old_dep"), + Row(2, Array(Row(20, null, false)), "my_new_dep"))) } else { val exception = intercept[org.apache.spark.sql.AnalysisException] { sql(mergeStmt) } assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot find data for the output column `a`.`element`.`c2`")) } } sql(s"DROP TABLE IF EXISTS $tableNameAsString") @@ -4051,6 +3978,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase sql(mergeStmt) } assert(exception.errorClass.get == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(exception.message.contains(" A column, variable, or function parameter with name " + + "`bonus` cannot be resolved")) } sql(s"DROP TABLE $tableNameAsString") @@ -4484,270 +4413,324 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("merge into with source missing fields in struct nested in array") { - Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> - coerceNestedTypes.toString) { - withTempView("source") { - // Target table has struct with 3 fields (c1, c2, c3) in array - createAndInitTable( - s"""pk INT NOT NULL, - |a ARRAY>, - |dep STRING""".stripMargin, - """{ "pk": 0, "a": [ { "c1": 1, "c2": "a", "c3": true } ], "dep": "sales" } - |{ "pk": 1, "a": [ { "c1": 2, "c2": "b", "c3": false } ], "dep": "sales" }""" - .stripMargin) + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Target table has struct with 3 fields (c1, c2, c3) in array + createAndInitTable( + s"""pk INT NOT NULL, + |a ARRAY>, + |dep STRING""".stripMargin, + """{ "pk": 0, "a": [ { "c1": 1, "c2": "a", "c3": true } ], "dep": "sales" } + |{ "pk": 1, "a": [ { "c1": 2, "c2": "b", "c3": false } ], "dep": "sales" }""" + .stripMargin) - // Source table has struct with only 2 fields (c1, c2) - missing c3 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("a", ArrayType( - StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType))))), // missing c3 field - StructField("dep", StringType))) - val data = Seq( - Row(1, Array(Row(10, "c")), "hr"), - Row(2, Array(Row(30, "e")), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + // Source table has struct with only 2 fields (c1, c2) - missing c3 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("a", ArrayType( + StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType))))), // missing c3 field + StructField("dep", StringType))) + val data = Seq( + Row(1, Array(Row(10, "c")), "hr"), + Row(2, Array(Row(30, "e")), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") - val mergeStmt = - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - if (coerceNestedTypes) { - sql(mergeStmt) - // Missing field c3 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Array(Row(1, "a", true)), "sales"), - Row(1, Array(Row(10, "c", null)), "hr"), - Row(2, Array(Row(30, "e", null)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { + if (coerceNestedTypes && withSchemaEvolution) { sql(mergeStmt) + // Missing field c3 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Array(Row(1, "a", true)), "sales"), + Row(1, Array(Row(10, "c", null)), "hr"), + Row(2, Array(Row(30, "e", null)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot find data for the output column `a`.`element`.`c3`.")) } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } } test("merge into with source missing fields in struct nested in map key") { - Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> - coerceNestedTypes.toString) { - withTempView("source") { - // Target table has struct with 2 fields in map key - val targetSchema = - StructType(Seq( - StructField("pk", IntegerType, nullable = false), + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Target table has struct with 2 fields in map key + val targetSchema = + StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType), StructField("c2", BooleanType))), + StructType(Seq(StructField("c3", StringType))))), + StructField("dep", StringType))) + createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) + + val targetData = Seq( + Row(0, Map(Row(10, true) -> Row("x")), "hr"), + Row(1, Map(Row(20, false) -> Row("y")), "sales")) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() + + // Source table has struct with only 1 field (c1) in map key - missing c2 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType), StructField("c2", BooleanType))), + StructType(Seq(StructField("c1", IntegerType))), // missing c2 StructType(Seq(StructField("c3", StringType))))), StructField("dep", StringType))) - createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) - - val targetData = Seq( - Row(0, Map(Row(10, true) -> Row("x")), "hr"), - Row(1, Map(Row(20, false) -> Row("y")), "sales")) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() - - // Source table has struct with only 1 field (c1) in map key - missing c2 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType))), // missing c2 - StructType(Seq(StructField("c3", StringType))))), - StructField("dep", StringType))) - val sourceData = Seq( - Row(1, Map(Row(10) -> Row("z")), "sales"), - Row(2, Map(Row(20) -> Row("w")), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source") + val sourceData = Seq( + Row(1, Map(Row(10) -> Row("z")), "sales"), + Row(2, Map(Row(20) -> Row("w")), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") - val mergeStmt = - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - if (coerceNestedTypes) { - sql(mergeStmt) - // Missing field c2 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Map(Row(10, true) -> Row("x")), "hr"), - Row(1, Map(Row(10, null) -> Row("z")), "sales"), - Row(2, Map(Row(20, null) -> Row("w")), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { + if (coerceNestedTypes && withSchemaEvolution) { sql(mergeStmt) + // Missing field c2 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Map(Row(10, true) -> Row("x")), "hr"), + Row(1, Map(Row(10, null) -> Row("z")), "sales"), + Row(2, Map(Row(20, null) -> Row("w")), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot find data for the output column `m`.`key`.`c2`.")) } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } } test("merge into with source missing fields in struct nested in map value") { - Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> - coerceNestedTypes.toString) { - withTempView("source") { - // Target table has struct with 2 fields in map value - val targetSchema = - StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType))), - StructType(Seq(StructField("c1", StringType), StructField("c2", BooleanType))))), - StructField("dep", StringType))) - createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) - - val targetData = Seq( - Row(0, Map(Row(10) -> Row("x", true)), "hr"), - Row(1, Map(Row(20) -> Row("y", false)), "sales")) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() - - // Source table has struct with only 1 field (c1) in map value - missing c2 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("m", MapType( - StructType(Seq(StructField("c1", IntegerType))), - StructType(Seq(StructField("c1", StringType))))), // missing c2 - StructField("dep", StringType))) - val sourceData = Seq( - Row(1, Map(Row(10) -> Row("z")), "sales"), - Row(2, Map(Row(20) -> Row("w")), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) - .createOrReplaceTempView("source") + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Target table has struct with 2 fields in map value + val targetSchema = + StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType))), + StructType(Seq(StructField("c1", StringType), StructField("c2", BooleanType))))), + StructField("dep", StringType))) + createTable(CatalogV2Util.structTypeToV2Columns(targetSchema)) - val mergeStmt = - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin + val targetData = Seq( + Row(0, Map(Row(10) -> Row("x", true)), "hr"), + Row(1, Map(Row(20) -> Row("y", false)), "sales")) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() - if (coerceNestedTypes) { - sql(mergeStmt) - // Missing field c2 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Map(Row(10) -> Row("x", true)), "hr"), - Row(1, Map(Row(10) -> Row("z", null)), "sales"), - Row(2, Map(Row(20) -> Row("w", null)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { + // Source table has struct with only 1 field (c1) in map value - missing c2 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("m", MapType( + StructType(Seq(StructField("c1", IntegerType))), + StructType(Seq(StructField("c1", StringType))))), // missing c2 + StructField("dep", StringType))) + val sourceData = Seq( + Row(1, Map(Row(10) -> Row("z")), "sales"), + Row(2, Map(Row(20) -> Row("w")), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(sourceData), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin + + if (coerceNestedTypes && withSchemaEvolution) { sql(mergeStmt) + // Missing field c2 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Map(Row(10) -> Row("x", true)), "hr"), + Row(1, Map(Row(10) -> Row("z", null)), "sales"), + Row(2, Map(Row(20) -> Row("w", null)), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot find data for the output column `m`.`value`.`c2`.")) } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } } test("merge into with source missing fields in top-level struct") { - Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> - coerceNestedTypes.toString) { - withTempView("source") { - // Target table has struct with 3 fields at top level - createAndInitTable( - s"""pk INT NOT NULL, - |s STRUCT, - |dep STRING""".stripMargin, - """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""") + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Target table has struct with 3 fields at top level + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep": "sales"}""") - // Source table has struct with only 2 fields (c1, c2) - missing c3 - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType)))), // missing c3 field - StructField("dep", StringType))) - val data = Seq( - Row(1, Row(10, "b"), "hr"), - Row(2, Row(20, "c"), "engineering") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + // Source table has struct with only 2 fields (c1, c2) - missing c3 + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType)))), // missing c3 field + StructField("dep", StringType))) + val data = Seq( + Row(1, Row(10, "b"), "hr"), + Row(2, Row(20, "c"), "engineering") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") - val mergeStmt = - s"""MERGE INTO $tableNameAsString t - |USING source src - |ON t.pk = src.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - if (coerceNestedTypes) { - sql(mergeStmt) - // Missing field c3 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, "a", true), "sales"), - Row(1, Row(10, "b", null), "hr"), - Row(2, Row(20, "c", null), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { + if (coerceNestedTypes && withSchemaEvolution) { sql(mergeStmt) + // Missing field c3 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, "a", true), "sales"), + Row(1, Row(10, "b", null), "hr"), + Row(2, Row(20, "c", null), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot find data for the output column `s`.`c3`.")) } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } } + test("merge into with source missing top-level column") { + Seq(true, false).foreach { withSchemaEvolution => + withTempView("source") { + // Target table has 3 columns: pk, salary, dep + createAndInitTable( + s"""pk INT NOT NULL, + |salary INT, + |dep STRING""".stripMargin, + """{ "pk": 0, "salary": 100, "dep": "sales" } + |{ "pk": 1, "salary": 200, "dep": "hr" }""" + .stripMargin) + + // Source table has only 2 columns: pk, dep (missing salary) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("dep", StringType))) + val data = Seq( + Row(1, "engineering"), + Row(2, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") + + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t + |USING source src + |ON t.pk = src.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin + + if (withSchemaEvolution) { + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, 100, "sales"), + Row(1, 200, "engineering"), + Row(2, null, "finance"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "UNRESOLVED_COLUMN.WITH_SUGGESTION") + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } + } + test("merge with null struct") { withTempView("source") { createAndInitTable( @@ -4884,70 +4867,69 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("merge with with null struct with missing nested field") { - Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> - coerceNestedTypes.toString) { - withTempView("source") { - // Target table has nested struct with fields c1 and c2 - createAndInitTable( - s"""pk INT NOT NULL, - |s STRUCT>, - |dep STRING""".stripMargin, - """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } }, "dep": "sales" } - |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } }, "dep": "hr" }""" - .stripMargin) + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + // Target table has nested struct with fields c1 and c2 + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT>, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": { "a": 10, "b": "x" } }, "dep": "sales" } + |{ "pk": 1, "s": { "c1": 2, "c2": { "a": 20, "b": "y" } }, "dep": "hr" }""" + .stripMargin) - // Source table has null for the nested struct - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType) - // missing field 'b' - ))) - ))), - StructField("dep", StringType) - )) + // Source table has null for the nested struct + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType) + // missing field 'b' + ))) + ))), + StructField("dep", StringType) + )) - val data = Seq( - Row(1, null, "engineering"), - Row(2, null, "finance") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") - val mergeStmt = - s"""MERGE INTO $tableNameAsString t USING source - |ON t.pk = source.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - if (coerceNestedTypes) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, Row(10, "x")), "sales"), - Row(1, null, "engineering"), - Row(2, null, "finance"))) - } else { - // Without coercion, the merge should fail due to missing field - val exception = intercept[org.apache.spark.sql.AnalysisException] { + if (coerceNestedTypes && withSchemaEvolution) { sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot find data for the output column `s`.`c2`.`b`.")) } } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } @@ -4998,37 +4980,21 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase | INSERT * |""".stripMargin - if (coerceNestedTypes) { - if (withSchemaEvolution) { - // extra nested field is added - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, Row(10, "x", null)), "sales"), - Row(1, null, "engineering"), - Row(2, null, "finance"))) - } else { - // extra nested field is not added - val exception = intercept[org.apache.spark.sql.AnalysisException] { - sql(mergeStmt) - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot write extra fields `c` to the struct `s`.`c2`")) - } + if (coerceNestedTypes && withSchemaEvolution) { + // extra nested field is added + sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x", null)), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) } else { - // Without source struct coercion, the merge should fail val exception = intercept[org.apache.spark.sql.AnalysisException] { sql(mergeStmt) } assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot find data for the output column `s`.`c2`.`b`.")) } } } @@ -5097,82 +5063,83 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("merge with null struct using default value") { - Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> - coerceNestedTypes.toString) { - withTempView("source") { - sql( - s"""CREATE TABLE $tableNameAsString ( - | pk INT NOT NULL, - | s STRUCT> DEFAULT - | named_struct('c1', 999, 'c2', named_struct('a', 999, 'b', 'default')), - | dep STRING) - |PARTITIONED BY (dep) - |""".stripMargin) + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + withTempView("source") { + sql( + s"""CREATE TABLE $tableNameAsString ( + | pk INT NOT NULL, + | s STRUCT> DEFAULT + | named_struct('c1', 999, 'c2', named_struct('a', 999, 'b', 'default')), + | dep STRING) + |PARTITIONED BY (dep) + |""".stripMargin) - val initialSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType), - StructField("b", StringType) - ))) - ))), - StructField("dep", StringType) - )) - val initialData = Seq( - Row(0, Row(1, Row(10, "x")), "sales"), - Row(1, Row(2, Row(20, "y")), "hr") - ) - spark.createDataFrame(spark.sparkContext.parallelize(initialData), initialSchema) - .writeTo(tableNameAsString).append() + val initialSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType) + ))) + ))), + StructField("dep", StringType) + )) + val initialData = Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, Row(2, Row(20, "y")), "hr") + ) + spark.createDataFrame(spark.sparkContext.parallelize(initialData), initialSchema) + .writeTo(tableNameAsString).append() - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType) - ))) - ))), - StructField("dep", StringType) - )) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType) + ))) + ))), + StructField("dep", StringType) + )) - val data = Seq( - Row(1, null, "engineering"), - Row(2, null, "finance") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") - val mergeStmt = - s"""MERGE INTO $tableNameAsString t USING source - |ON t.pk = source.pk - |WHEN MATCHED THEN - | UPDATE SET * - |WHEN NOT MATCHED THEN - | INSERT * - |""".stripMargin + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET * + |WHEN NOT MATCHED THEN + | INSERT * + |""".stripMargin - if (coerceNestedTypes) { - sql(mergeStmt) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, Row(10, "x")), "sales"), - Row(1, null, "engineering"), - Row(2, null, "finance"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { + if (coerceNestedTypes && withSchemaEvolution) { sql(mergeStmt) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + sql(mergeStmt) + } + assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") } - assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot find data for the output column `s`.`c2`.`b`")) } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } } @@ -5243,7 +5210,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("merge into with source missing fields in nested struct") { - Seq(true, false).foreach { nestedTypeCoercion => + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { nestedTypeCoercion => withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> nestedTypeCoercion.toString) { withTempView("source") { @@ -5275,7 +5243,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase .createOrReplaceTempView("source") // Missing field b should be filled with NULL - val mergeStmt = s"""MERGE INTO $tableNameAsString t + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t |USING source src |ON t.pk = src.pk |WHEN MATCHED THEN @@ -5284,7 +5253,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase | INSERT * |""".stripMargin - if (nestedTypeCoercion) { + if (nestedTypeCoercion && withSchemaEvolution) { sql(mergeStmt) checkAnswer( sql(s"SELECT * FROM $tableNameAsString"), @@ -5292,16 +5261,17 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase Row(1, Row(10, Row(20, null)), "sales"), Row(2, Row(20, Row(30, null)), "engineering"))) } else { - val exception = intercept[Exception] { + val exception = intercept[org.apache.spark.sql.AnalysisException] { sql(mergeStmt) } - assert(exception.getMessage.contains( - """Cannot write incompatible data for the table ``""".stripMargin)) + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") } } sql(s"DROP TABLE IF EXISTS $tableNameAsString") } } + } } test("merge with named_struct missing non-nullable field") { @@ -5355,51 +5325,68 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("merge with struct missing nested field with check constraint") { - withTempView("source") { - // Target table has struct with nested field c2 - createAndInitTable( - s"""pk INT NOT NULL, - |s STRUCT, - |dep STRING""".stripMargin, - """{ "pk": 0, "s": { "c1": 1, "c2": 10 }, "dep": "sales" } - |{ "pk": 1, "s": { "c1": 2, "c2": 20 }, "dep": "hr" }""" - .stripMargin) + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coercionEnabled => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coercionEnabled.toString) { + withTempView("source") { + // Target table has struct with nested field c2 + createAndInitTable( + s"""pk INT NOT NULL, + |s STRUCT, + |dep STRING""".stripMargin, + """{ "pk": 0, "s": { "c1": 1, "c2": 10 }, "dep": "sales" } + |{ "pk": 1, "s": { "c1": 2, "c2": 20 }, "dep": "hr" }""" + .stripMargin) - // Add CHECK constraint on nested field c2 using ALTER TABLE - sql(s"ALTER TABLE $tableNameAsString ADD CONSTRAINT check_c2 CHECK " + - s"(s.c2 IS NOT NULL AND s.c2 > 1)") + // Add CHECK constraint on nested field c2 using ALTER TABLE + sql(s"ALTER TABLE $tableNameAsString ADD CONSTRAINT check_c2 CHECK " + + s"(s.c2 IS NOT NULL AND s.c2 > 1)") - // Source table schema with struct missing the c2 field - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("s", StructType(Seq( - StructField("c1", IntegerType) - // missing field 'c2' which has CHECK constraint IS NOT NULL AND > 1 - ))), - StructField("dep", StringType) - )) + // Source table schema with struct missing the c2 field + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType) + // missing field 'c2' which has CHECK constraint IS NOT NULL AND > 1 + ))), + StructField("dep", StringType) + )) - val data = Seq( - Row(1, Row(100), "engineering"), - Row(2, Row(200), "finance") - ) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source") + val data = Seq( + Row(1, Row(100), "engineering"), + Row(2, Row(200), "finance") + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source") - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> "true") { - val error = intercept[SparkRuntimeException] { - sql( - s"""MERGE INTO $tableNameAsString t USING source - |ON t.pk = source.pk - |WHEN MATCHED THEN - | UPDATE SET s = source.s, dep = source.dep - |""".stripMargin)} - assert(error.getCondition == "CHECK_CONSTRAINT_VIOLATION") - assert(error.getMessage.contains("CHECK constraint check_c2 s.c2 IS NOT NULL AND " + - "s.c2 > 1 violated by row with values:\n - s.c2 : null")) + val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA EVOLUTION" else "" + val mergeStmt = + s"""MERGE $schemaEvolutionClause INTO $tableNameAsString t USING source + |ON t.pk = source.pk + |WHEN MATCHED THEN + | UPDATE SET s = source.s, dep = source.dep + |""".stripMargin + + if (withSchemaEvolution && coercionEnabled) { + val error = intercept[SparkRuntimeException] { + sql(mergeStmt) + } + assert(error.getCondition == "CHECK_CONSTRAINT_VIOLATION") + assert(error.getMessage.contains("CHECK constraint check_c2 s.c2 IS NOT NULL AND " + + "s.c2 > 1 violated by row with values:\n - s.c2 : null")) + } else { + // Without schema evolution or coercion, the schema mismatch is rejected + val error = intercept[AnalysisException] { + sql(mergeStmt) + } + assert(error.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") + } + } + sql(s"DROP TABLE IF EXISTS $tableNameAsString") + } } } - sql(s"DROP TABLE IF EXISTS $tableNameAsString") } test("merge with schema evolution using dataframe API: add new column and set all") { @@ -5926,30 +5913,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase .whenNotMatched() .insertAll() - if (coerceNestedTypes) { - if (withSchemaEvolution) { - mergeBuilder.withSchemaEvolution().merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) - } + if (coerceNestedTypes && withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) } else { val exception = intercept[org.apache.spark.sql.AnalysisException] { mergeBuilder.merge() } assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot find data for the output column `s`.`c2`.`a`")) } sql(s"DROP TABLE $tableNameAsString") @@ -6038,30 +6013,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase .whenNotMatched() .insertAll() - if (coerceNestedTypes) { - if (withSchemaEvolution) { - mergeBuilder.withSchemaEvolution().merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), - Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write extra fields `c3` to the struct `s`.`c2`")) - } + if (coerceNestedTypes && withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "hr"), + Row(2, Row(20, Row(null, Map("e" -> "f"), true)), "engineering"))) } else { val exception = intercept[org.apache.spark.sql.AnalysisException] { mergeBuilder.merge() } assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot find data for the output column `s`.`c2`.`a`")) } sql(s"DROP TABLE $tableNameAsString") @@ -6132,198 +6095,190 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("merge into with source missing fields in top-level struct using dataframe API") { - Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> - coerceNestedTypes.toString) { - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - // Target table has struct with 3 fields at top level - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT, - |dep STRING)""".stripMargin) + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + // Target table has struct with 3 fields at top level + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT, + |dep STRING)""".stripMargin) - val targetData = Seq( - Row(0, Row(1, "a", true), "sales") - ) - val targetSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType), - StructField("c3", BooleanType) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() + val targetData = Seq( + Row(0, Row(1, "a", true), "sales") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType), + StructField("c3", BooleanType) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() - // Create source table with struct having only 2 fields (c1, c2) - missing c3 - val sourceIdent = Identifier.of(Array("ns1"), "source_table") - val columns = Array( - Column.create("pk", IntegerType, false), - Column.create("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType)))), // missing c3 field - Column.create("dep", StringType)) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withProperties(extraTableProps) - .build() - catalog.createTable(sourceIdent, tableInfo) + // Create source table with struct having only 2 fields (c1, c2) - missing c3 + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType)))), // missing c3 field + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) - val data = Seq( - Row(1, Row(10, "b"), "hr"), - Row(2, Row(20, "c"), "engineering") - ) - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StringType)))), - StructField("dep", StringType))) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source_temp") + val data = Seq( + Row(1, Row(10, "b"), "hr"), + Row(2, Row(20, "c"), "engineering") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StringType)))), + StructField("dep", StringType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") - sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") - if (coerceNestedTypes) { - spark.table(sourceTable) + val mergeBuilder = spark.table(sourceTable) .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) .whenMatched() .updateAll() .whenNotMatched() .insertAll() - .merge() - // Missing field c3 should be filled with NULL - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, "a", true), "sales"), - Row(1, Row(10, "b", null), "hr"), - Row(2, Row(20, "c", null), "engineering"))) - } else { - val exception = intercept[org.apache.spark.sql.AnalysisException] { - spark.table(sourceTable) - .mergeInto(tableNameAsString, $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .updateAll() - .whenNotMatched() - .insertAll() - .merge() + if (coerceNestedTypes && withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + + // Missing field c3 should be filled with NULL + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, "a", true), "sales"), + Row(1, Row(10, "b", null), "hr"), + Row(2, Row(20, "c", null), "engineering"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot find data for the output column `s`.`c3`.")) - } - sql(s"DROP TABLE $tableNameAsString") + sql(s"DROP TABLE $tableNameAsString") + } } } } } test("merge with null struct with missing nested field using dataframe API") { - Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> - coerceNestedTypes.toString) { - val sourceTable = "cat.ns1.source_table" - withTable(sourceTable) { - // Target table has nested struct with fields c1 and c2 - sql( - s"""CREATE TABLE $tableNameAsString ( - |pk INT NOT NULL, - |s STRUCT>, - |dep STRING)""".stripMargin) - - val targetData = Seq( - Row(0, Row(1, Row(10, "x")), "sales"), - Row(1, Row(2, Row(20, "y")), "hr") - ) - val targetSchema = StructType(Seq( - StructField("pk", IntegerType, nullable = false), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType), - StructField("b", StringType) - ))) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) - .writeTo(tableNameAsString).append() + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + val sourceTable = "cat.ns1.source_table" + withTable(sourceTable) { + // Target table has nested struct with fields c1 and c2 + sql( + s"""CREATE TABLE $tableNameAsString ( + |pk INT NOT NULL, + |s STRUCT>, + |dep STRING)""".stripMargin) - // Create source table with missing nested field 'b' - val sourceIdent = Identifier.of(Array("ns1"), "source_table") - val columns = Array( - Column.create("pk", IntegerType, false), - Column.create("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType) - // missing field 'b' - ))) - ))), - Column.create("dep", StringType)) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withProperties(extraTableProps) - .build() - catalog.createTable(sourceIdent, tableInfo) + val targetData = Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, Row(2, Row(20, "y")), "hr") + ) + val targetSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(targetData), targetSchema) + .writeTo(tableNameAsString).append() - // Source table has null for the nested struct - val data = Seq( - Row(1, null, "engineering"), - Row(2, null, "finance") - ) - val sourceTableSchema = StructType(Seq( - StructField("pk", IntegerType), - StructField("s", StructType(Seq( - StructField("c1", IntegerType), - StructField("c2", StructType(Seq( - StructField("a", IntegerType) - ))) - ))), - StructField("dep", StringType) - )) - spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) - .createOrReplaceTempView("source_temp") + // Create source table with missing nested field 'b' + val sourceIdent = Identifier.of(Array("ns1"), "source_table") + val columns = Array( + Column.create("pk", IntegerType, false), + Column.create("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType) + // missing field 'b' + ))) + ))), + Column.create("dep", StringType)) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withProperties(extraTableProps) + .build() + catalog.createTable(sourceIdent, tableInfo) - sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") - val mergeBuilder = spark.table(sourceTable) - .mergeInto(tableNameAsString, - $"source_table.pk" === col(tableNameAsString + ".pk")) - .whenMatched() - .updateAll() - .whenNotMatched() - .insertAll() + // Source table has null for the nested struct + val data = Seq( + Row(1, null, "engineering"), + Row(2, null, "finance") + ) + val sourceTableSchema = StructType(Seq( + StructField("pk", IntegerType), + StructField("s", StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", StructType(Seq( + StructField("a", IntegerType) + ))) + ))), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(data), sourceTableSchema) + .createOrReplaceTempView("source_temp") - if (coerceNestedTypes) { - mergeBuilder.merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, Row(10, "x")), "sales"), - Row(1, null, "engineering"), - Row(2, null, "finance"))) - } else { - // Without coercion, the merge should fail due to missing field - val exception = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() + sql(s"INSERT INTO $sourceTable SELECT * FROM source_temp") + val mergeBuilder = spark.table(sourceTable) + .mergeInto(tableNameAsString, + $"source_table.pk" === col(tableNameAsString + ".pk")) + .whenMatched() + .updateAll() + .whenNotMatched() + .insertAll() + + if (coerceNestedTypes && withSchemaEvolution) { + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x")), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) + } else { + val exception = intercept[org.apache.spark.sql.AnalysisException] { + mergeBuilder.merge() + } + assert(exception.errorClass.get == + "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot find data for the output column `s`.`c2`.`b`.")) - } - sql(s"DROP TABLE $tableNameAsString") + sql(s"DROP TABLE $tableNameAsString") + } } } } @@ -6409,37 +6364,21 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase .whenNotMatched() .insertAll() - if (coerceNestedTypes) { - if (withSchemaEvolution) { - // extra nested field is added - mergeBuilder.withSchemaEvolution().merge() - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Seq( - Row(0, Row(1, Row(10, "x", null)), "sales"), - Row(1, null, "engineering"), - Row(2, null, "finance"))) - } else { - // extra nested field is not added - val exception = intercept[org.apache.spark.sql.AnalysisException] { - mergeBuilder.merge() - } - assert(exception.errorClass.get == - "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot write extra fields `c` to the struct `s`.`c2`")) - } + if (coerceNestedTypes && withSchemaEvolution) { + // extra nested field is added + mergeBuilder.withSchemaEvolution().merge() + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(0, Row(1, Row(10, "x", null)), "sales"), + Row(1, null, "engineering"), + Row(2, null, "finance"))) } else { - // Without source struct coercion, the merge should fail val exception = intercept[org.apache.spark.sql.AnalysisException] { mergeBuilder.merge() } assert(exception.errorClass.get == "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA") - assert(exception.getMessage.contains( - "Cannot write incompatible data for the table ``: " + - "Cannot find data for the output column `s`.`c2`.`b`.")) } sql(s"DROP TABLE $tableNameAsString") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala index f635131dc3f7..ddc864447141 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala @@ -690,28 +690,36 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { |""".stripMargin) assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i")) - Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> - coerceNestedTypes.toString) { - val mergeStmt = - s"""MERGE INTO nested_struct_table t USING nested_struct_table src - |ON t.i = src.i - |$clause THEN - | UPDATE SET s.n_s = named_struct('dn_i', 2L) - |""".stripMargin - if (coerceNestedTypes) { - val plan5 = parseAndResolve(mergeStmt) - // No null check for dn_i as it is explicitly set - assertNoNullCheckExists(plan5) - } else { - val e = intercept[AnalysisException] { - parseAndResolve(mergeStmt) + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + val schemaEvolutionString = if (withSchemaEvolution) { + "WITH SCHEMA EVOLUTION" + } else { + "" + } + val mergeStmt = + s"""MERGE $schemaEvolutionString INTO nested_struct_table t + |USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_s = named_struct('dn_i', 2L) + |""".stripMargin + if (coerceNestedTypes && withSchemaEvolution) { + val plan5 = parseAndResolve(mergeStmt) + // No null check for dn_i as it is explicitly set + assertNoNullCheckExists(plan5) + } else { + val e = intercept[AnalysisException] { + parseAndResolve(mergeStmt) + } + checkError( + exception = e, + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`") + ) } - checkError( - exception = e, - condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", - parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`") - ) } } } @@ -857,27 +865,35 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuiteBase { |""".stripMargin) assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i")) - Seq(true, false).foreach { coerceNestedTypes => - withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> - coerceNestedTypes.toString) { - val mergeStmt = - s"""MERGE INTO nested_struct_table t USING nested_struct_table src - |ON t.i = src.i - |$clause THEN - | UPDATE SET s.n_s = named_struct('dn_i', 1) - |""".stripMargin - if (coerceNestedTypes) { - val plan5 = parseAndResolve(mergeStmt) - // No null check for dn_i as it is explicitly set - assertNoNullCheckExists(plan5) - } else { - val e = intercept[AnalysisException] { - parseAndResolve(mergeStmt) + Seq(true, false).foreach { withSchemaEvolution => + Seq(true, false).foreach { coerceNestedTypes => + withSQLConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED.key -> + coerceNestedTypes.toString) { + val schemaEvolutionString = if (withSchemaEvolution) { + "WITH SCHEMA EVOLUTION" + } else { + "" + } + val mergeStmt = + s"""MERGE $schemaEvolutionString INTO nested_struct_table t + |USING nested_struct_table src + |ON t.i = src.i + |$clause THEN + | UPDATE SET s.n_s = named_struct('dn_i', 1) + |""".stripMargin + if (coerceNestedTypes && withSchemaEvolution) { + val plan5 = parseAndResolve(mergeStmt) + // No null check for dn_i as it is explicitly set + assertNoNullCheckExists(plan5) + } else { + val e = intercept[AnalysisException] { + parseAndResolve(mergeStmt) + } + checkError( + exception = e, + condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", + parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`")) } - checkError( - exception = e, - condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA", - parameters = Map("tableName" -> "``", "colName" -> "`s`.`n_s`.`dn_l`")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 18042bf73adf..8e5ee1644f9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -1633,10 +1633,10 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { if (starInUpdate) { assert(updateAssigns.size == 2) - assert(updateAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ts)) - assert(updateAssigns(0).value.asInstanceOf[AttributeReference].sameRef(ss)) - assert(updateAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ti)) - assert(updateAssigns(1).value.asInstanceOf[AttributeReference].sameRef(si)) + assert(updateAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti)) + assert(updateAssigns(0).value.asInstanceOf[AttributeReference].sameRef(si)) + assert(updateAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ts)) + assert(updateAssigns(1).value.asInstanceOf[AttributeReference].sameRef(ss)) } else { assert(updateAssigns.size == 1) assert(updateAssigns.head.key.asInstanceOf[AttributeReference].sameRef(ts)) @@ -1656,10 +1656,10 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { if (starInInsert) { assert(insertAssigns.size == 2) - assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ts)) - assert(insertAssigns(0).value.asInstanceOf[AttributeReference].sameRef(ss)) - assert(insertAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ti)) - assert(insertAssigns(1).value.asInstanceOf[AttributeReference].sameRef(si)) + assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti)) + assert(insertAssigns(0).value.asInstanceOf[AttributeReference].sameRef(si)) + assert(insertAssigns(1).key.asInstanceOf[AttributeReference].sameRef(ts)) + assert(insertAssigns(1).value.asInstanceOf[AttributeReference].sameRef(ss)) } else { assert(insertAssigns.size == 2) assert(insertAssigns(0).key.asInstanceOf[AttributeReference].sameRef(ti)) @@ -1720,8 +1720,40 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString) } + // star with schema evolution + val sqlStarSchemaEvolution = + s""" + |MERGE WITH SCHEMA EVOLUTION INTO $target AS target + |USING $source AS source + |ON target.i = source.i + |WHEN MATCHED AND (target.s='delete') THEN DELETE + |WHEN MATCHED AND (target.s='update') THEN UPDATE SET * + |WHEN NOT MATCHED AND (source.s='insert') THEN INSERT * + """.stripMargin + parseAndResolve(sqlStarSchemaEvolution) match { + case MergeIntoTable( + SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), + SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)), + mergeCondition, + Seq(DeleteAction(Some(EqualTo(dl: AttributeReference, StringLiteral("delete")))), + UpdateAction(Some(EqualTo(ul: AttributeReference, + StringLiteral("update"))), updateAssigns, _)), + Seq(InsertAction(Some(EqualTo(il: AttributeReference, StringLiteral("insert"))), + insertAssigns)), + Seq(), + withSchemaEvolution) => + checkMergeConditionResolution(target, source, mergeCondition) + checkMatchedClausesResolution(target, source, Some(dl), Some(ul), updateAssigns, + starInUpdate = true) + checkNotMatchedClausesResolution(target, source, Some(il), insertAssigns, + starInInsert = true) + assert(withSchemaEvolution === true) + + case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString) + } + // star - val sql2 = + val sqlStarWithoutSchemaEvolution = s""" |MERGE INTO $target AS target |USING $source AS source @@ -1730,7 +1762,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { |WHEN MATCHED AND (target.s='update') THEN UPDATE SET * |WHEN NOT MATCHED AND (source.s='insert') THEN INSERT * """.stripMargin - parseAndResolve(sql2) match { + parseAndResolve(sqlStarWithoutSchemaEvolution) match { case MergeIntoTable( SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(target)), SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(source)), @@ -2336,24 +2368,11 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { |USING testcat.tab2 |ON 1 = 1 |WHEN MATCHED THEN UPDATE SET *""".stripMargin - val parsed2 = parseAndResolve(sql2) - parsed2 match { - case MergeIntoTable( - AsDataSourceV2Relation(target), - AsDataSourceV2Relation(source), - EqualTo(IntegerLiteral(1), IntegerLiteral(1)), - Seq(UpdateAction(None, updateAssigns, _)), // Matched actions - Seq(), // Not matched actions - Seq(), // Not matched by source actions - withSchemaEvolution) => - val ti = target.output.find(_.name == "i").get - val si = source.output.find(_.name == "i").get - assert(updateAssigns.size == 1) - assert(updateAssigns.head.key.asInstanceOf[AttributeReference].sameRef(ti)) - assert(updateAssigns.head.value.asInstanceOf[AttributeReference].sameRef(si)) - assert(withSchemaEvolution === false) - case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString) - } + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql2)), + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`s`", "proposal" -> "`i`, `x`"), + context = ExpectedContext(fragment = sql2, start = 0, stop = 80)) // INSERT * with incompatible schema between source and target tables. val sql3 = @@ -2361,24 +2380,11 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { |USING testcat.tab2 |ON 1 = 1 |WHEN NOT MATCHED THEN INSERT *""".stripMargin - val parsed3 = parseAndResolve(sql3) - parsed3 match { - case MergeIntoTable( - AsDataSourceV2Relation(target), - AsDataSourceV2Relation(source), - EqualTo(IntegerLiteral(1), IntegerLiteral(1)), - Seq(), // Matched action - Seq(InsertAction(None, insertAssigns)), // Not matched actions - Seq(), // Not matched by source actions - withSchemaEvolution) => - val ti = target.output.find(_.name == "i").get - val si = source.output.find(_.name == "i").get - assert(insertAssigns.size == 1) - assert(insertAssigns.head.key.asInstanceOf[AttributeReference].sameRef(ti)) - assert(insertAssigns.head.value.asInstanceOf[AttributeReference].sameRef(si)) - assert(withSchemaEvolution === false) - case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString) - } + checkError( + exception = intercept[AnalysisException](parseAndResolve(sql3)), + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`s`", "proposal" -> "`i`, `x`"), + context = ExpectedContext(fragment = sql3, start = 0, stop = 80)) val sql4 = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala new file mode 100644 index 000000000000..15682bcf68f1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresDialectSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.jdbc + +import java.sql.Connection + +import org.mockito.Mockito._ +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions + +class PostgresDialectSuite extends SparkFunSuite with MockitoSugar { + + private def createJDBCOptions(extraOptions: Map[String, String]): JDBCOptions = { + new JDBCOptions(Map( + "url" -> "jdbc:postgresql://localhost:5432/test", + "dbtable" -> "test_table" + ) ++ extraOptions) + } + + test("beforeFetch sets autoCommit=false with lowercase fetchsize") { + val conn = mock[Connection] + val dialect = PostgresDialect() + dialect.beforeFetch(conn, createJDBCOptions(Map("fetchsize" -> "100"))) + verify(conn).setAutoCommit(false) + } + + test("beforeFetch sets autoCommit=false with camelCase fetchSize") { + val conn = mock[Connection] + val dialect = PostgresDialect() + dialect.beforeFetch(conn, createJDBCOptions(Map("fetchSize" -> "100"))) + verify(conn).setAutoCommit(false) + } + + test("beforeFetch sets autoCommit=false with uppercase FETCHSIZE") { + val conn = mock[Connection] + val dialect = PostgresDialect() + dialect.beforeFetch(conn, createJDBCOptions(Map("FETCHSIZE" -> "100"))) + verify(conn).setAutoCommit(false) + } + + test("beforeFetch does not set autoCommit when fetchSize is 0") { + val conn = mock[Connection] + val dialect = PostgresDialect() + dialect.beforeFetch(conn, createJDBCOptions(Map("fetchsize" -> "0"))) + verify(conn, never()).setAutoCommit(false) + } +}