diff --git a/packages/kbn-esql-ast/src/walker/helpers.ts b/packages/kbn-esql-ast/src/walker/helpers.ts new file mode 100644 index 0000000000000..73f0f8d09360c --- /dev/null +++ b/packages/kbn-esql-ast/src/walker/helpers.ts @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +import { ESQLProperNode } from '../types'; + +export type NodeMatchTemplateKey = V | V[] | RegExp; +export type NodeMatchTemplate = { + [K in keyof ESQLProperNode]?: NodeMatchTemplateKey; +}; + +/** + * Creates a predicate function which matches a single AST node against a + * template object. The template object should have the same keys as the + * AST node, and the values should be: + * + * - An array matches if the node key is in the array. + * - A RegExp matches if the node key matches the RegExp. + * - Any other value matches if the node key is triple-equal to the value. + * + * @param template Template from which to create a predicate function. + * @returns A predicate function that matches nodes against the template. + */ +export const templateToPredicate = ( + template: NodeMatchTemplate +): ((node: ESQLProperNode) => boolean) => { + const keys = Object.keys(template) as Array; + const predicate = (child: ESQLProperNode) => { + for (const key of keys) { + const matcher = template[key]; + if (matcher instanceof Array) { + if (!(matcher as any[]).includes(child[key])) { + return false; + } + } else if (matcher instanceof RegExp) { + if (!matcher.test(String(child[key]))) { + return false; + } + } else if (child[key] !== matcher) { + return false; + } + } + + return true; + }; + + return predicate; +}; diff --git a/packages/kbn-esql-ast/src/walker/walker.test.ts b/packages/kbn-esql-ast/src/walker/walker.test.ts index 9f62c2f07d200..59375b275b162 100644 --- a/packages/kbn-esql-ast/src/walker/walker.test.ts +++ b/packages/kbn-esql-ast/src/walker/walker.test.ts @@ -81,6 +81,24 @@ describe('structurally can walk all nodes', () => { ]); }); + test('"visitAny" can capture command nodes', () => { + const { ast } = getAstAndSyntaxErrors('FROM index | STATS a = 123 | WHERE 123 | LIMIT 10'); + const commands: ESQLCommand[] = []; + + walk(ast, { + visitAny: (node) => { + if (node.type === 'command') commands.push(node); + }, + }); + + expect(commands.map(({ name }) => name).sort()).toStrictEqual([ + 'from', + 'limit', + 'stats', + 'where', + ]); + }); + describe('command options', () => { test('can visit command options', () => { const { ast } = getAstAndSyntaxErrors('FROM index METADATA _index'); @@ -93,19 +111,47 @@ describe('structurally can walk all nodes', () => { expect(options.length).toBe(1); expect(options[0].name).toBe('metadata'); }); + + test('"visitAny" can capture an options node', () => { + const { ast } = getAstAndSyntaxErrors('FROM index METADATA _index'); + const options: ESQLCommandOption[] = []; + + walk(ast, { + visitAny: (node) => { + if (node.type === 'option') options.push(node); + }, + }); + + expect(options.length).toBe(1); + expect(options[0].name).toBe('metadata'); + }); }); describe('command mode', () => { test('visits "mode" nodes', () => { const { ast } = getAstAndSyntaxErrors('FROM index | ENRICH a:b'); - const options: ESQLCommandMode[] = []; + const modes: ESQLCommandMode[] = []; walk(ast, { - visitCommandMode: (opt) => options.push(opt), + visitCommandMode: (opt) => modes.push(opt), }); - expect(options.length).toBe(1); - expect(options[0].name).toBe('a'); + expect(modes.length).toBe(1); + expect(modes[0].name).toBe('a'); + }); + + test('"visitAny" can capture a mode node', () => { + const { ast } = getAstAndSyntaxErrors('FROM index | ENRICH a:b'); + const modes: ESQLCommandMode[] = []; + + walk(ast, { + visitAny: (node) => { + if (node.type === 'mode') modes.push(node); + }, + }); + + expect(modes.length).toBe(1); + expect(modes[0].name).toBe('a'); }); }); @@ -123,6 +169,20 @@ describe('structurally can walk all nodes', () => { expect(sources[0].name).toBe('index'); }); + test('"visitAny" can capture a source node', () => { + const { ast } = getAstAndSyntaxErrors('FROM index'); + const sources: ESQLSource[] = []; + + walk(ast, { + visitAny: (node) => { + if (node.type === 'source') sources.push(node); + }, + }); + + expect(sources.length).toBe(1); + expect(sources[0].name).toBe('index'); + }); + test('iterates through all sources', () => { const { ast } = getAstAndSyntaxErrors('METRICS index, index2, index3, index4'); const sources: ESQLSource[] = []; @@ -142,7 +202,7 @@ describe('structurally can walk all nodes', () => { }); describe('columns', () => { - test('can through a single column', () => { + test('can walk through a single column', () => { const query = 'ROW x = 1'; const { ast } = getAstAndSyntaxErrors(query); const columns: ESQLColumn[] = []; @@ -159,6 +219,25 @@ describe('structurally can walk all nodes', () => { ]); }); + test('"visitAny" can capture a column', () => { + const query = 'ROW x = 1'; + const { ast } = getAstAndSyntaxErrors(query); + const columns: ESQLColumn[] = []; + + walk(ast, { + visitAny: (node) => { + if (node.type === 'column') columns.push(node); + }, + }); + + expect(columns).toMatchObject([ + { + type: 'column', + name: 'x', + }, + ]); + }); + test('can walk through multiple columns', () => { const query = 'FROM index | STATS a = 123, b = 456'; const { ast } = getAstAndSyntaxErrors(query); @@ -181,6 +260,52 @@ describe('structurally can walk all nodes', () => { }); }); + describe('functions', () => { + test('can walk through functions', () => { + const query = 'FROM a | STATS fn(1), agg(true)'; + const { ast } = getAstAndSyntaxErrors(query); + const nodes: ESQLFunction[] = []; + + walk(ast, { + visitFunction: (node) => nodes.push(node), + }); + + expect(nodes).toMatchObject([ + { + type: 'function', + name: 'fn', + }, + { + type: 'function', + name: 'agg', + }, + ]); + }); + + test('"visitAny" can capture function nodes', () => { + const query = 'FROM a | STATS fn(1), agg(true)'; + const { ast } = getAstAndSyntaxErrors(query); + const nodes: ESQLFunction[] = []; + + walk(ast, { + visitAny: (node) => { + if (node.type === 'function') nodes.push(node); + }, + }); + + expect(nodes).toMatchObject([ + { + type: 'function', + name: 'fn', + }, + { + type: 'function', + name: 'agg', + }, + ]); + }); + }); + describe('literals', () => { test('can walk a single literal', () => { const query = 'ROW x = 1'; @@ -301,6 +426,20 @@ describe('structurally can walk all nodes', () => { ]); }); + test('"visitAny" can capture a list literal', () => { + const query = 'ROW x = [1, 2]'; + const { ast } = getAstAndSyntaxErrors(query); + const lists: ESQLList[] = []; + + walk(ast, { + visitAny: (node) => { + if (node.type === 'list') lists.push(node); + }, + }); + + expect(lists.length).toBe(1); + }); + test('can walk plain literals inside list literal', () => { const query = 'ROW x = [1, 2] + [3.3]'; const { ast } = getAstAndSyntaxErrors(query); @@ -492,7 +631,6 @@ describe('structurally can walk all nodes', () => { test('can visit time interval nodes', () => { const query = 'FROM index | STATS a = 123 BY 1h'; const { ast } = getAstAndSyntaxErrors(query); - const intervals: ESQLTimeInterval[] = []; walk(ast, { @@ -507,6 +645,43 @@ describe('structurally can walk all nodes', () => { }, ]); }); + + test('"visitAny" can capture time interval expressions', () => { + const query = 'FROM index | STATS a = 123 BY 1h'; + const { ast } = getAstAndSyntaxErrors(query); + const intervals: ESQLTimeInterval[] = []; + + walk(ast, { + visitAny: (node) => { + if (node.type === 'timeInterval') intervals.push(node); + }, + }); + + expect(intervals).toMatchObject([ + { + type: 'timeInterval', + quantity: 1, + unit: 'h', + }, + ]); + }); + + test('"visitAny" does not capture time interval node if type-specific callback provided', () => { + const query = 'FROM index | STATS a = 123 BY 1h'; + const { ast } = getAstAndSyntaxErrors(query); + const intervals1: ESQLTimeInterval[] = []; + const intervals2: ESQLTimeInterval[] = []; + + walk(ast, { + visitTimeIntervalLiteral: (node) => intervals1.push(node), + visitAny: (node) => { + if (node.type === 'timeInterval') intervals2.push(node); + }, + }); + + expect(intervals1.length).toBe(1); + expect(intervals2.length).toBe(0); + }); }); describe('cast expression', () => { @@ -532,6 +707,30 @@ describe('structurally can walk all nodes', () => { }, ]); }); + + test('"visitAny" can capture cast expression', () => { + const query = 'FROM index | STATS a = 123::integer'; + const { ast } = getAstAndSyntaxErrors(query); + const casts: ESQLInlineCast[] = []; + + walk(ast, { + visitAny: (node) => { + if (node.type === 'inlineCast') casts.push(node); + }, + }); + + expect(casts).toMatchObject([ + { + type: 'inlineCast', + castType: 'integer', + value: { + type: 'literal', + literalType: 'integer', + value: 123, + }, + }, + ]); + }); }); }); }); @@ -576,7 +775,7 @@ describe('Walker.commands()', () => { }); }); -describe('Walker.params', () => { +describe('Walker.params()', () => { test('can collect all params', () => { const query = 'ROW x = ?'; const { ast } = getAstAndSyntaxErrors(query); @@ -613,10 +812,195 @@ describe('Walker.params', () => { }); }); +describe('Walker.find()', () => { + test('can find a bucket() function', () => { + const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)'; + const fn = Walker.find( + getAstAndSyntaxErrors(query).ast!, + (node) => node.type === 'function' && node.name === 'bucket' + ); + + expect(fn).toMatchObject({ + type: 'function', + name: 'bucket', + }); + }); + + test('finds the first "fn" function', () => { + const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)'; + const fn = Walker.find( + getAstAndSyntaxErrors(query).ast!, + (node) => node.type === 'function' && node.name === 'fn' + ); + + expect(fn).toMatchObject({ + type: 'function', + name: 'fn', + args: [ + { + type: 'literal', + value: 1, + }, + ], + }); + }); +}); + +describe('Walker.findAll()', () => { + test('find all "fn" functions', () => { + const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)'; + const list = Walker.findAll( + getAstAndSyntaxErrors(query).ast!, + (node) => node.type === 'function' && node.name === 'fn' + ); + + expect(list).toMatchObject([ + { + type: 'function', + name: 'fn', + args: [ + { + type: 'literal', + value: 1, + }, + ], + }, + { + type: 'function', + name: 'fn', + args: [ + { + type: 'literal', + value: 2, + }, + ], + }, + ]); + }); +}); + +describe('Walker.match()', () => { + test('can find a bucket() function', () => { + const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)'; + const fn = Walker.match(getAstAndSyntaxErrors(query).ast!, { + type: 'function', + name: 'bucket', + }); + + expect(fn).toMatchObject({ + type: 'function', + name: 'bucket', + }); + }); + + test('finds the first "fn" function', () => { + const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)'; + const fn = Walker.match(getAstAndSyntaxErrors(query).ast!, { type: 'function', name: 'fn' }); + + expect(fn).toMatchObject({ + type: 'function', + name: 'fn', + args: [ + { + type: 'literal', + value: 1, + }, + ], + }); + }); +}); + +describe('Walker.matchAll()', () => { + test('find all "fn" functions', () => { + const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)'; + const list = Walker.matchAll(getAstAndSyntaxErrors(query).ast!, { + type: 'function', + name: 'fn', + }); + + expect(list).toMatchObject([ + { + type: 'function', + name: 'fn', + args: [ + { + type: 'literal', + value: 1, + }, + ], + }, + { + type: 'function', + name: 'fn', + args: [ + { + type: 'literal', + value: 2, + }, + ], + }, + ]); + }); + + test('find all "fn" and "agg" functions', () => { + const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)'; + const list = Walker.matchAll(getAstAndSyntaxErrors(query).ast!, { + type: 'function', + name: ['fn', 'agg'], + }); + + expect(list).toMatchObject([ + { + type: 'function', + name: 'fn', + args: [ + { + type: 'literal', + value: 1, + }, + ], + }, + { + type: 'function', + name: 'fn', + args: [ + { + type: 'literal', + value: 2, + }, + ], + }, + { + type: 'function', + name: 'agg', + }, + ]); + }); + + test('find all functions which start with "b" or "a"', () => { + const query = 'FROM b | STATS var0 = bucket(bytes, 1 hour), fn(1), fn(2), agg(true)'; + const list = Walker.matchAll(getAstAndSyntaxErrors(query).ast!, { + type: 'function', + name: /^a|b/i, + }); + + expect(list).toMatchObject([ + { + type: 'function', + name: 'bucket', + }, + { + type: 'function', + name: 'agg', + }, + ]); + }); +}); + describe('Walker.hasFunction()', () => { test('can find assignment expression', () => { - const query1 = 'METRICS source bucket(bytes, 1 hour)'; - const query2 = 'METRICS source var0 = bucket(bytes, 1 hour)'; + const query1 = 'FROM a | STATS bucket(bytes, 1 hour)'; + const query2 = 'FROM b | STATS var0 = bucket(bytes, 1 hour)'; const has1 = Walker.hasFunction(getAstAndSyntaxErrors(query1).ast!, '='); const has2 = Walker.hasFunction(getAstAndSyntaxErrors(query2).ast!, '='); diff --git a/packages/kbn-esql-ast/src/walker/walker.ts b/packages/kbn-esql-ast/src/walker/walker.ts index 20e052d211fe1..e6ed54517435e 100644 --- a/packages/kbn-esql-ast/src/walker/walker.ts +++ b/packages/kbn-esql-ast/src/walker/walker.ts @@ -19,11 +19,13 @@ import type { ESQLList, ESQLLiteral, ESQLParamLiteral, + ESQLProperNode, ESQLSingleAstItem, ESQLSource, ESQLTimeInterval, ESQLUnknownItem, } from '../types'; +import { NodeMatchTemplate, templateToPredicate } from './helpers'; type Node = ESQLAstNode | ESQLAstNode[]; @@ -40,6 +42,13 @@ export interface WalkerOptions { visitTimeIntervalLiteral?: (node: ESQLTimeInterval) => void; visitInlineCast?: (node: ESQLInlineCast) => void; visitUnknown?: (node: ESQLUnknownItem) => void; + + /** + * Called for any node type that does not have a specific visitor. + * + * @param node Any valid AST node. + */ + visitAny?: (node: ESQLProperNode) => void; } export type WalkerAstNode = ESQLAstNode | ESQLAstNode[]; @@ -102,6 +111,82 @@ export class Walker { return params; }; + /** + * Finds and returns the first node that matches the search criteria. + * + * @param node AST node to start the search from. + * @param predicate A function that returns true if the node matches the search criteria. + * @returns The first node that matches the search criteria. + */ + public static readonly find = ( + node: WalkerAstNode, + predicate: (node: ESQLProperNode) => boolean + ): ESQLProperNode | undefined => { + let found: ESQLProperNode | undefined; + Walker.walk(node, { + visitAny: (child) => { + if (!found && predicate(child)) { + found = child; + } + }, + }); + return found; + }; + + /** + * Finds and returns all nodes that match the search criteria. + * + * @param node AST node to start the search from. + * @param predicate A function that returns true if the node matches the search criteria. + * @returns All nodes that match the search criteria. + */ + public static readonly findAll = ( + node: WalkerAstNode, + predicate: (node: ESQLProperNode) => boolean + ): ESQLProperNode[] => { + const list: ESQLProperNode[] = []; + Walker.walk(node, { + visitAny: (child) => { + if (predicate(child)) { + list.push(child); + } + }, + }); + return list; + }; + + /** + * Matches a single node against a template object. Returns the first node + * that matches the template. + * + * @param node AST node to match against the template. + * @param template Template object to match against the node. + * @returns The first node that matches the template + */ + public static readonly match = ( + node: WalkerAstNode, + template: NodeMatchTemplate + ): ESQLProperNode | undefined => { + const predicate = templateToPredicate(template); + return Walker.find(node, predicate); + }; + + /** + * Matches all nodes against a template object. Returns all nodes that match + * the template. + * + * @param node AST node to match against the template. + * @param template Template object to match against the node. + * @returns All nodes that match the template + */ + public static readonly matchAll = ( + node: WalkerAstNode, + template: NodeMatchTemplate + ): ESQLProperNode[] => { + const predicate = templateToPredicate(template); + return Walker.findAll(node, predicate); + }; + /** * Finds the first function that matches the predicate. * @@ -161,7 +246,8 @@ export class Walker { } public walkCommand(node: ESQLAstCommand): void { - this.options.visitCommand?.(node); + const { options } = this; + (options.visitCommand ?? options.visitAny)?.(node); switch (node.name) { default: { this.walk(node.args); @@ -171,7 +257,8 @@ export class Walker { } public walkOption(node: ESQLCommandOption): void { - this.options.visitCommandOption?.(node); + const { options } = this; + (options.visitCommandOption ?? options.visitAny)?.(node); for (const child of node.args) { this.walkAstItem(child); } @@ -188,11 +275,13 @@ export class Walker { } public walkMode(node: ESQLCommandMode): void { - this.options.visitCommandMode?.(node); + const { options } = this; + (options.visitCommandMode ?? options.visitAny)?.(node); } public walkListLiteral(node: ESQLList): void { - this.options.visitListLiteral?.(node); + const { options } = this; + (options.visitListLiteral ?? options.visitAny)?.(node); for (const value of node.values) { this.walkAstItem(value); } @@ -215,11 +304,11 @@ export class Walker { break; } case 'source': { - options.visitSource?.(node); + (options.visitSource ?? options.visitAny)?.(node); break; } case 'column': { - options.visitColumn?.(node); + (options.visitColumn ?? options.visitAny)?.(node); break; } case 'literal': { @@ -231,22 +320,23 @@ export class Walker { break; } case 'timeInterval': { - options.visitTimeIntervalLiteral?.(node); + (options.visitTimeIntervalLiteral ?? options.visitAny)?.(node); break; } case 'inlineCast': { - options.visitInlineCast?.(node); + (options.visitInlineCast ?? options.visitAny)?.(node); break; } case 'unknown': { - options.visitUnknown?.(node); + (options.visitUnknown ?? options.visitAny)?.(node); break; } } } public walkFunction(node: ESQLFunction): void { - this.options.visitFunction?.(node); + const { options } = this; + (options.visitFunction ?? options.visitAny)?.(node); const args = node.args; const length = args.length; for (let i = 0; i < length; i++) {