Skip to content

Commit 5f7a298

Browse files
committed
Add sum and avg functions in eval
Signed-off-by: Vamsi Manohar <reddyvam@amazon.com>
1 parent c05a58c commit 5f7a298

File tree

14 files changed

+1727
-8
lines changed

14 files changed

+1727
-8
lines changed

core/src/main/java/org/opensearch/sql/calcite/utils/PPLOperandTypes.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77

88
package org.opensearch.sql.calcite.utils;
99

10-
import org.apache.calcite.sql.type.CompositeOperandTypeChecker;
11-
import org.apache.calcite.sql.type.FamilyOperandTypeChecker;
12-
import org.apache.calcite.sql.type.OperandTypes;
13-
import org.apache.calcite.sql.type.SqlTypeFamily;
10+
import org.apache.calcite.sql.type.*;
11+
import org.opensearch.sql.expression.function.PPLTypeChecker;
1412
import org.opensearch.sql.expression.function.UDFOperandMetadata;
1513

1614
/**
@@ -33,6 +31,9 @@ private PPLOperandTypes() {}
3331
public static final UDFOperandMetadata NUMERIC =
3432
UDFOperandMetadata.wrap((FamilyOperandTypeChecker) OperandTypes.NUMERIC);
3533

34+
public static final UDFOperandMetadata VARIADIC_NUMERIC =
35+
UDFOperandMetadata.wrap(
36+
PPLTypeChecker.wrapSameFamily(SqlOperandCountRanges.from(1), SqlTypeFamily.NUMERIC));
3637
public static final UDFOperandMetadata NUMERIC_OPTIONAL_STRING =
3738
UDFOperandMetadata.wrap(
3839
(CompositeOperandTypeChecker)

core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,13 @@
7373
import org.opensearch.sql.expression.function.udf.ip.CidrMatchFunction;
7474
import org.opensearch.sql.expression.function.udf.ip.CompareIpFunction;
7575
import org.opensearch.sql.expression.function.udf.ip.IPFunction;
76+
import org.opensearch.sql.expression.function.udf.math.AvgFunction;
7677
import org.opensearch.sql.expression.function.udf.math.CRC32Function;
7778
import org.opensearch.sql.expression.function.udf.math.ConvFunction;
7879
import org.opensearch.sql.expression.function.udf.math.DivideFunction;
7980
import org.opensearch.sql.expression.function.udf.math.EulerFunction;
8081
import org.opensearch.sql.expression.function.udf.math.ModFunction;
82+
import org.opensearch.sql.expression.function.udf.math.SumFunction;
8183

8284
/** Defines functions and operators that are implemented only by PPL */
8385
public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
@@ -104,6 +106,8 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable {
104106
public static final SqlOperator MOD = new ModFunction().toUDF("MOD");
105107
public static final SqlOperator CRC32 = new CRC32Function().toUDF("CRC32");
106108
public static final SqlOperator DIVIDE = new DivideFunction().toUDF("DIVIDE");
109+
public static final SqlOperator SUM = new SumFunction().toUDF("SUM");
110+
public static final SqlOperator AVG = new AvgFunction().toUDF("AVG");
107111
public static final SqlOperator SHA2 = CryptographicFunction.sha2().toUDF("SHA2");
108112
public static final SqlOperator CIDRMATCH = new CidrMatchFunction().toUDF("CIDRMATCH");
109113

core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,17 @@ private abstract static class AbstractBuilder {
570570
public void registerOperator(BuiltinFunctionName functionName, SqlOperator... operators) {
571571
for (SqlOperator operator : operators) {
572572
SqlOperandTypeChecker typeChecker;
573+
573574
if (operator instanceof SqlUserDefinedFunction udfOperator) {
574-
typeChecker = extractTypeCheckerFromUDF(udfOperator);
575+
UDFOperandMetadata udfMetadata = (UDFOperandMetadata) udfOperator.getOperandTypeChecker();
576+
577+
// Register directly if it has PPLTypeChecker.
578+
if (udfMetadata != null && udfMetadata.hasPPLTypeChecker()) {
579+
register(
580+
functionName, wrapWithPPLTypeChecker(operator, udfMetadata.getPPLTypeChecker()));
581+
return;
582+
}
583+
typeChecker = extractTypeCheckerFromUDF(udfMetadata);
575584
} else {
576585
typeChecker = operator.getOperandTypeChecker();
577586
}
@@ -614,9 +623,7 @@ public void registerOperator(BuiltinFunctionName functionName, SqlOperator... op
614623
}
615624

616625
private static SqlOperandTypeChecker extractTypeCheckerFromUDF(
617-
SqlUserDefinedFunction udfOperator) {
618-
UDFOperandMetadata udfOperandMetadata =
619-
(UDFOperandMetadata) udfOperator.getOperandTypeChecker();
626+
UDFOperandMetadata udfOperandMetadata) {
620627
return (udfOperandMetadata == null) ? null : udfOperandMetadata.getInnerTypeChecker();
621628
}
622629

@@ -708,6 +715,25 @@ public PPLTypeChecker getTypeChecker() {
708715
};
709716
}
710717

718+
/**
719+
* Wrap a SqlOperator into a FunctionImp with a direct PPLTypeChecker. This is the preferred
720+
* path for UDFs that provide PPLTypeChecker directly.
721+
*/
722+
private static FunctionImp wrapWithPPLTypeChecker(
723+
SqlOperator operator, PPLTypeChecker pplTypeChecker) {
724+
return new FunctionImp() {
725+
@Override
726+
public RexNode resolve(RexBuilder builder, RexNode... args) {
727+
return builder.makeCall(operator, args);
728+
}
729+
730+
@Override
731+
public PPLTypeChecker getTypeChecker() {
732+
return pplTypeChecker; // Direct PPL integration - no wrapping needed!
733+
}
734+
};
735+
}
736+
711737
private static FunctionImp createFunctionImpWithTypeChecker(
712738
BiFunction<RexBuilder, RexNode, RexNode> resolver, PPLTypeChecker typeChecker) {
713739
return new FunctionImp1() {
@@ -751,6 +777,8 @@ void populate() {
751777
registerOperator(AND, SqlStdOperatorTable.AND);
752778
registerOperator(OR, SqlStdOperatorTable.OR);
753779
registerOperator(NOT, SqlStdOperatorTable.NOT);
780+
registerOperator(SUM, PPLBuiltinOperators.SUM);
781+
registerOperator(AVG, PPLBuiltinOperators.AVG);
754782
registerOperator(ADD, SqlStdOperatorTable.PLUS);
755783
registerOperator(SUBTRACT, SqlStdOperatorTable.MINUS);
756784
registerOperator(MULTIPLY, SqlStdOperatorTable.MULTIPLY);

core/src/main/java/org/opensearch/sql/expression/function/PPLTypeChecker.java

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.apache.calcite.rel.type.RelDataType;
1818
import org.apache.calcite.rel.type.RelDataTypeField;
1919
import org.apache.calcite.sql.SqlIntervalQualifier;
20+
import org.apache.calcite.sql.SqlOperandCountRange;
2021
import org.apache.calcite.sql.parser.SqlParserPos;
2122
import org.apache.calcite.sql.type.CompositeOperandTypeChecker;
2223
import org.apache.calcite.sql.type.FamilyOperandTypeChecker;
@@ -246,6 +247,121 @@ public List<List<ExprType>> getParameterTypes() {
246247
}
247248
}
248249

250+
class PPLSameTypeChecker implements PPLTypeChecker {
251+
private final SqlOperandCountRange operandCountRange;
252+
private final SqlTypeFamily expectedFamily; // null means no family enforcement
253+
private final boolean enforceExactType; // true = exact SqlTypeName, false = same family only
254+
255+
// Constructor for family enforcement with exact type control
256+
public PPLSameTypeChecker(
257+
SqlOperandCountRange operandCountRange,
258+
SqlTypeFamily expectedFamily,
259+
boolean enforceExactType) {
260+
this.operandCountRange = operandCountRange;
261+
this.expectedFamily = expectedFamily;
262+
this.enforceExactType = enforceExactType;
263+
}
264+
265+
// Constructor for family enforcement (same family, different types allowed)
266+
public PPLSameTypeChecker(
267+
SqlOperandCountRange operandCountRange, SqlTypeFamily expectedFamily) {
268+
this(operandCountRange, expectedFamily, false);
269+
}
270+
271+
// Constructor for no enforcement
272+
public PPLSameTypeChecker(SqlOperandCountRange operandCountRange) {
273+
this(operandCountRange, null, false);
274+
}
275+
276+
@Override
277+
public boolean checkOperandTypes(List<RelDataType> types) {
278+
if (!operandCountRange.isValidCount(types.size())) {
279+
return false;
280+
}
281+
282+
if (types.isEmpty()) return true;
283+
284+
SqlTypeFamily firstFamily = null;
285+
SqlTypeName firstTypeName = null;
286+
287+
for (RelDataType type : types) {
288+
SqlTypeName typeName = UserDefinedFunctionUtils.convertRelDataTypeToSqlTypeName(type);
289+
SqlTypeFamily family = typeName.getFamily();
290+
291+
if (firstFamily == null) {
292+
firstFamily = family;
293+
firstTypeName = typeName;
294+
295+
// Check expected family if specified
296+
if (expectedFamily != null && !expectedFamily.getTypeNames().contains(typeName)) {
297+
return false;
298+
}
299+
} else {
300+
// Check based on enforcement level
301+
if (enforceExactType) {
302+
// Must be exact same type
303+
if (typeName != firstTypeName) return false;
304+
} else {
305+
// Must be same family (different types within family allowed)
306+
if (family != firstFamily) return false;
307+
}
308+
}
309+
}
310+
return true;
311+
}
312+
313+
@Override
314+
public String getAllowedSignatures() {
315+
int min = operandCountRange.getMin();
316+
int max = operandCountRange.getMax();
317+
318+
String typeLabel;
319+
if (expectedFamily != null) {
320+
typeLabel = expectedFamily.name();
321+
} else {
322+
typeLabel = enforceExactType ? "SQL_TYPE" : "SQL_TYPE_FAMILY";
323+
}
324+
325+
if (min == -1 || max == -1) {
326+
return String.format("[%s, %s, %s, ...]", typeLabel, typeLabel, typeLabel);
327+
} else {
328+
List<String> signatures = new ArrayList<>();
329+
final int MAX_ARGS = 10;
330+
max = Math.min(MAX_ARGS, max);
331+
for (int i = min; i <= max; i++) {
332+
signatures.add("[" + String.join(",", Collections.nCopies(i, typeLabel)) + "]");
333+
}
334+
return String.join(",", signatures);
335+
}
336+
}
337+
338+
@Override
339+
public List<List<ExprType>> getParameterTypes() {
340+
if (expectedFamily != null) {
341+
List<ExprType> exprTypes = getExprTypes(expectedFamily);
342+
int minArgs = operandCountRange.getMin();
343+
344+
if (enforceExactType) {
345+
// Each type gets its own signature
346+
return exprTypes.stream()
347+
.map(type -> Collections.nCopies(minArgs, type))
348+
.collect(Collectors.toList());
349+
} else {
350+
// One signature with all possible types from the family
351+
return List.of(
352+
new ArrayList<>(exprTypes.subList(0, Math.min(minArgs, exprTypes.size()))));
353+
}
354+
} else {
355+
return List.of(List.of(ExprCoreType.UNKNOWN, ExprCoreType.UNKNOWN));
356+
}
357+
}
358+
359+
// Getter for operandCountRange to be used by UDFOperandMetadata
360+
public SqlOperandCountRange getOperandCountRange() {
361+
return operandCountRange;
362+
}
363+
}
364+
249365
@RequiredArgsConstructor
250366
class PPLComparableTypeChecker implements PPLTypeChecker {
251367
private final SameOperandTypeChecker innerTypeChecker;
@@ -418,6 +534,19 @@ static PPLComparableTypeChecker wrapComparable(SameOperandTypeChecker typeChecke
418534
return new PPLComparableTypeChecker(typeChecker);
419535
}
420536

537+
// Same family, different types allowed (e.g., INTEGER + DOUBLE for NUMERIC)
538+
static PPLSameTypeChecker wrapSameFamily(
539+
SqlOperandCountRange operandCountRange, SqlTypeFamily expectedFamily) {
540+
return new PPLSameTypeChecker(operandCountRange, expectedFamily, false);
541+
}
542+
543+
static PPLSameTypeChecker wrapPPLSameTypeChecker(
544+
SqlOperandCountRange operandCountRange,
545+
SqlTypeFamily expectedFamily,
546+
boolean enforceExactType) {
547+
return new PPLSameTypeChecker(operandCountRange, expectedFamily, enforceExactType);
548+
}
549+
421550
/**
422551
* Create a {@link PPLTypeChecker} from a list of allowed signatures consisted of {@link
423552
* ExprType}. This is useful to validate arguments against user-defined types (UDT) that does not

core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,24 @@
2727
public interface UDFOperandMetadata extends SqlOperandMetadata {
2828
SqlOperandTypeChecker getInnerTypeChecker();
2929

30+
/**
31+
* Check if this UDFOperandMetadata has a direct PPLTypeChecker.
32+
*
33+
* @return true if PPLTypeChecker is available, false if using wrapped SqlOperandTypeChecker
34+
*/
35+
default boolean hasPPLTypeChecker() {
36+
return false;
37+
}
38+
39+
/**
40+
* Get the direct PPLTypeChecker if available.
41+
*
42+
* @return PPLTypeChecker or null if not available
43+
*/
44+
default PPLTypeChecker getPPLTypeChecker() {
45+
return null;
46+
}
47+
3048
static UDFOperandMetadata wrap(FamilyOperandTypeChecker typeChecker) {
3149
return new UDFOperandMetadata() {
3250
@Override
@@ -106,6 +124,58 @@ public String getAllowedSignatures(SqlOperator op, String opName) {
106124
};
107125
}
108126

127+
static UDFOperandMetadata wrap(PPLTypeChecker pplTypeChecker) {
128+
return new UDFOperandMetadata() {
129+
@Override
130+
public SqlOperandTypeChecker getInnerTypeChecker() {
131+
return this;
132+
}
133+
134+
@Override
135+
public boolean hasPPLTypeChecker() {
136+
return true;
137+
}
138+
139+
@Override
140+
public PPLTypeChecker getPPLTypeChecker() {
141+
return pplTypeChecker;
142+
}
143+
144+
@Override
145+
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
146+
return Collections.emptyList();
147+
}
148+
149+
@Override
150+
public List<String> paramNames() {
151+
return Collections.emptyList();
152+
}
153+
154+
@Override
155+
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
156+
// Convert SqlCallBinding to List<RelDataType> and use PPLTypeChecker
157+
List<RelDataType> types = callBinding.collectOperandTypes();
158+
return pplTypeChecker.checkOperandTypes(types);
159+
}
160+
161+
@Override
162+
public SqlOperandCountRange getOperandCountRange() {
163+
// Extract operandCountRange from PPLSameTypeChecker if possible
164+
if (pplTypeChecker instanceof PPLTypeChecker.PPLSameTypeChecker) {
165+
PPLTypeChecker.PPLSameTypeChecker sameTypeChecker =
166+
(PPLTypeChecker.PPLSameTypeChecker) pplTypeChecker;
167+
return sameTypeChecker.getOperandCountRange();
168+
}
169+
return null;
170+
}
171+
172+
@Override
173+
public String getAllowedSignatures(SqlOperator op, String opName) {
174+
return pplTypeChecker.getAllowedSignatures();
175+
}
176+
};
177+
}
178+
109179
static UDFOperandMetadata wrapUDT(List<List<ExprType>> allowSignatures) {
110180
return new UDTOperandMetadata(allowSignatures);
111181
}

0 commit comments

Comments
 (0)