Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.{Connection, Date, Timestamp}
import java.util.Properties
import java.math.BigDecimal

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -52,7 +52,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
import testImplicits._

override val db = new DatabaseOnDocker {
override val imageName = "wnameless/oracle-xe-11g:14.04.4"
override val imageName = "wnameless/oracle-xe-11g:16.04"
override val env = Map(
"ORACLE_ROOT_PASSWORD" -> "oracle"
)
Expand Down Expand Up @@ -104,15 +104,18 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
""".stripMargin.replaceAll("\n", " "))


conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate();
conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate()
conn.prepareStatement(
"INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate();
conn.commit();
"INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate()
conn.commit()

conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)").executeUpdate()
conn.commit()
}


test("SPARK-16625 : Importing Oracle numeric types") {
val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties);
val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties)
val rows = df.collect()
assert(rows.size == 1)
val row = rows(0)
Expand Down Expand Up @@ -307,4 +310,32 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
assert(values.getInt(1).equals(1))
assert(values.getBoolean(2).equals(false))
}

test("SPARK-22303: handle BINARY_DOUBLE and BINARY_FLOAT as DoubleType and FloatType") {
val tableName = "oracle_types"
val schema = StructType(Seq(
StructField("d", DoubleType, true),
StructField("f", FloatType, true)))
val props = new Properties()

// write it back to the table (append mode)
val data = spark.sparkContext.parallelize(Seq(Row(1.1, 2.2f)))
val dfWrite = spark.createDataFrame(data, schema)
dfWrite.write.mode(SaveMode.Append).jdbc(jdbcUrl, tableName, props)

// read records from oracle_types
val dfRead = sqlContext.read.jdbc(jdbcUrl, tableName, new Properties)
val rows = dfRead.collect()
assert(rows.size == 1)

// check data types
val types = dfRead.schema.map(field => field.dataType)
assert(types(0).equals(DoubleType))
assert(types(1).equals(FloatType))

// check values
val values = rows(0)
assert(values.getDouble(0) === 1.1)
assert(values.getFloat(1) === 2.2f)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ object JdbcUtils extends Logging {
case java.sql.Types.TIMESTAMP => TimestampType
case java.sql.Types.TIMESTAMP_WITH_TIMEZONE
=> TimestampType
case -101 => TimestampType // Value for Timestamp with Time Zone in Oracle
case java.sql.Types.TINYINT => IntegerType
case java.sql.Types.VARBINARY => BinaryType
case java.sql.Types.VARCHAR => StringType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,36 @@ import org.apache.spark.sql.types._


private case object OracleDialect extends JdbcDialect {
private[jdbc] val BINARY_FLOAT = 100
private[jdbc] val BINARY_DOUBLE = 101
private[jdbc] val TIMESTAMPTZ = -101

override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle")

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.NUMERIC) {
val scale = if (null != md) md.build().getLong("scale") else 0L
size match {
// Handle NUMBER fields that have no precision/scale in special way
// because JDBC ResultSetMetaData converts this to 0 precision and -127 scale
// For more details, please see
// https://github.com/apache/spark/pull/8780#issuecomment-145598968
// and
// https://github.com/apache/spark/pull/8780#issuecomment-144541760
case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
// Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts
// this to NUMERIC with -127 scale
// Not sure if there is a more robust way to identify the field as a float (or other
// numeric types that do not specify a scale.
case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
case _ => None
}
} else {
None
sqlType match {
case Types.NUMERIC =>
val scale = if (null != md) md.build().getLong("scale") else 0L
size match {
// Handle NUMBER fields that have no precision/scale in special way
// because JDBC ResultSetMetaData converts this to 0 precision and -127 scale
// For more details, please see
// https://github.com/apache/spark/pull/8780#issuecomment-145598968
// and
// https://github.com/apache/spark/pull/8780#issuecomment-144541760
case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
// Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts
// this to NUMERIC with -127 scale
// Not sure if there is a more robust way to identify the field as a float (or other
// numeric types that do not specify a scale.
case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
case _ => None
}
case TIMESTAMPTZ => Some(TimestampType) // Value for Timestamp with Time Zone in Oracle
case BINARY_FLOAT => Some(FloatType) // Value for OracleTypes.BINARY_FLOAT
case BINARY_DOUBLE => Some(DoubleType) // Value for OracleTypes.BINARY_DOUBLE
case _ => None
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,12 @@ class JDBCSuite extends SparkFunSuite
Some(DecimalType(DecimalType.MAX_PRECISION, 10)))
assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "numeric", 0, null) ==
Some(DecimalType(DecimalType.MAX_PRECISION, 10)))
assert(oracleDialect.getCatalystType(OracleDialect.BINARY_FLOAT, "BINARY_FLOAT", 0, null) ==
Some(FloatType))
assert(oracleDialect.getCatalystType(OracleDialect.BINARY_DOUBLE, "BINARY_DOUBLE", 0, null) ==
Some(DoubleType))
assert(oracleDialect.getCatalystType(OracleDialect.TIMESTAMPTZ, "TIMESTAMP", 0, null) ==
Some(TimestampType))
}

test("table exists query by jdbc dialect") {
Expand Down