Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix asynchronous Python exception propagation in StreamingPythonExecutor/CNNScoreVariants. #7402

Merged
merged 3 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class ExampleStreamingPythonExecutor extends ReadWalker {
final StreamingPythonScriptExecutor<String> pythonExecutor = new StreamingPythonScriptExecutor<>(true);

private List<String> batchList = new ArrayList<>(batchSize);
private boolean batchIsOutstanding = false;
private int batchCount = 0;

@Override
Expand All @@ -77,8 +78,11 @@ public void apply(GATKRead read, ReferenceContext referenceContext, FeatureConte
// Extract data from the read and accumulate, unless we've reached a batch size, in which case we
// kick off an asynchronous batch write.
if (batchCount == batchSize) {
pythonExecutor.waitForPreviousBatchCompletion();
if (batchIsOutstanding) {
pythonExecutor.waitForPreviousBatchCompletion();
}
startAsynchronousBatchWrite(); // start a new batch
batchIsOutstanding = true;
}
batchList.add(String.format(
"Read at %s:%d-%d:\n%s\n",
Expand All @@ -91,12 +95,15 @@ public void apply(GATKRead read, ReferenceContext referenceContext, FeatureConte
* @return Success indicator.
*/
public Object onTraversalSuccess() {
pythonExecutor.waitForPreviousBatchCompletion(); // wait for the previous batch to complete, if there is one
if (batchCount != 0) {
// If we have any accumulated reads that haven't been dispatched, start one last
// async batch write, and then wait for it to complete
if (batchIsOutstanding) {
pythonExecutor.waitForPreviousBatchCompletion();
}
startAsynchronousBatchWrite();
pythonExecutor.waitForPreviousBatchCompletion();
batchIsOutstanding = false;
}

return true;
Expand All @@ -108,6 +115,7 @@ private void startAsynchronousBatchWrite() {
pythonExecutor.startBatchWrite(
String.format("for i in range(%s):\n tempFile.write(tool.readDataFIFO())" + NL + NL, batchCount),
batchList);
batchIsOutstanding = true;
batchList = new ArrayList<>(batchSize);
batchCount = 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,15 @@ public void startBatchWrite(final String pythonCommand, final List<T> batchList)
* @return returns null if no previous work to complete, otherwise a completed Future
*/
public Future<Integer> waitForPreviousBatchCompletion() {
// wait for the batch queue to be completely written
// Rather than waiting for the asyncWriter Future to complete first, and THEN waiting for
// the ack, call waitForAck() first instead, because it will will detect and propagate any
// exception that occurs on the python side that causes it to stop pulling data from the
// FIFO (which in turn can result in the background thread blocking, thereby preventing the
// asyncWriter Future from ever completing). This is safer than waiting for the Future first,
// since the Future might never complete if the async writer thread is blocked.
waitForAck();
// now that we have the ack, verify that the async batch write completed
final Future<Integer> numberOfItemsWritten = asyncWriter.waitForPreviousBatchCompletion();
if (numberOfItemsWritten != null) {
// wait for the written items to be completely consumed
waitForAck();
}
return numberOfItemsWritten;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public Future<Integer> waitForPreviousBatchCompletion() {
*/
public boolean terminate() {
boolean isCancelled = true;
if (previousBatch != null) {
if (previousBatch != null && !previousBatch.isDone()) {
logger.warn("Cancelling outstanding asynchronous writing");
isCancelled = previousBatch.cancel(true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.broadinstitute.hellbender.testutils.VariantContextTestUtils;
import org.broadinstitute.hellbender.utils.Utils;

import org.broadinstitute.hellbender.utils.python.PythonScriptExecutorException;
import org.testng.Assert;
import org.testng.SkipException;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
Expand Down Expand Up @@ -65,6 +66,22 @@ public void testAllDefaultArgs() {
assertInfoFieldsAreClose(tempVcf, expectedVcf, GATKVCFConstants.CNN_1D_KEY);
}

@Test(groups = {"python"}, expectedExceptions = PythonScriptExecutorException.class)
public void testExceptionDuringAsyncBatch() {
final ArgumentsBuilder argsBuilder = new ArgumentsBuilder();
final File tempVcf = createTempFile("tester", ".vcf");
// the last variant in this vcf has a value of "." for the float attributes in the default CNN
// annotation set MQ, MQRankSum, ReadPosRankSum, SOR, VQSLOD, and QD
//TODO: move this into the large resources dir
final File malformedVCF = new File("src/test/resources/cnn_1d_chr20_subset_expected.badAnnotations.vcf");
argsBuilder.add(StandardArgumentDefinitions.VARIANT_LONG_NAME, malformedVCF)
.add(StandardArgumentDefinitions.OUTPUT_LONG_NAME, tempVcf.getPath())
.add(StandardArgumentDefinitions.REFERENCE_LONG_NAME, b37_reference_20_21)
.add(StandardArgumentDefinitions.ADD_OUTPUT_VCF_COMMANDLINE, "false");

runCommandLine(argsBuilder);
}

@Test(groups = {"python"})
public void testInferenceArchitecture() {
final boolean newExpectations = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import htsjdk.samtools.util.BufferedLineReader;
import org.broadinstitute.hellbender.GATKBaseTest;
import org.broadinstitute.hellbender.utils.runtime.AsynchronousStreamWriter;
import org.broadinstitute.hellbender.utils.runtime.ProcessOutput;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -180,8 +179,8 @@ public void testAsyncWriteService(final PythonScriptExecutor.PythonExecutableNam
for (int i = 0; i < ROUND_TRIP_COUNT; i++) {
if (i != 0 && (i % syncFrequency) == 0) {
// wait for the last batch to complete before we start a new one
streamingPythonExecutor.waitForPreviousBatchCompletion();
streamingPythonExecutor.startBatchWrite(String.format(PYTHON_TRANSFER_FIFO_TO_TEMP_FILE, count), fifoData);
streamingPythonExecutor.waitForPreviousBatchCompletion();
count = 0;
fifoData = new ArrayList<>(syncFrequency);
}
Expand All @@ -194,9 +193,6 @@ public void testAsyncWriteService(final PythonScriptExecutor.PythonExecutableNam
count++;
}

// wait for the writing to complete
streamingPythonExecutor.waitForPreviousBatchCompletion();

if (fifoData.size() != 0) {
streamingPythonExecutor.startBatchWrite(String.format(PYTHON_TRANSFER_FIFO_TO_TEMP_FILE, count), fifoData);
// wait for the writing to complete
Expand Down Expand Up @@ -254,6 +250,55 @@ public void testRaisePythonException(final PythonScriptExecutor.PythonExecutable
executeBadPythonCode(executableName,"raise Exception");
}

@Test(groups = "python", dataProvider = "supportedPythonVersions", dependsOnMethods = "testPythonExists",
expectedExceptions = PythonScriptExecutorException.class)
public void testRaiseAsynchronousPythonException(final PythonScriptExecutor.PythonExecutableName executableName) {
final StreamingPythonScriptExecutor<String> streamingPythonExecutor =
new StreamingPythonScriptExecutor<>(executableName, true);
Assert.assertNotNull(streamingPythonExecutor);
Assert.assertTrue(streamingPythonExecutor.start(Collections.emptyList(), true, null));

try {
streamingPythonExecutor.sendAsynchronousCommand("raise Exception" + NL);
streamingPythonExecutor.waitForAck();
} finally {
streamingPythonExecutor.terminate();
Assert.assertFalse(streamingPythonExecutor.getProcess().isAlive());
}
}

@Test(groups = "python", dataProvider = "supportedPythonVersions", dependsOnMethods = "testPythonExists",
expectedExceptions = PythonScriptExecutorException.class)
public void testRaiseAsynchronousBatchWritePythonException(final PythonScriptExecutor.PythonExecutableName executableName) {
final StreamingPythonScriptExecutor<String> streamingPythonExecutor =
new StreamingPythonScriptExecutor<>(executableName, true);
Assert.assertNotNull(streamingPythonExecutor);
Assert.assertTrue(streamingPythonExecutor.start(Collections.emptyList(), true, null));

try {
final int BATCH_SIZE = 1000;
final List<String> batchList = createLargeBatch(BATCH_SIZE);
streamingPythonExecutor.initStreamWriter(AsynchronousStreamWriter.stringSerializer);

final String batchCommand = String.format(
"for i in range(0, %d):"+ NL + "\t tool.readDataFIFO()" + NL + NL + "raise Exception" + NL,
BATCH_SIZE);
streamingPythonExecutor.startBatchWrite(batchCommand, batchList);
streamingPythonExecutor.waitForPreviousBatchCompletion();
} finally {
streamingPythonExecutor.terminate();
Assert.assertFalse(streamingPythonExecutor.getProcess().isAlive());
}
}

private List<String> createLargeBatch(final int batchSize) {
final List<String> batchList = new ArrayList<>(1000);
for (int i = 0; i < batchSize; i++) {
batchList.add(String.format("%d\n", i));
}
return batchList;
}

@Test(groups = "python", dataProvider="supportedPythonVersions", dependsOnMethods = "testPythonExists",
expectedExceptions = PythonScriptExecutorException.class)
public void testRaisePythonAssert(final PythonScriptExecutor.PythonExecutableName executableName) {
Expand Down
Loading