Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{Interval, UTF8String}


object Cast {
Expand Down Expand Up @@ -55,6 +55,9 @@ object Cast {

case (_, DateType) => true

case (StringType, IntervalType) => true
case (IntervalType, StringType) => true

case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
case (DateType, _: NumericType) => true
Expand Down Expand Up @@ -232,6 +235,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case _ => _ => null
}

// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => Interval.fromString(s.toString))
case _ => _ => null
}

// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
Expand Down Expand Up @@ -405,6 +415,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
case IntervalType => castToInterval(from)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
Expand Down Expand Up @@ -442,6 +453,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))")

case (StringType, IntervalType) =>
defineCodeGen(ctx, ev, c =>
s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())")

// fallback for DecimalType, this must be before other numeric types
case (_, dt: DecimalType) =>
super.genCode(ctx, ev)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,4 +563,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
InternalRow(0L)))
}

test("case between string and interval") {
import org.apache.spark.unsafe.types.Interval

checkEvaluation(Cast(Literal("interval -3 month 7 hours"), IntervalType),
new Interval(-3, 7 * Interval.MICROS_PER_HOUR))
checkEvaluation(Cast(Literal.create(
new Interval(15, -3 * Interval.MICROS_PER_DAY), IntervalType), StringType),
"interval 1 years 3 months -3 days")
}

}
48 changes: 48 additions & 0 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.unsafe.types;

import java.io.Serializable;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* The internal representation of interval type.
Expand All @@ -30,6 +32,52 @@ public final class Interval implements Serializable {
public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24;
public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7;

/**
* A function to generate regex which matches interval string's unit part like "3 years".
*
* First, we can leave out some units in interval string, and we only care about the value of
* unit, so here we use non-capturing group to wrap the actual regex.
* At the beginning of the actual regex, we should match spaces before the unit part.
* Next is the number part, starts with an optional "-" to represent negative value. We use
* capturing group to wrap this part as we need the value later.
* Finally is the unit name, ends with an optional "s".
*/
private static String unitRegex(String unit) {
return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add more comments explaining the regex, since most people would need to look up documentation on the regex.

}

private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") +
unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") +
unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond"));

private static long toLong(String s) {
if (s == null) {
return 0;
} else {
return Long.valueOf(s);
}
}

public static Interval fromString(String s) {
if (s == null) {
return null;
}
Matcher m = p.matcher(s);
if (!m.matches() || s.equals("interval")) {
return null;
} else {
long months = toLong(m.group(1)) * 12 + toLong(m.group(2));
long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK;
microseconds += toLong(m.group(4)) * MICROS_PER_DAY;
microseconds += toLong(m.group(5)) * MICROS_PER_HOUR;
microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE;
microseconds += toLong(m.group(7)) * MICROS_PER_SECOND;
microseconds += toLong(m.group(8)) * MICROS_PER_MILLI;
microseconds += toLong(m.group(9));
return new Interval((int) months, microseconds);
}
}

public final int months;
public final long microseconds;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,50 @@ public void toStringTest() {
i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123);
assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds");
}

@Test
public void fromStringTest() {
testSingleUnit("year", 3, 36, 0);
testSingleUnit("month", 3, 3, 0);
testSingleUnit("week", 3, 0, 3 * MICROS_PER_WEEK);
testSingleUnit("day", 3, 0, 3 * MICROS_PER_DAY);
testSingleUnit("hour", 3, 0, 3 * MICROS_PER_HOUR);
testSingleUnit("minute", 3, 0, 3 * MICROS_PER_MINUTE);
testSingleUnit("second", 3, 0, 3 * MICROS_PER_SECOND);
testSingleUnit("millisecond", 3, 0, 3 * MICROS_PER_MILLI);
testSingleUnit("microsecond", 3, 0, 3);

String input;

input = "interval -5 years 23 month";
Interval result = new Interval(-5 * 12 + 23, 0);
assertEquals(Interval.fromString(input), result);

// Error cases
input = "interval 3month 1 hour";
assertEquals(Interval.fromString(input), null);

input = "interval 3 moth 1 hour";
assertEquals(Interval.fromString(input), null);

input = "interval";
assertEquals(Interval.fromString(input), null);

input = "int";
assertEquals(Interval.fromString(input), null);

input = "";
assertEquals(Interval.fromString(input), null);

input = null;
assertEquals(Interval.fromString(input), null);
}

private void testSingleUnit(String unit, int number, int months, long microseconds) {
String input1 = "interval " + number + " " + unit;
String input2 = "interval " + number + " " + unit + "s";
Interval result = new Interval(months, microseconds);
assertEquals(Interval.fromString(input1), result);
assertEquals(Interval.fromString(input2), result);
}
}