diff --git a/libs/grok/src/main/java/org/opensearch/grok/Grok.java b/libs/grok/src/main/java/org/opensearch/grok/Grok.java index 27b4735cbe2f0..3bae4b360e34b 100644 --- a/libs/grok/src/main/java/org/opensearch/grok/Grok.java +++ b/libs/grok/src/main/java/org/opensearch/grok/Grok.java @@ -53,6 +53,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Stack; import java.util.function.Consumer; import static java.util.Collections.unmodifiableList; @@ -106,11 +107,7 @@ private Grok(Map patternBank, String grokPattern, boolean namedC this.namedCaptures = namedCaptures; this.matcherWatchdog = matcherWatchdog; - for (Map.Entry entry : patternBank.entrySet()) { - String name = entry.getKey(); - String pattern = entry.getValue(); - forbidCircularReferences(name, new ArrayList<>(), pattern); - } + validatePatternBank(); String expression = toRegex(grokPattern); byte[] expressionBytes = expression.getBytes(StandardCharsets.UTF_8); @@ -125,46 +122,68 @@ private Grok(Map patternBank, String grokPattern, boolean namedC } /** - * Checks whether patterns reference each other in a circular manner and if so fail with an exception + * Entry point to recursively validate the pattern bank for circular dependencies and malformed URLs + * via depth-first traversal. This implementation does not include memoization. + */ + private void validatePatternBank() { + for (String patternName : patternBank.keySet()) { + validatePatternBank(patternName, new Stack<>()); + } + } + + /** + * Checks whether patterns reference each other in a circular manner and, if so, fail with an exception. + * Also checks for malformed pattern definitions and fails with an exception. * * In a pattern, anything between %{ and } or : is considered * a reference to another named pattern. This method will navigate to all these named patterns and * check for a circular reference. */ - private void forbidCircularReferences(String patternName, List path, String pattern) { - if (pattern.contains("%{" + patternName + "}") || pattern.contains("%{" + patternName + ":")) { - String message; - if (path.isEmpty()) { - message = "circular reference in pattern [" + patternName + "][" + pattern + "]"; - } else { - message = "circular reference in pattern [" + path.remove(path.size() - 1) + "][" + pattern + - "] back to pattern [" + patternName + "]"; - // add rest of the path: - if (path.isEmpty() == false) { - message += " via patterns [" + String.join("=>", path) + "]"; - } - } - throw new IllegalArgumentException(message); + private void validatePatternBank(String patternName, Stack path) { + String pattern = patternBank.get(patternName); + boolean isSelfReference = pattern.contains("%{" + patternName + "}") || + pattern.contains("%{" + patternName + ":"); + if (isSelfReference) { + throwExceptionForCircularReference(patternName, pattern); + } else if (path.contains(patternName)) { + // current pattern name is already in the path, fetch its predecessor + String prevPatternName = path.pop(); + String prevPattern = patternBank.get(prevPatternName); + throwExceptionForCircularReference(prevPatternName, prevPattern, patternName, path); } - + path.push(patternName); for (int i = pattern.indexOf("%{"); i != -1; i = pattern.indexOf("%{", i + 1)) { int begin = i + 2; - int brackedIndex = pattern.indexOf('}', begin); - int columnIndex = pattern.indexOf(':', begin); - int end; - if (brackedIndex != -1 && columnIndex == -1) { - end = brackedIndex; - } else if (columnIndex != -1 && brackedIndex == -1) { - end = columnIndex; - } else if (brackedIndex != -1 && columnIndex != -1) { - end = Math.min(brackedIndex, columnIndex); - } else { - throw new IllegalArgumentException("pattern [" + pattern + "] has circular references to other pattern definitions"); + int syntaxEndIndex = pattern.indexOf('}', begin); + if (syntaxEndIndex == -1) { + throw new IllegalArgumentException("Malformed pattern [" + patternName + "][" + pattern +"]"); + } + int semanticNameIndex = pattern.indexOf(':', begin); + int end = syntaxEndIndex; + if (semanticNameIndex != -1) { + end = Math.min(syntaxEndIndex, semanticNameIndex); } - String otherPatternName = pattern.substring(begin, end); - path.add(otherPatternName); - forbidCircularReferences(patternName, path, patternBank.get(otherPatternName)); + String dependsOnPattern = pattern.substring(begin, end); + validatePatternBank(dependsOnPattern, path); + } + path.pop(); + } + + private static void throwExceptionForCircularReference(String patternName, String pattern) { + throwExceptionForCircularReference(patternName, pattern, null, null); + } + + private static void throwExceptionForCircularReference(String patternName, String pattern, String originPatterName, + Stack path) { + StringBuilder message = new StringBuilder("circular reference in pattern ["); + message.append(patternName).append("][").append(pattern).append("]"); + if (originPatterName != null) { + message.append(" back to pattern [").append(originPatterName).append("]"); + } + if (path != null && path.size() > 1) { + message.append(" via patterns [").append(String.join("=>", path)).append("]"); } + throw new IllegalArgumentException(message.toString()); } private String groupMatch(String name, Region region, String pattern) { diff --git a/libs/grok/src/test/java/org/opensearch/grok/GrokTests.java b/libs/grok/src/test/java/org/opensearch/grok/GrokTests.java index 7ec1eb7f98f0e..a9c2d384b62a6 100644 --- a/libs/grok/src/test/java/org/opensearch/grok/GrokTests.java +++ b/libs/grok/src/test/java/org/opensearch/grok/GrokTests.java @@ -51,16 +51,16 @@ import java.util.function.IntConsumer; import java.util.function.LongConsumer; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; import static org.opensearch.grok.GrokCaptureType.BOOLEAN; import static org.opensearch.grok.GrokCaptureType.DOUBLE; import static org.opensearch.grok.GrokCaptureType.FLOAT; import static org.opensearch.grok.GrokCaptureType.INTEGER; import static org.opensearch.grok.GrokCaptureType.LONG; import static org.opensearch.grok.GrokCaptureType.STRING; -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.nullValue; public class GrokTests extends OpenSearchTestCase { @@ -344,7 +344,17 @@ public void testCircularReference() { String pattern = "%{NAME1}"; new Grok(bank, pattern, false, logger::warn); }); - assertEquals("circular reference in pattern [NAME3][!!!%{NAME1}!!!] back to pattern [NAME1] via patterns [NAME2]", + assertEquals("circular reference in pattern [NAME3][!!!%{NAME1}!!!] back to pattern [NAME1] via patterns [NAME1=>NAME2]", + e.getMessage()); + + e = expectThrows(IllegalArgumentException.class, () -> { + Map bank = new TreeMap<>(); + bank.put("NAME1", "!!!%{NAME2}!!!"); + bank.put("NAME2", "!!!%{NAME2}!!!"); + String pattern = "%{NAME1}"; + new Grok(bank, pattern, false, logger::warn); + }); + assertEquals("circular reference in pattern [NAME2][!!!%{NAME2}!!!]", e.getMessage()); e = expectThrows(IllegalArgumentException.class, () -> { @@ -358,7 +368,25 @@ public void testCircularReference() { new Grok(bank, pattern, false, logger::warn ); }); assertEquals("circular reference in pattern [NAME5][!!!%{NAME1}!!!] back to pattern [NAME1] " + - "via patterns [NAME2=>NAME3=>NAME4]", e.getMessage()); + "via patterns [NAME1=>NAME2=>NAME3=>NAME4]", e.getMessage()); + } + + public void testMalformedPattern() { + Exception e = expectThrows(IllegalArgumentException.class, () -> { + Map bank = new HashMap<>(); + bank.put("NAME1", "!!!%{NAME2:!!!"); + String pattern = "%{NAME1}"; + new Grok(bank, pattern, false, logger::warn); + }); + assertEquals("Malformed pattern [NAME1][!!!%{NAME2:!!!]", e.getMessage()); + + e = expectThrows(IllegalArgumentException.class, () -> { + Map bank = new HashMap<>(); + bank.put("NAME1", "!!!%{NAME2!!!"); + String pattern = "%{NAME1}"; + new Grok(bank, pattern, false, logger::warn); + }); + assertEquals("Malformed pattern [NAME1][!!!%{NAME2!!!]", e.getMessage()); } public void testBooleanCaptures() {