From 91398eb49e5650c995ae64e26a4d31e75934860a Mon Sep 17 00:00:00 2001 From: Arthur-Li Date: Wed, 8 Mar 2023 11:50:48 -0500 Subject: [PATCH 01/15] Support micro seconds precisison during copy unload --- .../spark/snowflake/Conversions.scala | 28 ++++++- .../spark/snowflake/ConversionsSuite.scala | 84 +++++++++++++++++++ 2 files changed, 108 insertions(+), 4 deletions(-) diff --git a/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala b/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala index c00c2417..69d01aae 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala @@ -41,15 +41,18 @@ private[snowflake] object Conversions { // Note - we use a pattern with timezone in the beginning, to make sure // parsing with PATTERN_NTZ fails for PATTERN_TZLTZ strings. // Note - for JDK 1.6, we use Z ipo XX for SimpleDateFormat + // Because simpleDateFormat only support milliseconds, + // we need to refactor this and handle nano seconds field separately private val PATTERN_TZLTZ = if (System.getProperty("java.version").startsWith("1.6.")) { - "Z yyyy-MM-dd HH:mm:ss.SSS" + "Z yyyy-MM-dd HH:mm:ss." } else { - "XX yyyy-MM-dd HH:mm:ss.SSS" + "XX yyyy-MM-dd HH:mm:ss." } // For NTZ, Snowflake serializes w/o timezone - private val PATTERN_NTZ = "yyyy-MM-dd HH:mm:ss.SSS" + // and handle nano seconds field separately during parsing + private val PATTERN_NTZ = "yyyy-MM-dd HH:mm:ss." // For DATE, simple ISO format private val PATTERN_DATE = "yyyy-MM-dd" @@ -193,8 +196,25 @@ private[snowflake] object Conversions { * Parse a string exported from a Snowflake TIMESTAMP column */ private def parseTimestamp(s: String, isInternalRow: Boolean): Any = { + // Need to handle the nano seconds filed separately + // valueOf only works with yyyy-[m]m-[d]d hh:mm:ss[.f...] + // so we need to do a little parsing + val timestampRegex = """\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3,9}""".r + + val parsedTS = timestampRegex.findFirstMatchIn(s) match { + case Some(ts) => ts.toString() + case None => throw new IllegalArgumentException(s"Malformed timestamp $s") + } + + val ts = java.sql.Timestamp.valueOf(parsedTS) + val nanoFraction = ts.getNanos + val res = new Timestamp(snowflakeTimestampFormat.parse(s).getTime) - if (isInternalRow) DateTimeUtils.fromJavaTimestamp(res) + + res.setNanos(nanoFraction) + // Since fromJavaTimestamp and spark only support microsecond + // level precision so have to divide the nano field by 1000 + if (isInternalRow) (DateTimeUtils.fromJavaTimestamp(res) + nanoFraction/1000) else res } diff --git a/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala b/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala index 148f6e0d..eb227c2b 100644 --- a/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala +++ b/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala @@ -193,4 +193,88 @@ class ConversionsSuite extends FunSuite { assert(expect == result.toString()) } + + test("Data with micro-seconds and nano-seconds precision should be correctly converted"){ + val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema) + val doubleMin = Double.MinValue.toString + val longMax = Long.MaxValue.toString + // scalastyle:off + val unicodeString = "instacart是独角兽" + // scalastyle:on + + val timestampString = "2014-03-01 00:00:01.123456" + + val expectedTimestampMicro: Timestamp = java.sql.Timestamp.valueOf(timestampString) + + val dateString = "2015-07-01" + val expectedDate = TestUtils.toMillis(2015, 6, 1, 0, 0, 0) + + + + val timestampString2 = "2014-03-01 00:00:01.123456789" + + val expectedTimestampMicro2: Timestamp = java.sql.Timestamp.valueOf(timestampString2) + + val dateString2 = "2015-07-01" + val expectedDate2 = TestUtils.toMillis(2015, 6, 1, 0, 0, 0) + + val convertedRow = convertRow( + Array( + "1", + dateString, + "123.45", + doubleMin, + "1.0", + "42", + longMax, + "23", + unicodeString, + timestampString + ) + ) + + val expectedRow = Row( + 1.asInstanceOf[Byte], + new Date(expectedDate), + new java.math.BigDecimal("123.45"), + Double.MinValue, + 1.0f, + 42, + Long.MaxValue, + 23.toShort, + unicodeString, + expectedTimestampMicro + ) + + val convertedRow2 = convertRow( + Array( + "1", + dateString2, + "123.45", + doubleMin, + "1.0", + "42", + longMax, + "23", + unicodeString, + timestampString2 + ) + ) + + val expectedRow2 = Row( + 1.asInstanceOf[Byte], + new Date(expectedDate2), + new java.math.BigDecimal("123.45"), + Double.MinValue, + 1.0f, + 42, + Long.MaxValue, + 23.toShort, + unicodeString, + expectedTimestampMicro2 + ) + + assert(convertedRow == expectedRow) + assert(convertedRow2 == expectedRow2) + } } From 52adb3e4982e89dab26f1639f4191b892275203e Mon Sep 17 00:00:00 2001 From: Arthur-Li Date: Wed, 8 Mar 2023 12:48:10 -0500 Subject: [PATCH 02/15] Fixed typo --- .../scala/net/snowflake/spark/snowflake/ConversionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala b/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala index eb227c2b..04fd1c3b 100644 --- a/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala +++ b/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala @@ -199,7 +199,7 @@ class ConversionsSuite extends FunSuite { val doubleMin = Double.MinValue.toString val longMax = Long.MaxValue.toString // scalastyle:off - val unicodeString = "instacart是独角兽" + val unicodeString = "Unicode是樂趣" // scalastyle:on val timestampString = "2014-03-01 00:00:01.123456" From 285795a05e265408c1b8b484029f7590c7815ba0 Mon Sep 17 00:00:00 2001 From: Arthur-Li Date: Mon, 13 Mar 2023 01:24:12 -0400 Subject: [PATCH 03/15] Support copy unload in IT test timestamp --- .../SnowflakeResultSetRDDSuite.scala | 165 +++++++++--------- 1 file changed, 78 insertions(+), 87 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala index c927ebb6..3a55e5cc 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala @@ -710,111 +710,102 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { test("testTimestamp") { setupTimestampTable - // COPY UNLOAD can't be run because it only supports millisecond(0.001s). - if (!params.useCopyUnload) { - val result = sparkSession.sql("select * from test_table_timestamp") + val result = sparkSession.sql("select * from test_table_timestamp") - testPushdown( - s""" SELECT * FROM ( $test_table_timestamp ) AS "SF_CONNECTOR_QUERY_ALIAS" """.stripMargin, - result, - test_table_timestamp_rows - ) - } + testPushdown( + s""" SELECT * FROM ( $test_table_timestamp ) AS "SF_CONNECTOR_QUERY_ALIAS" """.stripMargin, + result, + test_table_timestamp_rows + ) } // Most simple case for timestamp write test("testTimestamp write") { setupTimestampTable - // COPY UNLOAD can't be run because it only supports millisecond(0.001s). - if (!params.useCopyUnload) { - val createTableSql = - s"""create or replace table $test_table_write ( - | int_c int, - | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), - | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), - | - | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), - | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), - | - | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), - | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) - | )""".stripMargin - writeAndCheckForOneTable(sparkSession, thisConnectorOptionsNoTable, - test_table_timestamp, "", test_table_write, Some(createTableSql), true) - } + val createTableSql = + s"""create or replace table $test_table_write ( + | int_c int, + | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), + | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), + | + | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), + | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), + | + | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), + | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) + | )""".stripMargin + writeAndCheckForOneTable(sparkSession, thisConnectorOptionsNoTable, + test_table_timestamp, "", test_table_write, Some(createTableSql), true) } // test timestamp write with timezone test("testTimestamp write with timezone") { setupTimestampTable - // COPY UNLOAD can't be run because it only supports millisecond(0.001s). - if (!params.useCopyUnload) { - var oldValue: Option[String] = None - if (thisConnectorOptionsNoTable.contains("sftimezone")) { - oldValue = Some(thisConnectorOptionsNoTable("sftimezone")) - thisConnectorOptionsNoTable -= "sftimezone" + var oldValue: Option[String] = None + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + oldValue = Some(thisConnectorOptionsNoTable("sftimezone")) + thisConnectorOptionsNoTable -= "sftimezone" + } + val oldTimezone = TimeZone.getDefault + + val createTableSql = + s"""create or replace table $test_table_write ( + | int_c int, + | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), + | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), + | + | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), + | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), + | + | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), + | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) + | )""".stripMargin + + // Test conditions with (sfTimezone, sparkTimezone) + val testConditions: List[(String, String)] = List( + (null, "GMT") + , (null, "America/Los_Angeles") + , ("America/New_York", "America/Los_Angeles") + ) + + for ((sfTimezone, sparkTimezone) <- testConditions) { + // set spark timezone + val thisSparkSession = if (sparkTimezone != null) { + TimeZone.setDefault(TimeZone.getTimeZone(sparkTimezone)) + SparkSession.builder + .master("local") + .appName("SnowflakeSourceSuite") + .config("spark.sql.shuffle.partitions", "6") + .config("spark.driver.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") + .config("spark.executor.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") + .config("spark.sql.session.timeZone", sparkTimezone) + .getOrCreate() + } else { + sparkSession } - val oldTimezone = TimeZone.getDefault - - val createTableSql = - s"""create or replace table $test_table_write ( - | int_c int, - | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), - | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), - | - | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), - | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), - | - | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), - | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) - | )""".stripMargin - - // Test conditions with (sfTimezone, sparkTimezone) - val testConditions: List[(String, String)] = List( - (null, "GMT") - , (null, "America/Los_Angeles") - , ("America/New_York", "America/Los_Angeles") - ) - for ((sfTimezone, sparkTimezone) <- testConditions) { - // set spark timezone - val thisSparkSession = if (sparkTimezone != null) { - TimeZone.setDefault(TimeZone.getTimeZone(sparkTimezone)) - SparkSession.builder - .master("local") - .appName("SnowflakeSourceSuite") - .config("spark.sql.shuffle.partitions", "6") - .config("spark.driver.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") - .config("spark.executor.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") - .config("spark.sql.session.timeZone", sparkTimezone) - .getOrCreate() - } else { - sparkSession + // Set timezone option + if (sfTimezone != null) { + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + thisConnectorOptionsNoTable -= "sftimezone" } - - // Set timezone option - if (sfTimezone != null) { - if (thisConnectorOptionsNoTable.contains("sftimezone")) { - thisConnectorOptionsNoTable -= "sftimezone" - } - thisConnectorOptionsNoTable += ("sftimezone" -> sfTimezone) - } else { - if (thisConnectorOptionsNoTable.contains("sftimezone")) { - thisConnectorOptionsNoTable -= "sftimezone" - } + thisConnectorOptionsNoTable += ("sftimezone" -> sfTimezone) + } else { + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + thisConnectorOptionsNoTable -= "sftimezone" } - - writeAndCheckForOneTable(thisSparkSession, thisConnectorOptionsNoTable, - test_table_timestamp, "", test_table_write, Some(createTableSql), true) } - // restore options for further test - thisConnectorOptionsNoTable -= "sftimezone" - if (oldValue.isDefined) { - thisConnectorOptionsNoTable += ("sftimezone" -> oldValue.get) - } - TimeZone.setDefault(oldTimezone) + writeAndCheckForOneTable(thisSparkSession, thisConnectorOptionsNoTable, + test_table_timestamp, "", test_table_write, Some(createTableSql), true) + } + + // restore options for further test + thisConnectorOptionsNoTable -= "sftimezone" + if (oldValue.isDefined) { + thisConnectorOptionsNoTable += ("sftimezone" -> oldValue.get) } + TimeZone.setDefault(oldTimezone) } test("testLargeResult") { From 0dedb0b37bdea5e14688dcf1de7d617dfeda003f Mon Sep 17 00:00:00 2001 From: Arthur-Li Date: Mon, 20 Mar 2023 01:08:00 -0400 Subject: [PATCH 04/15] Fixed IT test --- .../snowflake/SnowflakeResultSetRDDSuite.scala | 15 ++++++++++++--- .../snowflake/spark/snowflake/Conversions.scala | 5 ++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala index 3a55e5cc..f751a5c0 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala @@ -583,6 +583,9 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { thisConnectorOptionsNoTable += ("time_output_format" -> "HH24:MI:SS.FF") thisConnectorOptionsNoTable += ("s3maxfilesize" -> "1000001") thisConnectorOptionsNoTable += ("jdbc_query_result_format" -> "arrow") + thisConnectorOptionsNoTable += ("timestamp_ntz_output_format" -> "YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_ltz_output_format" -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_tz_output_format" -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") } test("testNumber") { @@ -763,15 +766,21 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { // Test conditions with (sfTimezone, sparkTimezone) val testConditions: List[(String, String)] = List( - (null, "GMT") - , (null, "America/Los_Angeles") - , ("America/New_York", "America/Los_Angeles") + (null, "GMT"), + (null, "America/Los_Angeles"), + ("America/New_York", "America/Los_Angeles") ) for ((sfTimezone, sparkTimezone) <- testConditions) { // set spark timezone val thisSparkSession = if (sparkTimezone != null) { TimeZone.setDefault(TimeZone.getTimeZone(sparkTimezone)) + // Avoid interference from any active sessions + val activeSessions = SparkSession.builder + .master("local") + .appName("SnowflakeSourceSuite") + .getOrCreate() + activeSessions.stop() SparkSession.builder .master("local") .appName("SnowflakeSourceSuite") diff --git a/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala b/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala index 69d01aae..d67fd2f3 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala @@ -212,9 +212,8 @@ private[snowflake] object Conversions { val res = new Timestamp(snowflakeTimestampFormat.parse(s).getTime) res.setNanos(nanoFraction) - // Since fromJavaTimestamp and spark only support microsecond - // level precision so have to divide the nano field by 1000 - if (isInternalRow) (DateTimeUtils.fromJavaTimestamp(res) + nanoFraction/1000) + + if (isInternalRow) DateTimeUtils.fromJavaTimestamp(res) else res } From 371a675e60ebf2aeae5be93b9cf460cceeae37e1 Mon Sep 17 00:00:00 2001 From: Arthur-Li Date: Mon, 20 Mar 2023 02:11:35 -0400 Subject: [PATCH 05/15] Making sure the spark session exists for other test cases --- .../snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala index f751a5c0..c0998f86 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala @@ -781,6 +781,7 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { .appName("SnowflakeSourceSuite") .getOrCreate() activeSessions.stop() + SparkSession.builder .master("local") .appName("SnowflakeSourceSuite") @@ -2515,6 +2516,7 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { override def beforeEach(): Unit = { super.beforeEach() + sparkSession = createDefaultSparkSession } override def afterAll(): Unit = { From 615128582a4516f7d5d11bcb7952fe08a2cbc3fe Mon Sep 17 00:00:00 2001 From: Arthur-Li Date: Mon, 27 Mar 2023 00:58:25 -0400 Subject: [PATCH 06/15] Added supportMicroSecs param and refactor IT test --- .../SnowflakeResultSetRDDSuite.scala | 304 +++++++++++++----- .../spark/snowflake/CSVConverter.scala | 7 +- .../spark/snowflake/Conversions.scala | 34 +- .../spark/snowflake/Parameters.scala | 8 + .../spark/snowflake/SnowflakeRelation.scala | 9 +- .../spark/snowflake/ConversionsSuite.scala | 23 +- 6 files changed, 281 insertions(+), 104 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala index c0998f86..28bbdd3d 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SnowflakeResultSetRDDSuite.scala @@ -583,9 +583,6 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { thisConnectorOptionsNoTable += ("time_output_format" -> "HH24:MI:SS.FF") thisConnectorOptionsNoTable += ("s3maxfilesize" -> "1000001") thisConnectorOptionsNoTable += ("jdbc_query_result_format" -> "arrow") - thisConnectorOptionsNoTable += ("timestamp_ntz_output_format" -> "YYYY-MM-DD HH24:MI:SS.FF6") - thisConnectorOptionsNoTable += ("timestamp_ltz_output_format" -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") - thisConnectorOptionsNoTable += ("timestamp_tz_output_format" -> "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") } test("testNumber") { @@ -701,18 +698,38 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { test("test read write Date Time") { setupDateTimeTable - val createTableSql = - s"""create or replace table $test_table_write ( - | int_c int, date_c date, time_c0 time(0), time_c1 time(1), time_c2 time(2), - | time_c3 time(3), time_c4 time(4), time_c5 time(5), time_c6 time(6), - | time_c7 time(7), time_c8 time(8), time_c9 time(9) + if (!params.useCopyUnload) { + val createTableSql = + s"""create or replace table $test_table_write ( + | int_c int, date_c date, time_c0 time(0), time_c1 time(1), time_c2 time(2), + | time_c3 time(3), time_c4 time(4), time_c5 time(5), time_c6 time(6), + | time_c7 time(7), time_c8 time(8), time_c9 time(9) )""".stripMargin - writeAndCheckForOneTable(sparkSession, thisConnectorOptionsNoTable, - test_table_date_time, "", test_table_write, Some(createTableSql), true) + writeAndCheckForOneTable(sparkSession, thisConnectorOptionsNoTable, + test_table_date_time, "", test_table_write, Some(createTableSql), true) + } } test("testTimestamp") { setupTimestampTable + if (!params.useCopyUnload) { + val result = sparkSession.sql("select * from test_table_timestamp") + + testPushdown( + s""" SELECT * FROM ( $test_table_timestamp ) AS "SF_CONNECTOR_QUERY_ALIAS" """.stripMargin, + result, + test_table_timestamp_rows + ) + } + } + + test("testTimestamp with copyUnload") { + thisConnectorOptionsNoTable += ("timestamp_ntz_output_format" -> "YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_ltz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_tz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + setupTimestampTable val result = sparkSession.sql("select * from test_table_timestamp") testPushdown( @@ -725,97 +742,214 @@ class SnowflakeResultSetRDDSuite extends IntegrationSuiteBase { // Most simple case for timestamp write test("testTimestamp write") { setupTimestampTable - val createTableSql = - s"""create or replace table $test_table_write ( - | int_c int, - | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), - | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), - | - | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), - | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), - | - | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), - | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) - | )""".stripMargin - writeAndCheckForOneTable(sparkSession, thisConnectorOptionsNoTable, - test_table_timestamp, "", test_table_write, Some(createTableSql), true) + if (!params.useCopyUnload) { + val createTableSql = + s"""create or replace table $test_table_write ( + | int_c int, + | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), + | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), + | + | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), + | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), + | + | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), + | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) + | )""".stripMargin + writeAndCheckForOneTable(sparkSession, thisConnectorOptionsNoTable, + test_table_timestamp, "", test_table_write, Some(createTableSql), true) + } + } + + test("testTimestamp write with copy unload") { + thisConnectorOptionsNoTable += ("timestamp_ntz_output_format" -> "YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_ltz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_tz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + setupTimestampTable + if (params.supportMicroSecondDuringUnload) { + val createTableSql = + s"""create or replace table $test_table_write ( + | int_c int, + | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), + | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), + | + | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), + | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), + | + | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), + | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) + | )""".stripMargin + writeAndCheckForOneTable(sparkSession, thisConnectorOptionsNoTable, + test_table_timestamp, "", test_table_write, Some(createTableSql), true) + } } + + // test timestamp write with timezone test("testTimestamp write with timezone") { setupTimestampTable - var oldValue: Option[String] = None - if (thisConnectorOptionsNoTable.contains("sftimezone")) { - oldValue = Some(thisConnectorOptionsNoTable("sftimezone")) + if (!params.useCopyUnload) { + var oldValue: Option[String] = None + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + oldValue = Some(thisConnectorOptionsNoTable("sftimezone")) + thisConnectorOptionsNoTable -= "sftimezone" + } + val oldTimezone = TimeZone.getDefault + + val createTableSql = + s"""create or replace table $test_table_write ( + | int_c int, + | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), + | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), + | + | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), + | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), + | + | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), + | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) + | )""".stripMargin + + // Test conditions with (sfTimezone, sparkTimezone) + val testConditions: List[(String, String)] = List( + (null, "GMT"), + (null, "America/Los_Angeles"), + ("America/New_York", "America/Los_Angeles") + ) + + for ((sfTimezone, sparkTimezone) <- testConditions) { + // set spark timezone + val thisSparkSession = if (sparkTimezone != null) { + TimeZone.setDefault(TimeZone.getTimeZone(sparkTimezone)) + // Avoid interference from any active sessions + val activeSessions = SparkSession.builder + .master("local") + .appName("SnowflakeSourceSuite") + .getOrCreate() + activeSessions.stop() + + SparkSession.builder + .master("local") + .appName("SnowflakeSourceSuite") + .config("spark.sql.shuffle.partitions", "6") + .config("spark.driver.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") + .config("spark.executor.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") + .config("spark.sql.session.timeZone", sparkTimezone) + .getOrCreate() + } else { + sparkSession + } + + // Set timezone option + if (sfTimezone != null) { + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + thisConnectorOptionsNoTable -= "sftimezone" + } + thisConnectorOptionsNoTable += ("sftimezone" -> sfTimezone) + } else { + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + thisConnectorOptionsNoTable -= "sftimezone" + } + } + + writeAndCheckForOneTable(thisSparkSession, thisConnectorOptionsNoTable, + test_table_timestamp, "", test_table_write, Some(createTableSql), true) + } + + // restore options for further test thisConnectorOptionsNoTable -= "sftimezone" + if (oldValue.isDefined) { + thisConnectorOptionsNoTable += ("sftimezone" -> oldValue.get) + } + TimeZone.setDefault(oldTimezone) } - val oldTimezone = TimeZone.getDefault + } - val createTableSql = - s"""create or replace table $test_table_write ( - | int_c int, - | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), - | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), - | - | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), - | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), - | - | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), - | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) - | )""".stripMargin - - // Test conditions with (sfTimezone, sparkTimezone) - val testConditions: List[(String, String)] = List( - (null, "GMT"), - (null, "America/Los_Angeles"), - ("America/New_York", "America/Los_Angeles") - ) + // test timestamp write with timezone + test("testTimestamp write with timezone with copy unload") { + thisConnectorOptionsNoTable += ("timestamp_ntz_output_format" -> "YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_ltz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + thisConnectorOptionsNoTable += ("timestamp_tz_output_format" -> + "TZHTZM YYYY-MM-DD HH24:MI:SS.FF6") + setupTimestampTable + if (params.supportMicroSecondDuringUnload) { - for ((sfTimezone, sparkTimezone) <- testConditions) { - // set spark timezone - val thisSparkSession = if (sparkTimezone != null) { - TimeZone.setDefault(TimeZone.getTimeZone(sparkTimezone)) - // Avoid interference from any active sessions - val activeSessions = SparkSession.builder - .master("local") - .appName("SnowflakeSourceSuite") - .getOrCreate() - activeSessions.stop() - - SparkSession.builder - .master("local") - .appName("SnowflakeSourceSuite") - .config("spark.sql.shuffle.partitions", "6") - .config("spark.driver.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") - .config("spark.executor.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") - .config("spark.sql.session.timeZone", sparkTimezone) - .getOrCreate() - } else { - sparkSession + + var oldValue: Option[String] = None + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + oldValue = Some(thisConnectorOptionsNoTable("sftimezone")) + thisConnectorOptionsNoTable -= "sftimezone" } + val oldTimezone = TimeZone.getDefault + + val createTableSql = + s"""create or replace table $test_table_write ( + | int_c int, + | ts_ltz_c timestamp_ltz(9), ts_ltz_c0 timestamp_ltz(0), + | ts_ltz_c3 timestamp_ltz(3), ts_ltz_c6 timestamp_ltz(6), + | + | ts_ntz_c timestamp_ntz(9), ts_ntz_c0 timestamp_ntz(0), + | ts_ntz_c3 timestamp_ntz(3), ts_ntz_c6 timestamp_ntz(6), + | + | ts_tz_c timestamp_tz(9), ts_tz_c0 timestamp_tz(0), + | ts_tz_c3 timestamp_tz(3), ts_tz_c6 timestamp_tz(6) + | )""".stripMargin + + // Test conditions with (sfTimezone, sparkTimezone) + val testConditions: List[(String, String)] = List( + (null, "GMT"), + (null, "America/Los_Angeles"), + ("America/New_York", "America/Los_Angeles") + ) - // Set timezone option - if (sfTimezone != null) { - if (thisConnectorOptionsNoTable.contains("sftimezone")) { - thisConnectorOptionsNoTable -= "sftimezone" + for ((sfTimezone, sparkTimezone) <- testConditions) { + // set spark timezone + val thisSparkSession = if (sparkTimezone != null) { + TimeZone.setDefault(TimeZone.getTimeZone(sparkTimezone)) + // Avoid interference from any active sessions + val activeSessions = SparkSession.builder + .master("local") + .appName("SnowflakeSourceSuite") + .getOrCreate() + activeSessions.stop() + + SparkSession.builder + .master("local") + .appName("SnowflakeSourceSuite") + .config("spark.sql.shuffle.partitions", "6") + .config("spark.driver.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") + .config("spark.executor.extraJavaOptions", s"-Duser.timezone=$sparkTimezone") + .config("spark.sql.session.timeZone", sparkTimezone) + .getOrCreate() + } else { + sparkSession } - thisConnectorOptionsNoTable += ("sftimezone" -> sfTimezone) - } else { - if (thisConnectorOptionsNoTable.contains("sftimezone")) { - thisConnectorOptionsNoTable -= "sftimezone" + + // Set timezone option + if (sfTimezone != null) { + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + thisConnectorOptionsNoTable -= "sftimezone" + } + thisConnectorOptionsNoTable += ("sftimezone" -> sfTimezone) + } else { + if (thisConnectorOptionsNoTable.contains("sftimezone")) { + thisConnectorOptionsNoTable -= "sftimezone" + } } - } - writeAndCheckForOneTable(thisSparkSession, thisConnectorOptionsNoTable, - test_table_timestamp, "", test_table_write, Some(createTableSql), true) - } + writeAndCheckForOneTable(thisSparkSession, thisConnectorOptionsNoTable, + test_table_timestamp, "", test_table_write, Some(createTableSql), true) + } - // restore options for further test - thisConnectorOptionsNoTable -= "sftimezone" - if (oldValue.isDefined) { - thisConnectorOptionsNoTable += ("sftimezone" -> oldValue.get) + // restore options for further test + thisConnectorOptionsNoTable -= "sftimezone" + if (oldValue.isDefined) { + thisConnectorOptionsNoTable += ("sftimezone" -> oldValue.get) + } + TimeZone.setDefault(oldTimezone) } - TimeZone.setDefault(oldTimezone) } test("testLargeResult") { diff --git a/src/main/scala/net/snowflake/spark/snowflake/CSVConverter.scala b/src/main/scala/net/snowflake/spark/snowflake/CSVConverter.scala index 26f39647..4086ffab 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/CSVConverter.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/CSVConverter.scala @@ -17,7 +17,9 @@ package net.snowflake.spark.snowflake +import net.snowflake.spark.snowflake.Parameters.MergedParameters import org.apache.spark.sql.types.StructType + import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -28,9 +30,10 @@ object CSVConverter { private[snowflake] def convert[T: ClassTag]( partition: Iterator[String], - resultSchema: StructType + resultSchema: StructType, + parameters: MergedParameters ): Iterator[T] = { - val converter = Conversions.createRowConverter[T](resultSchema) + val converter = Conversions.createRowConverter[T](resultSchema, parameters) partition.map(s => { val fields = ArrayBuffer.empty[String] var buff = new StringBuilder diff --git a/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala b/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala index d67fd2f3..ac70f68d 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Conversions.scala @@ -22,8 +22,8 @@ import java.text._ import java.time.ZonedDateTime import java.time.format.DateTimeFormatter import java.util.{Date, TimeZone} - import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode +import net.snowflake.spark.snowflake.Parameters.MergedParameters import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -43,6 +43,7 @@ private[snowflake] object Conversions { // Note - for JDK 1.6, we use Z ipo XX for SimpleDateFormat // Because simpleDateFormat only support milliseconds, // we need to refactor this and handle nano seconds field separately + private val PATTERN_TZLTZ = if (System.getProperty("java.version").startsWith("1.6.")) { "Z yyyy-MM-dd HH:mm:ss." @@ -145,9 +146,10 @@ private[snowflake] object Conversions { * the given schema to Row instances */ def createRowConverter[T: ClassTag]( - schema: StructType + schema: StructType, + parameters: MergedParameters ): Array[String] => T = { - convertRow[T](schema, _: Array[String]) + convertRow[T](schema, _: Array[String], parameters) } /** @@ -155,7 +157,9 @@ private[snowflake] object Conversions { * The schema will be used for type mappings. */ private def convertRow[T: ClassTag](schema: StructType, - fields: Array[String]): T = { + fields: Array[String], + parameters: MergedParameters + ): T = { val isIR: Boolean = isInternalRow[T]() @@ -179,7 +183,8 @@ private[snowflake] object Conversions { case ShortType => data.toShort case StringType => if (isIR) UTF8String.fromString(data) else data - case TimestampType => parseTimestamp(data, isIR) + case TimestampType => parseTimestamp(data, isIR, + parameters.supportMicroSecondDuringUnload) case _ => data } } @@ -192,14 +197,25 @@ private[snowflake] object Conversions { } } + /** * Parse a string exported from a Snowflake TIMESTAMP column */ - private def parseTimestamp(s: String, isInternalRow: Boolean): Any = { + private def parseTimestamp(s: String, + isInternalRow: Boolean, + supportMicroSeconds: Boolean = true): Any = { + + // Need to handle the nano seconds filed separately // valueOf only works with yyyy-[m]m-[d]d hh:mm:ss[.f...] // so we need to do a little parsing - val timestampRegex = """\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3,9}""".r + // When supportMicroSeconds is disabled, we should only use milliseconds field + val timestampRegex = if (supportMicroSeconds) { + """\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3,9}""".r + } + else { + """\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3}""".r + } val parsedTS = timestampRegex.findFirstMatchIn(s) match { case Some(ts) => ts.toString() @@ -207,11 +223,11 @@ private[snowflake] object Conversions { } val ts = java.sql.Timestamp.valueOf(parsedTS) - val nanoFraction = ts.getNanos + val res = new Timestamp(snowflakeTimestampFormat.parse(s).getTime) - res.setNanos(nanoFraction) + res.setNanos(ts.getNanos) if (isInternalRow) DateTimeUtils.fromJavaTimestamp(res) else res diff --git a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala index 4e081100..29886d60 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala @@ -229,6 +229,11 @@ object Parameters { "support_share_connection" ) + // Internal option support micro second precision timestamp during copy unload + val PARAM_INTERNAL_SUPPORT_MICRO_SECOND_DURING_UNLOAD: String = knownParam( + "internal_support_micro_second_during_unload" + ) + val DEFAULT_S3_MAX_FILE_SIZE: String = (10 * 1000 * 1000).toString val MIN_S3_MAX_FILE_SIZE = 1000000 @@ -706,6 +711,9 @@ object Parameters { def supportShareConnection: Boolean = { isTrue(parameters.getOrElse(PARAM_SUPPORT_SHARE_CONNECTION, "true")) } + def supportMicroSecondDuringUnload: Boolean = { + isTrue(parameters.getOrElse(PARAM_INTERNAL_SUPPORT_MICRO_SECOND_DURING_UNLOAD, "true")) + } def stagingTableNameRemoveQuotesOnly: Boolean = { isTrue(parameters.getOrElse(PARAM_INTERNAL_STAGING_TABLE_NAME_REMOVE_QUOTES_ONLY, "false")) } diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala index 833ed528..7a3cf57b 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeRelation.scala @@ -154,7 +154,7 @@ private[snowflake] case class SnowflakeRelation( private def getRDD[T: ClassTag](statement: SnowflakeSQLStatement, resultSchema: StructType): RDD[T] = { if (params.useCopyUnload) { - getSnowflakeRDD(statement, resultSchema) + getSnowflakeRDD(statement, resultSchema, params) } else { getSnowflakeResultSetRDD(statement, resultSchema) } @@ -162,7 +162,8 @@ private[snowflake] case class SnowflakeRelation( // Get an RDD with COPY Unload private def getSnowflakeRDD[T: ClassTag](statement: SnowflakeSQLStatement, - resultSchema: StructType): RDD[T] = { + resultSchema: StructType, + params: MergedParameters): RDD[T] = { val format: SupportedFormat = if (Utils.containVariant(resultSchema)) SupportedFormat.JSON else SupportedFormat.CSV @@ -171,7 +172,9 @@ private[snowflake] case class SnowflakeRelation( format match { case SupportedFormat.CSV => - rdd.mapPartitions(CSVConverter.convert[T](_, resultSchema)) + // Need to explicitly define params in the functions scope + // otherwise would encounter task not serializable error + rdd.mapPartitions(CSVConverter.convert[T](_, resultSchema, params)) case SupportedFormat.JSON => rdd.mapPartitions(JsonConverter.convert[T](_, resultSchema)) } diff --git a/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala b/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala index 04fd1c3b..b1ea5b6b 100644 --- a/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala +++ b/src/test/scala/net/snowflake/spark/snowflake/ConversionsSuite.scala @@ -18,12 +18,14 @@ package net.snowflake.spark.snowflake import java.sql.{Date, Timestamp} - import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.ObjectMapper import org.scalatest.FunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.types._ +import scala.:+ +import scala.collection.immutable.HashMap + /** * Unit test for data type conversions */ @@ -31,8 +33,19 @@ class ConversionsSuite extends FunSuite { val mapper = new ObjectMapper() + val commonOptions = Map("dbtable" -> "test_table", + "sfurl" -> "account.snowflakecomputing.com:443", + "sfuser" -> "username", + "sfpassword" -> "password") + + val notSupportMicro: Parameters.MergedParameters = Parameters.mergeParameters(commonOptions ++ + Map(Parameters.PARAM_INTERNAL_SUPPORT_MICRO_SECOND_DURING_UNLOAD-> "false")) + + val supportMicro: Parameters.MergedParameters = Parameters.mergeParameters(commonOptions ++ + Map(Parameters.PARAM_INTERNAL_SUPPORT_MICRO_SECOND_DURING_UNLOAD-> "true")) + test("Data should be correctly converted") { - val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema) + val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema, notSupportMicro) val doubleMin = Double.MinValue.toString val longMax = Long.MaxValue.toString // scalastyle:off @@ -110,7 +123,7 @@ class ConversionsSuite extends FunSuite { } test("Row conversion handles null values") { - val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema) + val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema, notSupportMicro) val emptyRow = List.fill(TestUtils.testSchema.length)(null).toArray[String] val nullsRow = List.fill(TestUtils.testSchema.length)(null).toArray[String] assert(convertRow(emptyRow) === Row(nullsRow: _*)) @@ -118,7 +131,7 @@ class ConversionsSuite extends FunSuite { test("Dates are correctly converted") { val convertRow = Conversions.createRowConverter[Row]( - StructType(Seq(StructField("a", DateType))) + StructType(Seq(StructField("a", DateType))), notSupportMicro ) assert( convertRow(Array("2015-07-09")) === Row(TestUtils.toDate(2015, 6, 9)) @@ -195,7 +208,7 @@ class ConversionsSuite extends FunSuite { } test("Data with micro-seconds and nano-seconds precision should be correctly converted"){ - val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema) + val convertRow = Conversions.createRowConverter[Row](TestUtils.testSchema, supportMicro) val doubleMin = Double.MinValue.toString val longMax = Long.MaxValue.toString // scalastyle:off From afed2e86f366b86d536b2a05639ba1299d222dbd Mon Sep 17 00:00:00 2001 From: Mingli Rui <63472932+sfc-gh-mrui@users.noreply.github.com> Date: Wed, 15 Mar 2023 13:08:52 -0700 Subject: [PATCH 07/15] SNOW-760569 Bump spark connector and depenencies versions (#493) Spark connector: 2.11.2 JDBC: 2.13.28 --- .github/workflows/ClusterTest.yml | 4 ++-- ClusterTest/build.sbt | 4 ++-- build.sbt | 6 +++--- src/main/scala/net/snowflake/spark/snowflake/Utils.scala | 4 ++-- .../scala/net/snowflake/spark/snowflake/UtilsSuite.scala | 3 +++ 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ClusterTest.yml b/.github/workflows/ClusterTest.yml index cdee8a04..9319289e 100644 --- a/.github/workflows/ClusterTest.yml +++ b/.github/workflows/ClusterTest.yml @@ -22,8 +22,8 @@ jobs: DOCKER_IMAGE_TAG: 'snowflakedb/spark-base:3.3.0' TEST_SCALA_VERSION: '2.12' TEST_COMPILE_SCALA_VERSION: '2.12.11' - TEST_SPARK_CONNECTOR_VERSION: '2.11.1' - TEST_JDBC_VERSION: '3.13.24' + TEST_SPARK_CONNECTOR_VERSION: '2.11.2' + TEST_JDBC_VERSION: '3.13.28' steps: - uses: actions/checkout@v2 diff --git a/ClusterTest/build.sbt b/ClusterTest/build.sbt index b7552abc..ff6d3b44 100644 --- a/ClusterTest/build.sbt +++ b/ClusterTest/build.sbt @@ -14,7 +14,7 @@ * limitations under the License. */ -val sparkConnectorVersion = "2.11.1" +val sparkConnectorVersion = "2.11.2" val scalaVersionMajor = "2.12" val sparkVersionMajor = "3.3" val sparkVersion = s"${sparkVersionMajor}.0" @@ -37,7 +37,7 @@ lazy val root = project.withId("spark-snowflake").in(file(".")) "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", libraryDependencies ++= Seq( "net.snowflake" % "snowflake-ingest-sdk" % "0.10.8", - "net.snowflake" % "snowflake-jdbc" % "3.13.24", + "net.snowflake" % "snowflake-jdbc" % "3.13.28", // "net.snowflake" %% "spark-snowflake" % "2.8.0-spark_3.0", // "com.google.guava" % "guava" % "14.0.1" % Test, // "org.scalatest" %% "scalatest" % "3.0.5" % Test, diff --git a/build.sbt b/build.sbt index 6dee4e2e..8ef9b115 100644 --- a/build.sbt +++ b/build.sbt @@ -26,7 +26,7 @@ val testSparkVersion = sys.props.get("spark.testVersion").getOrElse("3.3.0") * Tests/jenkins/BumpUpSparkConnectorVersion/run.sh * in snowflake repository. */ -val sparkConnectorVersion = "2.11.1" +val sparkConnectorVersion = "2.11.2" lazy val ItTest = config("it") extend Test @@ -60,12 +60,12 @@ lazy val root = project.withId("spark-snowflake").in(file(".")) "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", libraryDependencies ++= Seq( "net.snowflake" % "snowflake-ingest-sdk" % "0.10.8", - "net.snowflake" % "snowflake-jdbc" % "3.13.24", + "net.snowflake" % "snowflake-jdbc" % "3.13.28", "org.scalatest" %% "scalatest" % "3.1.1" % Test, "org.mockito" % "mockito-core" % "1.10.19" % Test, "org.apache.commons" % "commons-lang3" % "3.5" % "provided", // For test to read/write from postgresql - "org.postgresql" % "postgresql" % "42.4.1" % Test, + "org.postgresql" % "postgresql" % "42.5.4" % Test, // Below is for Spark Streaming from Kafka test only // "org.apache.spark" %% "spark-sql-kafka-0-10" % "2.4.0", "org.apache.spark" %% "spark-core" % testSparkVersion % "provided, test", diff --git a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala index 95f58abf..0a222352 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala @@ -55,12 +55,12 @@ object Utils { */ val SNOWFLAKE_SOURCE_SHORT_NAME = "snowflake" - val VERSION = "2.11.1" + val VERSION = "2.11.2" /** * The certified JDBC version to work with this spark connector version. */ - val CERTIFIED_JDBC_VERSION = "3.13.24" + val CERTIFIED_JDBC_VERSION = "3.13.28" /** * Important: diff --git a/src/test/scala/net/snowflake/spark/snowflake/UtilsSuite.scala b/src/test/scala/net/snowflake/spark/snowflake/UtilsSuite.scala index 433cd97d..86f8e6d2 100644 --- a/src/test/scala/net/snowflake/spark/snowflake/UtilsSuite.scala +++ b/src/test/scala/net/snowflake/spark/snowflake/UtilsSuite.scala @@ -180,4 +180,7 @@ class UtilsSuite extends FunSuite with Matchers { } } + test("verify JDBC version is updated for release") { + assert(Utils.CERTIFIED_JDBC_VERSION.equals(Utils.jdbcVersion)) + } } From ded0e0d31523407605d9d7dc534d2c9d73c0a61b Mon Sep 17 00:00:00 2001 From: Mingli Rui <63472932+sfc-gh-mrui@users.noreply.github.com> Date: Mon, 20 Mar 2023 08:58:48 -0700 Subject: [PATCH 08/15] SNOW-760569 Upgrade to use JDBC 3.13.29 (#497) --- .github/workflows/ClusterTest.yml | 2 +- ClusterTest/build.sbt | 2 +- build.sbt | 2 +- src/main/scala/net/snowflake/spark/snowflake/Utils.scala | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ClusterTest.yml b/.github/workflows/ClusterTest.yml index 9319289e..6a25eb50 100644 --- a/.github/workflows/ClusterTest.yml +++ b/.github/workflows/ClusterTest.yml @@ -23,7 +23,7 @@ jobs: TEST_SCALA_VERSION: '2.12' TEST_COMPILE_SCALA_VERSION: '2.12.11' TEST_SPARK_CONNECTOR_VERSION: '2.11.2' - TEST_JDBC_VERSION: '3.13.28' + TEST_JDBC_VERSION: '3.13.29' steps: - uses: actions/checkout@v2 diff --git a/ClusterTest/build.sbt b/ClusterTest/build.sbt index ff6d3b44..32ab61da 100644 --- a/ClusterTest/build.sbt +++ b/ClusterTest/build.sbt @@ -37,7 +37,7 @@ lazy val root = project.withId("spark-snowflake").in(file(".")) "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", libraryDependencies ++= Seq( "net.snowflake" % "snowflake-ingest-sdk" % "0.10.8", - "net.snowflake" % "snowflake-jdbc" % "3.13.28", + "net.snowflake" % "snowflake-jdbc" % "3.13.29", // "net.snowflake" %% "spark-snowflake" % "2.8.0-spark_3.0", // "com.google.guava" % "guava" % "14.0.1" % Test, // "org.scalatest" %% "scalatest" % "3.0.5" % Test, diff --git a/build.sbt b/build.sbt index 8ef9b115..f0c73607 100644 --- a/build.sbt +++ b/build.sbt @@ -60,7 +60,7 @@ lazy val root = project.withId("spark-snowflake").in(file(".")) "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", libraryDependencies ++= Seq( "net.snowflake" % "snowflake-ingest-sdk" % "0.10.8", - "net.snowflake" % "snowflake-jdbc" % "3.13.28", + "net.snowflake" % "snowflake-jdbc" % "3.13.29", "org.scalatest" %% "scalatest" % "3.1.1" % Test, "org.mockito" % "mockito-core" % "1.10.19" % Test, "org.apache.commons" % "commons-lang3" % "3.5" % "provided", diff --git a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala index 0a222352..7380269e 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala @@ -60,7 +60,7 @@ object Utils { /** * The certified JDBC version to work with this spark connector version. */ - val CERTIFIED_JDBC_VERSION = "3.13.28" + val CERTIFIED_JDBC_VERSION = "3.13.29" /** * Important: From f6b8e6c335482e6ec30a215753811f755e4a322f Mon Sep 17 00:00:00 2001 From: Mingli Rui <63472932+sfc-gh-mrui@users.noreply.github.com> Date: Wed, 19 Apr 2023 11:38:22 -0700 Subject: [PATCH 09/15] SNOW-763124 Support uploading files with down-scoped token for Gcs (#502) * SNOW-763124 Support uploading files with down-scoped token for Snowflake Gcs accounts * Update JDBC and SC version --- .github/workflows/ClusterTest.yml | 4 +- ClusterTest/build.sbt | 4 +- build.sbt | 4 +- .../spark/snowflake/CloudStorageSuite.scala | 281 ++++++++++++++++++ .../net/snowflake/spark/snowflake/Utils.scala | 4 +- .../snowflake/io/CloudStorageOperations.scala | 54 +++- 6 files changed, 327 insertions(+), 24 deletions(-) create mode 100644 src/it/scala/net/snowflake/spark/snowflake/CloudStorageSuite.scala diff --git a/.github/workflows/ClusterTest.yml b/.github/workflows/ClusterTest.yml index 6a25eb50..7f39d470 100644 --- a/.github/workflows/ClusterTest.yml +++ b/.github/workflows/ClusterTest.yml @@ -22,8 +22,8 @@ jobs: DOCKER_IMAGE_TAG: 'snowflakedb/spark-base:3.3.0' TEST_SCALA_VERSION: '2.12' TEST_COMPILE_SCALA_VERSION: '2.12.11' - TEST_SPARK_CONNECTOR_VERSION: '2.11.2' - TEST_JDBC_VERSION: '3.13.29' + TEST_SPARK_CONNECTOR_VERSION: '2.11.3' + TEST_JDBC_VERSION: '3.13.30' steps: - uses: actions/checkout@v2 diff --git a/ClusterTest/build.sbt b/ClusterTest/build.sbt index 32ab61da..48505a7b 100644 --- a/ClusterTest/build.sbt +++ b/ClusterTest/build.sbt @@ -14,7 +14,7 @@ * limitations under the License. */ -val sparkConnectorVersion = "2.11.2" +val sparkConnectorVersion = "2.11.3" val scalaVersionMajor = "2.12" val sparkVersionMajor = "3.3" val sparkVersion = s"${sparkVersionMajor}.0" @@ -37,7 +37,7 @@ lazy val root = project.withId("spark-snowflake").in(file(".")) "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", libraryDependencies ++= Seq( "net.snowflake" % "snowflake-ingest-sdk" % "0.10.8", - "net.snowflake" % "snowflake-jdbc" % "3.13.29", + "net.snowflake" % "snowflake-jdbc" % "3.13.30", // "net.snowflake" %% "spark-snowflake" % "2.8.0-spark_3.0", // "com.google.guava" % "guava" % "14.0.1" % Test, // "org.scalatest" %% "scalatest" % "3.0.5" % Test, diff --git a/build.sbt b/build.sbt index f0c73607..d1ee789f 100644 --- a/build.sbt +++ b/build.sbt @@ -26,7 +26,7 @@ val testSparkVersion = sys.props.get("spark.testVersion").getOrElse("3.3.0") * Tests/jenkins/BumpUpSparkConnectorVersion/run.sh * in snowflake repository. */ -val sparkConnectorVersion = "2.11.2" +val sparkConnectorVersion = "2.11.3" lazy val ItTest = config("it") extend Test @@ -60,7 +60,7 @@ lazy val root = project.withId("spark-snowflake").in(file(".")) "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", libraryDependencies ++= Seq( "net.snowflake" % "snowflake-ingest-sdk" % "0.10.8", - "net.snowflake" % "snowflake-jdbc" % "3.13.29", + "net.snowflake" % "snowflake-jdbc" % "3.13.30", "org.scalatest" %% "scalatest" % "3.1.1" % Test, "org.mockito" % "mockito-core" % "1.10.19" % Test, "org.apache.commons" % "commons-lang3" % "3.5" % "provided", diff --git a/src/it/scala/net/snowflake/spark/snowflake/CloudStorageSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/CloudStorageSuite.scala new file mode 100644 index 00000000..3023014d --- /dev/null +++ b/src/it/scala/net/snowflake/spark/snowflake/CloudStorageSuite.scala @@ -0,0 +1,281 @@ +/* + * Copyright 2015-2019 Snowflake Computing + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package net.snowflake.spark.snowflake + +import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_NAME +import net.snowflake.spark.snowflake.test.TestHook +import org.apache.spark.sql.{DataFrame, SaveMode} + +// scalastyle:off println +class CloudStorageSuite extends IntegrationSuiteBase { + + import testImplicits._ + + private val test_table1: String = s"test_temp_table_$randomSuffix" + private val test_table_large_result: String = + s"test_table_large_result_$randomSuffix" + private val test_table_write: String = s"test_table_write_$randomSuffix" + private lazy val localDF: DataFrame = Seq((1000, "str11"), (2000, "str22")).toDF("c1", "c2") + + private val largeStringValue = + s"""spark_connector_test_large_result_1234567890 + |spark_connector_test_large_result_1234567890 + |spark_connector_test_large_result_1234567890 + |spark_connector_test_large_result_1234567890 + |spark_connector_test_large_result_1234567890 + |spark_connector_test_large_result_1234567890 + |spark_connector_test_large_result_1234567890 + |spark_connector_test_large_result_1234567890 + |""".stripMargin.filter(_ >= ' ') + private val LARGE_TABLE_ROW_COUNT = 900000 + lazy val setupLargeResultTable = { + jdbcUpdate( + s"""create or replace table $test_table_large_result ( + | int_c int, c_string string(1024) )""".stripMargin) + + jdbcUpdate( + s"""insert into $test_table_large_result select + | row_number() over (order by seq4()) - 1, '$largeStringValue' + | from table(generator(rowcount => $LARGE_TABLE_ROW_COUNT))""".stripMargin) + true + } + + override def afterAll(): Unit = { + try { + jdbcUpdate(s"drop table if exists $test_table1") + jdbcUpdate(s"drop table if exists $test_table_large_result") + jdbcUpdate(s"drop table if exists $test_table_write") + } finally { + TestHook.disableTestHook() + super.afterAll() + } + } + + override def beforeAll(): Unit = { + super.beforeAll() + + jdbcUpdate(s"create or replace table $test_table1(c1 int, c2 string)") + jdbcUpdate(s"insert into $test_table1 values (100, 'str1'),(200, 'str2')") + } + + private def getHashAgg(tableName: String): java.math.BigDecimal = + sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("query", s"select HASH_AGG(*) from $tableName") + .load() + .collect()(0).getDecimal(0) + + private def getRowCount(tableName: String): Long = + sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", tableName) + .load() + .count() + + test("write a small DataFrame to GCS with down-scoped-token") { + // Only run this test on GCS + if ("gcp".equals(System.getenv("SNOWFLAKE_TEST_ACCOUNT"))) { + val df = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table1) + .load() + + // write a small DataFrame to a snowflake table + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table_write) + // GCS_USE_DOWNSCOPED_CREDENTIAL is not a public parameter, user can't set it. + // The default value of GCS_USE_DOWNSCOPED_CREDENTIAL will be true from Dec 2023 + // The option set can be removed after Dec 2023 + .option("GCS_USE_DOWNSCOPED_CREDENTIAL", "true") + .mode(SaveMode.Overwrite) + .save() + + // Check the source table and target table has same agg_hash. + assert(getHashAgg(test_table1) == getHashAgg(test_table_write)) + } else { + println("skip test for non-GCS platform: " + + "write a small DataFrame to GCS with down-scoped-token") + } + } + + test("write a big DataFrame to GCS with down-scoped-token") { + // Only run this test on GCS + if ("gcp".equals(System.getenv("SNOWFLAKE_TEST_ACCOUNT"))) { + setupLargeResultTable + val df = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("partition_size_in_mb", 1) // generate multiple partitions + .option("dbtable", test_table_large_result) + .load() + + // write a small DataFrame to a snowflake table + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table_write) + // GCS_USE_DOWNSCOPED_CREDENTIAL is not a public parameter, user can't set it. + // The default value of GCS_USE_DOWNSCOPED_CREDENTIAL will be true from Dec 2023 + // The option set can be removed after Dec 2023 + .option("GCS_USE_DOWNSCOPED_CREDENTIAL", "true") + .mode(SaveMode.Overwrite) + .save() + + // Check the source table and target table has same agg_hash. + assert(getHashAgg(test_table_large_result) == getHashAgg(test_table_write)) + } else { + println("skip test for non-GCS platform: " + + "write a big DataFrame to GCS with down-scoped-token") + } + } + + test("write a empty DataFrame to GCS with down-scoped-token") { + // Only run this test on GCS + if ("gcp".equals(System.getenv("SNOWFLAKE_TEST_ACCOUNT"))) { + val df = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("query", s"select * from $test_table1 where 1 = 2") + .load() + + // write a small DataFrame to a snowflake table + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table_write) + // GCS_USE_DOWNSCOPED_CREDENTIAL is not a public parameter, user can't set it. + // The default value of GCS_USE_DOWNSCOPED_CREDENTIAL will be true from Dec 2023 + // The option set can be removed after Dec 2023 + .option("GCS_USE_DOWNSCOPED_CREDENTIAL", "true") + .mode(SaveMode.Overwrite) + .save() + + // Check the source table and target table has same agg_hash. + assert(getRowCount(test_table_write) == 0) + } else { + println("skip test for non-GCS platform: " + + "write a empty DataFrame to GCS with down-scoped-token") + } + } + + // GCS_USE_DOWNSCOPED_CREDENTIAL is not a public parameter, user can't set it. + // Only the snowflake test account can set it for testing purpose. + // From Dec 2023, GCS_USE_DOWNSCOPED_CREDENTIAL may be configured as true for all deployments + // and this test case can be removed at that time. + test("write a small DataFrame to GCS with presigned-url (Can be removed by Dec 2023)") { + // Only run this test on GCS + if ("gcp".equals(System.getenv("SNOWFLAKE_TEST_ACCOUNT"))) { + val df = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table1) + .load() + + // write a small DataFrame to a snowflake table + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table_write) + // GCS_USE_DOWNSCOPED_CREDENTIAL is not a public parameter, user can't set it. + // The default value of GCS_USE_DOWNSCOPED_CREDENTIAL will be true from Dec 2023 + // The option set can be removed after Dec 2023 + .option("GCS_USE_DOWNSCOPED_CREDENTIAL", "false") + .mode(SaveMode.Overwrite) + .save() + + // Check the source table and target table has same agg_hash. + assert(getHashAgg(test_table1) == getHashAgg(test_table_write)) + } else { + println("skip test for non-GCS platform: " + + "write a small DataFrame to GCS with presigned-url (Can be removed by Dec 2023)") + } + } + + // GCS_USE_DOWNSCOPED_CREDENTIAL is not a public parameter, user can't set it. + // Only the snowflake test account can set it for testing purpose. + // From Dec 2023, GCS_USE_DOWNSCOPED_CREDENTIAL may be configured as true for all deployments + // and this test case can be removed at that time. + test("write a big DataFrame to GCS with presigned-url (Can be removed by Dec 2023)") { + // Only run this test on GCS + if ("gcp".equals(System.getenv("SNOWFLAKE_TEST_ACCOUNT"))) { + setupLargeResultTable + val df = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("partition_size_in_mb", 1) // generate multiple partitions + .option("dbtable", test_table_large_result) + .load() + + // write a small DataFrame to a snowflake table + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table_write) + // GCS_USE_DOWNSCOPED_CREDENTIAL is not a public parameter, user can't set it. + // The default value of GCS_USE_DOWNSCOPED_CREDENTIAL will be true from Dec 2023 + // The option set can be removed after Dec 2023 + .option("GCS_USE_DOWNSCOPED_CREDENTIAL", "false") + .mode(SaveMode.Overwrite) + .save() + + // Check the source table and target table has same agg_hash. + assert(getHashAgg(test_table_large_result) == getHashAgg(test_table_write)) + } else { + println("skip test for non-GCS platform: " + + "write a big DataFrame to GCS with presigned-url (Can be removed by Dec 2023)") + } + } + + // GCS_USE_DOWNSCOPED_CREDENTIAL is not a public parameter, user can't set it. + // Only the snowflake test account can set it for testing purpose. + // From Dec 2023, GCS_USE_DOWNSCOPED_CREDENTIAL may be configured as true for all deployments + // and this test case can be removed at that time. + test("write a empty DataFrame to GCS with presigned-url (Can be removed by Dec 2023)") { + // Only run this test on GCS + if ("gcp".equals(System.getenv("SNOWFLAKE_TEST_ACCOUNT"))) { + val df = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("query", s"select * from $test_table1 where 1 = 2") + .load() + + // write a small DataFrame to a snowflake table + df.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_table_write) + // GCS_USE_DOWNSCOPED_CREDENTIAL is not a public parameter, user can't set it. + // The default value of GCS_USE_DOWNSCOPED_CREDENTIAL will be true from Dec 2023 + // The option set can be removed after Dec 2023 + .option("GCS_USE_DOWNSCOPED_CREDENTIAL", "false") + .mode(SaveMode.Overwrite) + .save() + + // Check the source table and target table has same agg_hash. + assert(getRowCount(test_table_write) == 0) + } else { + println("skip test for non-GCS platform: " + + "write a empty DataFrame to GCS with presigned-url (Can be removed by Dec 2023)") + } + } +} +// scalastyle:on println diff --git a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala index 7380269e..b867f690 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala @@ -55,12 +55,12 @@ object Utils { */ val SNOWFLAKE_SOURCE_SHORT_NAME = "snowflake" - val VERSION = "2.11.2" + val VERSION = "2.11.3" /** * The certified JDBC version to work with this spark connector version. */ - val CERTIFIED_JDBC_VERSION = "3.13.29" + val CERTIFIED_JDBC_VERSION = "3.13.30" /** * Important: diff --git a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala index 8005ab8f..5ad5bcce 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/io/CloudStorageOperations.scala @@ -738,6 +738,7 @@ sealed trait CloudStorage { .setSnowflakeFileTransferMetadata(fileTransferMetadata.get) .setUploadStream(inStream) .setRequireCompress(compress) + .setDestFileName(fileName) .setOcspMode(OCSPMode.FAIL_OPEN) .setProxyProperties(proxyProperties) .build()) @@ -1745,19 +1746,27 @@ case class InternalGcsStorage(param: MergedParameters, : List[SnowflakeFileTransferMetadata] = { CloudStorageOperations.log.info( s"""${SnowflakeResultSetRDD.MASTER_LOG_PREFIX}: - | Begin to retrieve pre-signed URL for + | Begin to retrieve pre-signed URL or down-scoped token for | ${data.getNumPartitions} files by calling - | PUT command for each file. + | PUT command. |""".stripMargin.filter(_ >= ' ')) - var result = new ListBuffer[SnowflakeFileTransferMetadata]() + val result = new ListBuffer[SnowflakeFileTransferMetadata]() val startTime = System.currentTimeMillis() val printStep = 1000 - // Loop to execute one PUT command for one pre-signed URL. - // This is because GCS doesn't support to generate pre-signed for - // prefix(path). If GCS supports it, this part can be enhanced. - for (index <- 0 until fileCount) { + // If pre-signed URL is used to upload file, need to execute a dummy PUT command per file + // because file needs a pre-signed URL for upload. If down-scoped token is used to upload file, + // the token can be used to upload multiple files, so only need to execute one dummy + // PUT command to get the down-scoped token. NOTE: + // 1. The pre-signed URL or down-scoped token is encapsulated in SnowflakeFileTransferMetadata + // Spark connector doesn't need to touch it directly. + // 2. If down-scoped token is used, SnowflakeFileTransferMetadata.isForOneFile will be false. + // 3. Spark connector can know whether pre-signed URL is used after getting the first + // SnowflakeFileTransferMetadata. + var index = 0 + var useDownScopedToken = false + while (index < fileCount && !useDownScopedToken) { val fileName = getFileName(index, format, compress) val dummyDir = s"/dummy_put_${index}_of_$fileCount" val putCommand = s"put file://$dummyDir/$fileName @$stageName/$dir" @@ -1769,18 +1778,29 @@ case class InternalGcsStorage(param: MergedParameters, new SFStatement(connection.getSfSession) ).getFileTransferMetadatas .asScala - .map(oneMetadata => result += oneMetadata) + .foreach(oneMetadata => { + if (!oneMetadata.isForOneFile) { + CloudStorageOperations.log.info( + s"""${SnowflakeResultSetRDD.MASTER_LOG_PREFIX}: + | Upload file to GCP with down-scoped token instead of pre-signed URL. + |""".stripMargin.filter(_ >= ' ')) + useDownScopedToken = true + } + result.append(oneMetadata) + }) // Output time for retrieving every 1000 pre-signed URLs // to indicate the progress for big data. if ((index % printStep) == (printStep - 1)) { - StorageUtils.logPresignedUrlGenerateProgress(data.getNumPartitions, index + 1, startTime) + StorageUtils.logPresignedUrlGenerateProgress( + data.getNumPartitions, index + 1, startTime, useDownScopedToken) } + index += 1 } // Output the total time for retrieving pre-signed URLs - StorageUtils.logPresignedUrlGenerateProgress(data.getNumPartitions, - data.getNumPartitions, startTime) + StorageUtils.logPresignedUrlGenerateProgress( + data.getNumPartitions, index, startTime, useDownScopedToken) result.toList } @@ -1839,14 +1859,15 @@ case class InternalGcsStorage(param: MergedParameters, /////////////////////////////////////////////////////////////////////// } + val result = fileUploadResults.collect().toList + val endTime = System.currentTimeMillis() CloudStorageOperations.log.info( s"""${SnowflakeResultSetRDD.MASTER_LOG_PREFIX}: | Finish uploading data for ${data.getNumPartitions} partitions in | ${Utils.getTimeString(endTime - startTime)}. |""".stripMargin.filter(_ >= ' ')) - - fileUploadResults.collect().toList + result } // GCS doesn't support streaming yet @@ -1950,12 +1971,13 @@ object StorageUtils { private[io] def logPresignedUrlGenerateProgress(total: Int, index: Int, - startTime: Long): Unit = { + startTime: Long, + useDownScopedToken: Boolean): Unit = { val endTime = System.currentTimeMillis() CloudStorageOperations.log.info( s"""${SnowflakeResultSetRDD.MASTER_LOG_PREFIX}: - | Time to retrieve pre-signed URL for - | $index/$total files is + | Time to retrieve ${if (useDownScopedToken) "down-scoped token" else "pre-signed URL"} + | for $index/$total files is | ${Utils.getTimeString(endTime - startTime)}. |""".stripMargin.filter(_ >= ' ')) } From 86a0fa9509e621c04171e1072decec7ebcc7f011 Mon Sep 17 00:00:00 2001 From: Mingli Rui <63472932+sfc-gh-mrui@users.noreply.github.com> Date: Wed, 19 Apr 2023 11:38:53 -0700 Subject: [PATCH 10/15] SNOW-796952 Add option to disable pre- and post-action validation for session sharing (#503) --- .../snowflake/ShareConnectionSuite.scala | 76 +++++++++++++++++++ .../spark/snowflake/Parameters.scala | 11 ++- .../spark/snowflake/ServerConnection.scala | 9 ++- .../spark/snowflake/ParametersSuite.scala | 16 ++++ 4 files changed, 107 insertions(+), 5 deletions(-) diff --git a/src/it/scala/net/snowflake/spark/snowflake/ShareConnectionSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/ShareConnectionSuite.scala index 5aa6ecdf..e8de8c31 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/ShareConnectionSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/ShareConnectionSuite.scala @@ -310,4 +310,80 @@ class ShareConnectionSuite extends IntegrationSuiteBase { assert(ServerConnection.jdbcConnectionCount.get() == oldJdbcConnectionCount) assert(ServerConnection.serverConnectionCount.get() > oldServerConnectionCount) } + + test("test force_skip_pre_post_action_check_for_session_sharing") { + sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option(Parameters.PARAM_PREACTIONS, s"use schema ${params.sfSchema}") + .option(Parameters.PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING, "false") + .option("dbtable", test_table1) + .load() + .collect() + + // case 1: READ with force_skip_pre_post_action_check_for_session_sharing = false + var oldJdbcConnectionCount = ServerConnection.jdbcConnectionCount.get() + var oldServerConnectionCount = ServerConnection.serverConnectionCount.get() + // Read with the same SfOptions in 2nd time. + sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option(Parameters.PARAM_PREACTIONS, s"use schema ${params.sfSchema}") + .option(Parameters.PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING, "false") + .option("dbtable", test_table1) + .load() + .collect() + // JDBC connection count increases. + assert(ServerConnection.jdbcConnectionCount.get() > oldJdbcConnectionCount) + assert(ServerConnection.serverConnectionCount.get() > oldServerConnectionCount) + + // case 2: WRITE with force_skip_pre_post_action_check_for_session_sharing = false + oldJdbcConnectionCount = ServerConnection.jdbcConnectionCount.get() + oldServerConnectionCount = ServerConnection.serverConnectionCount.get() + localDF.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + // preactions are not in white list + .option(Parameters.PARAM_PREACTIONS, s"use schema ${params.sfSchema}") + .option(Parameters.PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING, "false") + .option("dbtable", test_table_write) + .mode(SaveMode.Append) + .save() + // JDBC connection count increases. + assert(ServerConnection.jdbcConnectionCount.get() > oldJdbcConnectionCount) + assert(ServerConnection.serverConnectionCount.get() > oldServerConnectionCount) + + // case 3: READ with force_skip_pre_post_action_check_for_session_sharing = true + oldJdbcConnectionCount = ServerConnection.jdbcConnectionCount.get() + oldServerConnectionCount = ServerConnection.serverConnectionCount.get() + localDF.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + // preactions are not in white list + .option(Parameters.PARAM_PREACTIONS, s"use schema ${params.sfSchema}") + .option(Parameters.PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING, "true") + .option("dbtable", test_table_write) + .mode(SaveMode.Append) + .save() + // With force_skip_pre_post_action_check_for_session_sharing = true, + // JDBC connection count is the same. + assert(ServerConnection.jdbcConnectionCount.get() == oldJdbcConnectionCount) + assert(ServerConnection.serverConnectionCount.get() > oldServerConnectionCount) + + // case 4: WRITE with force_skip_pre_post_action_check_for_session_sharing = true + oldJdbcConnectionCount = ServerConnection.jdbcConnectionCount.get() + oldServerConnectionCount = ServerConnection.serverConnectionCount.get() + localDF.write + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + // post actions are not in white list + .option(Parameters.PARAM_POSTACTIONS, s"use schema ${params.sfSchema}") + .option(Parameters.PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING, "true") + .option("dbtable", test_table_write) + .mode(SaveMode.Append) + .save() + // With connection sharing disabled, JDBC connection count increases. + assert(ServerConnection.jdbcConnectionCount.get() == oldJdbcConnectionCount) + assert(ServerConnection.serverConnectionCount.get() > oldServerConnectionCount) + } } diff --git a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala index 29886d60..4722c92c 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala @@ -234,6 +234,14 @@ object Parameters { "internal_support_micro_second_during_unload" ) + // preactions and postactions may affect the session level setting, so connection sharing + // may be enabled only when the queries in preactions and postactions are in a white list. + // force_skip_pre_post_action_check_for_session_sharing is introduced if users are sure that + // the queries in preactions and postactions don't affect others. Its default value is false. + val PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING: String = knownParam( + "force_skip_pre_post_action_check_for_session_sharing" + ) + val DEFAULT_S3_MAX_FILE_SIZE: String = (10 * 1000 * 1000).toString val MIN_S3_MAX_FILE_SIZE = 1000000 @@ -711,9 +719,6 @@ object Parameters { def supportShareConnection: Boolean = { isTrue(parameters.getOrElse(PARAM_SUPPORT_SHARE_CONNECTION, "true")) } - def supportMicroSecondDuringUnload: Boolean = { - isTrue(parameters.getOrElse(PARAM_INTERNAL_SUPPORT_MICRO_SECOND_DURING_UNLOAD, "true")) - } def stagingTableNameRemoveQuotesOnly: Boolean = { isTrue(parameters.getOrElse(PARAM_INTERNAL_STAGING_TABLE_NAME_REMOVE_QUOTES_ONLY, "false")) } diff --git a/src/main/scala/net/snowflake/spark/snowflake/ServerConnection.scala b/src/main/scala/net/snowflake/spark/snowflake/ServerConnection.scala index 9ebd448e..4965f556 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/ServerConnection.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/ServerConnection.scala @@ -23,6 +23,7 @@ sealed class ConnectionCacheKey(private val parameters: MergedParameters) { "overwrite", Parameters.PARAM_POSTACTIONS, Parameters.PARAM_PREACTIONS, + Parameters.PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING, Parameters.PARAM_SF_QUERY, Parameters.PARAM_SF_DBTABLE ) @@ -69,12 +70,16 @@ sealed class ConnectionCacheKey(private val parameters: MergedParameters) { } } + private[snowflake] def isPrePostActionsQualifiedForConnectionShare: Boolean = + parameters.forceSkipPrePostActionsCheck || + (parameters.preActions.forall(isQueryInWhiteList) && + parameters.postActions.forall(isQueryInWhiteList)) + def isConnectionCacheSupported: Boolean = { // Support sharing connection if pre/post actions doesn't change context ServerConnection.supportSharingJDBCConnection && parameters.supportShareConnection && - parameters.preActions.forall(isQueryInWhiteList) && - parameters.postActions.forall(isQueryInWhiteList) + isPrePostActionsQualifiedForConnectionShare } } diff --git a/src/test/scala/net/snowflake/spark/snowflake/ParametersSuite.scala b/src/test/scala/net/snowflake/spark/snowflake/ParametersSuite.scala index 9bca54bf..3a0bddca 100644 --- a/src/test/scala/net/snowflake/spark/snowflake/ParametersSuite.scala +++ b/src/test/scala/net/snowflake/spark/snowflake/ParametersSuite.scala @@ -284,4 +284,20 @@ class ParametersSuite extends FunSuite with Matchers { connectionCacheKey = new ConnectionCacheKey(mergedParams) assert(!connectionCacheKey.isConnectionCacheSupported) } + + test("test ConnectionCacheKey.isPrePostActionsQualifiedForConnectionShare()") { + val listQuery1 = "create database db1" + val mergedParams = Parameters.mergeParameters( + minParams ++ Map(Parameters.PARAM_POSTACTIONS -> listQuery1)) + assert(!mergedParams.forceSkipPrePostActionsCheck) + val connectionCacheKey = new ConnectionCacheKey(mergedParams) + assert(!connectionCacheKey.isPrePostActionsQualifiedForConnectionShare) + + val mergedParamsEnabled = Parameters.mergeParameters(minParams ++ + Map(Parameters.PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING -> "true", + Parameters.PARAM_POSTACTIONS -> listQuery1)) + assert(mergedParamsEnabled.forceSkipPrePostActionsCheck) + val forceEnabledConnectionCacheKey = new ConnectionCacheKey(mergedParamsEnabled) + assert(forceEnabledConnectionCacheKey.isPrePostActionsQualifiedForConnectionShare) + } } From 4698777486218036a3bce8d61f972039379a7580 Mon Sep 17 00:00:00 2001 From: Mingli Rui <63472932+sfc-gh-mrui@users.noreply.github.com> Date: Thu, 20 Apr 2023 18:03:39 -0700 Subject: [PATCH 11/15] SNOW-770051 Fix a wrong result issue for crossing schema join/union (#504) * SNOW-770051 Fix a potential wrong result issue for crossing schema join/union * Revise error message * Simplify canUseSameConnection() * Use toSet.size == 1 for duplication check. --- .../snowflake/PushdownJoinAndUnion.scala | 484 ++++++++++++++++++ .../snowflake/SnowflakeConnectorUtils.scala | 2 + .../querygeneration/QueryBuilder.scala | 23 +- .../querygeneration/SnowflakeQuery.scala | 34 +- 4 files changed, 534 insertions(+), 9 deletions(-) create mode 100644 src/it/scala/net/snowflake/spark/snowflake/PushdownJoinAndUnion.scala diff --git a/src/it/scala/net/snowflake/spark/snowflake/PushdownJoinAndUnion.scala b/src/it/scala/net/snowflake/spark/snowflake/PushdownJoinAndUnion.scala new file mode 100644 index 00000000..f845e550 --- /dev/null +++ b/src/it/scala/net/snowflake/spark/snowflake/PushdownJoinAndUnion.scala @@ -0,0 +1,484 @@ +/* + * Copyright 2015-2019 Snowflake Computing + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package net.snowflake.spark.snowflake + +import java.sql._ +import java.util.TimeZone +import net.snowflake.spark.snowflake.Utils.{SNOWFLAKE_SOURCE_NAME, SNOWFLAKE_SOURCE_SHORT_NAME} +import net.snowflake.spark.snowflake.test.TestHook +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.Expand +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{IntegerType, MapType, StringType, StructField, StructType} + +import scala.reflect.internal.util.TableDef + +// scalastyle:off println +class PushdownJoinAndUnion extends IntegrationSuiteBase { + private var thisConnectorOptionsNoTable: Map[String, String] = Map() + private val test_left_table = s"test_table_left_$randomSuffix" + private val test_right_table = s"test_table_right_$randomSuffix" + private val test_left_table_2 = s"test_table_left_2_$randomSuffix" + private val test_right_table_2 = s"test_table_right_2_$randomSuffix" + private val test_temp_schema = s"test_temp_schema_$randomSuffix" + + lazy val localDataFrame = { + val partitionCount = 1 + val rowCountPerPartition = 1 + // Create RDD which generates data with multiple partitions + val testRDD: RDD[Row] = sparkSession.sparkContext + .parallelize(Seq[Int](), partitionCount) + .mapPartitions { _ => { + (1 to rowCountPerPartition).map { i => { + Row(i, "local_value") + } + }.iterator + } + } + + val schema = StructType(List( + StructField("key", IntegerType), + StructField("local_value", StringType) + )) + + // Convert RDD to DataFrame + sparkSession.createDataFrame(testRDD, schema) + } + + override def afterAll(): Unit = { + try { + jdbcUpdate(s"drop table if exists $test_left_table") + jdbcUpdate(s"drop table if exists $test_right_table") + jdbcUpdate(s"drop table if exists $test_left_table_2") + jdbcUpdate(s"drop table if exists $test_right_table_2") + jdbcUpdate(s"drop schema if exists $test_temp_schema") + } finally { + TestHook.disableTestHook() + SnowflakeConnectorUtils.disablePushdownSession(sparkSession) + super.afterAll() + } + } + + override def beforeAll(): Unit = { + super.beforeAll() + + // There is bug for Date.equals() to compare Date with different timezone, + // so set up the timezone to work around it. + val gmtTimezone = TimeZone.getTimeZone("GMT") + TimeZone.setDefault(gmtTimezone) + + connectorOptionsNoTable.foreach(tup => { + thisConnectorOptionsNoTable += tup + }) + + // Create test tables in sfSchema + jdbcUpdate(s"create or replace table $test_left_table (key int, left_value string)") + jdbcUpdate(s"insert into $test_left_table values (1, 'left_in_current_schema')") + jdbcUpdate(s"create or replace table $test_left_table_2 (key int, left_value string)") + jdbcUpdate(s"insert into $test_left_table_2 values (1, 'left_in_current_schema')") + + jdbcUpdate(s"create or replace table $test_right_table (key int, right_value string)") + jdbcUpdate(s"insert into $test_right_table values (1, 'right_in_current_schema')") + jdbcUpdate(s"create or replace table $test_right_table_2 (key int, right_value string)") + jdbcUpdate(s"insert into $test_right_table_2 values (1, 'right_in_current_schema')") + + // Create test schema for crossing schema join/union test + jdbcUpdate(s"create or replace schema $test_temp_schema") + + jdbcUpdate(s"create or replace table $test_temp_schema.$test_right_table" + + s" (key int, right_value string)") + jdbcUpdate(s"insert into $test_temp_schema.$test_right_table" + + s" values (1, 'right_in_another_schema')") + jdbcUpdate(s"create or replace table $test_temp_schema.$test_right_table_2" + + s" (key int, right_value string)") + jdbcUpdate(s"insert into $test_temp_schema.$test_right_table_2" + + s" values (1, 'right_in_another_schema')") + + jdbcUpdate(s"create or replace table $test_temp_schema.$test_left_table" + + s" (key int, right_value string)") + jdbcUpdate(s"insert into $test_temp_schema.$test_left_table" + + s" values (1, 'left_in_another_schema')") + jdbcUpdate(s"create or replace table $test_temp_schema.$test_left_table_2" + + s" (key int, right_value string)") + jdbcUpdate(s"insert into $test_temp_schema.$test_left_table_2" + + s" values (1, 'left_in_another_schema')") + } + + private lazy val sfOptionsWithAnotherSchema = + connectorOptionsNoTable ++ Map(Parameters.PARAM_SF_SCHEMA -> test_temp_schema) + + test("same schema JOIN is pushdown") { + // Left DataFrame reads test_left_table in current schema + val dfLeft = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_left_table) + .load() + .select("*") + + // Right DataFrame reads test_right_table in another schema (not current schema) + val dfRight = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_right_table) + .load() + .select("*") + + val dfJoin = dfLeft.join(dfRight, "key") + + val expectedResult = Seq(Row(1, "left_in_current_schema", "right_in_current_schema")) + + testPushdown( + s"""SELECT ( "SUBQUERY_4"."SUBQUERY_4_COL_0" ) AS "SUBQUERY_5_COL_0" , + | ( "SUBQUERY_4"."SUBQUERY_4_COL_1" ) AS "SUBQUERY_5_COL_1" , + | ( "SUBQUERY_4"."SUBQUERY_4_COL_3" ) AS "SUBQUERY_5_COL_2" + | FROM ( + | SELECT ( "SUBQUERY_1"."KEY" ) AS "SUBQUERY_4_COL_0" , + | ( "SUBQUERY_1"."LEFT_VALUE" ) AS "SUBQUERY_4_COL_1" , + | ( "SUBQUERY_3"."KEY" ) AS "SUBQUERY_4_COL_2" , + | ( "SUBQUERY_3"."RIGHT_VALUE" ) AS "SUBQUERY_4_COL_3" + | FROM ( + | SELECT * FROM ( SELECT * FROM ( $test_left_table ) + | AS "SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" + | WHERE ( "SUBQUERY_0"."KEY" IS NOT NULL ) ) AS "SUBQUERY_1" + | INNER JOIN ( + | SELECT * FROM ( SELECT * FROM ( $test_right_table ) + | AS "SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_2" + | WHERE ( "SUBQUERY_2"."KEY" IS NOT NULL ) ) AS "SUBQUERY_3" + | ON ( "SUBQUERY_1"."KEY" = "SUBQUERY_3"."KEY" ) ) AS "SUBQUERY_4" + |""".stripMargin, + dfJoin, + expectedResult + ) + } + + test("same schema UNION is pushdown") { + // Left DataFrame reads test_left_table in current schema + val dfLeft = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_left_table) + .load() + .select("*") + + // Right DataFrame reads test_right_table in another schema (not current schema) + val dfRight = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_right_table) + .load() + .select("*") + + // union of 2 DataFrame + testPushdown( + s"""( SELECT * FROM ( $test_left_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) + | UNION ALL + |( SELECT * FROM ( $test_right_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) + |""".stripMargin, + dfLeft.union(dfRight), + Seq(Row(1, "left_in_current_schema"), Row(1, "right_in_current_schema")) + ) + + // union of 3 DataFrame + testPushdown( + s"""( SELECT * FROM ( $test_left_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) + | UNION ALL + |( SELECT * FROM ( $test_right_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) + | UNION ALL + |( SELECT * FROM ( $test_right_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) + |""".stripMargin, + dfLeft.union(dfRight).union(dfRight), + Seq(Row(1, "left_in_current_schema"), + Row(1, "right_in_current_schema"), + Row(1, "right_in_current_schema")) + ) + } + + test("self UNION is pushdown") { + // Left DataFrame reads test_left_table in current schema + val dfLeft = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_left_table) + .load() + .select("*") + + // union of 2 DataFrame + testPushdown( + s"""( SELECT * FROM ( $test_left_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) + | UNION ALL + |( SELECT * FROM ( $test_left_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) + |""".stripMargin, + dfLeft.union(dfLeft), + Seq(Row(1, "left_in_current_schema"), Row(1, "left_in_current_schema")) + ) + + // union of 3 DataFrame + testPushdown( + s"""( SELECT * FROM ( $test_left_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) + | UNION ALL + |( SELECT * FROM ( $test_left_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) + | UNION ALL + |( SELECT * FROM ( $test_left_table ) AS "SF_CONNECTOR_QUERY_ALIAS" ) + |""".stripMargin, + dfLeft.union(dfLeft).union(dfLeft), + Seq(Row(1, "left_in_current_schema"), + Row(1, "left_in_current_schema"), + Row(1, "left_in_current_schema")) + ) + } + + test("test snowflake DataFrame JOIN with non-snowflake DataFrame - not pushdown") { + // Left DataFrame reads test_left_table in current schema + val dfSnowflake = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_left_table) + .load() + .select("*") + + // Snowflake DataFrame join local DataFrame + testPushdown( + s"don't_check_query", + dfSnowflake.join(localDataFrame, "key"), + Seq(Row(1, "left_in_current_schema", "local_value")), + bypass = true + ) + + // Snowflake DataFrame join local DataFrame + testPushdown( + s"don't_check_query", + localDataFrame.join(dfSnowflake, "key"), + Seq(Row(1, "local_value", "left_in_current_schema")), + bypass = true + ) + } + + test("test snowflake DataFrame UNION with non-snowflake DataFrame - not pushdown") { + // Left DataFrame reads test_left_table in current schema + val dfSnowflake = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_left_table) + .load() + .select("*") + + dfSnowflake.union(localDataFrame).show(truncate = false) + // Snowflake DataFrame union local DataFrame + testPushdown( + s"don't_check_query", + dfSnowflake.union(localDataFrame), + Seq(Row(1, "left_in_current_schema"), Row(1, "local_value")), + bypass = true + ) + + // Snowflake DataFrame union local DataFrame + testPushdown( + s"don't_check_query", + localDataFrame.union(dfSnowflake), + Seq(Row(1, "left_in_current_schema"), Row(1, "local_value")), + bypass = true + ) + } + + test("test crossing schema join: SNOW-770051") { + // Left DataFrame reads test_left_table in current schema + val dfLeft = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_left_table) + .load() + .select("*") + + // Right DataFrame reads test_right_table in another schema (not current schema) + val dfRight = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(sfOptionsWithAnotherSchema) + .option("dbtable", s"$test_right_table") + .load() + .select("*") + + val dfJoin = dfLeft.join(dfRight, "key") + + val expectedResult = Seq( + Row(1, "left_in_current_schema", + "right_in_another_schema")) + + testPushdown( + s"don't_check_query", + dfJoin, + expectedResult, bypass = true + ) + } + + test("test crossing schema union") { + // Left DataFrame reads test_left_table in current schema + val dfLeft = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_left_table) + .load() + + // Right DataFrame reads test_right_table in another schema (not current schema) + val dfRight = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(sfOptionsWithAnotherSchema) + .option("dbtable", s"$test_right_table") + .load() + + val dfUnion = dfLeft.union(dfRight) // .union(dfRight).union(dfRight) + + val expectedResult = Seq( + Row(1, "left_in_current_schema"), + Row(1, "right_in_another_schema")) + + testPushdown( + s"don't_check_query", + dfUnion, + expectedResult, bypass = true + ) + } + + test("test crossing schema union for multiple DataFrames") { + // Left DataFrame reads test_left_table in current schema + val dfLeft = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_left_table) + .load() + + // Right DataFrame reads test_right_table in another schema (not current schema) + val dfRight = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(sfOptionsWithAnotherSchema) + .option("dbtable", s"$test_right_table") + .load() + + // df1 UNION df1 UNION df2 + testPushdown( + s"don't_check_query", + dfLeft.union(dfLeft).union(dfRight), + Seq( + Row(1, "left_in_current_schema"), + Row(1, "left_in_current_schema"), + Row(1, "right_in_another_schema")), + bypass = true + ) + + // df1 UNION df2 UNION df2 + testPushdown( + s"don't_check_query", + dfLeft.union(dfRight).union(dfRight), + Seq( + Row(1, "left_in_current_schema"), + Row(1, "right_in_another_schema"), + Row(1, "right_in_another_schema")), + bypass = true + ) + + // df1 UNION df1 UNION df2 UNION df2 + testPushdown( + s"don't_check_query", + dfLeft.union(dfLeft).union(dfRight).union(dfRight), + Seq( + Row(1, "left_in_current_schema"), + Row(1, "left_in_current_schema"), + Row(1, "right_in_another_schema"), + Row(1, "right_in_another_schema")), + bypass = true + ) + + // df1 UNION df2 UNION df1 UNION df2 + testPushdown( + s"don't_check_query", + dfLeft.union(dfRight).union(dfLeft).union(dfRight), + Seq( + Row(1, "left_in_current_schema"), + Row(1, "right_in_another_schema"), + Row(1, "left_in_current_schema"), + Row(1, "right_in_another_schema")), + bypass = true + ) + } + + test("test crossing schema union + join") { + // dfLeftA and dfRightA in current schema + val dfLeftA = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_left_table) + .load() + val dfRightA = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(connectorOptionsNoTable) + .option("dbtable", test_right_table) + .load() + + // dfLeftB and dfRightB in another test schema + val dfLeftB = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(sfOptionsWithAnotherSchema) + .option("dbtable", test_left_table) + .load() + val dfRightB = sparkSession.read + .format(SNOWFLAKE_SOURCE_NAME) + .options(sfOptionsWithAnotherSchema) + .option("dbtable", test_right_table) + .load() + + // (dfLeftA JOIN dfRightA) UNION (dfLeftB JOIN dfRightB) + testPushdown( + s"don't_check_query", + (dfLeftA.join(dfRightA, "key")).union(dfLeftB.join(dfRightB, "key")), + Seq( + Row(1, "left_in_current_schema", "right_in_current_schema"), + Row(1, "left_in_another_schema", "right_in_another_schema")), + bypass = true + ) + + // (dfLeftA UNION dfRightA) JOIN (dfLeftB UNION dfRightB) + testPushdown( + s"don't_check_query", + (dfLeftA.union(dfRightA)).join(dfLeftB.union(dfRightB), "key"), + Seq( + Row(1, "left_in_current_schema", "left_in_another_schema"), + Row(1, "left_in_current_schema", "right_in_another_schema"), + Row(1, "right_in_current_schema", "left_in_another_schema"), + Row(1, "right_in_current_schema", "right_in_another_schema")), + bypass = true + ) + + // (dfLeftA UNION dfRightB) JOIN (dfLeftB UNION dfRightA) + testPushdown( + s"don't_check_query", + (dfLeftA.union(dfRightB)).join(dfLeftB.union(dfRightA), "key"), + Seq( + Row(1, "left_in_current_schema", "left_in_another_schema"), + Row(1, "left_in_current_schema", "right_in_current_schema"), + Row(1, "right_in_another_schema", "left_in_another_schema"), + Row(1, "right_in_another_schema", "right_in_current_schema")), + bypass = true + ) + } + + override def beforeEach(): Unit = { + super.beforeEach() + } +} +// scalastyle:on println + diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala index af40cda0..90dd5e6f 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala @@ -145,6 +145,8 @@ object SnowflakeFailMessage { final val FAIL_PUSHDOWN_AGGREGATE_EXPRESSION = "pushdown failed for aggregate expression" final val FAIL_PUSHDOWN_UNSUPPORTED_CONVERSION = "pushdown failed for unsupported conversion" final val FAIL_PUSHDOWN_UNSUPPORTED_UNION = "pushdown failed for Spark feature: UNION by name" + final val FAIL_PUSHDOWN_CANNOT_UNION = + "pushdown failed for UNION because the spark connector options are not compatible" } class SnowflakePushdownUnsupportedException(message: String, diff --git a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala index 704bb8c0..1a1b847a 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/QueryBuilder.scala @@ -4,6 +4,7 @@ import java.io.{PrintWriter, StringWriter} import java.util.NoSuchElementException import net.snowflake.spark.snowflake.{ + ConnectionCacheKey, SnowflakeFailMessage, SnowflakePushdownException, SnowflakePushdownUnsupportedException, @@ -163,6 +164,13 @@ private[querygeneration] class QueryBuilder(plan: LogicalPlan) { expression.children.foreach(processExpression(_, statisticSet)) } + private def canUseSameConnection(snowflakeQueries: Seq[SnowflakeQuery]): Boolean = + snowflakeQueries + .flatMap(_.getSourceQueries) + .map(x => new ConnectionCacheKey(x.relation.params)) + .toSet + .size == 1 + /** Attempts to generate the query from the LogicalPlan. The queries are constructed from * the bottom up, but the validation of supported nodes for translation happens on the way down. * @@ -211,7 +219,7 @@ private[querygeneration] class QueryBuilder(plan: LogicalPlan) { generateQueries(left).flatMap { l => generateQueries(right) map { r => plan match { - case Join(_, _, joinType, condition, _) => + case Join(_, _, joinType, condition, _) if canUseSameConnection(Seq(l, r)) => joinType match { case Inner | LeftOuter | RightOuter | FullOuter => JoinQuery(l, r, condition, joinType, alias.next) @@ -240,7 +248,18 @@ private[querygeneration] class QueryBuilder(plan: LogicalPlan) { false ) } else { - Some(UnionQuery(children, alias.next)) + val unionQuery = UnionQuery(children, alias.next) + val sourceQueries = unionQuery.getSourceQueries + if (canUseSameConnection(sourceQueries)) { + Some(unionQuery) + } else { + throw new SnowflakePushdownUnsupportedException( + SnowflakeFailMessage.FAIL_PUSHDOWN_CANNOT_UNION, + s"${plan.nodeName} with source query count: ${sourceQueries.size}", + plan.getClass.getName, + false + ) + } } case Expand(projections, output, child) => diff --git a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/SnowflakeQuery.scala b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/SnowflakeQuery.scala index 3d19fa00..a8371a41 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/SnowflakeQuery.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/SnowflakeQuery.scala @@ -9,6 +9,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** Building blocks of a translated query, with nested subqueries. */ private[querygeneration] abstract sealed class SnowflakeQuery { + /** Get all children snowflake SourceQuery. */ + def getSourceQueries: Seq[SourceQuery] + /** Output columns. */ lazy val output: Seq[Attribute] = if (helper == null) Seq.empty @@ -89,6 +92,18 @@ private[querygeneration] abstract sealed class SnowflakeQuery { } } +private[querygeneration] +abstract sealed class UnarySnowflakeQuery(child: SnowflakeQuery) extends SnowflakeQuery { + override def getSourceQueries: Seq[SourceQuery] = child.getSourceQueries +} + +private[querygeneration] +abstract sealed class BinarySnowflakeQuery(left: SnowflakeQuery, right: SnowflakeQuery) + extends SnowflakeQuery { + override def getSourceQueries: Seq[SourceQuery] = + left.getSourceQueries ++ right.getSourceQueries +} + /** The query for a base type (representing a table or view). * * @constructor @@ -103,6 +118,8 @@ case class SourceQuery(relation: SnowflakeRelation, alias: String) extends SnowflakeQuery { + override def getSourceQueries: Seq[SourceQuery] = Seq(this) + override val helper: QueryHelper = QueryHelper( children = Seq.empty, projections = None, @@ -141,7 +158,7 @@ case class FilterQuery(conditions: Seq[Expression], child: SnowflakeQuery, alias: String, fields: Option[Seq[Attribute]] = None) - extends SnowflakeQuery { + extends UnarySnowflakeQuery(child) { override val helper: QueryHelper = QueryHelper( @@ -169,7 +186,7 @@ case class FilterQuery(conditions: Seq[Expression], case class ProjectQuery(columns: Seq[NamedExpression], child: SnowflakeQuery, alias: String) - extends SnowflakeQuery { + extends UnarySnowflakeQuery(child) { override val helper: QueryHelper = QueryHelper( @@ -192,7 +209,7 @@ case class AggregateQuery(columns: Seq[NamedExpression], groups: Seq[Expression], child: SnowflakeQuery, alias: String) - extends SnowflakeQuery { + extends UnarySnowflakeQuery(child) { override val helper: QueryHelper = QueryHelper( @@ -225,7 +242,7 @@ case class SortLimitQuery(limit: Option[Expression], orderBy: Seq[Expression], child: SnowflakeQuery, alias: String) - extends SnowflakeQuery { + extends UnarySnowflakeQuery(child) { override val helper: QueryHelper = QueryHelper( @@ -266,7 +283,7 @@ case class JoinQuery(left: SnowflakeQuery, conditions: Option[Expression], joinType: JoinType, alias: String) - extends SnowflakeQuery { + extends BinarySnowflakeQuery(left, right) { val conj: String = joinType match { case Inner => @@ -317,7 +334,7 @@ case class LeftSemiJoinQuery(left: SnowflakeQuery, conditions: Option[Expression], isAntiJoin: Boolean = false, alias: Iterator[String]) - extends SnowflakeQuery { + extends BinarySnowflakeQuery(left, right) { override val helper: QueryHelper = QueryHelper( @@ -362,6 +379,9 @@ case class UnionQuery(children: Seq[LogicalPlan], new QueryBuilder(child).treeRoot } + override def getSourceQueries: Seq[SourceQuery] = + queries.flatMap(_.getSourceQueries) + override val helper: QueryHelper = QueryHelper( children = queries, @@ -417,7 +437,7 @@ case class WindowQuery(windowExpressions: Seq[NamedExpression], child: SnowflakeQuery, alias: String, fields: Option[Seq[Attribute]]) - extends SnowflakeQuery { + extends UnarySnowflakeQuery(child) { val projectionVector: Seq[NamedExpression] = windowExpressions ++ child.helper.outputWithQualifier From 21f4a508707ea3c42397ff9a7d03e0ffac57f44d Mon Sep 17 00:00:00 2001 From: Mingli Rui <63472932+sfc-gh-mrui@users.noreply.github.com> Date: Mon, 22 May 2023 16:47:50 -0700 Subject: [PATCH 12/15] SNOW-824475 Support Spark 3.4 (#510) --- .github/workflows/ClusterTest.yml | 6 +- .github/workflows/IntegrationTest_2.12.yml | 2 +- .github/workflows/IntegrationTest_2.13.yml | 2 +- .../workflows/IntegrationTest_gcp_2.12.yml | 2 +- .../workflows/IntegrationTest_gcp_2.13.yml | 2 +- ClusterTest/build.sbt | 2 +- build.sbt | 6 +- .../SimpleNewPushdownIntegrationSuite.scala | 134 ++++++++++++------ .../sql/SFDataFrameWindowFramesSuite.scala | 35 +++-- .../sql/SFDataFrameWindowFunctionsSuite.scala | 2 +- .../spark/sql/SFDateFunctionsSuite.scala | 2 +- .../spark/sql/SnowflakeSparkUtilsSuite.scala | 2 +- .../snowflake/SnowflakeConnectorUtils.scala | 2 +- .../querygeneration/MiscStatement.scala | 14 +- .../querygeneration/NumericStatement.scala | 13 +- .../snowflake/spark/snowflake/TestUtils.scala | 14 ++ 16 files changed, 168 insertions(+), 72 deletions(-) diff --git a/.github/workflows/ClusterTest.yml b/.github/workflows/ClusterTest.yml index 7f39d470..6eb14abb 100644 --- a/.github/workflows/ClusterTest.yml +++ b/.github/workflows/ClusterTest.yml @@ -13,13 +13,13 @@ jobs: strategy: matrix: scala_version: [ '2.12.11' ] - spark_version: [ '3.3.0' ] + spark_version: [ '3.4.0' ] use_copy_unload: [ 'true' ] cloud_provider: [ 'gcp' ] env: SNOWFLAKE_TEST_CONFIG_SECRET: ${{ secrets.SNOWFLAKE_TEST_CONFIG_SECRET }} - TEST_SPARK_VERSION: '3.3' - DOCKER_IMAGE_TAG: 'snowflakedb/spark-base:3.3.0' + TEST_SPARK_VERSION: '3.4' + DOCKER_IMAGE_TAG: 'snowflakedb/spark-base:3.4.0' TEST_SCALA_VERSION: '2.12' TEST_COMPILE_SCALA_VERSION: '2.12.11' TEST_SPARK_CONNECTOR_VERSION: '2.11.3' diff --git a/.github/workflows/IntegrationTest_2.12.yml b/.github/workflows/IntegrationTest_2.12.yml index ee28f034..80235ce8 100644 --- a/.github/workflows/IntegrationTest_2.12.yml +++ b/.github/workflows/IntegrationTest_2.12.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: scala_version: [ '2.12.11' ] - spark_version: [ '3.3.0' ] + spark_version: [ '3.4.0' ] use_copy_unload: [ 'true', 'false' ] cloud_provider: [ 'aws', 'azure' ] # run_query_in_async can be removed after async mode is stable diff --git a/.github/workflows/IntegrationTest_2.13.yml b/.github/workflows/IntegrationTest_2.13.yml index 72f17be7..99999310 100644 --- a/.github/workflows/IntegrationTest_2.13.yml +++ b/.github/workflows/IntegrationTest_2.13.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: scala_version: [ '2.13.9' ] - spark_version: [ '3.3.0' ] + spark_version: [ '3.4.0' ] use_copy_unload: [ 'true', 'false' ] cloud_provider: [ 'aws', 'azure' ] # run_query_in_async can be removed after async mode is stable diff --git a/.github/workflows/IntegrationTest_gcp_2.12.yml b/.github/workflows/IntegrationTest_gcp_2.12.yml index cbc6a657..674e0b8f 100644 --- a/.github/workflows/IntegrationTest_gcp_2.12.yml +++ b/.github/workflows/IntegrationTest_gcp_2.12.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: scala_version: [ '2.12.11' ] - spark_version: [ '3.3.0' ] + spark_version: [ '3.4.0' ] use_copy_unload: [ 'false' ] cloud_provider: [ 'gcp' ] # run_query_in_async can be removed after async mode is stable diff --git a/.github/workflows/IntegrationTest_gcp_2.13.yml b/.github/workflows/IntegrationTest_gcp_2.13.yml index acd43db5..4a0f2b64 100644 --- a/.github/workflows/IntegrationTest_gcp_2.13.yml +++ b/.github/workflows/IntegrationTest_gcp_2.13.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: scala_version: [ '2.13.9' ] - spark_version: [ '3.3.0' ] + spark_version: [ '3.4.0' ] use_copy_unload: [ 'false' ] cloud_provider: [ 'gcp' ] # run_query_in_async can be removed after async mode is stable diff --git a/ClusterTest/build.sbt b/ClusterTest/build.sbt index 48505a7b..475c4c0a 100644 --- a/ClusterTest/build.sbt +++ b/ClusterTest/build.sbt @@ -16,7 +16,7 @@ val sparkConnectorVersion = "2.11.3" val scalaVersionMajor = "2.12" -val sparkVersionMajor = "3.3" +val sparkVersionMajor = "3.4" val sparkVersion = s"${sparkVersionMajor}.0" val testSparkVersion = sys.props.get("spark.testVersion").getOrElse(sparkVersion) diff --git a/build.sbt b/build.sbt index d1ee789f..10ef4a2b 100644 --- a/build.sbt +++ b/build.sbt @@ -16,8 +16,8 @@ import scala.util.Properties -val sparkVersion = "3.3" -val testSparkVersion = sys.props.get("spark.testVersion").getOrElse("3.3.0") +val sparkVersion = "3.4" +val testSparkVersion = sys.props.get("spark.testVersion").getOrElse("3.4.0") /* * Don't change the variable name "sparkConnectorVersion" because @@ -41,7 +41,7 @@ lazy val root = project.withId("spark-snowflake").in(file(".")) .settings( name := "spark-snowflake", organization := "net.snowflake", - version := s"${sparkConnectorVersion}-spark_3.3", + version := s"${sparkConnectorVersion}-spark_3.4", scalaVersion := sys.props.getOrElse("SPARK_SCALA_VERSION", default = "2.12.11"), // Spark 3.2 supports scala 2.12 and 2.13 crossScalaVersions := Seq("2.12.11", "2.13.9"), diff --git a/src/it/scala/net/snowflake/spark/snowflake/SimpleNewPushdownIntegrationSuite.scala b/src/it/scala/net/snowflake/spark/snowflake/SimpleNewPushdownIntegrationSuite.scala index 4ef35d70..bce1f162 100644 --- a/src/it/scala/net/snowflake/spark/snowflake/SimpleNewPushdownIntegrationSuite.scala +++ b/src/it/scala/net/snowflake/spark/snowflake/SimpleNewPushdownIntegrationSuite.scala @@ -334,15 +334,22 @@ class SimpleNewPushdownIntegrationSuite extends IntegrationSuiteBase { var result = sparkSession.sql(s"select o $operator p from df2 where o IS NOT NULL") - testPushdown( + val expectedAdditionQueries = Seq( s"""SELECT ( CAST ( ( "SUBQUERY_1"."O" $operator "SUBQUERY_1"."P" ) - |AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_0" FROM - |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS - |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE - |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + |AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_0" FROM + |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS + |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE + |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" """.stripMargin, - result, - // Data in df2 (o, p) values(null, 1), (2, 2), (3, 2), (4, 3) + // From spark 3.4, the CAST operation is not presented in the plan + s"""SELECT ( ( "SUBQUERY_1"."O" $operator "SUBQUERY_1"."P" ) ) AS "SUBQUERY_2_COL_0" FROM + |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS + |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE + |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + """.stripMargin + ) + + testPushdownMultiplefQueries(expectedAdditionQueries, result, Seq(Row(4), Row(5), Row(7)), disablePushDown ) @@ -352,14 +359,22 @@ class SimpleNewPushdownIntegrationSuite extends IntegrationSuiteBase { result = sparkSession.sql(s"select o $operator p from df2 where o IS NOT NULL") - testPushdown( + val expectedSubtractQueries = Seq( s"""SELECT ( CAST ( ( "SUBQUERY_1"."O" $operator "SUBQUERY_1"."P" ) - |AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_0" FROM - |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS - |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE - |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + |AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_0" FROM + |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS + |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE + |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" """.stripMargin, - result, + // From spark 3.4, the CAST operation is not presented in the plan + s"""SELECT ( ( "SUBQUERY_1"."O" $operator "SUBQUERY_1"."P" ) ) AS "SUBQUERY_2_COL_0" FROM + |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS + |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE + |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + """.stripMargin + ) + + testPushdownMultiplefQueries(expectedSubtractQueries, result, // Data in df2 (o, p) values(null, 1), (2, 2), (3, 2), (4, 3) Seq(Row(0), Row(1), Row(1)), disablePushDown @@ -370,15 +385,22 @@ class SimpleNewPushdownIntegrationSuite extends IntegrationSuiteBase { result = sparkSession.sql(s"select o $operator p from df2 where o IS NOT NULL") - testPushdown( + val expectedMultiplyQueries = Seq( s"""SELECT ( CAST ( ( "SUBQUERY_1"."O" $operator "SUBQUERY_1"."P" ) - |AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_0" FROM - |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS - |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE - |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + |AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_0" FROM + |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS + |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE + |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" """.stripMargin, - result, - // Data in df2 (o, p) values(null, 1), (2, 2), (3, 2), (4, 3) + // From spark 3.4, the CAST operation is not presented in the plan + s"""SELECT ( ( "SUBQUERY_1"."O" $operator "SUBQUERY_1"."P" ) ) AS "SUBQUERY_2_COL_0" FROM + |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS + |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE + |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + """.stripMargin + ) + + testPushdownMultiplefQueries(expectedMultiplyQueries, result, Seq(Row(4), Row(6), Row(12)), disablePushDown ) @@ -388,15 +410,22 @@ class SimpleNewPushdownIntegrationSuite extends IntegrationSuiteBase { result = sparkSession.sql(s"select o $operator p from df2 where o IS NOT NULL") - testPushdown( + val expectedDivisionQueries = Seq( s"""SELECT ( CAST ( ( "SUBQUERY_1"."O" $operator "SUBQUERY_1"."P" ) - |AS DECIMAL(38, 6) ) ) AS "SUBQUERY_2_COL_0" FROM - |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS - |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE - |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + |AS DECIMAL(38, 6) ) ) AS "SUBQUERY_2_COL_0" FROM + |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS + |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE + |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" """.stripMargin, - result, - // Data in df2 (o, p) values(null, 1), (2, 2), (3, 2), (4, 3) + // From spark 3.4, the CAST operation is not presented in the plan + s"""SELECT ( ( "SUBQUERY_1"."O" $operator "SUBQUERY_1"."P" ) ) AS "SUBQUERY_2_COL_0" FROM + |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS + |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE + |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + """.stripMargin + ) + + testPushdownMultiplefQueries(expectedDivisionQueries, result, Seq(Row(1.000000), Row(1.333333), Row(1.500000)), disablePushDown ) @@ -406,15 +435,22 @@ class SimpleNewPushdownIntegrationSuite extends IntegrationSuiteBase { result = sparkSession.sql(s"select o $operator p from df2 where o IS NOT NULL") - testPushdown( + val expectedModQueries = Seq( s"""SELECT ( CAST ( ( "SUBQUERY_1"."O" $operator "SUBQUERY_1"."P" ) - |AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_0" FROM - |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS - |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE - |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + |AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_0" FROM + |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS + |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE + |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" """.stripMargin, - result, - // Data in df2 (o, p) values(null, 1), (2, 2), (3, 2), (4, 3) + // From spark 3.4, the CAST operation is not presented in the plan + s"""SELECT ( ( "SUBQUERY_1"."O" $operator "SUBQUERY_1"."P" ) ) AS "SUBQUERY_2_COL_0" FROM + |( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS + |"SF_CONNECTOR_QUERY_ALIAS" ) AS "SUBQUERY_0" WHERE + |( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + """.stripMargin + ) + + testPushdownMultiplefQueries(expectedModQueries, result, Seq(Row(0), Row(1), Row(1)), disablePushDown ) @@ -425,18 +461,28 @@ class SimpleNewPushdownIntegrationSuite extends IntegrationSuiteBase { s"select -o, - o + p, - o - p, - ( o + p ), - 3 + o from df2 where o IS NOT NULL" ) - testPushdown( + val expectedUnaryMinusQueries = Seq( s"""SELECT ( - ( "SUBQUERY_1"."O" ) ) AS "SUBQUERY_2_COL_0" , ( CAST ( - |( - ( "SUBQUERY_1"."O" ) + "SUBQUERY_1"."P" ) AS DECIMAL(38, 0) ) ) - |AS "SUBQUERY_2_COL_1" , ( CAST ( ( - ( "SUBQUERY_1"."O" ) - "SUBQUERY_1"."P" ) - |AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_2" , ( - ( CAST ( ( "SUBQUERY_1"."O" - |+ "SUBQUERY_1"."P" ) AS DECIMAL(38, 0) ) ) ) AS "SUBQUERY_2_COL_3" , ( CAST ( ( - |-3 + "SUBQUERY_1"."O" ) AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_4" FROM ( SELECT - |* FROM ( SELECT * FROM ( $test_table2 ) AS "SF_CONNECTOR_QUERY_ALIAS" ) AS - |"SUBQUERY_0" WHERE ( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + |( - ( "SUBQUERY_1"."O" ) + "SUBQUERY_1"."P" ) AS DECIMAL(38, 0) ) ) + |AS "SUBQUERY_2_COL_1" , ( CAST ( ( - ( "SUBQUERY_1"."O" ) - "SUBQUERY_1"."P" ) + |AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_2" , ( - ( CAST ( ( "SUBQUERY_1"."O" + |+ "SUBQUERY_1"."P" ) AS DECIMAL(38, 0) ) ) ) AS "SUBQUERY_2_COL_3" , ( CAST ( ( + |-3 + "SUBQUERY_1"."O" ) AS DECIMAL(38, 0) ) ) AS "SUBQUERY_2_COL_4" FROM ( SELECT + |* FROM ( SELECT * FROM ( $test_table2 ) AS "SF_CONNECTOR_QUERY_ALIAS" ) AS + |"SUBQUERY_0" WHERE ( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" """.stripMargin, - result, - // Data in df2 (o, p) values(2, 2), (3, 2), (4, 3) + // From spark 3.4, the CAST operation is not presented in the plan + s"""SELECT ( - ( "SUBQUERY_1"."O" ) ) AS "SUBQUERY_2_COL_0" , + | ( ( - ( "SUBQUERY_1"."O" ) + "SUBQUERY_1"."P" ) ) AS "SUBQUERY_2_COL_1" , + | ( ( - ( "SUBQUERY_1"."O" ) - "SUBQUERY_1"."P" ) ) AS "SUBQUERY_2_COL_2" , + | ( - ( ( "SUBQUERY_1"."O" + "SUBQUERY_1"."P" ) ) ) AS "SUBQUERY_2_COL_3" , + | ( ( -3 + "SUBQUERY_1"."O" ) ) AS "SUBQUERY_2_COL_4" + | FROM ( SELECT * FROM ( SELECT * FROM ( $test_table2 ) AS "SF_CONNECTOR_QUERY_ALIAS" ) AS + | "SUBQUERY_0" WHERE ( "SUBQUERY_0"."O" IS NOT NULL ) ) AS "SUBQUERY_1" + """.stripMargin + ) + + testPushdownMultiplefQueries(expectedUnaryMinusQueries, result, Seq( Row(-2, 0, -4, -4, -1), Row(-3, -1, -5, -5, 0), diff --git a/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFramesSuite.scala b/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFramesSuite.scala index e4a54c55..96c6f403 100644 --- a/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFramesSuite.scala +++ b/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFramesSuite.scala @@ -1,5 +1,6 @@ package org.apache.spark.sql +import net.snowflake.spark.snowflake.{SnowflakeConnectorUtils, TestUtils} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.snowflake.{SFQueryTest, SFTestData, SFTestSessionBase} @@ -167,23 +168,41 @@ class SFDataFrameWindowFramesSuite window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), Row("non_numeric", "non_numeric") :: Nil) + // The error message for 3.4 is different. + val expectedErrorMessage1 = + if (TestUtils.compareVersion(SnowflakeConnectorUtils.SUPPORT_SPARK_VERSION, "3.4") >= 0) { + "The data type of the upper bound \"STRING\" does not match the expected data type" + } else { + "The data type of the upper bound 'string' does not match the expected data type" + } val e1 = intercept[AnalysisException]( df.select( min("value").over(window.rangeBetween(Window.unboundedPreceding, 1)))) - assert(e1.message.contains("The data type of the upper bound 'string' " + - "does not match the expected data type")) - + assert(e1.message.contains(expectedErrorMessage1)) + + // The error message for 3.4 is different. + val expectedErrorMessage2 = + if (TestUtils.compareVersion(SnowflakeConnectorUtils.SUPPORT_SPARK_VERSION, "3.4") >= 0) { + "The data type of the lower bound \"STRING\" does not match the expected data type" + } else { + "The data type of the lower bound 'string' does not match the expected data type" + } val e2 = intercept[AnalysisException]( df.select( min("value").over(window.rangeBetween(-1, Window.unboundedFollowing)))) - assert(e2.message.contains("The data type of the lower bound 'string' " + - "does not match the expected data type")) - + assert(e2.message.contains(expectedErrorMessage2)) + + // The error message for 3.4 is different. + val expectedErrorMessage3 = + if (TestUtils.compareVersion(SnowflakeConnectorUtils.SUPPORT_SPARK_VERSION, "3.4") >= 0) { + "The data type of the lower bound \"STRING\" does not match the expected data type" + } else { + "The data type of the lower bound 'string' does not match the expected data type" + } val e3 = intercept[AnalysisException]( df.select( min("value").over(window.rangeBetween(-1, 1)))) - assert(e3.message.contains("The data type of the lower bound 'string' " + - "does not match the expected data type")) + assert(e3.message.contains(expectedErrorMessage3)) } test("range between should accept int/long values as boundary") { diff --git a/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFunctionsSuite.scala b/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFunctionsSuite.scala index f70a949b..77d10e01 100644 --- a/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFunctionsSuite.scala +++ b/src/it/scala/org/apache/spark/sql/SFDataFrameWindowFunctionsSuite.scala @@ -33,7 +33,7 @@ class SFDataFrameWindowFunctionsSuite } } - protected lazy val sql = spark.sql _ + protected def sql(sqlText: String) = spark.sql(sqlText) override def spark: SparkSession = getSnowflakeSession() diff --git a/src/it/scala/org/apache/spark/sql/SFDateFunctionsSuite.scala b/src/it/scala/org/apache/spark/sql/SFDateFunctionsSuite.scala index d1520db0..382358dc 100644 --- a/src/it/scala/org/apache/spark/sql/SFDateFunctionsSuite.scala +++ b/src/it/scala/org/apache/spark/sql/SFDateFunctionsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.unsafe.types.CalendarInterval class SFDateFunctionsSuite extends SFQueryTest with SFTestSessionBase { import SFTestImplicits._ - protected lazy val sql = spark.sql _ + protected def sql(sqlText: String) = spark.sql(sqlText) test("function current_date") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") diff --git a/src/it/scala/org/apache/spark/sql/SnowflakeSparkUtilsSuite.scala b/src/it/scala/org/apache/spark/sql/SnowflakeSparkUtilsSuite.scala index 767cb268..c315b706 100644 --- a/src/it/scala/org/apache/spark/sql/SnowflakeSparkUtilsSuite.scala +++ b/src/it/scala/org/apache/spark/sql/SnowflakeSparkUtilsSuite.scala @@ -9,7 +9,7 @@ class SnowflakeSparkUtilsSuite extends SFQueryTest with SFTestSessionBase { import SFTestImplicits._ - protected lazy val sql = spark.sql _ + protected def sql(sqlText: String) = spark.sql(sqlText) test("unit test: SnowflakeSparkUtils.getJDBCProviderName") { assert(SnowflakeSparkUtils.getJDBCProviderName( diff --git a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala index 90dd5e6f..f1b50b53 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/SnowflakeConnectorUtils.scala @@ -34,7 +34,7 @@ object SnowflakeConnectorUtils { * Check Spark version, if Spark version matches SUPPORT_SPARK_VERSION enable PushDown, * otherwise disable it. */ - val SUPPORT_SPARK_VERSION = "3.3" + val SUPPORT_SPARK_VERSION = "3.4" def checkVersionAndEnablePushdown(session: SparkSession): Boolean = if (session.version.startsWith(SUPPORT_SPARK_VERSION)) { diff --git a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/MiscStatement.scala b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/MiscStatement.scala index cbcc240a..30d3c50e 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/MiscStatement.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/MiscStatement.scala @@ -8,6 +8,7 @@ import net.snowflake.spark.snowflake.{ SnowflakePushdownUnsupportedException, SnowflakeSQLStatement } +import org.apache.spark.sql.catalyst.expressions.EvalMode.LEGACY import org.apache.spark.sql.catalyst.expressions.{ Alias, Ascending, @@ -47,7 +48,13 @@ private[querygeneration] object MiscStatement { // override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled // So support to pushdown, if ansiEnabled is false. // https://github.com/apache/spark/commit/6f51e37eb52f21b50c8d7b15c68bf9969fee3567 - case Cast(child, t, _, ansiEnabled) if !ansiEnabled => + // Spark 3.4 changed the last argument type: + // https://github.com/apache/spark/commit/f8d51b9940b5f1f7c1f37693b10931cbec0a4741 + // - Old type: ansiEnabled: Boolean = SQLConf.get.ansiEnabled + // - New Type: evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get) + // Currently, there are 3 modes: LEGACY, ANSI, TRY + // support to pushdown, if the mode is LEGACY. + case Cast(child, t, _, evalMode) if evalMode == LEGACY => getCastType(t) match { case Some(cast) => // For known unsupported data conversion, raise exception to break the @@ -112,7 +119,10 @@ private[querygeneration] object MiscStatement { // joinCond: Seq[Expression] = Seq.empty // So support to pushdown, if joinCond is empty. // https://github.com/apache/spark/commit/806da9d6fae403f88aac42213a58923cf6c2cb05 - case ScalarSubquery(subquery, _, _, joinCond) if joinCond.isEmpty => + // Spark 3.4 introduce join hint. The join hint doesn't affect correctness. + // So it can be ignored in the pushdown process + // https://github.com/apache/spark/commit/0fa9c554fc0b3940a47c3d1c6a5a17ca9a8cee8e + case ScalarSubquery(subquery, _, _, joinCond, _) if joinCond.isEmpty => blockStatement(new QueryBuilder(subquery).statement) case UnscaledValue(child) => diff --git a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/NumericStatement.scala b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/NumericStatement.scala index 91386867..86ace904 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/NumericStatement.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/pushdowns/querygeneration/NumericStatement.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.expressions.{ Log, Pi, Pow, - PromotePrecision, + // PromotePrecision is removed from Spark 3.4 + // PromotePrecision, Rand, Round, Sin, @@ -75,7 +76,9 @@ private[querygeneration] object NumericStatement { ) ) - case PromotePrecision(child) => convertStatement(child, fields) + // PromotePrecision is removed from Spark 3.4 + // https://github.com/apache/spark/pull/36698 + // case PromotePrecision(child) => convertStatement(child, fields) case CheckOverflow(child, t, _) => MiscStatement.getCastType(t) match { @@ -94,7 +97,11 @@ private[querygeneration] object NumericStatement { ConstantString("RANDOM") + blockStatement( LongVariable(Option(seed).map(_.asInstanceOf[Long])) ! ) - case Round(child, scale) => + + // Spark 3.4 adds a new argument: ansiEnabled + // https://github.com/apache/spark/commit/42721120f3c7206a9fc22db5d0bb7cf40f0cacfd + // The pushdown is supported for non-ANSI mode. + case Round(child, scale, ansiEnabled) if !ansiEnabled => ConstantString("ROUND") + blockStatement( convertStatements(fields, child, scale) ) diff --git a/src/test/scala/net/snowflake/spark/snowflake/TestUtils.scala b/src/test/scala/net/snowflake/spark/snowflake/TestUtils.scala index bcd9007c..b3ac6530 100644 --- a/src/test/scala/net/snowflake/spark/snowflake/TestUtils.scala +++ b/src/test/scala/net/snowflake/spark/snowflake/TestUtils.scala @@ -229,4 +229,18 @@ object TestUtils { def getServerConnection(connection: Connection): ServerConnection = ServerConnection(connection) + + def compareVersion(leftVersion: String, rightVersion: String): Int = { + val leftVersionParts = leftVersion.split("\\.").map(_.toLong) + val rightVersionParts = rightVersion.split("\\.").map(_.toLong) + for (i <- 0 until Math.min(leftVersionParts.length, rightVersionParts.length)) { + if (leftVersionParts(i) > rightVersionParts(i)) { + return 1 + } else if (leftVersionParts(i) < rightVersionParts(i)) { + return -1 + } + } + // 3.1 < 3.1.1 + leftVersionParts.length - rightVersionParts.length + } } From 0be43cc1b4d15680e68cee630df7a3122d4c5a8c Mon Sep 17 00:00:00 2001 From: Mingli Rui <63472932+sfc-gh-mrui@users.noreply.github.com> Date: Mon, 22 May 2023 17:59:25 -0700 Subject: [PATCH 13/15] SNOW-824420 Update SC version to 2.12.0 (#512) * SNOW-824420 Update SC version to 2.12.0 * Update spark version in build_image.sh --- .github/docker/build_image.sh | 4 ++-- .github/workflows/ClusterTest.yml | 2 +- ClusterTest/build.sbt | 2 +- build.sbt | 2 +- src/main/scala/net/snowflake/spark/snowflake/Utils.scala | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/docker/build_image.sh b/.github/docker/build_image.sh index 78eeb452..a126fb02 100755 --- a/.github/docker/build_image.sh +++ b/.github/docker/build_image.sh @@ -33,8 +33,8 @@ cd ../.. # Build docker image docker build \ ---build-arg SPARK_URL=https://archive.apache.org/dist/spark/spark-3.3.0/spark-3.3.0-bin-hadoop3.tgz \ ---build-arg SPARK_BINARY_NAME=spark-3.3.0-bin-hadoop3.tgz \ +--build-arg SPARK_URL=https://archive.apache.org/dist/spark/spark-3.4.0/spark-3.4.0-bin-hadoop3.tgz \ +--build-arg SPARK_BINARY_NAME=spark-3.4.0-bin-hadoop3.tgz \ --build-arg JDBC_URL=https://repo1.maven.org/maven2/net/snowflake/snowflake-jdbc/${TEST_JDBC_VERSION}/$JDBC_JAR_NAME \ --build-arg JDBC_BINARY_NAME=$JDBC_JAR_NAME \ --build-arg SPARK_CONNECTOR_LOCATION=target/scala-${TEST_SCALA_VERSION}/$SPARK_CONNECTOR_JAR_NAME \ diff --git a/.github/workflows/ClusterTest.yml b/.github/workflows/ClusterTest.yml index 6eb14abb..b5364f26 100644 --- a/.github/workflows/ClusterTest.yml +++ b/.github/workflows/ClusterTest.yml @@ -22,7 +22,7 @@ jobs: DOCKER_IMAGE_TAG: 'snowflakedb/spark-base:3.4.0' TEST_SCALA_VERSION: '2.12' TEST_COMPILE_SCALA_VERSION: '2.12.11' - TEST_SPARK_CONNECTOR_VERSION: '2.11.3' + TEST_SPARK_CONNECTOR_VERSION: '2.12.0' TEST_JDBC_VERSION: '3.13.30' steps: diff --git a/ClusterTest/build.sbt b/ClusterTest/build.sbt index 475c4c0a..cdaf0797 100644 --- a/ClusterTest/build.sbt +++ b/ClusterTest/build.sbt @@ -14,7 +14,7 @@ * limitations under the License. */ -val sparkConnectorVersion = "2.11.3" +val sparkConnectorVersion = "2.12.0" val scalaVersionMajor = "2.12" val sparkVersionMajor = "3.4" val sparkVersion = s"${sparkVersionMajor}.0" diff --git a/build.sbt b/build.sbt index 10ef4a2b..72601c66 100644 --- a/build.sbt +++ b/build.sbt @@ -26,7 +26,7 @@ val testSparkVersion = sys.props.get("spark.testVersion").getOrElse("3.4.0") * Tests/jenkins/BumpUpSparkConnectorVersion/run.sh * in snowflake repository. */ -val sparkConnectorVersion = "2.11.3" +val sparkConnectorVersion = "2.12.0" lazy val ItTest = config("it") extend Test diff --git a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala index b867f690..e7b738df 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Utils.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Utils.scala @@ -55,7 +55,7 @@ object Utils { */ val SNOWFLAKE_SOURCE_SHORT_NAME = "snowflake" - val VERSION = "2.11.3" + val VERSION = "2.12.0" /** * The certified JDBC version to work with this spark connector version. From 5e7bdddc984103a3d9c302fd7c322990e3dab4f2 Mon Sep 17 00:00:00 2001 From: Arthur Li Date: Sun, 23 Jul 2023 11:27:16 -0400 Subject: [PATCH 14/15] pulled latest change from master --- .../scala/net/snowflake/spark/snowflake/Parameters.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala index 4722c92c..363a2744 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala @@ -719,6 +719,13 @@ object Parameters { def supportShareConnection: Boolean = { isTrue(parameters.getOrElse(PARAM_SUPPORT_SHARE_CONNECTION, "true")) } + def supportMicroSecondDuringUnload: Boolean = { + isTrue(parameters.getOrElse(PARAM_INTERNAL_SUPPORT_MICRO_SECOND_DURING_UNLOAD, "false")) + } + def forceSkipPrePostActionsCheck: Boolean = { + isTrue(parameters.getOrElse( + PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING, "false")) + } def stagingTableNameRemoveQuotesOnly: Boolean = { isTrue(parameters.getOrElse(PARAM_INTERNAL_STAGING_TABLE_NAME_REMOVE_QUOTES_ONLY, "false")) } From aaff19f6107f7f652d2bac2158aba9bf0fc07b3e Mon Sep 17 00:00:00 2001 From: Arthur Li Date: Sun, 23 Jul 2023 11:41:24 -0400 Subject: [PATCH 15/15] Fixed merge --- .../net/snowflake/spark/snowflake/Parameters.scala | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala index 3b6f4f13..7b2813f2 100644 --- a/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala +++ b/src/main/scala/net/snowflake/spark/snowflake/Parameters.scala @@ -242,14 +242,6 @@ object Parameters { "force_skip_pre_post_action_check_for_session_sharing" ) - // preactions and postactions may affect the session level setting, so connection sharing - // may be enabled only when the queries in preactions and postactions are in a white list. - // force_skip_pre_post_action_check_for_session_sharing is introduced if users are sure that - // the queries in preactions and postactions don't affect others. Its default value is false. - val PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING: String = knownParam( - "force_skip_pre_post_action_check_for_session_sharing" - ) - val DEFAULT_S3_MAX_FILE_SIZE: String = (10 * 1000 * 1000).toString val MIN_S3_MAX_FILE_SIZE = 1000000 @@ -734,10 +726,6 @@ object Parameters { isTrue(parameters.getOrElse( PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING, "false")) } - def forceSkipPrePostActionsCheck: Boolean = { - isTrue(parameters.getOrElse( - PARAM_FORCE_SKIP_PRE_POST_ACTION_CHECK_FOR_SESSION_SHARING, "false")) - } def stagingTableNameRemoveQuotesOnly: Boolean = { isTrue(parameters.getOrElse(PARAM_INTERNAL_STAGING_TABLE_NAME_REMOVE_QUOTES_ONLY, "false")) }