Skip to content

Commit

Permalink
[KYUUBI apache#4316] Fix returned Timestamp values may lose precision
Browse files Browse the repository at this point in the history
  • Loading branch information
pan3793 committed Feb 13, 2023
1 parent 41d9444 commit c66ad22
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,27 @@
package org.apache.kyuubi.engine.spark.schema

import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.Timestamp
import java.time._
import java.util.Date
import java.time.ZoneId

import scala.collection.JavaConverters._

import org.apache.hive.service.rpc.thrift._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter}
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.execution.HiveResult.TimeFormatters
import org.apache.spark.sql.types._

import org.apache.kyuubi.engine.spark.schema.SchemaHelper.TIMESTAMP_NTZ
import org.apache.kyuubi.util.RowSetUtils._

object RowSet {

def getTimeFormatters(timeZone: ZoneId): TimeFormatters = {
val dateFormatter = DateFormatter()
val timestampFormatter = TimestampFormatter.getFractionFormatter(timeZone)
TimeFormatters(dateFormatter, timestampFormatter)
}

def toTRowSet(
bytes: Array[Byte],
protocolVersion: TProtocolVersion): TRowSet = {
Expand Down Expand Up @@ -68,9 +73,9 @@ object RowSet {
}

def toRowBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
var i = 0
val rowSize = rows.length
val tRows = new java.util.ArrayList[TRow](rowSize)
var i = 0
while (i < rowSize) {
val row = rows(i)
val tRow = new TRow()
Expand Down Expand Up @@ -151,13 +156,8 @@ object RowSet {
while (i < rowSize) {
val row = rows(i)
nulls.set(i, row.isNullAt(ordinal))
val value =
if (row.isNullAt(ordinal)) {
""
} else {
toHiveString((row.get(ordinal), typ), timeZone)
}
values.add(value)
values.add(
HiveResult.toHiveString((row.get(ordinal), typ), false, getTimeFormatters(timeZone)))
i += 1
}
TColumn.stringVal(new TStringColumn(values, nulls))
Expand Down Expand Up @@ -239,68 +239,15 @@ object RowSet {
val tStrValue = new TStringValue
if (!row.isNullAt(ordinal)) {
tStrValue.setValue(
toHiveString((row.get(ordinal), types(ordinal).dataType), timeZone))
HiveResult.toHiveString(
(row.get(ordinal), types(ordinal).dataType),
false,
getTimeFormatters(timeZone)))
}
TColumnValue.stringVal(tStrValue)
}
}

/**
* A simpler impl of Spark's toHiveString
*/
def toHiveString(dataWithType: (Any, DataType), timeZone: ZoneId): String = {
dataWithType match {
case (null, _) =>
// Only match nulls in nested type values
"null"

case (d: Date, DateType) =>
formatDate(d)

case (ld: LocalDate, DateType) =>
formatLocalDate(ld)

case (t: Timestamp, TimestampType) =>
formatTimestamp(t)

case (t: LocalDateTime, ntz) if ntz.getClass.getSimpleName.equals(TIMESTAMP_NTZ) =>
formatLocalDateTime(t)

case (i: Instant, TimestampType) =>
formatInstant(i, Option(timeZone))

case (bin: Array[Byte], BinaryType) =>
new String(bin, StandardCharsets.UTF_8)

case (decimal: java.math.BigDecimal, DecimalType()) =>
decimal.toPlainString

case (s: String, StringType) =>
// Only match string in nested type values
"\"" + s + "\""

case (d: Duration, _) => toDayTimeIntervalString(d)

case (p: Period, _) => toYearMonthIntervalString(p)

case (seq: scala.collection.Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(e => toHiveString(e, timeZone)).mkString("[", ",", "]")

case (m: Map[_, _], MapType(kType, vType, _)) =>
m.map { case (key, value) =>
toHiveString((key, kType), timeZone) + ":" + toHiveString((value, vType), timeZone)
}.toSeq.sorted.mkString("{", ",", "}")

case (struct: Row, StructType(fields)) =>
struct.toSeq.zip(fields).map { case (v, t) =>
s""""${t.name}":${toHiveString((v, t.dataType), timeZone)}"""
}.mkString("{", ",", "}")

case (other, _) =>
other.toString
}
}

private def toTColumn(data: Array[Byte]): TColumn = {
val values = new java.util.ArrayList[ByteBuffer](1)
values.add(ByteBuffer.wrap(data))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ import java.time.ZoneId

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
import org.apache.spark.sql.types._

import org.apache.kyuubi.engine.spark.schema.RowSet

Expand All @@ -41,11 +42,11 @@ object SparkDatasetHelper {
val dt = DataType.fromDDL(schemaDDL)
dt match {
case StructType(Array(StructField(_, st: StructType, _, _))) =>
RowSet.toHiveString((row, st), timeZone)
HiveResult.toHiveString((row, st), true, RowSet.getTimeFormatters(timeZone))
case StructType(Array(StructField(_, at: ArrayType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, at), timeZone)
HiveResult.toHiveString((row.toSeq.head, at), true, RowSet.getTimeFormatters(timeZone))
case StructType(Array(StructField(_, mt: MapType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, mt), timeZone)
HiveResult.toHiveString((row.toSeq.head, mt), true, RowSet.getTimeFormatters(timeZone))
case _ =>
throw new UnsupportedOperationException
}
Expand All @@ -54,7 +55,7 @@ object SparkDatasetHelper {
val cols = df.schema.map {
case sf @ StructField(name, _: StructType, _, _) =>
toHiveStringUDF(quotedCol(name), lit(sf.toDDL)).as(name)
case sf @ StructField(name, (_: MapType | _: ArrayType), _, _) =>
case sf @ StructField(name, _: MapType | _: ArrayType, _, _) =>
toHiveStringUDF(struct(quotedCol(name)), lit(sf.toDDL)).as(name)
case StructField(name, _, _, _) => quotedCol(name)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ import scala.collection.JavaConverters._

import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

import org.apache.kyuubi.KyuubiFunSuite
import org.apache.kyuubi.engine.spark.schema.RowSet.toHiveString

class RowSetSuite extends KyuubiFunSuite {

Expand Down Expand Up @@ -159,22 +159,28 @@ class RowSetSuite extends KyuubiFunSuite {

val decCol = cols.next().getStringVal
decCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b.isEmpty)
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === s"$i.$i")
}

val dateCol = cols.next().getStringVal
dateCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b.isEmpty)
case (b, 11) => assert(b === "NULL")
case (b, i) =>
assert(b === toHiveString((Date.valueOf(s"2018-11-${i + 1}"), DateType), zoneId))
assert(b === HiveResult.toHiveString(
(Date.valueOf(s"2018-11-${i + 1}"), DateType),
false,
HiveResult.getTimeFormatters))
}

val tsCol = cols.next().getStringVal
tsCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b.isEmpty)
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b ===
toHiveString((Timestamp.valueOf(s"2018-11-17 13:33:33.$i"), TimestampType), zoneId))
HiveResult.toHiveString(
(Timestamp.valueOf(s"2018-11-17 13:33:33.$i"), TimestampType),
false,
HiveResult.getTimeFormatters))
}

val binCol = cols.next().getBinaryVal
Expand All @@ -185,23 +191,25 @@ class RowSetSuite extends KyuubiFunSuite {

val arrCol = cols.next().getStringVal
arrCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "")
case (b, i) => assert(b === toHiveString(
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === HiveResult.toHiveString(
(Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq, ArrayType(DoubleType)),
zoneId))
false,
HiveResult.getTimeFormatters))
}

val mapCol = cols.next().getStringVal
mapCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "")
case (b, i) => assert(b === toHiveString(
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === HiveResult.toHiveString(
(Map(i -> java.lang.Double.valueOf(s"$i.$i")), MapType(IntegerType, DoubleType)),
zoneId))
false,
HiveResult.getTimeFormatters))
}

val intervalCol = cols.next().getStringVal
intervalCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "")
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === new CalendarInterval(i, i, i).toString)
}
}
Expand Down Expand Up @@ -237,15 +245,18 @@ class RowSetSuite extends KyuubiFunSuite {
assert(r6.get(9).getStringVal.getValue === "2018-11-06")

val r7 = iter.next().getColVals
assert(r7.get(10).getStringVal.getValue === "2018-11-17 13:33:33.600")
assert(r7.get(10).getStringVal.getValue === "2018-11-17 13:33:33.6")
assert(r7.get(11).getStringVal.getValue === new String(
Array.fill[Byte](6)(6.toByte),
StandardCharsets.UTF_8))

val r8 = iter.next().getColVals
assert(r8.get(12).getStringVal.getValue === Array.fill(7)(7.7d).mkString("[", ",", "]"))
assert(r8.get(13).getStringVal.getValue ===
toHiveString((Map(7 -> 7.7d), MapType(IntegerType, DoubleType)), zoneId))
HiveResult.toHiveString(
(Map(7 -> 7.7d), MapType(IntegerType, DoubleType)),
false,
HiveResult.getTimeFormatters))

val r9 = iter.next().getColVals
assert(r9.get(14).getStringVal.getValue === new CalendarInterval(8, 8, 8).toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
package org.apache.kyuubi.util

import java.nio.ByteBuffer
import java.sql.Timestamp
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneId}
import java.time.{Instant, LocalDate, LocalDateTime, ZoneId}
import java.time.chrono.IsoChronology
import java.time.format.DateTimeFormatter
import java.time.format.DateTimeFormatterBuilder
import java.time.temporal.ChronoField
import java.util.{Date, Locale}
import java.util.concurrent.TimeUnit

import scala.language.implicitConversions

Expand All @@ -37,24 +34,18 @@ private[kyuubi] object RowSetUtils {
final private val SECOND_PER_HOUR: Long = SECOND_PER_MINUTE * 60L
final private val SECOND_PER_DAY: Long = SECOND_PER_HOUR * 24L

private lazy val dateFormatter = {
createDateTimeFormatterBuilder().appendPattern("yyyy-MM-dd")
.toFormatter(Locale.US)
.withChronology(IsoChronology.INSTANCE)
}
private lazy val dateFormatter = createDateTimeFormatterBuilder()
.appendPattern("yyyy-MM-dd")
.toFormatter(Locale.US)
.withChronology(IsoChronology.INSTANCE)

private lazy val legacyDateFormatter = FastDateFormat.getInstance("yyyy-MM-dd", Locale.US)

private lazy val timestampFormatter: DateTimeFormatter = {
createDateTimeFormatterBuilder().appendPattern("yyyy-MM-dd HH:mm:ss")
.appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true)
.toFormatter(Locale.US)
.withChronology(IsoChronology.INSTANCE)
}

private lazy val legacyTimestampFormatter = {
FastDateFormat.getInstance("yyyy-MM-dd HH:mm:ss.SSS", Locale.US)
}
private lazy val timestampFormatter = createDateTimeFormatterBuilder()
.appendPattern("yyyy-MM-dd HH:mm:ss")
.appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true)
.toFormatter(Locale.US)
.withChronology(IsoChronology.INSTANCE)

private def createDateTimeFormatterBuilder(): DateTimeFormatterBuilder = {
new DateTimeFormatterBuilder().parseCaseInsensitive()
Expand All @@ -77,34 +68,7 @@ private[kyuubi] object RowSetUtils {
.getOrElse(timestampFormatter.format(i))
}

def formatTimestamp(t: Timestamp): String = {
legacyTimestampFormatter.format(t)
}

implicit def bitSetToBuffer(bitSet: java.util.BitSet): ByteBuffer = {
ByteBuffer.wrap(bitSet.toByteArray)
}

def toDayTimeIntervalString(d: Duration): String = {
var rest = d.getSeconds
var sign = ""
if (d.getSeconds < 0) {
sign = "-"
rest = -rest
}
val days = TimeUnit.SECONDS.toDays(rest)
rest %= SECOND_PER_DAY
val hours = TimeUnit.SECONDS.toHours(rest)
rest %= SECOND_PER_HOUR
val minutes = TimeUnit.SECONDS.toMinutes(rest)
val seconds = rest % SECOND_PER_MINUTE
f"$sign$days $hours%02d:$minutes%02d:$seconds%02d.${d.getNano}%09d"
}

def toYearMonthIntervalString(d: Period): String = {
val years = d.getYears
val months = d.getMonths
val sign = if (years < 0 || months < 0) "-" else ""
s"$sign${Math.abs(years)}-${Math.abs(months)}"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ trait SparkDataTypeTests extends HiveJDBCTestHelper {
}
}

test("execute statement - select timestamp") {
test("execute statement - select timestamp - second") {
withJdbcStatement() { statement =>
val resultSet = statement.executeQuery("SELECT TIMESTAMP '2018-11-17 13:33:33' AS col")
assert(resultSet.next())
Expand All @@ -171,6 +171,18 @@ trait SparkDataTypeTests extends HiveJDBCTestHelper {
}
}

test("execute statement - select timestamp - millisecond") {
withJdbcStatement() { statement =>
val resultSet = statement.executeQuery("SELECT TIMESTAMP '2018-11-17 13:33:33.12345' AS col")
assert(resultSet.next())
assert(resultSet.getTimestamp("col") === Timestamp.valueOf("2018-11-17 13:33:33.12345"))
val metaData = resultSet.getMetaData
assert(metaData.getColumnType(1) === java.sql.Types.TIMESTAMP)
assert(metaData.getPrecision(1) === 29)
assert(metaData.getScale(1) === 9)
}
}

test("execute statement - select timestamp_ntz") {
assume(SPARK_ENGINE_VERSION >= "3.4")
withJdbcStatement() { statement =>
Expand Down
Loading

0 comments on commit c66ad22

Please sign in to comment.