diff --git a/presto-main/src/main/java/io/prestosql/dispatcher/QueuedStatementResource.java b/presto-main/src/main/java/io/prestosql/dispatcher/QueuedStatementResource.java index 7f79f9db100b..e5db058aa5cb 100644 --- a/presto-main/src/main/java/io/prestosql/dispatcher/QueuedStatementResource.java +++ b/presto-main/src/main/java/io/prestosql/dispatcher/QueuedStatementResource.java @@ -14,16 +14,17 @@ package io.prestosql.dispatcher; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import com.google.common.collect.Ordering; -import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.FluentFuture; import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; import io.airlift.log.Logger; import io.airlift.units.Duration; import io.prestosql.client.QueryError; import io.prestosql.client.QueryResults; import io.prestosql.client.StatementStats; import io.prestosql.execution.ExecutionFailureInfo; +import io.prestosql.execution.QueryManagerConfig; import io.prestosql.execution.QueryState; import io.prestosql.server.HttpRequestSessionContext; import io.prestosql.server.ServerConfig; @@ -35,8 +36,10 @@ import io.prestosql.spi.security.GroupProvider; import io.prestosql.spi.security.Identity; +import javax.annotation.Nullable; +import javax.annotation.PostConstruct; import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; import javax.inject.Inject; import javax.servlet.http.HttpServletRequest; import javax.ws.rs.DELETE; @@ -58,19 +61,21 @@ import javax.ws.rs.core.UriInfo; import java.net.URI; -import java.util.Map.Entry; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import static com.clearspring.analytics.util.Preconditions.checkState; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.concurrent.MoreFutures.addTimeout; -import static io.airlift.concurrent.Threads.threadsNamed; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse; import static io.prestosql.execution.QueryState.FAILED; import static io.prestosql.execution.QueryState.QUEUED; @@ -103,16 +108,16 @@ public class QueuedStatementResource private final Executor responseExecutor; private final ScheduledExecutorService timeoutExecutor; - private final ConcurrentMap queries = new ConcurrentHashMap<>(); - private final ScheduledExecutorService queryPurger = newSingleThreadScheduledExecutor(threadsNamed("dispatch-query-purger")); private final boolean compressionEnabled; + private final QueryManager queryManager; @Inject public QueuedStatementResource( GroupProvider groupProvider, DispatchManager dispatchManager, DispatchExecutor executor, - ServerConfig serverConfig) + ServerConfig serverConfig, + QueryManagerConfig queryManagerConfig) { this.groupProvider = requireNonNull(groupProvider, "groupProvider is null"); this.dispatchManager = requireNonNull(dispatchManager, "dispatchManager is null"); @@ -122,43 +127,20 @@ public QueuedStatementResource( this.timeoutExecutor = requireNonNull(executor, "timeoutExecutor is null").getScheduledExecutor(); this.compressionEnabled = requireNonNull(serverConfig, "serverConfig is null").isQueryResultsCompressionEnabled(); - queryPurger.scheduleWithFixedDelay( - () -> { - try { - // snapshot the queries before checking states to avoid registration race - for (Entry entry : ImmutableSet.copyOf(queries.entrySet())) { - if (!entry.getValue().isSubmissionFinished()) { - continue; - } - - // forget about this query if the query manager is no longer tracking it - if (!dispatchManager.isQueryRegistered(entry.getKey())) { - Query query = queries.remove(entry.getKey()); - if (query != null) { - try { - query.destroy(); - } - catch (Throwable e) { - // this catch clause is broad so query purger does not get stuck - log.warn(e, "Error destroying identity"); - } - } - } - } - } - catch (Throwable e) { - log.warn(e, "Error removing old queries"); - } - }, - 200, - 200, - MILLISECONDS); + requireNonNull(queryManagerConfig, "queryManagerConfig is null"); + queryManager = new QueryManager(queryManagerConfig.getClientTimeout()); + } + + @PostConstruct + public void start() + { + queryManager.initialize(dispatchManager); } @PreDestroy public void stop() { - queryPurger.shutdownNow(); + queryManager.destroy(); } @ResourceSecurity(AUTHENTICATED_USER) @@ -174,18 +156,25 @@ public Response postStatement( throw badRequest(BAD_REQUEST, "SQL statement is empty"); } + Query query = registerQuery(statement, servletRequest, httpHeaders); + + return createQueryResultsResponse(query.getQueryResults(query.getLastToken(), uriInfo)); + } + + private Query registerQuery(String statement, HttpServletRequest servletRequest, HttpHeaders httpHeaders) + { String remoteAddress = servletRequest.getRemoteAddr(); Optional identity = Optional.ofNullable((Identity) servletRequest.getAttribute(AUTHENTICATED_IDENTITY)); MultivaluedMap headers = httpHeaders.getRequestHeaders(); SessionContext sessionContext = new HttpRequestSessionContext(headers, remoteAddress, identity, groupProvider); Query query = new Query(statement, sessionContext, dispatchManager); - queries.put(query.getQueryId(), query); + queryManager.registerQuery(query); // let authentication filter know that identity lifecycle has been handed off servletRequest.setAttribute(AUTHENTICATED_IDENTITY, null); - return createQueryResultsResponse(query.getQueryResults(query.getLastToken(), uriInfo), compressionEnabled); + return query; } @ResourceSecurity(PUBLIC) @@ -202,25 +191,21 @@ public void getStatus( { Query query = getQuery(queryId, slug, token); - // wait for query to be dispatched, up to the wait timeout - ListenableFuture futureStateChange = addTimeout( - query.waitForDispatched(), - () -> null, - WAIT_ORDERING.min(MAX_WAIT_TIME, maxWait), - timeoutExecutor); - - // when state changes, fetch the next result - ListenableFuture queryResultsFuture = Futures.transform( - futureStateChange, - ignored -> query.getQueryResults(token, uriInfo), - responseExecutor); - - // transform to Response - ListenableFuture response = Futures.transform( - queryResultsFuture, - queryResults -> createQueryResultsResponse(queryResults, compressionEnabled), - directExecutor()); - bindAsyncResponse(asyncResponse, response, responseExecutor); + ListenableFuture future = getStatus(query, token, maxWait, uriInfo); + bindAsyncResponse(asyncResponse, future, responseExecutor); + } + + private ListenableFuture getStatus(Query query, long token, Duration maxWait, UriInfo uriInfo) + { + long waitMillis = WAIT_ORDERING.min(MAX_WAIT_TIME, maxWait).toMillis(); + + return FluentFuture.from(query.waitForDispatched()) + // wait for query to be dispatched, up to the wait timeout + .withTimeout(waitMillis, MILLISECONDS, timeoutExecutor) + .catching(TimeoutException.class, ignored -> null, directExecutor()) + // when state changes, fetch the next result + .transform(ignored -> query.getQueryResults(token, uriInfo), responseExecutor) + .transform(this::createQueryResultsResponse, directExecutor()); } @ResourceSecurity(PUBLIC) @@ -239,14 +224,14 @@ public Response cancelQuery( private Query getQuery(QueryId queryId, String slug, long token) { - Query query = queries.get(queryId); + Query query = queryManager.getQuery(queryId); if (query == null || !query.getSlug().isValid(QUEUED_QUERY, slug, token)) { throw badRequest(NOT_FOUND, "Query not found"); } return query; } - private static Response createQueryResultsResponse(QueryResults results, boolean compressionEnabled) + private Response createQueryResultsResponse(QueryResults results) { Response.ResponseBuilder builder = Response.ok(results); if (!compressionEnabled) { @@ -320,8 +305,9 @@ private static final class Query private final Slug slug = Slug.createNew(); private final AtomicLong lastToken = new AtomicLong(); - @GuardedBy("this") - private ListenableFuture querySubmissionFuture; + private final long initTime = System.nanoTime(); + private final AtomicReference submissionGate = new AtomicReference<>(); + private final SettableFuture creationFuture = SettableFuture.create(); public Query(String query, SessionContext sessionContext, DispatchManager dispatchManager) { @@ -346,27 +332,38 @@ public long getLastToken() return lastToken.get(); } - public synchronized boolean isSubmissionFinished() + public boolean tryAbandonSubmissionWithTimeout(Duration querySubmissionTimeout) + { + return Duration.nanosSince(initTime).compareTo(querySubmissionTimeout) >= 0 && submissionGate.compareAndSet(null, false); + } + + public boolean isSubmissionAbandoned() + { + return Boolean.FALSE.equals(submissionGate.get()); + } + + public boolean isCreated() { - return querySubmissionFuture != null && querySubmissionFuture.isDone(); + return creationFuture.isDone(); } private ListenableFuture waitForDispatched() { - // if query query submission has not finished, wait for it to finish - synchronized (this) { - if (querySubmissionFuture == null) { - querySubmissionFuture = dispatchManager.createQuery(queryId, slug, sessionContext, query); - } - if (!querySubmissionFuture.isDone()) { - return querySubmissionFuture; - } + submitIfNeeded(); + if (!creationFuture.isDone()) { + return nonCancellationPropagating(creationFuture); } - // otherwise, wait for the query to finish return dispatchManager.waitForDispatched(queryId); } + private void submitIfNeeded() + { + if (submissionGate.compareAndSet(null, true)) { + creationFuture.setFuture(dispatchManager.createQuery(queryId, slug, sessionContext, query)); + } + } + public QueryResults getQueryResults(long token, UriInfo uriInfo) { long lastToken = this.lastToken.get(); @@ -377,14 +374,12 @@ public QueryResults getQueryResults(long token, UriInfo uriInfo) // advance (or stay at) the token this.lastToken.compareAndSet(lastToken, token); - synchronized (this) { - // if query submission has not finished, return simple empty result - if (querySubmissionFuture == null || !querySubmissionFuture.isDone()) { - return createQueryResults( - token + 1, - uriInfo, - DispatchInfo.queued(NO_DURATION, NO_DURATION)); - } + // if query submission has not finished, return simple empty result + if (!creationFuture.isDone()) { + return createQueryResults( + token + 1, + uriInfo, + DispatchInfo.queued(NO_DURATION, NO_DURATION)); } Optional dispatchInfo = dispatchManager.getDispatchInfo(queryId); @@ -398,9 +393,9 @@ public QueryResults getQueryResults(long token, UriInfo uriInfo) return createQueryResults(token + 1, uriInfo, dispatchInfo.get()); } - public synchronized void cancel() + public void cancel() { - querySubmissionFuture.addListener(() -> dispatchManager.cancelQuery(queryId), directExecutor()); + creationFuture.addListener(() -> dispatchManager.cancelQuery(queryId), directExecutor()); } public void destroy() @@ -468,4 +463,82 @@ private QueryError toQueryError(ExecutionFailureInfo executionFailureInfo) executionFailureInfo.toFailureInfo()); } } + + @ThreadSafe + private static class QueryManager + { + private final ConcurrentMap queries = new ConcurrentHashMap<>(); + private final ScheduledExecutorService scheduledExecutorService = newSingleThreadScheduledExecutor(daemonThreadsNamed("drain-state-query-manager")); + + private final Duration querySubmissionTimeout; + + public QueryManager(Duration querySubmissionTimeout) + { + this.querySubmissionTimeout = requireNonNull(querySubmissionTimeout, "querySubmissionTimeout is null"); + } + + public void initialize(DispatchManager dispatchManager) + { + scheduledExecutorService.scheduleWithFixedDelay(() -> syncWith(dispatchManager), 200, 200, MILLISECONDS); + } + + public void destroy() + { + scheduledExecutorService.shutdownNow(); + } + + private void syncWith(DispatchManager dispatchManager) + { + queries.forEach((queryId, query) -> { + if (shouldBePurged(dispatchManager, query)) { + removeQuery(queryId); + } + }); + } + + private boolean shouldBePurged(DispatchManager dispatchManager, Query query) + { + if (query.isSubmissionAbandoned()) { + // Query submission was explicitly abandoned + return true; + } + if (query.tryAbandonSubmissionWithTimeout(querySubmissionTimeout)) { + // Query took too long to be submitted by the client + return true; + } + if (query.isCreated() && !dispatchManager.isQueryRegistered(query.getQueryId())) { + // Query was created in the DispatchManager, and DispatchManager has already purged the query + return true; + } + return false; + } + + private void removeQuery(QueryId queryId) + { + Optional.ofNullable(queries.remove(queryId)) + .ifPresent(QueryManager::destroyQuietly); + } + + private static void destroyQuietly(Query query) + { + try { + query.destroy(); + } + catch (Throwable t) { + log.error(t, "Error destroying query"); + } + } + + public void registerQuery(Query query) + { + Query existingQuery = queries.putIfAbsent(query.getQueryId(), query); + checkState(existingQuery == null, "Query already registered"); + } + + @Nullable + public Query getQuery(QueryId queryId) + { + return queries.get(queryId); + } + } }