diff --git a/datasources/build.gradle b/datasources/build.gradle index ef52db2305..6da0fdae34 100644 --- a/datasources/build.gradle +++ b/datasources/build.gradle @@ -29,6 +29,7 @@ dependencies { testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.2.0' testImplementation 'org.junit.jupiter:junit-jupiter:5.6.2' + testImplementation group: 'org.opensearch.test', name: 'framework', version: "${opensearch_version}" } test { @@ -37,6 +38,7 @@ test { events "passed", "skipped", "failed" exceptionFormat "full" } + systemProperty 'tests.security.manager', 'false' } jacocoTestReport { diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/exceptions/ErrorMessage.java b/datasources/src/main/java/org/opensearch/sql/datasources/exceptions/ErrorMessage.java index 6dbd9bcfb5..c96640693c 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/exceptions/ErrorMessage.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/exceptions/ErrorMessage.java @@ -7,6 +7,7 @@ package org.opensearch.sql.datasources.exceptions; import com.google.gson.Gson; +import com.google.gson.GsonBuilder; import com.google.gson.JsonObject; import lombok.Getter; import org.opensearch.rest.RestStatus; @@ -65,7 +66,8 @@ public String toString() { JsonObject jsonObject = new JsonObject(); jsonObject.addProperty("status", status); jsonObject.add("error", getErrorAsJson()); - return new Gson().toJson(jsonObject); + Gson gson = new GsonBuilder().setPrettyPrinting().create(); + return gson.toJson(jsonObject); } private JsonObject getErrorAsJson() { diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java index c75170c355..95efd2e8f5 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java @@ -21,6 +21,7 @@ import java.util.Locale; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; import org.opensearch.client.node.NodeClient; import org.opensearch.rest.BaseRestHandler; @@ -224,6 +225,9 @@ public void onFailure(Exception e) { private void handleException(Exception e, RestChannel restChannel) { if (e instanceof DataSourceNotFoundException) { reportError(restChannel, e, NOT_FOUND); + } else if (e instanceof OpenSearchException) { + OpenSearchException exception = (OpenSearchException) e; + reportError(restChannel, exception, exception.status()); } else { LOG.error("Error happened during request handling", e); if (isClientError(e)) { diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryActionTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryActionTest.java new file mode 100644 index 0000000000..3333bf6792 --- /dev/null +++ b/datasources/src/test/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryActionTest.java @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.datasources.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; +import static org.opensearch.sql.datasources.utils.Scheduler.SQL_WORKER_THREAD_POOL_NAME; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.OpenSearchSecurityException; +import org.opensearch.action.ActionListener; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.RestStatus; +import org.opensearch.sql.datasources.model.transport.CreateDataSourceActionResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.FixedExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +@ExtendWith(MockitoExtension.class) +public class RestDataSourceQueryActionTest extends OpenSearchTestCase { + @Mock + private RestChannel channel; + private ThreadPool threadPool; + NodeClient client; + + private RestDataSourceQueryAction restDataSourceQueryAction; + + /** + * SetUp threadPool and NodeClient for all unit tests. + */ + @BeforeEach + public void setup() { + restDataSourceQueryAction = new RestDataSourceQueryAction(); + threadPool = new TestThreadPool(this.getClass().getName(), + new FixedExecutorBuilder( + Settings.EMPTY, + SQL_WORKER_THREAD_POOL_NAME, + 1, + 1000, + null)); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + + @Test + public void testSecurityExceptionWithCreateDataSource() throws Exception { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure( + new OpenSearchSecurityException("No Permissions for datasource Creation", + RestStatus.FORBIDDEN)); + return null; + }).when(client).execute(any(), any(), any()); + + RestRequest restRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withMethod( + RestRequest.Method.GET).build(); + restDataSourceQueryAction.handleRequest(restRequest, channel, client); + ArgumentCaptor responseArgumentCaptor + = ArgumentCaptor.forClass(RestResponse.class); + + verify(channel, timeout(1000).times(1)) + .sendResponse(responseArgumentCaptor.capture()); + RestResponse response = responseArgumentCaptor.getValue(); + Assertions.assertEquals(RestStatus.FORBIDDEN, response.status()); + Assertions.assertEquals("{\n" + + " \"status\": 403,\n" + + " \"error\": {\n" + + " \"type\": \"OpenSearchSecurityException\",\n" + + " \"reason\": \"There was internal problem at backend\",\n" + + " \"details\": \"No Permissions for datasource Creation\"\n" + + " }\n" + + "}", response.content().utf8ToString()); + } + + +}