From 26ff736eb9c401c768d385d776c7e40045728c9e Mon Sep 17 00:00:00 2001 From: mikhail-khludnev Date: Tue, 12 Nov 2024 00:52:42 +0300 Subject: [PATCH] fallback to MultiRange only when too many clauses Signed-off-by: mikhail-khludnev --- .../index/mapper/IpFieldMapper.java | 121 +++++++++++++----- 1 file changed, 86 insertions(+), 35 deletions(-) diff --git a/server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java index 21d88556c5458..026d3846a4973 100644 --- a/server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java @@ -41,6 +41,7 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.IndexOrDocValuesQuery; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.PointRangeQuery; import org.apache.lucene.search.Query; @@ -69,6 +70,8 @@ import java.util.List; import java.util.Map; import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.IntSupplier; import java.util.function.Supplier; /** @@ -273,42 +276,42 @@ public Query termQuery(Object value, @Nullable QueryShardContext context) { @Override public Query termsQuery(List values, QueryShardContext context) { failIfNotIndexedAndNoDocValues(); - List concreteIPs = new ArrayList<>(); - List ranges = new ArrayList<>(); - IpMultiRangeQueryBuilder multiRange = new IpMultiRangeQueryBuilder(name()); - boolean multiRangeIsEmpty = true; - for (final Object value : values) { - if (value instanceof InetAddress) { - concreteIPs.add((InetAddress) value); - } else { - final String strVal = (value instanceof BytesRef) ? ((BytesRef) value).utf8ToString() : value.toString(); - if (strVal.contains("/")) { - // the `terms` query contains some prefix queries, so we cannot create a set query - // and need to fall back to a disjunction of `term` queries - // Query query = termQuery(strVal, context); + Tuple, List> ipsMasks = splitIpsAndMasks(values); + QueryUnion combiner = new QueryUnion(); + convertIps(ipsMasks.v1(), combiner); + convertMasks(ipsMasks.v2(), context, combiner, combiner.getAsInt()); + return combiner.get(); + } + + private void convertMasks(List masks, QueryShardContext context, Consumer combiner, int clauses) { + if (!masks.isEmpty()) { + // attempting to avoid too many exception at best + if (masks.size() + clauses >= IndexSearcher.getMaxClauseCount() - 1 && isSearchable()) { + IpMultiRangeQueryBuilder multiRange = new IpMultiRangeQueryBuilder(name()); + for (String strVal : masks) { final Tuple cidr = InetAddresses.parseCidr(strVal); PointRangeQuery query = (PointRangeQuery) InetAddressPoint.newPrefixQuery(name(), cidr.v1(), cidr.v2()); - // would be great to have union on ranges over bare points - // ranges.add(query); multiRange.add(query.getLowerPoint(), query.getUpperPoint()); - multiRangeIsEmpty = false; - } else { - concreteIPs.add(InetAddresses.forString(strVal)); + } + combiner.accept(multiRange.build()); + } else { + for (String strVal : masks) { + combiner.accept(termQuery(strVal, context)); } } } - if (!multiRangeIsEmpty) { - ranges.add(multiRange.build()); - } - if (!concreteIPs.isEmpty()) { + } + + private void convertIps(List inetAddresses, Consumer combiner) { + if (!inetAddresses.isEmpty()) { Supplier pointsQuery; - pointsQuery = () -> concreteIPs.size() == 1 - ? InetAddressPoint.newExactQuery(name(), concreteIPs.iterator().next()) - : InetAddressPoint.newSetQuery(name(), concreteIPs.toArray(new InetAddress[0])); + pointsQuery = () -> inetAddresses.size() == 1 + ? InetAddressPoint.newExactQuery(name(), inetAddresses.iterator().next()) + : InetAddressPoint.newSetQuery(name(), inetAddresses.toArray(new InetAddress[0])); if (hasDocValues()) { - List set = new ArrayList<>(concreteIPs.size()); - for (final InetAddress address : concreteIPs) { + List set = new ArrayList<>(inetAddresses.size()); + for (final InetAddress address : inetAddresses) { set.add(new BytesRef(InetAddressPoint.encode(address))); } Query dvQuery = SortedSetDocValuesField.newSlowSetQuery(name(), set); @@ -319,16 +322,26 @@ public Query termsQuery(List values, QueryShardContext context) { pointsQuery = () -> new IndexOrDocValuesQuery(wrap.get(), dvQuery); } } - ranges.add(pointsQuery.get()); - } - if (ranges.size() == 1) { - return ranges.iterator().next(); // CSQ? + combiner.accept(pointsQuery.get()); } - BooleanQuery.Builder union = new BooleanQuery.Builder(); - for (Query q : ranges) { - union.add(q, BooleanClause.Occur.SHOULD); + } + + private static Tuple, List> splitIpsAndMasks(List values) { + List concreteIPs = new ArrayList<>(); + List masks = new ArrayList<>(); + for (final Object value : values) { + if (value instanceof InetAddress) { + concreteIPs.add((InetAddress) value); + } else { + final String strVal = (value instanceof BytesRef) ? ((BytesRef) value).utf8ToString() : value.toString(); + if (strVal.contains("/")) { + masks.add(strVal); + } else { + concreteIPs.add(InetAddresses.forString(strVal)); + } + } } - return new ConstantScoreQuery(union.build()); + return Tuple.tuple(concreteIPs, masks); } @Override @@ -480,6 +493,44 @@ public DocValueFormat docValueFormat(@Nullable String format, ZoneId timeZone) { } return DocValueFormat.IP; } + + private static class QueryUnion implements Consumer, Supplier, IntSupplier { + Query first; + BooleanQuery.Builder union; + int cnt; + + @Override + public void accept(Query query) { + if (first == null) { + first = query; + } else { + if (union == null) { + union = new BooleanQuery.Builder(); + union.add(first, BooleanClause.Occur.SHOULD); + } + union.add(query, BooleanClause.Occur.SHOULD); + } + cnt++; + } + + @Override + public Query get() { + if (union != null) { + return new ConstantScoreQuery(union.build()); + } else { + if (first != null) { + return first; + } else { // no matches then + return new BooleanQuery.Builder().build(); + } + } + } + + @Override + public int getAsInt() { + return cnt; + } + } } /**