From f060bd4520130c37ff81159acf0574d5f33cdcaf Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Tue, 21 Oct 2025 18:35:23 +0800 Subject: [PATCH] [SPARK-53934][CONNECT] Initial implement Connect JDBC driver --- sql/connect/client/jdbc/pom.xml | 7 + .../NonRegisteringSparkConnectDriver.scala | 8 +- .../client/jdbc/SparkConnectConnection.scala | 280 +++++++ .../jdbc/SparkConnectDatabaseMetaData.scala | 600 +++++++++++++++ .../client/jdbc/SparkConnectResultSet.scala | 682 ++++++++++++++++++ .../jdbc/SparkConnectResultSetMetaData.scala | 84 +++ .../client/jdbc/SparkConnectStatement.scala | 222 ++++++ .../client/jdbc/util/JdbcErrorUtils.scala | 40 + .../client/jdbc/util/JdbcTypeUtils.scala | 96 +++ .../client/jdbc/SparkConnectDriverSuite.scala | 67 +- .../jdbc/SparkConnectJdbcDataTypeSuite.scala | 218 ++++++ .../connect/client/jdbc/test/JdbcHelper.scala | 46 ++ 12 files changed, 2340 insertions(+), 10 deletions(-) create mode 100644 sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala create mode 100644 sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala create mode 100644 sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala create mode 100644 sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSetMetaData.scala create mode 100644 sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala create mode 100644 sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala create mode 100644 sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala create mode 100644 sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala create mode 100644 sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala diff --git a/sql/connect/client/jdbc/pom.xml b/sql/connect/client/jdbc/pom.xml index 9f2ba011004d4..c2dda12b1e639 100644 --- a/sql/connect/client/jdbc/pom.xml +++ b/sql/connect/client/jdbc/pom.xml @@ -111,6 +111,13 @@ tests test + + org.apache.spark + spark-connect-client-jvm_${scala.binary.version} + ${project.version} + tests + test + com.typesafe diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/NonRegisteringSparkConnectDriver.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/NonRegisteringSparkConnectDriver.scala index 1052f6d3e5605..09e386835b7e1 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/NonRegisteringSparkConnectDriver.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/NonRegisteringSparkConnectDriver.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connect.client.jdbc -import java.sql.{Connection, Driver, DriverPropertyInfo, SQLFeatureNotSupportedException} +import java.sql.{Connection, Driver, DriverPropertyInfo, SQLException, SQLFeatureNotSupportedException} import java.util.Properties import java.util.logging.Logger @@ -29,7 +29,11 @@ class NonRegisteringSparkConnectDriver extends Driver { override def acceptsURL(url: String): Boolean = url.startsWith("jdbc:sc://") override def connect(url: String, info: Properties): Connection = { - throw new UnsupportedOperationException("TODO(SPARK-53934)") + if (url == null) { + throw new SQLException("url must not be null") + } + + if (this.acceptsURL(url)) new SparkConnectConnection(url, info) else null } override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala new file mode 100644 index 0000000000000..95ec956771dbb --- /dev/null +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.jdbc + +import java.sql.{Array => JdbcArray, _} +import java.util +import java.util.Properties +import java.util.concurrent.Executor + +import org.apache.spark.sql.connect.SparkSession +import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.client.jdbc.util.JdbcErrorUtils._ + +class SparkConnectConnection(val url: String, val info: Properties) extends Connection { + + private[jdbc] val client = SparkConnectClient + .builder() + .loadFromEnvironment() + .userAgent("Spark Connect JDBC") + .connectionString(url.stripPrefix("jdbc:")) + .build() + + private[jdbc] val spark = SparkSession.builder().client(client).create() + + @volatile private var closed: Boolean = false + + override def isClosed: Boolean = closed + + override def close(): Unit = synchronized { + if (!closed) { + spark.close() + closed = true + } + } + + private[jdbc] def checkOpen(): Unit = { + if (closed) { + throw new SQLException("JDBC Connection is closed.") + } + if (!client.isSessionValid) { + throw new SQLException(s"Spark Connect Session ${client.sessionId} is invalid.") + } + } + + override def isValid(timeout: Int): Boolean = !closed && client.isSessionValid + + override def setCatalog(catalog: String): Unit = { + checkOpen() + spark.catalog.setCurrentCatalog(catalog) + } + + override def getCatalog: String = { + checkOpen() + spark.catalog.currentCatalog() + } + + override def setSchema(schema: String): Unit = { + checkOpen() + spark.catalog.setCurrentDatabase(schema) + } + + override def getSchema: String = { + checkOpen() + spark.catalog.currentDatabase + } + + override def getMetaData: DatabaseMetaData = { + checkOpen() + new SparkConnectDatabaseMetaData(this) + } + + override def createStatement(): Statement = { + checkOpen() + new SparkConnectStatement(this) + } + + override def prepareStatement(sql: String): PreparedStatement = + throw new SQLFeatureNotSupportedException + + override def prepareCall(sql: String): CallableStatement = + throw new SQLFeatureNotSupportedException + + override def createStatement( + resultSetType: Int, + resultSetConcurrency: Int, + resultSetHoldability: Int): Statement = + throw new SQLFeatureNotSupportedException + + override def prepareStatement( + sql: String, + resultSetType: Int, + resultSetConcurrency: Int, + resultSetHoldability: Int): PreparedStatement = + throw new SQLFeatureNotSupportedException + + override def prepareCall( + sql: String, + resultSetType: Int, + resultSetConcurrency: Int, + resultSetHoldability: Int): CallableStatement = + throw new SQLFeatureNotSupportedException + + override def prepareStatement( + sql: String, autoGeneratedKeys: Int): PreparedStatement = + throw new SQLFeatureNotSupportedException + + override def prepareStatement( + sql: String, columnIndexes: Array[Int]): PreparedStatement = + throw new SQLFeatureNotSupportedException + + override def prepareStatement( + sql: String, columnNames: Array[String]): PreparedStatement = + throw new SQLFeatureNotSupportedException + + override def createStatement( + resultSetType: Int, resultSetConcurrency: Int): Statement = + throw new SQLFeatureNotSupportedException + + override def prepareStatement( + sql: String, + resultSetType: Int, + resultSetConcurrency: Int): PreparedStatement = + throw new SQLFeatureNotSupportedException + + override def prepareCall( + sql: String, + resultSetType: Int, + resultSetConcurrency: Int): CallableStatement = + throw new SQLFeatureNotSupportedException + + override def nativeSQL(sql: String): String = + throw new SQLFeatureNotSupportedException + + override def setAutoCommit(autoCommit: Boolean): Unit = { + checkOpen() + if (!autoCommit) { + throw new SQLFeatureNotSupportedException("Only auto-commit mode is supported") + } + } + + override def getAutoCommit: Boolean = { + checkOpen() + true + } + + override def commit(): Unit = { + checkOpen() + throw new SQLException("Connection is in auto-commit mode") + } + + override def rollback(): Unit = { + checkOpen() + throw new SQLException("Connection is in auto-commit mode") + } + + override def setReadOnly(readOnly: Boolean): Unit = { + checkOpen() + if (readOnly) { + throw new SQLFeatureNotSupportedException("Read-only mode is not supported") + } + } + + override def isReadOnly: Boolean = { + checkOpen() + false + } + + override def setTransactionIsolation(level: Int): Unit = { + checkOpen() + if (level != Connection.TRANSACTION_NONE) { + throw new SQLFeatureNotSupportedException( + "Requested transaction isolation level " + + s"${stringfiyTransactionIsolationLevel(level)} is not supported") + } + } + + override def getTransactionIsolation: Int = { + checkOpen() + Connection.TRANSACTION_NONE + } + + override def getWarnings: SQLWarning = null + + override def clearWarnings(): Unit = {} + + override def getTypeMap: util.Map[String, Class[_]] = + throw new SQLFeatureNotSupportedException + + override def setTypeMap(map: util.Map[String, Class[_]]): Unit = + throw new SQLFeatureNotSupportedException + + override def setHoldability(holdability: Int): Unit = { + if (holdability != ResultSet.HOLD_CURSORS_OVER_COMMIT) { + throw new SQLFeatureNotSupportedException( + s"Holdability ${stringfiyHoldability(holdability)} is not supported") + } + } + + override def getHoldability: Int = ResultSet.HOLD_CURSORS_OVER_COMMIT + + override def setSavepoint(): Savepoint = + throw new SQLFeatureNotSupportedException + + override def setSavepoint(name: String): Savepoint = + throw new SQLFeatureNotSupportedException + + override def rollback(savepoint: Savepoint): Unit = + throw new SQLFeatureNotSupportedException + + override def releaseSavepoint(savepoint: Savepoint): Unit = + throw new SQLFeatureNotSupportedException + + override def createClob(): Clob = + throw new SQLFeatureNotSupportedException + + override def createBlob(): Blob = + throw new SQLFeatureNotSupportedException + + override def createNClob(): NClob = + throw new SQLFeatureNotSupportedException + + override def createSQLXML(): SQLXML = + throw new SQLFeatureNotSupportedException + + override def setClientInfo(name: String, value: String): Unit = + throw new SQLFeatureNotSupportedException + + override def setClientInfo(properties: Properties): Unit = + throw new SQLFeatureNotSupportedException + + override def getClientInfo(name: String): String = + throw new SQLFeatureNotSupportedException + + override def getClientInfo: Properties = + throw new SQLFeatureNotSupportedException + + override def createArrayOf(typeName: String, elements: Array[AnyRef]): JdbcArray = + throw new SQLFeatureNotSupportedException + + override def createStruct(typeName: String, attributes: Array[AnyRef]): Struct = + throw new SQLFeatureNotSupportedException + + override def abort(executor: Executor): Unit = { + if (executor == null) { + throw new SQLException("executor can not be null") + } + if (!closed) { + executor.execute { () => this.close() } + } + } + + override def setNetworkTimeout(executor: Executor, milliseconds: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def getNetworkTimeout: Int = + throw new SQLFeatureNotSupportedException + + override def unwrap[T](iface: Class[T]): T = if (isWrapperFor(iface)) { + iface.asInstanceOf[T] + } else { + throw new SQLException(s"${this.getClass.getName} not unwrappable from ${iface.getName}") + } + + override def isWrapperFor(iface: Class[_]): Boolean = iface.isInstance(this) +} diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala new file mode 100644 index 0000000000000..4efbd2b8f917f --- /dev/null +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala @@ -0,0 +1,600 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.jdbc + +import java.sql.{Array => _, _} + +import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} +import org.apache.spark.util.VersionUtils + +class SparkConnectDatabaseMetaData(conn: SparkConnectConnection) extends DatabaseMetaData { + + override def allProceduresAreCallable: Boolean = + throw new SQLFeatureNotSupportedException + + override def allTablesAreSelectable: Boolean = + throw new SQLFeatureNotSupportedException + + override def getURL: String = conn.url + + override def getUserName: String = conn.spark.client.configuration.userName + + override def isReadOnly: Boolean = false + + override def nullsAreSortedHigh: Boolean = + throw new SQLFeatureNotSupportedException + + override def nullsAreSortedLow: Boolean = + throw new SQLFeatureNotSupportedException + + override def nullsAreSortedAtStart: Boolean = + throw new SQLFeatureNotSupportedException + + override def nullsAreSortedAtEnd: Boolean = + throw new SQLFeatureNotSupportedException + + override def getDatabaseProductName: String = "Apache Spark Connect Server" + + override def getDatabaseProductVersion: String = conn.spark.version + + override def getDriverName: String = "Apache Spark Connect JDBC Driver" + + override def getDriverVersion: String = SPARK_VERSION + + override def getDriverMajorVersion: Int = VersionUtils.majorVersion(SPARK_VERSION) + + override def getDriverMinorVersion: Int = VersionUtils.minorVersion(SPARK_VERSION) + + override def usesLocalFiles: Boolean = + throw new SQLFeatureNotSupportedException + + override def usesLocalFilePerTable: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsMixedCaseIdentifiers: Boolean = + throw new SQLFeatureNotSupportedException + + override def storesUpperCaseIdentifiers: Boolean = + throw new SQLFeatureNotSupportedException + + override def storesLowerCaseIdentifiers: Boolean = + throw new SQLFeatureNotSupportedException + + override def storesMixedCaseIdentifiers: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsMixedCaseQuotedIdentifiers: Boolean = + throw new SQLFeatureNotSupportedException + + override def storesUpperCaseQuotedIdentifiers: Boolean = + throw new SQLFeatureNotSupportedException + + override def storesLowerCaseQuotedIdentifiers: Boolean = + throw new SQLFeatureNotSupportedException + + override def storesMixedCaseQuotedIdentifiers: Boolean = + throw new SQLFeatureNotSupportedException + + override def getIdentifierQuoteString: String = "`" + + override def getSQLKeywords: String = + throw new SQLFeatureNotSupportedException + + override def getNumericFunctions: String = + throw new SQLFeatureNotSupportedException + + override def getStringFunctions: String = + throw new SQLFeatureNotSupportedException + + override def getSystemFunctions: String = + throw new SQLFeatureNotSupportedException + + override def getTimeDateFunctions: String = + throw new SQLFeatureNotSupportedException + + override def getSearchStringEscape: String = + throw new SQLFeatureNotSupportedException + + override def getExtraNameCharacters: String = "" + + override def supportsAlterTableWithAddColumn: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsAlterTableWithDropColumn: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsColumnAliasing: Boolean = + throw new SQLFeatureNotSupportedException + + override def nullPlusNonNullIsNull: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsConvert: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsConvert(fromType: Int, toType: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsTableCorrelationNames: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsDifferentTableCorrelationNames: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsExpressionsInOrderBy: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsOrderByUnrelated: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsGroupBy: Boolean = true + + override def supportsGroupByUnrelated: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsGroupByBeyondSelect: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsLikeEscapeClause: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsMultipleResultSets: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsMultipleTransactions: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsNonNullableColumns: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsMinimumSQLGrammar: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsCoreSQLGrammar: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsExtendedSQLGrammar: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsANSI92EntryLevelSQL: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsANSI92IntermediateSQL: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsANSI92FullSQL: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsIntegrityEnhancementFacility: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsOuterJoins: Boolean = true + + override def supportsFullOuterJoins: Boolean = true + + override def supportsLimitedOuterJoins: Boolean = true + + override def getSchemaTerm: String = "schema" + + override def getProcedureTerm: String = "procedure" + + override def getCatalogTerm: String = "catalog" + + override def isCatalogAtStart: Boolean = true + + override def getCatalogSeparator: String = "." + + override def supportsSchemasInDataManipulation: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsSchemasInProcedureCalls: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsSchemasInTableDefinitions: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsSchemasInIndexDefinitions: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsSchemasInPrivilegeDefinitions: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsCatalogsInDataManipulation: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsCatalogsInProcedureCalls: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsCatalogsInTableDefinitions: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsCatalogsInIndexDefinitions: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsCatalogsInPrivilegeDefinitions: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsPositionedDelete: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsPositionedUpdate: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsSelectForUpdate: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsStoredProcedures: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsSubqueriesInComparisons: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsSubqueriesInExists: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsSubqueriesInIns: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsSubqueriesInQuantifieds: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsCorrelatedSubqueries: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsUnion: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsUnionAll: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsOpenCursorsAcrossCommit: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsOpenCursorsAcrossRollback: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsOpenStatementsAcrossCommit: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsOpenStatementsAcrossRollback: Boolean = + throw new SQLFeatureNotSupportedException + + override def getMaxBinaryLiteralLength: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxCharLiteralLength: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxColumnNameLength: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxColumnsInGroupBy: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxColumnsInIndex: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxColumnsInOrderBy: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxColumnsInSelect: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxColumnsInTable: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxConnections: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxCursorNameLength: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxIndexLength: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxSchemaNameLength: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxProcedureNameLength: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxCatalogNameLength: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxRowSize: Int = + throw new SQLFeatureNotSupportedException + + override def doesMaxRowSizeIncludeBlobs: Boolean = + throw new SQLFeatureNotSupportedException + + override def getMaxStatementLength: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxStatements: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxTableNameLength: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxTablesInSelect: Int = + throw new SQLFeatureNotSupportedException + + override def getMaxUserNameLength: Int = + throw new SQLFeatureNotSupportedException + + override def getDefaultTransactionIsolation: Int = + throw new SQLFeatureNotSupportedException + + override def supportsTransactions: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsTransactionIsolationLevel(level: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsDataDefinitionAndDataManipulationTransactions: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsDataManipulationTransactionsOnly: Boolean = + throw new SQLFeatureNotSupportedException + + override def dataDefinitionCausesTransactionCommit: Boolean = + throw new SQLFeatureNotSupportedException + + override def dataDefinitionIgnoredInTransactions: Boolean = + throw new SQLFeatureNotSupportedException + + override def getProcedures( + catalog: String, + schemaPattern: String, + procedureNamePattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getProcedureColumns( + catalog: String, + schemaPattern: String, + procedureNamePattern: String, + columnNamePattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getCatalogs: ResultSet = + throw new SQLFeatureNotSupportedException + + override def getSchemas: ResultSet = + throw new SQLFeatureNotSupportedException + + override def getSchemas(catalog: String, schemaPattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getTableTypes: ResultSet = + throw new SQLFeatureNotSupportedException + + override def getTables( + catalog: String, + schemaPattern: String, + tableNamePattern: String, + types: Array[String]): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getColumns( + catalog: String, + schemaPattern: String, + tableNamePattern: String, + columnNamePattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getColumnPrivileges( + catalog: String, + schema: String, + table: String, + columnNamePattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getTablePrivileges( + catalog: String, + schemaPattern: String, + tableNamePattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getBestRowIdentifier( + catalog: String, + schema: String, + table: String, + scope: Int, + nullable: Boolean): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getVersionColumns( + catalog: String, schema: String, table: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getPrimaryKeys(catalog: String, schema: String, table: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getImportedKeys(catalog: String, schema: String, table: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getExportedKeys(catalog: String, schema: String, table: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getCrossReference( + parentCatalog: String, + parentSchema: String, + parentTable: String, + foreignCatalog: String, + foreignSchema: String, + foreignTable: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getTypeInfo: ResultSet = + throw new SQLFeatureNotSupportedException + + override def getIndexInfo( + catalog: String, + schema: String, + table: String, + unique: Boolean, + approximate: Boolean): ResultSet = + throw new SQLFeatureNotSupportedException + + override def supportsResultSetType(`type`: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsResultSetConcurrency(`type`: Int, concurrency: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def ownUpdatesAreVisible(`type`: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def ownDeletesAreVisible(`type`: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def ownInsertsAreVisible(`type`: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def othersUpdatesAreVisible(`type`: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def othersDeletesAreVisible(`type`: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def othersInsertsAreVisible(`type`: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def updatesAreDetected(`type`: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def deletesAreDetected(`type`: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def insertsAreDetected(`type`: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsBatchUpdates: Boolean = + throw new SQLFeatureNotSupportedException + + override def getUDTs( + catalog: String, + schemaPattern: String, + typeNamePattern: String, + types: Array[Int]): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getConnection: Connection = conn + + override def supportsSavepoints: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsNamedParameters: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsMultipleOpenResults: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsGetGeneratedKeys: Boolean = + throw new SQLFeatureNotSupportedException + + override def getSuperTypes( + catalog: String, + schemaPattern: String, + typeNamePattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getSuperTables( + catalog: String, + schemaPattern: String, + tableNamePattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getAttributes( + catalog: String, + schemaPattern: String, + typeNamePattern: String, + attributeNamePattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def supportsResultSetHoldability(holdability: Int): Boolean = + holdability == ResultSet.CLOSE_CURSORS_AT_COMMIT + + override def getResultSetHoldability: Int = ResultSet.CLOSE_CURSORS_AT_COMMIT + + override def getDatabaseMajorVersion: Int = VersionUtils.majorVersion(conn.spark.version) + + override def getDatabaseMinorVersion: Int = VersionUtils.minorVersion(conn.spark.version) + + // JSR-221 defines JDBC 4.0 API Specification - https://jcp.org/en/jsr/detail?id=221 + // JDBC 4.3 is the latest Maintenance version of the JDBC 4.0 specification as of JDK 17 + // https://docs.oracle.com/en/java/javase/17/docs/api/java.sql/java/sql/package-summary.html + override def getJDBCMajorVersion: Int = 4 + + override def getJDBCMinorVersion: Int = 3 + + override def getSQLStateType: Int = + throw new SQLFeatureNotSupportedException + + override def locatorsUpdateCopy: Boolean = + throw new SQLFeatureNotSupportedException + + override def supportsStatementPooling: Boolean = false + + override def getRowIdLifetime: RowIdLifetime = RowIdLifetime.ROWID_UNSUPPORTED + + override def supportsStoredFunctionsUsingCallSyntax: Boolean = + throw new SQLFeatureNotSupportedException + + override def autoCommitFailureClosesAllResultSets: Boolean = + throw new SQLFeatureNotSupportedException + + override def getClientInfoProperties: ResultSet = + throw new SQLFeatureNotSupportedException + + override def getFunctions( + catalog: String, + schemaPattern: String, + functionNamePattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getFunctionColumns( + catalog: String, + schemaPattern: String, + functionNamePattern: String, + columnNamePattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def getPseudoColumns( + catalog: String, + schemaPattern: String, + tableNamePattern: String, + columnNamePattern: String): ResultSet = + throw new SQLFeatureNotSupportedException + + override def generatedKeyAlwaysReturned: Boolean = false + + override def getMaxLogicalLobSize: Long = 0 + + override def supportsRefCursors: Boolean = false + + override def supportsSharding: Boolean = false + + override def unwrap[T](iface: Class[T]): T = if (isWrapperFor(iface)) { + iface.asInstanceOf[T] + } else { + throw new SQLException(s"${this.getClass.getName} not unwrappable from ${iface.getName}") + } + + override def isWrapperFor(iface: Class[_]): Boolean = iface.isInstance(this) +} diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala new file mode 100644 index 0000000000000..38417b0de2173 --- /dev/null +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -0,0 +1,682 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.jdbc + +import java.io.{InputStream, Reader} +import java.net.URL +import java.sql.{Array => JdbcArray, _} +import java.util +import java.util.Calendar + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connect.client.SparkResult + +class SparkConnectResultSet( + sparkResult: SparkResult[Row], + stmt: SparkConnectStatement = null) extends ResultSet { + + private val iterator = sparkResult.destructiveIterator + + private var currentRow: Row = _ + + private var _wasNull: Boolean = false + + override def wasNull: Boolean = _wasNull + + override def next(): Boolean = { + val hasNext = iterator.hasNext + if (hasNext) { + currentRow = iterator.next() + } else { + currentRow = null + } + hasNext + } + + @volatile protected var closed: Boolean = false + + override def isClosed: Boolean = closed + + override def close(): Unit = synchronized { + if (!closed) { + iterator.close() + sparkResult.close() + closed = true + } + } + + private[jdbc] def checkOpen(): Unit = { + if (closed) { + throw new SQLException("JDBC Statement is closed.") + } + } + + override def findColumn(columnLabel: String): Int = { + sparkResult.schema.getFieldIndex(columnLabel) match { + case Some(i) => i + 1 + case None => + throw new SQLException(s"Invalid column label: $columnLabel") + } + } + + override def getString(columnIndex: Int): String = { + if (currentRow.isNullAt(columnIndex - 1)) { + _wasNull = true + return null + } + _wasNull = false + String.valueOf(currentRow.get(columnIndex - 1)) + } + + override def getBoolean(columnIndex: Int): Boolean = { + if (currentRow.isNullAt(columnIndex - 1)) { + _wasNull = true + return false + } + _wasNull = false + currentRow.getBoolean(columnIndex - 1) + } + + override def getByte(columnIndex: Int): Byte = { + if (currentRow.isNullAt(columnIndex - 1)) { + _wasNull = true + return 0.toByte + } + _wasNull = false + currentRow.getByte(columnIndex - 1) + } + + override def getShort(columnIndex: Int): Short = { + if (currentRow.isNullAt(columnIndex - 1)) { + _wasNull = true + return 0.toShort + } + _wasNull = false + currentRow.getShort(columnIndex - 1) + } + + override def getInt(columnIndex: Int): Int = { + if (currentRow.isNullAt(columnIndex - 1)) { + _wasNull = true + return 0 + } + _wasNull = false + currentRow.getInt(columnIndex - 1) + } + + override def getLong(columnIndex: Int): Long = { + if (currentRow.isNullAt(columnIndex - 1)) { + _wasNull = true + return 0L + } + _wasNull = false + currentRow.getLong(columnIndex - 1) + } + + override def getFloat(columnIndex: Int): Float = { + if (currentRow.isNullAt(columnIndex - 1)) { + _wasNull = true + return 0.toFloat + } + _wasNull = false + currentRow.getFloat(columnIndex - 1) + } + + override def getDouble(columnIndex: Int): Double = { + if (currentRow.isNullAt(columnIndex - 1)) { + _wasNull = true + return 0.toDouble + } + _wasNull = false + currentRow.getDouble(columnIndex - 1) + } + + override def getBigDecimal(columnIndex: Int, scale: Int): java.math.BigDecimal = + throw new SQLFeatureNotSupportedException + + override def getBytes(columnIndex: Int): Array[Byte] = + throw new SQLFeatureNotSupportedException + + override def getDate(columnIndex: Int): Date = + throw new SQLFeatureNotSupportedException + + override def getTime(columnIndex: Int): Time = + throw new SQLFeatureNotSupportedException + + override def getTimestamp(columnIndex: Int): Timestamp = + throw new SQLFeatureNotSupportedException + + override def getAsciiStream(columnIndex: Int): InputStream = + throw new SQLFeatureNotSupportedException + + override def getUnicodeStream(columnIndex: Int): InputStream = + throw new SQLFeatureNotSupportedException + + override def getBinaryStream(columnIndex: Int): InputStream = + throw new SQLFeatureNotSupportedException + + override def getString(columnLabel: String): String = + getString(findColumn(columnLabel)) + + override def getBoolean(columnLabel: String): Boolean = + getBoolean(findColumn(columnLabel)) + + override def getByte(columnLabel: String): Byte = + getByte(findColumn(columnLabel)) + + override def getShort(columnLabel: String): Short = + getShort(findColumn(columnLabel)) + + override def getInt(columnLabel: String): Int = + getInt(findColumn(columnLabel)) + + override def getLong(columnLabel: String): Long = + getLong(findColumn(columnLabel)) + + override def getFloat(columnLabel: String): Float = + getFloat(findColumn(columnLabel)) + + override def getDouble(columnLabel: String): Double = + getDouble(findColumn(columnLabel)) + + override def getBigDecimal(columnLabel: String, scale: Int): java.math.BigDecimal = + throw new SQLFeatureNotSupportedException + + override def getBytes(columnLabel: String): Array[Byte] = + throw new SQLFeatureNotSupportedException + + override def getDate(columnLabel: String): Date = + throw new SQLFeatureNotSupportedException + + override def getTime(columnLabel: String): Time = + throw new SQLFeatureNotSupportedException + + override def getTimestamp(columnLabel: String): Timestamp = + throw new SQLFeatureNotSupportedException + + override def getAsciiStream(columnLabel: String): InputStream = + throw new SQLFeatureNotSupportedException + + override def getUnicodeStream(columnLabel: String): InputStream = + throw new SQLFeatureNotSupportedException + + override def getBinaryStream(columnLabel: String): InputStream = + throw new SQLFeatureNotSupportedException + + override def getWarnings: SQLWarning = null + + override def clearWarnings(): Unit = {} + + override def getCursorName: String = throw new SQLFeatureNotSupportedException + + override def getMetaData: ResultSetMetaData = { + checkOpen() + new SparkConnectResultSetMetaData(sparkResult.schema) + } + + override def getObject(columnIndex: Int): AnyRef = { + if (currentRow.isNullAt(columnIndex - 1)) { + _wasNull = true + return null + } + _wasNull = false + currentRow.get(columnIndex - 1).asInstanceOf[AnyRef] + } + + override def getObject(columnLabel: String): AnyRef = + getObject(findColumn(columnLabel)) + + override def getCharacterStream(columnIndex: Int): Reader = + throw new SQLFeatureNotSupportedException + + override def getCharacterStream(columnLabel: String): Reader = + throw new SQLFeatureNotSupportedException + + override def getBigDecimal(columnIndex: Int): java.math.BigDecimal = + throw new SQLFeatureNotSupportedException + + override def getBigDecimal(columnLabel: String): java.math.BigDecimal = + throw new SQLFeatureNotSupportedException + + override def isBeforeFirst: Boolean = throw new SQLFeatureNotSupportedException + + override def isAfterLast: Boolean = throw new SQLFeatureNotSupportedException + + override def isFirst: Boolean = throw new SQLFeatureNotSupportedException + + override def isLast: Boolean = throw new SQLFeatureNotSupportedException + + override def beforeFirst(): Unit = throw new SQLFeatureNotSupportedException + + override def afterLast(): Unit = throw new SQLFeatureNotSupportedException + + override def first(): Boolean = throw new SQLFeatureNotSupportedException + + override def last(): Boolean = throw new SQLFeatureNotSupportedException + + override def getRow: Int = throw new SQLFeatureNotSupportedException + + override def absolute(row: Int): Boolean = throw new SQLFeatureNotSupportedException + + override def relative(rows: Int): Boolean = throw new SQLFeatureNotSupportedException + + override def previous(): Boolean = throw new SQLFeatureNotSupportedException + + override def setFetchDirection(direction: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def getFetchDirection: Int = + throw new SQLFeatureNotSupportedException + + override def setFetchSize(rows: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def getFetchSize: Int = + throw new SQLFeatureNotSupportedException + + override def getType: Int = { + checkOpen() + ResultSet.TYPE_FORWARD_ONLY + } + + override def getConcurrency: Int = { + checkOpen() + ResultSet.CONCUR_READ_ONLY + } + + override def rowUpdated(): Boolean = + throw new SQLFeatureNotSupportedException + + override def rowInserted(): Boolean = + throw new SQLFeatureNotSupportedException + + override def rowDeleted(): Boolean = + throw new SQLFeatureNotSupportedException + + override def updateNull(columnIndex: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBoolean(columnIndex: Int, x: Boolean): Unit = + throw new SQLFeatureNotSupportedException + + override def updateByte(columnIndex: Int, x: Byte): Unit = + throw new SQLFeatureNotSupportedException + + override def updateShort(columnIndex: Int, x: Short): Unit = + throw new SQLFeatureNotSupportedException + + override def updateInt(columnIndex: Int, x: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def updateLong(columnIndex: Int, x: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateFloat(columnIndex: Int, x: Float): Unit = + throw new SQLFeatureNotSupportedException + + override def updateDouble(columnIndex: Int, x: Double): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBigDecimal(columnIndex: Int, x: java.math.BigDecimal): Unit = + throw new SQLFeatureNotSupportedException + + override def updateString(columnIndex: Int, x: String): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBytes(columnIndex: Int, x: scala.Array[Byte]): Unit = + throw new SQLFeatureNotSupportedException + + override def updateDate(columnIndex: Int, x: Date): Unit = + throw new SQLFeatureNotSupportedException + + override def updateTime(columnIndex: Int, x: Time): Unit = + throw new SQLFeatureNotSupportedException + + override def updateTimestamp(columnIndex: Int, x: Timestamp): Unit = + throw new SQLFeatureNotSupportedException + + override def updateAsciiStream(columnIndex: Int, x: InputStream, length: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBinaryStream(columnIndex: Int, x: InputStream, length: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def updateCharacterStream(columnIndex: Int, x: Reader, length: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def updateObject(columnIndex: Int, x: Any, scaleOrLength: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def updateObject(columnIndex: Int, x: Any): Unit = + throw new SQLFeatureNotSupportedException + + override def updateNull(columnLabel: String): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBoolean(columnLabel: String, x: Boolean): Unit = + throw new SQLFeatureNotSupportedException + + override def updateByte(columnLabel: String, x: Byte): Unit = + throw new SQLFeatureNotSupportedException + + override def updateShort(columnLabel: String, x: Short): Unit = + throw new SQLFeatureNotSupportedException + + override def updateInt(columnLabel: String, x: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def updateLong(columnLabel: String, x: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateFloat(columnLabel: String, x: Float): Unit = + throw new SQLFeatureNotSupportedException + + override def updateDouble(columnLabel: String, x: Double): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBigDecimal(columnLabel: String, x: java.math.BigDecimal): Unit = + throw new SQLFeatureNotSupportedException + + override def updateString(columnLabel: String, x: String): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBytes(columnLabel: String, x: Array[Byte]): Unit = + throw new SQLFeatureNotSupportedException + + override def updateDate(columnLabel: String, x: Date): Unit = + throw new SQLFeatureNotSupportedException + + override def updateTime(columnLabel: String, x: Time): Unit = + throw new SQLFeatureNotSupportedException + + override def updateTimestamp(columnLabel: String, x: Timestamp): Unit = + throw new SQLFeatureNotSupportedException + + override def updateAsciiStream(columnLabel: String, x: InputStream, length: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBinaryStream(columnLabel: String, x: InputStream, length: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def updateCharacterStream(columnLabel: String, reader: Reader, length: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def updateObject(columnLabel: String, x: Any, scaleOrLength: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def updateObject(columnLabel: String, x: Any): Unit = + throw new SQLFeatureNotSupportedException + + override def insertRow(): Unit = + throw new SQLFeatureNotSupportedException + + override def updateRow(): Unit = + throw new SQLFeatureNotSupportedException + + override def deleteRow(): Unit = + throw new SQLFeatureNotSupportedException + + override def refreshRow(): Unit = + throw new SQLFeatureNotSupportedException + + override def cancelRowUpdates(): Unit = + throw new SQLFeatureNotSupportedException + + override def moveToInsertRow(): Unit = + throw new SQLFeatureNotSupportedException + + override def moveToCurrentRow(): Unit = + throw new SQLFeatureNotSupportedException + + override def getStatement: Statement = { + checkOpen() + stmt + } + + override def getObject(columnIndex: Int, map: util.Map[String, Class[_]]): AnyRef = + throw new SQLFeatureNotSupportedException + + override def getRef(columnIndex: Int): Ref = + throw new SQLFeatureNotSupportedException + + override def getBlob(columnIndex: Int): Blob = + throw new SQLFeatureNotSupportedException + + override def getClob(columnIndex: Int): Clob = + throw new SQLFeatureNotSupportedException + + override def getArray(columnIndex: Int): JdbcArray = + throw new SQLFeatureNotSupportedException + + override def getObject(columnLabel: String, map: util.Map[String, Class[_]]): AnyRef = + throw new SQLFeatureNotSupportedException + + override def getRef(columnLabel: String): Ref = + throw new SQLFeatureNotSupportedException + + override def getBlob(columnLabel: String): Blob = + throw new SQLFeatureNotSupportedException + + override def getClob(columnLabel: String): Clob = + throw new SQLFeatureNotSupportedException + + override def getArray(columnLabel: String): JdbcArray = + throw new SQLFeatureNotSupportedException + + override def getDate(columnIndex: Int, cal: Calendar): Date = + throw new SQLFeatureNotSupportedException + + override def getDate(columnLabel: String, cal: Calendar): Date = + throw new SQLFeatureNotSupportedException + + override def getTime(columnIndex: Int, cal: Calendar): Time = + throw new SQLFeatureNotSupportedException + + override def getTime(columnLabel: String, cal: Calendar): Time = + throw new SQLFeatureNotSupportedException + + override def getTimestamp(columnIndex: Int, cal: Calendar): Timestamp = + throw new SQLFeatureNotSupportedException + + override def getTimestamp(columnLabel: String, cal: Calendar): Timestamp = + throw new SQLFeatureNotSupportedException + + override def getURL(columnIndex: Int): URL = + throw new SQLFeatureNotSupportedException + + override def getURL(columnLabel: String): URL = + throw new SQLFeatureNotSupportedException + + override def updateRef(columnIndex: Int, x: Ref): Unit = + throw new SQLFeatureNotSupportedException + + override def updateRef(columnLabel: String, x: Ref): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBlob(columnIndex: Int, x: Blob): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBlob(columnLabel: String, x: Blob): Unit = + throw new SQLFeatureNotSupportedException + + override def updateClob(columnIndex: Int, x: Clob): Unit = + throw new SQLFeatureNotSupportedException + + override def updateClob(columnLabel: String, x: Clob): Unit = + throw new SQLFeatureNotSupportedException + + override def updateArray(columnIndex: Int, x: JdbcArray): Unit = + throw new SQLFeatureNotSupportedException + + override def updateArray(columnLabel: String, x: JdbcArray): Unit = + throw new SQLFeatureNotSupportedException + + override def getRowId(columnIndex: Int): RowId = + throw new SQLFeatureNotSupportedException + + override def getRowId(columnLabel: String): RowId = + throw new SQLFeatureNotSupportedException + + override def updateRowId(columnIndex: Int, x: RowId): Unit = + throw new SQLFeatureNotSupportedException + + override def updateRowId(columnLabel: String, x: RowId): Unit = + throw new SQLFeatureNotSupportedException + + override def getHoldability: Int = ResultSet.HOLD_CURSORS_OVER_COMMIT + + override def updateNString(columnIndex: Int, nString: String): Unit = + throw new SQLFeatureNotSupportedException + + override def updateNString(columnLabel: String, nString: String): Unit = + throw new SQLFeatureNotSupportedException + + override def updateNClob(columnIndex: Int, nClob: NClob): Unit = + throw new SQLFeatureNotSupportedException + + override def updateNClob(columnLabel: String, nClob: NClob): Unit = + throw new SQLFeatureNotSupportedException + + override def getNClob(columnIndex: Int): NClob = + throw new SQLFeatureNotSupportedException + + override def getNClob(columnLabel: String): NClob = + throw new SQLFeatureNotSupportedException + + override def getSQLXML(columnIndex: Int): SQLXML = + throw new SQLFeatureNotSupportedException + + override def getSQLXML(columnLabel: String): SQLXML = + throw new SQLFeatureNotSupportedException + + override def updateSQLXML(columnIndex: Int, xmlObject: SQLXML): Unit = + throw new SQLFeatureNotSupportedException + + override def updateSQLXML(columnLabel: String, xmlObject: SQLXML): Unit = + throw new SQLFeatureNotSupportedException + + override def getNString(columnIndex: Int): String = + throw new SQLFeatureNotSupportedException + + override def getNString(columnLabel: String): String = + throw new SQLFeatureNotSupportedException + + override def getNCharacterStream(columnIndex: Int): Reader = + throw new SQLFeatureNotSupportedException + + override def getNCharacterStream(columnLabel: String): Reader = + throw new SQLFeatureNotSupportedException + + override def updateNCharacterStream(columnIndex: Int, x: Reader, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateNCharacterStream(columnLabel: String, reader: Reader, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateAsciiStream(columnIndex: Int, x: InputStream, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBinaryStream(columnIndex: Int, x: InputStream, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateCharacterStream(columnIndex: Int, x: Reader, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateAsciiStream(columnLabel: String, x: InputStream, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBinaryStream(columnLabel: String, x: InputStream, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateCharacterStream(columnLabel: String, reader: Reader, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBlob(columnIndex: Int, inputStream: InputStream, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBlob(columnLabel: String, inputStream: InputStream, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateClob(columnIndex: Int, reader: Reader, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateClob(columnLabel: String, reader: Reader, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateNClob(columnIndex: Int, reader: Reader, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateNClob(columnLabel: String, reader: Reader, length: Long): Unit = + throw new SQLFeatureNotSupportedException + + override def updateNCharacterStream(columnIndex: Int, x: Reader): Unit = + throw new SQLFeatureNotSupportedException + + override def updateNCharacterStream(columnLabel: String, reader: Reader): Unit = + throw new SQLFeatureNotSupportedException + + override def updateAsciiStream(columnIndex: Int, x: InputStream): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBinaryStream(columnIndex: Int, x: InputStream): Unit = + throw new SQLFeatureNotSupportedException + + override def updateCharacterStream(columnIndex: Int, x: Reader): Unit = + throw new SQLFeatureNotSupportedException + + override def updateAsciiStream(columnLabel: String, x: InputStream): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBinaryStream(columnLabel: String, x: InputStream): Unit = + throw new SQLFeatureNotSupportedException + + override def updateCharacterStream(columnLabel: String, reader: Reader): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBlob(columnIndex: Int, inputStream: InputStream): Unit = + throw new SQLFeatureNotSupportedException + + override def updateBlob(columnLabel: String, inputStream: InputStream): Unit = + throw new SQLFeatureNotSupportedException + + override def updateClob(columnIndex: Int, reader: Reader): Unit = + throw new SQLFeatureNotSupportedException + + override def updateClob(columnLabel: String, reader: Reader): Unit = + throw new SQLFeatureNotSupportedException + + override def updateNClob(columnIndex: Int, reader: Reader): Unit = + throw new SQLFeatureNotSupportedException + + override def updateNClob(columnLabel: String, reader: Reader): Unit = + throw new SQLFeatureNotSupportedException + + override def getObject[T](columnIndex: Int, `type`: Class[T]): T = + throw new SQLFeatureNotSupportedException + + override def getObject[T](columnLabel: String, `type`: Class[T]): T = + throw new SQLFeatureNotSupportedException + + override def unwrap[T](iface: Class[T]): T = if (isWrapperFor(iface)) { + iface.asInstanceOf[T] + } else { + throw new SQLException(s"${this.getClass.getName} not unwrappable from ${iface.getName}") + } + + override def isWrapperFor(iface: Class[_]): Boolean = iface.isInstance(this) +} diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSetMetaData.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSetMetaData.scala new file mode 100644 index 0000000000000..38e8f1d2e5d96 --- /dev/null +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSetMetaData.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.jdbc + +import java.sql.{Array => _, _} +import java.sql.ResultSetMetaData.{columnNoNulls, columnNullable} + +import org.apache.spark.sql.connect.client.jdbc.util.JdbcTypeUtils +import org.apache.spark.sql.types._ + +class SparkConnectResultSetMetaData(schema: StructType) extends ResultSetMetaData { + + override def getColumnCount: Int = schema.length + + override def isAutoIncrement(column: Int): Boolean = false + + override def isCaseSensitive(column: Int): Boolean = false + + override def isSearchable(column: Int): Boolean = false + + override def isCurrency(column: Int): Boolean = false + + override def isNullable(column: Int): Int = + if (schema(column - 1).nullable) columnNullable else columnNoNulls + + override def isSigned(column: Int): Boolean = + JdbcTypeUtils.isSigned(schema(column - 1)) + + override def getColumnDisplaySize(column: Int): Int = + JdbcTypeUtils.getDisplaySize(schema(column - 1)) + + override def getColumnLabel(column: Int): String = getColumnName(column) + + override def getColumnName(column: Int): String = schema(column - 1).name + + override def getColumnType(column: Int): Int = + JdbcTypeUtils.getColumnType(schema(column - 1)) + + override def getColumnTypeName(column: Int): String = schema(column - 1).dataType.sql + + override def getColumnClassName(column: Int): String = + JdbcTypeUtils.getColumnTypeClassName(schema(column - 1)) + + override def getPrecision(column: Int): Int = + JdbcTypeUtils.getPrecision(schema(column - 1)) + + override def getScale(column: Int): Int = + JdbcTypeUtils.getScale(schema(column - 1)) + + override def getCatalogName(column: Int): String = "" + + override def getSchemaName(column: Int): String = "" + + override def getTableName(column: Int): String = "" + + override def isReadOnly(column: Int): Boolean = true + + override def isWritable(column: Int): Boolean = false + + override def isDefinitelyWritable(column: Int): Boolean = false + + override def unwrap[T](iface: Class[T]): T = if (isWrapperFor(iface)) { + iface.asInstanceOf[T] + } else { + throw new SQLException(s"${this.getClass.getName} not unwrappable from ${iface.getName}") + } + + override def isWrapperFor(iface: Class[_]): Boolean = iface.isInstance(this) +} diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala new file mode 100644 index 0000000000000..8de227f9d07c2 --- /dev/null +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.jdbc + +import java.sql.{Array => _, _} + +class SparkConnectStatement(conn: SparkConnectConnection) extends Statement { + + private var operationId: String = _ + private var resultSet: SparkConnectResultSet = _ + + @volatile private var closed: Boolean = false + + override def isClosed: Boolean = closed + + override def close(): Unit = synchronized { + if (!closed) { + if (operationId != null) { + conn.spark.interruptOperation(operationId) + operationId = null + } + if (resultSet != null) { + resultSet.close() + resultSet = null + } + closed = false + } + } + + private[jdbc] def checkOpen(): Unit = { + if (closed) { + throw new SQLException("JDBC Statement is closed.") + } + } + + override def executeQuery(sql: String): ResultSet = { + checkOpen() + + val df = conn.spark.sql(sql) + val sparkResult = df.collectResult() + operationId = sparkResult.operationId + resultSet = new SparkConnectResultSet(sparkResult, this) + resultSet + } + + override def executeUpdate(sql: String): Int = { + checkOpen() + + val df = conn.spark.sql(sql) + val sparkResult = df.collectResult() + operationId = sparkResult.operationId + resultSet = null + + // always return 0 because affected rows is not supported yet + 0 + } + + override def execute(sql: String): Boolean = { + checkOpen() + + // always perform executeQuery and reture a ResultSet + executeQuery(sql) + true + } + + override def getResultSet: ResultSet = { + checkOpen() + resultSet + } + + override def getMaxFieldSize: Int = + throw new SQLFeatureNotSupportedException + + override def setMaxFieldSize(max: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def getMaxRows: Int = { + checkOpen() + 0 + } + + override def setMaxRows(max: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def setEscapeProcessing(enable: Boolean): Unit = + throw new SQLFeatureNotSupportedException + + override def getQueryTimeout: Int = { + checkOpen() + 0 + } + + override def setQueryTimeout(seconds: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def cancel(): Unit = { + checkOpen() + + if (operationId != null) { + conn.spark.interruptOperation(operationId) + } + } + + override def getWarnings: SQLWarning = null + + override def clearWarnings(): Unit = {} + + override def setCursorName(name: String): Unit = + throw new SQLFeatureNotSupportedException + + override def getUpdateCount: Int = + throw new SQLFeatureNotSupportedException + + override def getMoreResults: Boolean = + throw new SQLFeatureNotSupportedException + + override def setFetchDirection(direction: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def getFetchDirection: Int = + throw new SQLFeatureNotSupportedException + + override def setFetchSize(rows: Int): Unit = + throw new SQLFeatureNotSupportedException + + override def getFetchSize: Int = + throw new SQLFeatureNotSupportedException + + override def getResultSetConcurrency: Int = { + checkOpen() + ResultSet.CONCUR_READ_ONLY + } + + override def getResultSetType: Int = + throw new SQLFeatureNotSupportedException + + override def addBatch(sql: String): Unit = + throw new SQLFeatureNotSupportedException + + override def clearBatch(): Unit = + throw new SQLFeatureNotSupportedException + + override def executeBatch(): Array[Int] = + throw new SQLFeatureNotSupportedException + + override def getConnection: Connection = { + checkOpen() + conn + } + + override def getMoreResults(current: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def getGeneratedKeys: ResultSet = + throw new SQLFeatureNotSupportedException + + override def executeUpdate(sql: String, autoGeneratedKeys: Int): Int = + throw new SQLFeatureNotSupportedException + + override def executeUpdate(sql: String, columnIndexes: Array[Int]): Int = + throw new SQLFeatureNotSupportedException + + override def executeUpdate(sql: String, columnNames: Array[String]): Int = + throw new SQLFeatureNotSupportedException + + override def execute(sql: String, autoGeneratedKeys: Int): Boolean = + throw new SQLFeatureNotSupportedException + + override def execute(sql: String, columnIndexes: Array[Int]): Boolean = + throw new SQLFeatureNotSupportedException + + override def execute(sql: String, columnNames: Array[String]): Boolean = + throw new SQLFeatureNotSupportedException + + override def getResultSetHoldability: Int = + throw new SQLFeatureNotSupportedException + + override def setPoolable(poolable: Boolean): Unit = { + checkOpen() + + if (poolable) { + throw new SQLFeatureNotSupportedException("Poolable statement is not supported") + } + } + + override def isPoolable: Boolean = { + checkOpen() + false + } + + override def closeOnCompletion(): Unit = { + checkOpen() + } + + override def isCloseOnCompletion: Boolean = { + checkOpen() + false + } + + override def unwrap[T](iface: Class[T]): T = if (isWrapperFor(iface)) { + iface.asInstanceOf[T] + } else { + throw new SQLException(s"${this.getClass.getName} not unwrappable from ${iface.getName}") + } + + override def isWrapperFor(iface: Class[_]): Boolean = iface.isInstance(this) +} diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala new file mode 100644 index 0000000000000..cb941b7420a71 --- /dev/null +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.jdbc.util + +import java.sql.{Array => _, _} + +private[jdbc] object JdbcErrorUtils { + + def stringfiyTransactionIsolationLevel(level: Int): String = level match { + case Connection.TRANSACTION_NONE => "NONE" + case Connection.TRANSACTION_READ_UNCOMMITTED => "READ_UNCOMMITTED" + case Connection.TRANSACTION_READ_COMMITTED => "READ_COMMITTED" + case Connection.TRANSACTION_REPEATABLE_READ => "REPEATABLE_READ" + case Connection.TRANSACTION_SERIALIZABLE => "SERIALIZABLE" + case _ => + throw new IllegalArgumentException(s"Invalid transaction isolation level: $level") + } + + def stringfiyHoldability(holdability: Int): String = holdability match { + case ResultSet.HOLD_CURSORS_OVER_COMMIT => "HOLD_CURSORS_OVER_COMMIT" + case ResultSet.CLOSE_CURSORS_AT_COMMIT => "CLOSE_CURSORS_AT_COMMIT" + case _ => + throw new IllegalArgumentException(s"Invalid holdability: $holdability") + } +} diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala new file mode 100644 index 0000000000000..55e3d29c99a5e --- /dev/null +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.jdbc.util + +import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} +import java.sql.{Array => _, _} + +import org.apache.spark.sql.types._ + +private[jdbc] object JdbcTypeUtils { + + def getColumnType(field: StructField): Int = field.dataType match { + case NullType => Types.NULL + case BooleanType => Types.BOOLEAN + case ByteType => Types.TINYINT + case ShortType => Types.SMALLINT + case IntegerType => Types.INTEGER + case LongType => Types.BIGINT + case FloatType => Types.FLOAT + case DoubleType => Types.DOUBLE + case StringType => Types.VARCHAR + case other => + throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") + } + + def getColumnTypeClassName(field: StructField): String = field.dataType match { + case NullType => "null" + case BooleanType => classOf[JBoolean].getName + case ByteType => classOf[JByte].getName + case ShortType => classOf[JShort].getName + case IntegerType => classOf[Integer].getName + case LongType => classOf[JLong].getName + case FloatType => classOf[JFloat].getName + case DoubleType => classOf[JDouble].getName + case StringType => classOf[String].getName + case other => + throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") + } + + def isSigned(field: StructField): Boolean = field.dataType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case NullType | BooleanType | StringType => false + case other => + throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") + } + + def getPrecision(field: StructField): Int = field.dataType match { + case NullType => 0 + case BooleanType => 1 + case ByteType => 3 + case ShortType => 5 + case IntegerType => 10 + case LongType => 19 + case FloatType => 7 + case DoubleType => 15 + case StringType => 255 + case other => + throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") + } + + def getScale(field: StructField): Int = field.dataType match { + case FloatType => 7 + case DoubleType => 15 + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | StringType => 0 + case other => + throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") + } + + def getDisplaySize(field: StructField): Int = field.dataType match { + case NullType => 4 // length of `NULL` + case BooleanType => 5 // `TRUE` or `FALSE` + case ByteType | ShortType | IntegerType | LongType => + getPrecision(field) + 1 // may have leading negative sign + case FloatType => 14 + case DoubleType => 24 + case StringType => + getPrecision(field) + case other => + throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") + } +} diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDriverSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDriverSuite.scala index eb4ce76d2c0ab..09ba786297579 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDriverSuite.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDriverSuite.scala @@ -17,18 +17,69 @@ package org.apache.spark.sql.connect.client.jdbc -import java.sql.DriverManager +import java.sql.{Array => _, _} +import java.util.Properties -import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite +import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} +import org.apache.spark.sql.connect.client.jdbc.test.JdbcHelper +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} +import org.apache.spark.util.VersionUtils -class SparkConnectDriverSuite extends AnyFunSuite { // scalastyle:ignore funsuite +class SparkConnectDriverSuite extends ConnectFunSuite with RemoteSparkSession + with JdbcHelper { - // explicitly load the class to make it known to the DriverManager - classOf[SparkConnectDriver].getClassLoader + def jdbcUrl: String = s"jdbc:sc://localhost:$serverPort" - val jdbcUrl: String = s"jdbc:sc://localhost:15002" - - test("test SparkConnectDriver") { + test("get Connection from SparkConnectDriver") { assert(DriverManager.getDriver(jdbcUrl).isInstanceOf[SparkConnectDriver]) + + val cause = intercept[SQLException] { + new SparkConnectDriver().connect(null, new Properties()) + } + assert(cause.getMessage === "url must not be null") + + withConnection { conn => + assert(conn.isInstanceOf[SparkConnectConnection]) + } + } + + test("get DatabaseMetaData from SparkConnectConnection") { + withConnection { conn => + val spark = conn.asInstanceOf[SparkConnectConnection].spark + val metadata = conn.getMetaData + assert(metadata.getURL === jdbcUrl) + assert(metadata.isReadOnly === false) + assert(metadata.getUserName === spark.client.configuration.userName) + assert(metadata.getDatabaseProductName === "Apache Spark Connect Server") + assert(metadata.getDatabaseProductVersion === spark.version) + assert(metadata.getDriverVersion === SPARK_VERSION) + assert(metadata.getDriverMajorVersion === VersionUtils.majorVersion(SPARK_VERSION)) + assert(metadata.getDriverMinorVersion === VersionUtils.minorVersion(SPARK_VERSION)) + assert(metadata.getIdentifierQuoteString === "`") + assert(metadata.getExtraNameCharacters === "") + assert(metadata.supportsGroupBy === true) + assert(metadata.supportsOuterJoins === true) + assert(metadata.supportsFullOuterJoins === true) + assert(metadata.supportsLimitedOuterJoins === true) + assert(metadata.getSchemaTerm === "schema") + assert(metadata.getProcedureTerm === "procedure") + assert(metadata.getCatalogTerm === "catalog") + assert(metadata.isCatalogAtStart === true) + assert(metadata.getCatalogSeparator === ".") + assert(metadata.getConnection === conn) + assert(metadata.supportsResultSetHoldability(ResultSet.HOLD_CURSORS_OVER_COMMIT) === false) + assert(metadata.supportsResultSetHoldability(ResultSet.CLOSE_CURSORS_AT_COMMIT) === true) + assert(metadata.getResultSetHoldability === ResultSet.CLOSE_CURSORS_AT_COMMIT) + assert(metadata.getDatabaseMajorVersion === VersionUtils.majorVersion(spark.version)) + assert(metadata.getDatabaseMinorVersion === VersionUtils.minorVersion(spark.version)) + assert(metadata.getJDBCMajorVersion === 4) + assert(metadata.getJDBCMinorVersion === 3) + assert(metadata.supportsStatementPooling === false) + assert(metadata.getRowIdLifetime === RowIdLifetime.ROWID_UNSUPPORTED) + assert(metadata.generatedKeyAlwaysReturned === false) + assert(metadata.getMaxLogicalLobSize === 0) + assert(metadata.supportsRefCursors === false) + assert(metadata.supportsSharding === false) + } } } diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala new file mode 100644 index 0000000000000..619b279310eb3 --- /dev/null +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.jdbc + +import java.sql.Types + +import org.apache.spark.sql.connect.client.jdbc.test.JdbcHelper +import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} + +class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSession + with JdbcHelper { + + override def jdbcUrl: String = s"jdbc:sc://localhost:$serverPort" + + test("get null type") { + withExecuteQuery("SELECT null") { rs => + assert(rs.next()) + assert(rs.getString(1) === null) + assert(rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "NULL") + assert(metaData.getColumnLabel(1) === "NULL") + assert(metaData.getColumnType(1) === Types.NULL) + assert(metaData.getColumnTypeName(1) === "VOID") + assert(metaData.getColumnClassName(1) === "null") + assert(metaData.isSigned(1) === false) + assert(metaData.getPrecision(1) === 0) + assert(metaData.getScale(1) === 0) + assert(metaData.getColumnDisplaySize(1) === 4) + } + } + + test("get boolean type") { + withExecuteQuery("SELECT true") { rs => + assert(rs.next()) + assert(rs.getBoolean(1) === true) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "true") + assert(metaData.getColumnLabel(1) === "true") + assert(metaData.getColumnType(1) === Types.BOOLEAN) + assert(metaData.getColumnTypeName(1) === "BOOLEAN") + assert(metaData.getColumnClassName(1) === "java.lang.Boolean") + assert(metaData.isSigned(1) === false) + assert(metaData.getPrecision(1) === 1) + assert(metaData.getScale(1) === 0) + assert(metaData.getColumnDisplaySize(1) === 5) + } + } + + test("get byte type") { + withExecuteQuery("SELECT cast(1 as byte)") { rs => + assert(rs.next()) + assert(rs.getByte(1) === 1.toByte) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "CAST(1 AS TINYINT)") + assert(metaData.getColumnLabel(1) === "CAST(1 AS TINYINT)") + assert(metaData.getColumnType(1) === Types.TINYINT) + assert(metaData.getColumnTypeName(1) === "TINYINT") + assert(metaData.getColumnClassName(1) === "java.lang.Byte") + assert(metaData.isSigned(1) === true) + assert(metaData.getPrecision(1) === 3) + assert(metaData.getScale(1) === 0) + assert(metaData.getColumnDisplaySize(1) === 4) + } + } + + test("get short type") { + withExecuteQuery("SELECT cast(1 as short)") { rs => + assert(rs.next()) + assert(rs.getShort(1) === 1.toShort) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "CAST(1 AS SMALLINT)") + assert(metaData.getColumnLabel(1) === "CAST(1 AS SMALLINT)") + assert(metaData.getColumnType(1) === Types.SMALLINT) + assert(metaData.getColumnTypeName(1) === "SMALLINT") + assert(metaData.getColumnClassName(1) === "java.lang.Short") + assert(metaData.isSigned(1) === true) + assert(metaData.getPrecision(1) === 5) + assert(metaData.getScale(1) === 0) + assert(metaData.getColumnDisplaySize(1) === 6) + } + } + + test("get int type") { + withExecuteQuery("SELECT 1") { rs => + assert(rs.next()) + assert(rs.getInt(1) === 1) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "1") + assert(metaData.getColumnLabel(1) === "1") + assert(metaData.getColumnType(1) === Types.INTEGER) + assert(metaData.getColumnTypeName(1) === "INT") + assert(metaData.getColumnClassName(1) === "java.lang.Integer") + assert(metaData.isSigned(1) === true) + assert(metaData.getPrecision(1) === 10) + assert(metaData.getScale(1) === 0) + assert(metaData.getColumnDisplaySize(1) === 11) + } + } + + test("get bigint type") { + withExecuteQuery("SELECT cast(1 as bigint)") { rs => + assert(rs.next()) + assert(rs.getLong(1) === 1L) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "CAST(1 AS BIGINT)") + assert(metaData.getColumnLabel(1) === "CAST(1 AS BIGINT)") + assert(metaData.getColumnType(1) === Types.BIGINT) + assert(metaData.getColumnTypeName(1) === "BIGINT") + assert(metaData.getColumnClassName(1) === "java.lang.Long") + assert(metaData.isSigned(1) === true) + assert(metaData.getPrecision(1) === 19) + assert(metaData.getScale(1) === 0) + assert(metaData.getColumnDisplaySize(1) === 20) + } + } + + test("get float type") { + withExecuteQuery("SELECT cast(1.2 as float)") { rs => + assert(rs.next()) + assert(rs.getFloat(1) === 1.2F) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "CAST(1.2 AS FLOAT)") + assert(metaData.getColumnLabel(1) === "CAST(1.2 AS FLOAT)") + assert(metaData.getColumnType(1) === Types.FLOAT) + assert(metaData.getColumnTypeName(1) === "FLOAT") + assert(metaData.getColumnClassName(1) === "java.lang.Float") + assert(metaData.isSigned(1) === true) + assert(metaData.getPrecision(1) === 7) + assert(metaData.getScale(1) === 7) + assert(metaData.getColumnDisplaySize(1) === 14) + } + } + + test("get double type") { + withExecuteQuery("SELECT cast(1.2 as double)") { rs => + assert(rs.next()) + assert(rs.getDouble(1) === 1.2D) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "CAST(1.2 AS DOUBLE)") + assert(metaData.getColumnLabel(1) === "CAST(1.2 AS DOUBLE)") + assert(metaData.getColumnType(1) === Types.DOUBLE) + assert(metaData.getColumnTypeName(1) === "DOUBLE") + assert(metaData.getColumnClassName(1) === "java.lang.Double") + assert(metaData.isSigned(1) === true) + assert(metaData.getPrecision(1) === 15) + assert(metaData.getScale(1) === 15) + assert(metaData.getColumnDisplaySize(1) === 24) + } + } + + test("get string type") { + withExecuteQuery("SELECT 'str'") { rs => + assert(rs.next()) + assert(rs.getString(1) === "str") + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === "str") + assert(metaData.getColumnLabel(1) === "str") + assert(metaData.getColumnType(1) === Types.VARCHAR) + assert(metaData.getColumnTypeName(1) === "STRING") + assert(metaData.getColumnClassName(1) === "java.lang.String") + assert(metaData.isSigned(1) === false) + assert(metaData.getPrecision(1) === 255) + assert(metaData.getScale(1) === 0) + assert(metaData.getColumnDisplaySize(1) === 255) + } + } +} diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala new file mode 100644 index 0000000000000..9b3aa373e93ce --- /dev/null +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client.jdbc.test + +import java.sql.{Connection, DriverManager, ResultSet, Statement} + +import scala.util.Using + +import org.apache.spark.sql.connect.client.jdbc.SparkConnectDriver + +trait JdbcHelper { + + def jdbcUrl: String + + // explicitly load the class to make it known to the DriverManager + classOf[SparkConnectDriver].getClassLoader + + def withConnection[T](f: Connection => T): T = { + Using.resource(DriverManager.getConnection(jdbcUrl)) { conn => f(conn) } + } + + def withStatement[T](f: Statement => T): T = withConnection { conn => + Using.resource(conn.createStatement()) { stmt => f(stmt) } + } + + def withExecuteQuery(query: String)(f: ResultSet => Unit): Unit = { + withStatement { stmt => + Using.resource { stmt.executeQuery(query) } { rs => f(rs) } + } + } +}