Skip to content

Commit

Permalink
EQL: implement case sensitivity for indexOf and endsWith string funct…
Browse files Browse the repository at this point in the history
…ions (elastic#57707)

* EQL: implement case sensitivity for indexOf and endsWith string functions
  • Loading branch information
aleksmaus authored Jun 9, 2020
1 parent 9a43e3e commit e808026
Show file tree
Hide file tree
Showing 12 changed files with 287 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

package org.elasticsearch.xpack.eql.expression.function.scalar.string;

import org.elasticsearch.xpack.eql.session.EqlConfiguration;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.Expressions.ParamOrdinal;
import org.elasticsearch.xpack.ql.expression.FieldAttribute;
import org.elasticsearch.xpack.ql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.ql.expression.function.scalar.string.CaseSensitiveScalarFunction;
import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.ql.expression.gen.script.Scripts;
import org.elasticsearch.xpack.ql.session.Configuration;
import org.elasticsearch.xpack.ql.tree.NodeInfo;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
Expand All @@ -32,17 +34,22 @@
* Function that checks if first parameter ends with the second parameter. Both parameters should be strings
* and the function returns a boolean value. The function is case insensitive.
*/
public class EndsWith extends ScalarFunction {
public class EndsWith extends CaseSensitiveScalarFunction {

private final Expression source;
private final Expression pattern;

public EndsWith(Source source, Expression src, Expression pattern) {
super(source, Arrays.asList(src, pattern));
public EndsWith(Source source, Expression src, Expression pattern, Configuration configuration) {
super(source, Arrays.asList(src, pattern), configuration);
this.source = src;
this.pattern = pattern;
}

@Override
public boolean isCaseSensitive() {
return ((EqlConfiguration) configuration()).isCaseSensitive();
}

@Override
protected TypeResolution resolveType() {
if (!childrenResolved()) {
Expand All @@ -59,7 +66,7 @@ protected TypeResolution resolveType() {

@Override
protected Pipe makePipe() {
return new EndsWithFunctionPipe(source(), this, Expressions.pipe(source), Expressions.pipe(pattern));
return new EndsWithFunctionPipe(source(), this, Expressions.pipe(source), Expressions.pipe(pattern), isCaseSensitive());
}

@Override
Expand All @@ -69,12 +76,12 @@ public boolean foldable() {

@Override
public Object fold() {
return doProcess(source.fold(), pattern.fold());
return doProcess(source.fold(), pattern.fold(), isCaseSensitive());
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, EndsWith::new, source, pattern);
return NodeInfo.create(this, EndsWith::new, source, pattern, configuration());
}

@Override
Expand All @@ -86,14 +93,16 @@ public ScriptTemplate asScript() {
}

protected ScriptTemplate asScriptFrom(ScriptTemplate sourceScript, ScriptTemplate patternScript) {
return new ScriptTemplate(format(Locale.ROOT, formatTemplate("{eql}.%s(%s,%s)"),
return new ScriptTemplate(format(Locale.ROOT, formatTemplate("{eql}.%s(%s,%s,%s)"),
"endsWith",
sourceScript.template(),
patternScript.template()),
patternScript.template(),
"{}"),
paramsBuilder()
.script(sourceScript.params())
.script(patternScript.params())
.build(), dataType());
.script(sourceScript.params())
.script(patternScript.params())
.variable(isCaseSensitive())
.build(), dataType());
}

@Override
Expand All @@ -114,7 +123,7 @@ public Expression replaceChildren(List<Expression> newChildren) {
throw new IllegalArgumentException("expected [2] children but received [" + newChildren.size() + "]");
}

return new EndsWith(source(), newChildren.get(0), newChildren.get(1));
return new EndsWith(source(), newChildren.get(0), newChildren.get(1), configuration());
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ public class EndsWithFunctionPipe extends Pipe {

private final Pipe source;
private final Pipe pattern;
private final boolean isCaseSensitive;

public EndsWithFunctionPipe(Source source, Expression expression, Pipe src, Pipe pattern) {
public EndsWithFunctionPipe(Source source, Expression expression, Pipe src, Pipe pattern, boolean isCaseSensitive) {
super(source, expression, Arrays.asList(src, pattern));
this.source = src;
this.pattern = pattern;
this.isCaseSensitive = isCaseSensitive;
}

@Override
Expand Down Expand Up @@ -55,7 +57,7 @@ public boolean resolved() {
}

protected Pipe replaceChildren(Pipe newSource, Pipe newPattern) {
return new EndsWithFunctionPipe(source(), expression(), newSource, newPattern);
return new EndsWithFunctionPipe(source(), expression(), newSource, newPattern, isCaseSensitive);
}

@Override
Expand All @@ -66,12 +68,12 @@ public final void collectFields(QlSourceBuilder sourceBuilder) {

@Override
protected NodeInfo<EndsWithFunctionPipe> info() {
return NodeInfo.create(this, EndsWithFunctionPipe::new, expression(), source, pattern);
return NodeInfo.create(this, EndsWithFunctionPipe::new, expression(), source, pattern, isCaseSensitive);
}

@Override
public EndsWithFunctionProcessor asProcessor() {
return new EndsWithFunctionProcessor(source.asProcessor(), pattern.asProcessor());
return new EndsWithFunctionProcessor(source.asProcessor(), pattern.asProcessor(), isCaseSensitive);
}

public Pipe src() {
Expand All @@ -84,7 +86,7 @@ public Pipe pattern() {

@Override
public int hashCode() {
return Objects.hash(source, pattern);
return Objects.hash(source, pattern, isCaseSensitive);
}

@Override
Expand All @@ -99,6 +101,7 @@ public boolean equals(Object obj) {

EndsWithFunctionPipe other = (EndsWithFunctionPipe) obj;
return Objects.equals(source, other.source)
&& Objects.equals(pattern, other.pattern);
&& Objects.equals(pattern, other.pattern)
&& Objects.equals(isCaseSensitive, other.isCaseSensitive);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,33 @@ public class EndsWithFunctionProcessor implements Processor {

private final Processor source;
private final Processor pattern;
private final boolean isCaseSensitive;

public EndsWithFunctionProcessor(Processor source, Processor pattern) {
public EndsWithFunctionProcessor(Processor source, Processor pattern, boolean isCaseSensitive) {
this.source = source;
this.pattern = pattern;
this.isCaseSensitive = isCaseSensitive;
}

public EndsWithFunctionProcessor(StreamInput in) throws IOException {
source = in.readNamedWriteable(Processor.class);
pattern = in.readNamedWriteable(Processor.class);
isCaseSensitive = in.readBoolean();
}

@Override
public final void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(source);
out.writeNamedWriteable(pattern);
out.writeBoolean(isCaseSensitive);
}

@Override
public Object process(Object input) {
return doProcess(source.process(input), pattern.process(input));
return doProcess(source.process(input), pattern.process(input), isCaseSensitive());
}

public static Object doProcess(Object source, Object pattern) {
public static Object doProcess(Object source, Object pattern, boolean isCaseSensitive) {
if (source == null) {
return null;
}
Expand All @@ -56,7 +60,11 @@ public static Object doProcess(Object source, Object pattern) {
throw new EqlIllegalArgumentException("A string/char is required; received [{}]", pattern);
}

return source.toString().toLowerCase(Locale.ROOT).endsWith(pattern.toString().toLowerCase(Locale.ROOT));
if (isCaseSensitive) {
return source.toString().endsWith(pattern.toString());
} else {
return source.toString().toLowerCase(Locale.ROOT).endsWith(pattern.toString().toLowerCase(Locale.ROOT));
}
}

protected Processor source() {
Expand All @@ -66,7 +74,11 @@ protected Processor source() {
protected Processor pattern() {
return pattern;
}


protected boolean isCaseSensitive() {
return isCaseSensitive;
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
Expand All @@ -79,17 +91,18 @@ public boolean equals(Object obj) {

EndsWithFunctionProcessor other = (EndsWithFunctionProcessor) obj;
return Objects.equals(source(), other.source())
&& Objects.equals(pattern(), other.pattern());
&& Objects.equals(pattern(), other.pattern())
&& Objects.equals(isCaseSensitive(), other.isCaseSensitive());
}

@Override
public int hashCode() {
return Objects.hash(source(), pattern());
return Objects.hash(source(), pattern(), isCaseSensitive());
}


@Override
public String getWriteableName() {
return NAME;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@

package org.elasticsearch.xpack.eql.expression.function.scalar.string;

import org.elasticsearch.xpack.eql.session.EqlConfiguration;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.Expressions.ParamOrdinal;
import org.elasticsearch.xpack.ql.expression.FieldAttribute;
import org.elasticsearch.xpack.ql.expression.Literal;
import org.elasticsearch.xpack.ql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.ql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.ql.expression.function.scalar.string.CaseSensitiveScalarFunction;
import org.elasticsearch.xpack.ql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.ql.expression.gen.script.Scripts;
import org.elasticsearch.xpack.ql.session.Configuration;
import org.elasticsearch.xpack.ql.tree.NodeInfo;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
Expand All @@ -35,17 +37,22 @@
* Find the first position (zero-indexed) of a string where a substring is found.
* If the optional parameter start is provided, then this will find the first occurrence at or after the start position.
*/
public class IndexOf extends ScalarFunction implements OptionalArgument {
public class IndexOf extends CaseSensitiveScalarFunction implements OptionalArgument {

private final Expression source, substring, start;

public IndexOf(Source source, Expression src, Expression substring, Expression start) {
super(source, Arrays.asList(src, substring, start != null ? start : new Literal(source, null, DataTypes.NULL)));
public IndexOf(Source source, Expression src, Expression substring, Expression start, Configuration configuration) {
super(source, Arrays.asList(src, substring, start != null ? start : new Literal(source, null, DataTypes.NULL)), configuration);
this.source = src;
this.substring = substring;
this.start = arguments().get(2);
}

@Override
public boolean isCaseSensitive() {
return ((EqlConfiguration) configuration()).isCaseSensitive();
}

@Override
protected TypeResolution resolveType() {
if (!childrenResolved()) {
Expand All @@ -61,13 +68,14 @@ protected TypeResolution resolveType() {
if (resolution.unresolved()) {
return resolution;
}

return isInteger(start, sourceText(), ParamOrdinal.THIRD);
}

@Override
protected Pipe makePipe() {
return new IndexOfFunctionPipe(source(), this, Expressions.pipe(source), Expressions.pipe(substring), Expressions.pipe(start));
return new IndexOfFunctionPipe(source(), this, Expressions.pipe(source), Expressions.pipe(substring),
Expressions.pipe(start), isCaseSensitive());
}

@Override
Expand All @@ -77,12 +85,12 @@ public boolean foldable() {

@Override
public Object fold() {
return doProcess(source.fold(), substring.fold(), start.fold());
return doProcess(source.fold(), substring.fold(), start.fold(), isCaseSensitive());
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, IndexOf::new, source, substring, start);
return NodeInfo.create(this, IndexOf::new, source, substring, start, configuration());
}

@Override
Expand All @@ -93,18 +101,20 @@ public ScriptTemplate asScript() {

return asScriptFrom(sourceScript, substringScript, startScript);
}

protected ScriptTemplate asScriptFrom(ScriptTemplate sourceScript, ScriptTemplate substringScript, ScriptTemplate startScript) {
return new ScriptTemplate(format(Locale.ROOT, formatTemplate("{eql}.%s(%s,%s,%s)"),
return new ScriptTemplate(format(Locale.ROOT, formatTemplate("{eql}.%s(%s,%s,%s,%s)"),
"indexOf",
sourceScript.template(),
substringScript.template(),
startScript.template()),
startScript.template(),
"{}"),
paramsBuilder()
.script(sourceScript.params())
.script(substringScript.params())
.script(startScript.params())
.build(), dataType());
.script(sourceScript.params())
.script(substringScript.params())
.script(startScript.params())
.variable(isCaseSensitive())
.build(), dataType());
}

@Override
Expand All @@ -125,7 +135,7 @@ public Expression replaceChildren(List<Expression> newChildren) {
throw new IllegalArgumentException("expected [3] children but received [" + newChildren.size() + "]");
}

return new IndexOf(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2));
return new IndexOf(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), configuration());
}

}
}
Loading

0 comments on commit e808026

Please sign in to comment.