diff --git a/R/check-cran.sh b/R/check-cran.sh index 22c8f423cfd12..4123361f5e285 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/create-docs.sh b/R/create-docs.sh index 4867fd99e647c..3deaefd0659dc 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/create-rd.sh b/R/create-rd.sh index 72a932c175c95..1f0527458f2f0 100755 --- a/R/create-rd.sh +++ b/R/create-rd.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/find-r.sh b/R/find-r.sh index 690acc083af91..f1a5026911a7f 100755 --- a/R/find-r.sh +++ b/R/find-r.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/install-dev.sh b/R/install-dev.sh index 9fbc999f2e805..7df21c6c5ec9a 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/install-source-package.sh b/R/install-source-package.sh index 8de3569d1d482..0a2a5fe00f31f 100755 --- a/R/install-source-package.sh +++ b/R/install-source-package.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/run-tests.sh b/R/run-tests.sh index ca5b661127b53..90a60eda03871 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/bin/spark-class b/bin/spark-class index c1461a7712289..df6f6d8d1fcab 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -77,7 +77,8 @@ set +o posix CMD=() DELIM=$'\n' CMD_START_FLAG="false" -while IFS= read -d "$DELIM" -r ARG; do +while IFS= read -d "$DELIM" -r _ARG; do + ARG=${_ARG//$'\r'} # if windows, args can have trailing CR if [ "$CMD_START_FLAG" == "true" ]; then CMD+=("$ARG") else diff --git a/bin/spark-class.cmd b/bin/spark-class.cmd index b22536ab6f458..cc916ff8f0c16 100644 --- a/bin/spark-class.cmd +++ b/bin/spark-class.cmd @@ -22,4 +22,7 @@ rem the environment, it just launches a new cmd to do the real work. rem The outermost quotes are used to prevent Windows command line parse error rem when there are some quotes in parameters, see SPARK-21877. + +rem SHELL must be unset in non-SHELL environment +set SHELL= cmd /V /E /C ""%~dp0spark-class2.cmd" %*" diff --git a/bin/sparkR b/bin/sparkR index 29ab10df8ab6d..8ecc755839fe3 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/binder/postBuild b/binder/postBuild index 733eafe175ef0..34ead09f692f9 100644 --- a/binder/postBuild +++ b/binder/postBuild @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/build/spark-build-info b/build/spark-build-info index eb0e3d730e23e..26157e8cf8cb1 100755 --- a/build/spark-build-info +++ b/build/spark-build-info @@ -24,7 +24,7 @@ RESOURCE_DIR="$1" mkdir -p "$RESOURCE_DIR" -SPARK_BUILD_INFO="${RESOURCE_DIR}"/spark-version-info.properties +SPARK_BUILD_INFO="${RESOURCE_DIR%/}"/spark-version-info.properties echo_build_properties() { echo version=$1 diff --git a/connector/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java b/connector/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java index df5d6d73f2f14..b2a57060fc2d9 100644 --- a/connector/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java +++ b/connector/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java @@ -25,6 +25,7 @@ import org.apache.avro.file.CodecFactory; import org.apache.avro.file.DataFileWriter; import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.generic.GenericRecord; import org.apache.avro.mapred.AvroKey; import org.apache.avro.mapreduce.AvroKeyOutputFormat; @@ -53,7 +54,7 @@ protected RecordWriter, NullWritable> create( CodecFactory compressionCodec, OutputStream outputStream, int syncInterval) throws IOException { - return new SparkAvroKeyRecordWriter( + return new SparkAvroKeyRecordWriter<>( writerSchema, dataModel, compressionCodec, outputStream, syncInterval, metadata); } } @@ -72,7 +73,7 @@ class SparkAvroKeyRecordWriter extends RecordWriter, NullWritable> OutputStream outputStream, int syncInterval, Map metadata) throws IOException { - this.mAvroFileWriter = new DataFileWriter(dataModel.createDatumWriter(writerSchema)); + this.mAvroFileWriter = new DataFileWriter<>(new GenericDatumWriter<>(writerSchema, dataModel)); for (Map.Entry entry : metadata.entrySet()) { this.mAvroFileWriter.setMeta(entry.getKey(), entry.getValue()); } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index e065bbce27094..95001bb81508c 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} import org.apache.spark.sql.internal.SQLConf @@ -37,6 +37,8 @@ private[sql] class AvroOptions( @transient val conf: Configuration) extends FileSourceOptions(parameters) with Logging { + import AvroOptions._ + def this(parameters: Map[String, String], conf: Configuration) = { this(CaseInsensitiveMap(parameters), conf) } @@ -54,8 +56,8 @@ private[sql] class AvroOptions( * instead of "string" type in the default converted schema. */ val schema: Option[Schema] = { - parameters.get("avroSchema").map(new Schema.Parser().setValidateDefaults(false).parse).orElse({ - val avroUrlSchema = parameters.get("avroSchemaUrl").map(url => { + parameters.get(AVRO_SCHEMA).map(new Schema.Parser().setValidateDefaults(false).parse).orElse({ + val avroUrlSchema = parameters.get(AVRO_SCHEMA_URL).map(url => { log.debug("loading avro schema from url: " + url) val fs = FileSystem.get(new URI(url), conf) val in = fs.open(new Path(url)) @@ -75,20 +77,20 @@ private[sql] class AvroOptions( * whose field names do not match. Defaults to false. */ val positionalFieldMatching: Boolean = - parameters.get("positionalFieldMatching").exists(_.toBoolean) + parameters.get(POSITIONAL_FIELD_MATCHING).exists(_.toBoolean) /** * Top level record name in write result, which is required in Avro spec. * See https://avro.apache.org/docs/1.11.1/specification/#schema-record . * Default value is "topLevelRecord" */ - val recordName: String = parameters.getOrElse("recordName", "topLevelRecord") + val recordName: String = parameters.getOrElse(RECORD_NAME, "topLevelRecord") /** * Record namespace in write result. Default value is "". * See Avro spec for details: https://avro.apache.org/docs/1.11.1/specification/#schema-record . */ - val recordNamespace: String = parameters.getOrElse("recordNamespace", "") + val recordNamespace: String = parameters.getOrElse(RECORD_NAMESPACE, "") /** * The `ignoreExtension` option controls ignoring of files without `.avro` extensions in read. @@ -104,7 +106,7 @@ private[sql] class AvroOptions( ignoreFilesWithoutExtensionByDefault) parameters - .get(AvroOptions.ignoreExtensionKey) + .get(IGNORE_EXTENSION) .map(_.toBoolean) .getOrElse(!ignoreFilesWithoutExtension) } @@ -116,21 +118,21 @@ private[sql] class AvroOptions( * taken into account. If the former one is not set too, the `snappy` codec is used by default. */ val compression: String = { - parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec) + parameters.get(COMPRESSION).getOrElse(SQLConf.get.avroCompressionCodec) } val parseMode: ParseMode = - parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode) + parameters.get(MODE).map(ParseMode.fromString).getOrElse(FailFastMode) /** * The rebasing mode for the DATE and TIMESTAMP_MICROS, TIMESTAMP_MILLIS values in reads. */ val datetimeRebaseModeInRead: String = parameters - .get(AvroOptions.DATETIME_REBASE_MODE) + .get(DATETIME_REBASE_MODE) .getOrElse(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_READ)) } -private[sql] object AvroOptions { +private[sql] object AvroOptions extends DataSourceOptions { def apply(parameters: Map[String, String]): AvroOptions = { val hadoopConf = SparkSession .getActiveSession @@ -139,11 +141,17 @@ private[sql] object AvroOptions { new AvroOptions(CaseInsensitiveMap(parameters), hadoopConf) } - val ignoreExtensionKey = "ignoreExtension" - + val IGNORE_EXTENSION = newOption("ignoreExtension") + val MODE = newOption("mode") + val RECORD_NAME = newOption("recordName") + val COMPRESSION = newOption("compression") + val AVRO_SCHEMA = newOption("avroSchema") + val AVRO_SCHEMA_URL = newOption("avroSchemaUrl") + val RECORD_NAMESPACE = newOption("recordNamespace") + val POSITIONAL_FIELD_MATCHING = newOption("positionalFieldMatching") // The option controls rebasing of the DATE and TIMESTAMP values between // Julian and Proleptic Gregorian calendars. It impacts on the behaviour of the Avro // datasource similarly to the SQL config `spark.sql.avro.datetimeRebaseModeInRead`, // and can be set to the same values: `EXCEPTION`, `LEGACY` or `CORRECTED`. - val DATETIME_REBASE_MODE = "datetimeRebaseMode" + val DATETIME_REBASE_MODE = newOption("datetimeRebaseMode") } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 56d177da14369..45fa7450e4522 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.avro.AvroOptions.ignoreExtensionKey +import org.apache.spark.sql.avro.AvroOptions.IGNORE_EXTENSION import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.OutputWriterFactory @@ -50,8 +50,8 @@ private[sql] object AvroUtils extends Logging { val conf = spark.sessionState.newHadoopConfWithOptions(options) val parsedOptions = new AvroOptions(options, conf) - if (parsedOptions.parameters.contains(ignoreExtensionKey)) { - logWarning(s"Option $ignoreExtensionKey is deprecated. Please use the " + + if (parsedOptions.parameters.contains(IGNORE_EXTENSION)) { + logWarning(s"Option $IGNORE_EXTENSION is deprecated. Please use the " + "general data source option pathGlobFilter for filtering file names.") } // User can specify an optional avro json schema. diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index cf4a490b90273..f8d0ac08d0073 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1075,7 +1075,7 @@ abstract class AvroSuite .save(s"$tempDir/${UUID.randomUUID()}") }.getMessage assert(message.contains("Caused by: java.lang.NullPointerException: ")) - assert(message.contains("null in string in field Name")) + assert(message.contains("null value for (non-nullable) string at test_schema.Name")) } } @@ -1804,13 +1804,13 @@ abstract class AvroSuite spark .read .format("avro") - .option(AvroOptions.ignoreExtensionKey, false) + .option(AvroOptions.IGNORE_EXTENSION, false) .load(dir.getCanonicalPath) .count() } val deprecatedEvents = logAppender.loggingEvents .filter(_.getMessage.getFormattedMessage.contains( - s"Option ${AvroOptions.ignoreExtensionKey} is deprecated")) + s"Option ${AvroOptions.IGNORE_EXTENSION} is deprecated")) assert(deprecatedEvents.size === 1) } } @@ -2272,6 +2272,20 @@ abstract class AvroSuite checkAnswer(df2, df.collect().toSeq) } } + + test("SPARK-40667: validate Avro Options") { + assert(AvroOptions.getAllOptions.size == 9) + // Please add validation on any new Avro options here + assert(AvroOptions.isValidOption("ignoreExtension")) + assert(AvroOptions.isValidOption("mode")) + assert(AvroOptions.isValidOption("recordName")) + assert(AvroOptions.isValidOption("compression")) + assert(AvroOptions.isValidOption("avroSchema")) + assert(AvroOptions.isValidOption("avroSchemaUrl")) + assert(AvroOptions.isValidOption("recordNamespace")) + assert(AvroOptions.isValidOption("positionalFieldMatching")) + assert(AvroOptions.isValidOption("datetimeRebaseMode")) + } } class AvroV1Suite extends AvroSuite { diff --git a/connector/connect/dev/generate_protos.sh b/connector/connect/dev/generate_protos.sh index 204beda6aa971..9457e7b33edd5 100755 --- a/connector/connect/dev/generate_protos.sh +++ b/connector/connect/dev/generate_protos.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with diff --git a/connector/connect/pom.xml b/connector/connect/pom.xml index 62786e6d5b6a9..66c2f07cfb6ba 100644 --- a/connector/connect/pom.xml +++ b/connector/connect/pom.xml @@ -268,11 +268,13 @@ as assembly build. --> com.google.android:annotations - com.google.api.grpc:proto-google-common-proto + com.google.api.grpc:proto-google-common-protos io.perfmark:perfmark-api org.codehaus.mojo:animal-sniffer-annotations com.google.errorprone:error_prone_annotations com.google.j2objc:j2objc-annotations + org.checkerframework:checker-qual + com.google.code.gson:gson @@ -303,28 +305,66 @@ - com.google.android - ${spark.shade.packageName}.connect.android + android.annotation + ${spark.shade.packageName}.connect.android_annotation - com.google.api.grpc - ${spark.shade.packageName}.connect.api + io.perfmark + ${spark.shade.packageName}.connect.io_perfmark - io.perfmark - ${spark.shade.packageName}.connect.perfmark + org.codehaus.mojo.animal_sniffer + ${spark.shade.packageName}.connect.animal_sniffer + + + com.google.j2objc.annotations + ${spark.shade.packageName}.connect.j2objc_annotations + + + com.google.errorprone.annotations + ${spark.shade.packageName}.connect.errorprone_annotations + + + org.checkerframework + ${spark.shade.packageName}.connect.checkerframework + + + com.google.gson + ${spark.shade.packageName}.connect.gson + + + + + com.google.api + ${spark.shade.packageName}.connect.google_protos.api + + + com.google.cloud + ${spark.shade.packageName}.connect.google_protos.cloud + + + com.google.geo + ${spark.shade.packageName}.connect.google_protos.geo + + + com.google.logging + ${spark.shade.packageName}.connect.google_protos.logging - org.codehaus.mojo - ${spark.shade.packageName}.connect.mojo + com.google.longrunning + ${spark.shade.packageName}.connect.google_protos.longrunning - com.google.errorprone - ${spark.shade.packageName}.connect.errorprone + com.google.rpc + ${spark.shade.packageName}.connect.google_protos.rpc - com.com.google.j2objc - ${spark.shade.packageName}.connect.j2objc + com.google.type + ${spark.shade.packageName}.connect.google_protos.type diff --git a/connector/connect/src/main/protobuf/spark/connect/commands.proto b/connector/connect/src/main/protobuf/spark/connect/commands.proto index 425857b842e56..0a83e4543f5ec 100644 --- a/connector/connect/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/src/main/protobuf/spark/connect/commands.proto @@ -44,8 +44,8 @@ message CreateScalarFunction { repeated string parts = 1; FunctionLanguage language = 2; bool temporary = 3; - repeated Type argument_types = 4; - Type return_type = 5; + repeated DataType argument_types = 4; + DataType return_type = 5; // How the function body is defined: oneof function_definition { diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto index 9b3029a32b0a7..4b5a81d2a568c 100644 --- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto @@ -35,6 +35,7 @@ message Expression { UnresolvedFunction unresolved_function = 3; ExpressionString expression_string = 4; UnresolvedStar unresolved_star = 5; + Alias alias = 6; } message Literal { @@ -65,10 +66,10 @@ message Expression { // Timestamp in units of microseconds since the UNIX epoch. int64 timestamp_tz = 27; bytes uuid = 28; - Type null = 29; // a typed null literal + DataType null = 29; // a typed null literal List list = 30; - Type.List empty_list = 31; - Type.Map empty_map = 32; + DataType.List empty_list = 31; + DataType.Map empty_map = 32; UserDefined user_defined = 33; } @@ -164,5 +165,11 @@ message Expression { // by the analyzer. message QualifiedAttribute { string name = 1; + DataType type = 2; + } + + message Alias { + Expression expr = 1; + string name = 2; } } diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 087a4dca2f70a..30f36fa6ceb52 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -94,16 +94,17 @@ message Filter { message Join { Relation left = 1; Relation right = 2; - Expression on = 3; - JoinType how = 4; + Expression join_condition = 3; + JoinType join_type = 4; enum JoinType { JOIN_TYPE_UNSPECIFIED = 0; JOIN_TYPE_INNER = 1; - JOIN_TYPE_OUTER = 2; + JOIN_TYPE_FULL_OUTER = 2; JOIN_TYPE_LEFT_OUTER = 3; JOIN_TYPE_RIGHT_OUTER = 4; - JOIN_TYPE_ANTI = 5; + JOIN_TYPE_LEFT_ANTI = 5; + JOIN_TYPE_LEFT_SEMI = 6; } } @@ -129,22 +130,8 @@ message Fetch { // Relation of type [[Aggregate]]. message Aggregate { Relation input = 1; - - // Grouping sets are used in rollups - repeated GroupingSet grouping_sets = 2; - - // Measures - repeated Measure measures = 3; - - message GroupingSet { - repeated Expression aggregate_expressions = 1; - } - - message Measure { - AggregateFunction function = 1; - // Conditional filter for SUM(x FILTER WHERE x < 10) - Expression filter = 2; - } + repeated Expression grouping_expressions = 2; + repeated AggregateFunction result_expressions = 3; message AggregateFunction { string name = 1; diff --git a/connector/connect/src/main/protobuf/spark/connect/types.proto b/connector/connect/src/main/protobuf/spark/connect/types.proto index c46afa2afc651..98b0c48b1e016 100644 --- a/connector/connect/src/main/protobuf/spark/connect/types.proto +++ b/connector/connect/src/main/protobuf/spark/connect/types.proto @@ -22,9 +22,9 @@ package spark.connect; option java_multiple_files = true; option java_package = "org.apache.spark.connect.proto"; -// This message describes the logical [[Type]] of something. It does not carry the value +// This message describes the logical [[DataType]] of something. It does not carry the value // itself but only describes it. -message Type { +message DataType { oneof kind { Boolean bool = 1; I8 i8 = 2; @@ -168,20 +168,20 @@ message Type { } message Struct { - repeated Type types = 1; + repeated DataType types = 1; uint32 type_variation_reference = 2; Nullability nullability = 3; } message List { - Type type = 1; + DataType DataType = 1; uint32 type_variation_reference = 2; Nullability nullability = 3; } message Map { - Type key = 1; - Type value = 2; + DataType key = 1; + DataType value = 2; uint32 type_variation_reference = 3; Nullability nullability = 4; } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index d54d5b404410e..80d6e77c9fc45 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connect import scala.collection.JavaConverters._ import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.sql.catalyst.parser.CatalystSqlParser /** @@ -39,6 +40,11 @@ package object dsl { .build()) .build() } + + implicit class DslExpression(val expr: proto.Expression) { + def as(alias: String): proto.Expression = proto.Expression.newBuilder().setAlias( + proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build() + } } object plans { // scalastyle:ignore @@ -51,6 +57,34 @@ package object dsl { .build() ).build() } + + def join( + otherPlan: proto.Relation, + joinType: JoinType = JoinType.JOIN_TYPE_INNER, + condition: Option[proto.Expression] = None): proto.Relation = { + val relation = proto.Relation.newBuilder() + val join = proto.Join.newBuilder() + join.setLeft(logicalPlan) + .setRight(otherPlan) + .setJoinType(joinType) + if (condition.isDefined) { + join.setJoinCondition(condition.get) + } + relation.setJoin(join).build() + } + + def groupBy( + groupingExprs: proto.Expression*)(aggregateExprs: proto.Expression*): proto.Relation = { + val agg = proto.Aggregate.newBuilder() + agg.setInput(logicalPlan) + + for (groupingExpr <- groupingExprs) { + agg.addGroupingExpressions(groupingExpr) + } + // TODO: support aggregateExprs, which is blocked by supporting any builtin function + // resolution only by name in the analyzer. + proto.Relation.newBuilder().setAggregate(agg.build()).build() + } } } } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala new file mode 100644 index 0000000000000..b31855bfca993 --- /dev/null +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.planner + +import org.apache.spark.connect.proto +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} + +/** + * This object offers methods to convert to/from connect proto to catalyst types. + */ +object DataTypeProtoConverter { + def toCatalystType(t: proto.DataType): DataType = { + t.getKindCase match { + case proto.DataType.KindCase.I32 => IntegerType + case proto.DataType.KindCase.STRING => StringType + case _ => + throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.") + } + } + + def toConnectProtoType(t: DataType): proto.DataType = { + t match { + case IntegerType => + proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build() + case StringType => + proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build() + case _ => + throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.") + } + } +} diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index fa9dd18d3bfa5..5ad95a6b516ab 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -22,10 +22,10 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Since, Unstable} import org.apache.spark.connect.proto import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.{expressions, plans} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.types._ @@ -77,8 +77,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = { - // TODO: use data type from the proto. - AttributeReference(exp.getName, IntegerType)() + AttributeReference(exp.getName, DataTypeProtoConverter.toCatalystType(exp.getType))() } private def transformReadRel( @@ -133,6 +132,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { transformUnresolvedExpression(exp) case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION => transformScalarFunction(exp.getUnresolvedFunction) + case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias) case _ => throw InvalidPlanInput() } } @@ -209,6 +209,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } } + private def transformAlias(alias: proto.Expression.Alias): Expression = { + Alias(transformExpression(alias.getExpr), alias.getName)() + } + private def transformUnion(u: proto.Union): LogicalPlan = { assert(u.getInputsCount == 2, "Union must have 2 inputs") val plan = logical.Union(transformRelation(u.getInputs(0)), transformRelation(u.getInputs(1))) @@ -223,15 +227,30 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { private def transformJoin(rel: proto.Join): LogicalPlan = { assert(rel.hasLeft && rel.hasRight, "Both join sides must be present") + val joinCondition = + if (rel.hasJoinCondition) Some(transformExpression(rel.getJoinCondition)) else None + logical.Join( left = transformRelation(rel.getLeft), right = transformRelation(rel.getRight), - // TODO(SPARK-40534) Support additional join types and configuration. - joinType = plans.Inner, - condition = Some(transformExpression(rel.getOn)), + joinType = transformJoinType( + if (rel.getJoinType != null) rel.getJoinType else proto.Join.JoinType.JOIN_TYPE_INNER), + condition = joinCondition, hint = logical.JoinHint.NONE) } + private def transformJoinType(t: proto.Join.JoinType): JoinType = { + t match { + case proto.Join.JoinType.JOIN_TYPE_INNER => Inner + case proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI => LeftAnti + case proto.Join.JoinType.JOIN_TYPE_FULL_OUTER => FullOuter + case proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER => LeftOuter + case proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER => RightOuter + case proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI => LeftSemi + case _ => throw InvalidPlanInput(s"Join type ${t} is not supported") + } + } + private def transformSort(rel: proto.Sort): LogicalPlan = { assert(rel.getSortFieldsCount > 0, "'sort_fields' must be present and contain elements.") logical.Sort( @@ -256,11 +275,9 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { private def transformAggregate(rel: proto.Aggregate): LogicalPlan = { assert(rel.hasInput) - assert(rel.getGroupingSetsCount == 1, "Only one grouping set is supported") - val groupingSet = rel.getGroupingSetsList.asScala.take(1) - val ge = groupingSet - .flatMap(f => f.getAggregateExpressionsList.asScala) + val groupingExprs = + rel.getGroupingExpressionsList.asScala .map(transformExpression) .map { case x @ UnresolvedAttribute(_) => x @@ -269,18 +286,18 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { logical.Aggregate( child = transformRelation(rel.getInput), - groupingExpressions = ge.toSeq, + groupingExpressions = groupingExprs.toSeq, aggregateExpressions = - (rel.getMeasuresList.asScala.map(transformAggregateExpression) ++ ge).toSeq) + rel.getResultExpressionsList.asScala.map(transformAggregateExpression).toSeq) } private def transformAggregateExpression( - exp: proto.Aggregate.Measure): expressions.NamedExpression = { - val fun = exp.getFunction.getName + exp: proto.Aggregate.AggregateFunction): expressions.NamedExpression = { + val fun = exp.getName UnresolvedAlias( UnresolvedFunction( name = fun, - arguments = exp.getFunction.getArgumentsList.asScala.map(transformExpression).toSeq, + arguments = exp.getArgumentsList.asScala.map(transformExpression).toSeq, isDistinct = false)) } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index e1a658fb57b27..10e17f121f0e5 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -45,11 +45,6 @@ trait SparkConnectPlanTest { .build() } -trait SparkConnectSessionTest { - protected var spark: SparkSession - -} - /** * This is a rudimentary test class for SparkConnect. The main goal of these basic tests is to * ensure that the transformation from Proto to LogicalPlan works and that the right nodes are @@ -161,7 +156,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { proto.Relation.newBuilder.setJoin(proto.Join.newBuilder.setLeft(readRel)).build() intercept[AssertionError](transform(incompleteJoin)) - // Cartesian Product not supported. + // Join type JOIN_TYPE_UNSPECIFIED is not supported. intercept[InvalidPlanInput] { val simpleJoin = proto.Relation.newBuilder .setJoin(proto.Join.newBuilder.setLeft(readRel).setRight(readRel)) @@ -185,7 +180,12 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val simpleJoin = proto.Relation.newBuilder .setJoin( - proto.Join.newBuilder.setLeft(readRel).setRight(readRel).setOn(joinCondition).build()) + proto.Join.newBuilder + .setLeft(readRel) + .setRight(readRel) + .setJoinType(proto.Join.JoinType.JOIN_TYPE_INNER) + .setJoinCondition(joinCondition) + .build()) .build() val res = transform(simpleJoin) @@ -217,16 +217,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val agg = proto.Aggregate.newBuilder .setInput(readRel) - .addAllMeasures( - Seq( - proto.Aggregate.Measure.newBuilder - .setFunction(proto.Aggregate.AggregateFunction.newBuilder - .setName("sum") - .addArguments(unresolvedAttribute)) - .build()).asJava) - .addGroupingSets(proto.Aggregate.GroupingSet.newBuilder - .addAggregateExpressions(unresolvedAttribute) - .build()) + .addResultExpressions( + proto.Aggregate.AggregateFunction.newBuilder + .setName("sum") + .addArguments(unresolvedAttribute)) + .addGroupingExpressions(unresolvedAttribute) .build() val res = transform(proto.Relation.newBuilder.setAggregate(agg).build()) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 6eab50a0a2bdc..510b54cd25084 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.connect.planner import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation /** @@ -30,9 +31,13 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation */ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { - lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int)) + lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int, $"name".string)) - lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int) + lazy val connectTestRelation2 = createLocalRelationProto(Seq($"key".int, $"value".int)) + + lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int, $"name".string) + + lazy val sparkTestRelation2: LocalRelation = LocalRelation($"key".int, $"value".int) test("Basic select") { val connectPlan = { @@ -46,12 +51,62 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connectPlan.analyze, sparkPlan.analyze, false) } + test("Basic joins with different join types") { + val connectPlan = { + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.join(connectTestRelation2)) + } + val sparkPlan = sparkTestRelation.join(sparkTestRelation2) + comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + + val connectPlan2 = { + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.join(connectTestRelation2, condition = None)) + } + val sparkPlan2 = sparkTestRelation.join(sparkTestRelation2, condition = None) + comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false) + for ((t, y) <- Seq( + (JoinType.JOIN_TYPE_LEFT_OUTER, LeftOuter), + (JoinType.JOIN_TYPE_RIGHT_OUTER, RightOuter), + (JoinType.JOIN_TYPE_FULL_OUTER, FullOuter), + (JoinType.JOIN_TYPE_LEFT_ANTI, LeftAnti), + (JoinType.JOIN_TYPE_LEFT_SEMI, LeftSemi), + (JoinType.JOIN_TYPE_INNER, Inner))) { + val connectPlan3 = { + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.join(connectTestRelation2, t)) + } + val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, y) + comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false) + } + } + + test("column alias") { + val connectPlan = { + import org.apache.spark.sql.connect.dsl.expressions._ + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.select("id".protoAttr.as("id2"))) + } + val sparkPlan = sparkTestRelation.select($"id".as("id2")) + } + + test("Aggregate with more than 1 grouping expressions") { + val connectPlan = { + import org.apache.spark.sql.connect.dsl.expressions._ + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.groupBy("id".protoAttr, "name".protoAttr)()) + } + val sparkPlan = sparkTestRelation.groupBy($"id", $"name")() + comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + } + private def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() - // TODO: set data types for each local relation attribute one proto supports data type. for (attr <- attrs) { localRelationBuilder.addAttributes( - proto.Expression.QualifiedAttribute.newBuilder().setName(attr.name).build() + proto.Expression.QualifiedAttribute.newBuilder() + .setName(attr.name) + .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType)) ) } proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() diff --git a/connector/docker/build b/connector/docker/build index 253a2fc8dd8e7..de83c7d7611dc 100755 --- a/connector/docker/build +++ b/connector/docker/build @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/connector/docker/spark-test/build b/connector/docker/spark-test/build index 6f9e19743370b..55dff4754b000 100755 --- a/connector/docker/spark-test/build +++ b/connector/docker/spark-test/build @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/connector/docker/spark-test/master/default_cmd b/connector/docker/spark-test/master/default_cmd index 96a36cd0bb682..6865ca41b894f 100755 --- a/connector/docker/spark-test/master/default_cmd +++ b/connector/docker/spark-test/master/default_cmd @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/connector/docker/spark-test/worker/default_cmd b/connector/docker/spark-test/worker/default_cmd index 2401f5565aa0b..1f2aac95ed699 100755 --- a/connector/docker/spark-test/worker/default_cmd +++ b/connector/docker/spark-test/worker/default_cmd @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/core/src/main/java/org/apache/spark/SparkThrowable.java b/core/src/main/java/org/apache/spark/SparkThrowable.java index 7fb693d9c5569..e1235b2982ba0 100644 --- a/core/src/main/java/org/apache/spark/SparkThrowable.java +++ b/core/src/main/java/org/apache/spark/SparkThrowable.java @@ -51,7 +51,7 @@ default boolean isInternalError() { } default Map getMessageParameters() { - return new HashMap(); + return new HashMap<>(); } default QueryContext[] getQueryContext() { return new QueryContext[0]; } diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 3e1655d80e4e7..197ab6aa1a7b3 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -86,6 +86,11 @@ "Cannot resolve due to data type mismatch:" ], "subClass" : { + "BINARY_ARRAY_DIFF_TYPES" : { + "message" : [ + "Input to function should have been two with same element type, but it's [, ]." + ] + }, "BINARY_OP_DIFF_TYPES" : { "message" : [ "the left and right operands of the binary operator have incompatible types ( and )." @@ -118,6 +123,11 @@ "To convert values from to , you can use the functions instead." ] }, + "DATA_DIFF_TYPES" : { + "message" : [ + "Input to should all be the same type, but it's ." + ] + }, "FRAME_LESS_OFFSET_WITHOUT_FOLDABLE" : { "message" : [ "Offset expression must be a literal." @@ -133,6 +143,31 @@ "Input schema must be a struct, an array or a map." ] }, + "INVALID_MAP_KEY_TYPE" : { + "message" : [ + "The key of map cannot be/contain ." + ] + }, + "INVALID_ORDERING_TYPE" : { + "message" : [ + "The does not support ordering on type ." + ] + }, + "MAP_CONCAT_DIFF_TYPES" : { + "message" : [ + "The should all be of type map, but it's ." + ] + }, + "MAP_CONTAINS_KEY_DIFF_TYPES" : { + "message" : [ + "Input to should have been followed by a value with same key type, but it's [, ]." + ] + }, + "MAP_FROM_ENTRIES_WRONG_TYPE" : { + "message" : [ + "The accepts only arrays of pair structs, but is of ." + ] + }, "NON_FOLDABLE_INPUT" : { "message" : [ "the input should be a foldable string expression and not null; however, got ." @@ -143,6 +178,11 @@ "all arguments must be strings." ] }, + "NULL_TYPE" : { + "message" : [ + "Null typed values cannot be used as arguments of ." + ] + }, "RANGE_FRAME_INVALID_TYPE" : { "message" : [ "The data type used in the order specification does not match the data type which is used in the range frame." @@ -3393,5 +3433,130 @@ "message" : [ "Write is not supported for binary file data source" ] + }, + "_LEGACY_ERROR_TEMP_2076" : { + "message" : [ + "The length of is , which exceeds the max length allowed: ." + ] + }, + "_LEGACY_ERROR_TEMP_2077" : { + "message" : [ + "Unsupported field name: " + ] + }, + "_LEGACY_ERROR_TEMP_2078" : { + "message" : [ + "Both '' and '' can not be specified at the same time." + ] + }, + "_LEGACY_ERROR_TEMP_2079" : { + "message" : [ + "Option '' or '' is required." + ] + }, + "_LEGACY_ERROR_TEMP_2080" : { + "message" : [ + "Option `` can not be empty." + ] + }, + "_LEGACY_ERROR_TEMP_2081" : { + "message" : [ + "Invalid value `` for parameter ``. This can be `NONE`, `READ_UNCOMMITTED`, `READ_COMMITTED`, `REPEATABLE_READ` or `SERIALIZABLE`." + ] + }, + "_LEGACY_ERROR_TEMP_2082" : { + "message" : [ + "Can't get JDBC type for " + ] + }, + "_LEGACY_ERROR_TEMP_2083" : { + "message" : [ + "Unsupported type " + ] + }, + "_LEGACY_ERROR_TEMP_2084" : { + "message" : [ + "Unsupported array element type based on binary" + ] + }, + "_LEGACY_ERROR_TEMP_2085" : { + "message" : [ + "Nested arrays unsupported" + ] + }, + "_LEGACY_ERROR_TEMP_2086" : { + "message" : [ + "Can't translate non-null value for field " + ] + }, + "_LEGACY_ERROR_TEMP_2087" : { + "message" : [ + "Invalid value `` for parameter `` in table writing via JDBC. The minimum value is 1." + ] + }, + "_LEGACY_ERROR_TEMP_2088" : { + "message" : [ + " is not supported yet." + ] + }, + "_LEGACY_ERROR_TEMP_2089" : { + "message" : [ + "DataType: " + ] + }, + "_LEGACY_ERROR_TEMP_2090" : { + "message" : [ + "The input filter of should be fully convertible." + ] + }, + "_LEGACY_ERROR_TEMP_2091" : { + "message" : [ + "Could not read footer for file: " + ] + }, + "_LEGACY_ERROR_TEMP_2092" : { + "message" : [ + "Could not read footer for file: " + ] + }, + "_LEGACY_ERROR_TEMP_2093" : { + "message" : [ + "Found duplicate field(s) \"\": in case-insensitive mode" + ] + }, + "_LEGACY_ERROR_TEMP_2094" : { + "message" : [ + "Found duplicate field(s) \"\": in id mapping mode" + ] + }, + "_LEGACY_ERROR_TEMP_2095" : { + "message" : [ + "Failed to merge incompatible schemas and " + ] + }, + "_LEGACY_ERROR_TEMP_2096" : { + "message" : [ + " is not supported temporarily." + ] + }, + "_LEGACY_ERROR_TEMP_2097" : { + "message" : [ + "Could not execute broadcast in secs. You can increase the timeout for broadcasts via or disable broadcast join by setting to -1" + ] + }, + "_LEGACY_ERROR_TEMP_2098" : { + "message" : [ + "Could not compare cost with " + ] + }, + "_LEGACY_ERROR_TEMP_2099" : { + "message" : [ + "Unsupported data type:
" + ] + }, + "_LEGACY_ERROR_TEMP_2100" : { + "message" : [ + "not support type: " + ] } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c55298513824c..7ad53b8f9f877 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -2949,7 +2949,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case ExecutorLost(execId, reason) => val workerHost = reason match { case ExecutorProcessLost(_, workerHost, _) => workerHost - case ExecutorDecommission(workerHost) => workerHost + case ExecutorDecommission(workerHost, _) => workerHost case _ => None } dagScheduler.handleExecutorLost(execId, workerHost) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index f333c01bb890d..fb6a62551fa44 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -77,6 +77,13 @@ case class ExecutorProcessLost( * If you update this code make sure to re-run the K8s integration tests. * * @param workerHost it is defined when the worker is decommissioned too + * @param reason detailed decommission message */ -private [spark] case class ExecutorDecommission(workerHost: Option[String] = None) - extends ExecutorLossReason("Executor decommission.") +private [spark] case class ExecutorDecommission( + workerHost: Option[String] = None, + reason: String = "") + extends ExecutorLossReason(ExecutorDecommission.msgPrefix + reason) + +private[spark] object ExecutorDecommission { + val msgPrefix = "Executor decommission: " +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 1d157f51fe678..943d1e53df44b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -1071,7 +1071,7 @@ private[spark] class TaskSetManager( for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { val exitCausedByApp: Boolean = reason match { case ExecutorExited(_, false, _) => false - case ExecutorKilled | ExecutorDecommission(_) => false + case ExecutorKilled | ExecutorDecommission(_, _) => false case ExecutorProcessLost(_, _, false) => false // If the task is launching, this indicates that Driver has sent LaunchTask to Executor, // but Executor has not sent StatusUpdate(TaskState.RUNNING) to Driver. Hence, we assume diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index e37abd76296c9..225dd1d75bfaf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -99,8 +99,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors that have been lost, but for which we don't yet know the real exit reason. protected val executorsPendingLossReason = new HashSet[String] - // Executors which are being decommissioned. Maps from executorId to workerHost. - protected val executorsPendingDecommission = new HashMap[String, Option[String]] + // Executors which are being decommissioned. Maps from executorId to ExecutorDecommissionInfo. + protected val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo] // A map of ResourceProfile id to map of hostname with its possible task number running on it @GuardedBy("CoarseGrainedSchedulerBackend.this") @@ -447,11 +447,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorDataMap -= executorId executorsPendingLossReason -= executorId val killedByDriver = executorsPendingToRemove.remove(executorId).getOrElse(false) - val workerHostOpt = executorsPendingDecommission.remove(executorId) + val decommissionInfoOpt = executorsPendingDecommission.remove(executorId) if (killedByDriver) { ExecutorKilled - } else if (workerHostOpt.isDefined) { - ExecutorDecommission(workerHostOpt.get) + } else if (decommissionInfoOpt.isDefined) { + val decommissionInfo = decommissionInfoOpt.get + ExecutorDecommission(decommissionInfo.workerHost, decommissionInfo.message) } else { reason } @@ -535,7 +536,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Only bother decommissioning executors which are alive. if (isExecutorActive(executorId)) { scheduler.executorDecommission(executorId, decomInfo) - executorsPendingDecommission(executorId) = decomInfo.workerHost + executorsPendingDecommission(executorId) = decomInfo Some(executorId) } else { None diff --git a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala index 129428963696d..fc9248de7ee05 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala @@ -356,7 +356,7 @@ private[spark] class ExecutorMonitor( if (removed != null) { decrementExecResourceProfileCount(removed.resourceProfileId) if (event.reason == ExecutorLossMessage.decommissionFinished || - event.reason == ExecutorDecommission().message) { + (event.reason != null && event.reason.startsWith(ExecutorDecommission.msgPrefix))) { metrics.gracefullyDecommissioned.inc() } else if (removed.decommissioning) { metrics.decommissionUnfinished.inc() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala index e004c334dee73..d9d2e6102f120 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala @@ -183,9 +183,14 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS taskEndEvents.asScala.filter(_.taskInfo.successful).map(_.taskInfo.executorId).headOption } - sc.addSparkListener(new SparkListener { + val listener = new SparkListener { + var removeReasonValidated = false + override def onExecutorRemoved(execRemoved: SparkListenerExecutorRemoved): Unit = { executorRemovedSem.release() + if (execRemoved.reason == ExecutorDecommission.msgPrefix + "test msg 0") { + removeReasonValidated = true + } } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { @@ -211,7 +216,8 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS } } } - }) + } + sc.addSparkListener(listener) // Cache the RDD lazily if (persist) { @@ -247,7 +253,7 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS // Decommission executor and ensure it is not relaunched by setting adjustTargetNumExecutors sched.decommissionExecutor( execToDecommission, - ExecutorDecommissionInfo("", None), + ExecutorDecommissionInfo("test msg 0", None), adjustTargetNumExecutors = true) val decomTime = new SystemClock().getTimeMillis() @@ -343,5 +349,7 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS // should have same value like before assert(testRdd.count() === numParts) assert(accum.value === numParts) + import scala.language.reflectiveCalls + assert(listener.removeReasonValidated) } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 3a1f0862e94e1..56fb5bf6c6cfe 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1023,7 +1023,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { file.deleteOnExit() val cmd = s""" - |#!/bin/bash + |#!/usr/bin/env bash |trap "" SIGTERM |sleep 10 """.stripMargin diff --git a/dev/requirements.txt b/dev/requirements.txt index 4b47c1f6e834a..651fc280627ee 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -46,7 +46,9 @@ PyGithub # pandas API on Spark Code formatter. black==22.6.0 -# Spark Connect +# Spark Connect (required) grpcio==1.48.1 protobuf==4.21.6 + +# Spark Connect python proto generation plugin (optional) mypy-protobuf diff --git a/docs/sql-data-sources-orc.md b/docs/sql-data-sources-orc.md index 28e237a382df8..200037a7dea17 100644 --- a/docs/sql-data-sources-orc.md +++ b/docs/sql-data-sources-orc.md @@ -153,6 +153,24 @@ When reading from Hive metastore ORC tables and inserting to Hive metastore ORC 2.3.0 + + spark.sql.orc.columnarReaderBatchSize + 4096 + + The number of rows to include in an orc vectorized reader batch. The number should + be carefully chosen to minimize overhead and avoid OOMs in reading data. + + 2.4.0 + + + spark.sql.orc.columnarWriterBatchSize + 1024 + + The number of rows to include in an orc vectorized writer batch. The number should + be carefully chosen to minimize overhead and avoid OOMs in writing data. + + 3.4.0 + spark.sql.orc.enableNestedColumnVectorizedReader false @@ -163,6 +181,25 @@ When reading from Hive metastore ORC tables and inserting to Hive metastore ORC 3.2.0 + + spark.sql.orc.filterPushdown + true + + When true, enable filter pushdown for ORC files. + + 1.4.0 + + + spark.sql.orc.aggregatePushdown + false + + If true, aggregates will be pushed down to ORC for optimization. Support MIN, MAX and + COUNT as aggregate expression. For MIN/MAX, support boolean, integer, float and date + type. For COUNT, support all data types. If statistics is missing from any ORC file + footer, exception would be thrown. + + 3.3.0 + spark.sql.orc.mergeSchema false diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index bc7f17fd5cb13..18cc579e4f9ea 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -33,6 +33,7 @@ license: | - Valid Base64 string should include symbols from in base64 alphabet (A-Za-z0-9+/), optional padding (`=`), and optional whitespaces. Whitespaces are skipped in conversion except when they are preceded by padding symbol(s). If padding is present it should conclude the string and follow rules described in RFC 4648 ยง 4. - Valid hexadecimal strings should include only allowed symbols (0-9A-Fa-f). - Valid values for `fmt` are case-insensitive `hex`, `base64`, `utf-8`, `utf8`. + - Since Spark 3.4, Spark throws only `PartitionsAlreadyExistException` when it creates partitions but some of them exist already. In Spark 3.3 or earlier, Spark can throw either `PartitionsAlreadyExistException` or `PartitionAlreadyExistsException`. ## Upgrading from Spark SQL 3.2 to 3.3 diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index e1054c7060f12..622805e7c6649 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -87,16 +87,18 @@ public static void main(String[] argsArray) throws Exception { cmd = buildCommand(builder, env, printLaunchCommand); } - if (isWindows()) { + // test for shell environments, to enable non-Windows treatment of command line prep + boolean shellflag = !isEmpty(System.getenv("SHELL")); + if (isWindows() && !shellflag) { System.out.println(prepareWindowsCommand(cmd, env)); } else { // A sequence of NULL character and newline separates command-strings and others. - System.out.println('\0'); + System.out.printf("%c\n",'\0'); // In bash, use NULL as the arg separator since it cannot be used in an argument. List bashCmd = prepareBashCommand(cmd, env); for (String c : bashCmd) { - System.out.print(c); + System.out.print(c.replaceFirst("\r$","")); System.out.print('\0'); } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d8ac407bb30d3..6a55b54904dff 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -596,11 +596,20 @@ object SparkParallelTestGrouping { object Core { import scala.sys.process.Process + def buildenv = Process(Seq("uname")).!!.trim.replaceFirst("[^A-Za-z0-9].*", "").toLowerCase + def bashpath = Process(Seq("where", "bash")).!!.split("[\r\n]+").head.replace('\\', '/') lazy val settings = Seq( (Compile / resourceGenerators) += Def.task { val buildScript = baseDirectory.value + "/../build/spark-build-info" val targetDir = baseDirectory.value + "/target/extra-resources/" - val command = Seq("bash", buildScript, targetDir, version.value) + // support Windows build under cygwin/mingw64, etc + val bash = buildenv match { + case "cygwin" | "msys2" | "mingw64" | "clang64" => + bashpath + case _ => + "bash" + } + val command = Seq(bash, buildScript, targetDir, version.value) Process(command).!! val propsFile = baseDirectory.value / "target" / "extra-resources" / "spark-version-info.properties" Seq(propsFile) @@ -655,11 +664,39 @@ object SparkConnect { (assembly / logLevel) := Level.Info, + // Exclude `scala-library` from assembly. + (assembly / assemblyPackageScala / assembleArtifact) := false, + + // Exclude `pmml-model-*.jar`, `scala-collection-compat_*.jar`,`jsr305-*.jar` and + // `netty-*.jar` and `unused-1.0.0.jar` from assembly. + (assembly / assemblyExcludedJars) := { + val cp = (assembly / fullClasspath).value + cp filter { v => + val name = v.data.getName + name.startsWith("pmml-model-") || name.startsWith("scala-collection-compat_") || + name.startsWith("jsr305-") || name.startsWith("netty-") || name == "unused-1.0.0.jar" + } + }, + (assembly / assemblyShadeRules) := Seq( ShadeRule.rename("io.grpc.**" -> "org.sparkproject.connect.grpc.@0").inAll, ShadeRule.rename("com.google.common.**" -> "org.sparkproject.connect.guava.@1").inAll, ShadeRule.rename("com.google.thirdparty.**" -> "org.sparkproject.connect.guava.@1").inAll, ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connect.protobuf.@1").inAll, + ShadeRule.rename("android.annotation.**" -> "org.sparkproject.connect.android_annotation.@1").inAll, + ShadeRule.rename("io.perfmark.**" -> "org.sparkproject.connect.io_perfmark.@1").inAll, + ShadeRule.rename("org.codehaus.mojo.animal_sniffer.**" -> "org.sparkproject.connect.animal_sniffer.@1").inAll, + ShadeRule.rename("com.google.j2objc.annotations.**" -> "org.sparkproject.connect.j2objc_annotations.@1").inAll, + ShadeRule.rename("com.google.errorprone.annotations.**" -> "org.sparkproject.connect.errorprone_annotations.@1").inAll, + ShadeRule.rename("org.checkerframework.**" -> "org.sparkproject.connect.checkerframework.@1").inAll, + ShadeRule.rename("com.google.gson.**" -> "org.sparkproject.connect.gson.@1").inAll, + ShadeRule.rename("com.google.api.**" -> "org.sparkproject.connect.google_protos.api.@1").inAll, + ShadeRule.rename("com.google.cloud.**" -> "org.sparkproject.connect.google_protos.cloud.@1").inAll, + ShadeRule.rename("com.google.geo.**" -> "org.sparkproject.connect.google_protos.geo.@1").inAll, + ShadeRule.rename("com.google.logging.**" -> "org.sparkproject.connect.google_protos.logging.@1").inAll, + ShadeRule.rename("com.google.longrunning.**" -> "org.sparkproject.connect.google_protos.longrunning.@1").inAll, + ShadeRule.rename("com.google.rpc.**" -> "org.sparkproject.connect.google_protos.rpc.@1").inAll, + ShadeRule.rename("com.google.type.**" -> "org.sparkproject.connect.google_protos.type.@1").inAll ), (assembly / assemblyMergeStrategy) := { @@ -667,7 +704,7 @@ object SparkConnect { // Drop all proto files that are not needed as artifacts of the build. case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard case _ => MergeStrategy.first - }, + } ) } diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py index 5b94da4412551..9db688d913462 100644 --- a/python/pyspark/pandas/generic.py +++ b/python/pyspark/pandas/generic.py @@ -45,7 +45,6 @@ from pyspark.sql.types import ( BooleanType, DoubleType, - IntegralType, LongType, NumericType, ) @@ -1421,32 +1420,16 @@ def product( def prod(psser: "Series") -> Column: spark_type = psser.spark.data_type spark_column = psser.spark.column - - if not skipna: - spark_column = F.when(spark_column.isNull(), np.nan).otherwise(spark_column) - if isinstance(spark_type, BooleanType): - scol = F.min(F.coalesce(spark_column, F.lit(True))).cast(LongType()) - elif isinstance(spark_type, NumericType): - num_zeros = F.sum(F.when(spark_column == 0, 1).otherwise(0)) - sign = F.when( - F.sum(F.when(spark_column < 0, 1).otherwise(0)) % 2 == 0, 1 - ).otherwise(-1) - - scol = F.when(num_zeros > 0, 0).otherwise( - sign * F.exp(F.sum(F.log(F.abs(spark_column)))) - ) - - if isinstance(spark_type, IntegralType): - scol = F.round(scol).cast(LongType()) - else: + spark_column = spark_column.cast(LongType()) + elif not isinstance(spark_type, NumericType): raise TypeError( "Could not convert {} ({}) to numeric".format( spark_type_to_pandas_dtype(spark_type), spark_type.simpleString() ) ) - return F.coalesce(scol, F.lit(1)) + return SF.product(spark_column, skipna) return self._reduce_for_stat_function( prod, diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 08a136aa26891..c5dbcb79710a5 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -62,7 +62,6 @@ StructField, StructType, StringType, - IntegralType, ) from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. @@ -469,21 +468,10 @@ def first(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fr if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def first(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.first(col, ignorenulls=True)) - - else: - - def first(col: Column) -> Column: - return F.first(col, ignorenulls=True) - return self._reduce_for_stat_function( - first, + lambda col: F.first(col, ignorenulls=True), accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike: @@ -550,21 +538,10 @@ def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fra if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def last(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.last(col, ignorenulls=True)) - - else: - - def last(col: Column) -> Column: - return F.last(col, ignorenulls=True) - return self._reduce_for_stat_function( - last, + lambda col: F.last(col, ignorenulls=True), accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike: @@ -625,20 +602,10 @@ def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def max(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.max(col)) - - else: - - def max(col: Column) -> Column: - return F.max(col) - return self._reduce_for_stat_function( - max, accepted_spark_types=(NumericType, BooleanType) if numeric_only else None + F.max, + accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) def mean(self, numeric_only: Optional[bool] = True) -> FrameLike: @@ -803,20 +770,10 @@ def min(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def min(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.min(col)) - - else: - - def min(col: Column) -> Column: - return F.min(col) - return self._reduce_for_stat_function( - min, accepted_spark_types=(NumericType, BooleanType) if numeric_only else None + F.min, + accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) # TODO: sync the doc. @@ -945,20 +902,11 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL f"numeric_only=False, skip unsupported columns: {unsupported}" ) - if min_count > 0: - - def sum(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.sum(col)) - - else: - - def sum(col: Column) -> Column: - return F.sum(col) - return self._reduce_for_stat_function( - sum, accepted_spark_types=(NumericType,), bool_to_numeric=True + F.sum, + accepted_spark_types=(NumericType, BooleanType), + bool_to_numeric=True, + min_count=min_count, ) # TODO: sync the doc. @@ -1320,53 +1268,18 @@ def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> Frame 1 NaN 2.0 0.0 2 NaN NaN NaN """ + if not isinstance(min_count, int): + raise TypeError("min_count must be integer") self._validate_agg_columns(numeric_only=numeric_only, function_name="prod") - groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))] - internal, agg_columns, sdf = self._prepare_reduce( - groupkey_names=groupkey_names, + return self._reduce_for_stat_function( + lambda col: SF.product(col, True), accepted_spark_types=(NumericType, BooleanType), bool_to_numeric=True, + min_count=min_count, ) - psdf: DataFrame = DataFrame(internal) - if len(psdf._internal.column_labels) > 0: - - stat_exprs = [] - for label in psdf._internal.column_labels: - psser = psdf._psser_for(label) - column = psser._dtype_op.nan_to_null(psser).spark.column - data_type = psser.spark.data_type - aggregating = ( - F.product(column).cast("long") - if isinstance(data_type, IntegralType) - else F.product(column) - ) - - if min_count > 0: - prod_scol = F.when( - F.count(F.when(~F.isnull(column), F.lit(0))) < min_count, F.lit(None) - ).otherwise(aggregating) - else: - prod_scol = aggregating - - stat_exprs.append(prod_scol.alias(psser._internal.data_spark_column_names[0])) - - sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs) - - else: - sdf = sdf.select(*groupkey_names).distinct() - - internal = internal.copy( - spark_frame=sdf, - index_spark_columns=[scol_for(sdf, col) for col in groupkey_names], - data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names], - data_fields=None, - ) - - return self._prepare_return(DataFrame(internal)) - def all(self, skipna: bool = True) -> FrameLike: """ Returns True if all values in the group are truthful, else False. @@ -3621,6 +3534,7 @@ def _reduce_for_stat_function( sfun: Callable[[Column], Column], accepted_spark_types: Optional[Tuple[Type[DataType], ...]] = None, bool_to_numeric: bool = False, + **kwargs: Any, ) -> FrameLike: """Apply an aggregate function `sfun` per column and reduce to a FrameLike. @@ -3640,14 +3554,19 @@ def _reduce_for_stat_function( psdf: DataFrame = DataFrame(internal) if len(psdf._internal.column_labels) > 0: + min_count = kwargs.get("min_count", 0) stat_exprs = [] for label in psdf._internal.column_labels: psser = psdf._psser_for(label) - stat_exprs.append( - sfun(psser._dtype_op.nan_to_null(psser).spark.column).alias( - psser._internal.data_spark_column_names[0] + input_scol = psser._dtype_op.nan_to_null(psser).spark.column + output_scol = sfun(input_scol) + + if min_count > 0: + output_scol = F.when( + F.count(F.when(~F.isnull(input_scol), F.lit(0))) >= min_count, output_scol ) - ) + + stat_exprs.append(output_scol.alias(psser._internal.data_spark_column_names[0])) sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs) else: sdf = sdf.select(*groupkey_names).distinct() diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index f9311296a5724..658d3459b24f3 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -27,6 +27,11 @@ ) +def product(col: Column, dropna: bool) -> Column: + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna)) + + def stddev(col: Column, ddof: int) -> Column: sc = SparkContext._active_spark_context return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof)) diff --git a/python/pyspark/pandas/tests/test_generic_functions.py b/python/pyspark/pandas/tests/test_generic_functions.py index 7c252c8356d80..d476302205938 100644 --- a/python/pyspark/pandas/tests/test_generic_functions.py +++ b/python/pyspark/pandas/tests/test_generic_functions.py @@ -200,6 +200,22 @@ def test_stat_functions(self): self.assert_eq(pdf.b.kurtosis(), psdf.b.kurtosis()) self.assert_eq(pdf.c.kurtosis(), psdf.c.kurtosis()) + def test_prod_precision(self): + pdf = pd.DataFrame( + { + "a": [np.nan, np.nan, np.nan, np.nan], + "b": [1, np.nan, np.nan, -4], + "c": [1, -2, 3, -4], + "d": [55108, 55108, 55108, 55108], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(pdf.prod(), psdf.prod()) + self.assert_eq(pdf.prod(skipna=False), psdf.prod(skipna=False)) + self.assert_eq(pdf.prod(min_count=3), psdf.prod(min_count=3)) + self.assert_eq(pdf.prod(skipna=False, min_count=3), psdf.prod(skipna=False, min_count=3)) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py index 33ba815564145..a203f77717e9d 100644 --- a/python/pyspark/pandas/tests/test_groupby.py +++ b/python/pyspark/pandas/tests/test_groupby.py @@ -1493,13 +1493,35 @@ def test_nth(self): self.psdf.groupby("B").nth("x") def test_prod(self): + pdf = pd.DataFrame( + { + "A": [1, 2, 1, 2, 1], + "B": [3.1, 4.1, 4.1, 3.1, 0.1], + "C": ["a", "b", "b", "a", "c"], + "D": [True, False, False, True, False], + "E": [-1, -2, 3, -4, -2], + "F": [-1.5, np.nan, -3.2, 0.1, 0], + "G": [np.nan, np.nan, np.nan, np.nan, np.nan], + } + ) + psdf = ps.from_pandas(pdf) + for n in [0, 1, 2, 128, -1, -2, -128]: - self._test_stat_func(lambda groupby_obj: groupby_obj.prod(min_count=n)) self._test_stat_func( - lambda groupby_obj: groupby_obj.prod(numeric_only=None, min_count=n) + lambda groupby_obj: groupby_obj.prod(min_count=n), check_exact=False ) self._test_stat_func( - lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n) + lambda groupby_obj: groupby_obj.prod(numeric_only=None, min_count=n), + check_exact=False, + ) + self._test_stat_func( + lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n), + check_exact=False, + ) + self.assert_eq( + pdf.groupby("A").prod(min_count=n).sort_index(), + psdf.groupby("A").prod(min_count=n).sort_index(), + almost=True, ) def test_cumcount(self): diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 5691011795dcf..780cfdfba8e9b 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagement.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagement.java index e2c693f2d0a92..6c9e5ac577a7b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagement.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagement.java @@ -22,7 +22,6 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException; -import org.apache.spark.sql.catalyst.analysis.PartitionAlreadyExistsException; import org.apache.spark.sql.catalyst.analysis.PartitionsAlreadyExistException; /** @@ -46,15 +45,16 @@ @Experimental public interface SupportsAtomicPartitionManagement extends SupportsPartitionManagement { + @SuppressWarnings("unchecked") @Override default void createPartition( InternalRow ident, Map properties) - throws PartitionAlreadyExistsException, UnsupportedOperationException { + throws PartitionsAlreadyExistException, UnsupportedOperationException { try { createPartitions(new InternalRow[]{ident}, new Map[]{properties}); } catch (PartitionsAlreadyExistException e) { - throw new PartitionAlreadyExistsException(e.getMessage()); + throw new PartitionsAlreadyExistException(e.getMessage()); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java index ec2b61a766499..4830e193222fc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException; -import org.apache.spark.sql.catalyst.analysis.PartitionAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.PartitionsAlreadyExistException; import org.apache.spark.sql.types.StructType; /** @@ -59,13 +59,13 @@ public interface SupportsPartitionManagement extends Table { * * @param ident a new partition identifier * @param properties the metadata of a partition - * @throws PartitionAlreadyExistsException If a partition already exists for the identifier + * @throws PartitionsAlreadyExistException If a partition already exists for the identifier * @throws UnsupportedOperationException If partition property is not supported */ void createPartition( InternalRow ident, Map properties) - throws PartitionAlreadyExistsException, UnsupportedOperationException; + throws PartitionsAlreadyExistException, UnsupportedOperationException; /** * Drop a partition from table. @@ -147,14 +147,14 @@ Map loadPartitionMetadata(InternalRow ident) * @param to new partition identifier * @return true if renaming completes successfully otherwise false * @throws UnsupportedOperationException If partition renaming is not supported - * @throws PartitionAlreadyExistsException If the `to` partition exists already + * @throws PartitionsAlreadyExistException If the `to` partition exists already * @throws NoSuchPartitionException If the `from` partition does not exist * * @since 3.2.0 */ default boolean renamePartition(InternalRow from, InternalRow to) throws UnsupportedOperationException, - PartitionAlreadyExistsException, + PartitionsAlreadyExistException, NoSuchPartitionException { throw new UnsupportedOperationException("Partition renaming is not supported"); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index b32958d13daf1..fe16174586bad 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -47,7 +47,7 @@ public class V2ExpressionSQLBuilder { public String build(Expression expr) { if (expr instanceof Literal) { - return visitLiteral((Literal) expr); + return visitLiteral((Literal) expr); } else if (expr instanceof NamedReference) { return visitNamedReference((NamedReference) expr); } else if (expr instanceof Cast) { @@ -213,7 +213,7 @@ public String build(Expression expr) { } } - protected String visitLiteral(Literal literal) { + protected String visitLiteral(Literal literal) { return literal.toString(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java b/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java index 444263f31113e..283258ecb0a55 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.Random; @@ -53,20 +54,20 @@ public class NumericHistogram { * * @since 3.3.0 */ - public static class Coord implements Comparable { + public static class Coord implements Comparable { public double x; public double y; @Override - public int compareTo(Object other) { - return Double.compare(x, ((Coord) other).x); + public int compareTo(Coord other) { + return Double.compare(x, other.x); } } // Class variables private int nbins; private int nusedbins; - private ArrayList bins; + private List bins; private Random prng; /** @@ -146,7 +147,7 @@ public void addBin(double x, double y, int b) { */ public void allocate(int num_bins) { nbins = num_bins; - bins = new ArrayList(); + bins = new ArrayList<>(); nusedbins = 0; } @@ -163,7 +164,7 @@ public void merge(NumericHistogram other) { // by deserializing the ArrayList of (x,y) pairs into an array of Coord objects nbins = other.nbins; nusedbins = other.nusedbins; - bins = new ArrayList(nusedbins); + bins = new ArrayList<>(nusedbins); for (int i = 0; i < other.nusedbins; i += 1) { Coord bin = new Coord(); bin.x = other.getBin(i).x; @@ -174,7 +175,7 @@ public void merge(NumericHistogram other) { // The aggregation buffer already contains a partial histogram. Therefore, we need // to merge histograms using Algorithm #2 from the Ben-Haim and Tom-Tov paper. - ArrayList tmp_bins = new ArrayList(nusedbins + other.nusedbins); + List tmp_bins = new ArrayList<>(nusedbins + other.nusedbins); // Copy all the histogram bins from us and 'other' into an overstuffed histogram for (int i = 0; i < nusedbins; i++) { Coord bin = new Coord(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 4f6c9a8c703e3..72e1dd94c94da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -584,6 +584,8 @@ trait Row extends Serializable { case (i: CalendarInterval, _) => JString(i.toString) case (a: Array[_], ArrayType(elementType, _)) => iteratorToJsonArray(a.iterator, elementType) + case (a: mutable.ArraySeq[_], ArrayType(elementType, _)) => + iteratorToJsonArray(a.iterator, elementType) case (s: Seq[_], ArrayType(elementType, _)) => iteratorToJsonArray(s.iterator, elementType) case (m: Map[String @unchecked, _], MapType(StringType, valueType, _)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DataSourceOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DataSourceOptions.scala new file mode 100644 index 0000000000000..5348d1054d5d4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DataSourceOptions.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +/** + * Interface defines the following methods for a data source: + * - register a new option name + * - retrieve all registered option names + * - valid a given option name + * - get alternative option name if any + */ +trait DataSourceOptions { + // Option -> Alternative Option if any + private val validOptions = collection.mutable.Map[String, Option[String]]() + + /** + * Register a new Option. + */ + protected def newOption(name: String): String = { + validOptions += (name -> None) + name + } + + /** + * Register a new Option with an alternative name. + * @param name Option name + * @param alternative Alternative option name + */ + protected def newOption(name: String, alternative: String): Unit = { + // Register both of the options + validOptions += (name -> Some(alternative)) + validOptions += (alternative -> Some(name)) + } + + /** + * @return All data source options and their alternatives if any + */ + def getAllOptions: scala.collection.Set[String] = validOptions.keySet + + /** + * @param name Option name to be validated + * @return if the given Option name is valid + */ + def isValidOption(name: String): Boolean = validOptions.contains(name) + + /** + * @param name Option name + * @return Alternative option name if any + */ + def getAlternativeOption(name: String): Option[String] = validOptions.get(name).flatten +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index f65c29a06cc65..c1dd80e3f77e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -59,29 +59,23 @@ class TableAlreadyExistsException(message: String, cause: Option[Throwable] = No class TempTableAlreadyExistsException(table: String) extends TableAlreadyExistsException(s"Temporary view '$table' already exists") -class PartitionAlreadyExistsException(message: String) extends AnalysisException(message) { - def this(db: String, table: String, spec: TablePartitionSpec) = { - this(s"Partition already exists in table '$table' database '$db':\n" + spec.mkString("\n")) - } - - def this(tableName: String, partitionIdent: InternalRow, partitionSchema: StructType) = { - this(s"Partition already exists in table $tableName:" + - partitionIdent.toSeq(partitionSchema).zip(partitionSchema.map(_.name)) - .map( kv => s"${kv._1} -> ${kv._2}").mkString(",")) - } -} - class PartitionsAlreadyExistException(message: String) extends AnalysisException(message) { def this(db: String, table: String, specs: Seq[TablePartitionSpec]) = { - this(s"The following partitions already exists in table '$table' database '$db':\n" + this(s"The following partitions already exist in table '$table' database '$db':\n" + specs.mkString("\n===\n")) } + def this(db: String, table: String, spec: TablePartitionSpec) = + this(db, table, Seq(spec)) + def this(tableName: String, partitionIdents: Seq[InternalRow], partitionSchema: StructType) = { - this(s"The following partitions already exists in table $tableName:" + + this(s"The following partitions already exist in table $tableName:" + partitionIdents.map(id => partitionSchema.map(_.name).zip(id.toSeq(partitionSchema)) .map( kv => s"${kv._1} -> ${kv._2}").mkString(",")).mkString("\n===\n")) } + + def this(tableName: String, partitionIdent: InternalRow, partitionSchema: StructType) = + this(tableName, Seq(partitionIdent), partitionSchema) } class FunctionAlreadyExistsException(db: String, func: String) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 218a342e669bd..90e824284bdbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -90,7 +90,7 @@ class InMemoryCatalog( specs: Seq[TablePartitionSpec]): Unit = { specs.foreach { s => if (partitionExists(db, table, s)) { - throw new PartitionAlreadyExistsException(db = db, table = table, spec = s) + throw new PartitionsAlreadyExistException(db = db, table = table, spec = s) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index cc2ba8ac7e4bb..bc01986afdb14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1784,11 +1784,11 @@ class SessionCatalog( } /** - * List all registered functions in a database with the given pattern. + * List all built-in and temporary functions with the given pattern. */ - private def listRegisteredFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = { + private def listBuiltinAndTempFunctions(pattern: String): Seq[FunctionIdentifier] = { val functions = (functionRegistry.listFunction() ++ tableFunctionRegistry.listFunction()) - .filter(_.database.forall(_ == db)) + .filter(_.database.isEmpty) StringUtils.filterPattern(functions.map(_.unquotedString), pattern).map { f => // In functionRegistry, function names are stored as an unquoted format. Try(parser.parseFunctionIdentifier(f)) match { @@ -1817,7 +1817,7 @@ class SessionCatalog( requireDbExists(dbName) val dbFunctions = externalCatalog.listFunctions(dbName, pattern).map { f => FunctionIdentifier(f, Some(dbName)) } - val loadedFunctions = listRegisteredFunctions(db, pattern) + val loadedFunctions = listBuiltinAndTempFunctions(pattern) val functions = dbFunctions ++ loadedFunctions // The session catalog caches some persistent functions in the FunctionRegistry // so there can be duplicates. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 88396c65cc070..a66070aa853d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -24,7 +24,7 @@ import java.util.Locale import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, UnescapedQuoteHandling} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -37,6 +37,8 @@ class CSVOptions( defaultColumnNameOfCorruptRecord: String) extends FileSourceOptions(parameters) with Logging { + import CSVOptions._ + def this( parameters: Map[String, String], columnPruning: Boolean, @@ -99,46 +101,46 @@ class CSVOptions( } val delimiter = CSVExprUtils.toDelimiterStr( - parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) + parameters.getOrElse(SEP, parameters.getOrElse(DELIMITER, ","))) val parseMode: ParseMode = - parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) - val charset = parameters.getOrElse("encoding", - parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) + parameters.get(MODE).map(ParseMode.fromString).getOrElse(PermissiveMode) + val charset = parameters.getOrElse(ENCODING, + parameters.getOrElse(CHARSET, StandardCharsets.UTF_8.name())) - val quote = getChar("quote", '\"') - val escape = getChar("escape", '\\') - val charToEscapeQuoteEscaping = parameters.get("charToEscapeQuoteEscaping") match { + val quote = getChar(QUOTE, '\"') + val escape = getChar(ESCAPE, '\\') + val charToEscapeQuoteEscaping = parameters.get(CHAR_TO_ESCAPE_QUOTE_ESCAPING) match { case None => None case Some(null) => None case Some(value) if value.length == 0 => None case Some(value) if value.length == 1 => Some(value.charAt(0)) - case _ => throw QueryExecutionErrors.paramExceedOneCharError("charToEscapeQuoteEscaping") + case _ => throw QueryExecutionErrors.paramExceedOneCharError(CHAR_TO_ESCAPE_QUOTE_ESCAPING) } - val comment = getChar("comment", '\u0000') + val comment = getChar(COMMENT, '\u0000') - val headerFlag = getBool("header") - val inferSchemaFlag = getBool("inferSchema") - val ignoreLeadingWhiteSpaceInRead = getBool("ignoreLeadingWhiteSpace", default = false) - val ignoreTrailingWhiteSpaceInRead = getBool("ignoreTrailingWhiteSpace", default = false) + val headerFlag = getBool(HEADER) + val inferSchemaFlag = getBool(INFER_SCHEMA) + val ignoreLeadingWhiteSpaceInRead = getBool(IGNORE_LEADING_WHITESPACE, default = false) + val ignoreTrailingWhiteSpaceInRead = getBool(IGNORE_TRAILING_WHITESPACE, default = false) // For write, both options were `true` by default. We leave it as `true` for // backwards compatibility. - val ignoreLeadingWhiteSpaceFlagInWrite = getBool("ignoreLeadingWhiteSpace", default = true) - val ignoreTrailingWhiteSpaceFlagInWrite = getBool("ignoreTrailingWhiteSpace", default = true) + val ignoreLeadingWhiteSpaceFlagInWrite = getBool(IGNORE_LEADING_WHITESPACE, default = true) + val ignoreTrailingWhiteSpaceFlagInWrite = getBool(IGNORE_TRAILING_WHITESPACE, default = true) val columnNameOfCorruptRecord = - parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + parameters.getOrElse(COLUMN_NAME_OF_CORRUPT_RECORD, defaultColumnNameOfCorruptRecord) - val nullValue = parameters.getOrElse("nullValue", "") + val nullValue = parameters.getOrElse(NULL_VALUE, "") - val nanValue = parameters.getOrElse("nanValue", "NaN") + val nanValue = parameters.getOrElse(NAN_VALUE, "NaN") - val positiveInf = parameters.getOrElse("positiveInf", "Inf") - val negativeInf = parameters.getOrElse("negativeInf", "-Inf") + val positiveInf = parameters.getOrElse(POSITIVE_INF, "Inf") + val negativeInf = parameters.getOrElse(NEGATIVE_INF, "-Inf") val compressionCodec: Option[String] = { - val name = parameters.get("compression").orElse(parameters.get("codec")) + val name = parameters.get(COMPRESSION).orElse(parameters.get(CODEC)) name.map(CompressionCodecs.getCodecClassName) } @@ -146,7 +148,7 @@ class CSVOptions( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // A language tag in IETF BCP 47 format - val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + val locale: Locale = parameters.get(LOCALE).map(Locale.forLanguageTag).getOrElse(Locale.US) /** * Infer columns with all valid date entries as date type (otherwise inferred as string or @@ -161,11 +163,11 @@ class CSVOptions( if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { false } else { - getBool("prefersDate", true) + getBool(PREFERS_DATE, true) } } - val dateFormatOption: Option[String] = parameters.get("dateFormat") + val dateFormatOption: Option[String] = parameters.get(DATE_FORMAT) // Provide a default value for dateFormatInRead when prefersDate. This ensures that the // Iso8601DateFormatter (with strict date parsing) is used for date inference val dateFormatInRead: Option[String] = @@ -174,24 +176,24 @@ class CSVOptions( } else { dateFormatOption } - val dateFormatInWrite: String = parameters.getOrElse("dateFormat", DateFormatter.defaultPattern) + val dateFormatInWrite: String = parameters.getOrElse(DATE_FORMAT, DateFormatter.defaultPattern) val timestampFormatInRead: Option[String] = if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { - Some(parameters.getOrElse("timestampFormat", + Some(parameters.getOrElse(TIMESTAMP_FORMAT, s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX")) } else { - parameters.get("timestampFormat") + parameters.get(TIMESTAMP_FORMAT) } - val timestampFormatInWrite: String = parameters.getOrElse("timestampFormat", + val timestampFormatInWrite: String = parameters.getOrElse(TIMESTAMP_FORMAT, if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX" } else { s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS][XXX]" }) - val timestampNTZFormatInRead: Option[String] = parameters.get("timestampNTZFormat") - val timestampNTZFormatInWrite: String = parameters.getOrElse("timestampNTZFormat", + val timestampNTZFormatInRead: Option[String] = parameters.get(TIMESTAMP_NTZ_FORMAT) + val timestampNTZFormatInWrite: String = parameters.getOrElse(TIMESTAMP_NTZ_FORMAT, s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]") // SPARK-39731: Enables the backward compatible parsing behavior. @@ -203,17 +205,17 @@ class CSVOptions( // Otherwise, depending on the parser policy and a custom pattern, an exception may be thrown and // the value will be parsed as null. val enableDateTimeParsingFallback: Option[Boolean] = - parameters.get("enableDateTimeParsingFallback").map(_.toBoolean) + parameters.get(ENABLE_DATETIME_PARSING_FALLBACK).map(_.toBoolean) - val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + val multiLine = parameters.get(MULTI_LINE).map(_.toBoolean).getOrElse(false) - val maxColumns = getInt("maxColumns", 20480) + val maxColumns = getInt(MAX_COLUMNS, 20480) - val maxCharsPerColumn = getInt("maxCharsPerColumn", -1) + val maxCharsPerColumn = getInt(MAX_CHARS_PER_COLUMN, -1) - val escapeQuotes = getBool("escapeQuotes", true) + val escapeQuotes = getBool(ESCAPE_QUOTES, true) - val quoteAll = getBool("quoteAll", false) + val quoteAll = getBool(QUOTE_ALL, false) /** * The max error content length in CSV parser/writer exception message. @@ -223,18 +225,18 @@ class CSVOptions( val isCommentSet = this.comment != '\u0000' val samplingRatio = - parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + parameters.get(SAMPLING_RATIO).map(_.toDouble).getOrElse(1.0) /** * Forcibly apply the specified or inferred schema to datasource files. * If the option is enabled, headers of CSV files will be ignored. */ - val enforceSchema = getBool("enforceSchema", default = true) + val enforceSchema = getBool(ENFORCE_SCHEMA, default = true) /** * String representation of an empty value in read and in write. */ - val emptyValue = parameters.get("emptyValue") + val emptyValue = parameters.get(EMPTY_VALUE) /** * The string is returned when CSV reader doesn't have any characters for input value, * or an empty quoted string `""`. Default value is empty string. @@ -248,7 +250,7 @@ class CSVOptions( /** * A string between two consecutive JSON records. */ - val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + val lineSeparator: Option[String] = parameters.get(LINE_SEP).map { sep => require(sep.nonEmpty, "'lineSep' cannot be an empty string.") // Intentionally allow it up to 2 for Window's CRLF although multiple // characters have an issue with quotes. This is intentionally undocumented. @@ -263,14 +265,14 @@ class CSVOptions( } val lineSeparatorInWrite: Option[String] = lineSeparator - val inputBufferSize: Option[Int] = parameters.get("inputBufferSize").map(_.toInt) + val inputBufferSize: Option[Int] = parameters.get(INPUT_BUFFER_SIZE).map(_.toInt) .orElse(SQLConf.get.getConf(SQLConf.CSV_INPUT_BUFFER_SIZE)) /** * The handling method to be used when unescaped quotes are found in the input. */ val unescapedQuoteHandling: UnescapedQuoteHandling = UnescapedQuoteHandling.valueOf(parameters - .getOrElse("unescapedQuoteHandling", "STOP_AT_DELIMITER").toUpperCase(Locale.ROOT)) + .getOrElse(UNESCAPED_QUOTE_HANDLING, "STOP_AT_DELIMITER").toUpperCase(Locale.ROOT)) def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() @@ -327,3 +329,48 @@ class CSVOptions( settings } } + +object CSVOptions extends DataSourceOptions { + val HEADER = newOption("header") + val INFER_SCHEMA = newOption("inferSchema") + val IGNORE_LEADING_WHITESPACE = newOption("ignoreLeadingWhiteSpace") + val IGNORE_TRAILING_WHITESPACE = newOption("ignoreTrailingWhiteSpace") + val PREFERS_DATE = newOption("prefersDate") + val ESCAPE_QUOTES = newOption("escapeQuotes") + val QUOTE_ALL = newOption("quoteAll") + val ENFORCE_SCHEMA = newOption("enforceSchema") + val QUOTE = newOption("quote") + val ESCAPE = newOption("escape") + val COMMENT = newOption("comment") + val MAX_COLUMNS = newOption("maxColumns") + val MAX_CHARS_PER_COLUMN = newOption("maxCharsPerColumn") + val MODE = newOption("mode") + val CHAR_TO_ESCAPE_QUOTE_ESCAPING = newOption("charToEscapeQuoteEscaping") + val LOCALE = newOption("locale") + val DATE_FORMAT = newOption("dateFormat") + val TIMESTAMP_FORMAT = newOption("timestampFormat") + val TIMESTAMP_NTZ_FORMAT = newOption("timestampNTZFormat") + val ENABLE_DATETIME_PARSING_FALLBACK = newOption("enableDateTimeParsingFallback") + val MULTI_LINE = newOption("multiLine") + val SAMPLING_RATIO = newOption("samplingRatio") + val EMPTY_VALUE = newOption("emptyValue") + val LINE_SEP = newOption("lineSep") + val INPUT_BUFFER_SIZE = newOption("inputBufferSize") + val COLUMN_NAME_OF_CORRUPT_RECORD = newOption("columnNameOfCorruptRecord") + val NULL_VALUE = newOption("nullValue") + val NAN_VALUE = newOption("nanValue") + val POSITIVE_INF = newOption("positiveInf") + val NEGATIVE_INF = newOption("negativeInf") + val TIME_ZONE = newOption("timeZone") + val UNESCAPED_QUOTE_HANDLING = newOption("unescapedQuoteHandling") + // Options with alternative + val ENCODING = "encoding" + val CHARSET = "charset" + newOption(ENCODING, CHARSET) + val COMPRESSION = "compression" + val CODEC = "codec" + newOption(COMPRESSION, CODEC) + val SEP = "sep" + val DELIMITER = "delimiter" + newOption(SEP, DELIMITER) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala index 3af3944fd47d7..3325c8f16a4f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ImplicitCastInputTypes, Literal} +import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, Exp, Expression, If, ImplicitCastInputTypes, IsNull, Literal, Log} import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType} +import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, DoubleType, IntegralType, LongType, NumericType} /** Multiply numerical values within an aggregation group */ @@ -63,3 +63,114 @@ case class Product(child: Expression) override protected def withNewChildInternal(newChild: Expression): Product = copy(child = newChild) } + +/** + * Product in Pandas' fashion. This expression is dedicated only for Pandas API on Spark. + * It has three main differences from `Product`: + * 1, it compute the product of `Fractional` inputs in a more numerical-stable way; + * 2, it compute the product of `Integral` inputs with LongType variables internally; + * 3, it accepts NULLs when `ignoreNA` is False; + */ +case class PandasProduct( + child: Expression, + ignoreNA: Boolean) + extends DeclarativeAggregate with ImplicitCastInputTypes with UnaryLike[Expression] { + + override def nullable: Boolean = !ignoreNA + + override def dataType: DataType = child.dataType match { + case _: IntegralType => LongType + case _ => DoubleType + } + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + private lazy val product = + AttributeReference("product", LongType, nullable = false)() + private lazy val logSum = + AttributeReference("logSum", DoubleType, nullable = false)() + private lazy val positive = + AttributeReference("positive", BooleanType, nullable = false)() + private lazy val containsZero = + AttributeReference("containsZero", BooleanType, nullable = false)() + private lazy val containsNull = + AttributeReference("containsNull", BooleanType, nullable = false)() + + override lazy val aggBufferAttributes = child.dataType match { + case _: IntegralType => + Seq(product, containsNull) + case _ => + Seq(logSum, positive, containsZero, containsNull) + } + + override lazy val initialValues: Seq[Expression] = child.dataType match { + case _: IntegralType => + Seq(Literal(1L), Literal(false)) + case _ => + Seq(Literal(0.0), Literal(true), Literal(false), Literal(false)) + } + + override lazy val updateExpressions: Seq[Expression] = child.dataType match { + case _: IntegralType => + Seq( + If(IsNull(child), product, product * child), + containsNull || IsNull(child) + ) + case _ => + val newLogSum = logSum + Log(Abs(child)) + val newPositive = If(child < Literal(0.0), !positive, positive) + val newContainsZero = containsZero || child <=> Literal(0.0) + val newContainsNull = containsNull || IsNull(child) + if (ignoreNA) { + Seq( + If(IsNull(child) || newContainsZero, logSum, newLogSum), + newPositive, + newContainsZero, + newContainsNull + ) + } else { + Seq( + If(newContainsNull || newContainsZero, logSum, newLogSum), + newPositive, + newContainsZero, + newContainsNull + ) + } + } + + override lazy val mergeExpressions: Seq[Expression] = child.dataType match { + case _: IntegralType => + Seq( + product.left * product.right, + containsNull.left || containsNull.right + ) + case _ => + Seq( + logSum.left + logSum.right, + positive.left === positive.right, + containsZero.left || containsZero.right, + containsNull.left || containsNull.right + ) + } + + override lazy val evaluateExpression: Expression = child.dataType match { + case _: IntegralType => + if (ignoreNA) { + product + } else { + If(containsNull, Literal(null, LongType), product) + } + case _ => + val product = If(positive, Exp(logSum), -Exp(logSum)) + if (ignoreNA) { + If(containsZero, Literal(0.0), product) + } else { + If(containsNull, Literal(null, DoubleType), + If(containsZero, Literal(0.0), product)) + } + } + + override def prettyName: String = "pandas_product" + override protected def withNewChildInternal(newChild: Expression): PandasProduct = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index cad833d66f68b..efaadac6ed1c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -24,7 +24,9 @@ import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder +import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext, UnaryLike} @@ -66,9 +68,16 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression (left.dataType, right.dataType) match { case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess - case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + - s"been two ${ArrayType.simpleString}s with same element type, but it's " + - s"[${left.dataType.catalogString}, ${right.dataType.catalogString}]") + case _ => + DataTypeMismatch( + errorSubClass = "BINARY_ARRAY_DIFF_TYPES", + messageParameters = Map( + "functionName" -> prettyName, + "arrayType" -> toSQLType(ArrayType), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType) + ) + ) } } @@ -229,12 +238,21 @@ case class MapContainsKey(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { case (_, NullType) => - TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + DataTypeMismatch( + errorSubClass = "NULL_TYPE", + Map("functionName" -> prettyName)) case (MapType(kt, _, _), dt) if kt.sameType(dt) => TypeUtils.checkForOrderingExpr(kt, s"function $prettyName") - case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + - s"been ${MapType.simpleString} followed by a value with same key type, but it's " + - s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") + case _ => + DataTypeMismatch( + errorSubClass = "MAP_CONTAINS_KEY_DIFF_TYPES", + messageParameters = Map( + "functionName" -> prettyName, + "dataType" -> toSQLType(MapType), + "leftType" -> toSQLType(left.dataType), + "rightType" -> toSQLType(right.dataType) + ) + ) } } @@ -663,9 +681,13 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres override def checkInputDataTypes(): TypeCheckResult = { val funcName = s"function $prettyName" if (children.exists(!_.dataType.isInstanceOf[MapType])) { - TypeCheckResult.TypeCheckFailure( - s"input to $funcName should all be of type map, but it's " + - children.map(_.dataType.catalogString).mkString("[", ", ", "]")) + DataTypeMismatch( + errorSubClass = "MAP_CONCAT_DIFF_TYPES", + messageParameters = Map( + "functionName" -> funcName, + "dataType" -> children.map(_.dataType).map(toSQLType).mkString("[", ", ", "]") + ) + ) } else { val sameTypeCheck = TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) if (sameTypeCheck.isFailure) { @@ -801,8 +823,15 @@ case class MapFromEntries(child: Expression) extends UnaryExpression with NullIn override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { case Some((mapType, _, _)) => TypeUtils.checkForMapKeyType(mapType.keyType) - case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + - s"${child.dataType.catalogString} type. $prettyName accepts only arrays of pair structs.") + case None => + DataTypeMismatch( + errorSubClass = "MAP_FROM_ENTRIES_WRONG_TYPE", + messageParameters = Map( + "functionName" -> prettyName, + "childExpr" -> toSQLExpr(child), + "childType" -> toSQLType(child.dataType) + ) + ) } private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 9679a60622bc9..bf5b83e9df0f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -25,7 +25,7 @@ import com.fasterxml.jackson.core.{JsonFactory, JsonFactoryBuilder} import com.fasterxml.jackson.core.json.JsonReadFeature import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy @@ -41,6 +41,8 @@ private[sql] class JSONOptions( defaultColumnNameOfCorruptRecord: String) extends FileSourceOptions(parameters) with Logging { + import JSONOptions._ + def this( parameters: Map[String, String], defaultTimeZoneId: String, @@ -52,36 +54,36 @@ private[sql] class JSONOptions( } val samplingRatio = - parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + parameters.get(SAMPLING_RATIO).map(_.toDouble).getOrElse(1.0) val primitivesAsString = - parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) + parameters.get(PRIMITIVES_AS_STRING).map(_.toBoolean).getOrElse(false) val prefersDecimal = - parameters.get("prefersDecimal").map(_.toBoolean).getOrElse(false) + parameters.get(PREFERS_DECIMAL).map(_.toBoolean).getOrElse(false) val allowComments = - parameters.get("allowComments").map(_.toBoolean).getOrElse(false) + parameters.get(ALLOW_COMMENTS).map(_.toBoolean).getOrElse(false) val allowUnquotedFieldNames = - parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false) + parameters.get(ALLOW_UNQUOTED_FIELD_NAMES).map(_.toBoolean).getOrElse(false) val allowSingleQuotes = - parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true) + parameters.get(ALLOW_SINGLE_QUOTES).map(_.toBoolean).getOrElse(true) val allowNumericLeadingZeros = - parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false) + parameters.get(ALLOW_NUMERIC_LEADING_ZEROS).map(_.toBoolean).getOrElse(false) val allowNonNumericNumbers = - parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + parameters.get(ALLOW_NON_NUMERIC_NUMBERS).map(_.toBoolean).getOrElse(true) val allowBackslashEscapingAnyCharacter = - parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) + parameters.get(ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER).map(_.toBoolean).getOrElse(false) private val allowUnquotedControlChars = - parameters.get("allowUnquotedControlChars").map(_.toBoolean).getOrElse(false) - val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) + parameters.get(ALLOW_UNQUOTED_CONTROL_CHARS).map(_.toBoolean).getOrElse(false) + val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName) val parseMode: ParseMode = - parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) + parameters.get(MODE).map(ParseMode.fromString).getOrElse(PermissiveMode) val columnNameOfCorruptRecord = - parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + parameters.getOrElse(COLUMN_NAME_OF_CORRUPTED_RECORD, defaultColumnNameOfCorruptRecord) // Whether to ignore column of all null values or empty array/struct during schema inference - val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + val dropFieldIfAllNull = parameters.get(DROP_FIELD_IF_ALL_NULL).map(_.toBoolean).getOrElse(false) // Whether to ignore null fields during json generating - val ignoreNullFields = parameters.get("ignoreNullFields").map(_.toBoolean) + val ignoreNullFields = parameters.get(IGNORE_NULL_FIELDS).map(_.toBoolean) .getOrElse(SQLConf.get.jsonGeneratorIgnoreNullFields) // If this is true, when writing NULL values to columns of JSON tables with explicit DEFAULT @@ -91,31 +93,31 @@ private[sql] class JSONOptions( val writeNullIfWithDefaultValue = SQLConf.get.jsonWriteNullIfWithDefaultValue // A language tag in IETF BCP 47 format - val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + val locale: Locale = parameters.get(LOCALE).map(Locale.forLanguageTag).getOrElse(Locale.US) val zoneId: ZoneId = DateTimeUtils.getZoneId( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) - val dateFormatInRead: Option[String] = parameters.get("dateFormat") - val dateFormatInWrite: String = parameters.getOrElse("dateFormat", DateFormatter.defaultPattern) + val dateFormatInRead: Option[String] = parameters.get(DATE_FORMAT) + val dateFormatInWrite: String = parameters.getOrElse(DATE_FORMAT, DateFormatter.defaultPattern) val timestampFormatInRead: Option[String] = if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { - Some(parameters.getOrElse("timestampFormat", + Some(parameters.getOrElse(TIMESTAMP_FORMAT, s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX")) } else { - parameters.get("timestampFormat") + parameters.get(TIMESTAMP_FORMAT) } - val timestampFormatInWrite: String = parameters.getOrElse("timestampFormat", + val timestampFormatInWrite: String = parameters.getOrElse(TIMESTAMP_FORMAT, if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX" } else { s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS][XXX]" }) - val timestampNTZFormatInRead: Option[String] = parameters.get("timestampNTZFormat") + val timestampNTZFormatInRead: Option[String] = parameters.get(TIMESTAMP_NTZ_FORMAT) val timestampNTZFormatInWrite: String = - parameters.getOrElse("timestampNTZFormat", s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]") + parameters.getOrElse(TIMESTAMP_NTZ_FORMAT, s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]") // SPARK-39731: Enables the backward compatible parsing behavior. // Generally, this config should be set to false to avoid producing potentially incorrect results @@ -126,14 +128,14 @@ private[sql] class JSONOptions( // Otherwise, depending on the parser policy and a custom pattern, an exception may be thrown and // the value will be parsed as null. val enableDateTimeParsingFallback: Option[Boolean] = - parameters.get("enableDateTimeParsingFallback").map(_.toBoolean) + parameters.get(ENABLE_DATETIME_PARSING_FALLBACK).map(_.toBoolean) - val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + val multiLine = parameters.get(MULTI_LINE).map(_.toBoolean).getOrElse(false) /** * A string between two consecutive JSON records. */ - val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + val lineSeparator: Option[String] = parameters.get(LINE_SEP).map { sep => require(sep.nonEmpty, "'lineSep' cannot be an empty string.") sep } @@ -146,8 +148,8 @@ private[sql] class JSONOptions( * when the multiLine option is set to `true`. If encoding is not specified in write, * UTF-8 is used by default. */ - val encoding: Option[String] = parameters.get("encoding") - .orElse(parameters.get("charset")).map(checkedEncoding) + val encoding: Option[String] = parameters.get(ENCODING) + .orElse(parameters.get(CHARSET)).map(checkedEncoding) val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => lineSep.getBytes(encoding.getOrElse(StandardCharsets.UTF_8.name())) @@ -157,20 +159,20 @@ private[sql] class JSONOptions( /** * Generating JSON strings in pretty representation if the parameter is enabled. */ - val pretty: Boolean = parameters.get("pretty").map(_.toBoolean).getOrElse(false) + val pretty: Boolean = parameters.get(PRETTY).map(_.toBoolean).getOrElse(false) /** * Enables inferring of TimestampType and TimestampNTZType from strings matched to the * corresponding timestamp pattern defined by the timestampFormat and timestampNTZFormat options * respectively. */ - val inferTimestamp: Boolean = parameters.get("inferTimestamp").map(_.toBoolean).getOrElse(false) + val inferTimestamp: Boolean = parameters.get(INFER_TIMESTAMP).map(_.toBoolean).getOrElse(false) /** * Generating \u0000 style codepoints for non-ASCII characters if the parameter is enabled. */ val writeNonAsciiCharacterAsCodePoint: Boolean = - parameters.get("writeNonAsciiCharacterAsCodePoint").map(_.toBoolean).getOrElse(false) + parameters.get(WRITE_NON_ASCII_CHARACTER_AS_CODEPOINT).map(_.toBoolean).getOrElse(false) /** Build a Jackson [[JsonFactory]] using JSON options. */ def buildJsonFactory(): JsonFactory = { @@ -230,3 +232,36 @@ private[sql] object JSONOptionsInRead { Charset.forName("UTF-32") ) } + +object JSONOptions extends DataSourceOptions { + val SAMPLING_RATIO = newOption("samplingRatio") + val PRIMITIVES_AS_STRING = newOption("primitivesAsString") + val PREFERS_DECIMAL = newOption("prefersDecimal") + val ALLOW_COMMENTS = newOption("allowComments") + val ALLOW_UNQUOTED_FIELD_NAMES = newOption("allowUnquotedFieldNames") + val ALLOW_SINGLE_QUOTES = newOption("allowSingleQuotes") + val ALLOW_NUMERIC_LEADING_ZEROS = newOption("allowNumericLeadingZeros") + val ALLOW_NON_NUMERIC_NUMBERS = newOption("allowNonNumericNumbers") + val ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER = newOption("allowBackslashEscapingAnyCharacter") + val ALLOW_UNQUOTED_CONTROL_CHARS = newOption("allowUnquotedControlChars") + val COMPRESSION = newOption("compression") + val MODE = newOption("mode") + val DROP_FIELD_IF_ALL_NULL = newOption("dropFieldIfAllNull") + val IGNORE_NULL_FIELDS = newOption("ignoreNullFields") + val LOCALE = newOption("locale") + val DATE_FORMAT = newOption("dateFormat") + val TIMESTAMP_FORMAT = newOption("timestampFormat") + val TIMESTAMP_NTZ_FORMAT = newOption("timestampNTZFormat") + val ENABLE_DATETIME_PARSING_FALLBACK = newOption("enableDateTimeParsingFallback") + val MULTI_LINE = newOption("multiLine") + val LINE_SEP = newOption("lineSep") + val PRETTY = newOption("pretty") + val INFER_TIMESTAMP = newOption("inferTimestamp") + val COLUMN_NAME_OF_CORRUPTED_RECORD = newOption("columnNameOfCorruptRecord") + val TIME_ZONE = newOption("timeZone") + val WRITE_NON_ASCII_CHARACTER_AS_CODEPOINT = newOption("writeNonAsciiCharacterAsCodePoint") + // Options with alternative + val ENCODING = "encoding" + val CHARSET = "charset" + newOption(ENCODING, CHARSET) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index c7e42692d163c..7cb471d14bda4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ @@ -31,8 +33,13 @@ object TypeUtils { if (RowOrdering.isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure( - s"$caller does not support ordering on type ${dt.catalogString}") + DataTypeMismatch( + errorSubClass = "INVALID_ORDERING_TYPE", + Map( + "functionName" -> caller, + "dataType" -> toSQLType(dt) + ) + ) } } @@ -40,15 +47,24 @@ object TypeUtils { if (TypeCoercion.haveSameType(types)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure( - s"input to $caller should all be the same type, but it's " + - types.map(_.catalogString).mkString("[", ", ", "]")) + DataTypeMismatch( + errorSubClass = "DATA_DIFF_TYPES", + messageParameters = Map( + "functionName" -> caller, + "dataType" -> types.map(toSQLType).mkString("(", " or ", ")") + ) + ) } } def checkForMapKeyType(keyType: DataType): TypeCheckResult = { if (keyType.existsRecursively(_.isInstanceOf[MapType])) { - TypeCheckResult.TypeCheckFailure("The key of map cannot be/contain map.") + DataTypeMismatch( + errorSubClass = "INVALID_MAP_KEY_TYPE", + messageParameters = Map( + "keyType" -> toSQLType(keyType) + ) + ) } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index a3e1b980d1fa7..9b043957d2cf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.errors import java.io.{FileNotFoundException, IOException} import java.lang.reflect.InvocationTargetException import java.net.{URISyntaxException, URL} -import java.sql.{SQLException, SQLFeatureNotSupportedException} +import java.sql.{SQLFeatureNotSupportedException} import java.time.{DateTimeException, LocalDate} import java.time.temporal.ChronoField import java.util.ConcurrentModificationException @@ -967,39 +967,55 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { def fileLengthExceedsMaxLengthError(status: FileStatus, maxLength: Int): Throwable = { new SparkException( - s"The length of ${status.getPath} is ${status.getLen}, " + - s"which exceeds the max length allowed: ${maxLength}.") + errorClass = "_LEGACY_ERROR_TEMP_2076", + messageParameters = Map( + "path" -> status.getPath.toString(), + "len" -> status.getLen.toString(), + "maxLength" -> maxLength.toString()), + cause = null) } - def unsupportedFieldNameError(fieldName: String): Throwable = { - new RuntimeException(s"Unsupported field name: ${fieldName}") + def unsupportedFieldNameError(fieldName: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2077", + messageParameters = Map("fieldName" -> fieldName)) } def cannotSpecifyBothJdbcTableNameAndQueryError( - jdbcTableName: String, jdbcQueryString: String): Throwable = { - new IllegalArgumentException( - s"Both '$jdbcTableName' and '$jdbcQueryString' can not be specified at the same time.") + jdbcTableName: String, jdbcQueryString: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2078", + messageParameters = Map( + "jdbcTableName" -> jdbcTableName, + "jdbcQueryString" -> jdbcQueryString)) } def missingJdbcTableNameAndQueryError( - jdbcTableName: String, jdbcQueryString: String): Throwable = { - new IllegalArgumentException( - s"Option '$jdbcTableName' or '$jdbcQueryString' is required." - ) + jdbcTableName: String, jdbcQueryString: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2079", + messageParameters = Map( + "jdbcTableName" -> jdbcTableName, + "jdbcQueryString" -> jdbcQueryString)) } - def emptyOptionError(optionName: String): Throwable = { - new IllegalArgumentException(s"Option `$optionName` can not be empty.") + def emptyOptionError(optionName: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2080", + messageParameters = Map("optionName" -> optionName)) } - def invalidJdbcTxnIsolationLevelError(jdbcTxnIsolationLevel: String, value: String): Throwable = { - new IllegalArgumentException( - s"Invalid value `$value` for parameter `$jdbcTxnIsolationLevel`. This can be " + - "`NONE`, `READ_UNCOMMITTED`, `READ_COMMITTED`, `REPEATABLE_READ` or `SERIALIZABLE`.") + def invalidJdbcTxnIsolationLevelError( + jdbcTxnIsolationLevel: String, value: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2081", + messageParameters = Map("value" -> value, "jdbcTxnIsolationLevel" -> jdbcTxnIsolationLevel)) } - def cannotGetJdbcTypeError(dt: DataType): Throwable = { - new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}") + def cannotGetJdbcTypeError(dt: DataType): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2082", + messageParameters = Map("catalogString" -> dt.catalogString)) } def unrecognizedSqlTypeError(sqlType: Int): Throwable = { @@ -1008,27 +1024,35 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Map("typeName" -> sqlType.toString)) } - def unsupportedJdbcTypeError(content: String): Throwable = { - new SQLException(s"Unsupported type $content") + def unsupportedJdbcTypeError(content: String): SparkSQLException = { + new SparkSQLException( + errorClass = "_LEGACY_ERROR_TEMP_2083", + messageParameters = Map("content" -> content)) } - def unsupportedArrayElementTypeBasedOnBinaryError(dt: DataType): Throwable = { - new IllegalArgumentException(s"Unsupported array element " + - s"type ${dt.catalogString} based on binary") + def unsupportedArrayElementTypeBasedOnBinaryError(dt: DataType): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2084", + messageParameters = Map("catalogString" -> dt.catalogString)) } - def nestedArraysUnsupportedError(): Throwable = { - new IllegalArgumentException("Nested arrays unsupported") + def nestedArraysUnsupportedError(): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2085", + messageParameters = Map.empty) } - def cannotTranslateNonNullValueForFieldError(pos: Int): Throwable = { - new IllegalArgumentException(s"Can't translate non-null value for field $pos") + def cannotTranslateNonNullValueForFieldError(pos: Int): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2086", + messageParameters = Map("pos" -> pos.toString())) } - def invalidJdbcNumPartitionsError(n: Int, jdbcNumPartitions: String): Throwable = { - new IllegalArgumentException( - s"Invalid value `$n` for parameter `$jdbcNumPartitions` in table writing " + - "via JDBC. The minimum value is 1.") + def invalidJdbcNumPartitionsError( + n: Int, jdbcNumPartitions: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2087", + messageParameters = Map("n" -> n.toString(), "jdbcNumPartitions" -> jdbcNumPartitions)) } def transactionUnsupportedByJdbcServerError(): Throwable = { @@ -1037,72 +1061,99 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Map.empty[String, String]) } - def dataTypeUnsupportedYetError(dataType: DataType): Throwable = { - new UnsupportedOperationException(s"$dataType is not supported yet.") + def dataTypeUnsupportedYetError(dataType: DataType): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2088", + messageParameters = Map("dataType" -> dataType.toString())) } - def unsupportedOperationForDataTypeError(dataType: DataType): Throwable = { - new UnsupportedOperationException(s"DataType: ${dataType.catalogString}") + def unsupportedOperationForDataTypeError( + dataType: DataType): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2089", + messageParameters = Map("catalogString" -> dataType.catalogString)) } def inputFilterNotFullyConvertibleError(owner: String): Throwable = { - new SparkException(s"The input filter of $owner should be fully convertible.") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2090", + messageParameters = Map("owner" -> owner), + cause = null) } def cannotReadFooterForFileError(file: Path, e: IOException): Throwable = { - new SparkException(s"Could not read footer for file: $file", e) + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2091", + messageParameters = Map("file" -> file.toString()), + cause = e) } def cannotReadFooterForFileError(file: FileStatus, e: RuntimeException): Throwable = { - new IOException(s"Could not read footer for file: $file", e) + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2092", + messageParameters = Map("file" -> file.toString()), + cause = e) } def foundDuplicateFieldInCaseInsensitiveModeError( - requiredFieldName: String, matchedOrcFields: String): Throwable = { - new RuntimeException( - s""" - |Found duplicate field(s) "$requiredFieldName": $matchedOrcFields - |in case-insensitive mode - """.stripMargin.replaceAll("\n", " ")) + requiredFieldName: String, matchedOrcFields: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2093", + messageParameters = Map( + "requiredFieldName" -> requiredFieldName, + "matchedOrcFields" -> matchedOrcFields)) } def foundDuplicateFieldInFieldIdLookupModeError( - requiredId: Int, matchedFields: String): Throwable = { - new RuntimeException( - s""" - |Found duplicate field(s) "$requiredId": $matchedFields - |in id mapping mode - """.stripMargin.replaceAll("\n", " ")) + requiredId: Int, matchedFields: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2094", + messageParameters = Map( + "requiredId" -> requiredId.toString(), + "matchedFields" -> matchedFields)) } def failedToMergeIncompatibleSchemasError( left: StructType, right: StructType, e: Throwable): Throwable = { - new SparkException(s"Failed to merge incompatible schemas $left and $right", e) + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2095", + messageParameters = Map("left" -> left.toString(), "right" -> right.toString()), + cause = e) } - def ddlUnsupportedTemporarilyError(ddl: String): Throwable = { - new UnsupportedOperationException(s"$ddl is not supported temporarily.") + def ddlUnsupportedTemporarilyError(ddl: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2096", + messageParameters = Map("ddl" -> ddl)) } def executeBroadcastTimeoutError(timeout: Long, ex: Option[TimeoutException]): Throwable = { new SparkException( - s""" - |Could not execute broadcast in $timeout secs. You can increase the timeout - |for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or disable broadcast join - |by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 - """.stripMargin.replaceAll("\n", " "), ex.orNull) + errorClass = "_LEGACY_ERROR_TEMP_2097", + messageParameters = Map( + "timeout" -> timeout.toString(), + "broadcastTimeout" -> toSQLConf(SQLConf.BROADCAST_TIMEOUT.key), + "autoBroadcastJoinThreshold" -> toSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key)), + cause = ex.orNull) } - def cannotCompareCostWithTargetCostError(cost: String): Throwable = { - new IllegalArgumentException(s"Could not compare cost with $cost") + def cannotCompareCostWithTargetCostError(cost: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2098", + messageParameters = Map("cost" -> cost)) } - def unsupportedDataTypeError(dt: String): Throwable = { - new UnsupportedOperationException(s"Unsupported data type: ${dt}") + def unsupportedDataTypeError(dt: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2099", + messageParameters = Map("dt" -> dt)) } def notSupportTypeError(dataType: DataType): Throwable = { - new Exception(s"not support type: $dataType") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2100", + messageParameters = Map("dataType" -> dataType.toString()), + cause = null) } def notSupportNonPrimitiveTypeError(): Throwable = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 385f749736846..82731cdb220a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import scala.collection.mutable.ArraySeq + +import org.json4s.JsonAST.{JArray, JObject, JString} import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.must.Matchers import org.scalatest.matchers.should.Matchers._ @@ -91,6 +94,14 @@ class RowTest extends AnyFunSpec with Matchers { it("getAs() on type extending AnyVal does not throw exception when value is null") { sampleRowWithoutCol3.getAs[String](sampleRowWithoutCol3.fieldIndex("col1")) shouldBe null } + + it("json should convert a mutable array to JSON") { + val schema = new StructType().add(StructField("list", ArrayType(StringType))) + val values = ArraySeq("1", "2", "3") + val row = new GenericRowWithSchema(Array(values), schema) + val expectedList = JArray(JString("1") :: JString("2") :: JString("3") :: Nil) + row.jsonValue shouldBe new JObject(("list", expectedList) :: Nil) + } } describe("row equals") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index dc007b5238648..d6a63dd0d0bf5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -719,7 +719,17 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some($"b" === $"d")) - assertAnalysisError(plan2, "EqualTo does not support ordering on type map" :: Nil) + + assertAnalysisErrorClass( + inputPlan = plan2, + expectedErrorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + expectedMessageParameters = Map( + "functionName" -> "EqualTo", + "dataType" -> "\"MAP\"", + "sqlExpr" -> "\"(b = d)\"" + ), + caseSensitive = true + ) } test("PredicateSubQuery is used outside of a allowed nodes") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 933645a413540..633f9a648157d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -25,10 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper { +class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with QueryErrorsBase { val testRelation = LocalRelation( $"intField".int, @@ -52,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper { SimpleAnalyzer.checkAnalysis(analyzed) } - def assertErrorForDifferingTypes( + def assertErrorForBinaryDifferingTypes( expr: Expression, messageParameters: Map[String, String]): Unit = { checkError( exception = intercept[AnalysisException] { @@ -62,6 +63,26 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper { parameters = messageParameters) } + def assertErrorForOrderingTypes( + expr: Expression, messageParameters: Map[String, String]): Unit = { + checkError( + exception = intercept[AnalysisException] { + assertSuccess(expr) + }, + errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + parameters = messageParameters) + } + + def assertErrorForDataDifferingTypes( + expr: Expression, messageParameters: Map[String, String]): Unit = { + checkError( + exception = intercept[AnalysisException] { + assertSuccess(expr) + }, + errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + parameters = messageParameters) + } + def assertForWrongType(expr: Expression, messageParameters: Map[String, String]): Unit = { checkError( exception = intercept[AnalysisException] { @@ -94,49 +115,49 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper { assertSuccess(Remainder($"intField", $"stringField")) // checkAnalysis(BitwiseAnd($"intField", $"stringField")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = Add($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField + booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = Subtract($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField - booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = Multiply($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField * booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = Divide($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField / booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = Remainder($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField % booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = BitwiseAnd($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField & booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = BitwiseOr($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField | booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = BitwiseXor($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField ^ booleanField)\"", @@ -211,13 +232,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper { assertSuccess(EqualNullSafe($"intField", $"booleanField")) } withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = EqualTo($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField = booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = EqualNullSafe($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField <=> booleanField)\"", @@ -225,55 +246,99 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper { "right" -> "\"BOOLEAN\"")) } - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = EqualTo($"intField", $"mapField"), messageParameters = Map( "sqlExpr" -> "\"(intField = mapField)\"", "left" -> "\"INT\"", "right" -> "\"MAP\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = EqualNullSafe($"intField", $"mapField"), messageParameters = Map( "sqlExpr" -> "\"(intField <=> mapField)\"", "left" -> "\"INT\"", "right" -> "\"MAP\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = LessThan($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField < booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = LessThanOrEqual($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField <= booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = GreaterThan($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField > booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertErrorForDifferingTypes( + assertErrorForBinaryDifferingTypes( expr = GreaterThanOrEqual($"intField", $"booleanField"), messageParameters = Map( "sqlExpr" -> "\"(intField >= booleanField)\"", "left" -> "\"INT\"", "right" -> "\"BOOLEAN\"")) - assertError(EqualTo($"mapField", $"mapField"), - "EqualTo does not support ordering on type map") - assertError(EqualNullSafe($"mapField", $"mapField"), - "EqualNullSafe does not support ordering on type map") - assertError(LessThan($"mapField", $"mapField"), - "LessThan does not support ordering on type map") - assertError(LessThanOrEqual($"mapField", $"mapField"), - "LessThanOrEqual does not support ordering on type map") - assertError(GreaterThan($"mapField", $"mapField"), - "GreaterThan does not support ordering on type map") - assertError(GreaterThanOrEqual($"mapField", $"mapField"), - "GreaterThanOrEqual does not support ordering on type map") + assertErrorForOrderingTypes( + expr = EqualTo($"mapField", $"mapField"), + messageParameters = Map( + "sqlExpr" -> "\"(mapField = mapField)\"", + "functionName" -> "EqualTo", + "dataType" -> "\"MAP\"" + ) + ) + assertErrorForOrderingTypes( + expr = EqualTo($"mapField", $"mapField"), + messageParameters = Map( + "sqlExpr" -> "\"(mapField = mapField)\"", + "functionName" -> "EqualTo", + "dataType" -> "\"MAP\"" + ) + ) + assertErrorForOrderingTypes( + expr = EqualNullSafe($"mapField", $"mapField"), + messageParameters = Map( + "sqlExpr" -> "\"(mapField <=> mapField)\"", + "functionName" -> "EqualNullSafe", + "dataType" -> "\"MAP\"" + ) + ) + assertErrorForOrderingTypes( + expr = LessThan($"mapField", $"mapField"), + messageParameters = Map( + "sqlExpr" -> "\"(mapField < mapField)\"", + "functionName" -> "LessThan", + "dataType" -> "\"MAP\"" + ) + ) + assertErrorForOrderingTypes( + expr = LessThanOrEqual($"mapField", $"mapField"), + messageParameters = Map( + "sqlExpr" -> "\"(mapField <= mapField)\"", + "functionName" -> "LessThanOrEqual", + "dataType" -> "\"MAP\"" + ) + ) + assertErrorForOrderingTypes( + expr = GreaterThan($"mapField", $"mapField"), + messageParameters = Map( + "sqlExpr" -> "\"(mapField > mapField)\"", + "functionName" -> "GreaterThan", + "dataType" -> "\"MAP\"" + ) + ) + assertErrorForOrderingTypes( + expr = GreaterThanOrEqual($"mapField", $"mapField"), + messageParameters = Map( + "sqlExpr" -> "\"(mapField >= mapField)\"", + "functionName" -> "GreaterThanOrEqual", + "dataType" -> "\"MAP\"" + ) + ) assertError(If($"intField", $"stringField", $"stringField"), "type of predicate expression in If should be boolean") @@ -305,18 +370,45 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper { assertSuccess(new BoolAnd($"booleanField")) assertSuccess(new BoolOr($"booleanField")) - assertError(Min($"mapField"), "min does not support ordering on type") - assertError(Max($"mapField"), "max does not support ordering on type") + assertErrorForOrderingTypes( + expr = Min($"mapField"), + messageParameters = Map( + "sqlExpr" -> "\"min(mapField)\"", + "functionName" -> "function min", + "dataType" -> "\"MAP\"" + ) + ) + assertErrorForOrderingTypes( + expr = Max($"mapField"), + messageParameters = Map( + "sqlExpr" -> "\"max(mapField)\"", + "functionName" -> "function max", + "dataType" -> "\"MAP\"" + ) + ) assertError(Sum($"booleanField"), "function sum requires numeric or interval types") assertError(Average($"booleanField"), "function average requires numeric or interval types") } test("check types for others") { - assertError(CreateArray(Seq($"intField", $"booleanField")), - "input to function array should all be the same type") - assertError(Coalesce(Seq($"intField", $"booleanField")), - "input to function coalesce should all be the same type") + assertErrorForDataDifferingTypes( + expr = CreateArray(Seq($"intField", $"booleanField")), + messageParameters = Map( + "sqlExpr" -> "\"array(intField, booleanField)\"", + "functionName" -> "function array", + "dataType" -> "(\"INT\" or \"BOOLEAN\")" + ) + ) + assertErrorForDataDifferingTypes( + expr = Coalesce(Seq($"intField", $"booleanField")), + messageParameters = Map( + "sqlExpr" -> "\"coalesce(intField, booleanField)\"", + "functionName" -> "function coalesce", + "dataType" -> "(\"INT\" or \"BOOLEAN\")" + ) + ) + assertError(Coalesce(Nil), "function coalesce requires at least one argument") assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") assertError(new XxHash64(Nil), "function xxhash64 requires at least one argument") @@ -437,8 +529,15 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper { assertError(operator(Seq($"booleanField")), "requires at least two arguments") assertError(operator(Seq($"intField", $"stringField")), "should all have the same type") - assertError(operator(Seq($"mapField", $"mapField")), - "does not support ordering") + val expr3 = operator(Seq($"mapField", $"mapField")) + assertErrorForOrderingTypes( + expr = expr3, + messageParameters = Map( + "sqlExpr" -> toSQLExpr(expr3), + "functionName" -> s"function ${expr3.prettyName}", + "dataType" -> "\"MAP\"" + ) + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 9a6caea59bf03..9839b784e6059 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -263,8 +263,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val map = MapConcat(Seq(mapOfMap, mapOfMap2)) map.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") - case TypeCheckResult.TypeCheckFailure(msg) => - assert(msg.contains("The key of map cannot be/contain map")) + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass === "INVALID_MAP_KEY_TYPE") + assert(messageParameters === Map("keyType" -> "\"MAP\"")) } } @@ -341,8 +342,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper arrayType(keyType = MapType(IntegerType, IntegerType), valueType = IntegerType))) map.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") - case TypeCheckResult.TypeCheckFailure(msg) => - assert(msg.contains("The key of map cannot be/contain map")) + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass === "INVALID_MAP_KEY_TYPE") + assert(messageParameters === Map("keyType" -> "\"MAP\"")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 3d9416fda4596..fb6a23e3d776c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -310,8 +310,9 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { )) map2.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") - case TypeCheckResult.TypeCheckFailure(msg) => - assert(msg.contains("The key of map cannot be/contain map")) + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "INVALID_MAP_KEY_TYPE") + assert(messageParameters === Map("keyType" -> "\"MAP\"")) } } @@ -371,8 +372,9 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.create(Seq(1), ArrayType(IntegerType))) map.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") - case TypeCheckResult.TypeCheckFailure(msg) => - assert(msg.contains("The key of map cannot be/contain map")) + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "INVALID_MAP_KEY_TYPE") + assert(messageParameters === Map("keyType" -> "\"MAP\"")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index b1c4c4414274c..a6546d8a5dbb0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -524,8 +524,9 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val map = transformKeys(ai0, makeMap) map.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") - case TypeCheckResult.TypeCheckFailure(msg) => - assert(msg.contains("The key of map cannot be/contain map")) + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "INVALID_MAP_KEY_TYPE") + assert(messageParameters === Map("keyType" -> "\"MAP\"")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 0d1464ffdb029..5e5d0f7445e37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -239,7 +239,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { In(map, Seq(map)).checkInputDataTypes() match { case TypeCheckResult.TypeCheckFailure(msg) => assert(msg.contains("function in does not support ordering on type map")) - case _ => fail("In should not work on map type") + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "INVALID_ORDERING_TYPE") + assert(messageParameters === Map( + "functionName" -> "function in", "dataType" -> "\"MAP\"")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala index bc6852ca7e1fd..b209b93ce4d1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.types._ class TypeUtilsSuite extends SparkFunSuite { @@ -28,7 +28,8 @@ class TypeUtilsSuite extends SparkFunSuite { } private def typeCheckFail(types: Seq[DataType]): Unit = { - assert(TypeUtils.checkForSameTypeInputExpr(types, "a").isInstanceOf[TypeCheckFailure]) + assert(TypeUtils.checkForSameTypeInputExpr(types, "a") + .isInstanceOf[DataTypeMismatch]) } test("checkForSameTypeInputExpr") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala index a48eb04a98806..dd3d77f26cdd3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector.catalog import java.util import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException, PartitionsAlreadyExistException} +import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionsAlreadyExistException} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType @@ -39,7 +39,7 @@ class InMemoryAtomicPartitionTable ( ident: InternalRow, properties: util.Map[String, String]): Unit = { if (memoryTablePartitions.containsKey(ident)) { - throw new PartitionAlreadyExistsException(name, ident, partitionSchema) + throw new PartitionsAlreadyExistException(name, ident, partitionSchema) } else { createPartitionKey(ident.toSeq(schema)) memoryTablePartitions.put(ident, properties) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala index 660140e282ecb..7280d6a5b0776 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala @@ -23,7 +23,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionsAlreadyExistException} import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType @@ -51,7 +51,7 @@ class InMemoryPartitionTable( ident: InternalRow, properties: util.Map[String, String]): Unit = { if (memoryTablePartitions.containsKey(ident)) { - throw new PartitionAlreadyExistsException(name, ident, partitionSchema) + throw new PartitionsAlreadyExistException(name, ident, partitionSchema) } else { createPartitionKey(ident.toSeq(schema)) memoryTablePartitions.put(ident, properties) @@ -111,7 +111,7 @@ class InMemoryPartitionTable( override def renamePartition(from: InternalRow, to: InternalRow): Boolean = { if (memoryTablePartitions.containsKey(to)) { - throw new PartitionAlreadyExistsException(name, to, partitionSchema) + throw new PartitionsAlreadyExistException(name, to, partitionSchema) } else { val partValue = memoryTablePartitions.remove(from) if (partValue == null) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala index e5aeb90b841a6..7f7c529944501 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionsAlreadyExistException} import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference} import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -218,10 +218,10 @@ class SupportsPartitionManagementSuite extends SparkFunSuite { test("renamePartition") { val partTable = createMultiPartTable() - val errMsg1 = intercept[PartitionAlreadyExistsException] { + val errMsg1 = intercept[PartitionsAlreadyExistException] { partTable.renamePartition(InternalRow(0, "abc"), InternalRow(1, "abc")) }.getMessage - assert(errMsg1.contains("Partition already exists")) + assert(errMsg1.contains("partitions already exist")) val newPart = InternalRow(2, "xyz") val errMsg2 = intercept[NoSuchPartitionException] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index d43a906067770..70474f4d5c43b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -155,6 +155,10 @@ private[sql] object PythonSQLUtils extends Logging { Column(TimestampDiff(unit, start.expr, end.expr)) } + def pandasProduct(e: Column, ignoreNA: Boolean): Column = { + Column(PandasProduct(e.expr, ignoreNA).toAggregateExpression(false)) + } + def pandasStddev(e: Column, ddof: Int): Column = { Column(PandasStddev(e.expr, ddof).toAggregateExpression(false)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndexOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndexOptions.scala new file mode 100644 index 0000000000000..1c352e3748f21 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndexOptions.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +object FileIndexOptions extends DataSourceOptions { + val IGNORE_MISSING_FILES = newOption(FileSourceOptions.IGNORE_MISSING_FILES) + val TIME_ZONE = newOption(DateTimeUtils.TIMEZONE_OPTION) + val RECURSIVE_FILE_LOOKUP = newOption("recursiveFileLookup") + val BASE_PATH_PARAM = newOption("basePath") + val MODIFIED_BEFORE = newOption("modifiedbefore") + val MODIFIED_AFTER = newOption("modifiedafter") + val PATH_GLOB_FILTER = newOption("pathglobfilter") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index d70c4b11bc0d7..53be85ad44844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.FileFormat.createMetadataInternalRow import org.apache.spark.sql.types.StructType @@ -43,7 +43,6 @@ abstract class PartitioningAwareFileIndex( parameters: Map[String, String], userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends FileIndex with Logging { - import PartitioningAwareFileIndex.BASE_PATH_PARAM /** Returns the specification of the partitions inferred from the data. */ def partitionSpec(): PartitionSpec @@ -64,7 +63,7 @@ abstract class PartitioningAwareFileIndex( pathFilters.forall(_.accept(file)) protected lazy val recursiveFileLookup: Boolean = { - caseInsensitiveMap.getOrElse("recursiveFileLookup", "false").toBoolean + caseInsensitiveMap.getOrElse(FileIndexOptions.RECURSIVE_FILE_LOOKUP, "false").toBoolean } override def listFiles( @@ -178,7 +177,7 @@ abstract class PartitioningAwareFileIndex( }.keys.toSeq val caseInsensitiveOptions = CaseInsensitiveMap(parameters) - val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) + val timeZoneId = caseInsensitiveOptions.get(FileIndexOptions.TIME_ZONE) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) PartitioningUtils.parsePartitions( @@ -248,11 +247,12 @@ abstract class PartitioningAwareFileIndex( * and the returned DataFrame will have the column of `something`. */ private def basePaths: Set[Path] = { - caseInsensitiveMap.get(BASE_PATH_PARAM).map(new Path(_)) match { + caseInsensitiveMap.get(FileIndexOptions.BASE_PATH_PARAM).map(new Path(_)) match { case Some(userDefinedBasePath) => val fs = userDefinedBasePath.getFileSystem(hadoopConf) if (!fs.isDirectory(userDefinedBasePath)) { - throw new IllegalArgumentException(s"Option '$BASE_PATH_PARAM' must be a directory") + throw new IllegalArgumentException(s"Option '${FileIndexOptions.BASE_PATH_PARAM}' " + + s"must be a directory") } val qualifiedBasePath = fs.makeQualified(userDefinedBasePath) val qualifiedBasePathStr = qualifiedBasePath.toString @@ -279,7 +279,3 @@ abstract class PartitioningAwareFileIndex( !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } } - -object PartitioningAwareFileIndex { - val BASE_PATH_PARAM = "basePath" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index ef1c2bb5b4104..1c819f07038ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -21,7 +21,7 @@ import java.util.Locale import org.apache.orc.OrcConf.COMPRESS -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -45,9 +45,9 @@ class OrcOptions( val compressionCodec: String = { // `compression`, `orc.compress`(i.e., OrcConf.COMPRESS), and `spark.sql.orc.compression.codec` // are in order of precedence from highest to lowest. - val orcCompressionConf = parameters.get(COMPRESS.getAttribute) + val orcCompressionConf = parameters.get(ORC_COMPRESSION) val codecName = parameters - .get("compression") + .get(COMPRESSION) .orElse(orcCompressionConf) .getOrElse(sqlConf.orcCompressionCodec) .toLowerCase(Locale.ROOT) @@ -69,8 +69,10 @@ class OrcOptions( .getOrElse(sqlConf.isOrcSchemaMergingEnabled) } -object OrcOptions { - val MERGE_SCHEMA = "mergeSchema" +object OrcOptions extends DataSourceOptions { + val MERGE_SCHEMA = newOption("mergeSchema") + val ORC_COMPRESSION = newOption(COMPRESS.getAttribute) + val COMPRESSION = newOption("compression") // The ORC compression short names private val shortOrcCompressionCodecNames = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 07ed55b0b8f84..d20edbde00be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -47,9 +47,9 @@ class ParquetOptions( // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and // `spark.sql.parquet.compression.codec` // are in order of precedence from highest to lowest. - val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) + val parquetCompressionConf = parameters.get(PARQUET_COMPRESSION) val codecName = parameters - .get("compression") + .get(COMPRESSION) .orElse(parquetCompressionConf) .getOrElse(sqlConf.parquetCompressionCodec) .toLowerCase(Locale.ROOT) @@ -86,9 +86,7 @@ class ParquetOptions( } -object ParquetOptions { - val MERGE_SCHEMA = "mergeSchema" - +object ParquetOptions extends DataSourceOptions { // The parquet compression short names private val shortParquetCompressionCodecNames = Map( "none" -> CompressionCodecName.UNCOMPRESSED, @@ -104,15 +102,19 @@ object ParquetOptions { shortParquetCompressionCodecNames(name).name() } + val MERGE_SCHEMA = newOption("mergeSchema") + val PARQUET_COMPRESSION = newOption(ParquetOutputFormat.COMPRESSION) + val COMPRESSION = newOption("compression") + // The option controls rebasing of the DATE and TIMESTAMP values between // Julian and Proleptic Gregorian calendars. It impacts on the behaviour of the Parquet // datasource similarly to the SQL config `spark.sql.parquet.datetimeRebaseModeInRead`, // and can be set to the same values: `EXCEPTION`, `LEGACY` or `CORRECTED`. - val DATETIME_REBASE_MODE = "datetimeRebaseMode" + val DATETIME_REBASE_MODE = newOption("datetimeRebaseMode") // The option controls rebasing of the INT96 timestamp values between Julian and Proleptic // Gregorian calendars. It impacts on the behaviour of the Parquet datasource similarly to // the SQL config `spark.sql.parquet.int96RebaseModeInRead`. // The valid option values are: `EXCEPTION`, `LEGACY` or `CORRECTED`. - val INT96_REBASE_MODE = "int96RebaseMode" + val INT96_REBASE_MODE = newOption("int96RebaseMode") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala index d07e1957e8c6f..303129b4d576f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala @@ -43,10 +43,8 @@ class PathGlobFilter(filePatten: String) extends PathFilterStrategy { } object PathGlobFilter extends StrategyBuilder { - val PARAM_NAME = "pathglobfilter" - override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { - parameters.get(PARAM_NAME).map(new PathGlobFilter(_)) + parameters.get(FileIndexOptions.PATH_GLOB_FILTER).map(new PathGlobFilter(_)) } } @@ -111,12 +109,10 @@ class ModifiedBeforeFilter(thresholdTime: Long, val timeZoneId: String) object ModifiedBeforeFilter extends StrategyBuilder { import ModifiedDateFilter._ - val PARAM_NAME = "modifiedbefore" - override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { - parameters.get(PARAM_NAME).map { value => + parameters.get(FileIndexOptions.MODIFIED_BEFORE).map { value => val timeZoneId = getTimeZoneId(parameters) - val thresholdTime = toThreshold(value, timeZoneId, PARAM_NAME) + val thresholdTime = toThreshold(value, timeZoneId, FileIndexOptions.MODIFIED_BEFORE) new ModifiedBeforeFilter(thresholdTime, timeZoneId) } } @@ -137,12 +133,10 @@ class ModifiedAfterFilter(thresholdTime: Long, val timeZoneId: String) object ModifiedAfterFilter extends StrategyBuilder { import ModifiedDateFilter._ - val PARAM_NAME = "modifiedafter" - override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { - parameters.get(PARAM_NAME).map { value => + parameters.get(FileIndexOptions.MODIFIED_AFTER).map { value => val timeZoneId = getTimeZoneId(parameters) - val thresholdTime = toThreshold(value, timeZoneId, PARAM_NAME) + val thresholdTime = toThreshold(value, timeZoneId, FileIndexOptions.MODIFIED_AFTER) new ModifiedAfterFilter(thresholdTime, timeZoneId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index f1a1d465d1b8c..f26f05cbe1c55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.text import java.nio.charset.{Charset, StandardCharsets} -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} /** @@ -44,8 +44,8 @@ class TextOptions(@transient private val parameters: CaseInsensitiveMap[String]) val encoding: Option[String] = parameters.get(ENCODING) - val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { lineSep => - require(lineSep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + val lineSeparator: Option[String] = parameters.get(LINE_SEP).map { lineSep => + require(lineSep.nonEmpty, s"'$LINE_SEP' cannot be an empty string.") lineSep } @@ -58,9 +58,9 @@ class TextOptions(@transient private val parameters: CaseInsensitiveMap[String]) lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } -private[datasources] object TextOptions { - val COMPRESSION = "compression" - val WHOLETEXT = "wholetext" - val ENCODING = "encoding" - val LINE_SEPARATOR = "lineSep" +private[sql] object TextOptions extends DataSourceOptions { + val COMPRESSION = newOption("compression") + val WHOLETEXT = newOption("wholetext") + val ENCODING = newOption("encoding") + val LINE_SEP = newOption("lineSep") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index a5c1c735cbd7b..ae09095590865 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -23,7 +23,7 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.datasources.{ModifiedAfterFilter, ModifiedBeforeFilter} +import org.apache.spark.sql.execution.datasources.FileIndexOptions import org.apache.spark.util.Utils /** @@ -36,7 +36,7 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging checkDisallowedOptions() private def checkDisallowedOptions(): Unit = { - Seq(ModifiedBeforeFilter.PARAM_NAME, ModifiedAfterFilter.PARAM_NAME).foreach { param => + Seq(FileIndexOptions.MODIFIED_BEFORE, FileIndexOptions.MODIFIED_AFTER).foreach { param => if (parameters.contains(param)) { throw new IllegalArgumentException(s"option '$param' is not allowed in file stream sources") } diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 93fe7cf5956ac..2078d3d8eb686 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -3605,7 +3605,21 @@ SELECT array(INTERVAL 1 MONTH, INTERVAL 20 DAYS) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'array(INTERVAL '1' MONTH, INTERVAL '20' DAY)' due to data type mismatch: input to function array should all be the same type, but it's [interval month, interval day]; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + "messageParameters" : { + "dataType" : "(\"INTERVAL MONTH\" or \"INTERVAL DAY\")", + "functionName" : "function array", + "sqlExpr" : "\"array(INTERVAL '1' MONTH, INTERVAL '20' DAY)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 48, + "fragment" : "array(INTERVAL 1 MONTH, INTERVAL 20 DAYS)" + } ] +} -- !query @@ -3630,7 +3644,21 @@ SELECT coalesce(INTERVAL 1 MONTH, INTERVAL 20 DAYS) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'coalesce(INTERVAL '1' MONTH, INTERVAL '20' DAY)' due to data type mismatch: input to function coalesce should all be the same type, but it's [interval month, interval day]; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + "messageParameters" : { + "dataType" : "(\"INTERVAL MONTH\" or \"INTERVAL DAY\")", + "functionName" : "function coalesce", + "sqlExpr" : "\"coalesce(INTERVAL '1' MONTH, INTERVAL '20' DAY)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 51, + "fragment" : "coalesce(INTERVAL 1 MONTH, INTERVAL 20 DAYS)" + } ] +} -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out index cd7cf9a60ce37..c7c6f5578e76d 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out @@ -69,7 +69,23 @@ select map_contains_key(map('1', 'a', '2', 'b'), 1) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_contains_key(map('1', 'a', '2', 'b'), 1)' due to data type mismatch: Input to function map_contains_key should have been map followed by a value with same key type, but it's [map, int].; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.MAP_CONTAINS_KEY_DIFF_TYPES", + "messageParameters" : { + "dataType" : "\"MAP\"", + "functionName" : "map_contains_key", + "leftType" : "\"MAP\"", + "rightType" : "\"INT\"", + "sqlExpr" : "\"map_contains_key(map(1, a, 2, b), 1)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 51, + "fragment" : "map_contains_key(map('1', 'a', '2', 'b'), 1)" + } ] +} -- !query @@ -78,4 +94,20 @@ select map_contains_key(map(1, 'a', 2, 'b'), '1') struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_contains_key(map(1, 'a', 2, 'b'), '1')' due to data type mismatch: Input to function map_contains_key should have been map followed by a value with same key type, but it's [map, string].; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.MAP_CONTAINS_KEY_DIFF_TYPES", + "messageParameters" : { + "dataType" : "\"MAP\"", + "functionName" : "map_contains_key", + "leftType" : "\"MAP\"", + "rightType" : "\"STRING\"", + "sqlExpr" : "\"map_contains_key(map(1, a, 2, b), 1)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 49, + "fragment" : "map_contains_key(map(1, 'a', 2, 'b'), '1')" + } ] +} \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 88fd0f538514e..6eb5fb4ce8447 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -3418,7 +3418,21 @@ SELECT array(INTERVAL 1 MONTH, INTERVAL 20 DAYS) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'array(INTERVAL '1' MONTH, INTERVAL '20' DAY)' due to data type mismatch: input to function array should all be the same type, but it's [interval month, interval day]; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + "messageParameters" : { + "dataType" : "(\"INTERVAL MONTH\" or \"INTERVAL DAY\")", + "functionName" : "function array", + "sqlExpr" : "\"array(INTERVAL '1' MONTH, INTERVAL '20' DAY)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 48, + "fragment" : "array(INTERVAL 1 MONTH, INTERVAL 20 DAYS)" + } ] +} -- !query @@ -3443,7 +3457,21 @@ SELECT coalesce(INTERVAL 1 MONTH, INTERVAL 20 DAYS) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'coalesce(INTERVAL '1' MONTH, INTERVAL '20' DAY)' due to data type mismatch: input to function coalesce should all be the same type, but it's [interval month, interval day]; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + "messageParameters" : { + "dataType" : "(\"INTERVAL MONTH\" or \"INTERVAL DAY\")", + "functionName" : "function coalesce", + "sqlExpr" : "\"coalesce(INTERVAL '1' MONTH, INTERVAL '20' DAY)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 51, + "fragment" : "coalesce(INTERVAL 1 MONTH, INTERVAL 20 DAYS)" + } ] +} -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/map.sql.out b/sql/core/src/test/resources/sql-tests/results/map.sql.out index cd7cf9a60ce37..c7c6f5578e76d 100644 --- a/sql/core/src/test/resources/sql-tests/results/map.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/map.sql.out @@ -69,7 +69,23 @@ select map_contains_key(map('1', 'a', '2', 'b'), 1) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_contains_key(map('1', 'a', '2', 'b'), 1)' due to data type mismatch: Input to function map_contains_key should have been map followed by a value with same key type, but it's [map, int].; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.MAP_CONTAINS_KEY_DIFF_TYPES", + "messageParameters" : { + "dataType" : "\"MAP\"", + "functionName" : "map_contains_key", + "leftType" : "\"MAP\"", + "rightType" : "\"INT\"", + "sqlExpr" : "\"map_contains_key(map(1, a, 2, b), 1)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 51, + "fragment" : "map_contains_key(map('1', 'a', '2', 'b'), 1)" + } ] +} -- !query @@ -78,4 +94,20 @@ select map_contains_key(map(1, 'a', 2, 'b'), '1') struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_contains_key(map(1, 'a', 2, 'b'), '1')' due to data type mismatch: Input to function map_contains_key should have been map followed by a value with same key type, but it's [map, string].; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.MAP_CONTAINS_KEY_DIFF_TYPES", + "messageParameters" : { + "dataType" : "\"MAP\"", + "functionName" : "map_contains_key", + "leftType" : "\"MAP\"", + "rightType" : "\"STRING\"", + "sqlExpr" : "\"map_contains_key(map(1, a, 2, b), 1)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 49, + "fragment" : "map_contains_key(map(1, 'a', 2, 'b'), '1')" + } ] +} \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out index 916d32c5e35c7..5c1f5d4d917f5 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -40,7 +40,6 @@ struct<> -- !query output - -- !query SELECT map_concat(boolean_map1, boolean_map2) boolean_map, @@ -91,7 +90,21 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.tinyint_map1, various_maps.array_map1)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,array>]; line 2 pos 4 +{ + "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + "messageParameters" : { + "dataType" : "(\"MAP\" or \"MAP, ARRAY>\")", + "functionName" : "function map_concat", + "sqlExpr" : "\"map_concat(tinyint_map1, array_map1)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 47, + "fragment" : "map_concat(tinyint_map1, array_map1)" + } ] +} -- !query @@ -102,7 +115,21 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.boolean_map1, various_maps.int_map2)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map]; line 2 pos 4 +{ + "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + "messageParameters" : { + "dataType" : "(\"MAP\" or \"MAP\")", + "functionName" : "function map_concat", + "sqlExpr" : "\"map_concat(boolean_map1, int_map2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 45, + "fragment" : "map_concat(boolean_map1, int_map2)" + } ] +} -- !query @@ -113,7 +140,21 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.int_map1, various_maps.struct_map2)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,struct>]; line 2 pos 4 +{ + "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + "messageParameters" : { + "dataType" : "(\"MAP\" or \"MAP, STRUCT>\")", + "functionName" : "function map_concat", + "sqlExpr" : "\"map_concat(int_map1, struct_map2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 44, + "fragment" : "map_concat(int_map1, struct_map2)" + } ] +} -- !query @@ -124,7 +165,21 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.struct_map1, various_maps.array_map2)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,struct>, map,array>]; line 2 pos 4 +{ + "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + "messageParameters" : { + "dataType" : "(\"MAP, STRUCT>\" or \"MAP, ARRAY>\")", + "functionName" : "function map_concat", + "sqlExpr" : "\"map_concat(struct_map1, array_map2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 46, + "fragment" : "map_concat(struct_map1, array_map2)" + } ] +} -- !query @@ -135,4 +190,18 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_concat(various_maps.int_map1, various_maps.array_map2)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,array>]; line 2 pos 4 +{ + "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + "messageParameters" : { + "dataType" : "(\"MAP\" or \"MAP, ARRAY>\")", + "functionName" : "function map_concat", + "sqlExpr" : "\"map_concat(int_map1, array_map2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 43, + "fragment" : "map_concat(int_map1, array_map2)" + } ] +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f58e30c71188f..90e2acfe5d688 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -913,8 +913,17 @@ class DataFrameAggregateSuite extends QueryTest val error = intercept[AnalysisException] { sql("SELECT max_by(x, y) FROM tempView").show } - assert( - error.message.contains("function max_by does not support ordering on type map")) + checkError( + exception = error, + errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + sqlState = None, + parameters = Map( + "functionName" -> "function max_by", + "dataType" -> "\"MAP\"", + "sqlExpr" -> "\"max_by(x, y)\"" + ), + context = ExpectedContext(fragment = "max_by(x, y)", start = 7, stop = 18) + ) } } @@ -974,8 +983,17 @@ class DataFrameAggregateSuite extends QueryTest val error = intercept[AnalysisException] { sql("SELECT min_by(x, y) FROM tempView").show } - assert( - error.message.contains("function min_by does not support ordering on type map")) + checkError( + exception = error, + errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + sqlState = None, + parameters = Map( + "functionName" -> "function min_by", + "dataType" -> "\"MAP\"", + "sqlExpr" -> "\"min_by(x, y)\"" + ), + context = ExpectedContext(fragment = "min_by(x, y)", start = 7, stop = 18) + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 2b6ac9f6580d1..41af747a83e41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1003,25 +1003,61 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3) checkAnswer(df3.select(map_concat($"map1", $"map2")), expected3) - val expectedMessage1 = "input to function map_concat should all be the same type" - - assert(intercept[AnalysisException] { - df2.selectExpr("map_concat(map1, map2)").collect() - }.getMessage().contains(expectedMessage1)) - - assert(intercept[AnalysisException] { - df2.select(map_concat($"map1", $"map2")).collect() - }.getMessage().contains(expectedMessage1)) + checkError( + exception = intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, map2)").collect() + }, + errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"map_concat(map1, map2)\"", + "dataType" -> "(\"MAP, INT>\" or \"MAP\")", + "functionName" -> "function map_concat"), + context = ExpectedContext( + fragment = "map_concat(map1, map2)", + start = 0, + stop = 21) + ) - val expectedMessage2 = "input to function map_concat should all be of type map" + checkError( + exception = intercept[AnalysisException] { + df2.select(map_concat($"map1", $"map2")).collect() + }, + errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"map_concat(map1, map2)\"", + "dataType" -> "(\"MAP, INT>\" or \"MAP\")", + "functionName" -> "function map_concat") + ) - assert(intercept[AnalysisException] { - df2.selectExpr("map_concat(map1, 12)").collect() - }.getMessage().contains(expectedMessage2)) + checkError( + exception = intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, 12)").collect() + }, + errorClass = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"map_concat(map1, 12)\"", + "dataType" -> "[\"MAP, INT>\", \"INT\"]", + "functionName" -> "function map_concat"), + context = ExpectedContext( + fragment = "map_concat(map1, 12)", + start = 0, + stop = 19) + ) - assert(intercept[AnalysisException] { - df2.select(map_concat($"map1", lit(12))).collect() - }.getMessage().contains(expectedMessage2)) + checkError( + exception = intercept[AnalysisException] { + df2.select(map_concat($"map1", lit(12))).collect() + }, + errorClass = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"map_concat(map1, 12)\"", + "dataType" -> "[\"MAP, INT>\", \"INT\"]", + "functionName" -> "function map_concat") + ) } test("map_from_entries function") { @@ -3606,10 +3642,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\"")) // scalastyle:on line.size.limit - val ex5 = intercept[AnalysisException] { - df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") - } - assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") + }, + errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"map_zip_with(mmi, mmi, lambdafunction(x, x, y, z))\"", + "dataType" -> "\"MAP\"", + "functionName" -> "function map_zip_with"), + context = ExpectedContext( + fragment = "map_zip_with(mmi, mmi, (x, y, z) -> x)", + start = 0, + stop = 37)) } test("transform keys function - primitive data types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala index 080cd89c4a209..6e67946a557ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionsAlreadyExistException} import org.apache.spark.sql.internal.SQLConf /** @@ -75,15 +75,15 @@ trait AlterTableRenamePartitionSuiteBase extends QueryTest with DDLCommandTestUt } } - test("target partition exists") { + test("target partitions exist") { withNamespaceAndTable("ns", "tbl") { t => createSinglePartTable(t) sql(s"INSERT INTO $t PARTITION (id = 2) SELECT 'def'") checkPartitions(t, Map("id" -> "1"), Map("id" -> "2")) - val errMsg = intercept[PartitionAlreadyExistsException] { + val errMsg = intercept[PartitionsAlreadyExistException] { sql(s"ALTER TABLE $t PARTITION (id = 1) RENAME TO PARTITION (id = 2)") }.getMessage - assert(errMsg.contains("Partition already exists")) + assert(errMsg.contains("partitions already exist")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala index 6b2308766f6c8..54287cc6a47bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala @@ -148,7 +148,7 @@ trait AlterTableAddPartitionSuiteBase extends command.AlterTableAddPartitionSuit " PARTITION (id=2) LOCATION 'loc1'") }.getMessage assert(errMsg === - """The following partitions already exists in table 'tbl' database 'ns': + """The following partitions already exist in table 'tbl' database 'ns': |Map(id -> 2)""".stripMargin) sql(s"ALTER TABLE $t ADD IF NOT EXISTS PARTITION (id=1) LOCATION 'loc'" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala index a238dfcf2dd9c..dc6e5a2909da0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala @@ -111,7 +111,7 @@ class AlterTableAddPartitionSuite sql(s"ALTER TABLE $t ADD PARTITION (id=1) LOCATION 'loc'" + " PARTITION (id=2) LOCATION 'loc1'") }.getMessage - assert(errMsg === s"The following partitions already exists in table $t:id -> 2") + assert(errMsg === s"The following partitions already exist in table $t:id -> 2") sql(s"ALTER TABLE $t ADD IF NOT EXISTS PARTITION (id=1) LOCATION 'loc'" + " PARTITION (id=2) LOCATION 'loc1'") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 1897a347ef175..07018508b91cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -539,6 +539,18 @@ class FileIndexSuite extends SharedSparkSession { } } } + + test("SPARK-40667: validate FileIndex Options") { + assert(FileIndexOptions.getAllOptions.size == 7) + // Please add validation on any new FileIndex options here + assert(FileIndexOptions.isValidOption("ignoreMissingFiles")) + assert(FileIndexOptions.isValidOption("timeZone")) + assert(FileIndexOptions.isValidOption("recursiveFileLookup")) + assert(FileIndexOptions.isValidOption("basePath")) + assert(FileIndexOptions.isValidOption("modifiedbefore")) + assert(FileIndexOptions.isValidOption("modifiedafter")) + assert(FileIndexOptions.isValidOption("pathglobfilter")) + } } object DeletionRaceFileSystem { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 8f2e1c76e4980..dbeadeb949abc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -38,6 +38,7 @@ import org.apache.logging.log4j.Level import org.apache.spark.{SparkConf, SparkException, SparkUpgradeException, TestUtils} import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Encoders, QueryTest, Row} +import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} import org.apache.spark.sql.execution.datasources.CommonFileDataSourceSuite import org.apache.spark.sql.internal.SQLConf @@ -3079,6 +3080,57 @@ abstract class CSVSuite } } } + + test("SPARK-40667: validate CSV Options") { + assert(CSVOptions.getAllOptions.size == 38) + // Please add validation on any new CSV options here + assert(CSVOptions.isValidOption("header")) + assert(CSVOptions.isValidOption("inferSchema")) + assert(CSVOptions.isValidOption("ignoreLeadingWhiteSpace")) + assert(CSVOptions.isValidOption("ignoreTrailingWhiteSpace")) + assert(CSVOptions.isValidOption("prefersDate")) + assert(CSVOptions.isValidOption("escapeQuotes")) + assert(CSVOptions.isValidOption("quoteAll")) + assert(CSVOptions.isValidOption("enforceSchema")) + assert(CSVOptions.isValidOption("quote")) + assert(CSVOptions.isValidOption("escape")) + assert(CSVOptions.isValidOption("comment")) + assert(CSVOptions.isValidOption("maxColumns")) + assert(CSVOptions.isValidOption("maxCharsPerColumn")) + assert(CSVOptions.isValidOption("mode")) + assert(CSVOptions.isValidOption("charToEscapeQuoteEscaping")) + assert(CSVOptions.isValidOption("locale")) + assert(CSVOptions.isValidOption("dateFormat")) + assert(CSVOptions.isValidOption("timestampFormat")) + assert(CSVOptions.isValidOption("timestampNTZFormat")) + assert(CSVOptions.isValidOption("enableDateTimeParsingFallback")) + assert(CSVOptions.isValidOption("multiLine")) + assert(CSVOptions.isValidOption("samplingRatio")) + assert(CSVOptions.isValidOption("emptyValue")) + assert(CSVOptions.isValidOption("lineSep")) + assert(CSVOptions.isValidOption("inputBufferSize")) + assert(CSVOptions.isValidOption("columnNameOfCorruptRecord")) + assert(CSVOptions.isValidOption("nullValue")) + assert(CSVOptions.isValidOption("nanValue")) + assert(CSVOptions.isValidOption("positiveInf")) + assert(CSVOptions.isValidOption("negativeInf")) + assert(CSVOptions.isValidOption("timeZone")) + assert(CSVOptions.isValidOption("unescapedQuoteHandling")) + assert(CSVOptions.isValidOption("encoding")) + assert(CSVOptions.isValidOption("charset")) + assert(CSVOptions.isValidOption("compression")) + assert(CSVOptions.isValidOption("codec")) + assert(CSVOptions.isValidOption("sep")) + assert(CSVOptions.isValidOption("delimiter")) + // Please add validation on any new parquet options with alternative here + assert(CSVOptions.getAlternativeOption("sep").contains("delimiter")) + assert(CSVOptions.getAlternativeOption("delimiter").contains("sep")) + assert(CSVOptions.getAlternativeOption("encoding").contains("charset")) + assert(CSVOptions.getAlternativeOption("charset").contains("encoding")) + assert(CSVOptions.getAlternativeOption("compression").contains("codec")) + assert(CSVOptions.getAlternativeOption("codec").contains("compression")) + assert(CSVOptions.getAlternativeOption("prefersDate").isEmpty) + } } class CSVv1Suite extends CSVSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 0826a6d126779..f3210b049d483 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -3380,6 +3380,43 @@ abstract class JsonSuite } } } + + test("SPARK-40667: validate JSON Options") { + assert(JSONOptions.getAllOptions.size == 28) + // Please add validation on any new Json options here + assert(JSONOptions.isValidOption("samplingRatio")) + assert(JSONOptions.isValidOption("primitivesAsString")) + assert(JSONOptions.isValidOption("prefersDecimal")) + assert(JSONOptions.isValidOption("allowComments")) + assert(JSONOptions.isValidOption("allowUnquotedFieldNames")) + assert(JSONOptions.isValidOption("allowSingleQuotes")) + assert(JSONOptions.isValidOption("allowNumericLeadingZeros")) + assert(JSONOptions.isValidOption("allowNonNumericNumbers")) + assert(JSONOptions.isValidOption("allowBackslashEscapingAnyCharacter")) + assert(JSONOptions.isValidOption("allowUnquotedControlChars")) + assert(JSONOptions.isValidOption("compression")) + assert(JSONOptions.isValidOption("mode")) + assert(JSONOptions.isValidOption("dropFieldIfAllNull")) + assert(JSONOptions.isValidOption("ignoreNullFields")) + assert(JSONOptions.isValidOption("locale")) + assert(JSONOptions.isValidOption("dateFormat")) + assert(JSONOptions.isValidOption("timestampFormat")) + assert(JSONOptions.isValidOption("timestampNTZFormat")) + assert(JSONOptions.isValidOption("enableDateTimeParsingFallback")) + assert(JSONOptions.isValidOption("multiLine")) + assert(JSONOptions.isValidOption("lineSep")) + assert(JSONOptions.isValidOption("pretty")) + assert(JSONOptions.isValidOption("inferTimestamp")) + assert(JSONOptions.isValidOption("columnNameOfCorruptRecord")) + assert(JSONOptions.isValidOption("timeZone")) + assert(JSONOptions.isValidOption("writeNonAsciiCharacterAsCodePoint")) + assert(JSONOptions.isValidOption("encoding")) + assert(JSONOptions.isValidOption("charset")) + // Please add validation on any new Json options with alternative here + assert(JSONOptions.getAlternativeOption("encoding").contains("charset")) + assert(JSONOptions.getAlternativeOption("charset").contains("encoding")) + assert(JSONOptions.getAlternativeOption("dateFormat").isEmpty) + } } class JsonV1Suite extends JsonSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 19477adec3960..94ce3d77962ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -580,6 +580,14 @@ abstract class OrcSuite "ORC sources shall write an empty file contains meta if necessary") } } + + test("SPARK-40667: validate Orc Options") { + assert(OrcOptions.getAllOptions.size == 3) + // Please add validation on any new Orc options here + assert(OrcOptions.isValidOption("mergeSchema")) + assert(OrcOptions.isValidOption("orc.compress")) + assert(OrcOptions.isValidOption("compression")) + } } abstract class OrcSourceSuite extends OrcSuite with SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 8b8c4918e5bb0..fea986cc8e2de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -1485,6 +1485,16 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } + test("SPARK-40667: validate Parquet Options") { + assert(ParquetOptions.getAllOptions.size == 5) + // Please add validation on any new parquet options here + assert(ParquetOptions.isValidOption("mergeSchema")) + assert(ParquetOptions.isValidOption("compression")) + assert(ParquetOptions.isValidOption("parquet.compression")) + assert(ParquetOptions.isValidOption("datetimeRebaseMode")) + assert(ParquetOptions.isValidOption("int96RebaseMode")) + } + test("SPARK-23173 Writing a file with data converted from JSON with and incorrect user schema") { withTempPath { file => val jsonData = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 1eb32ed285799..ff6b9aadf7cfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -236,6 +236,15 @@ abstract class TextSuite extends QueryTest with SharedSparkSession with CommonFi assert(data(3) == Row("\"doh\"")) assert(data.length == 4) } + + test("SPARK-40667: validate Text Options") { + assert(TextOptions.getAllOptions.size == 4) + // Please add validation on any new Text options here + assert(TextOptions.isValidOption("compression")) + assert(TextOptions.isValidOption("wholetext")) + assert(TextOptions.isValidOption("encoding")) + assert(TextOptions.isValidOption("lineSep")) + } } class TextV1Suite extends TextSuite { diff --git a/sql/create-docs.sh b/sql/create-docs.sh index 8721df874ee73..c5a36e0474eb0 100755 --- a/sql/create-docs.sh +++ b/sql/create-docs.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java index 79426e0e3de18..8ee606be314c2 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java @@ -17,6 +17,7 @@ package org.apache.hive.service.cli.operation; import java.io.CharArrayWriter; +import java.io.Serializable; import java.util.Map; import java.util.regex.Pattern; @@ -265,7 +266,7 @@ private static StringLayout initLayout(OperationLog.LoggingLevel loggingMode) { Map appenders = root.getAppenders(); for (Appender ap : appenders.values()) { if (ap.getClass().equals(ConsoleAppender.class)) { - Layout l = ap.getLayout(); + Layout l = ap.getLayout(); if (l instanceof StringLayout) { layout = (StringLayout) l; break; diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java index a261a54581828..6ee48186e7ea8 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java @@ -39,7 +39,7 @@ import org.apache.hive.service.cli.session.HiveSession; import org.apache.hive.service.rpc.thrift.TRowSet; import org.apache.hive.service.rpc.thrift.TTableSchema; -import org.apache.logging.log4j.core.appender.AbstractWriterAppender; +import org.apache.logging.log4j.core.Appender; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -82,7 +82,7 @@ public synchronized void stop() { private void initOperationLogCapture(String loggingMode) { // Register another Appender (with the same layout) that talks to us. - AbstractWriterAppender ap = new LogDivertAppender(this, OperationLog.getLoggingLevel(loggingMode)); + Appender ap = new LogDivertAppender(this, OperationLog.getLoggingLevel(loggingMode)); ((org.apache.logging.log4j.core.Logger)org.apache.logging.log4j.LogManager.getRootLogger()).addAppender(ap); ap.start(); } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index bef320174ecd7..db600bcd3d459 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -50,7 +50,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, NoSuchDatabaseException, NoSuchPartitionException, NoSuchPartitionsException, NoSuchTableException, PartitionAlreadyExistsException, PartitionsAlreadyExistException} +import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, NoSuchDatabaseException, NoSuchPartitionException, NoSuchPartitionsException, NoSuchTableException, PartitionsAlreadyExistException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.CatalogUtils.stringToURI @@ -707,7 +707,7 @@ private[hive] class HiveClientImpl( hiveTable.setOwner(userName) specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => if (shim.getPartition(client, hiveTable, newSpec.asJava, false) != null) { - throw new PartitionAlreadyExistsException(db, table, newSpec) + throw new PartitionsAlreadyExistException(db, table, newSpec) } val hivePart = getPartitionOption(rawHiveTable, oldSpec) .map { p => toHivePartition(p.copy(spec = newSpec), hiveTable) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 184e03d088cc5..e6abc7b96c17d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -528,7 +528,7 @@ class HiveClientSuite(version: String, allVersions: Seq[String]) val errMsg = intercept[PartitionsAlreadyExistException] { client.createPartitions("default", "src_part", partitions, ignoreIfExists = false) }.getMessage - assert(errMsg.contains("partitions already exists")) + assert(errMsg.contains("partitions already exist")) } finally { client.dropPartitions( "default",