Skip to content

Commit

Permalink
[SPARK-49505][SQL] Create new SQL functions "randstr" and "uniform" t…
Browse files Browse the repository at this point in the history
…o generate random strings or numbers within ranges

### What changes were proposed in this pull request?

This PR introduces two new SQL functions "randstr" and "uniform" to generate random strings or numbers within ranges.

* The "randstr" function returns a string of the specified length whose characters are chosen uniformly at random from the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively).
* The "uniform" function returns a random value with independent and identically distributed  values with the specified range of numbers. The random seed is optional. The provided numbers specifying the minimum and maximum values of the range must be constant. If both of these numbers are integers, then the result will also be an integer. Otherwise if one or both of these are floating-point numbers, then the result will also be a floating-point number.

For example:

```
SELECT randstr(5);
> ceV0P

SELECT randstr(10, 0) FROM VALUES (0), (1), (2) tab(col);
> ceV0PXaR2I
  fYxVfArnv7
  iSIv0VT2XL

SELECT uniform(10, 20.0F);
> 17.604954

SELECT uniform(10, 20, 0) FROM VALUES (0), (1), (2) tab(col);
> 15
  16
  17
```

### Why are the changes needed?

This improves the SQL functionality of Apache Spark and improves its parity with other systems:
* https://clickhouse.com/docs/en/sql-reference/functions/random-functions#randuniform
* https://docs.snowflake.com/en/sql-reference/functions/uniform
* https://www.microfocus.com/documentation/silk-test/21.0.2/en/silktestclassic-help-en/STCLASSIC-8BFE8661-RANDSTRFUNCTION-REF.html
* https://docs.snowflake.com/en/sql-reference/functions/randstr

### Does this PR introduce _any_ user-facing change?

Yes, see above.

### How was this patch tested?

This PR adds golden file based test coverage.

### Was this patch authored or co-authored using generative AI tooling?

Not this time.

Closes #48004 from dtenedor/uniform-randstr-functions.

Authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
dtenedor authored and MaxGekk committed Sep 17, 2024
1 parent f586ffb commit 6393afa
Show file tree
Hide file tree
Showing 7 changed files with 1,191 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ object FunctionRegistry {
expression[Rand]("rand"),
expression[Rand]("random", true, Some("3.0.0")),
expression[Randn]("randn"),
expression[RandStr]("randstr"),
expression[Stack]("stack"),
expression[Uniform]("uniform"),
expression[ZeroIfNull]("zeroifnull"),
CaseWhen.registryEntry,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, TreePattern}
import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.random.XORShiftRandom

/**
Expand All @@ -33,8 +38,7 @@ import org.apache.spark.util.random.XORShiftRandom
*
* Since this expression is stateful, it cannot be a case object.
*/
abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic
with ExpressionWithRandomSeed {
trait RDG extends Expression with ExpressionWithRandomSeed {
/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize and initialize it.
Expand All @@ -43,12 +47,6 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm

override def stateful: Boolean = true

override protected def initializeInternal(partitionIndex: Int): Unit = {
rng = new XORShiftRandom(seed + partitionIndex)
}

override def seedExpression: Expression = child

@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
Expand All @@ -57,6 +55,15 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm
override def nullable: Boolean = false

override def dataType: DataType = DoubleType
}

abstract class NondeterministicUnaryRDG
extends RDG with UnaryLike[Expression] with Nondeterministic with ExpectsInputTypes {
override def seedExpression: Expression = child

override protected def initializeInternal(partitionIndex: Int): Unit = {
rng = new XORShiftRandom(seed + partitionIndex)
}

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType))
}
Expand Down Expand Up @@ -99,7 +106,7 @@ private[catalyst] object ExpressionWithRandomSeed {
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG {
case class Rand(child: Expression, hideSeed: Boolean = false) extends NondeterministicUnaryRDG {

def this() = this(UnresolvedSeed, true)

Expand Down Expand Up @@ -150,7 +157,7 @@ object Rand {
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG {
case class Randn(child: Expression, hideSeed: Boolean = false) extends NondeterministicUnaryRDG {

def this() = this(UnresolvedSeed, true)

Expand Down Expand Up @@ -181,3 +188,236 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG {
object Randn {
def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
}

@ExpressionDescription(
usage = """
_FUNC_(min, max[, seed]) - Returns a random value with independent and identically
distributed (i.i.d.) values with the specified range of numbers. The random seed is optional.
The provided numbers specifying the minimum and maximum values of the range must be constant.
If both of these numbers are integers, then the result will also be an integer. Otherwise if
one or both of these are floating-point numbers, then the result will also be a floating-point
number.
""",
examples = """
Examples:
> SELECT _FUNC_(10, 20, 0) > 0 AS result;
true
""",
since = "4.0.0",
group = "math_funcs")
case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
extends RuntimeReplaceable with TernaryLike[Expression] with RDG {
def this(min: Expression, max: Expression) = this(min, max, UnresolvedSeed)

final override lazy val deterministic: Boolean = false
override val nodePatterns: Seq[TreePattern] =
Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)

override val dataType: DataType = {
val first = min.dataType
val second = max.dataType
(min.dataType, max.dataType) match {
case _ if !seedExpression.resolved || seedExpression.dataType == NullType =>
NullType
case (_, NullType) | (NullType, _) => NullType
case (_, LongType) | (LongType, _)
if Seq(first, second).forall(integer) => LongType
case (_, IntegerType) | (IntegerType, _)
if Seq(first, second).forall(integer) => IntegerType
case (_, ShortType) | (ShortType, _)
if Seq(first, second).forall(integer) => ShortType
case (_, DoubleType) | (DoubleType, _) => DoubleType
case (_, FloatType) | (FloatType, _) => FloatType
case _ =>
throw SparkException.internalError(
s"Unexpected argument data types: ${min.dataType}, ${max.dataType}")
}
}

private def integer(t: DataType): Boolean = t match {
case _: ShortType | _: IntegerType | _: LongType => true
case _ => false
}

override def checkInputDataTypes(): TypeCheckResult = {
var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
def requiredType = "integer or floating-point"
Seq((min, "min", 0),
(max, "max", 1),
(seedExpression, "seed", 2)).foreach {
case (expr: Expression, name: String, index: Int) =>
if (result == TypeCheckResult.TypeCheckSuccess) {
if (!expr.foldable) {
result = DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> name,
"inputType" -> requiredType,
"inputExpr" -> toSQLExpr(expr)))
} else expr.dataType match {
case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType |
_: NullType =>
case _ =>
result = DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(index),
"requiredType" -> requiredType,
"inputSql" -> toSQLExpr(expr),
"inputType" -> toSQLType(expr.dataType)))
}
}
}
result
}

override def first: Expression = min
override def second: Expression = max
override def third: Expression = seedExpression

override def withNewSeed(newSeed: Long): Expression =
Uniform(min, max, Literal(newSeed, LongType))

override def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
Uniform(newFirst, newSecond, newThird)

override def replacement: Expression = {
if (Seq(min, max, seedExpression).exists(_.dataType == NullType)) {
Literal(null)
} else {
def cast(e: Expression, to: DataType): Expression = if (e.dataType == to) e else Cast(e, to)
cast(Add(
cast(min, DoubleType),
Multiply(
Subtract(
cast(max, DoubleType),
cast(min, DoubleType)),
Rand(seed))),
dataType)
}
}
}

@ExpressionDescription(
usage = """
_FUNC_(length[, seed]) - Returns a string of the specified length whose characters are chosen
uniformly at random from the following pool of characters: 0-9, a-z, A-Z. The random seed is
optional. The string length must be a constant two-byte or four-byte integer (SMALLINT or INT,
respectively).
""",
examples =
"""
Examples:
> SELECT _FUNC_(3, 0) AS result;
ceV
""",
since = "4.0.0",
group = "string_funcs")
case class RandStr(length: Expression, override val seedExpression: Expression)
extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic {
def this(length: Expression) = this(length, UnresolvedSeed)

override def nullable: Boolean = false
override def dataType: DataType = StringType
override def stateful: Boolean = true
override def left: Expression = length
override def right: Expression = seedExpression

/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize and initialize it.
*/
@transient protected var rng: XORShiftRandom = _

@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
}
override protected def initializeInternal(partitionIndex: Int): Unit = {
rng = new XORShiftRandom(seed + partitionIndex)
}

override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType))
override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression =
RandStr(newFirst, newSecond)

override def checkInputDataTypes(): TypeCheckResult = {
var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
def requiredType = "INT or SMALLINT"
Seq((length, "length", 0),
(seedExpression, "seedExpression", 1)).foreach {
case (expr: Expression, name: String, index: Int) =>
if (result == TypeCheckResult.TypeCheckSuccess) {
if (!expr.foldable) {
result = DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> name,
"inputType" -> requiredType,
"inputExpr" -> toSQLExpr(expr)))
} else expr.dataType match {
case _: ShortType | _: IntegerType =>
case _: LongType if index == 1 =>
case _ =>
result = DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(index),
"requiredType" -> requiredType,
"inputSql" -> toSQLExpr(expr),
"inputType" -> toSQLType(expr.dataType)))
}
}
}
result
}

override def evalInternal(input: InternalRow): Any = {
val numChars = length.eval(input).asInstanceOf[Number].intValue()
val bytes = new Array[Byte](numChars)
(0 until numChars).foreach { i =>
// We generate a random number between 0 and 61, inclusive. Between the 62 different choices
// we choose 0-9, a-z, or A-Z, where each category comprises 10 choices, 26 choices, or 26
// choices, respectively (10 + 26 + 26 = 62).
val num = (rng.nextInt() % 62).abs
num match {
case _ if num < 10 =>
bytes.update(i, ('0' + num).toByte)
case _ if num < 36 =>
bytes.update(i, ('a' + num - 10).toByte)
case _ =>
bytes.update(i, ('A' + num - 36).toByte)
}
}
val result: UTF8String = UTF8String.fromBytes(bytes.toArray)
result
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = classOf[XORShiftRandom].getName
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
val eval = length.genCode(ctx)
ev.copy(code =
code"""
|${eval.code}
|int length = (int)(${eval.value});
|char[] chars = new char[length];
|for (int i = 0; i < length; i++) {
| int v = Math.abs($rngTerm.nextInt() % 62);
| if (v < 10) {
| chars[i] = (char)('0' + v);
| } else if (v < 36) {
| chars[i] = (char)('a' + (v - 10));
| } else {
| chars[i] = (char)('A' + (v - 36));
| }
|}
|UTF8String ${ev.value} = UTF8String.fromString(new String(chars));
|boolean ${ev.isNull} = false;
|""".stripMargin,
isNull = FalseLiteral)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types.{IntegerType, LongType}

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand All @@ -41,4 +42,27 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(Rand(Literal(1L), false).sql === "rand(1L)")
assert(Randn(Literal(1L), false).sql === "randn(1L)")
}

test("SPARK-49505: Test the RANDSTR and UNIFORM SQL functions without codegen") {
// Note that we use a seed of zero in these tests to keep the results deterministic.
def testRandStr(first: Any, result: Any): Unit = {
checkEvaluationWithoutCodegen(
RandStr(Literal(first), Literal(0)), CatalystTypeConverters.convertToCatalyst(result))
}
testRandStr(1, "c")
testRandStr(5, "ceV0P")
testRandStr(10, "ceV0PXaR2I")
testRandStr(10L, "ceV0PXaR2I")

def testUniform(first: Any, second: Any, result: Any): Unit = {
checkEvaluationWithoutCodegen(
Uniform(Literal(first), Literal(second), Literal(0)).replacement,
CatalystTypeConverters.convertToCatalyst(result))
}
testUniform(0, 1, 0)
testUniform(0, 10, 7)
testUniform(0L, 10L, 7L)
testUniform(10.0F, 20.0F, 17.604954F)
testUniform(10L, 20.0F, 17.604954F)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@
| org.apache.spark.sql.catalyst.expressions.RaiseErrorExpressionBuilder | raise_error | SELECT raise_error('custom error message') | struct<raise_error(USER_RAISED_EXCEPTION, map(errorMessage, custom error message)):void> |
| org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | struct<rand():double> |
| org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct<rand():double> |
| org.apache.spark.sql.catalyst.expressions.RandStr | randstr | SELECT randstr(3, 0) AS result | struct<result:string> |
| org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct<randn():double> |
| org.apache.spark.sql.catalyst.expressions.Rank | rank | SELECT a, b, rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct<a:string,b:int,RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int> |
| org.apache.spark.sql.catalyst.expressions.RegExpCount | regexp_count | SELECT regexp_count('Steven Jones and Stephen Smith are the best players', 'Ste(v&#124;ph)en') | struct<regexp_count(Steven Jones and Stephen Smith are the best players, Ste(v&#124;ph)en):int> |
Expand Down Expand Up @@ -367,6 +368,7 @@
| org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct<negative(1):int> |
| org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> |
| org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct<decode(unhex(537061726B2053514C), UTF-8):string> |
| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(10, 20, 0) > 0 AS result | struct<result:boolean> |
| org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct<unix_date(1970-01-02):int> |
| org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct<unix_micros(1970-01-01 00:00:01Z):bigint> |
| org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct<unix_millis(1970-01-01 00:00:01Z):bigint> |
Expand Down
Loading

0 comments on commit 6393afa

Please sign in to comment.