From 5a15ae4de8381d4ae167aeba351389a9dca2e997 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 24 Nov 2023 14:00:42 +0800 Subject: [PATCH] Test ppl transportaction Signed-off-by: zane-neo --- ml-algorithms/build.gradle | 3 + plugin/build.gradle | 24 ++++ .../ml/plugin/MachineLearningPlugin.java | 6 +- .../ml/rest/MyRestPPLQueryAction.java | 120 ++++++++++++++++++ 4 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index ceb10ca791..0a586e74ae 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -68,6 +68,9 @@ dependencies { configurations.all { resolutionStrategy.force 'com.google.protobuf:protobuf-java:3.21.9' +// resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-annotations:2.16.0' +// resolutionStrategy.force 'org.opensearch.client:opensearch-rest-client:2.12.0-SNAPSHOT' +// resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-databind:2.16.0' } jacocoTestReport { diff --git a/plugin/build.gradle b/plugin/build.gradle index af976e6f9f..826089b03f 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -48,6 +48,27 @@ dependencies { implementation project(':opensearch-ml-memory') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" + implementation (group: 'opensearch-sql', name: 'opensearch-sql', version: "${common_utils_version}") { + exclude module: 'legacy' + exclude module: 'opensearch' + exclude module: 'prometheus' + exclude module: 'datasources' + exclude module: 'spark' + } + implementation (group: 'opensearch-sql', name: 'ppl', version: "${common_utils_version}") { + exclude group: 'org.reflections', module: 'reflections' + exclude group: 'com.google.guava', module: 'guava' + exclude group: 'org.json', module: 'json' + exclude module: 'common' + exclude module: 'core' + } + implementation (group: 'opensearch-sql', name: 'protocol', version: "${common_utils_version}") { + exclude group: 'com.google.guava', module: 'guava' + exclude group: 'com.fasterxml.jackson.core', module: 'jackson-core' + exclude group: 'com.fasterxml.jackson.core', module: 'jackson-databind' + exclude group: 'com.fasterxml.jackson.dataformat', module: 'jackson-dataformat-cbor' + exclude group: 'com.google.code.gson', module: 'gson' + } implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" implementation "org.opensearch:common-utils:${common_utils_version}" implementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") @@ -330,6 +351,9 @@ configurations.all { resolutionStrategy.force 'org.apache.httpcomponents:httpclient:4.5.14' resolutionStrategy.force 'commons-codec:commons-codec:1.15' resolutionStrategy.force 'org.slf4j:slf4j-api:1.7.36' +// resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-annotations:2.16.0' +// resolutionStrategy.force 'org.opensearch.client:opensearch-rest-client:2.12.0-SNAPSHOT' +// resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-databind:2.16.0' } apply plugin: 'com.netflix.nebula.ospackage' diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 5c294de64d..187064cf6b 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -144,6 +144,8 @@ import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +//import org.opensearch.ml.rest.MyRestPPLQueryAction; +import org.opensearch.ml.rest.MyRestPPLQueryAction; import org.opensearch.ml.rest.RestMLCreateConnectorAction; import org.opensearch.ml.rest.RestMLDeleteConnectorAction; import org.opensearch.ml.rest.RestMLDeleteModelAction; @@ -554,6 +556,7 @@ public List getRestHandlers( RestMemoryGetInteractionsAction restListInteractionsAction = new RestMemoryGetInteractionsAction(); RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction(); RestMLUpdateConnectorAction restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); + MyRestPPLQueryAction restPPLQueryAction = new MyRestPPLQueryAction(); return ImmutableList .of( restMLStatsAction, @@ -587,7 +590,8 @@ public List getRestHandlers( restCreateInteractionAction, restListInteractionsAction, restDeleteConversationAction, - restMLUpdateConnectorAction + restMLUpdateConnectorAction, + restPPLQueryAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java b/plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java new file mode 100644 index 0000000000..d6f08c0444 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchSecurityException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.sql.plugin.request.PPLQueryRequestFactory; +import org.opensearch.sql.plugin.transport.PPLQueryAction; +import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; +import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; + + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.core.rest.RestStatus.OK; + +public class MyRestPPLQueryAction extends BaseRestHandler { + public static final String QUERY_API_ENDPOINT = "_ml/_ppl"; + public static final String EXPLAIN_API_ENDPOINT = "_ml/_ppl/_explain"; + public static final String LEGACY_QUERY_API_ENDPOINT = "_ml/_opendistro/_ppl"; + public static final String LEGACY_EXPLAIN_API_ENDPOINT = "_ml/_opendistro/_ppl/_explain"; + + private static final Logger LOG = LogManager.getLogger(); + + /** Constructor of RestPPLQueryAction. */ + public MyRestPPLQueryAction() { + super(); + } + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.POST, QUERY_API_ENDPOINT), new Route(RestRequest.Method.POST, EXPLAIN_API_ENDPOINT)); + } + +// @Override +// public List replacedRoutes() { +// return Arrays.asList( +// new ReplacedRoute( +// RestRequest.Method.POST, QUERY_API_ENDPOINT, +// RestRequest.Method.POST, LEGACY_QUERY_API_ENDPOINT), +// new ReplacedRoute( +// RestRequest.Method.POST, EXPLAIN_API_ENDPOINT, +// RestRequest.Method.POST, LEGACY_EXPLAIN_API_ENDPOINT)); +// } + + @Override + public String getName() { + return "ml_ppl_query_action"; + } + + @Override + protected Set responseParams() { + Set responseParams = new HashSet<>(super.responseParams()); + responseParams.addAll(Arrays.asList("format", "sanitize")); + return responseParams; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nodeClient) { + TransportPPLQueryRequest transportPPLQueryRequest = + new TransportPPLQueryRequest(PPLQueryRequestFactory.getPPLRequest(request)); + + return channel -> + nodeClient.execute( + PPLQueryAction.INSTANCE, + transportPPLQueryRequest, + new ActionListener<>() { + @Override + public void onResponse(TransportPPLQueryResponse response) { + sendResponse(channel, OK, response.getResult()); + } + + @Override + public void onFailure(Exception e) { + if (e instanceof IllegalAccessException) { + LOG.error("Error happened during query handling", e); + reportError(channel, e, BAD_REQUEST); + } else if (transportPPLQueryRequest.isExplainRequest()) { + LOG.error("Error happened during explain", e); + sendResponse( + channel, + INTERNAL_SERVER_ERROR, + "Failed to explain the query due to error: " + e.getMessage()); + } else if (e instanceof OpenSearchSecurityException) { + OpenSearchSecurityException exception = (OpenSearchSecurityException) e; + reportError(channel, exception, exception.status()); + } else { + LOG.error("Error happened during query handling", e); + reportError(channel, e, INTERNAL_SERVER_ERROR); + } + } + }); + } + + private void sendResponse(RestChannel channel, RestStatus status, String content) { + channel.sendResponse(new BytesRestResponse(status, "application/json; charset=UTF-8", content)); + } + + private void reportError(final RestChannel channel, final Exception e, final RestStatus status) { + channel.sendResponse(new BytesRestResponse(status, e.getMessage())); + } +}