Skip to content

Commit

Permalink
[CALCITE-6704] Limit result size of RelMdUniqueKeys handler
Browse files Browse the repository at this point in the history
For certain query patterns RelMdUniqueKeys handler generates an
exponentially large number of unique keys that results into crashes
and OOM errors. The limit guards against the combinatorial explosion
that may appear for such use-cases and provides the users of a way to
tune further the upper bound if needed.
  • Loading branch information
zabetak committed Dec 12, 2024
1 parent 041619f commit b70025a
Show file tree
Hide file tree
Showing 4 changed files with 390 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,41 @@
/**
* RelMdUniqueKeys supplies a default implementation of
* {@link RelMetadataQuery#getUniqueKeys} for the standard logical algebra.
* The number of returned keys for each relational expression is bounded by a limit.
* The limit is used to restrict the exponential logic that can appear for certain query patterns
* and lead to CPU/memory exhaustion and crashes.
*/
public class RelMdUniqueKeys
implements MetadataHandler<BuiltInMetadata.UniqueKeys> {
public static final RelMetadataProvider SOURCE =
ReflectiveRelMetadataProvider.reflectiveSource(
new RelMdUniqueKeys(), BuiltInMetadata.UniqueKeys.Handler.class);

/**
* A limit about the number of unique keys returned by the handler.
* The limit must be in the range [0, Integer.MAX_VALUE].
*/
private final int limit;
//~ Constructors -----------------------------------------------------------

private RelMdUniqueKeys() {}
/**
* Creates a metadata handler for unique keys with the default limit.
*/
public RelMdUniqueKeys() {
this(1000);
}

/**
* Creates a metadata handler for unique keys with the specified limit.
*
* @param limit a non-negative integer that bounds the number of unique keys returned for each
* relational expression.
*/
public RelMdUniqueKeys(int limit) {
if (limit < 0) {
throw new IllegalArgumentException("Limit cannot be negative");
}
this.limit = limit;
}

//~ Methods ----------------------------------------------------------------

Expand Down Expand Up @@ -102,6 +127,9 @@ private RelMdUniqueKeys() {}

public @Nullable Set<ImmutableBitSet> getUniqueKeys(Sort rel, RelMetadataQuery mq,
boolean ignoreNulls) {
if (limit == 0) {
return ImmutableSet.of();
}
Double maxRowCount = mq.getMaxRowCount(rel);
if (maxRowCount != null && maxRowCount <= 1.0d) {
return ImmutableSet.of(ImmutableBitSet.of());
Expand Down Expand Up @@ -131,7 +159,7 @@ public Set<ImmutableBitSet> getUniqueKeys(Project rel, RelMetadataQuery mq,
Util.transform(program.getProjectList(), program::expandLocalRef));
}

private static Set<ImmutableBitSet> getProjectUniqueKeys(SingleRel rel, RelMetadataQuery mq,
private Set<ImmutableBitSet> getProjectUniqueKeys(SingleRel rel, RelMetadataQuery mq,
boolean ignoreNulls, List<RexNode> projExprs) {
// LogicalProject maps a set of rows to a different set;
// Without knowledge of the mapping function(whether it
Expand Down Expand Up @@ -171,9 +199,10 @@ private static Set<ImmutableBitSet> getProjectUniqueKeys(SingleRel rel, RelMetad

Multimap<Integer, Integer> mapInToOutPos = inToOutPosBuilder.build();

ImmutableSet.Builder<ImmutableBitSet> resultBuilder = ImmutableSet.builder();
Set<ImmutableBitSet> resultBuilder = new HashSet<>();
// Now add to the projUniqueKeySet the child keys that are fully
// projected.
outerLoop:
for (ImmutableBitSet colMask : childUniqueKeySet) {
if (!inColumnsUsed.contains(colMask)) {
// colMask contains a column that is not projected as RexInput => the key is not unique
Expand All @@ -184,10 +213,14 @@ private static Set<ImmutableBitSet> getProjectUniqueKeys(SingleRel rel, RelMetad
// the resulting unique keys would be {{0},{4}}, {{1},{4}}

Iterable<List<Integer>> product = Linq4j.product(Util.transform(colMask, mapInToOutPos::get));

resultBuilder.addAll(Util.transform(product, ImmutableBitSet::of));
for (List<Integer> passKey : product) {
if (resultBuilder.size() == limit) {
break outerLoop;
}
resultBuilder.add(ImmutableBitSet.of(passKey));
}
}
return resultBuilder.build();
return resultBuilder;
}

public @Nullable Set<ImmutableBitSet> getUniqueKeys(Join rel, RelMetadataQuery mq,
Expand Down Expand Up @@ -272,17 +305,7 @@ private static Set<ImmutableBitSet> getProjectUniqueKeys(SingleRel rel, RelMetad
.forEach(retSet::add);
}

// Remove sets that are supersets of other sets
final Set<ImmutableBitSet> reducedSet = new HashSet<>();
for (ImmutableBitSet bigger : retSet) {
if (retSet.stream()
.filter(smaller -> !bigger.equals(smaller))
.noneMatch(bigger::contains)) {
reducedSet.add(bigger);
}
}

return reducedSet;
return filterSupersets(retSet, limit);
}

/**
Expand Down Expand Up @@ -333,17 +356,23 @@ public Set<ImmutableBitSet> getUniqueKeys(Aggregate rel, RelMetadataQuery mq,

// If an input's unique column(s) value is returned (passed through) by an aggregation
// function, then the result of the function(s) is also unique.
final ImmutableSet.Builder<ImmutableBitSet> keysBuilder = ImmutableSet.builder();
Set<ImmutableBitSet> keysBuilder = new HashSet<>();
if (inputUniqueKeys != null) {
outerLoop:
for (ImmutableBitSet inputKey : inputUniqueKeys) {
Iterable<List<Integer>> product =
Linq4j.product(Util.transform(inputKey, i -> getPassedThroughCols(i, rel)));
keysBuilder.addAll(Util.transform(product, ImmutableBitSet::of));
for (List<Integer> passKey : product) {
if (keysBuilder.size() == limit) {
break outerLoop;
}
keysBuilder.add(ImmutableBitSet.of(passKey));
}
}
}

return filterSupersets(Sets.union(preciseUniqueKeys, keysBuilder.build()));
} else if (ignoreNulls) {
return filterSupersets(Sets.union(preciseUniqueKeys, keysBuilder), limit);
} else if (ignoreNulls && limit > 0) {
// group by keys form a unique key
return ImmutableSet.of(rel.getGroupSet());
} else {
Expand All @@ -358,7 +387,7 @@ public Set<ImmutableBitSet> getUniqueKeys(Aggregate rel, RelMetadataQuery mq,
* other keys. Given {@code {0},{1},{1,2}}, returns {@code {0},{1}}.
*/
private static Set<ImmutableBitSet> filterSupersets(
Set<ImmutableBitSet> uniqueKeys) {
Set<ImmutableBitSet> uniqueKeys, int limit) {
Set<ImmutableBitSet> minimalKeys = new HashSet<>();
outer:
for (ImmutableBitSet candidateKey : uniqueKeys) {
Expand All @@ -368,6 +397,9 @@ private static Set<ImmutableBitSet> filterSupersets(
continue outer;
}
}
if (minimalKeys.size() == limit) {
break outer;
}
minimalKeys.add(candidateKey);
}
return minimalKeys;
Expand Down Expand Up @@ -399,7 +431,7 @@ private static ImmutableBitSet getPassedThroughCols(Integer inputColumn,

public Set<ImmutableBitSet> getUniqueKeys(Union rel, RelMetadataQuery mq,
boolean ignoreNulls) {
if (!rel.all) {
if (!rel.all && limit > 0) {
return ImmutableSet.of(
ImmutableBitSet.range(rel.getRowType().getFieldCount()));
}
Expand All @@ -411,19 +443,24 @@ public Set<ImmutableBitSet> getUniqueKeys(Union rel, RelMetadataQuery mq,
*/
public Set<ImmutableBitSet> getUniqueKeys(Intersect rel,
RelMetadataQuery mq, boolean ignoreNulls) {
ImmutableSet.Builder<ImmutableBitSet> keys = new ImmutableSet.Builder<>();
Set<ImmutableBitSet> keys = new HashSet<>();
outerLoop:
for (RelNode input : rel.getInputs()) {
Set<ImmutableBitSet> uniqueKeys = mq.getUniqueKeys(input, ignoreNulls);
if (uniqueKeys != null) {
keys.addAll(uniqueKeys);
for (ImmutableBitSet inKey : uniqueKeys) {
if (keys.size() == limit) {
break outerLoop;
}
keys.add(inKey);
}
}
}
ImmutableSet<ImmutableBitSet> uniqueKeys = keys.build();
if (!uniqueKeys.isEmpty()) {
return uniqueKeys;
if (!keys.isEmpty()) {
return keys;
}

if (!rel.all) {
if (!rel.all && limit > 0) {
return ImmutableSet.of(
ImmutableBitSet.range(rel.getRowType().getFieldCount()));
}
Expand All @@ -440,7 +477,7 @@ public Set<ImmutableBitSet> getUniqueKeys(Minus rel,
return uniqueKeys;
}

if (!rel.all) {
if (!rel.all && limit > 0) {
return ImmutableSet.of(
ImmutableBitSet.range(rel.getRowType().getFieldCount()));
}
Expand All @@ -459,10 +496,15 @@ public Set<ImmutableBitSet> getUniqueKeys(Minus rel,
if (keys == null) {
return null;
}
Set<ImmutableBitSet> result = new HashSet<>(Math.min(keys.size(), limit));
for (ImmutableBitSet key : keys) {
if (result.size() == limit) {
break;
}
assert rel.getTable().isKey(key);
result.add(key);
}
return ImmutableSet.copyOf(keys);
return result;
}

public @Nullable Set<ImmutableBitSet> getUniqueKeys(Values rel, RelMetadataQuery mq,
Expand All @@ -484,10 +526,15 @@ public Set<ImmutableBitSet> getUniqueKeys(Minus rel,
}

ImmutableSet.Builder<ImmutableBitSet> keySetBuilder = ImmutableSet.builder();
int keySetSize = 0;
for (int i = 0; i < ranges.size(); i++) {
final Set<RexLiteral> range = ranges.get(i);
if (keySetSize == limit) {
break;
}
if (range.size() == tuples.size()) {
keySetBuilder.add(ImmutableBitSet.of(i));
keySetSize++;
}
}
return keySetBuilder.build();
Expand Down
Loading

0 comments on commit b70025a

Please sign in to comment.