Skip to content

Commit

Permalink
Check digest of local file to decide whether to re-download it.
Browse files Browse the repository at this point in the history
Normally, for a non-clean build (outputBase is not clean), skyframe is able to detect modifications, invalid the action, and rerun. Before rerun, it will delete the stales outputs. So we only use `path.exists()` to decide whether we should download an input.

However, there are some files  under the outputBase are not tracked by skyframe. In that case, we could wrongly use a staled output.

PiperOrigin-RevId: 483352318
Change-Id: I7e100100e6c3218630c5dc9bf3f900b2de232e0e
  • Loading branch information
coeuvre authored and copybara-github committed Oct 24, 2022
1 parent 3d2bb2a commit ebd6e58
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import io.reactivex.rxjava3.core.Flowable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -102,7 +103,26 @@ protected AbstractActionInputPrefetcher(
this.patternsToDownload = patternsToDownload;
}

protected abstract boolean shouldDownloadFile(Path path, FileArtifactValue metadata);
private boolean shouldDownloadFile(Path path, FileArtifactValue metadata) {
if (!path.exists()) {
return true;
}

// In the most cases, skyframe should be able to detect source files modifications and delete
// staled outputs before action execution. However, there are some cases where outputs are not
// tracked by skyframe. We compare the digest here to make sure we don't use staled files.
try {
byte[] digest = path.getFastDigest();
if (digest == null) {
digest = path.getDigest();
}
return !Arrays.equals(digest, metadata.getDigest());
} catch (IOException ignored) {
return true;
}
}

protected abstract boolean canDownloadFile(Path path, FileArtifactValue metadata);

/**
* Downloads file to the given path via its metadata.
Expand Down Expand Up @@ -189,13 +209,14 @@ private Completable prefetchInputTreeOrSymlink(
// TODO(tjgq): Only download individual files that were requested within the tree.
// This isn't straightforward because multiple tree artifacts may share the same output tree
// when a ctx.actions.symlink is involved.
if (treeMetadata == null || !shouldDownloadAnyTreeFiles(treeFiles, treeMetadata)) {
if (treeMetadata == null || !canDownloadAnyTreeFiles(treeFiles, treeMetadata)) {
return Completable.complete();
}

PathFragment prefetchExecPath = treeMetadata.getMaterializationExecPath().orElse(execPath);

Completable prefetch = prefetchInputTree(provider, prefetchExecPath, treeFiles, priority);
Completable prefetch =
prefetchInputTree(provider, prefetchExecPath, treeFiles, treeMetadata, priority);

// If prefetching to a different path, plant a symlink into it.
if (!prefetchExecPath.equals(execPath)) {
Expand All @@ -207,6 +228,16 @@ private Completable prefetchInputTreeOrSymlink(
return prefetch;
}

private boolean canDownloadAnyTreeFiles(
Iterable<TreeFileArtifact> treeFiles, FileArtifactValue metadata) {
for (TreeFileArtifact treeFile : treeFiles) {
if (canDownloadFile(treeFile.getPath(), metadata)) {
return true;
}
}
return false;
}

private boolean shouldDownloadAnyTreeFiles(
Iterable<TreeFileArtifact> treeFiles, FileArtifactValue metadata) {
for (TreeFileArtifact treeFile : treeFiles) {
Expand All @@ -221,6 +252,7 @@ private Completable prefetchInputTree(
MetadataProvider provider,
PathFragment execPath,
List<TreeFileArtifact> treeFiles,
FileArtifactValue treeMetadata,
Priority priority) {
Path treeRoot = execRoot.getRelative(execPath);
HashMap<TreeFileArtifact, Path> treeFileTmpPathMap = new HashMap<>();
Expand Down Expand Up @@ -293,7 +325,15 @@ private Completable prefetchInputTree(
}
}
});
return downloadCache.executeIfNot(treeRoot, download);
return downloadCache.executeIfNot(
treeRoot,
Completable.defer(
() -> {
if (shouldDownloadAnyTreeFiles(treeFiles, treeMetadata)) {
return download;
}
return Completable.complete();
}));
}

private Completable prefetchInputFileOrSymlink(
Expand All @@ -306,7 +346,7 @@ private Completable prefetchInputFileOrSymlink(
PathFragment execPath = input.getExecPath();

FileArtifactValue metadata = metadataProvider.getMetadata(input);
if (metadata == null || !shouldDownloadFile(execRoot.getRelative(execPath), metadata)) {
if (metadata == null || !canDownloadFile(execRoot.getRelative(execPath), metadata)) {
return Completable.complete();
}

Expand All @@ -332,7 +372,7 @@ private Completable prefetchInputFileOrSymlink(
* download finished.
*/
private Completable downloadFileRx(Path path, FileArtifactValue metadata, Priority priority) {
if (!shouldDownloadFile(path, metadata)) {
if (!canDownloadFile(path, metadata)) {
return Completable.complete();
}
return downloadFileNoCheckRx(path, metadata, priority);
Expand Down Expand Up @@ -373,7 +413,16 @@ private Completable downloadFileNoCheckRx(
// Set eager=false here because we want cleanup the download *after* upstream is
// disposed.
/* eager= */ false);
return downloadCache.executeIfNot(path, download);

return downloadCache.executeIfNot(
finalPath,
Completable.defer(
() -> {
if (shouldDownloadFile(finalPath, metadata)) {
return download;
}
return Completable.complete();
}));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ protected void prefetchVirtualActionInput(VirtualActionInput input) throws IOExc
}

@Override
protected boolean shouldDownloadFile(Path path, FileArtifactValue metadata) {
return metadata.isRemote() && !path.exists();
protected boolean canDownloadFile(Path path, FileArtifactValue metadata) {
return metadata.isRemote();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -156,6 +159,40 @@ protected Pair<SpecialArtifact, ImmutableList<TreeFileArtifact>> createRemoteTre

protected abstract AbstractActionInputPrefetcher createPrefetcher(Map<HashCode, byte[]> cas);

@Test
public void prefetchFiles_fileExists_doNotDownload() throws IOException, InterruptedException {
Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
Map<HashCode, byte[]> cas = new HashMap<>();
Artifact a = createRemoteArtifact("file", "hello world", metadata, cas);
FileSystemUtils.writeContent(a.getPath(), "hello world".getBytes(UTF_8));
MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
AbstractActionInputPrefetcher prefetcher = spy(createPrefetcher(cas));

wait(prefetcher.prefetchFiles(metadata.keySet(), metadataProvider));

verify(prefetcher, never()).doDownloadFile(any(), any(), any(), any());
assertThat(prefetcher.downloadedFiles()).containsExactly(a.getPath());
assertThat(prefetcher.downloadsInProgress()).isEmpty();
}

@Test
public void prefetchFiles_fileExistsButContentMismatches_download()
throws IOException, InterruptedException {
Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
Map<HashCode, byte[]> cas = new HashMap<>();
Artifact a = createRemoteArtifact("file", "hello world remote", metadata, cas);
FileSystemUtils.writeContent(a.getPath(), "hello world local".getBytes(UTF_8));
MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
AbstractActionInputPrefetcher prefetcher = spy(createPrefetcher(cas));

wait(prefetcher.prefetchFiles(metadata.keySet(), metadataProvider));

verify(prefetcher).doDownloadFile(any(), eq(a.getExecPath()), any(), any());
assertThat(prefetcher.downloadedFiles()).containsExactly(a.getPath());
assertThat(prefetcher.downloadsInProgress()).isEmpty();
assertThat(FileSystemUtils.readContent(a.getPath(), UTF_8)).isEqualTo("hello world remote");
}

@Test
public void prefetchFiles_downloadRemoteFiles() throws Exception {
Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
Expand Down

0 comments on commit ebd6e58

Please sign in to comment.