diff --git a/core/src/main/java/org/opensearch/sql/calcite/plan/Scannable.java b/core/src/main/java/org/opensearch/sql/calcite/plan/Scannable.java new file mode 100644 index 00000000000..2ec341f2eaa --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/calcite/plan/Scannable.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.plan; + +import org.apache.calcite.linq4j.Enumerable; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * The customized table scan is implemented in OpenSearch module, to invoke this scan() method in + * core module, we add this interface. Now the only implementation is CalciteEnumerableIndexScan. + * When a RelNode after optimization is a Scannable, we can directly invoke scan() method to get the + * result of the scan instead of codegen and compile via Linq4j expression. + */ +public interface Scannable { + + public Enumerable<@Nullable Object> scan(); +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java b/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java index c9365674f8d..14c8d8f369e 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java @@ -27,26 +27,35 @@ package org.opensearch.sql.calcite.utils; +import static java.util.Objects.requireNonNull; + import com.google.common.collect.ImmutableList; +import java.lang.reflect.Type; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; import java.time.Instant; import java.util.Properties; import java.util.function.Consumer; +import org.apache.calcite.adapter.enumerable.EnumerableConvention; +import org.apache.calcite.adapter.enumerable.EnumerableRel; import org.apache.calcite.adapter.java.JavaTypeFactory; import org.apache.calcite.avatica.AvaticaConnection; import org.apache.calcite.avatica.AvaticaFactory; +import org.apache.calcite.avatica.Meta; import org.apache.calcite.avatica.UnregisteredDriver; import org.apache.calcite.config.CalciteConnectionProperty; +import org.apache.calcite.interpreter.BindableConvention; import org.apache.calcite.interpreter.Bindables; import org.apache.calcite.jdbc.CalciteFactory; import org.apache.calcite.jdbc.CalciteJdbc41Factory; import org.apache.calcite.jdbc.CalcitePrepare; import org.apache.calcite.jdbc.CalciteSchema; import org.apache.calcite.jdbc.Driver; +import org.apache.calcite.linq4j.function.Function0; import org.apache.calcite.plan.Context; import org.apache.calcite.plan.Contexts; +import org.apache.calcite.plan.Convention; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptPlanner; import org.apache.calcite.plan.RelOptSchema; @@ -55,18 +64,23 @@ import org.apache.calcite.prepare.CalcitePrepareImpl; import org.apache.calcite.rel.RelHomogeneousShuttle; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.RelShuttle; import org.apache.calcite.rel.core.TableScan; import org.apache.calcite.rel.logical.LogicalTableScan; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.runtime.Bindable; import org.apache.calcite.runtime.Hook; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.server.CalciteServerStatement; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql2rel.SqlRexConvertletTable; import org.apache.calcite.tools.FrameworkConfig; import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.RelBuilder; @@ -74,11 +88,12 @@ import org.apache.calcite.util.Holder; import org.apache.calcite.util.Util; import org.opensearch.sql.calcite.CalcitePlanContext; +import org.opensearch.sql.calcite.plan.Scannable; import org.opensearch.sql.calcite.udf.udaf.NullableSqlAvgAggFunction; /** * Calcite Tools Helper. This class is used to create customized: 1. Connection 2. JavaTypeFactory - * 3. RelBuilder 4. RelRunner TODO delete it in future if possible. + * 3. RelBuilder 4. RelRunner 5. CalcitePreparingStmt. TODO delete it in future if possible. */ public class CalciteToolsHelper { @@ -153,6 +168,11 @@ public Connection connect( this.handler.onConnectionInit(connection); return connection; } + + @Override + protected Function0 createPrepareFactory() { + return OpenSearchPrepareImpl::new; + } } /** do nothing, just extend for a public construct for new */ @@ -214,6 +234,104 @@ public R perform( final RelOptCluster cluster = createCluster(planner, rexBuilder); return action.apply(cluster, catalogReader, prepareContext.getRootSchema().plus(), statement); } + + /** + * Customize CalcitePreparingStmt. Override {@link CalcitePrepareImpl#getPreparingStmt} and + * return {@link OpenSearchCalcitePreparingStmt} + */ + @Override + protected CalcitePrepareImpl.CalcitePreparingStmt getPreparingStmt( + CalcitePrepare.Context context, + Type elementType, + CalciteCatalogReader catalogReader, + RelOptPlanner planner) { + final JavaTypeFactory typeFactory = context.getTypeFactory(); + final EnumerableRel.Prefer prefer; + if (elementType == Object[].class) { + prefer = EnumerableRel.Prefer.ARRAY; + } else { + prefer = EnumerableRel.Prefer.CUSTOM; + } + final Convention resultConvention = + enableBindable ? BindableConvention.INSTANCE : EnumerableConvention.INSTANCE; + return new OpenSearchCalcitePreparingStmt( + this, + context, + catalogReader, + typeFactory, + context.getRootSchema(), + prefer, + createCluster(planner, new RexBuilder(typeFactory)), + resultConvention, + createConvertletTable()); + } + } + + /** + * Similar to {@link CalcitePrepareImpl.CalcitePreparingStmt}. Customize the logic to convert an + * EnumerableTableScan to BindableTableScan. + */ + public static class OpenSearchCalcitePreparingStmt + extends CalcitePrepareImpl.CalcitePreparingStmt { + + public OpenSearchCalcitePreparingStmt( + CalcitePrepareImpl prepare, + CalcitePrepare.Context context, + CatalogReader catalogReader, + RelDataTypeFactory typeFactory, + CalciteSchema schema, + EnumerableRel.Prefer prefer, + RelOptCluster cluster, + Convention resultConvention, + SqlRexConvertletTable convertletTable) { + super( + prepare, + context, + catalogReader, + typeFactory, + schema, + prefer, + cluster, + resultConvention, + convertletTable); + } + + @Override + protected PreparedResult implement(RelRoot root) { + Hook.PLAN_BEFORE_IMPLEMENTATION.run(root); + RelDataType resultType = root.rel.getRowType(); + boolean isDml = root.kind.belongsTo(SqlKind.DML); + if (root.rel instanceof Scannable scannable) { + final Bindable bindable = dataContext -> scannable.scan(); + + return new PreparedResultImpl( + resultType, + requireNonNull(parameterRowType, "parameterRowType"), + requireNonNull(fieldOrigins, "fieldOrigins"), + root.collation.getFieldCollations().isEmpty() + ? ImmutableList.of() + : ImmutableList.of(root.collation), + root.rel, + mapTableModOp(isDml, root.kind), + isDml) { + @Override + public String getCode() { + throw new UnsupportedOperationException(); + } + + @Override + public Bindable getBindable(Meta.CursorFactory cursorFactory) { + return bindable; + } + + @Override + public Type getElementType() { + return resultType.getFieldList().size() == 1 ? Object.class : Object[].class; + } + }; + } + return super.implement(root); + } } public static class OpenSearchRelRunners { diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLExplainIT.java index ffccf29e575..78eafc33950 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLExplainIT.java @@ -44,15 +44,34 @@ public void testExplainCommand() throws IOException { } @Test - public void testExplainCommandExtended() throws IOException { + public void testExplainCommandExtendedWithCodegen() throws IOException { var result = - executeWithReplace("explain extended source=test | where age = 20 | fields name, age"); + executeWithReplace( + "explain extended source=test | where age = 20 | join left=l right=r on l.age=r.age" + + " test"); assertTrue( result.contains( "public org.apache.calcite.linq4j.Enumerable bind(final" + " org.apache.calcite.DataContext root)")); } + @Test + public void testExplainCommandExtendedWithoutCodegen() throws IOException { + var result = + executeWithReplace("explain extended source=test | where age = 20 | fields name, age"); + if (isPushdownEnabled()) { + assertFalse( + result.contains( + "public org.apache.calcite.linq4j.Enumerable bind(final" + + " org.apache.calcite.DataContext root)")); + } else { + assertTrue( + result.contains( + "public org.apache.calcite.linq4j.Enumerable bind(final" + + " org.apache.calcite.DataContext root)")); + } + } + @Test public void testExplainCommandCost() throws IOException { var result = executeWithReplace("explain cost source=test | where age = 20 | fields name, age"); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/EnumerableIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/EnumerableIndexScanRule.java index 8cf021c9866..8dc7f3e187b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/EnumerableIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/EnumerableIndexScanRule.java @@ -27,7 +27,7 @@ public class EnumerableIndexScanRule extends ConverterRule { "EnumerableIndexScanRule") .withRuleFactory(EnumerableIndexScanRule::new); - /** Creates an EnumerableProjectRule. */ + /** Creates an EnumerableIndexScanRule. */ protected EnumerableIndexScanRule(Config config) { super(config); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java index e395804a7dd..6faaaa9db45 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteEnumerableIndexScan.java @@ -28,11 +28,13 @@ import org.apache.logging.log4j.Logger; import org.checkerframework.checker.nullness.qual.Nullable; import org.opensearch.sql.calcite.plan.OpenSearchRules; +import org.opensearch.sql.calcite.plan.Scannable; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; /** The physical relational operator representing a scan of an OpenSearchIndex type. */ -public class CalciteEnumerableIndexScan extends AbstractCalciteIndexScan implements EnumerableRel { +public class CalciteEnumerableIndexScan extends AbstractCalciteIndexScan + implements Scannable, EnumerableRel { private static final Logger LOG = LogManager.getLogger(CalciteEnumerableIndexScan.class); /** @@ -85,6 +87,7 @@ public Result implement(EnumerableRelImplementor implementor, Prefer pref) { * each time to avoid reusing source builder. That's because the source builder has stats like PIT * or SearchAfter recorded during previous search. */ + @Override public Enumerable<@Nullable Object> scan() { return new AbstractEnumerable<>() { @Override