Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KYUUBI #4316] Fix returned Timestamp values may lose precision #4318

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ class ExecutePython(
val output = response.map(_.content.getOutput()).getOrElse("")
val ename = response.map(_.content.getEname()).getOrElse("")
val evalue = response.map(_.content.getEvalue()).getOrElse("")
val traceback = response.map(_.content.getTraceback()).getOrElse(Array.empty)
val traceback = response.map(_.content.getTraceback()).getOrElse(Seq.empty)
iter =
new ArrayFetchIterator[Row](Array(Row(output, status, ename, evalue, Row(traceback: _*))))
new ArrayFetchIterator[Row](Array(Row(output, status, ename, evalue, traceback)))
setState(OperationState.FINISHED)
} else {
throw KyuubiSQLException(s"Interpret error:\n$statement\n $response")
Expand Down Expand Up @@ -210,7 +210,7 @@ case class SessionPythonWorker(
stdin.flush()
val pythonResponse = Option(stdout.readLine()).map(ExecutePython.fromJson[PythonResponse](_))
// throw exception if internal python code fail
if (internal && pythonResponse.map(_.content.status) != Some(PythonResponse.OK_STATUS)) {
if (internal && !pythonResponse.map(_.content.status).contains(PythonResponse.OK_STATUS)) {
throw KyuubiSQLException(s"Internal python code $code failure: $pythonResponse")
}
pythonResponse
Expand Down Expand Up @@ -328,7 +328,7 @@ object ExecutePython extends Logging {
}

// for test
def defaultSparkHome(): String = {
def defaultSparkHome: String = {
val homeDirFilter: FilenameFilter = (dir: File, name: String) =>
dir.isDirectory && name.contains("spark-") && !name.contains("-engine")
// get from kyuubi-server/../externals/kyuubi-download/target
Expand Down Expand Up @@ -418,7 +418,7 @@ case class PythonResponseContent(
data: Map[String, String],
ename: String,
evalue: String,
traceback: Array[String],
traceback: Seq[String],
status: String) {
def getOutput(): String = {
Option(data)
Expand All @@ -431,7 +431,7 @@ case class PythonResponseContent(
def getEvalue(): String = {
Option(evalue).getOrElse("")
}
def getTraceback(): Array[String] = {
Option(traceback).getOrElse(Array.empty)
def getTraceback(): Seq[String] = {
Option(traceback).getOrElse(Seq.empty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TProgressU
import org.apache.spark.kyuubi.{SparkProgressMonitor, SQLOperationListener}
import org.apache.spark.kyuubi.SparkUtilsHelper.redact
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.{KyuubiSQLException, Utils}
Expand Down Expand Up @@ -135,27 +136,35 @@ abstract class SparkOperation(session: Session)
spark.sparkContext.setLocalProperty

protected def withLocalProperties[T](f: => T): T = {
try {
spark.sparkContext.setJobGroup(statementId, redactedStatement, forceCancel)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, session.user)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, statementId)
schedulerPool match {
case Some(pool) =>
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, pool)
case None =>
}
if (isSessionUserSignEnabled) {
setSessionUserSign()
}
SQLConf.withExistingConf(spark.sessionState.conf) {
val originalSession = SparkSession.getActiveSession
try {
SparkSession.setActiveSession(spark)
spark.sparkContext.setJobGroup(statementId, redactedStatement, forceCancel)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, session.user)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, statementId)
schedulerPool match {
case Some(pool) =>
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, pool)
case None =>
}
if (isSessionUserSignEnabled) {
setSessionUserSign()
}

f
} finally {
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, null)
spark.sparkContext.clearJobGroup()
if (isSessionUserSignEnabled) {
clearSessionUserSign()
f
} finally {
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, null)
spark.sparkContext.clearJobGroup()
if (isSessionUserSignEnabled) {
clearSessionUserSign()
}
originalSession match {
case Some(session) => SparkSession.setActiveSession(session)
case None => SparkSession.clearActiveSession()
}
}
}
}
Expand Down Expand Up @@ -246,7 +255,7 @@ abstract class SparkOperation(session: Session)
} else {
val taken = iter.take(rowSetSize)
RowSet.toTRowSet(
taken.toList.asInstanceOf[List[Row]],
taken.toSeq.asInstanceOf[Seq[Row]],
resultSchema,
getProtocolVersion,
timeZone)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,25 @@
package org.apache.kyuubi.engine.spark.schema

import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.Timestamp
import java.time._
import java.util.Date
import java.time.ZoneId

import scala.collection.JavaConverters._

import org.apache.hive.service.rpc.thrift._
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.types._

import org.apache.kyuubi.engine.spark.schema.SchemaHelper.TIMESTAMP_NTZ
import org.apache.kyuubi.util.RowSetUtils._

object RowSet {

def toHiveString(valueAndType: (Any, DataType), nested: Boolean = false): String = {
// compatible w/ Spark 3.1 and above
val timeFormatters = HiveResult.getTimeFormatters
HiveResult.toHiveString(valueAndType, nested, timeFormatters)
}

def toTRowSet(
bytes: Array[Byte],
protocolVersion: TProtocolVersion): TRowSet = {
Expand Down Expand Up @@ -68,9 +71,9 @@ object RowSet {
}

def toRowBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
var i = 0
val rowSize = rows.length
val tRows = new java.util.ArrayList[TRow](rowSize)
var i = 0
while (i < rowSize) {
val row = rows(i)
val tRow = new TRow()
Expand Down Expand Up @@ -151,13 +154,7 @@ object RowSet {
while (i < rowSize) {
val row = rows(i)
nulls.set(i, row.isNullAt(ordinal))
val value =
if (row.isNullAt(ordinal)) {
""
} else {
toHiveString((row.get(ordinal), typ), timeZone)
}
values.add(value)
values.add(toHiveString(row.get(ordinal) -> typ))
i += 1
}
TColumn.stringVal(new TStringColumn(values, nulls))
Expand Down Expand Up @@ -238,69 +235,12 @@ object RowSet {
case _ =>
val tStrValue = new TStringValue
if (!row.isNullAt(ordinal)) {
tStrValue.setValue(
toHiveString((row.get(ordinal), types(ordinal).dataType), timeZone))
tStrValue.setValue(toHiveString(row.get(ordinal) -> types(ordinal).dataType))
}
TColumnValue.stringVal(tStrValue)
}
}

/**
* A simpler impl of Spark's toHiveString
*/
def toHiveString(dataWithType: (Any, DataType), timeZone: ZoneId): String = {
dataWithType match {
case (null, _) =>
// Only match nulls in nested type values
"null"

case (d: Date, DateType) =>
formatDate(d)

case (ld: LocalDate, DateType) =>
formatLocalDate(ld)

case (t: Timestamp, TimestampType) =>
formatTimestamp(t, Option(timeZone))

case (t: LocalDateTime, ntz) if ntz.getClass.getSimpleName.equals(TIMESTAMP_NTZ) =>
formatLocalDateTime(t)

case (i: Instant, TimestampType) =>
formatInstant(i, Option(timeZone))

case (bin: Array[Byte], BinaryType) =>
new String(bin, StandardCharsets.UTF_8)

case (decimal: java.math.BigDecimal, DecimalType()) =>
decimal.toPlainString

case (s: String, StringType) =>
// Only match string in nested type values
"\"" + s + "\""

case (d: Duration, _) => toDayTimeIntervalString(d)

case (p: Period, _) => toYearMonthIntervalString(p)

case (seq: scala.collection.Seq[_], ArrayType(typ, _)) =>
seq.map(v => (v, typ)).map(e => toHiveString(e, timeZone)).mkString("[", ",", "]")

case (m: Map[_, _], MapType(kType, vType, _)) =>
m.map { case (key, value) =>
toHiveString((key, kType), timeZone) + ":" + toHiveString((value, vType), timeZone)
}.toSeq.sorted.mkString("{", ",", "}")

case (struct: Row, StructType(fields)) =>
struct.toSeq.zip(fields).map { case (v, t) =>
s""""${t.name}":${toHiveString((v, t.dataType), timeZone)}"""
}.mkString("{", ",", "}")

case (other, _) =>
other.toString
}
}

private def toTColumn(data: Array[Byte]): TColumn = {
val values = new java.util.ArrayList[ByteBuffer](1)
values.add(ByteBuffer.wrap(data))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.time.ZoneId
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
import org.apache.spark.sql.types._

import org.apache.kyuubi.engine.spark.schema.RowSet

Expand All @@ -41,11 +41,11 @@ object SparkDatasetHelper {
val dt = DataType.fromDDL(schemaDDL)
dt match {
case StructType(Array(StructField(_, st: StructType, _, _))) =>
RowSet.toHiveString((row, st), timeZone)
RowSet.toHiveString((row, st), nested = true)
case StructType(Array(StructField(_, at: ArrayType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, at), timeZone)
RowSet.toHiveString((row.toSeq.head, at), nested = true)
case StructType(Array(StructField(_, mt: MapType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, mt), timeZone)
RowSet.toHiveString((row.toSeq.head, mt), nested = true)
case _ =>
throw new UnsupportedOperationException
}
Expand All @@ -54,7 +54,7 @@ object SparkDatasetHelper {
val cols = df.schema.map {
case sf @ StructField(name, _: StructType, _, _) =>
toHiveStringUDF(quotedCol(name), lit(sf.toDDL)).as(name)
case sf @ StructField(name, (_: MapType | _: ArrayType), _, _) =>
case sf @ StructField(name, _: MapType | _: ArrayType, _, _) =>
toHiveStringUDF(struct(quotedCol(name)), lit(sf.toDDL)).as(name)
case StructField(name, _, _, _) => quotedCol(name)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

import org.apache.kyuubi.KyuubiFunSuite
import org.apache.kyuubi.engine.spark.schema.RowSet.toHiveString

class RowSetSuite extends KyuubiFunSuite {

Expand Down Expand Up @@ -159,22 +158,22 @@ class RowSetSuite extends KyuubiFunSuite {

val decCol = cols.next().getStringVal
decCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b.isEmpty)
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === s"$i.$i")
}

val dateCol = cols.next().getStringVal
dateCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b.isEmpty)
case (b, 11) => assert(b === "NULL")
case (b, i) =>
assert(b === toHiveString((Date.valueOf(s"2018-11-${i + 1}"), DateType), zoneId))
assert(b === RowSet.toHiveString(Date.valueOf(s"2018-11-${i + 1}") -> DateType))
}

val tsCol = cols.next().getStringVal
tsCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b.isEmpty)
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b ===
toHiveString((Timestamp.valueOf(s"2018-11-17 13:33:33.$i"), TimestampType), zoneId))
RowSet.toHiveString(Timestamp.valueOf(s"2018-11-17 13:33:33.$i") -> TimestampType))
}

val binCol = cols.next().getBinaryVal
Expand All @@ -185,23 +184,21 @@ class RowSetSuite extends KyuubiFunSuite {

val arrCol = cols.next().getStringVal
arrCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "")
case (b, i) => assert(b === toHiveString(
(Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq, ArrayType(DoubleType)),
zoneId))
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === RowSet.toHiveString(
Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq -> ArrayType(DoubleType)))
}

val mapCol = cols.next().getStringVal
mapCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "")
case (b, i) => assert(b === toHiveString(
(Map(i -> java.lang.Double.valueOf(s"$i.$i")), MapType(IntegerType, DoubleType)),
zoneId))
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === RowSet.toHiveString(
Map(i -> java.lang.Double.valueOf(s"$i.$i")) -> MapType(IntegerType, DoubleType)))
}

val intervalCol = cols.next().getStringVal
intervalCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "")
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === new CalendarInterval(i, i, i).toString)
}
}
Expand Down Expand Up @@ -237,15 +234,15 @@ class RowSetSuite extends KyuubiFunSuite {
assert(r6.get(9).getStringVal.getValue === "2018-11-06")

val r7 = iter.next().getColVals
assert(r7.get(10).getStringVal.getValue === "2018-11-17 13:33:33.600")
assert(r7.get(10).getStringVal.getValue === "2018-11-17 13:33:33.6")
assert(r7.get(11).getStringVal.getValue === new String(
Array.fill[Byte](6)(6.toByte),
StandardCharsets.UTF_8))

val r8 = iter.next().getColVals
assert(r8.get(12).getStringVal.getValue === Array.fill(7)(7.7d).mkString("[", ",", "]"))
assert(r8.get(13).getStringVal.getValue ===
toHiveString((Map(7 -> 7.7d), MapType(IntegerType, DoubleType)), zoneId))
RowSet.toHiveString(Map(7 -> 7.7d) -> MapType(IntegerType, DoubleType)))

val r9 = iter.next().getColVals
assert(r9.get(14).getStringVal.getValue === new CalendarInterval(8, 8, 8).toString)
Expand Down
Loading