Skip to content

Commit

Permalink
[nereids](datetime) fix wrong result type of datetime add with interv…
Browse files Browse the repository at this point in the history
…al as first arg (apache#26957)
  • Loading branch information
jacktengg committed Nov 14, 2023
1 parent 160a515 commit b7f33df
Show file tree
Hide file tree
Showing 5 changed files with 1,460 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ public Expression visitArithmeticBinary(ArithmeticBinaryContext ctx) {
throw new ParseException("Only supported: " + Operator.ADD, ctx);
}
Interval interval = (Interval) left;
return new TimestampArithmetic(Operator.ADD, right, interval.value(), interval.timeUnit(), true);
return new TimestampArithmetic(Operator.ADD, right, interval.value(), interval.timeUnit());
}

if (right instanceof Interval) {
Expand All @@ -918,7 +918,7 @@ public Expression visitArithmeticBinary(ArithmeticBinaryContext ctx) {
throw new ParseException("Only supported: " + Operator.ADD + " and " + Operator.SUBTRACT, ctx);
}
Interval interval = (Interval) right;
return new TimestampArithmetic(op, left, interval.value(), interval.timeUnit(), false);
return new TimestampArithmetic(op, left, interval.value(), interval.timeUnit());
}

return ParserUtils.withOrigin(ctx, () -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,21 @@
public class TimestampArithmetic extends Expression implements BinaryExpression, PropagateNullableOnDateLikeV2Args {

private final String funcName;
private final boolean intervalFirst;
private final Operator op;
private final TimeUnit timeUnit;

public TimestampArithmetic(Operator op, Expression e1, Expression e2, TimeUnit timeUnit, boolean intervalFirst) {
this(null, op, e1, e2, timeUnit, intervalFirst);
public TimestampArithmetic(Operator op, Expression e1, Expression e2, TimeUnit timeUnit) {
this(null, op, e1, e2, timeUnit);
}

/**
* Full parameter constructor.
*/
public TimestampArithmetic(String funcName, Operator op, Expression e1, Expression e2, TimeUnit timeUnit,
boolean intervalFirst) {
public TimestampArithmetic(String funcName, Operator op, Expression e1, Expression e2, TimeUnit timeUnit) {
super(ImmutableList.of(e1, e2));
Preconditions.checkState(op == Operator.ADD || op == Operator.SUBTRACT);
this.funcName = funcName;
this.op = op;
this.intervalFirst = intervalFirst;
this.timeUnit = timeUnit;
}

Expand All @@ -76,21 +73,16 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
public TimestampArithmetic withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new TimestampArithmetic(this.funcName, this.op, children.get(0), children.get(1),
this.timeUnit, this.intervalFirst);
this.timeUnit);
}

public Expression withFuncName(String funcName) {
return new TimestampArithmetic(funcName, this.op, children.get(0), children.get(1), this.timeUnit,
this.intervalFirst);
return new TimestampArithmetic(funcName, this.op, children.get(0), children.get(1), this.timeUnit);
}

@Override
public DataType getDataType() throws UnboundException {
int dateChildIndex = 0;
if (intervalFirst) {
dateChildIndex = 1;
}
DataType childType = child(dateChildIndex).getDataType();
DataType childType = child(0).getDataType();
if (childType instanceof DateTimeV2Type) {
return childType;
}
Expand Down Expand Up @@ -137,21 +129,12 @@ public String toSql() {
strBuilder.append(")");
return strBuilder.toString();
}
if (intervalFirst) {
// Non-function-call like version with interval as first operand.
strBuilder.append("INTERVAL ");
strBuilder.append(child(1).toSql()).append(" ");
strBuilder.append(timeUnit);
strBuilder.append(" ").append(op.toString()).append(" ");
strBuilder.append(child(0).toSql());
} else {
// Non-function-call like version with interval as second operand.
strBuilder.append(child(0).toSql());
strBuilder.append(" ").append(op.toString()).append(" ");
strBuilder.append("INTERVAL ");
strBuilder.append(child(1).toSql()).append(" ");
strBuilder.append(timeUnit);
}
// Non-function-call like version with interval as second operand.
strBuilder.append(child(0).toSql());
strBuilder.append(" ").append(op.toString()).append(" ");
strBuilder.append("INTERVAL ");
strBuilder.append(child(1).toSql()).append(" ");
strBuilder.append(timeUnit);
return strBuilder.toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ public void testTimestampFold() {

// a + interval 1 day
Slot a = SlotReference.of("a", DateTimeV2Type.SYSTEM_DEFAULT);
TimestampArithmetic arithmetic = new TimestampArithmetic(Operator.ADD, a, Literal.of(1), TimeUnit.DAY, false);
TimestampArithmetic arithmetic = new TimestampArithmetic(Operator.ADD, a, Literal.of(1), TimeUnit.DAY);
Expression process = process(arithmetic);
assertRewrite(process, process);
}
Expand Down
Loading

0 comments on commit b7f33df

Please sign in to comment.