Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fuzz tests for cast from string to other types #2898

Merged
merged 10 commits into from
Jul 13, 2021
119 changes: 119 additions & 0 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.TimeZone

import ai.rapids.cudf.ColumnVector
import scala.collection.JavaConverters._
import scala.util.Random

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
Expand Down Expand Up @@ -56,6 +57,124 @@ class CastOpSuite extends GpuExpressionTestSuite {
for (from <- supportedTypes; to <- supportedTypes) yield (from, to)
}

private val BOOL_CHARS = " \t\r\nfalseTRUE01yesNO"
jlowe marked this conversation as resolved.
Show resolved Hide resolved
private val NUMERIC_CHARS = "inf \t\r\n0123456789.+-eE"
private val DATE_CHARS = " \t\r\n0123456789:-/TZ"

ignore("Cast from string to boolean using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2902
castRandomStrings(DataTypes.BooleanType, BOOL_CHARS, maxStringLen = 1)
castRandomStrings(DataTypes.BooleanType, BOOL_CHARS, maxStringLen = 3)
castRandomStrings(DataTypes.BooleanType, BOOL_CHARS)
}

ignore("Cast from string to byte using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2899
castRandomStrings(DataTypes.ByteType, NUMERIC_CHARS)
}

ignore("Cast from string to short using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2899
castRandomStrings(DataTypes.ShortType, NUMERIC_CHARS)
}

ignore("Cast from string to int using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2899
castRandomStrings(DataTypes.IntegerType, NUMERIC_CHARS)
}

ignore("Cast from string to long using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2899
castRandomStrings(DataTypes.LongType, NUMERIC_CHARS)
}

ignore("Cast from string to float using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2900
castRandomStrings(DataTypes.FloatType, NUMERIC_CHARS)
}

ignore("Cast from string to double using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2900
castRandomStrings(DataTypes.DoubleType, NUMERIC_CHARS)
}

test("Cast from string to date using random inputs") {
castRandomStrings(DataTypes.DateType, DATE_CHARS, maxStringLen = 8)
}

ignore("Cast from string to date using random inputs with valid year prefix") {
// this will fail until https://github.com/NVIDIA/spark-rapids/pull/2890 is merged
castRandomStrings(DataTypes.DateType, DATE_CHARS, maxStringLen = 8, Some("2021"))
}

ignore("Cast from string to timestamp using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2889
castRandomStrings(DataTypes.TimestampType, DATE_CHARS, maxStringLen = 32, None)
}

ignore("Cast from string to timestamp using random inputs with valid year prefix") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2889
castRandomStrings(DataTypes.TimestampType, DATE_CHARS, maxStringLen = 32, Some("2021-"))
}

private def castRandomStrings(
toType: DataType,
validChars: String,
maxStringLen: Int = 12,
prefix: Option[String] = None) {

val randomValueCount = 8192

val random = new Random(0)
val r = new EnhancedRandom(random,
new FuzzerOptions(validChars, maxStringLen))

val randomStrings = (0 until randomValueCount)
.map(n => (n, prefix.getOrElse("") + r.nextString()))

def castDf(spark: SparkSession): Seq[Row] = {
import spark.implicits._
val df = randomStrings.toDF("id", "c0").repartition(2)
val castDf = df.withColumn("c1", col("c0").cast(toType))
println(castDf.queryExecution.executedPlan)
castDf.collect()
}

val cpu = withCpuSparkSession(castDf)
.sortBy(_.getInt(0))

val conf = new SparkConf()
.set(RapidsConf.EXPLAIN.key, "ALL")
.set(RapidsConf.INCOMPATIBLE_DATE_FORMATS.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_FLOAT.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_DECIMAL.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true")

val gpu = withGpuSparkSession(castDf, conf)
.sortBy(_.getInt(0))

for ((cpuRow, gpuRow) <- cpu.zip(gpu)) {
assert(cpuRow.getInt(0) === gpuRow.getInt(0))
assert(cpuRow.getString(1) === gpuRow.getString(1))
val cpuValue = cpuRow.get(2)
val gpuValue = gpuRow.get(2)
if (!compare(cpuValue, gpuValue)) {
fail(s"Mismatch casting string [${cpuRow.getString(1)}] " +
s"to $toType. CPU: $cpuValue; GPU: $gpuValue")
}
}
}

test("Test all supported casts with in-range values") {
// test cast() and ansi_cast()
Seq(false, true).foreach { ansiEnabled =>
Expand Down
47 changes: 17 additions & 30 deletions tests/src/test/scala/com/nvidia/spark/rapids/FuzzerUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ object FuzzerUtils {
/**
* Default options when generating random data.
*/
private val DEFAULT_OPTIONS = FuzzerOptions(
numbersAsStrings = true,
asciiStringsOnly = false,
maxStringLen = 64)
private val DEFAULT_OPTIONS = FuzzerOptions()

/**
* Create a schema with the specified data types.
Expand Down Expand Up @@ -331,20 +328,6 @@ class EnhancedRandom(protected val r: Random, protected val options: FuzzerOptio
}
}

def nextString(): String = {
if (options.numbersAsStrings) {
r.nextInt(5) match {
case 0 => String.valueOf(r.nextInt())
case 1 => String.valueOf(r.nextLong())
case 2 => String.valueOf(r.nextFloat())
case 3 => String.valueOf(r.nextDouble())
case 4 => generateString()
}
} else {
generateString()
}
}

def nextDate(): Date = {
val futureDate = 6321706291000L // Upper limit Sunday, April 29, 2170 9:31:31 PM
new Date((futureDate * r.nextDouble()).toLong);
Expand All @@ -355,22 +338,26 @@ class EnhancedRandom(protected val r: Random, protected val options: FuzzerOptio
new Timestamp((futureDate * r.nextDouble()).toLong)
}

private def generateString(): String = {
if (options.asciiStringsOnly) {
val b = new StringBuilder()
for (_ <- 0 until options.maxStringLen) {
b.append(ASCII_CHARS.charAt(r.nextInt(ASCII_CHARS.length)))
}
b.toString
} else {
r.nextString(r.nextInt(options.maxStringLen))
def nextString(): String = {
val b = new StringBuilder(options.maxStringLen)
for (_ <- 0 until options.maxStringLen) {
b.append(options.validStringChars.charAt(r.nextInt(options.validStringChars.length)))
}
b.toString
}

private val ASCII_CHARS = "abcdefghijklmnopqrstuvwxyz"
}

object FuzzerOptions {
val ALPHABET_CHARS: String = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
val NUMERIC_CHARS: String = "0123456789"
val ALPHANUMERIC_CHARS: String = ALPHABET_CHARS + NUMERIC_CHARS
val WHITESPACE_CHARS: String = " \t\r\n"
val SPECIAL_CHARS: String =
"!@#$%^&*()-+=/?,.<>\\|[]{}~;:`\"'"
val ALL_CHARS: String = ALPHABET_CHARS + NUMERIC_CHARS + SPECIAL_CHARS + WHITESPACE_CHARS
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}

case class FuzzerOptions(
numbersAsStrings: Boolean = true,
asciiStringsOnly: Boolean = false,
validStringChars: String = FuzzerOptions.ALL_CHARS,
maxStringLen: Int = 64)
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite {
}

def firstDf(spark: SparkSession): DataFrame = {
val options = FuzzerOptions(asciiStringsOnly = true, numbersAsStrings = false,
val options = FuzzerOptions(validStringChars = FuzzerOptions.ALPHABET_CHARS,
maxStringLen = 4)
val schema = FuzzerUtils.createSchema(Seq(DataTypes.StringType, DataTypes.IntegerType))
FuzzerUtils.generateDataFrame(spark, schema, 100, options, seed = 0)
Expand Down Expand Up @@ -857,7 +857,7 @@ class HashAggregatesSuite extends SparkQueryCompareTestSuite {
private def randomDF(dataType: DataType)(spark: SparkSession) : DataFrame = {
val schema = FuzzerUtils.createSchema(Seq(DataTypes.StringType, dataType))
FuzzerUtils.generateDataFrame(spark, schema, rowCount = 1000,
options = FuzzerOptions(numbersAsStrings = false, asciiStringsOnly = true,
options = FuzzerOptions(validStringChars = FuzzerOptions.ALPHABET_CHARS,
maxStringLen = 2))
}

Expand Down