diff --git a/DEVELOPER_GUIDE.rst b/DEVELOPER_GUIDE.rst index 516cf23556..923cb459f9 100644 --- a/DEVELOPER_GUIDE.rst +++ b/DEVELOPER_GUIDE.rst @@ -147,6 +147,7 @@ The plugin codebase is in standard layout of Gradle project:: ├── plugin ├── protocol ├── ppl + ├── spark ├── sql ├── sql-cli ├── sql-jdbc @@ -161,6 +162,7 @@ Here are sub-folders (Gradle modules) for plugin source code: - ``core``: core query engine. - ``opensearch``: OpenSearch storage engine. - ``prometheus``: Prometheus storage engine. +- ``spark`` : Spark storage engine - ``protocol``: request/response protocol formatter. - ``common``: common util code. - ``integ-test``: integration and comparison test. diff --git a/common/build.gradle b/common/build.gradle index 992135e06d..ce449479d2 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -33,7 +33,7 @@ repositories { dependencies { api "org.antlr:antlr4-runtime:4.7.1" - api group: 'com.google.guava', name: 'guava', version: '31.0.1-jre' + api group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' api group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.17.1' api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.9.3' @@ -43,7 +43,7 @@ dependencies { testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.assertj', name: 'assertj-core', version: '3.9.1' - testImplementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre' + testImplementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testImplementation group: 'org.mockito', name: 'mockito-core', version: '3.12.4' diff --git a/core/build.gradle b/core/build.gradle index 624c10fd6b..8205638138 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -44,7 +44,7 @@ pitest { } dependencies { - api group: 'com.google.guava', name: 'guava', version: '31.0.1-jre' + api group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' api group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' api group: 'com.facebook.presto', name: 'presto-matching', version: '0.240' api group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1' diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 6418f92686..2c4647004c 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -29,6 +29,7 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.math3.analysis.function.Exp; import org.opensearch.sql.DataSourceSchemaName; import org.opensearch.sql.analysis.symbol.Namespace; import org.opensearch.sql.analysis.symbol.Symbol; @@ -469,8 +470,13 @@ public LogicalPlan visitSort(Sort node, AnalysisContext context) { node.getSortList().stream() .map( sortField -> { - Expression expression = optimizer.optimize( - expressionAnalyzer.analyze(sortField.getField(), context), context); + var analyzed = expressionAnalyzer.analyze(sortField.getField(), context); + if (analyzed == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", sortField.getField()) + ); + } + Expression expression = optimizer.optimize(analyzed, context); return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression); }) .collect(Collectors.toList()); diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 43155a868a..601e3e00cc 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -186,7 +186,16 @@ public Expression visitFunction(Function node, AnalysisContext context) { FunctionName functionName = FunctionName.of(node.getFuncName()); List arguments = node.getFuncArgs().stream() - .map(unresolvedExpression -> analyze(unresolvedExpression, context)) + .map(unresolvedExpression -> { + var ret = analyze(unresolvedExpression, context); + if (ret == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", unresolvedExpression) + ); + } else { + return ret; + } + }) .collect(Collectors.toList()); return (Expression) repository.compile(context.getFunctionProperties(), functionName, arguments); diff --git a/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java index 4e3939bb14..f050824557 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java @@ -15,6 +15,7 @@ import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.expression.Expression; @@ -45,6 +46,28 @@ public LogicalPlan visitAlias(Alias node, AnalysisContext context) { return node.getDelegated().accept(this, context); } + @Override + public LogicalPlan visitNestedAllTupleFields(NestedAllTupleFields node, AnalysisContext context) { + List> args = new ArrayList<>(); + for (NamedExpression namedExpr : namedExpressions) { + if (isNestedFunction(namedExpr.getDelegated())) { + ReferenceExpression field = + (ReferenceExpression) ((FunctionExpression) namedExpr.getDelegated()) + .getArguments().get(0); + + // If path is same as NestedAllTupleFields path + if (field.getAttr().substring(0, field.getAttr().lastIndexOf(".")) + .equalsIgnoreCase(node.getPath())) { + args.add(Map.of( + "field", field, + "path", new ReferenceExpression(node.getPath(), STRING))); + } + } + } + + return mergeChildIfLogicalNested(args); + } + @Override public LogicalPlan visitFunction(Function node, AnalysisContext context) { if (node.getFuncName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) { @@ -54,6 +77,8 @@ public LogicalPlan visitFunction(Function node, AnalysisContext context) { ReferenceExpression nestedField = (ReferenceExpression)expressionAnalyzer.analyze(expressions.get(0), context); Map args; + + // Path parameter is supplied if (expressions.size() == 2) { args = Map.of( "field", nestedField, @@ -65,16 +90,28 @@ public LogicalPlan visitFunction(Function node, AnalysisContext context) { "path", generatePath(nestedField.toString()) ); } - if (child instanceof LogicalNested) { - ((LogicalNested)child).addFields(args); - return child; - } else { - return new LogicalNested(child, new ArrayList<>(Arrays.asList(args)), namedExpressions); - } + + return mergeChildIfLogicalNested(new ArrayList<>(Arrays.asList(args))); } return null; } + /** + * NestedAnalyzer visits all functions in SELECT clause, creates logical plans for each and + * merges them. This is to avoid another merge rule in LogicalPlanOptimizer:create(). + * @param args field and path params to add to logical plan. + * @return child of logical nested with added args, or new LogicalNested. + */ + private LogicalPlan mergeChildIfLogicalNested(List> args) { + if (child instanceof LogicalNested) { + for (var arg : args) { + ((LogicalNested) child).addFields(arg); + } + return child; + } + return new LogicalNested(child, args, namedExpressions); + } + /** * Validate each parameter used in nested function in SELECT clause. Any supplied parameter * for a nested function in a SELECT statement must be a valid qualified name, and the field diff --git a/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java index 3593488f46..734f37378b 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java @@ -10,13 +10,17 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.regex.Pattern; import java.util.stream.Collectors; import lombok.RequiredArgsConstructor; import org.opensearch.sql.analysis.symbol.Namespace; +import org.opensearch.sql.analysis.symbol.Symbol; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.data.type.ExprType; @@ -58,6 +62,11 @@ public List visitField(Field node, AnalysisContext context) { @Override public List visitAlias(Alias node, AnalysisContext context) { + // Expand all nested fields if used in SELECT clause + if (node.getDelegated() instanceof NestedAllTupleFields) { + return node.getDelegated().accept(this, context); + } + Expression expr = referenceIfSymbolDefined(node, context); return Collections.singletonList(DSL.named( unqualifiedNameIfFieldOnly(node, context), @@ -100,6 +109,29 @@ public List visitAllFields(AllFields node, new ReferenceExpression(entry.getKey(), entry.getValue()))).collect(Collectors.toList()); } + @Override + public List visitNestedAllTupleFields(NestedAllTupleFields node, + AnalysisContext context) { + TypeEnvironment environment = context.peek(); + Map lookupAllTupleFields = + environment.lookupAllTupleFields(Namespace.FIELD_NAME); + environment.resolve(new Symbol(Namespace.FIELD_NAME, node.getPath())); + + // Match all fields with same path as used in nested function. + Pattern p = Pattern.compile(node.getPath() + "\\.[^\\.]+$"); + return lookupAllTupleFields.entrySet().stream() + .filter(field -> p.matcher(field.getKey()).find()) + .map(entry -> { + Expression nestedFunc = new Function( + "nested", + List.of( + new QualifiedName(List.of(entry.getKey().split("\\.")))) + ).accept(expressionAnalyzer, context); + return DSL.named("nested(" + entry.getKey() + ")", nestedFunc); + }) + .collect(Collectors.toList()); + } + /** * Get unqualified name if select item is just a field. For example, suppose an index * named "accounts", return "age" for "SELECT accounts.age". But do nothing for expression diff --git a/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java b/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java index c9fd8030e0..17d203f66f 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java +++ b/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java @@ -85,6 +85,17 @@ public Map lookupAllFields(Namespace namespace) { return result; } + /** + * Resolve all fields in the current environment. + * @param namespace a namespace + * @return all symbols in the namespace + */ + public Map lookupAllTupleFields(Namespace namespace) { + Map result = new LinkedHashMap<>(); + symbolTable.lookupAllTupleFields(namespace).forEach(result::putIfAbsent); + return result; + } + /** * Define symbol with the type. * diff --git a/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java b/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java index 45f77915f2..be7435c288 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java +++ b/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java @@ -128,6 +128,21 @@ public Map lookupAllFields(Namespace namespace) { return results; } + /** + * Look up all top level symbols in the namespace. + * + * @param namespace a namespace + * @return all symbols in the namespace map + */ + public Map lookupAllTupleFields(Namespace namespace) { + final LinkedHashMap allSymbols = + orderedTable.getOrDefault(namespace, new LinkedHashMap<>()); + final LinkedHashMap result = new LinkedHashMap<>(); + allSymbols.entrySet().stream() + .forEach(entry -> result.put(entry.getKey(), entry.getValue())); + return result; + } + /** * Check if namespace map in empty (none definition). * diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 3e81509fae..f02bc07ccc 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -25,6 +25,7 @@ import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; @@ -238,6 +239,10 @@ public T visitAllFields(AllFields node, C context) { return visitChildren(node, context); } + public T visitNestedAllTupleFields(NestedAllTupleFields node, C context) { + return visitChildren(node, context); + } + public T visitInterval(Interval node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index de2ab5404a..d5f10fcfd4 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -30,6 +30,7 @@ import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.ParseMethod; @@ -377,6 +378,10 @@ public Alias alias(String name, UnresolvedExpression expr, String alias) { return new Alias(name, expr, alias); } + public NestedAllTupleFields nestedAllTupleFields(String path) { + return new NestedAllTupleFields(path); + } + public static List exprList(UnresolvedExpression... exprList) { return Arrays.asList(exprList); } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/NestedAllTupleFields.java b/core/src/main/java/org/opensearch/sql/ast/expression/NestedAllTupleFields.java new file mode 100644 index 0000000000..adf2025e6c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/expression/NestedAllTupleFields.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.ast.expression; + +import java.util.Collections; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +/** + * Represents all tuple fields used in nested function. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class NestedAllTupleFields extends UnresolvedExpression { + @Getter + private final String path; + + @Override + public List getChild() { + return Collections.emptyList(); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitNestedAllTupleFields(this, context); + } + + @Override + public String toString() { + return String.format("nested(%s.*)", path); + } +} diff --git a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java index 48098b9741..5010e41942 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java +++ b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java @@ -8,8 +8,7 @@ public enum DataSourceType { PROMETHEUS("prometheus"), OPENSEARCH("opensearch"), - JDBC("jdbc"); - + SPARK("spark"); private String text; DataSourceType(String text) { diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 6d83ee53a8..100cfd67af 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -24,6 +24,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.filteredAggregate; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; +import static org.opensearch.sql.ast.dsl.AstDSL.nestedAllTupleFields; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.ast.dsl.AstDSL.relation; import static org.opensearch.sql.ast.dsl.AstDSL.span; @@ -556,7 +557,7 @@ public void project_nested_field_arg() { List projectList = List.of( new NamedExpression( - "message.info", + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING)), null) ); @@ -567,13 +568,13 @@ public void project_nested_field_arg() { LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), - DSL.named("message.info", + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING))) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("message.info", + AstDSL.alias("nested(message.info)", function("nested", qualifiedName("message", "info")), null) ) ); @@ -583,6 +584,195 @@ public void project_nested_field_arg() { assertFalse(isNestedFunction(DSL.match(DSL.namedArgument("field", literal("message"))))); } + @Test + public void sort_with_nested_all_tuple_fields_throws_exception() { + assertThrows(UnsupportedOperationException.class, () -> analyze( + AstDSL.project( + AstDSL.sort( + AstDSL.relation("schema"), + field(nestedAllTupleFields("message")) + ), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")) + ) + )); + } + + @Test + public void filter_with_nested_all_tuple_fields_throws_exception() { + assertThrows(UnsupportedOperationException.class, () -> analyze( + AstDSL.project( + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function("=", nestedAllTupleFields("message"), AstDSL.intLiteral(1))), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")) + ) + )); + } + + + @Test + public void project_nested_field_star_arg() { + List> nestedArgs = + List.of( + Map.of( + "field", new ReferenceExpression("message.info", STRING), + "path", new ReferenceExpression("message", STRING) + ) + ); + + List projectList = + List.of( + new NamedExpression("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))) + ); + + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.nested( + LogicalPlanDSL.relation("schema", table), + nestedArgs, + projectList), + DSL.named("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))) + ), + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")) + ) + ); + } + + @Test + public void project_nested_field_star_arg_with_another_nested_function() { + List> nestedArgs = + List.of( + Map.of( + "field", new ReferenceExpression("message.info", STRING), + "path", new ReferenceExpression("message", STRING) + ), + Map.of( + "field", new ReferenceExpression("comment.data", STRING), + "path", new ReferenceExpression("comment", STRING) + ) + ); + + List projectList = + List.of( + new NamedExpression("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + new NamedExpression("nested(comment.data)", + DSL.nested(DSL.ref("comment.data", STRING))) + ); + + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.nested( + LogicalPlanDSL.relation("schema", table), + nestedArgs, + projectList), + DSL.named("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("nested(comment.data)", + DSL.nested(DSL.ref("comment.data", STRING))) + ), + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")), + AstDSL.alias("nested(comment.*)", + nestedAllTupleFields("comment")) + ) + ); + } + + @Test + public void project_nested_field_star_arg_with_another_field() { + List> nestedArgs = + List.of( + Map.of( + "field", new ReferenceExpression("message.info", STRING), + "path", new ReferenceExpression("message", STRING) + ) + ); + + List projectList = + List.of( + new NamedExpression("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + new NamedExpression("comment.data", + DSL.ref("comment.data", STRING)) + ); + + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.nested( + LogicalPlanDSL.relation("schema", table), + nestedArgs, + projectList), + DSL.named("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("comment.data", + DSL.ref("comment.data", STRING)) + ), + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")), + AstDSL.alias("comment.data", + field("comment.data")) + ) + ); + } + + @Test + public void project_nested_field_star_arg_with_highlight() { + List> nestedArgs = + List.of( + Map.of( + "field", new ReferenceExpression("message.info", STRING), + "path", new ReferenceExpression("message", STRING) + ) + ); + + List projectList = + List.of( + new NamedExpression("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("highlight(fieldA)", + new HighlightExpression(DSL.literal("fieldA"))) + ); + + Map highlightArgs = new HashMap<>(); + + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.nested( + LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), + DSL.literal("fieldA"), highlightArgs), + nestedArgs, + projectList), + DSL.named("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("highlight(fieldA)", + new HighlightExpression(DSL.literal("fieldA"))) + ), + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")), + AstDSL.alias("highlight(fieldA)", + new HighlightFunction(AstDSL.stringLiteral("fieldA"), highlightArgs)) + ) + ); + } + @Test public void project_nested_field_and_path_args() { List> nestedArgs = @@ -596,7 +786,7 @@ public void project_nested_field_and_path_args() { List projectList = List.of( new NamedExpression( - "message.info", + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING)), null) ); @@ -607,13 +797,13 @@ public void project_nested_field_and_path_args() { LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), - DSL.named("message.info", + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING))) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("message.info", + AstDSL.alias("nested(message.info)", function( "nested", qualifiedName("message", "info"), @@ -638,7 +828,7 @@ public void project_nested_deep_field_arg() { List projectList = List.of( new NamedExpression( - "message.info.id", + "nested(message.info.id)", DSL.nested(DSL.ref("message.info.id", STRING)), null) ); @@ -649,13 +839,13 @@ public void project_nested_deep_field_arg() { LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), - DSL.named("message.info.id", + DSL.named("nested(message.info.id)", DSL.nested(DSL.ref("message.info.id", STRING))) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("message.info.id", + AstDSL.alias("nested(message.info.id)", function("nested", qualifiedName("message", "info", "id")), null) ) ); @@ -678,11 +868,11 @@ public void project_multiple_nested() { List projectList = List.of( new NamedExpression( - "message.info", + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING)), null), new NamedExpression( - "comment.data", + "nested(comment.data)", DSL.nested(DSL.ref("comment.data", STRING)), null) ); @@ -693,17 +883,17 @@ public void project_multiple_nested() { LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), - DSL.named("message.info", + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING))), - DSL.named("comment.data", + DSL.named("nested(comment.data)", DSL.nested(DSL.ref("comment.data", STRING))) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("message.info", + AstDSL.alias("nested(message.info)", function("nested", qualifiedName("message", "info")), null), - AstDSL.alias("comment.data", + AstDSL.alias("nested(comment.data)", function("nested", qualifiedName("comment", "data")), null) ) ); diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorage.java b/datasources/src/main/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorage.java index f76e1ba9dc..d756f2e029 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorage.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorage.java @@ -217,6 +217,8 @@ private List searchInDataSourcesIndex(QueryBuilder query) { searchSourceBuilder.query(query); searchSourceBuilder.size(DATASOURCE_QUERY_RESULT_SIZE); searchRequest.source(searchSourceBuilder); + // strongly consistent reads is requred. more info https://github.com/opensearch-project/sql/issues/1801. + searchRequest.preference("_primary"); ActionFuture searchResponseActionFuture; try (ThreadContext.StoredContext ignored = client.threadPool().getThreadContext() .stashContext()) { @@ -305,4 +307,4 @@ private void handleSigV4PropertiesEncryptionDecryption(Map prope encryptOrDecrypt(propertiesMap, isEncryption, list); } -} \ No newline at end of file +} diff --git a/docs/user/dql/basics.rst b/docs/user/dql/basics.rst index b7e8cf35a4..a03ac4db70 100644 --- a/docs/user/dql/basics.rst +++ b/docs/user/dql/basics.rst @@ -155,14 +155,17 @@ Result set: | Nanette| Bates| +---------+--------+ -One can also provide meta-field name(s) to retrieve reserved-fields (beginning with underscore) from OpenSearch documents. Meta-fields are not output -from wildcard calls (`SELECT *`) and must be explicitly included to be returned. +One can also provide meta-field name(s) to retrieve reserved-fields (beginning with underscore) from OpenSearch documents. They may also be used +in the query `WHERE` or `ORDER BY` clauses. Meta-fields are not output from wildcard calls (`SELECT *`) and must be explicitly included to be returned. + +Note: `_routing` is used differently in the `SELECT` and `WHERE` clauses. In `WHERE`, it contains the routing hash id. In `SELECT`, +it returns the shard used for the query (unless shards aren't active, in which case it returns the routing hash id). SQL query:: POST /_plugins/_sql { - "query" : "SELECT firstname, lastname, _id, _index, _sort FROM accounts" + "query" : "SELECT firstname, lastname, _id, _index, _sort, _routing FROM accounts WHERE _index = 'accounts'" } Explain:: @@ -175,6 +178,7 @@ Explain:: "firstname", "_id", "_index", + "_routing", "_sort", "lastname" ], diff --git a/docs/user/dql/expressions.rst b/docs/user/dql/expressions.rst index 39d381f59c..275795707c 100644 --- a/docs/user/dql/expressions.rst +++ b/docs/user/dql/expressions.rst @@ -160,7 +160,7 @@ Here is an example for different type of comparison operators:: +---------+----------+---------+----------+----------+---------+ It is possible to compare datetimes. When comparing different datetime types, for example `DATE` and `TIME`, both converted to `DATETIME`. -The following rule is applied on coversion: a `TIME` applied to today's date; `DATE` is interpreted at midnight. +The following rule is applied on coversion: a `TIME` applied to today's date; `DATE` is interpreted at midnight. See example below:: os> SELECT current_time() > current_date() AS `now.time > today`, typeof(current_time()) AS `now.time.type`, typeof(current_date()) AS `now.date.type`; fetched rows / total rows = 1/1 diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index cef87624a5..19260e8bea 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -4458,6 +4458,17 @@ Example with ``field`` and ``path`` parameters:: | b | +---------------------------------+ +Example with ``field.*`` used in SELECT clause:: + + os> SELECT nested(message.*) FROM nested; + fetched rows / total rows = 2/2 + +--------------------------+-----------------------------+------------------------+ + | nested(message.author) | nested(message.dayOfWeek) | nested(message.info) | + |--------------------------+-----------------------------+------------------------| + | e | 1 | a | + | f | 2 | b | + +--------------------------+-----------------------------+------------------------+ + Example with ``field`` and ``path`` parameters in the SELECT and WHERE clause:: diff --git a/docs/user/general/datatypes.rst b/docs/user/general/datatypes.rst index 4f1f3100c2..a265ffd4c9 100644 --- a/docs/user/general/datatypes.rst +++ b/docs/user/general/datatypes.rst @@ -155,15 +155,15 @@ The following matrix illustrates the conversions allowed by our query engine for +--------------+------+-------+---------+------+-------+--------+---------+--------------+------+--------+-----------+------+------+----------+----------+-----------+-----+--------+-----------+---------+ | STRING | E | E | E | E | E | E | IE | X | X | N/A | IE | IE | IE | IE | X | X | X | X | X | X | +--------------+------+-------+---------+------+-------+--------+---------+--------------+------+--------+-----------+------+------+----------+----------+-----------+-----+--------+-----------+---------+ -| TIMESTAMP | X | X | X | X | X | X | X | X | X | E | N/A | | | X | X | X | X | X | X | X | +| TIMESTAMP | X | X | X | X | X | X | X | X | X | E | N/A | IE | IE | IE | X | X | X | X | X | X | +--------------+------+-------+---------+------+-------+--------+---------+--------------+------+--------+-----------+------+------+----------+----------+-----------+-----+--------+-----------+---------+ -| DATE | X | X | X | X | X | X | X | X | X | E | | N/A | | X | X | X | X | X | X | X | +| DATE | X | X | X | X | X | X | X | X | X | E | E | N/A | IE | E | X | X | X | X | X | X | +--------------+------+-------+---------+------+-------+--------+---------+--------------+------+--------+-----------+------+------+----------+----------+-----------+-----+--------+-----------+---------+ -| TIME | X | X | X | X | X | X | X | X | X | E | | | N/A | X | X | X | X | X | X | X | +| TIME | X | X | X | X | X | X | X | X | X | E | E | E | N/A | E | X | X | X | X | X | X | +--------------+------+-------+---------+------+-------+--------+---------+--------------+------+--------+-----------+------+------+----------+----------+-----------+-----+--------+-----------+---------+ -| DATETIME | X | X | X | X | X | X | X | X | X | E | | | | N/A | X | X | X | X | X | X | +| DATETIME | X | X | X | X | X | X | X | X | X | E | E | E | E | N/A | X | X | X | X | X | X | +--------------+------+-------+---------+------+-------+--------+---------+--------------+------+--------+-----------+------+------+----------+----------+-----------+-----+--------+-----------+---------+ -| INTERVAL | X | X | X | X | X | X | X | X | X | E | | | | X | N/A | X | X | X | X | X | +| INTERVAL | X | X | X | X | X | X | X | X | X | E | X | X | X | X | N/A | X | X | X | X | X | +--------------+------+-------+---------+------+-------+--------+---------+--------------+------+--------+-----------+------+------+----------+----------+-----------+-----+--------+-----------+---------+ | GEO_POINT | X | X | X | X | X | X | X | X | X | | X | X | X | X | X | N/A | X | X | X | X | +--------------+------+-------+---------+------+-------+--------+---------+--------------+------+--------+-----------+------+------+----------+----------+-----------+-----+--------+-----------+---------+ @@ -236,8 +236,7 @@ Numeric values ranged from -2147483648 to +2147483647 are recognized as integer Date and Time Data Types ======================== -The date and time data types are the types that represent temporal values and SQL plugin supports types including DATE, TIME, DATETIME, TIMESTAMP and INTERVAL. By default, the OpenSearch DSL uses date type as the only date and time related type, which has contained all information about an absolute time point. To integrate with SQL language, each of the types other than timestamp is holding part of temporal or timezone information, and the usage to explicitly clarify the date and time types is reflected in the datetime functions (see `Functions `_ for details), where some functions might have restrictions in the input argument type. - +The datetime types supported by the SQL plugin are ``DATE``, ``TIME``, ``DATETIME``, ``TIMESTAMP``, and ``INTERVAL``, with date and time being used to represent temporal values. By default, the OpenSearch DSL uses ``date`` type as the only date and time related type as it contains all information about an absolute time point. To integrate with SQL language each of the types other than timestamp hold part of the temporal or timezone information. This information can be used to explicitly clarify the date and time types reflected in the datetime functions (see `Functions `_ for details), where some functions might have restrictions in the input argument type. Date ---- @@ -299,7 +298,7 @@ Interval data type represents a temporal duration or a period. The syntax is as | Interval | INTERVAL expr unit | +----------+--------------------+ -The expr is any expression that can be iterated to a quantity value eventually, see `Expressions `_ for details. The unit represents the unit for interpreting the quantity, including MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER and YEAR.The INTERVAL keyword and the unit specifier are not case sensitive. Note that there are two classes of intervals. Year-week intervals can store years, quarters, months and weeks. Day-time intervals can store days, hours, minutes, seconds and microseconds. Year-week intervals are comparable only with another year-week intervals. These two types of intervals can only comparable with the same type of themselves. +The expr is any expression that can be iterated to a quantity value eventually, see `Expressions `_ for details. The unit represents the unit for interpreting the quantity, including ``MICROSECOND``, ``SECOND``, ``MINUTE``, ``HOUR``, ``DAY``, ``WEEK``, ``MONTH``, ``QUARTER`` and ``YEAR``. The ``INTERVAL`` keyword and the unit specifier are not case sensitive. Note that there are two classes of intervals. Year-week intervals can store years, quarters, months and weeks. Day-time intervals can store days, hours, minutes, seconds and microseconds. Year-week intervals are comparable only with another year-week intervals. These two types of intervals can only comparable with the same type of themselves. Conversion between date and time types @@ -320,7 +319,7 @@ Conversion from DATE Conversion from TIME >>>>>>>>>>>>>>>>>>>> -- Time value cannot be converted to any other date and time types since it does not contain any date information, so it is not meaningful to give no date info to a date/datetime/timestamp instance. +- When time value is converted to any other datetime types, the date part of the new value is filled up with today's date, like with the `CURDATE` function. For example, a time value X converted to a timestamp would produce today's date at time X. Conversion from DATETIME @@ -354,18 +353,94 @@ A string can also represent and be converted to date and time types (except to i | True | False | True | +------------------------------------------------------------+-------------------------------------+----------------------------------+ +Please, see `more examples here <../dql/expressions.rst#toc-entry-15>`_. + +Date formats +------------ + +SQL plugin supports all named formats for OpenSearch ``date`` data type, custom formats and their combination. Please, refer to `OpenSearch docs `_ for format description. +Plugin detects which type of data is stored in ``date`` field according to formats given and returns results in the corresponding SQL types. +Given an index with the following mapping. + +.. code-block:: json + + { + "mappings" : { + "properties" : { + "date1" : { + "type" : "date", + "format": "yyyy-MM-dd" + }, + "date2" : { + "type" : "date", + "format": "date_time_no_millis" + }, + "date3" : { + "type" : "date", + "format": "hour_minute_second" + }, + "date4" : { + "type" : "date" + }, + "date5" : { + "type" : "date", + "format": "yyyy-MM-dd || time" + } + } + } + } + +Querying such index will provide a response with ``schema`` block as shown below. + +.. code-block:: json + + { + "query" : "SELECT * from date_formats LIMIT 0;" + } + +.. code-block:: json + + { + "schema": [ + { + "name": "date5", + "type": "timestamp" + }, + { + "name": "date4", + "type": "timestamp" + }, + { + "name": "date3", + "type": "time" + }, + { + "name": "date2", + "type": "timestamp" + }, + { + "name": "date1", + "type": "date" + }, + ], + "datarows": [], + "total": 0, + "size": 0, + "status": 200 + } + String Data Types ================= -A string is a sequence of characters enclosed in either single or double quotes. For example, both 'text' and "text" will be treated as string literal. To use quote characters in a string literal, you can use two quotes of the same type as the enclosing quotes:: +A string is a sequence of characters enclosed in either single or double quotes. For example, both 'text' and "text" will be treated as string literal. To use quote characters in a string literal, you can use two quotes of the same type as the enclosing quotes or a backslash symbol (``\``):: - os> SELECT 'hello', "world", '"hello"', "'world'", '''hello''', """world""" + os> SELECT 'hello', "world", '"hello"', "'world'", '''hello''', """world""", 'I\'m', 'I''m', "I\"m" fetched rows / total rows = 1/1 - +-----------+-----------+-------------+-------------+---------------+---------------+ - | 'hello' | "world" | '"hello"' | "'world'" | '''hello''' | """world""" | - |-----------+-----------+-------------+-------------+---------------+---------------| - | hello | world | "hello" | 'world' | 'hello' | "world" | - +-----------+-----------+-------------+-------------+---------------+---------------+ + +-----------+-----------+-------------+-------------+---------------+---------------+----------+----------+----------+ + | 'hello' | "world" | '"hello"' | "'world'" | '''hello''' | """world""" | 'I\'m' | 'I''m' | "I\"m" | + |-----------+-----------+-------------+-------------+---------------+---------------+----------+----------+----------| + | hello | world | "hello" | 'world' | 'hello' | "world" | I'm | I'm | I"m | + +-----------+-----------+-------------+-------------+---------------+---------------+----------+----------+----------+ Boolean Data Types ================== diff --git a/docs/user/ppl/admin/spark_connector.rst b/docs/user/ppl/admin/spark_connector.rst new file mode 100644 index 0000000000..8ff8dd944e --- /dev/null +++ b/docs/user/ppl/admin/spark_connector.rst @@ -0,0 +1,92 @@ +.. highlight:: sh + +==================== +Spark Connector +==================== + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 1 + + +Introduction +============ + +This page covers spark connector properties for dataSource configuration +and the nuances associated with spark connector. + + +Spark Connector Properties in DataSource Configuration +======================================================== +Spark Connector Properties. + +* ``spark.connector`` [Required]. + * This parameters provides the spark client information for connection. +* ``spark.sql.application`` [Optional]. + * This parameters provides the spark sql application jar. Default value is ``s3://spark-datasource/sql-job.jar``. +* ``emr.cluster`` [Required]. + * This parameters provides the emr cluster id information. +* ``emr.auth.type`` [Required] + * This parameters provides the authentication type information. + * Spark emr connector currently supports ``awssigv4`` authentication mechanism and following parameters are required. + * ``emr.auth.region``, ``emr.auth.access_key`` and ``emr.auth.secret_key`` +* ``spark.datasource.flint.*`` [Optional] + * This parameters provides the Opensearch domain host information for flint integration. + * ``spark.datasource.flint.integration`` [Optional] + * Default value for integration jar is ``s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar``. + * ``spark.datasource.flint.host`` [Optional] + * Default value for host is ``localhost``. + * ``spark.datasource.flint.port`` [Optional] + * Default value for port is ``9200``. + * ``spark.datasource.flint.scheme`` [Optional] + * Default value for scheme is ``http``. + * ``spark.datasource.flint.auth`` [Optional] + * Default value for auth is ``false``. + * ``spark.datasource.flint.region`` [Optional] + * Default value for auth is ``us-west-2``. + +Example spark dataSource configuration +======================================== + +AWSSigV4 Auth:: + + [{ + "name" : "my_spark", + "connector": "spark", + "properties" : { + "spark.connector": "emr", + "emr.cluster" : "{{clusterId}}", + "emr.auth.type" : "awssigv4", + "emr.auth.region" : "us-east-1", + "emr.auth.access_key" : "{{accessKey}}" + "emr.auth.secret_key" : "{{secretKey}}" + "spark.datasource.flint.host" : "{{opensearchHost}}", + "spark.datasource.flint.port" : "{{opensearchPort}}", + "spark.datasource.flint.scheme" : "{{opensearchScheme}}", + "spark.datasource.flint.auth" : "{{opensearchAuth}}", + "spark.datasource.flint.region" : "{{opensearchRegion}}", + } + }] + + +Spark SQL Support +================== + +`sql` Function +---------------------------- +Spark connector offers `sql` function. This function can be used to run spark sql query. +The function takes spark sql query as input. Argument should be either passed by name or positionArguments should be either passed by name or position. +`source=my_spark.sql('select 1')` +or +`source=my_spark.sql(query='select 1')` +Example:: + + > source=my_spark.sql('select 1') + +---+ + | 1 | + |---+ + | 1 | + +---+ + diff --git a/docs/user/ppl/functions/datetime.rst b/docs/user/ppl/functions/datetime.rst index 917eea869c..fccfefca6b 100644 --- a/docs/user/ppl/functions/datetime.rst +++ b/docs/user/ppl/functions/datetime.rst @@ -621,7 +621,7 @@ DAY Description >>>>>>>>>>> -Usage: day(date) extracts the day of the month for date, in the range 1 to 31. The dates with value 0 such as '0000-00-00' or '2008-00-00' are invalid. +Usage: day(date) extracts the day of the month for date, in the range 1 to 31. Argument type: STRING/DATE/DATETIME/TIMESTAMP @@ -669,7 +669,7 @@ DAYOFMONTH Description >>>>>>>>>>> -Usage: dayofmonth(date) extracts the day of the month for date, in the range 1 to 31. The dates with value 0 such as '0000-00-00' or '2008-00-00' are invalid. +Usage: dayofmonth(date) extracts the day of the month for date, in the range 1 to 31. Argument type: STRING/DATE/DATETIME/TIMESTAMP @@ -694,7 +694,7 @@ DAY_OF_MONTH Description >>>>>>>>>>> -Usage: day_of_month(date) extracts the day of the month for date, in the range 1 to 31. The dates with value 0 such as '0000-00-00' or '2008-00-00' are invalid. +Usage: day_of_month(date) extracts the day of the month for date, in the range 1 to 31. Argument type: STRING/DATE/DATETIME/TIMESTAMP @@ -1074,52 +1074,52 @@ Example:: +----------------------------+ -MINUTE_OF_HOUR --------------- +MINUTE_OF_DAY +------ Description >>>>>>>>>>> -Usage: minute(time) returns the minute for time, in the range 0 to 59. +Usage: minute(time) returns the amount of minutes in the day, in the range of 0 to 1439. Argument type: STRING/TIME/DATETIME/TIMESTAMP Return type: INTEGER -Synonyms: `MINUTE`_ - Example:: - os> source=people | eval `MINUTE_OF_HOUR(TIME('01:02:03'))` = MINUTE_OF_HOUR(TIME('01:02:03')) | fields `MINUTE_OF_HOUR(TIME('01:02:03'))` + os> source=people | eval `MINUTE_OF_DAY(TIME('01:02:03'))` = MINUTE_OF_DAY(TIME('01:02:03')) | fields `MINUTE_OF_DAY(TIME('01:02:03'))` fetched rows / total rows = 1/1 - +------------------------------------+ - | MINUTE_OF_HOUR(TIME('01:02:03')) | - |------------------------------------| - | 2 | - +------------------------------------+ + +-----------------------------------+ + | MINUTE_OF_DAY(TIME('01:02:03')) | + |-----------------------------------| + | 62 | + +-----------------------------------+ -MINUTE_OF_DAY ------- +MINUTE_OF_HOUR +-------------- Description >>>>>>>>>>> -Usage: minute(time) returns the amount of minutes in the day, in the range of 0 to 1439. +Usage: minute(time) returns the minute for time, in the range 0 to 59. Argument type: STRING/TIME/DATETIME/TIMESTAMP Return type: INTEGER +Synonyms: `MINUTE`_ + Example:: - os> source=people | eval `MINUTE_OF_DAY(TIME('01:02:03'))` = MINUTE_OF_DAY(TIME('01:02:03')) | fields `MINUTE_OF_DAY(TIME('01:02:03'))` + os> source=people | eval `MINUTE_OF_HOUR(TIME('01:02:03'))` = MINUTE_OF_HOUR(TIME('01:02:03')) | fields `MINUTE_OF_HOUR(TIME('01:02:03'))` fetched rows / total rows = 1/1 - +-----------------------------------+ - | MINUTE_OF_DAY(TIME('01:02:03')) | - |-----------------------------------| - | 62 | - +-----------------------------------+ + +------------------------------------+ + | MINUTE_OF_HOUR(TIME('01:02:03')) | + |------------------------------------| + | 2 | + +------------------------------------+ MONTH @@ -1128,7 +1128,7 @@ MONTH Description >>>>>>>>>>> -Usage: month(date) returns the month for date, in the range 1 to 12 for January to December. The dates with value 0 such as '0000-00-00' or '2008-00-00' are invalid. +Usage: month(date) returns the month for date, in the range 1 to 12 for January to December. Argument type: STRING/DATE/DATETIME/TIMESTAMP @@ -1154,7 +1154,7 @@ MONTH_OF_YEAR Description >>>>>>>>>>> -Usage: month_of_year(date) returns the month for date, in the range 1 to 12 for January to December. The dates with value 0 such as '0000-00-00' or '2008-00-00' are invalid. +Usage: month_of_year(date) returns the month for date, in the range 1 to 12 for January to December. Argument type: STRING/DATE/DATETIME/TIMESTAMP diff --git a/integ-test/build.gradle b/integ-test/build.gradle index fc97fff9a4..c71701d821 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -298,6 +298,9 @@ integTest { // Exclude JDBC related tests exclude 'org/opensearch/sql/jdbc/**' + + // Exclude this IT until running IT with security plugin enabled is ready + exclude 'org/opensearch/sql/ppl/CrossClusterSearchIT.class' } diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java index ac6949e77e..c942962fb8 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java @@ -17,14 +17,10 @@ import java.util.ArrayList; import java.util.List; import lombok.SneakyThrows; -import org.apache.commons.lang3.StringUtils; import org.junit.AfterClass; import org.junit.Assert; -import org.junit.BeforeClass; import org.junit.Test; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Request; -import org.opensearch.client.RequestOptions; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.sql.datasource.model.DataSourceMetadata; diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DatasourceClusterSettingsIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DatasourceClusterSettingsIT.java new file mode 100644 index 0000000000..8c4959707a --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DatasourceClusterSettingsIT.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.datasource; + +import static org.hamcrest.Matchers.equalTo; + +import java.io.IOException; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONObject; +import org.junit.Test; +import org.opensearch.client.ResponseException; +import org.opensearch.sql.legacy.TestUtils; +import org.opensearch.sql.ppl.PPLIntegTestCase; + +public class DatasourceClusterSettingsIT extends PPLIntegTestCase { + + private static final Logger LOG = LogManager.getLogger(); + @Test + public void testGetDatasourceClusterSettings() throws IOException { + JSONObject clusterSettings = getAllClusterSettings(); + assertThat(clusterSettings.query("/defaults/plugins.query.datasources.encryption.masterkey"), + equalTo(null)); + } + + + @Test + public void testPutDatasourceClusterSettings() throws IOException { + final ResponseException exception = + expectThrows(ResponseException.class, () -> updateClusterSettings(new ClusterSetting(PERSISTENT, + "plugins.query.datasources.encryption.masterkey", + "masterkey"))); + JSONObject resp = new JSONObject(TestUtils.getResponseBody(exception.getResponse())); + assertThat(resp.getInt("status"), equalTo(400)); + assertThat(resp.query("/error/root_cause/0/reason"), + equalTo("final persistent setting [plugins.query.datasources.encryption.masterkey], not updateable")); + assertThat(resp.query("/error/type"), equalTo("settings_exception")); + } + +} diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java index 200c300f3b..ef80098df6 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java @@ -53,6 +53,9 @@ public class PrettyFormatResponseIT extends SQLIntegTestCase { private static final Set messageFields = Sets.newHashSet( "message.dayOfWeek", "message.info", "message.author"); + private static final Set messageFieldsWithNestedFunction = Sets.newHashSet( + "nested(message.dayOfWeek)", "nested(message.info)", "nested(message.author)"); + private static final Set commentFields = Sets.newHashSet("comment.data", "comment.likes"); private static final List nameFields = Arrays.asList("firstname", "lastname"); @@ -211,7 +214,7 @@ public void selectNestedFieldWithWildcard() throws IOException { String.format(Locale.ROOT, "SELECT nested(message.*) FROM %s", TestsConstants.TEST_INDEX_NESTED_TYPE)); - assertContainsColumnsInAnyOrder(getSchema(response), messageFields); + assertContainsColumnsInAnyOrder(getSchema(response), messageFieldsWithNestedFunction); assertContainsData(getDataRows(response), messageFields); } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/DateTimeFormatsIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/DateTimeFormatsIT.java index 7cd95fb509..fc05e502c5 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/DateTimeFormatsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/DateTimeFormatsIT.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.util.Locale; +import lombok.SneakyThrows; import org.json.JSONObject; import org.junit.jupiter.api.Test; import org.opensearch.client.Request; @@ -56,6 +57,69 @@ public void testDateFormatsWithOr() throws IOException { rows("1984-04-12 09:07:42.000123456")); } + @Test + @SneakyThrows + public void testCustomFormats() { + String query = String.format("SELECT custom_time, custom_timestamp, custom_date_or_date," + + "custom_date_or_custom_time, custom_time_parser_check FROM %s", TEST_INDEX_DATE_FORMATS); + JSONObject result = executeQuery(query); + verifySchema(result, + schema("custom_time", null, "time"), + schema("custom_timestamp", null, "timestamp"), + schema("custom_date_or_date", null, "date"), + schema("custom_date_or_custom_time", null, "timestamp"), + schema("custom_time_parser_check", null, "time")); + verifyDataRows(result, + rows("09:07:42", "1984-04-12 09:07:42", "1984-04-12", "1961-04-12 00:00:00", "23:44:36.321"), + rows("21:07:42", "1984-04-12 22:07:42", "1984-04-12", "1970-01-01 09:07:00", "09:01:16.542")); + } + + @Test + @SneakyThrows + public void testCustomFormats2() { + String query = String.format("SELECT custom_no_delimiter_date, custom_no_delimiter_time," + + "custom_no_delimiter_ts FROM %s", TEST_INDEX_DATE_FORMATS); + JSONObject result = executeQuery(query); + verifySchema(result, + schema("custom_no_delimiter_date", null, "date"), + schema("custom_no_delimiter_time", null, "time"), + schema("custom_no_delimiter_ts", null, "timestamp")); + verifyDataRows(result, + rows("1984-10-20", "10:20:30", "1984-10-20 15:35:48"), + rows("1961-04-12", "09:07:00", "1961-04-12 09:07:00")); + } + + @Test + @SneakyThrows + public void testIncompleteFormats() { + String query = String.format("SELECT incomplete_1, incomplete_2, incorrect," + + "incomplete_custom_time, incomplete_custom_date FROM %s", TEST_INDEX_DATE_FORMATS); + JSONObject result = executeQuery(query); + verifySchema(result, + schema("incomplete_1", null, "timestamp"), + schema("incomplete_2", null, "date"), + schema("incorrect", null, "timestamp"), + schema("incomplete_custom_time", null, "time"), + schema("incomplete_custom_date", null, "date")); + verifyDataRows(result, + rows("1984-01-01 00:00:00", null, null, "10:00:00", "1999-01-01"), + rows("2012-01-01 00:00:00", null, null, "20:00:00", "3021-01-01")); + } + + @Test + @SneakyThrows + public void testNumericFormats() { + String query = String.format("SELECT epoch_sec, epoch_milli" + + " FROM %s", TEST_INDEX_DATE_FORMATS); + JSONObject result = executeQuery(query); + verifySchema(result, + schema("epoch_sec", null, "timestamp"), + schema("epoch_milli", null, "timestamp")); + verifyDataRows(result, + rows("1970-01-01 00:00:42", "1970-01-01 00:00:00.042"), + rows("1970-01-02 03:55:00", "1970-01-01 00:01:40.5")); + } + protected JSONObject executeQuery(String query) throws IOException { Request request = new Request("POST", QUERY_API_ENDPOINT); request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query)); diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/IdentifierIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/IdentifierIT.java index d5c194968d..22632cc4de 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/IdentifierIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/IdentifierIT.java @@ -14,6 +14,8 @@ import static org.opensearch.sql.util.TestUtils.performRequest; import java.io.IOException; +import java.util.ArrayList; +import org.json.JSONArray; import org.json.JSONObject; import org.junit.jupiter.api.Test; import org.opensearch.client.Request; @@ -99,6 +101,80 @@ public void testMetafieldIdentifierTest() throws IOException { verifyDataRows(result, rows(30, id, index, 1.0, 1.0, -2)); } + @Test + public void testMetafieldIdentifierRoutingSelectTest() throws IOException { + // create an index, but the contents doesn't really matter + String index = "test.routing_select"; + String mapping = "{\"_routing\": {\"required\": true }}"; + new Index(index, mapping) + .addDocWithShardId("{\"age\": 31}", "test0", "test0") + .addDocWithShardId("{\"age\": 31}", "test1", "test1") + .addDocWithShardId("{\"age\": 32}", "test2", "test2") + .addDocWithShardId("{\"age\": 33}", "test3", "test3") + .addDocWithShardId("{\"age\": 34}", "test4", "test4") + .addDocWithShardId("{\"age\": 35}", "test5", "test5"); + + // Execute using field metadata values filtering on the routing shard hash id + final JSONObject result = new JSONObject(executeQuery( + "SELECT age, _id, _index, _routing " + + "FROM " + index, + "jdbc")); + + // Verify that the metadata values are returned when requested + verifySchema(result, + schema("age", null, "long"), + schema("_id", null, "keyword"), + schema("_index", null, "keyword"), + schema("_routing", null, "keyword")); + assertTrue(result.getJSONArray("schema").length() == 4); + + var datarows = result.getJSONArray("datarows"); + assertEquals(6, datarows.length()); + + // note that _routing in the SELECT clause returns the shard + for (int i = 0; i < 6; i++) { + assertEquals("test" + i, datarows.getJSONArray(i).getString(1)); + assertEquals(index, datarows.getJSONArray(i).getString(2)); + assertTrue(datarows.getJSONArray(i).getString(3).contains("[" + index + "]")); + } + } + + @Test + public void testMetafieldIdentifierRoutingFilterTest() throws IOException { + // create an index, but the contents doesn't really matter + String index = "test.routing_filter"; + String mapping = "{\"_routing\": {\"required\": true }}"; + new Index(index, mapping) + .addDocWithShardId("{\"age\": 31}", "test1", "test1") + .addDocWithShardId("{\"age\": 32}", "test2", "test2") + .addDocWithShardId("{\"age\": 33}", "test3", "test3") + .addDocWithShardId("{\"age\": 34}", "test4", "test4") + .addDocWithShardId("{\"age\": 35}", "test5", "test5") + .addDocWithShardId("{\"age\": 36}", "test6", "test6"); + + // Execute using field metadata values filtering on the routing shard hash id + final JSONObject result = new JSONObject(executeQuery( + "SELECT _id, _index, _routing " + + "FROM " + index + " " + + "WHERE _routing = \\\"test4\\\"", + "jdbc")); + + // Verify that the metadata values are returned when requested + verifySchema(result, + schema("_id", null, "keyword"), + schema("_index", null, "keyword"), + schema("_routing", null, "keyword")); + assertTrue(result.getJSONArray("schema").length() == 3); + + var datarows = result.getJSONArray("datarows"); + assertEquals(1, datarows.length()); + + assertEquals("test4", datarows.getJSONArray(0).getString(0)); + // note that _routing in the SELECT clause returns the shard, not the routing hash id + assertTrue(datarows.getJSONArray(0).getString(2).contains("[" + index + "]")); + + } + @Test public void testMetafieldIdentifierWithAliasTest() throws IOException { // create an index, but the contents doesn't matter @@ -152,16 +228,32 @@ private static class Index { } } + Index(String indexName, String mapping) throws IOException { + this.indexName = indexName; + + Request createIndex = new Request("PUT", "/" + indexName); + createIndex.setJsonEntity(mapping); + executeRequest(new Request("PUT", "/" + indexName)); + } + void addDoc(String doc) { Request indexDoc = new Request("POST", String.format("/%s/_doc?refresh=true", indexName)); indexDoc.setJsonEntity(doc); performRequest(client(), indexDoc); } - void addDoc(String doc, String id) { + public Index addDoc(String doc, String id) { Request indexDoc = new Request("POST", String.format("/%s/_doc/%s?refresh=true", indexName, id)); indexDoc.setJsonEntity(doc); performRequest(client(), indexDoc); + return this; + } + + public Index addDocWithShardId(String doc, String id, String routing) { + Request indexDoc = new Request("POST", String.format("/%s/_doc/%s?refresh=true&routing=%s", indexName, id, routing)); + indexDoc.setJsonEntity(doc); + performRequest(client(), indexDoc); + return this; } } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java index 69b54cfc4f..d3230188b7 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java @@ -89,8 +89,7 @@ public void nested_function_with_arrays_in_an_aggregate_function_in_select_test( verifyDataRows(result, rows(19)); } - // TODO not currently supported by legacy, should we add implementation in AstBuilder? - @Disabled + @Test public void nested_function_in_a_function_in_select_test() { String query = "SELECT upper(nested(message.info)) FROM " + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS; @@ -104,6 +103,41 @@ public void nested_function_in_a_function_in_select_test() { rows("ZZ")); } + @Test + public void nested_all_function_in_a_function_in_select_test() { + String query = "SELECT nested(message.*) FROM " + + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS + " WHERE nested(message.info) = 'a'"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, rows("e", 1, "a")); + } + + @Test + public void invalid_multiple_nested_all_function_in_a_function_in_select_test() { + String query = "SELECT nested(message.*), nested(message.info) FROM " + + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS; + RuntimeException result = assertThrows( + RuntimeException.class, + () -> executeJdbcRequest(query) + ); + assertTrue( + result.getMessage().contains("IllegalArgumentException") + && result.getMessage().contains("Multiple entries with same key") + ); + } + + @Test + public void nested_all_function_with_limit_test() { + String query = "SELECT nested(message.*) FROM " + + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS + " LIMIT 3"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, + rows("e", 1, "a"), + rows("f", 2, "b"), + rows("g", 1, "c") + ); + } + + @Test public void nested_function_with_array_of_multi_nested_field_test() { String query = "SELECT nested(message.author.name) FROM " + TEST_INDEX_MULTI_NESTED_TYPE; @@ -403,6 +437,107 @@ public void test_nested_in_where_as_predicate_expression_with_relevance_query() verifyDataRows(result, rows(10, "a")); } + @Test + public void nested_function_all_subfields() { + String query = "SELECT nested(message.*) FROM " + TEST_INDEX_NESTED_TYPE; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifySchema(result, + schema("nested(message.author)", null, "keyword"), + schema("nested(message.dayOfWeek)", null, "long"), + schema("nested(message.info)", null, "keyword")); + verifyDataRows(result, + rows("e", 1, "a"), + rows("f", 2, "b"), + rows("g", 1, "c"), + rows("h", 4, "c"), + rows("i", 5, "a"), + rows("zz", 6, "zz")); + } + + @Test + public void nested_function_all_subfields_and_specified_subfield() { + String query = "SELECT nested(message.*), nested(comment.data) FROM " + + TEST_INDEX_NESTED_TYPE; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifySchema(result, + schema("nested(message.author)", null, "keyword"), + schema("nested(message.dayOfWeek)", null, "long"), + schema("nested(message.info)", null, "keyword"), + schema("nested(comment.data)", null, "keyword")); + verifyDataRows(result, + rows("e", 1, "a", "ab"), + rows("f", 2, "b", "aa"), + rows("g", 1, "c", "aa"), + rows("h", 4, "c", "ab"), + rows("i", 5, "a", "ab"), + rows("zz", 6, "zz", new JSONArray(List.of("aa", "bb")))); + } + + @Test + public void nested_function_all_deep_nested_subfields() { + String query = "SELECT nested(message.author.address.*) FROM " + + TEST_INDEX_MULTI_NESTED_TYPE; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifySchema(result, + schema("nested(message.author.address.number)", null, "integer"), + schema("nested(message.author.address.street)", null, "keyword")); + verifyDataRows(result, + rows(1, "bc"), + rows(2, "ab"), + rows(3, "sk"), + rows(4, "mb"), + rows(5, "on"), + rows(6, "qc")); + } + + @Test + public void nested_function_all_subfields_for_two_nested_fields() { + String query = "SELECT nested(message.*), nested(comment.*) FROM " + + TEST_INDEX_NESTED_TYPE; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifySchema(result, + schema("nested(message.author)", null, "keyword"), + schema("nested(message.dayOfWeek)", null, "long"), + schema("nested(message.info)", null, "keyword"), + schema("nested(comment.data)", null, "keyword"), + schema("nested(comment.likes)", null, "long")); + verifyDataRows(result, + rows("e", 1, "a", "ab", 3), + rows("f", 2, "b", "aa", 2), + rows("g", 1, "c", "aa", 3), + rows("h", 4, "c", "ab", 1), + rows("i", 5, "a", "ab", 1), + rows("zz", 6, "zz", new JSONArray(List.of("aa", "bb")), 10)); + } + + @Test + public void nested_function_all_subfields_and_non_nested_field() { + String query = "SELECT nested(message.*), myNum FROM " + TEST_INDEX_NESTED_TYPE; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifySchema(result, + schema("nested(message.author)", null, "keyword"), + schema("nested(message.dayOfWeek)", null, "long"), + schema("nested(message.info)", null, "keyword"), + schema("myNum", null, "long")); + verifyDataRows(result, + rows("e", 1, "a", 1), + rows("f", 2, "b", 2), + rows("g", 1, "c", 3), + rows("h", 4, "c", 4), + rows("i", 5, "a", 4), + rows("zz", 6, "zz", new JSONArray(List.of(3, 4)))); + } + @Test public void nested_function_with_date_types_as_object_arrays_within_arrays_test() { String query = "SELECT nested(address.moveInDate) FROM " + TEST_INDEX_NESTED_SIMPLE; @@ -435,4 +570,23 @@ public void nested_function_with_date_types_as_object_arrays_within_arrays_test( ) ); } + + @Test + public void nested_function_all_subfields_in_wrong_clause() { + String query = "SELECT * FROM " + TEST_INDEX_NESTED_TYPE + " ORDER BY nested(message.*)"; + + Exception exception = assertThrows(RuntimeException.class, () -> + executeJdbcRequest(query)); + + assertTrue(exception.getMessage().contains("" + + "{\n" + + " \"error\": {\n" + + " \"reason\": \"There was internal problem at backend\",\n" + + " \"details\": \"Invalid use of expression nested(message.*)\",\n" + + " \"type\": \"UnsupportedOperationException\"\n" + + " },\n" + + " \"status\": 503\n" + + "}" + )); + } } diff --git a/integ-test/src/test/resources/date_formats.json b/integ-test/src/test/resources/date_formats.json index cc694930e9..13d46a0e8c 100644 --- a/integ-test/src/test/resources/date_formats.json +++ b/integ-test/src/test/resources/date_formats.json @@ -1,4 +1,4 @@ {"index": {}} -{"epoch_millis": "450608862000.123456", "epoch_second": "450608862.000123456", "date_optional_time": "1984-04-12T09:07:42.000Z", "strict_date_optional_time": "1984-04-12T09:07:42.000Z", "strict_date_optional_time_nanos": "1984-04-12T09:07:42.000123456Z", "basic_date": "19840412", "basic_date_time": "19840412T090742.000Z", "basic_date_time_no_millis": "19840412T090742Z", "basic_ordinal_date": "1984103", "basic_ordinal_date_time": "1984103T090742.000Z", "basic_ordinal_date_time_no_millis": "1984103T090742Z", "basic_time": "090742.000Z", "basic_time_no_millis": "090742Z", "basic_t_time": "T090742.000Z", "basic_t_time_no_millis": "T090742Z", "basic_week_date": "1984W154", "strict_basic_week_date": "1984W154", "basic_week_date_time": "1984W154T090742.000Z", "strict_basic_week_date_time": "1984W154T090742.000Z", "basic_week_date_time_no_millis": "1984W154T090742Z", "strict_basic_week_date_time_no_millis": "1984W154T090742Z", "date": "1984-04-12", "strict_date": "1984-04-12", "date_hour": "1984-04-12T09", "strict_date_hour": "1984-04-12T09", "date_hour_minute": "1984-04-12T09:07", "strict_date_hour_minute": "1984-04-12T09:07", "date_hour_minute_second": "1984-04-12T09:07:42", "strict_date_hour_minute_second": "1984-04-12T09:07:42", "date_hour_minute_second_fraction": "1984-04-12T09:07:42.000", "strict_date_hour_minute_second_fraction": "1984-04-12T09:07:42.000", "date_hour_minute_second_millis": "1984-04-12T09:07:42.000", "strict_date_hour_minute_second_millis": "1984-04-12T09:07:42.000", "date_time": "1984-04-12T09:07:42.000Z", "strict_date_time": "1984-04-12T09:07:42.000123456Z", "date_time_no_millis": "1984-04-12T09:07:42Z", "strict_date_time_no_millis": "1984-04-12T09:07:42Z", "hour": "09", "strict_hour": "09", "hour_minute": "09:07", "strict_hour_minute": "09:07", "hour_minute_second": "09:07:42", "strict_hour_minute_second": "09:07:42", "hour_minute_second_fraction": "09:07:42.000", "strict_hour_minute_second_fraction": "09:07:42.000", "hour_minute_second_millis": "09:07:42.000", "strict_hour_minute_second_millis": "09:07:42.000", "ordinal_date": "1984-103", "strict_ordinal_date": "1984-103", "ordinal_date_time": "1984-103T09:07:42.000123456Z", "strict_ordinal_date_time": "1984-103T09:07:42.000123456Z", "ordinal_date_time_no_millis": "1984-103T09:07:42Z", "strict_ordinal_date_time_no_millis": "1984-103T09:07:42Z", "time": "09:07:42.000Z", "strict_time": "09:07:42.000Z", "time_no_millis": "09:07:42Z", "strict_time_no_millis": "09:07:42Z", "t_time": "T09:07:42.000Z", "strict_t_time": "T09:07:42.000Z", "t_time_no_millis": "T09:07:42Z", "strict_t_time_no_millis": "T09:07:42Z", "week_date": "1984-W15-4", "strict_week_date": "1984-W15-4", "week_date_time": "1984-W15-4T09:07:42.000Z", "strict_week_date_time": "1984-W15-4T09:07:42.000Z", "week_date_time_no_millis": "1984-W15-4T09:07:42Z", "strict_week_date_time_no_millis": "1984-W15-4T09:07:42Z", "weekyear_week_day": "1984-W15-4", "strict_weekyear_week_day": "1984-W15-4", "year_month_day": "1984-04-12", "strict_year_month_day": "1984-04-12", "yyyy-MM-dd": "1984-04-12", "HH:mm:ss": "09:07:42", "yyyy-MM-dd_OR_epoch_millis": "1984-04-12", "hour_minute_second_OR_t_time": "09:07:42"} +{"epoch_millis": "450608862000.123456", "epoch_second": "450608862.000123456", "date_optional_time": "1984-04-12T09:07:42.000Z", "strict_date_optional_time": "1984-04-12T09:07:42.000Z", "strict_date_optional_time_nanos": "1984-04-12T09:07:42.000123456Z", "basic_date": "19840412", "basic_date_time": "19840412T090742.000Z", "basic_date_time_no_millis": "19840412T090742Z", "basic_ordinal_date": "1984103", "basic_ordinal_date_time": "1984103T090742.000Z", "basic_ordinal_date_time_no_millis": "1984103T090742Z", "basic_time": "090742.000Z", "basic_time_no_millis": "090742Z", "basic_t_time": "T090742.000Z", "basic_t_time_no_millis": "T090742Z", "basic_week_date": "1984W154", "strict_basic_week_date": "1984W154", "basic_week_date_time": "1984W154T090742.000Z", "strict_basic_week_date_time": "1984W154T090742.000Z", "basic_week_date_time_no_millis": "1984W154T090742Z", "strict_basic_week_date_time_no_millis": "1984W154T090742Z", "date": "1984-04-12", "strict_date": "1984-04-12", "date_hour": "1984-04-12T09", "strict_date_hour": "1984-04-12T09", "date_hour_minute": "1984-04-12T09:07", "strict_date_hour_minute": "1984-04-12T09:07", "date_hour_minute_second": "1984-04-12T09:07:42", "strict_date_hour_minute_second": "1984-04-12T09:07:42", "date_hour_minute_second_fraction": "1984-04-12T09:07:42.000", "strict_date_hour_minute_second_fraction": "1984-04-12T09:07:42.000", "date_hour_minute_second_millis": "1984-04-12T09:07:42.000", "strict_date_hour_minute_second_millis": "1984-04-12T09:07:42.000", "date_time": "1984-04-12T09:07:42.000Z", "strict_date_time": "1984-04-12T09:07:42.000123456Z", "date_time_no_millis": "1984-04-12T09:07:42Z", "strict_date_time_no_millis": "1984-04-12T09:07:42Z", "hour": "09", "strict_hour": "09", "hour_minute": "09:07", "strict_hour_minute": "09:07", "hour_minute_second": "09:07:42", "strict_hour_minute_second": "09:07:42", "hour_minute_second_fraction": "09:07:42.000", "strict_hour_minute_second_fraction": "09:07:42.000", "hour_minute_second_millis": "09:07:42.000", "strict_hour_minute_second_millis": "09:07:42.000", "ordinal_date": "1984-103", "strict_ordinal_date": "1984-103", "ordinal_date_time": "1984-103T09:07:42.000123456Z", "strict_ordinal_date_time": "1984-103T09:07:42.000123456Z", "ordinal_date_time_no_millis": "1984-103T09:07:42Z", "strict_ordinal_date_time_no_millis": "1984-103T09:07:42Z", "time": "09:07:42.000Z", "strict_time": "09:07:42.000Z", "time_no_millis": "09:07:42Z", "strict_time_no_millis": "09:07:42Z", "t_time": "T09:07:42.000Z", "strict_t_time": "T09:07:42.000Z", "t_time_no_millis": "T09:07:42Z", "strict_t_time_no_millis": "T09:07:42Z", "week_date": "1984-W15-4", "strict_week_date": "1984-W15-4", "week_date_time": "1984-W15-4T09:07:42.000Z", "strict_week_date_time": "1984-W15-4T09:07:42.000Z", "week_date_time_no_millis": "1984-W15-4T09:07:42Z", "strict_week_date_time_no_millis": "1984-W15-4T09:07:42Z", "weekyear_week_day": "1984-W15-4", "strict_weekyear_week_day": "1984-W15-4", "year_month_day": "1984-04-12", "strict_year_month_day": "1984-04-12", "yyyy-MM-dd": "1984-04-12", "custom_time": "09:07:42 AM", "yyyy-MM-dd_OR_epoch_millis": "1984-04-12", "hour_minute_second_OR_t_time": "09:07:42", "custom_timestamp": "1984-04-12 09:07:42 ---- AM", "custom_date_or_date": "1984-04-12", "custom_date_or_custom_time": "1961-04-12", "custom_time_parser_check": "85476321", "incomplete_1" : 1984, "incomplete_2": null, "incomplete_custom_date": 1999, "incomplete_custom_time" : 10, "incorrect" : null, "epoch_sec" : 42, "epoch_milli" : 42, "custom_no_delimiter_date" : "19841020", "custom_no_delimiter_time" : "102030", "custom_no_delimiter_ts" : "19841020153548"} {"index": {}} -{"epoch_millis": "450608862000.123456", "epoch_second": "450608862.000123456", "date_optional_time": "1984-04-12T09:07:42.000Z", "strict_date_optional_time": "1984-04-12T09:07:42.000Z", "strict_date_optional_time_nanos": "1984-04-12T09:07:42.000123456Z", "basic_date": "19840412", "basic_date_time": "19840412T090742.000Z", "basic_date_time_no_millis": "19840412T090742Z", "basic_ordinal_date": "1984103", "basic_ordinal_date_time": "1984103T090742.000Z", "basic_ordinal_date_time_no_millis": "1984103T090742Z", "basic_time": "090742.000Z", "basic_time_no_millis": "090742Z", "basic_t_time": "T090742.000Z", "basic_t_time_no_millis": "T090742Z", "basic_week_date": "1984W154", "strict_basic_week_date": "1984W154", "basic_week_date_time": "1984W154T090742.000Z", "strict_basic_week_date_time": "1984W154T090742.000Z", "basic_week_date_time_no_millis": "1984W154T090742Z", "strict_basic_week_date_time_no_millis": "1984W154T090742Z", "date": "1984-04-12", "strict_date": "1984-04-12", "date_hour": "1984-04-12T09", "strict_date_hour": "1984-04-12T09", "date_hour_minute": "1984-04-12T09:07", "strict_date_hour_minute": "1984-04-12T09:07", "date_hour_minute_second": "1984-04-12T09:07:42", "strict_date_hour_minute_second": "1984-04-12T09:07:42", "date_hour_minute_second_fraction": "1984-04-12T09:07:42.000", "strict_date_hour_minute_second_fraction": "1984-04-12T09:07:42.000", "date_hour_minute_second_millis": "1984-04-12T09:07:42.000", "strict_date_hour_minute_second_millis": "1984-04-12T09:07:42.000", "date_time": "1984-04-12T09:07:42.000Z", "strict_date_time": "1984-04-12T09:07:42.000123456Z", "date_time_no_millis": "1984-04-12T09:07:42Z", "strict_date_time_no_millis": "1984-04-12T09:07:42Z", "hour": "09", "strict_hour": "09", "hour_minute": "09:07", "strict_hour_minute": "09:07", "hour_minute_second": "09:07:42", "strict_hour_minute_second": "09:07:42", "hour_minute_second_fraction": "09:07:42.000", "strict_hour_minute_second_fraction": "09:07:42.000", "hour_minute_second_millis": "09:07:42.000", "strict_hour_minute_second_millis": "09:07:42.000", "ordinal_date": "1984-103", "strict_ordinal_date": "1984-103", "ordinal_date_time": "1984-103T09:07:42.000123456Z", "strict_ordinal_date_time": "1984-103T09:07:42.000123456Z", "ordinal_date_time_no_millis": "1984-103T09:07:42Z", "strict_ordinal_date_time_no_millis": "1984-103T09:07:42Z", "time": "09:07:42.000Z", "strict_time": "09:07:42.000Z", "time_no_millis": "09:07:42Z", "strict_time_no_millis": "09:07:42Z", "t_time": "T09:07:42.000Z", "strict_t_time": "T09:07:42.000Z", "t_time_no_millis": "T09:07:42Z", "strict_t_time_no_millis": "T09:07:42Z", "week_date": "1984-W15-4", "strict_week_date": "1984-W15-4", "week_date_time": "1984-W15-4T09:07:42.000Z", "strict_week_date_time": "1984-W15-4T09:07:42.000Z", "week_date_time_no_millis": "1984-W15-4T09:07:42Z", "strict_week_date_time_no_millis": "1984-W15-4T09:07:42Z", "weekyear_week_day": "1984-W15-4", "strict_weekyear_week_day": "1984-W15-4", "year_month_day": "1984-04-12", "strict_year_month_day": "1984-04-12", "yyyy-MM-dd": "1984-04-12", "HH:mm:ss": "09:07:42", "yyyy-MM-dd_OR_epoch_millis": "450608862000.123456", "hour_minute_second_OR_t_time": "T09:07:42.000Z"} +{"epoch_millis": "450608862000.123456", "epoch_second": "450608862.000123456", "date_optional_time": "1984-04-12T09:07:42.000Z", "strict_date_optional_time": "1984-04-12T09:07:42.000Z", "strict_date_optional_time_nanos": "1984-04-12T09:07:42.000123456Z", "basic_date": "19840412", "basic_date_time": "19840412T090742.000Z", "basic_date_time_no_millis": "19840412T090742Z", "basic_ordinal_date": "1984103", "basic_ordinal_date_time": "1984103T090742.000Z", "basic_ordinal_date_time_no_millis": "1984103T090742Z", "basic_time": "090742.000Z", "basic_time_no_millis": "090742Z", "basic_t_time": "T090742.000Z", "basic_t_time_no_millis": "T090742Z", "basic_week_date": "1984W154", "strict_basic_week_date": "1984W154", "basic_week_date_time": "1984W154T090742.000Z", "strict_basic_week_date_time": "1984W154T090742.000Z", "basic_week_date_time_no_millis": "1984W154T090742Z", "strict_basic_week_date_time_no_millis": "1984W154T090742Z", "date": "1984-04-12", "strict_date": "1984-04-12", "date_hour": "1984-04-12T09", "strict_date_hour": "1984-04-12T09", "date_hour_minute": "1984-04-12T09:07", "strict_date_hour_minute": "1984-04-12T09:07", "date_hour_minute_second": "1984-04-12T09:07:42", "strict_date_hour_minute_second": "1984-04-12T09:07:42", "date_hour_minute_second_fraction": "1984-04-12T09:07:42.000", "strict_date_hour_minute_second_fraction": "1984-04-12T09:07:42.000", "date_hour_minute_second_millis": "1984-04-12T09:07:42.000", "strict_date_hour_minute_second_millis": "1984-04-12T09:07:42.000", "date_time": "1984-04-12T09:07:42.000Z", "strict_date_time": "1984-04-12T09:07:42.000123456Z", "date_time_no_millis": "1984-04-12T09:07:42Z", "strict_date_time_no_millis": "1984-04-12T09:07:42Z", "hour": "09", "strict_hour": "09", "hour_minute": "09:07", "strict_hour_minute": "09:07", "hour_minute_second": "09:07:42", "strict_hour_minute_second": "09:07:42", "hour_minute_second_fraction": "09:07:42.000", "strict_hour_minute_second_fraction": "09:07:42.000", "hour_minute_second_millis": "09:07:42.000", "strict_hour_minute_second_millis": "09:07:42.000", "ordinal_date": "1984-103", "strict_ordinal_date": "1984-103", "ordinal_date_time": "1984-103T09:07:42.000123456Z", "strict_ordinal_date_time": "1984-103T09:07:42.000123456Z", "ordinal_date_time_no_millis": "1984-103T09:07:42Z", "strict_ordinal_date_time_no_millis": "1984-103T09:07:42Z", "time": "09:07:42.000Z", "strict_time": "09:07:42.000Z", "time_no_millis": "09:07:42Z", "strict_time_no_millis": "09:07:42Z", "t_time": "T09:07:42.000Z", "strict_t_time": "T09:07:42.000Z", "t_time_no_millis": "T09:07:42Z", "strict_t_time_no_millis": "T09:07:42Z", "week_date": "1984-W15-4", "strict_week_date": "1984-W15-4", "week_date_time": "1984-W15-4T09:07:42.000Z", "strict_week_date_time": "1984-W15-4T09:07:42.000Z", "week_date_time_no_millis": "1984-W15-4T09:07:42Z", "strict_week_date_time_no_millis": "1984-W15-4T09:07:42Z", "weekyear_week_day": "1984-W15-4", "strict_weekyear_week_day": "1984-W15-4", "year_month_day": "1984-04-12", "strict_year_month_day": "1984-04-12", "yyyy-MM-dd": "1984-04-12", "custom_time": "09:07:42 PM", "yyyy-MM-dd_OR_epoch_millis": "450608862000.123456", "hour_minute_second_OR_t_time": "T09:07:42.000Z", "custom_timestamp": "1984-04-12 10:07:42 ---- PM", "custom_date_or_date": "1984-04-12", "custom_date_or_custom_time": "09:07:00", "custom_time_parser_check": "::: 9-32476542", "incomplete_1" : 2012, "incomplete_2": null, "incomplete_custom_date": 3021, "incomplete_custom_time" : 20, "incorrect" : null, "epoch_sec" : 100500, "epoch_milli" : 100500, "custom_no_delimiter_date" : "19610412", "custom_no_delimiter_time" : "090700", "custom_no_delimiter_ts" : "19610412090700"} diff --git a/integ-test/src/test/resources/indexDefinitions/date_formats_index_mapping.json b/integ-test/src/test/resources/indexDefinitions/date_formats_index_mapping.json index 938f598d0b..65811f8d9e 100644 --- a/integ-test/src/test/resources/indexDefinitions/date_formats_index_mapping.json +++ b/integ-test/src/test/resources/indexDefinitions/date_formats_index_mapping.json @@ -289,9 +289,9 @@ "type" : "date", "format": "yyyy-MM-dd" }, - "HH:mm:ss" : { + "custom_time" : { "type" : "date", - "format": "HH:mm:ss" + "format": "hh:mm:ss a" }, "yyyy-MM-dd_OR_epoch_millis" : { "type" : "date", @@ -300,7 +300,63 @@ "hour_minute_second_OR_t_time" : { "type" : "date", "format": "hour_minute_second||t_time" + }, + "custom_timestamp" : { + "type" : "date", + "format": "yyyy-MM-dd hh:mm:ss ---- a" + }, + "custom_date_or_date" : { + "type" : "date", + "format": "yyyy-MM-dd||date" + }, + "custom_date_or_custom_time" : { + "type" : "date", + "format" : "yyyy-MM-dd || HH:mm:ss" + }, + "custom_time_parser_check" : { + "type" : "date", + "format" : "::: k-A || A " + }, + "incomplete_1" : { + "type" : "date", + "format" : "year" + }, + "incomplete_2" : { + "type" : "date", + "format" : "E-w" + }, + "incomplete_custom_date" : { + "type" : "date", + "format" : "uuuu" + }, + "incomplete_custom_time" : { + "type" : "date", + "format" : "HH" + }, + "incorrect" : { + "type" : "date", + "format" : "'___'" + }, + "epoch_sec" : { + "type" : "date", + "format" : "epoch_second" + }, + "epoch_milli" : { + "type" : "date", + "format" : "epoch_millis" + }, + "custom_no_delimiter_date" : { + "type" : "date", + "format" : "uuuuMMdd" + }, + "custom_no_delimiter_time" : { + "type" : "date", + "format" : "HHmmss" + }, + "custom_no_delimiter_ts" : { + "type" : "date", + "format" : "uuuuMMddHHmmss" } } } -} \ No newline at end of file +} diff --git a/legacy/build.gradle b/legacy/build.gradle index dd96884346..d89f7affe7 100644 --- a/legacy/build.gradle +++ b/legacy/build.gradle @@ -88,7 +88,7 @@ dependencies { because 'https://www.whitesourcesoftware.com/vulnerability-database/WS-2019-0379' } } - implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' implementation group: 'org.json', name: 'json', version:'20230227' implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDateType.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDateType.java index 3554a5b2b4..76947bf720 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDateType.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDateType.java @@ -12,6 +12,7 @@ import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; import lombok.EqualsAndHashCode; import org.opensearch.common.time.DateFormatter; @@ -27,10 +28,15 @@ public class OpenSearchDateType extends OpenSearchDataType { private static final OpenSearchDateType instance = new OpenSearchDateType(); + /** Numeric formats which support full datetime. */ + public static final List SUPPORTED_NAMED_NUMERIC_FORMATS = List.of( + FormatNames.EPOCH_MILLIS, + FormatNames.EPOCH_SECOND + ); + + /** List of named formats which support full datetime. */ public static final List SUPPORTED_NAMED_DATETIME_FORMATS = List.of( FormatNames.ISO8601, - FormatNames.EPOCH_MILLIS, - FormatNames.EPOCH_SECOND, FormatNames.BASIC_DATE_TIME, FormatNames.BASIC_DATE_TIME_NO_MILLIS, FormatNames.BASIC_ORDINAL_DATE_TIME, @@ -69,7 +75,7 @@ public class OpenSearchDateType extends OpenSearchDataType { FormatNames.STRICT_WEEK_DATE_TIME_NO_MILLIS ); - // list of named formats that only support year/month/day + /** List of named formats that only support year/month/day. */ public static final List SUPPORTED_NAMED_DATE_FORMATS = List.of( FormatNames.BASIC_DATE, FormatNames.BASIC_ORDINAL_DATE, @@ -77,16 +83,21 @@ public class OpenSearchDateType extends OpenSearchDataType { FormatNames.STRICT_DATE, FormatNames.YEAR_MONTH_DAY, FormatNames.STRICT_YEAR_MONTH_DAY, - FormatNames.YEAR_MONTH, - FormatNames.STRICT_YEAR_MONTH, - FormatNames.YEAR, - FormatNames.STRICT_YEAR, FormatNames.ORDINAL_DATE, FormatNames.STRICT_ORDINAL_DATE, FormatNames.WEEK_DATE, FormatNames.STRICT_WEEK_DATE, FormatNames.WEEKYEAR_WEEK_DAY, - FormatNames.STRICT_WEEKYEAR_WEEK_DAY, + FormatNames.STRICT_WEEKYEAR_WEEK_DAY + ); + + /** list of named formats which produce incomplete date, + * e.g. 1 or 2 are missing from tuple year/month/day. */ + public static final List SUPPORTED_NAMED_INCOMPLETE_DATE_FORMATS = List.of( + FormatNames.YEAR_MONTH, + FormatNames.STRICT_YEAR_MONTH, + FormatNames.YEAR, + FormatNames.STRICT_YEAR, FormatNames.WEEK_YEAR, FormatNames.WEEK_YEAR_WEEK, FormatNames.STRICT_WEEKYEAR_WEEK, @@ -94,7 +105,7 @@ public class OpenSearchDateType extends OpenSearchDataType { FormatNames.STRICT_WEEKYEAR ); - // list of named formats that only support hour/minute/second + /** List of named formats that only support hour/minute/second. */ public static final List SUPPORTED_NAMED_TIME_FORMATS = List.of( FormatNames.BASIC_TIME, FormatNames.BASIC_TIME_NO_MILLIS, @@ -120,12 +131,17 @@ public class OpenSearchDateType extends OpenSearchDataType { FormatNames.STRICT_T_TIME_NO_MILLIS ); + /** Formatter symbols which used to format time or date correspondingly. + * {@link java.time.format.DateTimeFormatter}. */ + private static final String CUSTOM_FORMAT_TIME_SYMBOLS = "nNASsmHkKha"; + private static final String CUSTOM_FORMAT_DATE_SYMBOLS = "FecEWwYqQgdMLDyuG"; + @EqualsAndHashCode.Exclude - String formatString; + private final List formats; private OpenSearchDateType() { super(MappingType.Date); - this.formatString = ""; + this.formats = List.of(); } private OpenSearchDateType(ExprCoreType exprCoreType) { @@ -138,102 +154,194 @@ private OpenSearchDateType(ExprType exprType) { this.exprCoreType = (ExprCoreType) exprType; } - private OpenSearchDateType(String formatStringArg) { + private OpenSearchDateType(String format) { super(MappingType.Date); - this.formatString = formatStringArg; - this.exprCoreType = getExprTypeFromFormatString(formatStringArg); + this.formats = getFormatList(format); + this.exprCoreType = getExprTypeFromFormatString(format); + } + + public boolean hasFormats() { + return !formats.isEmpty(); } /** * Retrieves and splits a user defined format string from the mapping into a list of formats. * @return A list of format names and user defined formats. */ - private List getFormatList() { - String format = strip8Prefix(formatString); - List patterns = splitCombinedPatterns(format); - return patterns; + private List getFormatList(String format) { + format = strip8Prefix(format); + return splitCombinedPatterns(format).stream().map(String::trim).collect(Collectors.toList()); } - /** * Retrieves a list of named OpenSearch formatters given by user mapping. * @return a list of DateFormatters that can be used to parse a Date/Time/Timestamp. */ public List getAllNamedFormatters() { - return getFormatList().stream() + return formats.stream() .filter(formatString -> FormatNames.forName(formatString) != null) .map(DateFormatter::forPattern).collect(Collectors.toList()); } + /** + * Retrieves a list of numeric formatters that format for dates. + * @return a list of DateFormatters that can be used to parse a Date. + */ + public List getNumericNamedFormatters() { + return formats.stream() + .filter(formatString -> { + FormatNames namedFormat = FormatNames.forName(formatString); + return namedFormat != null && SUPPORTED_NAMED_NUMERIC_FORMATS.contains(namedFormat); + }) + .map(DateFormatter::forPattern).collect(Collectors.toList()); + } + + /** + * Retrieves a list of custom formats defined by the user. + * @return a list of formats as strings that can be used to parse a Date/Time/Timestamp. + */ + public List getAllCustomFormats() { + return formats.stream() + .filter(format -> FormatNames.forName(format) == null) + .map(format -> { + try { + DateFormatter.forPattern(format); + return format; + } catch (Exception ignored) { + // parsing failed + return null; + } + }) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + } + /** * Retrieves a list of custom formatters defined by the user. * @return a list of DateFormatters that can be used to parse a Date/Time/Timestamp. */ public List getAllCustomFormatters() { - return getFormatList().stream() - .filter(formatString -> FormatNames.forName(formatString) == null) - .map(DateFormatter::forPattern).collect(Collectors.toList()); + return getAllCustomFormats().stream() + .map(DateFormatter::forPattern) + .collect(Collectors.toList()); } /** * Retrieves a list of named formatters that format for dates. - * * @return a list of DateFormatters that can be used to parse a Date. */ public List getDateNamedFormatters() { - return getFormatList().stream() + return formats.stream() .filter(formatString -> { FormatNames namedFormat = FormatNames.forName(formatString); - return SUPPORTED_NAMED_DATE_FORMATS.contains(namedFormat); + return namedFormat != null && SUPPORTED_NAMED_DATE_FORMATS.contains(namedFormat); }) .map(DateFormatter::forPattern).collect(Collectors.toList()); } /** * Retrieves a list of named formatters that format for Times. - * * @return a list of DateFormatters that can be used to parse a Time. */ public List getTimeNamedFormatters() { - return getFormatList().stream() + return formats.stream() .filter(formatString -> { FormatNames namedFormat = FormatNames.forName(formatString); - return SUPPORTED_NAMED_TIME_FORMATS.contains(namedFormat); + return namedFormat != null && SUPPORTED_NAMED_TIME_FORMATS.contains(namedFormat); }) .map(DateFormatter::forPattern).collect(Collectors.toList()); } - private ExprCoreType getExprTypeFromFormatString(String formatString) { - if (formatString.isEmpty()) { - // FOLLOW-UP: check the default formatter - and set it here instead - // of assuming that the default is always a timestamp - return TIMESTAMP; + /** + * Retrieves a list of named formatters that format for DateTimes. + * @return a list of DateFormatters that can be used to parse a DateTime. + */ + public List getDateTimeNamedFormatters() { + return formats.stream() + .filter(formatString -> { + FormatNames namedFormat = FormatNames.forName(formatString); + return namedFormat != null && SUPPORTED_NAMED_DATETIME_FORMATS.contains(namedFormat); + }) + .map(DateFormatter::forPattern).collect(Collectors.toList()); + } + + private ExprCoreType getExprTypeFromCustomFormats(List formats) { + boolean isDate = false; + boolean isTime = false; + + for (String format : formats) { + if (!isTime) { + for (char symbol : CUSTOM_FORMAT_TIME_SYMBOLS.toCharArray()) { + if (format.contains(String.valueOf(symbol))) { + isTime = true; + break; + } + } + } + if (!isDate) { + for (char symbol : CUSTOM_FORMAT_DATE_SYMBOLS.toCharArray()) { + if (format.contains(String.valueOf(symbol))) { + isDate = true; + break; + } + } + } + if (isDate && isTime) { + return TIMESTAMP; + } + } + + if (isDate) { + return DATE; + } + if (isTime) { + return TIME; } - List namedFormatters = getAllNamedFormatters(); + // Incomplete or incorrect formats: can't be converted to DATE nor TIME, for example `year` + return TIMESTAMP; + } - if (namedFormatters.isEmpty()) { + private ExprCoreType getExprTypeFromFormatString(String formatString) { + List datetimeFormatters = getDateTimeNamedFormatters(); + List numericFormatters = getNumericNamedFormatters(); + + if (formatString.isEmpty() || !datetimeFormatters.isEmpty() || !numericFormatters.isEmpty()) { return TIMESTAMP; } - if (!getAllCustomFormatters().isEmpty()) { - // FOLLOW-UP: support custom format in + List timeFormatters = getTimeNamedFormatters(); + List dateFormatters = getDateNamedFormatters(); + if (!timeFormatters.isEmpty() && !dateFormatters.isEmpty()) { return TIMESTAMP; } + List customFormatters = getAllCustomFormats(); + if (!customFormatters.isEmpty()) { + ExprCoreType customFormatType = getExprTypeFromCustomFormats(customFormatters); + ExprCoreType combinedByDefaultFormats = customFormatType; + if (!dateFormatters.isEmpty()) { + combinedByDefaultFormats = DATE; + } + if (!timeFormatters.isEmpty()) { + combinedByDefaultFormats = TIME; + } + return customFormatType == combinedByDefaultFormats ? customFormatType : TIMESTAMP; + } + // if there is nothing in the dateformatter that accepts a year/month/day, then // we can assume the type is strictly a Time object - if (namedFormatters.size() == getTimeNamedFormatters().size()) { + if (!timeFormatters.isEmpty()) { return TIME; } // if there is nothing in the dateformatter that accepts a hour/minute/second, then // we can assume the type is strictly a Date object - if (namedFormatters.size() == getDateNamedFormatters().size()) { + if (!dateFormatters.isEmpty()) { return DATE; } - // According to the user mapping, this field may contain a DATE or a TIME + // Unknown or incorrect format provided return TIMESTAMP; } @@ -280,7 +388,7 @@ public static OpenSearchDateType of() { @Override public List getParent() { - return List.of(this.exprCoreType); + return List.of(exprCoreType); } @Override @@ -290,9 +398,9 @@ public boolean shouldCast(ExprType other) { @Override protected OpenSearchDataType cloneEmpty() { - if (this.formatString.isEmpty()) { - return OpenSearchDateType.of(this.exprCoreType); + if (formats.isEmpty()) { + return OpenSearchDateType.of(exprCoreType); } - return OpenSearchDateType.of(this.formatString); + return OpenSearchDateType.of(String.join(" || ", formats)); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java index abad197bd4..22a43d3444 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java @@ -45,6 +45,7 @@ import lombok.Setter; import org.opensearch.common.time.DateFormatter; import org.opensearch.common.time.DateFormatters; +import org.opensearch.common.time.FormatNames; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.model.ExprByteValue; import org.opensearch.sql.data.model.ExprCollectionValue; @@ -60,6 +61,7 @@ import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.opensearch.data.type.OpenSearchBinaryType; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; @@ -82,7 +84,7 @@ public class OpenSearchExprValueFactory { /** * Extend existing mapping by new data without overwrite. - * Called from aggregation only {@link AggregationQueryBuilder#buildTypeMapping}. + * Called from aggregation only {@see AggregationQueryBuilder#buildTypeMapping}. * @param typeMapping A data type mapping produced by aggregation. */ public void extendTypeMapping(Map typeMapping) { @@ -124,14 +126,10 @@ public void extendTypeMapping(Map typeMapping) { .put(OpenSearchDataType.of(OpenSearchDataType.MappingType.Boolean), (c, dt) -> ExprBooleanValue.of(c.booleanValue())) //Handles the creation of DATE, TIME & DATETIME - .put(OpenSearchDateType.of(TIME), - this::createOpenSearchDateType) - .put(OpenSearchDateType.of(DATE), - this::createOpenSearchDateType) - .put(OpenSearchDateType.of(TIMESTAMP), - this::createOpenSearchDateType) - .put(OpenSearchDateType.of(DATETIME), - this::createOpenSearchDateType) + .put(OpenSearchDateType.of(TIME), this::createOpenSearchDateType) + .put(OpenSearchDateType.of(DATE), this::createOpenSearchDateType) + .put(OpenSearchDateType.of(TIMESTAMP), this::createOpenSearchDateType) + .put(OpenSearchDateType.of(DATETIME), this::createOpenSearchDateType) .put(OpenSearchDataType.of(OpenSearchDataType.MappingType.Ip), (c, dt) -> new OpenSearchExprIpValue(c.stringValue())) .put(OpenSearchDataType.of(OpenSearchDataType.MappingType.GeoPoint), @@ -217,137 +215,81 @@ private Optional type(String field) { } /** - * Parses value with the first matching formatter as an Instant to UTF. - * - * @param value - timestamp as string - * @param dateType - field type - * @return Instant without timezone - */ - private ExprValue parseTimestampString(String value, OpenSearchDateType dateType) { - Instant parsed = null; - for (DateFormatter formatter : dateType.getAllNamedFormatters()) { - try { - TemporalAccessor accessor = formatter.parse(value); - ZonedDateTime zonedDateTime = DateFormatters.from(accessor); - // remove the Zone - parsed = zonedDateTime.withZoneSameLocal(ZoneId.of("Z")).toInstant(); - } catch (IllegalArgumentException ignored) { - // nothing to do, try another format - } - } - - // FOLLOW-UP PR: Check custom formatters too - - // if no named formatters are available, use the default - if (dateType.getAllNamedFormatters().size() == 0 - || dateType.getAllCustomFormatters().size() > 0) { - try { - parsed = DateFormatters.from(DATE_TIME_FORMATTER.parse(value)).toInstant(); - } catch (DateTimeParseException e) { - // ignored - } - } - - if (parsed == null) { - // otherwise, throw an error that no formatters worked - throw new IllegalArgumentException( - String.format( - "Construct ExprTimestampValue from \"%s\" failed, unsupported date format.", value) - ); - } - - return new ExprTimestampValue(parsed); - } - - /** - * return the first matching formatter as a time without timezone. + * Parse value with the first matching formatter into {@link ExprValue} + * with corresponding {@link ExprCoreType}. * * @param value - time as string - * @param dateType - field data type - * @return time without timezone + * @param dataType - field data type + * @return Parsed value */ - private ExprValue parseTimeString(String value, OpenSearchDateType dateType) { - for (DateFormatter formatter : dateType.getAllNamedFormatters()) { - try { - TemporalAccessor accessor = formatter.parse(value); - ZonedDateTime zonedDateTime = DateFormatters.from(accessor); - return new ExprTimeValue( - zonedDateTime.withZoneSameLocal(ZoneId.of("Z")).toLocalTime()); - } catch (IllegalArgumentException ignored) { - // nothing to do, try another format - } - } + private ExprValue parseDateTimeString(String value, OpenSearchDateType dataType) { + List formatters = dataType.getAllNamedFormatters(); + formatters.addAll(dataType.getAllCustomFormatters()); + ExprCoreType returnFormat = (ExprCoreType) dataType.getExprType(); - // if no named formatters are available, use the default - if (dateType.getAllNamedFormatters().size() == 0) { - try { - return new ExprTimeValue( - DateFormatters.from(STRICT_HOUR_MINUTE_SECOND_FORMATTER.parse(value)).toLocalTime()); - } catch (DateTimeParseException e) { - // ignored - } - } - throw new IllegalArgumentException("Construct ExprTimeValue from \"" + value - + "\" failed, unsupported time format."); - } - - /** - * return the first matching formatter as a date without timezone. - * - * @param value - date as string - * @param dateType - field data type - * @return date without timezone - */ - private ExprValue parseDateString(String value, OpenSearchDateType dateType) { - for (DateFormatter formatter : dateType.getAllNamedFormatters()) { + for (DateFormatter formatter : formatters) { try { TemporalAccessor accessor = formatter.parse(value); ZonedDateTime zonedDateTime = DateFormatters.from(accessor); - // return the first matching formatter as a date without timezone - return new ExprDateValue( - zonedDateTime.withZoneSameLocal(ZoneId.of("Z")).toLocalDate()); - } catch (IllegalArgumentException ignored) { + switch (returnFormat) { + case TIME: return new ExprTimeValue( + zonedDateTime.withZoneSameLocal(UTC_ZONE_ID).toLocalTime()); + case DATE: return new ExprDateValue( + zonedDateTime.withZoneSameLocal(UTC_ZONE_ID).toLocalDate()); + default: return new ExprTimestampValue( + zonedDateTime.withZoneSameLocal(UTC_ZONE_ID).toInstant()); + } + } catch (IllegalArgumentException ignored) { // nothing to do, try another format } } - // if no named formatters are available, use the default - if (dateType.getAllNamedFormatters().size() == 0) { - try { - return new ExprDateValue( + // if no formatters are available, try the default formatter + try { + switch (returnFormat) { + case TIME: return new ExprTimeValue( + DateFormatters.from(STRICT_HOUR_MINUTE_SECOND_FORMATTER.parse(value)).toLocalTime()); + case DATE: return new ExprDateValue( DateFormatters.from(STRICT_YEAR_MONTH_DAY_FORMATTER.parse(value)).toLocalDate()); - } catch (DateTimeParseException e) { - // ignored + default: return new ExprTimestampValue( + DateFormatters.from(DATE_TIME_FORMATTER.parse(value)).toInstant()); } + } catch (DateTimeParseException ignored) { + // ignored } - throw new IllegalArgumentException("Construct ExprDateValue from \"" + value - + "\" failed, unsupported date format."); + + throw new IllegalArgumentException(String.format( + "Construct %s from \"%s\" failed, unsupported format.", returnFormat, value)); } private ExprValue createOpenSearchDateType(Content value, ExprType type) { OpenSearchDateType dt = (OpenSearchDateType) type; ExprType returnFormat = dt.getExprType(); - if (value.isNumber()) { - Instant epochMillis = Instant.ofEpochMilli(value.longValue()); - if (returnFormat == TIME) { - return new ExprTimeValue(LocalTime.from(epochMillis.atZone(UTC_ZONE_ID))); - } - if (returnFormat == DATE) { - return new ExprDateValue(LocalDate.ofInstant(epochMillis, UTC_ZONE_ID)); + if (value.isNumber()) { // isNumber + var numFormatters = dt.getNumericNamedFormatters(); + if (numFormatters.size() > 0 || !dt.hasFormats()) { + long epochMillis = 0; + if (numFormatters.contains(DateFormatter.forPattern( + FormatNames.EPOCH_SECOND.getSnakeCaseName()))) { + // no CamelCase for `EPOCH_*` formats + epochMillis = value.longValue() * 1000; + } else /* EPOCH_MILLIS */ { + epochMillis = value.longValue(); + } + Instant instant = Instant.ofEpochMilli(epochMillis); + switch ((ExprCoreType) returnFormat) { + case TIME: return new ExprTimeValue(LocalTime.from(instant.atZone(UTC_ZONE_ID))); + case DATE: return new ExprDateValue(LocalDate.ofInstant(instant, UTC_ZONE_ID)); + default: return new ExprTimestampValue(instant); + } + } else { + // custom format + return parseDateTimeString(value.stringValue(), dt); } - return new ExprTimestampValue(Instant.ofEpochMilli(value.longValue())); } - if (value.isString()) { - if (returnFormat == TIME) { - return parseTimeString(value.stringValue(), dt); - } - if (returnFormat == DATE) { - return parseDateString(value.stringValue(), dt); - } - // else timestamp/datetime - return parseTimestampString(value.stringValue(), dt); + return parseDateTimeString(value.stringValue(), dt); } return new ExprTimestampValue((Instant) value.objectValue()); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java index 973624d19a..0bbab796be 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java @@ -10,6 +10,7 @@ import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_ID; import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_INDEX; import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_MAXSCORE; +import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_ROUTING; import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_SCORE; import static org.opensearch.sql.opensearch.storage.OpenSearchIndex.METADATA_FIELD_SORT; @@ -185,8 +186,10 @@ private void addMetaDataFieldsToBuilder( if (maxScore != null) { builder.put(METADATA_FIELD_MAXSCORE, maxScore); } - } else { // if (metaDataField.equals(METADATA_FIELD_SORT)) { + } else if (metaDataField.equals(METADATA_FIELD_SORT)) { builder.put(METADATA_FIELD_SORT, new ExprLongValue(hit.getSeqNo())); + } else { // if (metaDataField.equals(METADATA_FIELD_ROUTING)){ + builder.put(METADATA_FIELD_ROUTING, new ExprStringValue(hit.getShard().toString())); } }); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index 671f4113be..01c3aeb30d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -113,7 +113,8 @@ public class OpenSearchSettings extends Settings { ENCYRPTION_MASTER_KEY.getKeyValue(), "0000000000000000", Setting.Property.NodeScope, - Setting.Property.Final); + Setting.Property.Final, + Setting.Property.Filtered); public static final Setting DATASOURCE_URI_ALLOW_HOSTS = Setting.simpleString( Key.DATASOURCES_URI_ALLOWHOSTS.getKeyValue(), diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index 6c620e5042..62617f744e 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -45,12 +45,15 @@ public class OpenSearchIndex implements Table { public static final String METADATA_FIELD_MAXSCORE = "_maxscore"; public static final String METADATA_FIELD_SORT = "_sort"; + public static final String METADATA_FIELD_ROUTING = "_routing"; + public static final java.util.Map METADATAFIELD_TYPE_MAP = Map.of( METADATA_FIELD_ID, ExprCoreType.STRING, METADATA_FIELD_INDEX, ExprCoreType.STRING, METADATA_FIELD_SCORE, ExprCoreType.FLOAT, METADATA_FIELD_MAXSCORE, ExprCoreType.FLOAT, - METADATA_FIELD_SORT, ExprCoreType.LONG + METADATA_FIELD_SORT, ExprCoreType.LONG, + METADATA_FIELD_ROUTING, ExprCoreType.STRING ); /** OpenSearch client connection. */ diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/type/OpenSearchDateTypeTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/type/OpenSearchDateTypeTest.java index f0add5bcd9..13393da732 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/type/OpenSearchDateTypeTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/type/OpenSearchDateTypeTest.java @@ -3,34 +3,37 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.opensearch.data.type; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNotSame; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assumptions.assumeFalse; +import static org.junit.jupiter.api.Assertions.fail; import static org.opensearch.sql.data.type.ExprCoreType.DATE; import static org.opensearch.sql.data.type.ExprCoreType.DATETIME; import static org.opensearch.sql.data.type.ExprCoreType.TIME; import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; import static org.opensearch.sql.opensearch.data.type.OpenSearchDateType.SUPPORTED_NAMED_DATETIME_FORMATS; import static org.opensearch.sql.opensearch.data.type.OpenSearchDateType.SUPPORTED_NAMED_DATE_FORMATS; +import static org.opensearch.sql.opensearch.data.type.OpenSearchDateType.SUPPORTED_NAMED_INCOMPLETE_DATE_FORMATS; +import static org.opensearch.sql.opensearch.data.type.OpenSearchDateType.SUPPORTED_NAMED_NUMERIC_FORMATS; import static org.opensearch.sql.opensearch.data.type.OpenSearchDateType.SUPPORTED_NAMED_TIME_FORMATS; import static org.opensearch.sql.opensearch.data.type.OpenSearchDateType.isDateTypeCompatible; +import com.google.common.collect.Lists; import java.util.EnumSet; +import java.util.List; +import java.util.stream.Stream; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.opensearch.common.time.FormatNames; -import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.data.type.ExprCoreType; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchDateTypeTest { @@ -53,150 +56,207 @@ class OpenSearchDateTypeTest { @Test public void isCompatible() { - // timestamp types is compatible with all date-types - assertTrue(TIMESTAMP.isCompatible(defaultDateType)); - assertTrue(TIMESTAMP.isCompatible(dateDateType)); - assertTrue(TIMESTAMP.isCompatible(timeDateType)); - assertTrue(TIMESTAMP.isCompatible(datetimeDateType)); - - // datetime - assertFalse(DATETIME.isCompatible(defaultDateType)); - assertTrue(DATETIME.isCompatible(dateDateType)); - assertTrue(DATETIME.isCompatible(timeDateType)); - assertFalse(DATETIME.isCompatible(datetimeDateType)); - - // time type - assertFalse(TIME.isCompatible(defaultDateType)); - assertFalse(TIME.isCompatible(dateDateType)); - assertTrue(TIME.isCompatible(timeDateType)); - assertFalse(TIME.isCompatible(datetimeDateType)); - - // date type - assertFalse(DATE.isCompatible(defaultDateType)); - assertTrue(DATE.isCompatible(dateDateType)); - assertFalse(DATE.isCompatible(timeDateType)); - assertFalse(DATE.isCompatible(datetimeDateType)); + assertAll( + // timestamp types is compatible with all date-types + () -> assertTrue(TIMESTAMP.isCompatible(defaultDateType)), + () -> assertTrue(TIMESTAMP.isCompatible(dateDateType)), + () -> assertTrue(TIMESTAMP.isCompatible(timeDateType)), + () -> assertTrue(TIMESTAMP.isCompatible(datetimeDateType)), + + // datetime + () -> assertFalse(DATETIME.isCompatible(defaultDateType)), + () -> assertTrue(DATETIME.isCompatible(dateDateType)), + () -> assertTrue(DATETIME.isCompatible(timeDateType)), + () -> assertFalse(DATETIME.isCompatible(datetimeDateType)), + + // time type + () -> assertFalse(TIME.isCompatible(defaultDateType)), + () -> assertFalse(TIME.isCompatible(dateDateType)), + () -> assertTrue(TIME.isCompatible(timeDateType)), + () -> assertFalse(TIME.isCompatible(datetimeDateType)), + + // date type + () -> assertFalse(DATE.isCompatible(defaultDateType)), + () -> assertTrue(DATE.isCompatible(dateDateType)), + () -> assertFalse(DATE.isCompatible(timeDateType)), + () -> assertFalse(DATE.isCompatible(datetimeDateType)) + ); } // `typeName` and `legacyTypeName` return the same thing for date objects: // https://github.com/opensearch-project/sql/issues/1296 @Test public void check_typeName() { - // always use the MappingType of "DATE" - assertEquals("DATE", defaultDateType.typeName()); - assertEquals("DATE", timeDateType.typeName()); - assertEquals("DATE", dateDateType.typeName()); - assertEquals("DATE", datetimeDateType.typeName()); + assertAll( + // always use the MappingType of "DATE" + () -> assertEquals("DATE", defaultDateType.typeName()), + () -> assertEquals("DATE", timeDateType.typeName()), + () -> assertEquals("DATE", dateDateType.typeName()), + () -> assertEquals("DATE", datetimeDateType.typeName()) + ); } @Test public void check_legacyTypeName() { - // always use the legacy "DATE" type - assertEquals("DATE", defaultDateType.legacyTypeName()); - assertEquals("DATE", timeDateType.legacyTypeName()); - assertEquals("DATE", dateDateType.legacyTypeName()); - assertEquals("DATE", datetimeDateType.legacyTypeName()); + assertAll( + // always use the legacy "DATE" type + () -> assertEquals("DATE", defaultDateType.legacyTypeName()), + () -> assertEquals("DATE", timeDateType.legacyTypeName()), + () -> assertEquals("DATE", dateDateType.legacyTypeName()), + () -> assertEquals("DATE", datetimeDateType.legacyTypeName()) + ); } @Test public void check_exprTypeName() { - // exprType changes based on type (no datetime): - assertEquals(TIMESTAMP, defaultDateType.getExprType()); - assertEquals(TIME, timeDateType.getExprType()); - assertEquals(DATE, dateDateType.getExprType()); - assertEquals(TIMESTAMP, datetimeDateType.getExprType()); + assertAll( + // exprType changes based on type (no datetime): + () -> assertEquals(TIMESTAMP, defaultDateType.getExprType()), + () -> assertEquals(TIME, timeDateType.getExprType()), + () -> assertEquals(DATE, dateDateType.getExprType()), + () -> assertEquals(TIMESTAMP, datetimeDateType.getExprType()) + ); } - @Test - public void checkSupportedFormatNamesCoverage() { - EnumSet allFormatNames = EnumSet.allOf(FormatNames.class); - allFormatNames.stream().forEach(formatName -> { - assertTrue( - SUPPORTED_NAMED_DATETIME_FORMATS.contains(formatName) - || SUPPORTED_NAMED_DATE_FORMATS.contains(formatName) - || SUPPORTED_NAMED_TIME_FORMATS.contains(formatName), - formatName + " not supported"); - }); + private static Stream getAllSupportedFormats() { + return EnumSet.allOf(FormatNames.class).stream().map(Arguments::of); } - @Test - public void checkTimestampFormatNames() { - SUPPORTED_NAMED_DATETIME_FORMATS.stream().forEach( - datetimeFormat -> { - String camelCaseName = datetimeFormat.getCamelCaseName(); - if (camelCaseName != null && !camelCaseName.isEmpty()) { - OpenSearchDateType dateType = - OpenSearchDateType.of(camelCaseName); - assertTrue(dateType.getExprType() == TIMESTAMP, camelCaseName - + " does not format to a TIMESTAMP type, instead got " - + dateType.getExprType()); - } - - String snakeCaseName = datetimeFormat.getSnakeCaseName(); - if (snakeCaseName != null && !snakeCaseName.isEmpty()) { - OpenSearchDateType dateType = OpenSearchDateType.of(snakeCaseName); - assertTrue(dateType.getExprType() == TIMESTAMP, snakeCaseName - + " does not format to a TIMESTAMP type, instead got " - + dateType.getExprType()); - } - } - ); + @ParameterizedTest + @MethodSource("getAllSupportedFormats") + public void check_supported_format_names_coverage(FormatNames formatName) { + assertTrue(SUPPORTED_NAMED_NUMERIC_FORMATS.contains(formatName) + || SUPPORTED_NAMED_DATETIME_FORMATS.contains(formatName) + || SUPPORTED_NAMED_DATE_FORMATS.contains(formatName) + || SUPPORTED_NAMED_TIME_FORMATS.contains(formatName) + || SUPPORTED_NAMED_INCOMPLETE_DATE_FORMATS.contains(formatName), + formatName + " not supported"); + } - // check the default format case - OpenSearchDateType dateType = OpenSearchDateType.of(""); - assertTrue(dateType.getExprType() == TIMESTAMP); + private static Stream getSupportedDatetimeFormats() { + return SUPPORTED_NAMED_DATETIME_FORMATS.stream().map(Arguments::of); } - @Test - public void checkDateFormatNames() { - SUPPORTED_NAMED_DATE_FORMATS.stream().forEach( - dateFormat -> { - String camelCaseName = dateFormat.getCamelCaseName(); - if (camelCaseName != null && !camelCaseName.isEmpty()) { - OpenSearchDateType dateType = - OpenSearchDateType.of(camelCaseName); - assertTrue(dateType.getExprType() == DATE, camelCaseName - + " does not format to a DATE type, instead got " - + dateType.getExprType()); - } - - String snakeCaseName = dateFormat.getSnakeCaseName(); - if (snakeCaseName != null && !snakeCaseName.isEmpty()) { - OpenSearchDateType dateType = OpenSearchDateType.of(snakeCaseName); - assertTrue(dateType.getExprType() == DATE, snakeCaseName - + " does not format to a DATE type, instead got " - + dateType.getExprType()); - } - } + @ParameterizedTest + @MethodSource("getSupportedDatetimeFormats") + public void check_datetime_format_names(FormatNames datetimeFormat) { + String camelCaseName = datetimeFormat.getCamelCaseName(); + if (camelCaseName != null && !camelCaseName.isEmpty()) { + OpenSearchDateType dateType = + OpenSearchDateType.of(camelCaseName); + assertSame(dateType.getExprType(), TIMESTAMP, camelCaseName + + " does not format to a TIMESTAMP type, instead got " + dateType.getExprType()); + } + + String snakeCaseName = datetimeFormat.getSnakeCaseName(); + if (snakeCaseName != null && !snakeCaseName.isEmpty()) { + OpenSearchDateType dateType = OpenSearchDateType.of(snakeCaseName); + assertSame(dateType.getExprType(), TIMESTAMP, snakeCaseName + + " does not format to a TIMESTAMP type, instead got " + dateType.getExprType()); + } else { + fail(); + } + } + + private static Stream getSupportedDateFormats() { + return SUPPORTED_NAMED_DATE_FORMATS.stream().map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("getSupportedDateFormats") + public void check_date_format_names(FormatNames dateFormat) { + String camelCaseName = dateFormat.getCamelCaseName(); + if (camelCaseName != null && !camelCaseName.isEmpty()) { + OpenSearchDateType dateType = OpenSearchDateType.of(camelCaseName); + assertSame(dateType.getExprType(), DATE, camelCaseName + + " does not format to a DATE type, instead got " + dateType.getExprType()); + } + + String snakeCaseName = dateFormat.getSnakeCaseName(); + if (snakeCaseName != null && !snakeCaseName.isEmpty()) { + OpenSearchDateType dateType = OpenSearchDateType.of(snakeCaseName); + assertSame(dateType.getExprType(), DATE, snakeCaseName + + " does not format to a DATE type, instead got " + dateType.getExprType()); + } else { + fail(); + } + } + + private static Stream getSupportedTimeFormats() { + return SUPPORTED_NAMED_TIME_FORMATS.stream().map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("getSupportedTimeFormats") + public void check_time_format_names(FormatNames timeFormat) { + String camelCaseName = timeFormat.getCamelCaseName(); + if (camelCaseName != null && !camelCaseName.isEmpty()) { + OpenSearchDateType dateType = OpenSearchDateType.of(camelCaseName); + assertSame(dateType.getExprType(), TIME, camelCaseName + + " does not format to a TIME type, instead got " + dateType.getExprType()); + } + + String snakeCaseName = timeFormat.getSnakeCaseName(); + if (snakeCaseName != null && !snakeCaseName.isEmpty()) { + OpenSearchDateType dateType = OpenSearchDateType.of(snakeCaseName); + assertSame(dateType.getExprType(), TIME, snakeCaseName + + " does not format to a TIME type, instead got " + dateType.getExprType()); + } else { + fail(); + } + } + + private static Stream get_format_combinations_for_test() { + return Stream.of( + Arguments.of(DATE, List.of("dd.MM.yyyy", "date"), "d && custom date"), + Arguments.of(TIME, List.of("time", "HH:mm"), "t && custom time"), + Arguments.of(TIMESTAMP, List.of("dd.MM.yyyy", "time"), "t && custom date"), + Arguments.of(TIMESTAMP, List.of("date", "HH:mm"), "d && custom time"), + Arguments.of(TIMESTAMP, List.of("dd.MM.yyyy HH:mm", "date_time"), "dt && custom datetime"), + Arguments.of(TIMESTAMP, List.of("dd.MM.yyyy", "date_time"), "dt && custom date"), + Arguments.of(TIMESTAMP, List.of("HH:mm", "date_time"), "dt && custom time"), + Arguments.of(TIMESTAMP, List.of("dd.MM.yyyy", "epoch_second"), "custom date && num"), + Arguments.of(TIMESTAMP, List.of("HH:mm", "epoch_second"), "custom time && num"), + Arguments.of(TIMESTAMP, List.of("date_time", "epoch_second"), "dt && num"), + Arguments.of(TIMESTAMP, List.of("date", "epoch_second"), "d && num"), + Arguments.of(TIMESTAMP, List.of("time", "epoch_second"), "t && num"), + Arguments.of(TIMESTAMP, List.of(""), "no formats given"), + Arguments.of(TIMESTAMP, List.of("time", "date"), "t && d"), + Arguments.of(TIMESTAMP, List.of("epoch_second"), "numeric"), + Arguments.of(TIME, List.of("time"), "t"), + Arguments.of(DATE, List.of("date"), "d"), + Arguments.of(TIMESTAMP, List.of("date_time"), "dt"), + Arguments.of(TIMESTAMP, List.of("unknown"), "unknown/incorrect"), + Arguments.of(DATE, List.of("uuuu"), "incomplete date"), + Arguments.of(TIME, List.of("HH"), "incomplete time"), + Arguments.of(DATE, List.of("E-w"), "incomplete"), + // E - day of week, w - week of year + Arguments.of(DATE, List.of("uuuu", "E-w"), "incomplete with year"), + Arguments.of(TIMESTAMP, List.of("---"), "incorrect"), + Arguments.of(TIMESTAMP, List.of("dd.MM.yyyy", "HH:mm"), "custom date and time"), + // D - day of year, N - nano of day + Arguments.of(TIMESTAMP, List.of("dd.MM.yyyy N", "uuuu:D:HH:mm"), "custom datetime"), + Arguments.of(DATE, List.of("dd.MM.yyyy", "uuuu:D"), "custom date"), + Arguments.of(TIME, List.of("HH:mm", "N"), "custom time") ); } + @ParameterizedTest(name = "[{index}] {2}") + @MethodSource("get_format_combinations_for_test") + public void check_ExprCoreType_of_combinations_of_custom_and_predefined_formats( + ExprCoreType expected, List formats, String testName) { + assertEquals(expected, OpenSearchDateType.of(String.join(" || ", formats)).getExprType()); + formats = Lists.reverse(formats); + assertEquals(expected, OpenSearchDateType.of(String.join(" || ", formats)).getExprType()); + } + @Test - public void checkTimeFormatNames() { - SUPPORTED_NAMED_TIME_FORMATS.stream().forEach( - timeFormat -> { - String camelCaseName = timeFormat.getCamelCaseName(); - if (camelCaseName != null && !camelCaseName.isEmpty()) { - OpenSearchDateType dateType = - OpenSearchDateType.of(camelCaseName); - assertTrue(dateType.getExprType() == TIME, camelCaseName - + " does not format to a TIME type, instead got " - + dateType.getExprType()); - } - - String snakeCaseName = timeFormat.getSnakeCaseName(); - if (snakeCaseName != null && !snakeCaseName.isEmpty()) { - OpenSearchDateType dateType = OpenSearchDateType.of(snakeCaseName); - assertTrue(dateType.getExprType() == TIME, snakeCaseName - + " does not format to a TIME type, instead got " - + dateType.getExprType()); - } - } - ); + public void dont_use_incorrect_format_as_custom() { + assertEquals(0, OpenSearchDateType.of(" ").getAllCustomFormatters().size()); } @Test - public void checkIfDateTypeCompatible() { + public void check_if_date_type_compatible() { assertTrue(isDateTypeCompatible(DATE)); assertFalse(isDateTypeCompatible(OpenSearchDataType.of( OpenSearchDataType.MappingType.Text))); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java index a7e3531e8b..827606a961 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java @@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.data.model.ExprValueUtils.booleanValue; @@ -79,12 +80,17 @@ class OpenSearchExprValueFactoryTest { .put("dateStringV", OpenSearchDateType.of("date")) .put("timeStringV", OpenSearchDateType.of("time")) .put("epochMillisV", OpenSearchDateType.of("epoch_millis")) - .put("dateOrEpochMillisV", OpenSearchDateType.of("date_time_no_millis||epoch_millis")) - .put("timeNoMillisOrTimeV", OpenSearchDateType.of("time_no_millis||time")) - .put("dateOrOrdinalDateV", OpenSearchDateType.of("date||ordinal_date")) + .put("epochSecondV", OpenSearchDateType.of("epoch_second")) + .put("timeCustomV", OpenSearchDateType.of("HHmmss")) + .put("dateCustomV", OpenSearchDateType.of("uuuuMMdd")) + .put("dateTimeCustomV", OpenSearchDateType.of("uuuuMMddHHmmss")) + .put("dateOrEpochMillisV", OpenSearchDateType.of("date_time_no_millis || epoch_millis")) + .put("timeNoMillisOrTimeV", OpenSearchDateType.of("time_no_millis || time")) + .put("dateOrOrdinalDateV", OpenSearchDateType.of("date || ordinal_date")) .put("customFormatV", OpenSearchDateType.of("yyyy-MM-dd-HH-mm-ss")) .put("customAndEpochMillisV", - OpenSearchDateType.of("yyyy-MM-dd-HH-mm-ss||epoch_millis")) + OpenSearchDateType.of("yyyy-MM-dd-HH-mm-ss || epoch_millis")) + .put("incompleteFormatV", OpenSearchDateType.of("year")) .put("boolV", OpenSearchDataType.of(BOOLEAN)) .put("structV", OpenSearchDataType.of(STRUCT)) .put("structV.id", OpenSearchDataType.of(INTEGER)) @@ -116,26 +122,32 @@ class OpenSearchExprValueFactoryTest { @Test public void constructNullValue() { - assertEquals(nullValue(), tupleValue("{\"intV\":null}").get("intV")); - assertEquals(nullValue(), constructFromObject("intV", null)); - assertTrue(new OpenSearchJsonContent(null).isNull()); + assertAll( + () -> assertEquals(nullValue(), tupleValue("{\"intV\":null}").get("intV")), + () -> assertEquals(nullValue(), constructFromObject("intV", null)), + () -> assertTrue(new OpenSearchJsonContent(null).isNull()) + ); } @Test public void iterateArrayValue() throws JsonProcessingException { ObjectMapper mapper = new ObjectMapper(); var arrayIt = new OpenSearchJsonContent(mapper.readTree("[\"zz\",\"bb\"]")).array(); - assertTrue(arrayIt.next().stringValue().equals("zz")); - assertTrue(arrayIt.next().stringValue().equals("bb")); - assertTrue(!arrayIt.hasNext()); + assertAll( + () -> assertEquals("zz", arrayIt.next().stringValue()), + () -> assertEquals("bb", arrayIt.next().stringValue()), + () -> assertFalse(arrayIt.hasNext()) + ); } @Test public void iterateArrayValueWithOneElement() throws JsonProcessingException { ObjectMapper mapper = new ObjectMapper(); var arrayIt = new OpenSearchJsonContent(mapper.readTree("[\"zz\"]")).array(); - assertTrue(arrayIt.next().stringValue().equals("zz")); - assertTrue(!arrayIt.hasNext()); + assertAll( + () -> assertEquals("zz", arrayIt.next().stringValue()), + () -> assertFalse(arrayIt.hasNext()) + ); } @Test @@ -145,23 +157,29 @@ public void constructNullArrayValue() { @Test public void constructByte() { - assertEquals(byteValue((byte) 1), tupleValue("{\"byteV\":1}").get("byteV")); - assertEquals(byteValue((byte) 1), constructFromObject("byteV", 1)); - assertEquals(byteValue((byte) 1), constructFromObject("byteV", "1.0")); + assertAll( + () -> assertEquals(byteValue((byte) 1), tupleValue("{\"byteV\":1}").get("byteV")), + () -> assertEquals(byteValue((byte) 1), constructFromObject("byteV", 1)), + () -> assertEquals(byteValue((byte) 1), constructFromObject("byteV", "1.0")) + ); } @Test public void constructShort() { - assertEquals(shortValue((short) 1), tupleValue("{\"shortV\":1}").get("shortV")); - assertEquals(shortValue((short) 1), constructFromObject("shortV", 1)); - assertEquals(shortValue((short) 1), constructFromObject("shortV", "1.0")); + assertAll( + () -> assertEquals(shortValue((short) 1), tupleValue("{\"shortV\":1}").get("shortV")), + () -> assertEquals(shortValue((short) 1), constructFromObject("shortV", 1)), + () -> assertEquals(shortValue((short) 1), constructFromObject("shortV", "1.0")) + ); } @Test public void constructInteger() { - assertEquals(integerValue(1), tupleValue("{\"intV\":1}").get("intV")); - assertEquals(integerValue(1), constructFromObject("intV", 1)); - assertEquals(integerValue(1), constructFromObject("intV", "1.0")); + assertAll( + () -> assertEquals(integerValue(1), tupleValue("{\"intV\":1}").get("intV")), + () -> assertEquals(integerValue(1), constructFromObject("intV", 1)), + () -> assertEquals(integerValue(1), constructFromObject("intV", "1.0")) + ); } @Test @@ -171,168 +189,181 @@ public void constructIntegerValueInStringValue() { @Test public void constructLong() { - assertEquals(longValue(1L), tupleValue("{\"longV\":1}").get("longV")); - assertEquals(longValue(1L), constructFromObject("longV", 1L)); - assertEquals(longValue(1L), constructFromObject("longV", "1.0")); + assertAll( + () -> assertEquals(longValue(1L), tupleValue("{\"longV\":1}").get("longV")), + () -> assertEquals(longValue(1L), constructFromObject("longV", 1L)), + () -> assertEquals(longValue(1L), constructFromObject("longV", "1.0")) + ); } @Test public void constructFloat() { - assertEquals(floatValue(1f), tupleValue("{\"floatV\":1.0}").get("floatV")); - assertEquals(floatValue(1f), constructFromObject("floatV", 1f)); + assertAll( + () -> assertEquals(floatValue(1f), tupleValue("{\"floatV\":1.0}").get("floatV")), + () -> assertEquals(floatValue(1f), constructFromObject("floatV", 1f)) + ); } @Test public void constructDouble() { - assertEquals(doubleValue(1d), tupleValue("{\"doubleV\":1.0}").get("doubleV")); - assertEquals(doubleValue(1d), constructFromObject("doubleV", 1d)); + assertAll( + () -> assertEquals(doubleValue(1d), tupleValue("{\"doubleV\":1.0}").get("doubleV")), + () -> assertEquals(doubleValue(1d), constructFromObject("doubleV", 1d)) + ); } @Test public void constructString() { - assertEquals(stringValue("text"), tupleValue("{\"stringV\":\"text\"}").get("stringV")); - assertEquals(stringValue("text"), constructFromObject("stringV", "text")); + assertAll( + () -> assertEquals(stringValue("text"), + tupleValue("{\"stringV\":\"text\"}").get("stringV")), + () -> assertEquals(stringValue("text"), constructFromObject("stringV", "text")) + ); } @Test public void constructBoolean() { - assertEquals(booleanValue(true), tupleValue("{\"boolV\":true}").get("boolV")); - assertEquals(booleanValue(true), constructFromObject("boolV", true)); - assertEquals(booleanValue(true), constructFromObject("boolV", "true")); - assertEquals(booleanValue(true), constructFromObject("boolV", 1)); - assertEquals(booleanValue(false), constructFromObject("boolV", 0)); + assertAll( + () -> assertEquals(booleanValue(true), tupleValue("{\"boolV\":true}").get("boolV")), + () -> assertEquals(booleanValue(true), constructFromObject("boolV", true)), + () -> assertEquals(booleanValue(true), constructFromObject("boolV", "true")), + () -> assertEquals(booleanValue(true), constructFromObject("boolV", 1)), + () -> assertEquals(booleanValue(false), constructFromObject("boolV", 0)) + ); } @Test public void constructText() { - assertEquals(new OpenSearchExprTextValue("text"), - tupleValue("{\"textV\":\"text\"}").get("textV")); - assertEquals(new OpenSearchExprTextValue("text"), - constructFromObject("textV", "text")); - - assertEquals(new OpenSearchExprTextValue("text"), - tupleValue("{\"textKeywordV\":\"text\"}").get("textKeywordV")); - assertEquals(new OpenSearchExprTextValue("text"), - constructFromObject("textKeywordV", "text")); + assertAll( + () -> assertEquals(new OpenSearchExprTextValue("text"), + tupleValue("{\"textV\":\"text\"}").get("textV")), + () -> assertEquals(new OpenSearchExprTextValue("text"), + constructFromObject("textV", "text")), + + () -> assertEquals(new OpenSearchExprTextValue("text"), + tupleValue("{\"textKeywordV\":\"text\"}").get("textKeywordV")), + () -> assertEquals(new OpenSearchExprTextValue("text"), + constructFromObject("textKeywordV", "text")) + ); } @Test public void constructDates() { ExprValue dateStringV = constructFromObject("dateStringV", "1984-04-12"); - assertEquals(new ExprDateValue("1984-04-12"), dateStringV); - - assertEquals( - new ExprDateValue(LocalDate.ofInstant(Instant.ofEpochMilli(450576000000L), - UTC_ZONE_ID)), - constructFromObject("dateV", 450576000000L)); - - assertEquals( - new ExprDateValue("1984-04-12"), - constructFromObject("dateOrOrdinalDateV", "1984-103")); - assertEquals( - new ExprDateValue("2015-01-01"), - tupleValue("{\"dateV\":\"2015-01-01\"}").get("dateV")); + assertAll( + () -> assertEquals(new ExprDateValue("1984-04-12"), dateStringV), + () -> assertEquals(new ExprDateValue( + LocalDate.ofInstant(Instant.ofEpochMilli(450576000000L), UTC_ZONE_ID)), + constructFromObject("dateV", 450576000000L)), + () -> assertEquals(new ExprDateValue("1984-04-12"), + constructFromObject("dateOrOrdinalDateV", "1984-103")), + () -> assertEquals(new ExprDateValue("2015-01-01"), + tupleValue("{\"dateV\":\"2015-01-01\"}").get("dateV")) + ); } @Test public void constructTimes() { ExprValue timeStringV = constructFromObject("timeStringV","12:10:30.000Z"); - assertTrue(timeStringV.isDateTime()); - assertTrue(timeStringV instanceof ExprTimeValue); - assertEquals(new ExprTimeValue("12:10:30"), timeStringV); - - assertEquals( - new ExprTimeValue(LocalTime.from(Instant.ofEpochMilli(1420070400001L).atZone(UTC_ZONE_ID))), - constructFromObject("timeV", 1420070400001L)); - assertEquals( - new ExprTimeValue("09:07:42.000"), - constructFromObject("timeNoMillisOrTimeV", "09:07:42.000Z")); - assertEquals( - new ExprTimeValue("09:07:42"), - tupleValue("{\"timeV\":\"09:07:42\"}").get("timeV")); + assertAll( + () -> assertTrue(timeStringV.isDateTime()), + () -> assertTrue(timeStringV instanceof ExprTimeValue), + () -> assertEquals(new ExprTimeValue("12:10:30"), timeStringV), + () -> assertEquals(new ExprTimeValue(LocalTime.from( + Instant.ofEpochMilli(1420070400001L).atZone(UTC_ZONE_ID))), + constructFromObject("timeV", 1420070400001L)), + () -> assertEquals(new ExprTimeValue("09:07:42.000"), + constructFromObject("timeNoMillisOrTimeV", "09:07:42.000Z")), + () -> assertEquals(new ExprTimeValue("09:07:42"), + tupleValue("{\"timeV\":\"09:07:42\"}").get("timeV")) + ); } @Test public void constructDatetime() { - assertEquals( - new ExprTimestampValue("2015-01-01 00:00:00"), - tupleValue("{\"timestampV\":\"2015-01-01\"}").get("timestampV")); - assertEquals( - new ExprTimestampValue("2015-01-01 12:10:30"), - tupleValue("{\"timestampV\":\"2015-01-01T12:10:30Z\"}").get("timestampV")); - assertEquals( - new ExprTimestampValue("2015-01-01 12:10:30"), - tupleValue("{\"timestampV\":\"2015-01-01T12:10:30\"}").get("timestampV")); - assertEquals( - new ExprTimestampValue("2015-01-01 12:10:30"), - tupleValue("{\"timestampV\":\"2015-01-01 12:10:30\"}").get("timestampV")); - assertEquals( - new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), - tupleValue("{\"timestampV\":1420070400001}").get("timestampV")); - assertEquals( - new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), - constructFromObject("timestampV", 1420070400001L)); - assertEquals( - new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), - constructFromObject("timestampV", Instant.ofEpochMilli(1420070400001L))); - assertEquals( - new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), - constructFromObject("epochMillisV", "1420070400001")); - assertEquals( - new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), - constructFromObject("epochMillisV", 1420070400001L)); - assertEquals( - new ExprTimestampValue("2015-01-01 12:10:30"), - constructFromObject("timestampV", "2015-01-01 12:10:30")); - assertEquals( - new ExprDatetimeValue("2015-01-01 12:10:30"), - constructFromObject("datetimeV", "2015-01-01 12:10:30")); - assertEquals( - new ExprDatetimeValue("2015-01-01 12:10:30"), - constructFromObject("datetimeDefaultV", "2015-01-01 12:10:30")); - assertEquals( - new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), - constructFromObject("dateOrEpochMillisV", "1420070400001")); - - // case: timestamp-formatted field, but it only gets a time: should match a time - assertEquals( - new ExprTimeValue("19:36:22"), - tupleValue("{\"timestampV\":\"19:36:22\"}").get("timestampV")); - - // case: timestamp-formatted field, but it only gets a date: should match a date - assertEquals( - new ExprDateValue("2011-03-03"), - tupleValue("{\"timestampV\":\"2011-03-03\"}").get("timestampV")); + assertAll( + () -> assertEquals( + new ExprTimestampValue("2015-01-01 00:00:00"), + tupleValue("{\"timestampV\":\"2015-01-01\"}").get("timestampV")), + () -> assertEquals( + new ExprTimestampValue("2015-01-01 12:10:30"), + tupleValue("{\"timestampV\":\"2015-01-01T12:10:30Z\"}").get("timestampV")), + () -> assertEquals( + new ExprTimestampValue("2015-01-01 12:10:30"), + tupleValue("{\"timestampV\":\"2015-01-01T12:10:30\"}").get("timestampV")), + () -> assertEquals( + new ExprTimestampValue("2015-01-01 12:10:30"), + tupleValue("{\"timestampV\":\"2015-01-01 12:10:30\"}").get("timestampV")), + () -> assertEquals( + new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), + constructFromObject("timestampV", 1420070400001L)), + () -> assertEquals( + new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), + constructFromObject("timestampV", Instant.ofEpochMilli(1420070400001L))), + () -> assertEquals( + new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), + constructFromObject("epochMillisV", "1420070400001")), + () -> assertEquals( + new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), + constructFromObject("epochMillisV", 1420070400001L)), + () -> assertEquals( + new ExprTimestampValue(Instant.ofEpochSecond(142704001L)), + constructFromObject("epochSecondV", 142704001L)), + () -> assertEquals( + new ExprTimeValue("10:20:30"), + tupleValue("{ \"timeCustomV\" : 102030 }").get("timeCustomV")), + () -> assertEquals( + new ExprDateValue("1961-04-12"), + tupleValue("{ \"dateCustomV\" : 19610412 }").get("dateCustomV")), + () -> assertEquals( + new ExprTimestampValue("1984-05-10 20:30:40"), + tupleValue("{ \"dateTimeCustomV\" : 19840510203040 }").get("dateTimeCustomV")), + () -> assertEquals( + new ExprTimestampValue("2015-01-01 12:10:30"), + constructFromObject("timestampV", "2015-01-01 12:10:30")), + () -> assertEquals( + new ExprDatetimeValue("2015-01-01 12:10:30"), + constructFromObject("datetimeV", "2015-01-01 12:10:30")), + () -> assertEquals( + new ExprDatetimeValue("2015-01-01 12:10:30"), + constructFromObject("datetimeDefaultV", "2015-01-01 12:10:30")), + () -> assertEquals( + new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), + constructFromObject("dateOrEpochMillisV", "1420070400001")), + + // case: timestamp-formatted field, but it only gets a time: should match a time + () -> assertEquals( + new ExprTimeValue("19:36:22"), + tupleValue("{\"timestampV\":\"19:36:22\"}").get("timestampV")), + + // case: timestamp-formatted field, but it only gets a date: should match a date + () -> assertEquals( + new ExprDateValue("2011-03-03"), + tupleValue("{\"timestampV\":\"2011-03-03\"}").get("timestampV")) + ); } @Test public void constructDatetime_fromCustomFormat() { - // this is not the desirable behaviour - instead if accepts the default formatter assertEquals( new ExprDatetimeValue("2015-01-01 12:10:30"), - constructFromObject("customFormatV", "2015-01-01 12:10:30")); + constructFromObject("customFormatV", "2015-01-01-12-10-30")); - // this should pass when custom formats are supported IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, - () -> constructFromObject("customFormatV", "2015-01-01-12-10-30")); + () -> constructFromObject("customFormatV", "2015-01-01 12-10-30")); assertEquals( - "Construct ExprTimestampValue from \"2015-01-01-12-10-30\" failed, " - + "unsupported date format.", + "Construct TIMESTAMP from \"2015-01-01 12-10-30\" failed, " + + "unsupported format.", exception.getMessage()); assertEquals( new ExprDatetimeValue("2015-01-01 12:10:30"), constructFromObject("customAndEpochMillisV", "2015-01-01 12:10:30")); - // this should pass when custom formats are supported - exception = - assertThrows(IllegalArgumentException.class, - () -> constructFromObject("customAndEpochMillisV", "2015-01-01-12-10-30")); assertEquals( - "Construct ExprTimestampValue from \"2015-01-01-12-10-30\" failed, " - + "unsupported date format.", - exception.getMessage()); + new ExprDatetimeValue("2015-01-01 12:10:30"), + constructFromObject("customAndEpochMillisV", "2015-01-01-12-10-30")); } @Test @@ -341,8 +372,8 @@ public void constructDatetimeFromUnsupportedFormat_ThrowIllegalArgumentException assertThrows(IllegalArgumentException.class, () -> constructFromObject("timestampV", "2015-01-01 12:10")); assertEquals( - "Construct ExprTimestampValue from \"2015-01-01 12:10\" failed, " - + "unsupported date format.", + "Construct TIMESTAMP from \"2015-01-01 12:10\" failed, " + + "unsupported format.", exception.getMessage()); // fail with missing seconds @@ -350,8 +381,8 @@ public void constructDatetimeFromUnsupportedFormat_ThrowIllegalArgumentException assertThrows(IllegalArgumentException.class, () -> constructFromObject("dateOrEpochMillisV", "2015-01-01 12:10")); assertEquals( - "Construct ExprTimestampValue from \"2015-01-01 12:10\" failed, " - + "unsupported date format.", + "Construct TIMESTAMP from \"2015-01-01 12:10\" failed, " + + "unsupported format.", exception.getMessage()); } @@ -360,15 +391,15 @@ public void constructTimeFromUnsupportedFormat_ThrowIllegalArgumentException() { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, () -> constructFromObject("timeV", "2015-01-01")); assertEquals( - "Construct ExprTimeValue from \"2015-01-01\" failed, " - + "unsupported time format.", + "Construct TIME from \"2015-01-01\" failed, " + + "unsupported format.", exception.getMessage()); exception = assertThrows( IllegalArgumentException.class, () -> constructFromObject("timeStringV", "10:10")); assertEquals( - "Construct ExprTimeValue from \"10:10\" failed, " - + "unsupported time format.", + "Construct TIME from \"10:10\" failed, " + + "unsupported format.", exception.getMessage()); } @@ -377,18 +408,25 @@ public void constructDateFromUnsupportedFormat_ThrowIllegalArgumentException() { IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, () -> constructFromObject("dateV", "12:10:10")); assertEquals( - "Construct ExprDateValue from \"12:10:10\" failed, " - + "unsupported date format.", + "Construct DATE from \"12:10:10\" failed, " + + "unsupported format.", exception.getMessage()); exception = assertThrows( IllegalArgumentException.class, () -> constructFromObject("dateStringV", "abc")); assertEquals( - "Construct ExprDateValue from \"abc\" failed, " - + "unsupported date format.", + "Construct DATE from \"abc\" failed, " + + "unsupported format.", exception.getMessage()); } + @Test + public void constructDateFromIncompleteFormat() { + assertEquals( + new ExprDateValue("1984-01-01"), + constructFromObject("incompleteFormatV", "1984")); + } + @Test public void constructArray() { assertEquals( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java index 05e5d80c39..672fca12d7 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java @@ -32,8 +32,10 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.text.Text; +import org.opensearch.index.shard.ShardId; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchShardTarget; import org.opensearch.search.aggregations.Aggregations; import org.opensearch.search.fetch.subphase.highlight.HighlightField; import org.opensearch.sql.data.model.ExprFloatValue; @@ -148,9 +150,13 @@ void iterator_metafields() { new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 3.75F)); + ShardId shardId = new ShardId("index", "indexUUID", 42); + SearchShardTarget shardTarget = new SearchShardTarget("node", shardId, null, null); + when(searchHit1.getSourceAsString()).thenReturn("{\"id1\", 1}"); when(searchHit1.getId()).thenReturn("testId"); when(searchHit1.getIndex()).thenReturn("testIndex"); + when(searchHit1.getShard()).thenReturn(shardTarget); when(searchHit1.getScore()).thenReturn(3.75F); when(searchHit1.getSeqNo()).thenReturn(123456L); @@ -160,11 +166,12 @@ void iterator_metafields() { "id1", new ExprIntegerValue(1), "_index", new ExprStringValue("testIndex"), "_id", new ExprStringValue("testId"), + "_routing", new ExprStringValue(shardTarget.toString()), "_sort", new ExprLongValue(123456L), "_score", new ExprFloatValue(3.75F), "_maxscore", new ExprFloatValue(3.75F) )); - List includes = List.of("id1", "_index", "_id", "_sort", "_score", "_maxscore"); + List includes = List.of("id1", "_index", "_id", "_routing", "_sort", "_score", "_maxscore"); int i = 0; for (ExprValue hit : new OpenSearchResponse(searchResponse, factory, includes)) { if (i == 0) { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java index 11694813cc..39af59b6cd 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java @@ -187,9 +187,10 @@ void getReservedFieldTypes() { assertThat( fieldTypes, allOf( - aMapWithSize(5), + aMapWithSize(6), hasEntry("_id", ExprCoreType.STRING), hasEntry("_index", ExprCoreType.STRING), + hasEntry("_routing", ExprCoreType.STRING), hasEntry("_sort", ExprCoreType.LONG), hasEntry("_score", ExprCoreType.FLOAT), hasEntry("_maxscore", ExprCoreType.FLOAT) diff --git a/plugin/build.gradle b/plugin/build.gradle index 42d5723194..11f97ea857 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -126,6 +126,7 @@ dependencies { api project(':opensearch') api project(':prometheus') api project(':datasources') + api project(':spark') testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.12.13' testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 36986c9afc..7e867be967 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -81,6 +81,7 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; +import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; @@ -221,6 +222,7 @@ private DataSourceServiceImpl createDataSourceService() { .add(new OpenSearchDataSourceFactory( new OpenSearchNodeClient(this.client), pluginSettings)) .add(new PrometheusStorageFactory(pluginSettings)) + .add(new SparkStorageFactory(this.client, pluginSettings)) .build(), dataSourceMetadataStorage, dataSourceUserAuthorizationHelper); diff --git a/ppl/build.gradle b/ppl/build.gradle index 365b8ff0a8..36cd935cf1 100644 --- a/ppl/build.gradle +++ b/ppl/build.gradle @@ -47,7 +47,7 @@ dependencies { runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12' implementation "org.antlr:antlr4-runtime:4.7.1" - implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' api group: 'org.json', name: 'json', version: '20230227' implementation group: 'org.apache.logging.log4j', name: 'log4j-core', version:'2.17.1' api project(':common') diff --git a/protocol/build.gradle b/protocol/build.gradle index 5d32a235ea..92a1aa0917 100644 --- a/protocol/build.gradle +++ b/protocol/build.gradle @@ -29,7 +29,7 @@ plugins { } dependencies { - implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' implementation group: 'com.fasterxml.jackson.core', name: 'jackson-core', version: "${versions.jackson}" implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: "${versions.jackson_databind}" implementation group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: "${versions.jackson}" diff --git a/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java b/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java index ae66364419..3ce1dd8875 100644 --- a/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java +++ b/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java @@ -9,6 +9,7 @@ import java.util.Collection; import java.util.Iterator; import java.util.LinkedHashMap; +import java.util.Locale; import java.util.Map; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -58,7 +59,7 @@ public Map columnNameTypes() { Map colNameTypes = new LinkedHashMap<>(); schema.getColumns().forEach(column -> colNameTypes.put( getColumnName(column), - column.getExprType().typeName().toLowerCase())); + column.getExprType().typeName().toLowerCase(Locale.ROOT))); return colNameTypes; } diff --git a/settings.gradle b/settings.gradle index 6f7214cb3a..2140ad6c9e 100644 --- a/settings.gradle +++ b/settings.gradle @@ -19,4 +19,6 @@ include 'legacy' include 'sql' include 'prometheus' include 'benchmarks' -include 'datasources' \ No newline at end of file +include 'datasources' +include 'spark' + diff --git a/spark-sql-application/README.md b/spark-sql-application/README.md index b0505282ab..6422f294cd 100644 --- a/spark-sql-application/README.md +++ b/spark-sql-application/README.md @@ -3,6 +3,7 @@ This application execute sql query and store the result in OpenSearch index in following format ``` "stepId":"", +"applicationId":"" "schema": "json blob", "result": "json blob" ``` @@ -61,7 +62,8 @@ OpenSearch index document will look like "{'column_name':'Letter','data_type':'string'}", "{'column_name':'Number','data_type':'integer'}" ], - "stepId" : "s-JZSB1139WIVU" + "stepId" : "s-JZSB1139WIVU", + "applicationId" : "application_1687726870985_0003" } } ``` diff --git a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala index f2dd0c869c..04fa92b25b 100644 --- a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala +++ b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala @@ -84,13 +84,15 @@ object SQLJob { val schema = StructType(Seq( StructField("result", ArrayType(StringType, containsNull = true), nullable = true), StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("stepId", StringType, nullable = true))) + StructField("stepId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true))) // Create the data rows val rows = Seq(( result.toJSON.collect.toList.map(_.replaceAll("\"", "'")), resultSchema.toJSON.collect.toList.map(_.replaceAll("\"", "'")), - sys.env.getOrElse("EMR_STEP_ID", ""))) + sys.env.getOrElse("EMR_STEP_ID", "unknown"), + spark.sparkContext.applicationId)) // Create the DataFrame for data spark.createDataFrame(rows).toDF(schema.fields.map(_.name): _*) diff --git a/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala b/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala index 2cdb06d6ca..7ec4e45450 100644 --- a/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala +++ b/spark-sql-application/src/test/scala/org/opensearch/sql/SQLJobTest.scala @@ -31,13 +31,15 @@ class SQLJobTest extends AnyFunSuite{ val expectedSchema = StructType(Seq( StructField("result", ArrayType(StringType, containsNull = true), nullable = true), StructField("schema", ArrayType(StringType, containsNull = true), nullable = true), - StructField("stepId", StringType, nullable = true) + StructField("stepId", StringType, nullable = true), + StructField("applicationId", StringType, nullable = true) )) val expectedRows = Seq( Row( Array("{'Letter':'A','Number':1}","{'Letter':'B','Number':2}", "{'Letter':'C','Number':3}"), Array("{'column_name':'Letter','data_type':'string'}", "{'column_name':'Number','data_type':'integer'}"), - "" + "unknown", + spark.sparkContext.applicationId ) ) val expected: DataFrame = spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema) diff --git a/spark/build.gradle b/spark/build.gradle new file mode 100644 index 0000000000..89842e5ea8 --- /dev/null +++ b/spark/build.gradle @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +plugins { + id 'java-library' + id "io.freefair.lombok" + id 'jacoco' +} + +repositories { + mavenCentral() +} + +dependencies { + api project(':core') + implementation project(':datasources') + + implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" + implementation group: 'org.json', name: 'json', version: '20230227' + implementation group: 'com.amazonaws', name: 'aws-java-sdk-emr', version: '1.12.1' + + testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') + testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' + testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.2.0' + testImplementation 'junit:junit:4.13.1' +} + +test { + useJUnitPlatform() + testLogging { + events "passed", "skipped", "failed" + exceptionFormat "full" + } +} + +jacocoTestReport { + reports { + html.enabled true + xml.enabled true + } + afterEvaluate { + classDirectories.setFrom(files(classDirectories.files.collect { + fileTree(dir: it) + })) + } +} +test.finalizedBy(project.tasks.jacocoTestReport) + +jacocoTestCoverageVerification { + violationRules { + rule { + element = 'CLASS' + excludes = [ + 'org.opensearch.sql.spark.data.constants.*' + ] + limit { + counter = 'LINE' + minimum = 1.0 + } + limit { + counter = 'BRANCH' + minimum = 1.0 + } + } + } + afterEvaluate { + classDirectories.setFrom(files(classDirectories.files.collect { + fileTree(dir: it) + })) + } +} +check.dependsOn jacocoTestCoverageVerification +jacocoTestCoverageVerification.dependsOn jacocoTestReport diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java new file mode 100644 index 0000000000..1e2475c196 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_SQL_APPLICATION_JAR; + +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; +import com.amazonaws.services.elasticmapreduce.model.ActionOnFailure; +import com.amazonaws.services.elasticmapreduce.model.AddJobFlowStepsRequest; +import com.amazonaws.services.elasticmapreduce.model.AddJobFlowStepsResult; +import com.amazonaws.services.elasticmapreduce.model.DescribeStepRequest; +import com.amazonaws.services.elasticmapreduce.model.HadoopJarStepConfig; +import com.amazonaws.services.elasticmapreduce.model.StepConfig; +import com.amazonaws.services.elasticmapreduce.model.StepStatus; +import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import lombok.SneakyThrows; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONObject; +import org.opensearch.sql.spark.helper.FlintHelper; +import org.opensearch.sql.spark.response.SparkResponse; + +public class EmrClientImpl implements SparkClient { + private final AmazonElasticMapReduce emr; + private final String emrCluster; + private final FlintHelper flint; + private final String sparkApplicationJar; + private static final Logger logger = LogManager.getLogger(EmrClientImpl.class); + private SparkResponse sparkResponse; + + /** + * Constructor for EMR Client Implementation. + * + * @param emr EMR helper + * @param flint Opensearch args for flint integration jar + * @param sparkResponse Response object to help with retrieving results from Opensearch index + */ + public EmrClientImpl(AmazonElasticMapReduce emr, String emrCluster, FlintHelper flint, + SparkResponse sparkResponse, String sparkApplicationJar) { + this.emr = emr; + this.emrCluster = emrCluster; + this.flint = flint; + this.sparkResponse = sparkResponse; + this.sparkApplicationJar = + sparkApplicationJar == null ? SPARK_SQL_APPLICATION_JAR : sparkApplicationJar; + } + + @Override + public JSONObject sql(String query) throws IOException { + runEmrApplication(query); + return sparkResponse.getResultFromOpensearchIndex(); + } + + @VisibleForTesting + void runEmrApplication(String query) { + + HadoopJarStepConfig stepConfig = new HadoopJarStepConfig() + .withJar("command-runner.jar") + .withArgs("spark-submit", + "--class","org.opensearch.sql.SQLJob", + "--jars", + flint.getFlintIntegrationJar(), + sparkApplicationJar, + query, + SPARK_INDEX_NAME, + flint.getFlintHost(), + flint.getFlintPort(), + flint.getFlintScheme(), + flint.getFlintAuth(), + flint.getFlintRegion() + ); + + StepConfig emrstep = new StepConfig() + .withName("Spark Application") + .withActionOnFailure(ActionOnFailure.CONTINUE) + .withHadoopJarStep(stepConfig); + + AddJobFlowStepsRequest request = new AddJobFlowStepsRequest() + .withJobFlowId(emrCluster) + .withSteps(emrstep); + + AddJobFlowStepsResult result = emr.addJobFlowSteps(request); + logger.info("EMR step ID: " + result.getStepIds()); + + String stepId = result.getStepIds().get(0); + DescribeStepRequest stepRequest = new DescribeStepRequest() + .withClusterId(emrCluster) + .withStepId(stepId); + + waitForStepExecution(stepRequest); + sparkResponse.setValue(stepId); + } + + @SneakyThrows + private void waitForStepExecution(DescribeStepRequest stepRequest) { + // Wait for the step to complete + boolean completed = false; + while (!completed) { + // Get the step status + StepStatus statusDetail = emr.describeStep(stepRequest).getStep().getStatus(); + // Check if the step has completed + if (statusDetail.getState().equals("COMPLETED")) { + completed = true; + logger.info("EMR step completed successfully."); + } else if (statusDetail.getState().equals("FAILED") + || statusDetail.getState().equals("CANCELLED")) { + logger.error("EMR step failed or cancelled."); + throw new RuntimeException("Spark SQL application failed."); + } else { + // Sleep for some time before checking the status again + Thread.sleep(2500); + } + } + } + +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java new file mode 100644 index 0000000000..99d8600dd0 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import java.io.IOException; +import org.json.JSONObject; + +/** + * Interface class for Spark Client. + */ +public interface SparkClient { + /** + * This method executes spark sql query. + * + * @param query spark sql query + * @return spark query response + */ + JSONObject sql(String query) throws IOException; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java new file mode 100644 index 0000000000..65d5a01ba2 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/data/constants/SparkConstants.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.data.constants; + +public class SparkConstants { + public static final String EMR = "emr"; + public static final String STEP_ID_FIELD = "stepId.keyword"; + public static final String SPARK_SQL_APPLICATION_JAR = "s3://spark-datasource/sql-job.jar"; + public static final String SPARK_INDEX_NAME = ".query_execution_result"; + public static final String FLINT_INTEGRATION_JAR = + "s3://spark-datasource/flint-spark-integration-assembly-0.1.0-SNAPSHOT.jar"; + public static final String FLINT_DEFAULT_HOST = "localhost"; + public static final String FLINT_DEFAULT_PORT = "9200"; + public static final String FLINT_DEFAULT_SCHEME = "http"; + public static final String FLINT_DEFAULT_AUTH = "-1"; + public static final String FLINT_DEFAULT_REGION = "us-west-2"; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java b/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java new file mode 100644 index 0000000000..1936c266de --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.implementation; + +import static org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver.QUERY; + +import java.util.List; +import java.util.stream.Collectors; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.spark.storage.SparkTable; +import org.opensearch.sql.storage.Table; + +/** + * Spark SQL function implementation. + */ +public class SparkSqlFunctionImplementation extends FunctionExpression + implements TableFunctionImplementation { + + private final FunctionName functionName; + private final List arguments; + private final SparkClient sparkClient; + + /** + * Constructor for spark sql function. + * + * @param functionName name of the function + * @param arguments a list of expressions + * @param sparkClient spark client + */ + public SparkSqlFunctionImplementation( + FunctionName functionName, List arguments, SparkClient sparkClient) { + super(functionName, arguments); + this.functionName = functionName; + this.arguments = arguments; + this.sparkClient = sparkClient; + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new UnsupportedOperationException(String.format( + "Spark defined function [%s] is only " + + "supported in SOURCE clause with spark connector catalog", functionName)); + } + + @Override + public ExprType type() { + return ExprCoreType.STRUCT; + } + + @Override + public String toString() { + List args = arguments.stream() + .map(arg -> String.format("%s=%s", + ((NamedArgumentExpression) arg).getArgName(), + ((NamedArgumentExpression) arg).getValue().toString())) + .collect(Collectors.toList()); + return String.format("%s(%s)", functionName, String.join(", ", args)); + } + + @Override + public Table applyArguments() { + return new SparkTable(sparkClient, buildQueryFromSqlFunction(arguments)); + } + + /** + * This method builds a spark query request. + * + * @param arguments spark sql function arguments + * @return spark query request + */ + private SparkQueryRequest buildQueryFromSqlFunction(List arguments) { + + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + arguments.forEach(arg -> { + String argName = ((NamedArgumentExpression) arg).getArgName(); + Expression argValue = ((NamedArgumentExpression) arg).getValue(); + ExprValue literalValue = argValue.valueOf(); + if (argName.equals(QUERY)) { + sparkQueryRequest.setSql((String) literalValue.value()); + } else { + throw new ExpressionEvaluationException( + String.format("Invalid Function Argument:%s", argName)); + } + }); + return sparkQueryRequest; + } + +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java b/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java new file mode 100644 index 0000000000..624600e1a8 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.resolver; + +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.ArrayList; +import java.util.List; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.implementation.SparkSqlFunctionImplementation; + +/** + * Function resolver for sql function of spark connector. + */ +@RequiredArgsConstructor +public class SparkSqlTableFunctionResolver implements FunctionResolver { + private final SparkClient sparkClient; + + public static final String SQL = "sql"; + public static final String QUERY = "query"; + + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + FunctionName functionName = FunctionName.of(SQL); + FunctionSignature functionSignature = + new FunctionSignature(functionName, List.of(STRING)); + final List argumentNames = List.of(QUERY); + + FunctionBuilder functionBuilder = (functionProperties, arguments) -> { + Boolean argumentsPassedByName = arguments.stream() + .noneMatch(arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); + Boolean argumentsPassedByPosition = arguments.stream() + .allMatch(arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); + if (!(argumentsPassedByName || argumentsPassedByPosition)) { + throw new SemanticCheckException("Arguments should be either passed by name or position"); + } + + if (arguments.size() != argumentNames.size()) { + throw new SemanticCheckException( + String.format("Missing arguments:[%s]", + String.join(",", argumentNames.subList(arguments.size(), argumentNames.size())))); + } + + if (argumentsPassedByPosition) { + List namedArguments = new ArrayList<>(); + for (int i = 0; i < arguments.size(); i++) { + namedArguments.add(new NamedArgumentExpression(argumentNames.get(i), + ((NamedArgumentExpression) arguments.get(i)).getValue())); + } + return new SparkSqlFunctionImplementation(functionName, namedArguments, sparkClient); + } + return new SparkSqlFunctionImplementation(functionName, arguments, sparkClient); + }; + return Pair.of(functionSignature, functionBuilder); + } + + @Override + public FunctionName getFunctionName() { + return FunctionName.of(SQL); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java new file mode 100644 index 0000000000..cb2b31ddc1 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.response; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONArray; +import org.json.JSONObject; +import org.opensearch.sql.data.model.ExprBooleanValue; +import org.opensearch.sql.data.model.ExprByteValue; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTimestampValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.executor.ExecutionEngine; + +/** + * Default implementation of SparkSqlFunctionResponseHandle. + */ +public class DefaultSparkSqlFunctionResponseHandle implements SparkSqlFunctionResponseHandle { + private Iterator responseIterator; + private ExecutionEngine.Schema schema; + private static final Logger logger = + LogManager.getLogger(DefaultSparkSqlFunctionResponseHandle.class); + + /** + * Constructor. + * + * @param responseObject Spark responseObject. + */ + public DefaultSparkSqlFunctionResponseHandle(JSONObject responseObject) { + constructIteratorAndSchema(responseObject); + } + + private void constructIteratorAndSchema(JSONObject responseObject) { + List result = new ArrayList<>(); + List columnList; + JSONObject items = responseObject.getJSONObject("data"); + logger.info("Spark Application ID: " + items.getString("applicationId")); + columnList = getColumnList(items.getJSONArray("schema")); + for (int i = 0; i < items.getJSONArray("result").length(); i++) { + JSONObject row = new JSONObject( + items.getJSONArray("result").get(i).toString().replace("'", "\"")); + LinkedHashMap linkedHashMap = extractRow(row, columnList); + result.add(new ExprTupleValue(linkedHashMap)); + } + this.schema = new ExecutionEngine.Schema(columnList); + this.responseIterator = result.iterator(); + } + + private static LinkedHashMap extractRow( + JSONObject row, List columnList) { + LinkedHashMap linkedHashMap = new LinkedHashMap<>(); + for (ExecutionEngine.Schema.Column column : columnList) { + ExprType type = column.getExprType(); + if (type == ExprCoreType.BOOLEAN) { + linkedHashMap.put(column.getName(), ExprBooleanValue.of(row.getBoolean(column.getName()))); + } else if (type == ExprCoreType.LONG) { + linkedHashMap.put(column.getName(), new ExprLongValue(row.getLong(column.getName()))); + } else if (type == ExprCoreType.INTEGER) { + linkedHashMap.put(column.getName(), new ExprIntegerValue(row.getInt(column.getName()))); + } else if (type == ExprCoreType.SHORT) { + linkedHashMap.put(column.getName(), new ExprShortValue(row.getInt(column.getName()))); + } else if (type == ExprCoreType.BYTE) { + linkedHashMap.put(column.getName(), new ExprByteValue(row.getInt(column.getName()))); + } else if (type == ExprCoreType.DOUBLE) { + linkedHashMap.put(column.getName(), new ExprDoubleValue(row.getDouble(column.getName()))); + } else if (type == ExprCoreType.FLOAT) { + linkedHashMap.put(column.getName(), new ExprFloatValue(row.getFloat(column.getName()))); + } else if (type == ExprCoreType.DATE) { + linkedHashMap.put(column.getName(), new ExprDateValue(row.getString(column.getName()))); + } else if (type == ExprCoreType.TIMESTAMP) { + linkedHashMap.put(column.getName(), + new ExprTimestampValue(row.getString(column.getName()))); + } else if (type == ExprCoreType.STRING) { + linkedHashMap.put(column.getName(), new ExprStringValue(row.getString(column.getName()))); + } else { + throw new RuntimeException("Result contains invalid data type"); + } + } + + return linkedHashMap; + } + + private List getColumnList(JSONArray schema) { + List columnList = new ArrayList<>(); + for (int i = 0; i < schema.length(); i++) { + JSONObject column = new JSONObject(schema.get(i).toString().replace("'", "\"")); + columnList.add(new ExecutionEngine.Schema.Column( + column.get("column_name").toString(), + column.get("column_name").toString(), + getDataType(column.get("data_type").toString()))); + } + return columnList; + } + + private ExprCoreType getDataType(String sparkDataType) { + switch (sparkDataType) { + case "boolean": + return ExprCoreType.BOOLEAN; + case "long": + return ExprCoreType.LONG; + case "integer": + return ExprCoreType.INTEGER; + case "short": + return ExprCoreType.SHORT; + case "byte": + return ExprCoreType.BYTE; + case "double": + return ExprCoreType.DOUBLE; + case "float": + return ExprCoreType.FLOAT; + case "timestamp": + return ExprCoreType.DATE; + case "date": + return ExprCoreType.TIMESTAMP; + case "string": + case "varchar": + case "char": + return ExprCoreType.STRING; + default: + return ExprCoreType.UNKNOWN; + } + } + + @Override + public boolean hasNext() { + return responseIterator.hasNext(); + } + + @Override + public ExprValue next() { + return responseIterator.next(); + } + + @Override + public ExecutionEngine.Schema schema() { + return schema; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/response/SparkSqlFunctionResponseHandle.java b/spark/src/main/java/org/opensearch/sql/spark/functions/response/SparkSqlFunctionResponseHandle.java new file mode 100644 index 0000000000..da68b591eb --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/response/SparkSqlFunctionResponseHandle.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.response; + +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.ExecutionEngine; + +/** + * Handle Spark response. + */ +public interface SparkSqlFunctionResponseHandle { + + /** + * Return true if Spark response has more result. + */ + boolean hasNext(); + + /** + * Return Spark response as {@link ExprValue}. Attention, the method must been called when + * hasNext return true. + */ + ExprValue next(); + + /** + * Return ExecutionEngine.Schema of the Spark response. + */ + ExecutionEngine.Schema schema(); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java new file mode 100644 index 0000000000..28ce7dd19a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.scan; + +import lombok.AllArgsConstructor; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * TableScanBuilder for sql function of spark connector. + */ +@AllArgsConstructor +public class SparkSqlFunctionTableScanBuilder extends TableScanBuilder { + + private final SparkClient sparkClient; + + private final SparkQueryRequest sparkQueryRequest; + + @Override + public TableScanOperator build() { + return new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + } + + @Override + public boolean pushDownProject(LogicalProject project) { + return true; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java new file mode 100644 index 0000000000..85e854e422 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.scan; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Locale; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONObject; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.response.DefaultSparkSqlFunctionResponseHandle; +import org.opensearch.sql.spark.functions.response.SparkSqlFunctionResponseHandle; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; + +/** + * This a table scan operator to handle sql table function. + */ +@RequiredArgsConstructor +public class SparkSqlFunctionTableScanOperator extends TableScanOperator { + private final SparkClient sparkClient; + private final SparkQueryRequest request; + private SparkSqlFunctionResponseHandle sparkResponseHandle; + private static final Logger LOG = LogManager.getLogger(); + + @Override + public void open() { + super.open(); + this.sparkResponseHandle = AccessController.doPrivileged( + (PrivilegedAction) () -> { + try { + JSONObject responseObject = sparkClient.sql(request.getSql()); + return new DefaultSparkSqlFunctionResponseHandle(responseObject); + } catch (IOException e) { + LOG.error(e.getMessage()); + throw new RuntimeException( + String.format("Error fetching data from spark server: %s", e.getMessage())); + } + }); + } + + @Override + public boolean hasNext() { + return this.sparkResponseHandle.hasNext(); + } + + @Override + public ExprValue next() { + return this.sparkResponseHandle.next(); + } + + @Override + public String explain() { + return String.format(Locale.ROOT, "sql(%s)", request.getSql()); + } + + @Override + public ExecutionEngine.Schema schema() { + return this.sparkResponseHandle.schema(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/helper/FlintHelper.java b/spark/src/main/java/org/opensearch/sql/spark/helper/FlintHelper.java new file mode 100644 index 0000000000..b3c3c0871a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/helper/FlintHelper.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.helper; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_AUTH; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_HOST; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_PORT; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_REGION; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_DEFAULT_SCHEME; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_INTEGRATION_JAR; + +import lombok.Getter; + +public class FlintHelper { + @Getter + private final String flintIntegrationJar; + @Getter + private final String flintHost; + @Getter + private final String flintPort; + @Getter + private final String flintScheme; + @Getter + private final String flintAuth; + @Getter + private final String flintRegion; + + /** Arguments required to write data to opensearch index using flint integration. + * + * @param flintHost Opensearch host for flint + * @param flintPort Opensearch port for flint integration + * @param flintScheme Opensearch scheme for flint integration + * @param flintAuth Opensearch auth for flint integration + * @param flintRegion Opensearch region for flint integration + */ + public FlintHelper( + String flintIntegrationJar, + String flintHost, + String flintPort, + String flintScheme, + String flintAuth, + String flintRegion) { + this.flintIntegrationJar = + flintIntegrationJar == null ? FLINT_INTEGRATION_JAR : flintIntegrationJar; + this.flintHost = flintHost != null ? flintHost : FLINT_DEFAULT_HOST; + this.flintPort = flintPort != null ? flintPort : FLINT_DEFAULT_PORT; + this.flintScheme = flintScheme != null ? flintScheme : FLINT_DEFAULT_SCHEME; + this.flintAuth = flintAuth != null ? flintAuth : FLINT_DEFAULT_AUTH; + this.flintRegion = flintRegion != null ? flintRegion : FLINT_DEFAULT_REGION; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java new file mode 100644 index 0000000000..bc0944a784 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.request; + +import lombok.Data; + +/** + * Spark query request. + */ +@Data +public class SparkQueryRequest { + + /** + * SQL. + */ + private String sql; + +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java new file mode 100644 index 0000000000..3e348381f2 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.response; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; + +import com.google.common.annotations.VisibleForTesting; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.Setter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.json.JSONObject; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionFuture; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException; + +@Data +public class SparkResponse { + private final Client client; + private String value; + private final String field; + private static final Logger LOG = LogManager.getLogger(); + + /** + * Response for spark sql query. + * + * @param client Opensearch client + * @param value Identifier field value + * @param field Identifier field name + */ + public SparkResponse(Client client, String value, String field) { + this.client = client; + this.value = value; + this.field = field; + } + + public JSONObject getResultFromOpensearchIndex() { + return searchInSparkIndex(QueryBuilders.termQuery(field, value)); + } + + private JSONObject searchInSparkIndex(QueryBuilder query) { + SearchRequest searchRequest = new SearchRequest(); + searchRequest.indices(SPARK_INDEX_NAME); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(query); + searchRequest.source(searchSourceBuilder); + ActionFuture searchResponseActionFuture; + try { + searchResponseActionFuture = client.search(searchRequest); + } catch (Exception e) { + throw new RuntimeException(e); + } + SearchResponse searchResponse = searchResponseActionFuture.actionGet(); + if (searchResponse.status().getStatus() != 200) { + throw new RuntimeException( + "Fetching result from " + SPARK_INDEX_NAME + " index failed with status : " + + searchResponse.status()); + } else { + JSONObject data = new JSONObject(); + for (SearchHit searchHit : searchResponse.getHits().getHits()) { + data.put("data", searchHit.getSourceAsMap()); + deleteInSparkIndex(searchHit.getId()); + } + return data; + } + } + + @VisibleForTesting + void deleteInSparkIndex(String id) { + DeleteRequest deleteRequest = new DeleteRequest(SPARK_INDEX_NAME); + deleteRequest.id(id); + ActionFuture deleteResponseActionFuture; + try { + deleteResponseActionFuture = client.delete(deleteRequest); + } catch (Exception e) { + throw new RuntimeException(e); + } + DeleteResponse deleteResponse = deleteResponseActionFuture.actionGet(); + if (deleteResponse.getResult().equals(DocWriteResponse.Result.DELETED)) { + LOG.debug("Spark result successfully deleted ", id); + } else if (deleteResponse.getResult().equals(DocWriteResponse.Result.NOT_FOUND)) { + throw new ResourceNotFoundException("Spark result with id " + + id + " doesn't exist"); + } else { + throw new RuntimeException("Deleting spark result information failed with : " + + deleteResponse.getResult().getLowercase()); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java new file mode 100644 index 0000000000..3897e8690e --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; + +/** + * Spark scan operator. + */ +@EqualsAndHashCode(onlyExplicitlyIncluded = true, callSuper = false) +@ToString(onlyExplicitlyIncluded = true) +public class SparkScan extends TableScanOperator { + + private final SparkClient sparkClient; + + @EqualsAndHashCode.Include + @Getter + @Setter + @ToString.Include + private SparkQueryRequest request; + + + /** + * Constructor. + * + * @param sparkClient sparkClient. + */ + public SparkScan(SparkClient sparkClient) { + this.sparkClient = sparkClient; + this.request = new SparkQueryRequest(); + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public ExprValue next() { + return null; + } + + @Override + public String explain() { + return getRequest().toString(); + } + +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java new file mode 100644 index 0000000000..a5e35ecc4c --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import java.util.Collection; +import java.util.Collections; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.DataSourceSchemaName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver; +import org.opensearch.sql.storage.StorageEngine; +import org.opensearch.sql.storage.Table; + +/** + * Spark storage engine implementation. + */ +@RequiredArgsConstructor +public class SparkStorageEngine implements StorageEngine { + private final SparkClient sparkClient; + + @Override + public Collection getFunctions() { + return Collections.singletonList( + new SparkSqlTableFunctionResolver(sparkClient)); + } + + @Override + public Table getTable(DataSourceSchemaName dataSourceSchemaName, String tableName) { + throw new RuntimeException("Unable to get table from storage engine."); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java new file mode 100644 index 0000000000..937679b50e --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.EMR; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STEP_ID_FIELD; + +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder; +import java.security.AccessController; +import java.security.InvalidParameterException; +import java.security.PrivilegedAction; +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.opensearch.client.Client; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.datasources.auth.AuthenticationType; +import org.opensearch.sql.spark.client.EmrClientImpl; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.helper.FlintHelper; +import org.opensearch.sql.spark.response.SparkResponse; +import org.opensearch.sql.storage.DataSourceFactory; +import org.opensearch.sql.storage.StorageEngine; + +/** + * Storage factory implementation for spark connector. + */ +@RequiredArgsConstructor +public class SparkStorageFactory implements DataSourceFactory { + private final Client client; + private final Settings settings; + + // Spark datasource configuration properties + public static final String CONNECTOR_TYPE = "spark.connector"; + public static final String SPARK_SQL_APPLICATION = "spark.sql.application"; + + // EMR configuration properties + public static final String EMR_CLUSTER = "emr.cluster"; + public static final String EMR_AUTH_TYPE = "emr.auth.type"; + public static final String EMR_REGION = "emr.auth.region"; + public static final String EMR_ROLE_ARN = "emr.auth.role_arn"; + public static final String EMR_ACCESS_KEY = "emr.auth.access_key"; + public static final String EMR_SECRET_KEY = "emr.auth.secret_key"; + + // Flint integration jar configuration properties + public static final String FLINT_INTEGRATION = "spark.datasource.flint.integration"; + public static final String FLINT_HOST = "spark.datasource.flint.host"; + public static final String FLINT_PORT = "spark.datasource.flint.port"; + public static final String FLINT_SCHEME = "spark.datasource.flint.scheme"; + public static final String FLINT_AUTH = "spark.datasource.flint.auth"; + public static final String FLINT_REGION = "spark.datasource.flint.region"; + + @Override + public DataSourceType getDataSourceType() { + return DataSourceType.SPARK; + } + + @Override + public DataSource createDataSource(DataSourceMetadata metadata) { + return new DataSource( + metadata.getName(), + DataSourceType.SPARK, + getStorageEngine(metadata.getProperties())); + } + + /** + * This function gets spark storage engine. + * + * @param requiredConfig spark config options + * @return spark storage engine object + */ + StorageEngine getStorageEngine(Map requiredConfig) { + SparkClient sparkClient; + if (requiredConfig.get(CONNECTOR_TYPE).equals(EMR)) { + sparkClient = + AccessController.doPrivileged((PrivilegedAction) () -> { + validateEMRConfigProperties(requiredConfig); + return new EmrClientImpl( + getEMRClient( + requiredConfig.get(EMR_ACCESS_KEY), + requiredConfig.get(EMR_SECRET_KEY), + requiredConfig.get(EMR_REGION)), + requiredConfig.get(EMR_CLUSTER), + new FlintHelper( + requiredConfig.get(FLINT_INTEGRATION), + requiredConfig.get(FLINT_HOST), + requiredConfig.get(FLINT_PORT), + requiredConfig.get(FLINT_SCHEME), + requiredConfig.get(FLINT_AUTH), + requiredConfig.get(FLINT_REGION)), + new SparkResponse(client, null, STEP_ID_FIELD), + requiredConfig.get(SPARK_SQL_APPLICATION)); + }); + } else { + throw new InvalidParameterException("Spark connector type is invalid."); + } + return new SparkStorageEngine(sparkClient); + } + + private void validateEMRConfigProperties(Map dataSourceMetadataConfig) + throws IllegalArgumentException { + if (dataSourceMetadataConfig.get(EMR_CLUSTER) == null + || dataSourceMetadataConfig.get(EMR_AUTH_TYPE) == null) { + throw new IllegalArgumentException("EMR config properties are missing."); + } else if (dataSourceMetadataConfig.get(EMR_AUTH_TYPE) + .equals(AuthenticationType.AWSSIGV4AUTH.getName()) + && (dataSourceMetadataConfig.get(EMR_ACCESS_KEY) == null + || dataSourceMetadataConfig.get(EMR_SECRET_KEY) == null)) { + throw new IllegalArgumentException("EMR auth keys are missing."); + } else if (!dataSourceMetadataConfig.get(EMR_AUTH_TYPE) + .equals(AuthenticationType.AWSSIGV4AUTH.getName())) { + throw new IllegalArgumentException("Invalid auth type."); + } + } + + private AmazonElasticMapReduce getEMRClient( + String emrAccessKey, String emrSecretKey, String emrRegion) { + return AmazonElasticMapReduceClientBuilder.standard() + .withCredentials(new AWSStaticCredentialsProvider( + new BasicAWSCredentials(emrAccessKey, emrSecretKey))) + .withRegion(emrRegion) + .build(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java new file mode 100644 index 0000000000..5151405db9 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import java.util.HashMap; +import java.util.Map; +import lombok.Getter; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.planner.DefaultImplementor; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanBuilder; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Spark table implementation. + * This can be constructed from SparkQueryRequest. + */ +public class SparkTable implements Table { + + private final SparkClient sparkClient; + + @Getter + private final SparkQueryRequest sparkQueryRequest; + + /** + * Constructor for entire Sql Request. + */ + public SparkTable(SparkClient sparkService, SparkQueryRequest sparkQueryRequest) { + this.sparkClient = sparkService; + this.sparkQueryRequest = sparkQueryRequest; + } + + @Override + public boolean exists() { + throw new UnsupportedOperationException( + "Exists operation is not supported in spark datasource"); + } + + @Override + public void create(Map schema) { + throw new UnsupportedOperationException( + "Create operation is not supported in spark datasource"); + } + + @Override + public Map getFieldTypes() { + return new HashMap<>(); + } + + @Override + public PhysicalPlan implement(LogicalPlan plan) { + SparkScan metricScan = + new SparkScan(sparkClient); + metricScan.setRequest(sparkQueryRequest); + return plan.accept(new DefaultImplementor(), metricScan); + } + + @Override + public TableScanBuilder createScanBuilder() { + return new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java new file mode 100644 index 0000000000..a94ac01f2f --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java @@ -0,0 +1,160 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_CLUSTER_ID; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.utils.TestUtils.getJson; + +import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; +import com.amazonaws.services.elasticmapreduce.model.AddJobFlowStepsResult; +import com.amazonaws.services.elasticmapreduce.model.DescribeStepResult; +import com.amazonaws.services.elasticmapreduce.model.Step; +import com.amazonaws.services.elasticmapreduce.model.StepStatus; +import lombok.SneakyThrows; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.helper.FlintHelper; +import org.opensearch.sql.spark.response.SparkResponse; + +@ExtendWith(MockitoExtension.class) +public class EmrClientImplTest { + + @Mock + private AmazonElasticMapReduce emr; + @Mock + private FlintHelper flint; + @Mock + private SparkResponse sparkResponse; + + @Test + @SneakyThrows + void testRunEmrApplication() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus stepStatus = new StepStatus(); + stepStatus.setState("COMPLETED"); + Step step = new Step(); + step.setStatus(stepStatus); + DescribeStepResult describeStepResult = new DescribeStepResult(); + describeStepResult.setStep(step); + when(emr.describeStep(any())).thenReturn(describeStepResult); + + EmrClientImpl emrClientImpl = new EmrClientImpl( + emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + emrClientImpl.runEmrApplication(QUERY); + } + + @Test + @SneakyThrows + void testRunEmrApplicationFailed() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus stepStatus = new StepStatus(); + stepStatus.setState("FAILED"); + Step step = new Step(); + step.setStatus(stepStatus); + DescribeStepResult describeStepResult = new DescribeStepResult(); + describeStepResult.setStep(step); + when(emr.describeStep(any())).thenReturn(describeStepResult); + + EmrClientImpl emrClientImpl = new EmrClientImpl( + emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + RuntimeException exception = Assertions.assertThrows(RuntimeException.class, + () -> emrClientImpl.runEmrApplication(QUERY)); + Assertions.assertEquals("Spark SQL application failed.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testRunEmrApplicationCancelled() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus stepStatus = new StepStatus(); + stepStatus.setState("CANCELLED"); + Step step = new Step(); + step.setStatus(stepStatus); + DescribeStepResult describeStepResult = new DescribeStepResult(); + describeStepResult.setStep(step); + when(emr.describeStep(any())).thenReturn(describeStepResult); + + EmrClientImpl emrClientImpl = new EmrClientImpl( + emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + RuntimeException exception = Assertions.assertThrows(RuntimeException.class, + () -> emrClientImpl.runEmrApplication(QUERY)); + Assertions.assertEquals("Spark SQL application failed.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testRunEmrApplicationRunnning() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus runningStatus = new StepStatus(); + runningStatus.setState("RUNNING"); + Step runningStep = new Step(); + runningStep.setStatus(runningStatus); + DescribeStepResult runningDescribeStepResult = new DescribeStepResult(); + runningDescribeStepResult.setStep(runningStep); + + StepStatus completedStatus = new StepStatus(); + completedStatus.setState("COMPLETED"); + Step completedStep = new Step(); + completedStep.setStatus(completedStatus); + DescribeStepResult completedDescribeStepResult = new DescribeStepResult(); + completedDescribeStepResult.setStep(completedStep); + + when(emr.describeStep(any())).thenReturn(runningDescribeStepResult) + .thenReturn(completedDescribeStepResult); + + EmrClientImpl emrClientImpl = new EmrClientImpl( + emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + emrClientImpl.runEmrApplication(QUERY); + } + + @Test + @SneakyThrows + void testSql() { + AddJobFlowStepsResult addStepsResult = new AddJobFlowStepsResult().withStepIds(EMR_CLUSTER_ID); + when(emr.addJobFlowSteps(any())).thenReturn(addStepsResult); + + StepStatus runningStatus = new StepStatus(); + runningStatus.setState("RUNNING"); + Step runningStep = new Step(); + runningStep.setStatus(runningStatus); + DescribeStepResult runningDescribeStepResult = new DescribeStepResult(); + runningDescribeStepResult.setStep(runningStep); + + StepStatus completedStatus = new StepStatus(); + completedStatus.setState("COMPLETED"); + Step completedStep = new Step(); + completedStep.setStatus(completedStatus); + DescribeStepResult completedDescribeStepResult = new DescribeStepResult(); + completedDescribeStepResult.setStep(completedStep); + + when(emr.describeStep(any())).thenReturn(runningDescribeStepResult) + .thenReturn(completedDescribeStepResult); + when(sparkResponse.getResultFromOpensearchIndex()) + .thenReturn(new JSONObject(getJson("select_query_response.json"))); + + EmrClientImpl emrClientImpl = new EmrClientImpl( + emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + emrClientImpl.sql(QUERY); + + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java new file mode 100644 index 0000000000..2b1020568a --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.constants; + +public class TestConstants { + public static final String QUERY = "select 1"; + public static final String EMR_CLUSTER_ID = "j-123456789"; +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java new file mode 100644 index 0000000000..18db5b9471 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.implementation.SparkSqlFunctionImplementation; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.spark.storage.SparkTable; + +@ExtendWith(MockitoExtension.class) +public class SparkSqlFunctionImplementationTest { + @Mock + private SparkClient client; + + @Test + void testValueOfAndTypeToString() { + FunctionName functionName = new FunctionName("sql"); + List namedArgumentExpressionList + = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation + = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + UnsupportedOperationException exception = assertThrows(UnsupportedOperationException.class, + () -> sparkSqlFunctionImplementation.valueOf()); + assertEquals("Spark defined function [sql] is only " + + "supported in SOURCE clause with spark connector catalog", exception.getMessage()); + assertEquals("sql(query=\"select 1\")", + sparkSqlFunctionImplementation.toString()); + assertEquals(ExprCoreType.STRUCT, sparkSqlFunctionImplementation.type()); + } + + @Test + void testApplyArguments() { + FunctionName functionName = new FunctionName("sql"); + List namedArgumentExpressionList + = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation + = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + SparkTable sparkTable + = (SparkTable) sparkSqlFunctionImplementation.applyArguments(); + assertNotNull(sparkTable.getSparkQueryRequest()); + SparkQueryRequest sparkQueryRequest + = sparkTable.getSparkQueryRequest(); + assertEquals(QUERY, sparkQueryRequest.getSql()); + } + + @Test + void testApplyArgumentsException() { + FunctionName functionName = new FunctionName("sql"); + List namedArgumentExpressionList + = List.of(DSL.namedArgument("query", DSL.literal(QUERY)), + DSL.namedArgument("tmp", DSL.literal(12345))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation + = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, + () -> sparkSqlFunctionImplementation.applyArguments()); + assertEquals("Invalid Function Argument:tmp", exception.getMessage()); + } + +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java new file mode 100644 index 0000000000..94c87602b7 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanBuilder; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanOperator; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; + +public class SparkSqlFunctionTableScanBuilderTest { + @Mock + private SparkClient sparkClient; + + @Mock + private LogicalProject logicalProject; + + @Test + void testBuild() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder + = new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); + TableScanOperator sqlFunctionTableScanOperator + = sparkSqlFunctionTableScanBuilder.build(); + Assertions.assertTrue(sqlFunctionTableScanOperator + instanceof SparkSqlFunctionTableScanOperator); + } + + @Test + void testPushProject() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder + = new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); + Assertions.assertTrue(sparkSqlFunctionTableScanBuilder.pushDownProject(logicalProject)); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java new file mode 100644 index 0000000000..f6807f9913 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.utils.TestUtils.getJson; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import lombok.SneakyThrows; +import org.json.JSONObject; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprBooleanValue; +import org.opensearch.sql.data.model.ExprByteValue; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTimestampValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanOperator; +import org.opensearch.sql.spark.request.SparkQueryRequest; + +@ExtendWith(MockitoExtension.class) +public class SparkSqlFunctionTableScanOperatorTest { + + @Mock + private SparkClient sparkClient; + + @Test + @SneakyThrows + void testEmptyQueryWithException() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenThrow(new IOException("Error Message")); + RuntimeException runtimeException + = assertThrows(RuntimeException.class, sparkSqlFunctionTableScanOperator::open); + assertEquals("Error fetching data from spark server: Error Message", + runtimeException.getMessage()); + } + + @Test + @SneakyThrows + void testClose() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + sparkSqlFunctionTableScanOperator.close(); + } + + @Test + @SneakyThrows + void testExplain() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + Assertions.assertEquals("sql(select 1)", + sparkSqlFunctionTableScanOperator.explain()); + } + + @Test + @SneakyThrows + void testQueryResponseIterator() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenReturn(new JSONObject(getJson("select_query_response.json"))); + sparkSqlFunctionTableScanOperator.open(); + assertTrue(sparkSqlFunctionTableScanOperator.hasNext()); + ExprTupleValue firstRow = new ExprTupleValue(new LinkedHashMap<>() { + { + put("1", new ExprIntegerValue(1)); + } + }); + assertEquals(firstRow, sparkSqlFunctionTableScanOperator.next()); + Assertions.assertFalse(sparkSqlFunctionTableScanOperator.hasNext()); + } + + @Test + @SneakyThrows + void testQueryResponseAllTypes() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenReturn(new JSONObject(getJson("all_data_type.json"))); + sparkSqlFunctionTableScanOperator.open(); + assertTrue(sparkSqlFunctionTableScanOperator.hasNext()); + ExprTupleValue firstRow = new ExprTupleValue(new LinkedHashMap<>() { + { + put("boolean", ExprBooleanValue.of(true)); + put("long", new ExprLongValue(922337203)); + put("integer", new ExprIntegerValue(2147483647)); + put("short", new ExprShortValue(32767)); + put("byte", new ExprByteValue(127)); + put("double", new ExprDoubleValue(9223372036854.775807)); + put("float", new ExprFloatValue(21474.83647)); + put("timestamp", new ExprDateValue("2023-07-01 10:31:30")); + put("date", new ExprTimestampValue("2023-07-01 10:31:30")); + put("string", new ExprStringValue("ABC")); + put("char", new ExprStringValue("A")); + } + }); + assertEquals(firstRow, sparkSqlFunctionTableScanOperator.next()); + Assertions.assertFalse(sparkSqlFunctionTableScanOperator.hasNext()); + } + + @Test + @SneakyThrows + void testQueryResponseInvalidDataType() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenReturn(new JSONObject(getJson("invalid_data_type.json"))); + + RuntimeException exception = Assertions.assertThrows(RuntimeException.class, + () -> sparkSqlFunctionTableScanOperator.open()); + Assertions.assertEquals("Result contains invalid data type", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testQuerySchema() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator + = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenReturn( + new JSONObject(getJson("select_query_response.json"))); + sparkSqlFunctionTableScanOperator.open(); + ArrayList columns = new ArrayList<>(); + columns.add(new ExecutionEngine.Schema.Column("1", "1", ExprCoreType.INTEGER)); + ExecutionEngine.Schema expectedSchema = new ExecutionEngine.Schema(columns); + assertEquals(expectedSchema, sparkSqlFunctionTableScanOperator.schema()); + } + +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java new file mode 100644 index 0000000000..e18fac36de --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java @@ -0,0 +1,141 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import java.util.List; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.implementation.SparkSqlFunctionImplementation; +import org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.spark.storage.SparkTable; + +@ExtendWith(MockitoExtension.class) +public class SparkSqlTableFunctionResolverTest { + @Mock + private SparkClient client; + + @Mock + private FunctionProperties functionProperties; + + @Test + void testResolve() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver + = new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions + = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + FunctionSignature functionSignature = new FunctionSignature(functionName, expressions + .stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution + = sqlTableFunctionResolver.resolve(functionSignature); + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + FunctionBuilder functionBuilder = resolution.getValue(); + TableFunctionImplementation functionImplementation + = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); + assertTrue(functionImplementation instanceof SparkSqlFunctionImplementation); + SparkTable sparkTable + = (SparkTable) functionImplementation.applyArguments(); + assertNotNull(sparkTable.getSparkQueryRequest()); + SparkQueryRequest sparkQueryRequest = + sparkTable.getSparkQueryRequest(); + assertEquals(QUERY, sparkQueryRequest.getSql()); + } + + @Test + void testArgumentsPassedByPosition() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver + = new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions + = List.of(DSL.namedArgument(null, DSL.literal(QUERY))); + FunctionSignature functionSignature = new FunctionSignature(functionName, expressions + .stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution + = sqlTableFunctionResolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + FunctionBuilder functionBuilder = resolution.getValue(); + TableFunctionImplementation functionImplementation + = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); + assertTrue(functionImplementation instanceof SparkSqlFunctionImplementation); + SparkTable sparkTable + = (SparkTable) functionImplementation.applyArguments(); + assertNotNull(sparkTable.getSparkQueryRequest()); + SparkQueryRequest sparkQueryRequest = + sparkTable.getSparkQueryRequest(); + assertEquals(QUERY, sparkQueryRequest.getSql()); + } + + @Test + void testMixedArgumentTypes() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver + = new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions + = List.of(DSL.namedArgument("query", DSL.literal(QUERY)), + DSL.namedArgument(null, DSL.literal(12345))); + FunctionSignature functionSignature = new FunctionSignature(functionName, expressions + .stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution + = sqlTableFunctionResolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + SemanticCheckException exception = assertThrows(SemanticCheckException.class, + () -> resolution.getValue().apply(functionProperties, expressions)); + + assertEquals("Arguments should be either passed by name or position", exception.getMessage()); + } + + @Test + void testWrongArgumentsSizeWhenPassedByName() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver + = new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions + = List.of(); + FunctionSignature functionSignature = new FunctionSignature(functionName, expressions + .stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution + = sqlTableFunctionResolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + SemanticCheckException exception = assertThrows(SemanticCheckException.class, + () -> resolution.getValue().apply(functionProperties, expressions)); + + assertEquals("Missing arguments:[query]", exception.getMessage()); + } + +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java new file mode 100644 index 0000000000..20210ea7e5 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.response; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.constants.TestConstants.EMR_CLUSTER_ID; +import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_INDEX_NAME; + +import java.util.Map; +import org.apache.lucene.search.TotalHits; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionFuture; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; + +@ExtendWith(MockitoExtension.class) +public class SparkResponseTest { + @Mock + private Client client; + @Mock + private SearchResponse searchResponse; + @Mock + private DeleteResponse deleteResponse; + @Mock + private SearchHit searchHit; + @Mock + private ActionFuture searchResponseActionFuture; + @Mock + private ActionFuture deleteResponseActionFuture; + + @Test + public void testGetResultFromOpensearchIndex() { + when(client.search(any())).thenReturn(searchResponseActionFuture); + when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + when(searchResponse.status()).thenReturn(RestStatus.OK); + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 1.0F)); + Mockito.when(searchHit.getSourceAsMap()) + .thenReturn(Map.of("stepId", EMR_CLUSTER_ID)); + + + when(client.delete(any())).thenReturn(deleteResponseActionFuture); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.DELETED); + + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + assertFalse(sparkResponse.getResultFromOpensearchIndex().isEmpty()); + } + + @Test + public void testInvalidSearchResponse() { + when(client.search(any())).thenReturn(searchResponseActionFuture); + when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); + when(searchResponse.status()).thenReturn(RestStatus.NO_CONTENT); + + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + RuntimeException exception = assertThrows(RuntimeException.class, + () -> sparkResponse.getResultFromOpensearchIndex()); + Assertions.assertEquals( + "Fetching result from " + SPARK_INDEX_NAME + + " index failed with status : " + RestStatus.NO_CONTENT, + exception.getMessage()); + } + + @Test + public void testSearchFailure() { + when(client.search(any())).thenThrow(RuntimeException.class); + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + assertThrows(RuntimeException.class, () -> sparkResponse.getResultFromOpensearchIndex()); + } + + @Test + public void testDeleteFailure() { + when(client.delete(any())).thenThrow(RuntimeException.class); + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + assertThrows(RuntimeException.class, () -> sparkResponse.deleteInSparkIndex("id")); + } + + @Test + public void testNotFoundDeleteResponse() { + when(client.delete(any())).thenReturn(deleteResponseActionFuture); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); + + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + RuntimeException exception = assertThrows(ResourceNotFoundException.class, + () -> sparkResponse.deleteInSparkIndex("123")); + Assertions.assertEquals("Spark result with id 123 doesn't exist", exception.getMessage()); + } + + @Test + public void testInvalidDeleteResponse() { + when(client.delete(any())).thenReturn(deleteResponseActionFuture); + when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); + when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOOP); + + SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); + RuntimeException exception = assertThrows(RuntimeException.class, + () -> sparkResponse.deleteInSparkIndex("123")); + Assertions.assertEquals( + "Deleting spark result information failed with : noop", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java new file mode 100644 index 0000000000..c57142f580 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.spark.client.SparkClient; + +@ExtendWith(MockitoExtension.class) +public class SparkScanTest { + @Mock + private SparkClient sparkClient; + + @Test + @SneakyThrows + void testQueryResponseIteratorForQueryRangeFunction() { + SparkScan sparkScan = new SparkScan(sparkClient); + sparkScan.getRequest().setSql(QUERY); + Assertions.assertFalse(sparkScan.hasNext()); + assertNull(sparkScan.next()); + } + + @Test + @SneakyThrows + void testExplain() { + SparkScan sparkScan = new SparkScan(sparkClient); + sparkScan.getRequest().setSql(QUERY); + assertEquals( + "SparkQueryRequest(sql=select 1)", + sparkScan.explain()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java new file mode 100644 index 0000000000..d42e123678 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Collection; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.DataSourceSchemaName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver; + +@ExtendWith(MockitoExtension.class) +public class SparkStorageEngineTest { + @Mock + private SparkClient client; + + @Test + public void getFunctions() { + SparkStorageEngine engine = new SparkStorageEngine(client); + Collection functionResolverCollection + = engine.getFunctions(); + assertNotNull(functionResolverCollection); + assertEquals(1, functionResolverCollection.size()); + assertTrue( + functionResolverCollection.iterator().next() instanceof SparkSqlTableFunctionResolver); + } + + @Test + public void getTable() { + SparkStorageEngine engine = new SparkStorageEngine(client); + RuntimeException exception = assertThrows(RuntimeException.class, + () -> engine.getTable(new DataSourceSchemaName("spark", "default"), "")); + assertEquals("Unable to get table from storage engine.", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java new file mode 100644 index 0000000000..c68adf2039 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java @@ -0,0 +1,182 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.opensearch.sql.spark.constants.TestConstants.EMR_CLUSTER_ID; + +import java.security.InvalidParameterException; +import java.util.HashMap; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.Client; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.storage.StorageEngine; + +@ExtendWith(MockitoExtension.class) +public class SparkStorageFactoryTest { + @Mock + private Settings settings; + + @Mock + private Client client; + + @Test + void testGetConnectorType() { + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + Assertions.assertEquals( + DataSourceType.SPARK, sparkStorageFactory.getDataSourceType()); + } + + @Test + @SneakyThrows + void testGetStorageEngine() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "access_key"); + properties.put("emr.auth.secret_key", "secret_key"); + properties.put("emr.auth.region", "region"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + StorageEngine storageEngine + = sparkStorageFactory.getStorageEngine(properties); + Assertions.assertTrue(storageEngine instanceof SparkStorageEngine); + } + + @Test + @SneakyThrows + void testInvalidConnectorType() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "random"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + InvalidParameterException exception = Assertions.assertThrows(InvalidParameterException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("Spark connector type is invalid.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingAuth() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR config properties are missing.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testUnsupportedEmrAuth() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "basic"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("Invalid auth type.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingCluster() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.auth.type", "awssigv4"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR config properties are missing.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingAuthKeys() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR auth keys are missing.", + exception.getMessage()); + } + + @Test + @SneakyThrows + void testMissingAuthSecretKey() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "test"); + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR auth keys are missing.", + exception.getMessage()); + } + + @Test + void testCreateDataSourceSuccess() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "access_key"); + properties.put("emr.auth.secret_key", "secret_key"); + properties.put("emr.auth.region", "region"); + properties.put("spark.datasource.flint.host", "localhost"); + properties.put("spark.datasource.flint.port", "9200"); + properties.put("spark.datasource.flint.scheme", "http"); + properties.put("spark.datasource.flint.auth", "false"); + properties.put("spark.datasource.flint.region", "us-west-2"); + + DataSourceMetadata metadata = new DataSourceMetadata(); + metadata.setName("spark"); + metadata.setConnector(DataSourceType.SPARK); + metadata.setProperties(properties); + + DataSource dataSource = new SparkStorageFactory(client, settings).createDataSource(metadata); + Assertions.assertTrue(dataSource.getStorageEngine() instanceof SparkStorageEngine); + } + + @Test + void testSetSparkJars() { + HashMap properties = new HashMap<>(); + properties.put("spark.connector", "emr"); + properties.put("spark.sql.application", "s3://spark/spark-sql-job.jar"); + properties.put("emr.cluster", EMR_CLUSTER_ID); + properties.put("emr.auth.type", "awssigv4"); + properties.put("emr.auth.access_key", "access_key"); + properties.put("emr.auth.secret_key", "secret_key"); + properties.put("emr.auth.region", "region"); + properties.put("spark.datasource.flint.integration", "s3://spark/flint-spark-integration.jar"); + + DataSourceMetadata metadata = new DataSourceMetadata(); + metadata.setName("spark"); + metadata.setConnector(DataSourceType.SPARK); + metadata.setProperties(properties); + + DataSource dataSource = new SparkStorageFactory(client, settings).createDataSource(metadata); + Assertions.assertTrue(dataSource.getStorageEngine() instanceof SparkStorageEngine); + } + +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java new file mode 100644 index 0000000000..39bd2eb199 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.opensearch.sql.spark.constants.TestConstants.QUERY; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanBuilder; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanOperator; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.read.TableScanBuilder; + +@ExtendWith(MockitoExtension.class) +public class SparkTableTest { + @Mock + private SparkClient client; + + @Test + void testUnsupportedOperation() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + SparkTable sparkTable = + new SparkTable(client, sparkQueryRequest); + + assertThrows(UnsupportedOperationException.class, sparkTable::exists); + assertThrows(UnsupportedOperationException.class, + () -> sparkTable.create(Collections.emptyMap())); + } + + @Test + void testCreateScanBuilderWithSqlTableFunction() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + SparkTable sparkTable = + new SparkTable(client, sparkQueryRequest); + TableScanBuilder tableScanBuilder = sparkTable.createScanBuilder(); + Assertions.assertNotNull(tableScanBuilder); + Assertions.assertTrue(tableScanBuilder instanceof SparkSqlFunctionTableScanBuilder); + } + + @Test + @SneakyThrows + void testGetFieldTypesFromSparkQueryRequest() { + SparkTable sparkTable + = new SparkTable(client, new SparkQueryRequest()); + Map expectedFieldTypes = new HashMap<>(); + Map fieldTypes = sparkTable.getFieldTypes(); + + assertEquals(expectedFieldTypes, fieldTypes); + verifyNoMoreInteractions(client); + assertNotNull(sparkTable.getSparkQueryRequest()); + } + + @Test + void testImplementWithSqlFunction() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + SparkTable sparkMetricTable = + new SparkTable(client, sparkQueryRequest); + PhysicalPlan plan = sparkMetricTable.implement( + new SparkSqlFunctionTableScanBuilder(client, sparkQueryRequest)); + assertTrue(plan instanceof SparkSqlFunctionTableScanOperator); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/utils/TestUtils.java b/spark/src/test/java/org/opensearch/sql/spark/utils/TestUtils.java new file mode 100644 index 0000000000..0630a85096 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/utils/TestUtils.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.utils; + +import java.io.IOException; +import java.util.Objects; + +public class TestUtils { + + /** + * Get Json document from the files in resources folder. + * @param filename filename. + * @return String. + * @throws IOException IOException. + */ + public static String getJson(String filename) throws IOException { + ClassLoader classLoader = TestUtils.class.getClassLoader(); + return new String( + Objects.requireNonNull(classLoader.getResourceAsStream(filename)).readAllBytes()); + } + +} + diff --git a/spark/src/test/resources/all_data_type.json b/spark/src/test/resources/all_data_type.json new file mode 100644 index 0000000000..a046912319 --- /dev/null +++ b/spark/src/test/resources/all_data_type.json @@ -0,0 +1,22 @@ +{ + "data": { + "result": [ + "{'boolean':true,'long':922337203,'integer':2147483647,'short':32767,'byte':127,'double':9223372036854.775807,'float':21474.83647,'timestamp':'2023-07-01 10:31:30','date':'2023-07-01 10:31:30','string':'ABC','char':'A'}" + ], + "schema": [ + "{'column_name':'boolean','data_type':'boolean'}", + "{'column_name':'long','data_type':'long'}", + "{'column_name':'integer','data_type':'integer'}", + "{'column_name':'short','data_type':'short'}", + "{'column_name':'byte','data_type':'byte'}", + "{'column_name':'double','data_type':'double'}", + "{'column_name':'float','data_type':'float'}", + "{'column_name':'timestamp','data_type':'timestamp'}", + "{'column_name':'date','data_type':'date'}", + "{'column_name':'string','data_type':'string'}", + "{'column_name':'char','data_type':'char'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +} diff --git a/spark/src/test/resources/invalid_data_type.json b/spark/src/test/resources/invalid_data_type.json new file mode 100644 index 0000000000..0eb08423c8 --- /dev/null +++ b/spark/src/test/resources/invalid_data_type.json @@ -0,0 +1,12 @@ +{ + "data": { + "result": [ + "{'struct_column':'struct_value'}" + ], + "schema": [ + "{'column_name':'struct_column','data_type':'struct'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +} diff --git a/spark/src/test/resources/invalid_response.json b/spark/src/test/resources/invalid_response.json new file mode 100644 index 0000000000..53222e0560 --- /dev/null +++ b/spark/src/test/resources/invalid_response.json @@ -0,0 +1,12 @@ +{ + "content": { + "result": [ + "{'1':1}" + ], + "schema": [ + "{'column_name':'1','data_type':'integer'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +} diff --git a/spark/src/test/resources/select_query_response.json b/spark/src/test/resources/select_query_response.json new file mode 100644 index 0000000000..24cb06b49e --- /dev/null +++ b/spark/src/test/resources/select_query_response.json @@ -0,0 +1,12 @@ +{ + "data": { + "result": [ + "{'1':1}" + ], + "schema": [ + "{'column_name':'1','data_type':'integer'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +} diff --git a/sql/build.gradle b/sql/build.gradle index 0f95b0850f..44dc37cf0f 100644 --- a/sql/build.gradle +++ b/sql/build.gradle @@ -45,7 +45,7 @@ dependencies { antlr "org.antlr:antlr4:4.7.1" implementation "org.antlr:antlr4-runtime:4.7.1" - implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre' + implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' implementation group: 'org.json', name: 'json', version:'20230227' implementation project(':common') implementation project(':core') diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 9b7aef6e27..e68edbbc58 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -328,7 +328,8 @@ nullNotnull ; functionCall - : scalarFunctionName LR_BRACKET functionArgs RR_BRACKET #scalarFunctionCall + : nestedFunctionName LR_BRACKET allTupleFields RR_BRACKET #nestedAllFunctionCall + | scalarFunctionName LR_BRACKET functionArgs RR_BRACKET #scalarFunctionCall | specificFunction #specificFunctionCall | windowFunctionClause #windowFunctionCall | aggregateFunction #aggregateFunctionCall @@ -813,6 +814,10 @@ columnName : qualifiedName ; +allTupleFields + : path=qualifiedName DOT STAR + ; + alias : ident ; diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index bad0543e02..7279553106 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -9,7 +9,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.between; import static org.opensearch.sql.ast.dsl.AstDSL.not; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; -import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.LIKE; @@ -41,6 +40,7 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.LikePredicateContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MathExpressionAtomContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MultiFieldRelevanceFunctionContext; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NestedAllFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NoFieldRelevanceFunctionContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NotExpressionContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NullLiteralContext; @@ -90,6 +90,7 @@ import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; @@ -102,6 +103,7 @@ import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AlternateMultiMatchQueryContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AndExpressionContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ColumnNameContext; @@ -150,6 +152,14 @@ public UnresolvedExpression visitNestedExpressionAtom(NestedExpressionAtomContex return visit(ctx.expression()); // Discard parenthesis around } + @Override + public UnresolvedExpression visitNestedAllFunctionCall( + NestedAllFunctionCallContext ctx) { + return new NestedAllTupleFields( + visitQualifiedName(ctx.allTupleFields().path).toString() + ); + } + @Override public UnresolvedExpression visitScalarFunctionCall(ScalarFunctionCallContext ctx) { return buildFunction(ctx.scalarFunctionName().getText(), ctx.functionArgs().functionArg()); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index e017bd8cd6..3e56a89754 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -39,6 +39,7 @@ import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.common.antlr.SyntaxCheckException; class AstBuilderTest extends AstBuilderTestBase { @@ -86,6 +87,19 @@ public void can_build_select_all_from_index() { assertThrows(SyntaxCheckException.class, () -> buildAST("SELECT *")); } + @Test + public void can_build_nested_select_all() { + assertEquals( + project( + relation("test"), + alias("nested(field.*)", + new NestedAllTupleFields("field") + ) + ), + buildAST("SELECT nested(field.*) FROM test") + ); + } + @Test public void can_build_select_all_and_fields_from_index() { assertEquals(