Skip to content

Commit

Permalink
Unify Driver#process method
Browse files Browse the repository at this point in the history
  • Loading branch information
arhimondr committed May 13, 2022
1 parent 2780c21 commit 1c4c2e5
Show file tree
Hide file tree
Showing 16 changed files with 67 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ public ListenableFuture<Void> processFor(Duration duration)
driver = this.driver;
}

return driver.processFor(duration);
return driver.processForDuration(duration);
}

@Override
Expand Down
49 changes: 27 additions & 22 deletions core/trino-main/src/main/java/io/trino/operator/Driver.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static java.lang.Boolean.TRUE;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.NANOSECONDS;

//
// NOTE: As a general strategy the methods should "stage" a change and only
Expand All @@ -66,6 +67,8 @@ public class Driver
{
private static final Logger log = Logger.get(Driver.class);

private static final Duration UNLIMITED_DURATION = new Duration(Long.MAX_VALUE, NANOSECONDS);

private final DriverContext driverContext;
private final List<Operator> activeOperators;
// this is present only for debugging
Expand Down Expand Up @@ -268,33 +271,52 @@ private void processNewSources()
currentSplitAssignment = newAssignment;
}

public ListenableFuture<Void> processFor(Duration duration)
public ListenableFuture<Void> processForDuration(Duration duration)
{
return process(duration, Integer.MAX_VALUE);
}

public ListenableFuture<Void> processForNumberOfIterations(int maxIterations)
{
return process(UNLIMITED_DURATION, maxIterations);
}

public ListenableFuture<Void> processUntilBlocked()
{
return process(UNLIMITED_DURATION, Integer.MAX_VALUE);
}

@VisibleForTesting
public ListenableFuture<Void> process(Duration maxRuntime, int maxIterations)
{
checkLockNotHeld("Cannot process for a duration while holding the driver lock");

requireNonNull(duration, "duration is null");
requireNonNull(maxRuntime, "maxRuntime is null");
checkArgument(maxIterations > 0, "maxIterations must be greater than zero");

// if the driver is blocked we don't need to continue
SettableFuture<Void> blockedFuture = driverBlockedFuture.get();
if (!blockedFuture.isDone()) {
return blockedFuture;
}

long maxRuntime = duration.roundTo(TimeUnit.NANOSECONDS);
long maxRuntimeInNanos = maxRuntime.roundTo(TimeUnit.NANOSECONDS);

Optional<ListenableFuture<Void>> result = tryWithLock(100, TimeUnit.MILLISECONDS, true, () -> {
OperationTimer operationTimer = createTimer();
driverContext.startProcessTimer();
driverContext.getYieldSignal().setWithDelay(maxRuntime, driverContext.getYieldExecutor());
driverContext.getYieldSignal().setWithDelay(maxRuntimeInNanos, driverContext.getYieldExecutor());
try {
long start = System.nanoTime();
int iterations = 0;
do {
ListenableFuture<Void> future = processInternal(operationTimer);
iterations++;
if (!future.isDone()) {
return updateDriverBlockedFuture(future);
}
}
while (System.nanoTime() - start < maxRuntime && !isFinishedInternal());
while (System.nanoTime() - start < maxRuntimeInNanos && iterations < maxIterations && !isFinishedInternal());
}
finally {
driverContext.getYieldSignal().reset();
Expand All @@ -305,23 +327,6 @@ public ListenableFuture<Void> processFor(Duration duration)
return result.orElse(NOT_BLOCKED);
}

public ListenableFuture<Void> process()
{
checkLockNotHeld("Cannot process while holding the driver lock");

// if the driver is blocked we don't need to continue
SettableFuture<Void> blockedFuture = driverBlockedFuture.get();
if (!blockedFuture.isDone()) {
return blockedFuture;
}

Optional<ListenableFuture<Void>> result = tryWithLock(100, TimeUnit.MILLISECONDS, true, () -> {
ListenableFuture<Void> future = processInternal(createTimer());
return updateDriverBlockedFuture(future);
});
return result.orElse(NOT_BLOCKED);
}

private OperationTimer createTimer()
{
return new OperationTimer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ public boolean load(List<UpdateRequest> requests)
ScheduledSplit split = new ScheduledSplit(0, sourcePlanNodeId, new Split(INDEX_CONNECTOR_ID, new IndexSplit(recordSetForLookupSource), Lifespan.taskWide()));
driver.updateSplitAssignment(new SplitAssignment(sourcePlanNodeId, ImmutableSet.of(split), true));
while (!driver.isFinished()) {
ListenableFuture<Void> process = driver.process();
ListenableFuture<Void> process = driver.processUntilBlocked();
checkState(process.isDone(), "Driver should never block");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ private boolean loadNextPage()
if (driver.isFinished()) {
return false;
}
driver.process();
driver.processForNumberOfIterations(1);
nextPage = extractNonEmptyPage(pageBuffer);
}
currentPage = nextPage;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ private MaterializedResultWithPlan executeInternal(Session session, @Language("S
}

if (!driver.isFinished()) {
driver.process();
driver.processForNumberOfIterations(1);
processed = true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public void testTableScanMemoryBlocking()
Split testSplit = new Split(new CatalogName("test"), new TestSplit(), Lifespan.taskWide());
driver.updateSplitAssignment(new SplitAssignment(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, testSplit)), true));

ListenableFuture<Void> blocked = driver.processFor(new Duration(1, NANOSECONDS));
ListenableFuture<Void> blocked = driver.processForDuration(new Duration(1, NANOSECONDS));

// the driver shouldn't block in the first call as it will be able to move a page between source and the sink operator
// but the operator should be blocked
Expand All @@ -128,7 +128,7 @@ public void testTableScanMemoryBlocking()
// in the subsequent calls both the driver and the operator should be blocked
// and they should stay blocked until more memory becomes available
for (int i = 0; i < 10; i++) {
blocked = driver.processFor(new Duration(1, NANOSECONDS));
blocked = driver.processForDuration(new Duration(1, NANOSECONDS));
assertFalse(blocked.isDone());
assertFalse(source.getOperatorContext().isWaitingForMemory().isDone());
}
Expand All @@ -140,7 +140,7 @@ public void testTableScanMemoryBlocking()
assertTrue(source.getOperatorContext().isWaitingForMemory().isDone());

// the driver shouldn't be blocked
blocked = driver.processFor(new Duration(1, NANOSECONDS));
blocked = driver.processForDuration(new Duration(1, NANOSECONDS));
assertTrue(blocked.isDone());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ private long runDriversUntilBlocked(Predicate<OperatorContext> reason)
// run driver, until it blocks
while (!isOperatorBlocked(drivers, reason)) {
for (Driver driver : drivers) {
driver.process();
driver.processForNumberOfIterations(1);
}
iterationsCount++;
}
Expand All @@ -454,7 +454,7 @@ private void assertDriversProgress(Predicate<OperatorContext> reason)
assertFalse(isOperatorBlocked(drivers, reason));
boolean progress = false;
for (Driver driver : drivers) {
ListenableFuture<Void> blocked = driver.process();
ListenableFuture<Void> blocked = driver.processUntilBlocked();
progress = progress | blocked.isDone();
}
// query should not block
Expand Down
20 changes: 10 additions & 10 deletions core/trino-main/src/test/java/io/trino/operator/TestDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void testNormalFinish()
assertSame(driver.getDriverContext(), driverContext);

assertFalse(driver.isFinished());
ListenableFuture<Void> blocked = driver.processFor(new Duration(1, TimeUnit.SECONDS));
ListenableFuture<Void> blocked = driver.processForDuration(new Duration(1, TimeUnit.SECONDS));
assertTrue(blocked.isDone());
assertTrue(driver.isFinished());

Expand All @@ -127,7 +127,7 @@ public void testConcurrentClose()
Operator sink = createSinkOperator(types);
Driver driver = Driver.createDriver(driverContext, source, sink);
// let these threads race
scheduledExecutor.submit(() -> driver.processFor(new Duration(1, TimeUnit.NANOSECONDS))); // don't want to call isFinishedInternal in processFor
scheduledExecutor.submit(() -> driver.processForDuration(new Duration(1, TimeUnit.NANOSECONDS))); // don't want to call isFinishedInternal in processFor
scheduledExecutor.submit(driver::close);
while (!driverContext.isDone()) {
Uninterruptibles.sleepUninterruptibly(1, TimeUnit.MILLISECONDS);
Expand Down Expand Up @@ -179,13 +179,13 @@ public void testAddSourceFinish()
assertSame(driver.getDriverContext(), driverContext);

assertFalse(driver.isFinished());
assertFalse(driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone());
assertFalse(driver.processForDuration(new Duration(1, TimeUnit.MILLISECONDS)).isDone());
assertFalse(driver.isFinished());

driver.updateSplitAssignment(new SplitAssignment(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, newMockSplit())), true));

assertFalse(driver.isFinished());
assertTrue(driver.processFor(new Duration(1, TimeUnit.SECONDS)).isDone());
assertTrue(driver.processForDuration(new Duration(1, TimeUnit.SECONDS)).isDone());
assertTrue(driver.isFinished());

assertTrue(sink.isFinished());
Expand All @@ -202,7 +202,7 @@ public void testBrokenOperatorCloseWhileProcessing()
assertSame(driver.getDriverContext(), driverContext);

// block thread in operator processing
Future<Boolean> driverProcessFor = executor.submit(() -> driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone());
Future<Boolean> driverProcessFor = executor.submit(() -> driver.processForDuration(new Duration(1, TimeUnit.MILLISECONDS)).isDone());
brokenOperator.waitForLocked();

driver.close();
Expand All @@ -229,7 +229,7 @@ public void testBrokenOperatorProcessWhileClosing()
});
brokenOperator.waitForLocked();

assertTrue(driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone());
assertTrue(driver.processForDuration(new Duration(1, TimeUnit.MILLISECONDS)).isDone());
assertTrue(driver.isFinished());

brokenOperator.unlock();
Expand All @@ -253,7 +253,7 @@ public void testMemoryRevocationRace()
// the table scan operator will request memory revocation with requestMemoryRevoking()
// while the driver is still not done with the processFor() method and before it moves to
// updateDriverBlockedFuture() method.
assertTrue(driver.processFor(new Duration(100, TimeUnit.MILLISECONDS)).isDone());
assertTrue(driver.processForDuration(new Duration(100, TimeUnit.MILLISECONDS)).isDone());
}

@Test
Expand All @@ -275,21 +275,21 @@ public void testBrokenOperatorAddSource()
Driver driver = Driver.createDriver(driverContext, source, brokenOperator);

// block thread in operator processing
Future<Boolean> driverProcessFor = executor.submit(() -> driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone());
Future<Boolean> driverProcessFor = executor.submit(() -> driver.processForDuration(new Duration(1, TimeUnit.MILLISECONDS)).isDone());
brokenOperator.waitForLocked();

assertSame(driver.getDriverContext(), driverContext);

assertFalse(driver.isFinished());
// processFor always returns NOT_BLOCKED, because DriveLockResult was not acquired
assertTrue(driver.processFor(new Duration(1, TimeUnit.MILLISECONDS)).isDone());
assertTrue(driver.processForDuration(new Duration(1, TimeUnit.MILLISECONDS)).isDone());
assertFalse(driver.isFinished());

driver.updateSplitAssignment(new SplitAssignment(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, newMockSplit())), true));

assertFalse(driver.isFinished());
// processFor always returns NOT_BLOCKED, because DriveLockResult was not acquired
assertTrue(driver.processFor(new Duration(1, TimeUnit.SECONDS)).isDone());
assertTrue(driver.processForDuration(new Duration(1, TimeUnit.SECONDS)).isDone());
assertFalse(driver.isFinished());

driver.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public void testSemiJoin(boolean hashEnabled)

Driver driver = Driver.createDriver(driverContext, buildOperator, setBuilderOperator);
while (!driver.isFinished()) {
driver.process();
driver.processUntilBlocked();
}

// probe
Expand Down Expand Up @@ -186,7 +186,7 @@ public void testSemiJoinOnVarcharType(boolean hashEnabled)

Driver driver = Driver.createDriver(driverContext, buildOperator, setBuilderOperator);
while (!driver.isFinished()) {
driver.process();
driver.processUntilBlocked();
}

// probe
Expand Down Expand Up @@ -281,7 +281,7 @@ public void testBuildSideNulls(boolean hashEnabled)

Driver driver = Driver.createDriver(driverContext, buildOperator, setBuilderOperator);
while (!driver.isFinished()) {
driver.process();
driver.processUntilBlocked();
}

// probe
Expand Down Expand Up @@ -337,7 +337,7 @@ public void testProbeSideNulls(boolean hashEnabled)

Driver driver = Driver.createDriver(driverContext, buildOperator, setBuilderOperator);
while (!driver.isFinished()) {
driver.process();
driver.processUntilBlocked();
}

// probe
Expand Down Expand Up @@ -397,7 +397,7 @@ public void testProbeAndBuildNulls(boolean hashEnabled)

Driver driver = Driver.createDriver(driverContext, buildOperator, setBuilderOperator);
while (!driver.isFinished()) {
driver.process();
driver.processUntilBlocked();
}

// probe
Expand Down Expand Up @@ -455,7 +455,7 @@ public void testMemoryLimit(boolean hashEnabled)

Driver driver = Driver.createDriver(driverContext, buildOperator, setBuilderOperator);
while (!driver.isFinished()) {
driver.process();
driver.processUntilBlocked();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public static BuildSideSetup setupBuildSide(
sinkOperatorFactory.noMoreOperators();

while (!sourceDriver.isFinished()) {
sourceDriver.process();
sourceDriver.processUntilBlocked();
}

// build side operator factories
Expand Down Expand Up @@ -214,7 +214,7 @@ public static void buildLookupSource(ExecutorService executor, BuildSideSetup bu

while (!lookupSourceProvider.isDone()) {
for (Driver buildDriver : buildDrivers) {
buildDriver.process();
buildDriver.processForNumberOfIterations(1);
}
}
getFutureValue(lookupSourceProvider).close();
Expand All @@ -232,7 +232,7 @@ public static void runDriverInThread(ExecutorService executor, Driver driver)
executor.execute(() -> {
if (!driver.isFinished()) {
try {
driver.process();
driver.processUntilBlocked();
}
catch (TrinoException e) {
driver.getDriverContext().failed(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.ExceededMemoryLimitException;
import io.trino.RowPagesBuilder;
import io.trino.execution.Lifespan;
Expand Down Expand Up @@ -78,7 +79,6 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

Expand Down Expand Up @@ -111,6 +111,7 @@
import static java.util.Collections.singletonList;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down Expand Up @@ -528,7 +529,7 @@ private void innerJoinWithSpill(boolean probeHashEnabled, List<WhenSpill> whenSp
while (!lookupSourceProvider.isDone()) {
for (int i = 0; i < buildOperatorCount; i++) {
checkErrors(taskStateMachine);
buildDrivers.get(i).process();
buildDrivers.get(i).processForNumberOfIterations(1);
HashBuilderOperator buildOperator = buildSideSetup.getBuildOperators().get(i);
if (whenSpill.get(i) == WhenSpill.DURING_BUILD && buildOperator.getOperatorContext().getReservedRevocableBytes() > 0) {
checkState(!lookupSourceProvider.isDone(), "Too late, LookupSource already done");
Expand Down Expand Up @@ -598,9 +599,7 @@ private void innerJoinWithSpill(boolean probeHashEnabled, List<WhenSpill> whenSp

private static void processRow(Driver joinDriver, TaskStateMachine taskStateMachine)
{
joinDriver.getDriverContext().getYieldSignal().setWithDelay(TimeUnit.SECONDS.toNanos(1), joinDriver.getDriverContext().getYieldExecutor());
joinDriver.process();
joinDriver.getDriverContext().getYieldSignal().reset();
joinDriver.process(new Duration(1, NANOSECONDS), 1);
checkErrors(taskStateMachine);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ private static NestedLoopJoinOperatorFactory newJoinOperatorFactoryWithCompleted
nestedLoopBuildOperatorFactory.noMoreOperators();

while (nestedLoopBuildOperator.isBlocked().isDone()) {
driver.process();
driver.processUntilBlocked();
}

return joinOperatorFactory;
Expand Down
Loading

0 comments on commit 1c4c2e5

Please sign in to comment.