Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect dedup of non-deterministic functions #22686

Merged
merged 3 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ private ArrayReduceFunction()
.build())
.nullable()
.argumentNullability(false, true, false, false)
.nondeterministic()
hashhar marked this conversation as resolved.
Show resolved Hide resolved
.description("Reduce elements of the array into a single value")
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ private ArrayTransformFunction()
.argumentType(arrayType(new TypeSignature("T")))
.argumentType(functionType(new TypeSignature("T"), new TypeSignature("U")))
.build())
.nondeterministic()
.description("Apply lambda to each element of the array")
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ private MapFilterFunction()
.argumentType(mapType(new TypeSignature("K"), new TypeSignature("V")))
.argumentType(functionType(new TypeSignature("K"), new TypeSignature("V"), BOOLEAN.getTypeSignature()))
.build())
.nondeterministic()
.description("return map containing entries that match the given predicate")
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ public MapTransformKeysFunction(BlockTypeOperators blockTypeOperators)
.argumentType(mapType(new TypeSignature("K1"), new TypeSignature("V")))
.argumentType(functionType(new TypeSignature("K1"), new TypeSignature("V"), new TypeSignature("K2")))
.build())
.nondeterministic()
.description("Apply lambda to each entry of the map and transform the key")
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ private MapTransformValuesFunction()
.argumentType(mapType(new TypeSignature("K"), new TypeSignature("V1")))
.argumentType(functionType(new TypeSignature("K"), new TypeSignature("V1"), new TypeSignature("V2")))
.build())
.nondeterministic()
.description("Apply lambda to each entry of the map and transform the value")
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ private MapZipWithFunction()
.argumentType(mapType(new TypeSignature("K"), new TypeSignature("V2")))
.argumentType(functionType(new TypeSignature("K"), new TypeSignature("V1"), new TypeSignature("V2"), new TypeSignature("V3")))
.build())
.nondeterministic()
.description("Merge two maps into a single map by applying the lambda function to the pair of values with the same key")
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ private ZipWithFunction()
.argumentType(arrayType(new TypeSignature("U")))
.argumentType(functionType(new TypeSignature("T"), new TypeSignature("U"), new TypeSignature("R")))
.build())
.nondeterministic()
.description("Merge two arrays, element-wise, into a single array using the lambda function")
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,10 @@ public Set<ResolvedFunction> getResolvedFunctions()
.collect(toImmutableSet());
}

public ResolvedFunction getResolvedFunction(Node node)
public Optional<ResolvedFunction> getResolvedFunction(Node node)
{
return resolvedFunctions.get(NodeRef.of(node)).getFunction();
return Optional.ofNullable(resolvedFunctions.get(NodeRef.of(node)))
.map(RoutineEntry::getFunction);
}

public void addResolvedFunction(Node node, ResolvedFunction function, String authorization)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5625,9 +5625,9 @@ else if (field.getName().isPresent()) {

private ResolvedFunction getResolvedFunction(FunctionCall functionCall)
{
ResolvedFunction resolvedFunction = analysis.getResolvedFunction(functionCall);
verify(resolvedFunction != null, "function has not been analyzed yet: %s", functionCall);
return resolvedFunction;
Optional<ResolvedFunction> resolvedFunction = analysis.getResolvedFunction(functionCall);
verify(resolvedFunction.isPresent(), "function has not been analyzed yet: %s", functionCall);
return resolvedFunction.get();
}

private List<Expression> analyzeOrderBy(Node node, List<SortItem> sortItems, Scope orderByScope)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import io.trino.sql.tree.DefaultExpressionTraversalVisitor;
import io.trino.sql.tree.FunctionCall;

import java.util.Optional;

import static io.trino.spi.StandardErrorCode.MISSING_OVER;
import static io.trino.spi.function.FunctionKind.WINDOW;
import static io.trino.sql.analyzer.SemanticExceptions.semanticException;
Expand All @@ -32,9 +34,9 @@ protected Void visitFunctionCall(FunctionCall functionCall, Analysis analysis)

// pattern recognition functions are not resolved
if (!analysis.isPatternRecognitionFunction(functionCall)) {
ResolvedFunction resolvedFunction = analysis.getResolvedFunction(functionCall);
if (resolvedFunction != null && functionCall.getWindow().isEmpty() && resolvedFunction.functionKind() == WINDOW) {
throw semanticException(MISSING_OVER, functionCall, "Window function %s requires an OVER clause", resolvedFunction.signature().getName());
Optional<ResolvedFunction> resolvedFunction = analysis.getResolvedFunction(functionCall);
if (resolvedFunction.isPresent() && functionCall.getWindow().isEmpty() && resolvedFunction.get().functionKind() == WINDOW) {
throw semanticException(MISSING_OVER, functionCall, "Window function %s requires an OVER clause", resolvedFunction.get().signature().getName());
}
}
return super.visitFunctionCall(functionCall, analysis);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ private PlanBuilder planAggregation(PlanBuilder subPlan, List<List<Symbol>> grou
// which requires that ORDER BY be a subset of arguments
// What can happen currently is that if the argument requires a coercion, the argument will take a different input that the ORDER BY clause, which is undefined behavior
Aggregation aggregation = new Aggregation(
analysis.getResolvedFunction(function),
analysis.getResolvedFunction(function).get(),
function.getArguments().stream()
.map(argument -> {
if (argument instanceof LambdaExpression) {
Expand Down Expand Up @@ -1769,7 +1769,7 @@ private PlanBuilder planWindow(
.orElse(NullTreatment.RESPECT);

WindowNode.Function function = new WindowNode.Function(
analysis.getResolvedFunction(windowFunction),
analysis.getResolvedFunction(windowFunction).get(),
windowFunction.getArguments().stream()
.map(argument -> {
if (argument instanceof LambdaExpression) {
Expand Down Expand Up @@ -1850,7 +1850,7 @@ private PlanBuilder planPatternRecognition(
.orElse(NullTreatment.RESPECT);

WindowNode.Function function = new WindowNode.Function(
analysis.getResolvedFunction(windowFunction),
analysis.getResolvedFunction(windowFunction).get(),
windowFunction.getArguments().stream()
.map(argument -> {
if (argument instanceof LambdaExpression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1664,12 +1664,12 @@ private JsonTableColumn getColumn(
if (columnDefinition instanceof OrdinalityColumn) {
return new JsonTableOrdinalityColumn(index);
}
ResolvedFunction columnFunction = analysis.getResolvedFunction(columnDefinition);
Optional<ResolvedFunction> columnFunction = analysis.getResolvedFunction(columnDefinition);
IrJsonPath columnPath = new JsonPathTranslator(session, plannerContext).rewriteToIr(analysis.getJsonPathAnalysis(columnDefinition), ImmutableList.of());
if (columnDefinition instanceof QueryColumn queryColumn) {
return new JsonTableQueryColumn(
index,
columnFunction,
columnFunction.get(),
columnPath,
queryColumn.getWrapperBehavior().ordinal(),
queryColumn.getEmptyBehavior().ordinal(),
Expand All @@ -1684,7 +1684,7 @@ private JsonTableColumn getColumn(
.orElse(-1);
return new JsonTableValueColumn(
index,
columnFunction,
columnFunction.get(),
columnPath,
valueColumn.getEmptyBehavior().ordinal(),
emptyDefault,
Expand Down
11 changes: 11 additions & 0 deletions core/trino-main/src/main/java/io/trino/sql/planner/ScopeAware.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
*/
package io.trino.sql.planner;

import io.trino.metadata.ResolvedFunction;
import io.trino.sql.analyzer.Analysis;
import io.trino.sql.analyzer.CanonicalizationAware;
import io.trino.sql.analyzer.ResolvedField;
import io.trino.sql.analyzer.Scope;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.Node;

import java.util.Optional;
import java.util.OptionalInt;

import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -142,6 +145,14 @@ private Boolean scopeAwareComparison(Node left, Node right)
// References come from different scopes
return false;
}
if (left instanceof FunctionCall && right instanceof FunctionCall) {
Optional<ResolvedFunction> resolvedLeft = analysis.getResolvedFunction(left);
Optional<ResolvedFunction> resolvedRight = analysis.getResolvedFunction(right);

if ((resolvedLeft.isPresent() && !resolvedLeft.get().deterministic()) || (resolvedRight.isPresent() && !resolvedRight.get().deterministic())) {
return left == right;
}
}
if (leftExpression instanceof Identifier && rightExpression instanceof Identifier) {
return treeEqual(leftExpression, rightExpression, CanonicalizationAware::canonicalizationAwareComparison);
}
Expand Down
Loading