diff --git a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala index c927ebb6..28bbdd3d 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala @@ -698,19 +698,20 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { test("test read write Date Time") { setupDateTimeTable - val createTableSql = - s"""create or replace table $test_table_write ( - | int_c int, date_c date, time_c0 time(0), time_c1 time(1), time_c2 time(2), - | time_c3 time(3), time_c4 time(4), time_c5 time(5), time_c6 time(6), - | time_c7 time(7), time_c8 time(8), time_c9 time(9) + if (!params.useCopyUnload) { + val createTableSql = + s"""create or replace table $test_table_write ( + | int_c int, date_c date, time_c0 time(0), time_c1 time(1), time_c2 time(2), + | time_c3 time(3), time_c4 time(4), time_c5 time(5), time_c6 time(6), + | time_c7 time(7), time_c8 time(8), time_c9 time(9) )""".stripMargin - writeAndCheckForOneTable(sparkSession, thisConnectorOptionsNoTable, - test_table_date_time, "", test_table_write, Some(createTableSql), true) + writeAndCheckForOneTable(sparkSession, thisConnectorOptionsNoTable, + test_table_date_time, "", test_table_write, Some(createTableSql), true) + } } test("testTimestamp") { setupTimestampTable - // COPY UNLOAD can't be run because it only supports millisecond(0.001s). if (!params.useCopyUnload) { val result = sparkSession.sql("select * from test_table_timestamp") @@ -722,10 +723,25 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { } } + test("testTimestamp with copyUnload") { + thisConnectorOptionsNoTable += ("timestamp_ntz_output_format" -> "YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_ltz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_tz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + setupTimestampTable + val result = sparkSession.sql("select * from test_table_timestamp") + + testPushdown( + s""" SELECT * FROM ( $test_table_timestamp ) AS "SF_CONNECTOR_QUERY_ALIAS" """.stripMargin, + result, + test_table_timestamp_rows + ) + } + // Most simple case for timestamp write test("testTimestamp write") { setupTimestampTable - // COPY UNLOAD can't be run because it only supports millisecond(0.001s). if (!params.useCopyUnload) { val createTableSql = s"""create or replace table $test_table_write ( @@ -744,10 +760,36 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { } } + test("testTimestamp write with copy unload") { + thisConnectorOptionsNoTable += ("timestamp_ntz_output_format" -> "YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_ltz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_tz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + setupTimestampTable + if (params.supportMicroSecondDuringUnload) { + val createTableSql = + s"""create or replace table $test_table_write ( + | int_c int, + | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), + | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), + | + | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), + | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), + | + | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), + | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) + | )""".stripMargin + writeAndCheckForOneTable(sparkSession, thisConnectorOptionsNoTable, + test_table_timestamp, "", test_table_write, Some(createTableSql), true) + } + } + + + // test timestamp write with timezone test("testTimestamp write with timezone") { setupTimestampTable - // COPY UNLOAD can't be run because it only supports millisecond(0.001s). if (!params.useCopyUnload) { var oldValue: Option[String] = None if (thisConnectorOptionsNoTable.contains("sftimezone")) { @@ -771,15 +813,108 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { // Test conditions with (sfTimezone, sparkTimezone) val testConditions: List[(String, String)] = List( - (null, "GMT") - , (null, "America/Los_Angeles") - , ("America/New_York", "America/Los_Angeles") + (null, "GMT"), + (null, "America/Los_Angeles"), + ("America/New_York", "America/Los_Angeles") + ) + + for ((sfTimezone, sparkTimezone) <- testConditions) { + // set spark timezone + val thisSparkSession = if (sparkTimezone != null) { + TimeZone.setDefault(TimeZone.getTimeZone(sparkTimezone)) + // Avoid interference from any active sessions + val activeSessions = SparkSession.builder + .master("local") + .appName("SnowflakeSourceSuite") + .getOrCreate() + activeSessions.stop() + + SparkSession.builder + .master("local") + .appName("SnowflakeSourceSuite") + .config("spark.sql.shuffle.partitions", "6") + .config("spark.driver.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") + .config("spark.executor.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") + .config("spark.sql.session.timeZone", sparkTimezone) + .getOrCreate() + } else { + sparkSession + } + + // Set timezone option + if (sfTimezone != null) { + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + thisConnectorOptionsNoTable -= "sftimezone" + } + thisConnectorOptionsNoTable += ("sftimezone" -> sfTimezone) + } else { + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + thisConnectorOptionsNoTable -= "sftimezone" + } + } + + writeAndCheckForOneTable(thisSparkSession, thisConnectorOptionsNoTable, + test_table_timestamp, "", test_table_write, Some(createTableSql), true) + } + + // restore options for further test + thisConnectorOptionsNoTable -= "sftimezone" + if (oldValue.isDefined) { + thisConnectorOptionsNoTable += ("sftimezone" -> oldValue.get) + } + TimeZone.setDefault(oldTimezone) + } + } + + // test timestamp write with timezone + test("testTimestamp write with timezone with copy unload") { + thisConnectorOptionsNoTable += ("timestamp_ntz_output_format" -> "YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_ltz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_tz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + setupTimestampTable + if (params.supportMicroSecondDuringUnload) { + + + var oldValue: Option[String] = None + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + oldValue = Some(thisConnectorOptionsNoTable("sftimezone")) + thisConnectorOptionsNoTable -= "sftimezone" + } + val oldTimezone = TimeZone.getDefault + + val createTableSql = + s"""create or replace table $test_table_write ( + | int_c int, + | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), + | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), + | + | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), + | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), + | + | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), + | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) + | )""".stripMargin + + // Test conditions with (sfTimezone, sparkTimezone) + val testConditions: List[(String, String)] = List( + (null, "GMT"), + (null, "America/Los_Angeles"), + ("America/New_York", "America/Los_Angeles") ) for ((sfTimezone, sparkTimezone) <- testConditions) { // set spark timezone val thisSparkSession = if (sparkTimezone != null) { TimeZone.setDefault(TimeZone.getTimeZone(sparkTimezone)) + // Avoid interference from any active sessions + val activeSessions = SparkSession.builder + .master("local") + .appName("SnowflakeSourceSuite") + .getOrCreate() + activeSessions.stop() + SparkSession.builder .master("local") .appName("SnowflakeSourceSuite") @@ -2515,6 +2650,7 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { override def beforeEach(): Unit = { super.beforeEach() + sparkSession = createDefaultSparkSession } override def afterAll(): Unit = { diff --git a/src/main/scala/net/snowflake/spark/snowflake/CSVConverter.scala b/src/main/scala/net/snowflake/spark/snowflake/CSVConverter.scala index 26f39647..4086ffab 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/CSVConverter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/CSVConverter.scala @@ -17,7 +17,9 @@ package net.snowflake.spark.snowflake +import net.snowflake.spark.snowflake.Parameters.MergedParameters import org.apache.spark.sql.types.StructType + import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -28,9 +30,10 @@ object CSVConverter { private[snowflake] def convert[T: ClassTag]( partition: Iterator[String], - resultSchema: StructType + resultSchema: StructType, + parameters: MergedParameters ): Iterator[T] = { - val converter = Conversions.createRowConverter[T](resultSchema) + val converter = Conversions.createRowConverter[T](resultSchema, parameters) partition.map(s => { val fields = ArrayBuffer.empty[String] var buff = new StringBuilder diff --git a/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala b/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala index c00c2417..ac70f68d 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala @@ -22,8 +22,8 @@ import java.text._ import java.time.ZonedDateTime import java.time.format.DateTimeFormatter import java.util.{Date, TimeZone} - import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode +import net.snowflake.spark.snowflake.Parameters.MergedParameters import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -41,15 +41,19 @@ private[snowflake] object Conversions { // Note - we use a pattern with timezone in the beginning, to make sure // parsing with PATTERN_NTZ fails for PATTERN_TZLTZ strings. // Note - for JDK 1.6, we use Z ipo XX for SimpleDateFormat + // Because simpleDateFormat only support milliseconds, + // we need to refactor this and handle nano seconds field separately + private val PATTERN_TZLTZ = if (System.getProperty("java.version").startsWith("1.6.")) { - "Z yyyy-MM-dd HH:mm:ss.SSS" + "Z yyyy-MM-dd HH:mm:ss." } else { - "XX yyyy-MM-dd HH:mm:ss.SSS" + "XX yyyy-MM-dd HH:mm:ss." } // For NTZ, Snowflake serializes w/o timezone - private val PATTERN_NTZ = "yyyy-MM-dd HH:mm:ss.SSS" + // and handle nano seconds field separately during parsing + private val PATTERN_NTZ = "yyyy-MM-dd HH:mm:ss." // For DATE, simple ISO format private val PATTERN_DATE = "yyyy-MM-dd" @@ -142,9 +146,10 @@ private[snowflake] object Conversions { * the given schema to Row instances */ def createRowConverter[T: ClassTag]( - schema: StructType + schema: StructType, + parameters: MergedParameters ): Array[String] => T = { - convertRow[T](schema, _: Array[String]) + convertRow[T](schema, _: Array[String], parameters) } /** @@ -152,7 +157,9 @@ private[snowflake] object Conversions { * The schema will be used for type mappings. */ private def convertRow[T: ClassTag](schema: StructType, - fields: Array[String]): T = { + fields: Array[String], + parameters: MergedParameters + ): T = { val isIR: Boolean = isInternalRow[T]() @@ -176,7 +183,8 @@ private[snowflake] object Conversions { case ShortType => data.toShort case StringType => if (isIR) UTF8String.fromString(data) else data - case TimestampType => parseTimestamp(data, isIR) + case TimestampType => parseTimestamp(data, isIR, + parameters.supportMicroSecondDuringUnload) case _ => data } } @@ -189,11 +197,38 @@ private[snowflake] object Conversions { } } + /** * Parse a string exported from a Snowflake TIMESTAMP column */ - private def parseTimestamp(s: String, isInternalRow: Boolean): Any = { + private def parseTimestamp(s: String, + isInternalRow: Boolean, + supportMicroSeconds: Boolean = true): Any = { + + + // Need to handle the nano seconds filed separately + // valueOf only works with yyyy-[m]m-[d]d hh:mm:ss[.f...] + // so we need to do a little parsing + // When supportMicroSeconds is disabled, we should only use milliseconds field + val timestampRegex = if (supportMicroSeconds) { + """\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3,9}""".r + } + else { + """\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3}""".r + } + + val parsedTS = timestampRegex.findFirstMatchIn(s) match { + case Some(ts) => ts.toString() + case None => throw new IllegalArgumentException(s"Malformed timestamp $s") + } + + val ts = java.sql.Timestamp.valueOf(parsedTS) + + val res = new Timestamp(snowflakeTimestampFormat.parse(s).getTime) + + res.setNanos(ts.getNanos) + if (isInternalRow) DateTimeUtils.fromJavaTimestamp(res) else res } diff --git a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala index 8b0be5f6..7b2813f2 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala @@ -229,6 +229,11 @@ object Parameters { "support_share_connection" ) + // Internal option support micro second precision timestamp during copy unload + val PARAM_INTERNAL_SUPPORT_MICRO_SECOND_DURING_UNLOAD: String = knownParam( + "internal_support_micro_second_during_unload" + ) + // preactions and postactions may affect the session level setting, so connection sharing // may be enabled only when the queries in preactions and postactions are in a white list. // force_skip_pre_post_action_check_for_session_sharing is introduced if users are sure that @@ -714,6 +719,9 @@ object Parameters { def supportShareConnection: Boolean = { isTrue(parameters.getOrElse(PARAM_SUPPORT_SHARE_CONNECTION, "true")) } + def supportMicroSecondDuringUnload: Boolean = { + isTrue(parameters.getOrElse(PARAM_INTERNAL_SUPPORT_MICRO_SECOND_DURING_UNLOAD, "false")) + } def forceSkipPrePostActionsCheck: Boolean = { isTrue(parameters.getOrElse( PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING, "false")) diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala index 833ed528..7a3cf57b 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala @@ -154,7 +154,7 @@ private[snowflake] case class SnowflakeRelation( private def getRDD[T: ClassTag](statement: SnowflakeSQLStatement, resultSchema: StructType): RDD[T] = { if (params.useCopyUnload) { - getSnowflakeRDD(statement, resultSchema) + getSnowflakeRDD(statement, resultSchema, params) } else { getSnowflakeResultSetRDD(statement, resultSchema) } @@ -162,7 +162,8 @@ private[snowflake] case class SnowflakeRelation( // Get an RDD with COPY Unload private def getSnowflakeRDD[T: ClassTag](statement: SnowflakeSQLStatement, - resultSchema: StructType): RDD[T] = { + resultSchema: StructType, + params: MergedParameters): RDD[T] = { val format: SupportedFormat = if (Utils.containVariant(resultSchema)) SupportedFormat.JSON else SupportedFormat.CSV @@ -171,7 +172,9 @@ private[snowflake] case class SnowflakeRelation( format match { case SupportedFormat.CSV => - rdd.mapPartitions(CSVConverter.convert[T](_, resultSchema)) + // Need to explicitly define params in the functions scope + // otherwise would encounter task not serializable error + rdd.mapPartitions(CSVConverter.convert[T](_, resultSchema, params)) case SupportedFormat.JSON => rdd.mapPartitions(JsonConverter.convert[T](_, resultSchema)) } diff --git a/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala b/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala index 148f6e0d..b1ea5b6b 100644 --- a/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala +++ b/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala @@ -18,12 +18,14 @@ package net.snowflake.spark.snowflake import java.sql.{Date, Timestamp} - import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper import org.scalatest.FunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.types._ +import scala.:+ +import scala.collection.immutable.HashMap + /** * Unit test for data type conversions */ @@ -31,8 +33,19 @@ class ConversionsSuite extends FunSuite { val mapper = new ObjectMapper() + val commonOptions = Map("dbtable" -> "test_table", + "sfurl" -> "account.snowflakecomputing.com:443", + "sfuser" -> "username", + "sfpassword" -> "password") + + val notSupportMicro: Parameters.MergedParameters = Parameters.mergeParameters(commonOptions ++ + Map(Parameters.PARAM_INTERNAL_SUPPORT_MICRO_SECOND_DURING_UNLOAD-> "false")) + + val supportMicro: Parameters.MergedParameters = Parameters.mergeParameters(commonOptions ++ + Map(Parameters.PARAM_INTERNAL_SUPPORT_MICRO_SECOND_DURING_UNLOAD-> "true")) + test("Data should be correctly converted") { - val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema) + val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema, notSupportMicro) val doubleMin = Double.MinValue.toString val longMax = Long.MaxValue.toString // scalastyle:off @@ -110,7 +123,7 @@ class ConversionsSuite extends FunSuite { } test("Row conversion handles null values") { - val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema) + val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema, notSupportMicro) val emptyRow = List.fill(TestUtils.testSchema.length)(null).toArray[String] val nullsRow = List.fill(TestUtils.testSchema.length)(null).toArray[String] assert(convertRow(emptyRow) === Row(nullsRow: _*)) @@ -118,7 +131,7 @@ class ConversionsSuite extends FunSuite { test("Dates are correctly converted") { val convertRow = Conversions.createRowConverter[Row]( - StructType(Seq(StructField("a", DateType))) + StructType(Seq(StructField("a", DateType))), notSupportMicro ) assert( convertRow(Array("2015-07-09")) === Row(TestUtils.toDate(2015, 6, 9)) @@ -193,4 +206,88 @@ class ConversionsSuite extends FunSuite { assert(expect == result.toString()) } + + test("Data with micro-seconds and nano-seconds precision should be correctly converted"){ + val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema, supportMicro) + val doubleMin = Double.MinValue.toString + val longMax = Long.MaxValue.toString + // scalastyle:off + val unicodeString = "Unicode是樂趣" + // scalastyle:on + + val timestampString = "2014-03-01 00:00:01.123456" + + val expectedTimestampMicro: Timestamp = java.sql.Timestamp.valueOf(timestampString) + + val dateString = "2015-07-01" + val expectedDate = TestUtils.toMillis(2015, 6, 1, 0, 0, 0) + + + + val timestampString2 = "2014-03-01 00:00:01.123456789" + + val expectedTimestampMicro2: Timestamp = java.sql.Timestamp.valueOf(timestampString2) + + val dateString2 = "2015-07-01" + val expectedDate2 = TestUtils.toMillis(2015, 6, 1, 0, 0, 0) + + val convertedRow = convertRow( + Array( + "1", + dateString, + "123.45", + doubleMin, + "1.0", + "42", + longMax, + "23", + unicodeString, + timestampString + ) + ) + + val expectedRow = Row( + 1.asInstanceOf[Byte], + new Date(expectedDate), + new java.math.BigDecimal("123.45"), + Double.MinValue, + 1.0f, + 42, + Long.MaxValue, + 23.toShort, + unicodeString, + expectedTimestampMicro + ) + + val convertedRow2 = convertRow( + Array( + "1", + dateString2, + "123.45", + doubleMin, + "1.0", + "42", + longMax, + "23", + unicodeString, + timestampString2 + ) + ) + + val expectedRow2 = Row( + 1.asInstanceOf[Byte], + new Date(expectedDate2), + new java.math.BigDecimal("123.45"), + Double.MinValue, + 1.0f, + 42, + Long.MaxValue, + 23.toShort, + unicodeString, + expectedTimestampMicro2 + ) + + assert(convertedRow == expectedRow) + assert(convertedRow2 == expectedRow2) + } }