diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 7680ae383513..90343182712e 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Connection, Date, Timestamp} import java.util.Properties import java.math.BigDecimal -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SaveMode} import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -52,7 +52,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo import testImplicits._ override val db = new DatabaseOnDocker { - override val imageName = "wnameless/oracle-xe-11g:14.04.4" + override val imageName = "wnameless/oracle-xe-11g:16.04" override val env = Map( "ORACLE_ROOT_PASSWORD" -> "oracle" ) @@ -104,15 +104,18 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate(); + conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate() conn.prepareStatement( - "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate(); - conn.commit(); + "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate() + conn.commit() + + conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)").executeUpdate() + conn.commit() } test("SPARK-16625 : Importing Oracle numeric types") { - val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties); + val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties) val rows = df.collect() assert(rows.size == 1) val row = rows(0) @@ -307,4 +310,32 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo assert(values.getInt(1).equals(1)) assert(values.getBoolean(2).equals(false)) } + + test("SPARK-22303: handle BINARY_DOUBLE and BINARY_FLOAT as DoubleType and FloatType") { + val tableName = "oracle_types" + val schema = StructType(Seq( + StructField("d", DoubleType, true), + StructField("f", FloatType, true))) + val props = new Properties() + + // write it back to the table (append mode) + val data = spark.sparkContext.parallelize(Seq(Row(1.1, 2.2f))) + val dfWrite = spark.createDataFrame(data, schema) + dfWrite.write.mode(SaveMode.Append).jdbc(jdbcUrl, tableName, props) + + // read records from oracle_types + val dfRead = sqlContext.read.jdbc(jdbcUrl, tableName, new Properties) + val rows = dfRead.collect() + assert(rows.size == 1) + + // check data types + val types = dfRead.schema.map(field => field.dataType) + assert(types(0).equals(DoubleType)) + assert(types(1).equals(FloatType)) + + // check values + val values = rows(0) + assert(values.getDouble(0) === 1.1) + assert(values.getFloat(1) === 2.2f) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 71133666b324..9debc4ff8274 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -230,7 +230,6 @@ object JdbcUtils extends Logging { case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => TimestampType - case -101 => TimestampType // Value for Timestamp with Time Zone in Oracle case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 3b44c1de93a6..e3f106c41c7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -23,30 +23,36 @@ import org.apache.spark.sql.types._ private case object OracleDialect extends JdbcDialect { + private[jdbc] val BINARY_FLOAT = 100 + private[jdbc] val BINARY_DOUBLE = 101 + private[jdbc] val TIMESTAMPTZ = -101 override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.NUMERIC) { - val scale = if (null != md) md.build().getLong("scale") else 0L - size match { - // Handle NUMBER fields that have no precision/scale in special way - // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale - // For more details, please see - // https://github.com/apache/spark/pull/8780#issuecomment-145598968 - // and - // https://github.com/apache/spark/pull/8780#issuecomment-144541760 - case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts - // this to NUMERIC with -127 scale - // Not sure if there is a more robust way to identify the field as a float (or other - // numeric types that do not specify a scale. - case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - case _ => None - } - } else { - None + sqlType match { + case Types.NUMERIC => + val scale = if (null != md) md.build().getLong("scale") else 0L + size match { + // Handle NUMBER fields that have no precision/scale in special way + // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale + // For more details, please see + // https://github.com/apache/spark/pull/8780#issuecomment-145598968 + // and + // https://github.com/apache/spark/pull/8780#issuecomment-144541760 + case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts + // this to NUMERIC with -127 scale + // Not sure if there is a more robust way to identify the field as a float (or other + // numeric types that do not specify a scale. + case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + case _ => None + } + case TIMESTAMPTZ => Some(TimestampType) // Value for Timestamp with Time Zone in Oracle + case BINARY_FLOAT => Some(FloatType) // Value for OracleTypes.BINARY_FLOAT + case BINARY_DOUBLE => Some(DoubleType) // Value for OracleTypes.BINARY_DOUBLE + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 34205e0b2bf0..167b3e019002 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -815,6 +815,12 @@ class JDBCSuite extends SparkFunSuite Some(DecimalType(DecimalType.MAX_PRECISION, 10))) assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "numeric", 0, null) == Some(DecimalType(DecimalType.MAX_PRECISION, 10))) + assert(oracleDialect.getCatalystType(OracleDialect.BINARY_FLOAT, "BINARY_FLOAT", 0, null) == + Some(FloatType)) + assert(oracleDialect.getCatalystType(OracleDialect.BINARY_DOUBLE, "BINARY_DOUBLE", 0, null) == + Some(DoubleType)) + assert(oracleDialect.getCatalystType(OracleDialect.TIMESTAMPTZ, "TIMESTAMP", 0, null) == + Some(TimestampType)) } test("table exists query by jdbc dialect") {