diff --git a/server/src/main/java/org/opensearch/index/query/IntervalsSourceProvider.java b/server/src/main/java/org/opensearch/index/query/IntervalsSourceProvider.java index 44a14f3b3dec8..3461a0ebcf3ca 100644 --- a/server/src/main/java/org/opensearch/index/query/IntervalsSourceProvider.java +++ b/server/src/main/java/org/opensearch/index/query/IntervalsSourceProvider.java @@ -39,6 +39,7 @@ import org.apache.lucene.queries.intervals.IntervalsSource; import org.apache.lucene.search.FuzzyQuery; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.automaton.CompiledAutomaton; import org.opensearch.LegacyESVersion; import org.opensearch.Version; import org.opensearch.common.ParseField; @@ -101,12 +102,14 @@ public static IntervalsSourceProvider fromXContent(XContentParser parser) throws return Prefix.fromXContent(parser); case "wildcard": return Wildcard.fromXContent(parser); + case "regexp": + return Regexp.fromXContent(parser); case "fuzzy": return Fuzzy.fromXContent(parser); } throw new ParsingException( parser.getTokenLocation(), - "Unknown interval type [" + parser.currentName() + "], expecting one of [match, any_of, all_of, prefix, wildcard]" + "Unknown interval type [" + parser.currentName() + "], expecting one of [match, any_of, all_of, prefix, wildcard, regexp]" ); } @@ -631,6 +634,155 @@ String getUseField() { } } + public static class Regexp extends IntervalsSourceProvider { + + public static final String NAME = "regexp"; + public static final int DEFAULT_FLAGS_VALUE = RegexpFlag.ALL.value(); + + private final String pattern; + private final int flags; + private final String useField; + private final Integer maxExpansions; + + public Regexp(String pattern, int flags, String useField, Integer maxExpansions) { + this.pattern = pattern; + this.flags = flags; + this.useField = useField; + this.maxExpansions = (maxExpansions != null && maxExpansions > 0) ? maxExpansions : null; + } + + public Regexp(StreamInput in) throws IOException { + this.pattern = in.readString(); + this.flags = in.readVInt(); + this.useField = in.readOptionalString(); + this.maxExpansions = in.readOptionalVInt(); + } + + @Override + public IntervalsSource getSource(QueryShardContext context, MappedFieldType fieldType) { + final org.apache.lucene.util.automaton.RegExp regexp = new org.apache.lucene.util.automaton.RegExp(pattern, flags); + final CompiledAutomaton automaton = new CompiledAutomaton(regexp.toAutomaton()); + + if (useField != null) { + fieldType = context.fieldMapper(useField); + assert fieldType != null; + checkPositions(fieldType); + + IntervalsSource regexpSource = maxExpansions == null + ? Intervals.multiterm(automaton, regexp.toString()) + : Intervals.multiterm(automaton, maxExpansions, regexp.toString()); + return Intervals.fixField(useField, regexpSource); + } else { + checkPositions(fieldType); + return maxExpansions == null + ? Intervals.multiterm(automaton, regexp.toString()) + : Intervals.multiterm(automaton, maxExpansions, regexp.toString()); + } + } + + private void checkPositions(MappedFieldType type) { + if (type.getTextSearchInfo().hasPositions() == false) { + throw new IllegalArgumentException("Cannot create intervals over field [" + type.name() + "] with no positions indexed"); + } + } + + @Override + public void extractFields(Set fields) { + if (useField != null) { + fields.add(useField); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Regexp regexp = (Regexp) o; + return Objects.equals(pattern, regexp.pattern) + && Objects.equals(flags, regexp.flags) + && Objects.equals(useField, regexp.useField) + && Objects.equals(maxExpansions, regexp.maxExpansions); + } + + @Override + public int hashCode() { + return Objects.hash(pattern, flags, useField, maxExpansions); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(pattern); + out.writeVInt(flags); + out.writeOptionalString(useField); + out.writeOptionalVInt(maxExpansions); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.field("pattern", pattern); + if (flags != DEFAULT_FLAGS_VALUE) { + builder.field("flags_value", flags); + } + if (useField != null) { + builder.field("use_field", useField); + } + if (maxExpansions != null) { + builder.field("max_expansions", maxExpansions); + } + builder.endObject(); + return builder; + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { + String pattern = (String) args[0]; + String flags = (String) args[1]; + Integer flagsValue = (Integer) args[2]; + String useField = (String) args[3]; + Integer maxExpansions = (Integer) args[4]; + + if (flagsValue != null) { + return new Regexp(pattern, flagsValue, useField, maxExpansions); + } else if (flags != null) { + return new Regexp(pattern, RegexpFlag.resolveValue(flags), useField, maxExpansions); + } else { + return new Regexp(pattern, DEFAULT_FLAGS_VALUE, useField, maxExpansions); + } + }); + static { + PARSER.declareString(constructorArg(), new ParseField("pattern")); + PARSER.declareString(optionalConstructorArg(), new ParseField("flags")); + PARSER.declareInt(optionalConstructorArg(), new ParseField("flags_value")); + PARSER.declareString(optionalConstructorArg(), new ParseField("use_field")); + PARSER.declareInt(optionalConstructorArg(), new ParseField("max_expansions")); + } + + public static Regexp fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + String getPattern() { + return pattern; + } + + int getFlags() { + return flags; + } + + String getUseField() { + return useField; + } + + Integer getMaxExpansions() { + return maxExpansions; + } + } + public static class Wildcard extends IntervalsSourceProvider { public static final String NAME = "wildcard"; diff --git a/server/src/main/java/org/opensearch/search/SearchModule.java b/server/src/main/java/org/opensearch/search/SearchModule.java index 367fb28809cfa..cdc2509bbcb00 100644 --- a/server/src/main/java/org/opensearch/search/SearchModule.java +++ b/server/src/main/java/org/opensearch/search/SearchModule.java @@ -1254,6 +1254,11 @@ public static List getIntervalsSourceProviderNamed IntervalsSourceProvider.Wildcard.NAME, IntervalsSourceProvider.Wildcard::new ), + new NamedWriteableRegistry.Entry( + IntervalsSourceProvider.class, + IntervalsSourceProvider.Regexp.NAME, + IntervalsSourceProvider.Regexp::new + ), new NamedWriteableRegistry.Entry( IntervalsSourceProvider.class, IntervalsSourceProvider.Fuzzy.NAME, diff --git a/server/src/test/java/org/opensearch/index/query/IntervalQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/IntervalQueryBuilderTests.java index 11f8c165877ae..9dd991f200714 100644 --- a/server/src/test/java/org/opensearch/index/query/IntervalQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/IntervalQueryBuilderTests.java @@ -42,6 +42,8 @@ import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.automaton.CompiledAutomaton; +import org.apache.lucene.util.automaton.RegExp; import org.opensearch.common.ParsingException; import org.opensearch.common.Strings; import org.opensearch.common.compress.CompressedXContent; @@ -686,6 +688,114 @@ public void testWildcard() throws IOException { }); } + private static IntervalsSource buildRegexpSource(String pattern, int flags, Integer maxExpansions) { + final RegExp regexp = new RegExp(pattern, flags); + CompiledAutomaton automaton = new CompiledAutomaton(regexp.toAutomaton()); + + if (maxExpansions != null) { + return Intervals.multiterm(automaton, maxExpansions, regexp.toString()); + } else { + return Intervals.multiterm(automaton, regexp.toString()); + } + } + + public void testRegexp() throws IOException { + final int DEFAULT_FLAGS = RegexpFlag.ALL.value(); + String json = "{ \"intervals\" : { \"" + TEXT_FIELD_NAME + "\": { " + "\"regexp\" : { \"pattern\" : \"te.m\" } } } }"; + + IntervalQueryBuilder builder = (IntervalQueryBuilder) parseQuery(json); + Query expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", DEFAULT_FLAGS, null)); + assertEquals(expected, builder.toQuery(createShardContext())); + + String no_positions_json = "{ \"intervals\" : { \"" + + NO_POSITIONS_FIELD + + "\": { " + + "\"regexp\" : { \"pattern\" : \"[Tt]erm\" } } } }"; + expectThrows(IllegalArgumentException.class, () -> { + IntervalQueryBuilder builder1 = (IntervalQueryBuilder) parseQuery(no_positions_json); + builder1.toQuery(createShardContext()); + }); + + String fixed_field_json = "{ \"intervals\" : { \"" + + TEXT_FIELD_NAME + + "\": { " + + "\"regexp\" : { \"pattern\" : \"te.m\", \"use_field\" : \"masked_field\" } } } }"; + + builder = (IntervalQueryBuilder) parseQuery(fixed_field_json); + expected = new IntervalQuery(TEXT_FIELD_NAME, Intervals.fixField(MASKED_FIELD, buildRegexpSource("te.m", DEFAULT_FLAGS, null))); + assertEquals(expected, builder.toQuery(createShardContext())); + + String fixed_field_json_no_positions = "{ \"intervals\" : { \"" + + TEXT_FIELD_NAME + + "\": { " + + "\"regexp\" : { \"pattern\" : \"te.m\", \"use_field\" : \"" + + NO_POSITIONS_FIELD + + "\" } } } }"; + expectThrows(IllegalArgumentException.class, () -> { + IntervalQueryBuilder builder1 = (IntervalQueryBuilder) parseQuery(fixed_field_json_no_positions); + builder1.toQuery(createShardContext()); + }); + + String flags_json = "{ \"intervals\" : { \"" + + TEXT_FIELD_NAME + + "\": { " + + "\"regexp\" : { \"pattern\" : \"te.m\", \"flags\" : \"NONE\" } } } }"; + + builder = (IntervalQueryBuilder) parseQuery(flags_json); + expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", RegexpFlag.NONE.value(), null)); + assertEquals(expected, builder.toQuery(createShardContext())); + + String flags_value_json = "{ \"intervals\" : { \"" + + TEXT_FIELD_NAME + + "\": { " + + "\"regexp\" : { \"pattern\" : \"te.m\", \"flags_value\" : \"" + + RegexpFlag.ANYSTRING.value() + + "\" } } } }"; + + builder = (IntervalQueryBuilder) parseQuery(flags_value_json); + expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", RegexpFlag.ANYSTRING.value(), null)); + assertEquals(expected, builder.toQuery(createShardContext())); + + String regexp_max_expand_json = "{ \"intervals\" : { \"" + + TEXT_FIELD_NAME + + "\": { " + + "\"regexp\" : { \"pattern\" : \"te.m\", \"max_expansions\" : 500 } } } }"; + + builder = (IntervalQueryBuilder) parseQuery(regexp_max_expand_json); + expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", DEFAULT_FLAGS, 500)); + assertEquals(expected, builder.toQuery(createShardContext())); + + String regexp_neg_max_expand_json = "{ \"intervals\" : { \"" + + TEXT_FIELD_NAME + + "\": { " + + "\"regexp\" : { \"pattern\" : \"te.m\", \"max_expansions\" : -20 } } } }"; + + builder = (IntervalQueryBuilder) parseQuery(regexp_neg_max_expand_json); + // max expansions use default + expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", DEFAULT_FLAGS, null)); + assertEquals(expected, builder.toQuery(createShardContext())); + + String regexp_over_max_expand_json = "{ \"intervals\" : { \"" + + TEXT_FIELD_NAME + + "\": { " + + "\"regexp\" : { \"pattern\" : \"te.m\", \"max_expansions\" : " + + (BooleanQuery.getMaxClauseCount() + 1) + + " } } } }"; + expectThrows(IllegalArgumentException.class, () -> { + IntervalQueryBuilder builder1 = (IntervalQueryBuilder) parseQuery(regexp_over_max_expand_json); + builder1.toQuery(createShardContext()); + }); + + String regexp_max_expand_with_flags_json = "{ \"intervals\" : { \"" + + TEXT_FIELD_NAME + + "\": { " + + "\"regexp\" : { \"pattern\" : \"te.m\", \"flags\": \"NONE\", \"max_expansions\" : 500 } } } }"; + + builder = (IntervalQueryBuilder) parseQuery(regexp_max_expand_with_flags_json); + expected = new IntervalQuery(TEXT_FIELD_NAME, buildRegexpSource("te.m", RegexpFlag.NONE.value(), 500)); + assertEquals(expected, builder.toQuery(createShardContext())); + } + private static IntervalsSource buildFuzzySource(String term, String label, int prefixLength, boolean transpositions, int editDistance) { FuzzyQuery fq = new FuzzyQuery(new Term("field", term), editDistance, prefixLength, 128, transpositions); return Intervals.multiterm(fq.getAutomata(), label); diff --git a/server/src/test/java/org/opensearch/index/query/RegexpIntervalsSourceProviderTests.java b/server/src/test/java/org/opensearch/index/query/RegexpIntervalsSourceProviderTests.java new file mode 100644 index 0000000000000..ba97bdddf52ff --- /dev/null +++ b/server/src/test/java/org/opensearch/index/query/RegexpIntervalsSourceProviderTests.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query; + +import static org.opensearch.index.query.IntervalsSourceProvider.Regexp; +import static org.opensearch.index.query.IntervalsSourceProvider.fromXContent; + +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +public class RegexpIntervalsSourceProviderTests extends AbstractSerializingTestCase { + private static final List FLAGS = Arrays.asList("INTERSECTION", "COMPLEMENT", "EMPTY", "ANYSTRING", "INTERVAL", "NONE"); + + @Override + protected Regexp createTestInstance() { + return createRandomRegexp(); + } + + static Regexp createRandomRegexp() { + return new Regexp( + randomAlphaOfLengthBetween(0, 3) + (randomBoolean() ? ".*?" : "." + randomAlphaOfLength(4)) + randomAlphaOfLengthBetween(0, 5), + randomBoolean() ? RegexpFlag.resolveValue(randomFrom(FLAGS)) : RegexpFlag.ALL.value(), + randomBoolean() ? randomAlphaOfLength(10) : null, + randomBoolean() ? randomIntBetween(-1, Integer.MAX_VALUE) : null + ); + } + + @Override + protected Regexp mutateInstance(Regexp instance) throws IOException { + String pattern = instance.getPattern(); + int flags = instance.getFlags(); + String useField = instance.getUseField(); + Integer maxExpansions = instance.getMaxExpansions(); + int ran = between(0, 3); + switch (ran) { + case 0: + pattern += randomBoolean() ? ".*?" : randomAlphaOfLength(5); + break; + case 1: + flags = (flags == RegexpFlag.ALL.value()) ? RegexpFlag.resolveValue(randomFrom(FLAGS)) : RegexpFlag.ALL.value(); + break; + case 2: + useField = useField == null ? randomAlphaOfLength(5) : null; + break; + case 3: + maxExpansions = maxExpansions == null ? randomIntBetween(1, Integer.MAX_VALUE) : null; + break; + default: + throw new AssertionError("Illegal randomisation branch"); + } + return new Regexp(pattern, flags, useField, maxExpansions); + } + + @Override + protected Writeable.Reader instanceReader() { + return Regexp::new; + } + + @Override + protected Regexp doParseInstance(XContentParser parser) throws IOException { + if (parser.nextToken() == XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + Regexp regexp = (Regexp) fromXContent(parser); + assertEquals(XContentParser.Token.END_OBJECT, parser.nextToken()); + return regexp; + } +}