Skip to content

Commit

Permalink
Fix wrong 503 error response code (#2493)
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <reddyvam@amazon.com>
  • Loading branch information
vamsimanohar authored Feb 5, 2024
1 parent 94bd664 commit 70d94e6
Show file tree
Hide file tree
Showing 18 changed files with 196 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
package org.opensearch.sql.datasources.rest;

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.rest.RestStatus.NOT_FOUND;
import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.rest.RestRequest.Method.*;

import com.google.common.collect.ImmutableList;
Expand All @@ -20,6 +20,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -282,6 +283,10 @@ private void handleException(Exception e, RestChannel restChannel) {
if (e instanceof DataSourceNotFoundException) {
MetricUtils.incrementNumericalMetric(MetricName.DATASOURCE_FAILED_REQ_COUNT_CUS);
reportError(restChannel, e, NOT_FOUND);
} else if (e instanceof OpenSearchSecurityException) {
MetricUtils.incrementNumericalMetric(MetricName.DATASOURCE_FAILED_REQ_COUNT_CUS);
OpenSearchSecurityException exception = (OpenSearchSecurityException) e;
reportError(restChannel, exception, exception.status());
} else if (e instanceof OpenSearchException) {
MetricUtils.incrementNumericalMetric(MetricName.DATASOURCE_FAILED_REQ_COUNT_SYS);
OpenSearchException exception = (OpenSearchException) e;
Expand All @@ -293,7 +298,7 @@ private void handleException(Exception e, RestChannel restChannel) {
reportError(restChannel, e, BAD_REQUEST);
} else {
MetricUtils.incrementNumericalMetric(MetricName.DATASOURCE_FAILED_REQ_COUNT_SYS);
reportError(restChannel, e, SERVICE_UNAVAILABLE);
reportError(restChannel, e, INTERNAL_SERVER_ERROR);
}
}
}
Expand Down
6 changes: 2 additions & 4 deletions docs/user/interfaces/asyncqueryinterface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ Sample Setting Value ::
"region":"eu-west-1",
"sparkSubmitParameter": "--conf spark.dynamicAllocation.enabled=false"
}'
If this setting is not configured during bootstrap, Async Query APIs will be disabled and it requires a cluster restart to enable them back again.
We make use of default aws credentials chain to make calls to the emr serverless application and also make sure the default credentials
have pass role permissions for emr-job-execution-role mentioned in the engine configuration.

The user must be careful before transitioning to a new application or region, as changing these parameters might lead to failures in the retrieval of results from previous async query jobs.
The system relies on the default AWS credentials chain for making calls to the EMR serverless application. It is essential to confirm that the default credentials possess the necessary permissions to pass the role required for EMR job execution, as specified in the engine configuration.
* ``applicationId``, ``executionRoleARN`` and ``region`` are required parameters.
* ``sparkSubmitParameter`` is an optional parameter. It can take the form ``--conf A=1 --conf B=2 ...``.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1755,7 +1755,7 @@ public void multipleIndicesOneNotExistWithoutHint() throws IOException {
Assert.fail("Expected exception, but call succeeded");
} catch (ResponseException e) {
Assert.assertEquals(
RestStatus.BAD_REQUEST.getStatus(), e.getResponse().getStatusLine().getStatusCode());
RestStatus.NOT_FOUND.getStatus(), e.getResponse().getStatusLine().getStatusCode());
final String entity = TestUtils.getResponseBody(e.getResponse());
Assert.assertThat(entity, containsString("\"type\": \"IndexNotFoundException\""));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public void testQueryEndpointShouldFail() throws IOException {
@Test
public void testQueryEndpointShouldFailWithNonExistIndex() throws IOException {
exceptionRule.expect(ResponseException.class);
exceptionRule.expect(hasProperty("response", statusCode(400)));
exceptionRule.expect(hasProperty("response", statusCode(404)));

client().performRequest(makePPLRequest("search source=non_exist_index"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public void queryExceedResourceLimitShouldFail() throws IOException {
String query = String.format("search source=%s age=20", TEST_INDEX_DOG);

ResponseException exception = expectThrows(ResponseException.class, () -> executeQuery(query));
assertEquals(503, exception.getResponse().getStatusLine().getStatusCode());
assertEquals(500, exception.getResponse().getStatusLine().getStatusCode());
assertThat(
exception.getMessage(),
Matchers.containsString("resource is not enough to run the" + " query, quit."));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void testCrossClusterSearchWithoutLocalFieldMappingShouldFail() throws IO
() -> executeQuery(String.format("search source=%s", TEST_INDEX_ACCOUNT_REMOTE)));
assertTrue(
exception.getMessage().contains("IndexNotFoundException")
&& exception.getMessage().contains("400 Bad Request"));
&& exception.getMessage().contains("404 Not Found"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ public void nested_function_all_subfields_in_wrong_clause() {
+ " \"details\": \"Invalid use of expression nested(message.*)\",\n"
+ " \"type\": \"UnsupportedOperationException\"\n"
+ " },\n"
+ " \"status\": 503\n"
+ " \"status\": 500\n"
+ "}"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
package org.opensearch.sql.legacy.plugin;

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.rest.RestStatus.OK;
import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;

import com.alibaba.druid.sql.parser.ParserException;
import com.google.common.collect.ImmutableList;
Expand All @@ -23,6 +23,7 @@
import java.util.regex.Pattern;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.inject.Injector;
Expand Down Expand Up @@ -171,21 +172,23 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
QueryAction queryAction = explainRequest(client, sqlRequest, format);
executeSqlRequest(request, queryAction, client, restChannel);
} catch (Exception e) {
logAndPublishMetrics(e);
reportError(restChannel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE);
handleException(restChannel, e);
}
},
(restChannel, exception) -> {
logAndPublishMetrics(exception);
reportError(
restChannel,
exception,
isClientError(exception) ? BAD_REQUEST : SERVICE_UNAVAILABLE);
});
this::handleException);
} catch (Exception e) {
logAndPublishMetrics(e);
return channel ->
reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE);
return channel -> handleException(channel, e);
}
}

private void handleException(RestChannel restChannel, Exception exception) {
logAndPublishMetrics(exception);
if (exception instanceof OpenSearchException) {
OpenSearchException openSearchException = (OpenSearchException) exception;
reportError(restChannel, openSearchException, openSearchException.status());
} else {
reportError(
restChannel, exception, isClientError(exception) ? BAD_REQUEST : INTERNAL_SERVER_ERROR);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.sql.legacy.plugin;

import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
Expand Down Expand Up @@ -84,8 +84,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
return channel ->
channel.sendResponse(
new BytesRestResponse(
SERVICE_UNAVAILABLE,
ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus())
INTERNAL_SERVER_ERROR,
ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus())
.toString()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,15 @@ public void collect(
indexToType.put(tableName, null);
} else if (sqlExprTableSource.getExpr() instanceof SQLBinaryOpExpr) {
SQLBinaryOpExpr sqlBinaryOpExpr = (SQLBinaryOpExpr) sqlExprTableSource.getExpr();
tableName = ((SQLIdentifierExpr) sqlBinaryOpExpr.getLeft()).getName();
SQLExpr leftSideOfExpression = sqlBinaryOpExpr.getLeft();
if (leftSideOfExpression instanceof SQLIdentifierExpr) {
tableName = ((SQLIdentifierExpr) sqlBinaryOpExpr.getLeft()).getName();
} else {
throw new ParserException(
"Left side of the expression ["
+ leftSideOfExpression.toString()
+ "] is expected to be an identifier");
}
SQLExpr rightSideOfExpression = sqlBinaryOpExpr.getRight();

// This assumes that right side of the expression is different name in query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.expr.SQLQueryExpr;
import com.alibaba.druid.sql.parser.ParserException;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -100,6 +101,15 @@ public void testSelectTheFieldWithConflictMappingShouldThrowException() {
rewriteTerm(sql);
}

@Test
public void testIssue2391_WithWrongBinaryOperation() {
String sql = "SELECT * from I_THINK/IM/A_URL";
exception.expect(ParserException.class);
exception.expectMessage(
"Left side of the expression [I_THINK / IM] is expected to be an identifier");
rewriteTerm(sql);
}

private String rewriteTerm(String sql) {
SQLQueryExpr sqlQueryExpr = SqlParserUtils.parse(sql);
sqlQueryExpr.accept(new TermFieldRewriter());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.rest.RestStatus.OK;
import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
Expand All @@ -17,7 +16,7 @@
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.OpenSearchException;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -116,8 +115,11 @@ public void onFailure(Exception e) {
channel,
INTERNAL_SERVER_ERROR,
"Failed to explain the query due to error: " + e.getMessage());
} else if (e instanceof OpenSearchSecurityException) {
OpenSearchSecurityException exception = (OpenSearchSecurityException) e;
} else if (e instanceof OpenSearchException) {
Metrics.getInstance()
.getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_CUS)
.increment();
OpenSearchException exception = (OpenSearchException) e;
reportError(channel, exception, exception.status());
} else {
LOG.error("Error happened during query handling", e);
Expand All @@ -130,7 +132,7 @@ public void onFailure(Exception e) {
Metrics.getInstance()
.getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS)
.increment();
reportError(channel, e, SERVICE_UNAVAILABLE);
reportError(channel, e, INTERNAL_SERVER_ERROR);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.sql.plugin.rest;

import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
Expand Down Expand Up @@ -79,8 +79,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
return channel ->
channel.sendResponse(
new BytesRestResponse(
SERVICE_UNAVAILABLE,
ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus())
INTERNAL_SERVER_ERROR,
ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus())
.toString()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.execution.statestore.StateStore;
Expand All @@ -22,6 +25,9 @@ public class OpensearchAsyncQueryJobMetadataStorageService

private final StateStore stateStore;

private static final Logger LOGGER =
LogManager.getLogger(OpensearchAsyncQueryJobMetadataStorageService.class);

@Override
public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) {
AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId();
Expand All @@ -30,8 +36,13 @@ public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) {

@Override
public Optional<AsyncQueryJobMetadata> getJobMetadata(String qid) {
AsyncQueryId queryId = new AsyncQueryId(qid);
return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName())
.apply(queryId.docId());
try {
AsyncQueryId queryId = new AsyncQueryId(qid);
return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName())
.apply(queryId.docId());
} catch (Exception e) {
LOGGER.error("Error while fetching the job metadata.", e);
throw new AsyncQueryNotFoundException(String.format("Invalid QueryId: %s", qid));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package org.opensearch.sql.spark.rest;

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.rest.RestStatus.TOO_MANY_REQUESTS;
import static org.opensearch.rest.RestRequest.Method.DELETE;
import static org.opensearch.rest.RestRequest.Method.GET;
Expand All @@ -26,10 +26,12 @@
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException;
import org.opensearch.sql.datasources.exceptions.ErrorMessage;
import org.opensearch.sql.datasources.utils.Scheduler;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.utils.MetricUtils;
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.leasemanager.ConcurrencyLimitExceededException;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction;
Expand Down Expand Up @@ -112,12 +114,12 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
}
}

private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient)
throws IOException {
MetricUtils.incrementNumericalMetric(MetricName.ASYNC_QUERY_CREATE_API_REQUEST_COUNT);
CreateAsyncQueryRequest submitJobRequest =
CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser());
return restChannel ->
private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient) {
return restChannel -> {
try {
MetricUtils.incrementNumericalMetric(MetricName.ASYNC_QUERY_CREATE_API_REQUEST_COUNT);
CreateAsyncQueryRequest submitJobRequest =
CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser());
Scheduler.schedule(
nodeClient,
() ->
Expand All @@ -140,6 +142,10 @@ public void onFailure(Exception e) {
handleException(e, restChannel, restRequest.method());
}
}));
} catch (Exception e) {
handleException(e, restChannel, restRequest.method());
}
};
}

private RestChannelConsumer executeGetAsyncQueryResultRequest(
Expand Down Expand Up @@ -187,7 +193,7 @@ private void handleException(
reportError(restChannel, e, BAD_REQUEST);
addCustomerErrorMetric(requestMethod);
} else {
reportError(restChannel, e, SERVICE_UNAVAILABLE);
reportError(restChannel, e, INTERNAL_SERVER_ERROR);
addSystemErrorMetric(requestMethod);
}
}
Expand Down Expand Up @@ -227,7 +233,10 @@ private void reportError(final RestChannel channel, final Exception e, final Res
}

private static boolean isClientError(Exception e) {
return e instanceof IllegalArgumentException || e instanceof IllegalStateException;
return e instanceof IllegalArgumentException
|| e instanceof IllegalStateException
|| e instanceof DataSourceNotFoundException
|| e instanceof AsyncQueryNotFoundException;
}

private void addSystemErrorMetric(RestRequest.Method requestMethod) {
Expand Down
Loading

0 comments on commit 70d94e6

Please sign in to comment.