Skip to content

Commit

Permalink
Move test sharding test cases to open source
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 491343307
Change-Id: Iab24e133f00b705870fe814efc8da55af5bfc360
  • Loading branch information
yuyue730 authored and copybara-github committed Nov 28, 2022
1 parent 7bd0ab6 commit d5880dd
Showing 1 changed file with 122 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import static com.google.common.truth.Truth.assertThat;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
Expand All @@ -23,6 +24,7 @@
import com.google.devtools.build.lib.analysis.AnalysisResult;
import com.google.devtools.build.lib.analysis.ConfiguredAspect;
import com.google.devtools.build.lib.analysis.ConfiguredTarget;
import com.google.devtools.build.lib.analysis.TransitiveInfoCollection;
import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.collect.nestedset.Depset;
Expand All @@ -31,7 +33,9 @@
import com.google.devtools.build.lib.packages.TestTimeout;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -55,13 +59,12 @@ public final void createBuildFile() throws Exception {
"",
"sh_test(name = 'small_test_2',",
" srcs = ['small_test_2.sh'],",
" data = ['//testing/shbase:googletest.sh'],",
" size = 'small',",
" tags = ['tag2'])",
"",
"sh_test(name = 'large_test_1',",
" srcs = ['large_test_1.sh'],",
" data = ['//testing/shbase:googletest.sh', ':xUnit'],",
" data = [':xUnit'],",
" size = 'large',",
" tags = ['tag1'])",
"",
Expand All @@ -72,6 +75,112 @@ public final void createBuildFile() throws Exception {
"test_suite(name = 'smallTests', tags=['small'])");
}

private void assertSharded(ConfiguredTarget testRule, int expectSharding) {
ImmutableList<Artifact.DerivedArtifact> testStatusList = getTestStatusArtifacts(testRule);
if (expectSharding == 0) {
Artifact testResult = Iterables.getOnlyElement(testStatusList);
TestRunnerAction action = (TestRunnerAction) getGeneratingAction(testResult);
assertThat(action.isSharded()).isFalse();
assertThat(action.getExecutionSettings().getTotalShards()).isSameInstanceAs(0);
assertThat(action.getShardNum()).isSameInstanceAs(0);
return;
}

int totalShards = testStatusList.size();
Set<Integer> shardNumbers = new HashSet<>();
for (Artifact testResult : testStatusList) {
TestRunnerAction action = (TestRunnerAction) getGeneratingAction(testResult);
assertThat(action.isSharded()).isTrue();
assertThat(action.getExecutionSettings().getTotalShards()).isSameInstanceAs(totalShards);
assertThat(action.getTestLog().getExecPath().getPathString())
.endsWith(
String.format("shard_%d_of_%d/test.log", action.getShardNum() + 1, totalShards));
shardNumbers.add(action.getShardNum());
}
assertThat(shardNumbers).isEqualTo(sequenceSet(0, totalShards));
assertThat(shardNumbers).hasSize(expectSharding);
}

private static Set<Integer> sequenceSet(int start, int end) {
Preconditions.checkArgument(end > start);
Set<Integer> seqSet = new HashSet<>();
for (int i = start; i < end; i++) {
seqSet.add(i);
}
return seqSet;
}

private void writeJavaTests() throws IOException {
scratch.file(
"javatests/jt/BUILD",
"java_test(name = 'RGT',",
" srcs = ['RGT.java'])",
"java_test(name = 'RGT_none',",
" shard_count = 0,",
" srcs = ['RGT.java'])",
"java_test(name = 'RGT_many',",
" shard_count = 33,",
" srcs = ['RGT.java'])",
"java_test(name = 'RGT_small',",
" srcs = ['RGT.java'],",
" size = 'small')",
"",
"java_test(name = 'NoRunner',",
" main_class = 'NoTestRunnerTest.java',",
" use_testrunner = 0,",
" srcs = ['NoTestRunnerTest.java'])",
"");
}

@Test
public void testSharding() throws Exception {
useConfiguration("--test_sharding_strategy=explicit");

assertSharded(getConfiguredTarget("//tests:small_test_1"), 0);
assertSharded(getConfiguredTarget("//tests:large_test_1"), 0);

writeJavaTests();
assertSharded(getConfiguredTarget("//javatests/jt:NoRunner"), 0);
assertSharded(getConfiguredTarget("//javatests/jt:RGT"), 0);
assertSharded(getConfiguredTarget("//javatests/jt:RGT_small"), 0);
assertSharded(getConfiguredTarget("//javatests/jt:RGT_none"), 0);

// Has an explicit "shard_count" attribute.
assertSharded(getConfiguredTarget("//javatests/jt:RGT_many"), 33);
}

@Test
public void testShardingDisabled() throws Exception {
useConfiguration("--test_sharding_strategy=disabled");

assertSharded(getConfiguredTarget("//tests:small_test_1"), 0);
assertSharded(getConfiguredTarget("//tests:large_test_1"), 0);

writeJavaTests();
assertSharded(getConfiguredTarget("//javatests/jt:NoRunner"), 0);
assertSharded(getConfiguredTarget("//javatests/jt:RGT"), 0);
assertSharded(getConfiguredTarget("//javatests/jt:RGT_small"), 0);
assertSharded(getConfiguredTarget("//javatests/jt:RGT_none"), 0);

// Has an explicit "shard_count" attribute.
assertSharded(getConfiguredTarget("//javatests/jt:RGT_many"), 0);
}

@Test
public void testShardingForced() throws Exception {
useConfiguration("--test_sharding_strategy=forced=5");

assertSharded(getConfiguredTarget("//tests:small_test_1"), 5);
assertSharded(getConfiguredTarget("//tests:large_test_1"), 5);

writeJavaTests();
assertSharded(getConfiguredTarget("//javatests/jt:NoRunner"), 5);
assertSharded(getConfiguredTarget("//javatests/jt:RGT"), 5);
assertSharded(getConfiguredTarget("//javatests/jt:RGT_small"), 5);
assertSharded(getConfiguredTarget("//javatests/jt:RGT_none"), 5);
assertSharded(getConfiguredTarget("//javatests/jt:RGT_many"), 5);
}

@Test
public void testFlakyAttributeValidation() throws Exception {
scratch.file("flaky/BUILD",
Expand Down Expand Up @@ -364,12 +473,6 @@ public void testOverrideExecGroup() throws Exception {
assertThat(executionInfo).containsExactly("key", "good");
}

private ImmutableList<Artifact.DerivedArtifact> getTestStatusArtifacts(String label)
throws Exception {
ConfiguredTarget target = getConfiguredTarget(label);
return target.getProvider(TestProvider.class).getTestParams().getTestStatusArtifacts();
}

@Test
public void testNonExecutableCoverageReportGenerator() throws Exception {
useConfiguration(
Expand All @@ -381,4 +484,15 @@ public void testNonExecutableCoverageReportGenerator() throws Exception {
"sh_library(name = 'bad_cov_gen')",
"cc_test(name = 'some_test')");
}

private ImmutableList<Artifact.DerivedArtifact> getTestStatusArtifacts(String label)
throws Exception {
ConfiguredTarget target = getConfiguredTarget(label);
return target.getProvider(TestProvider.class).getTestParams().getTestStatusArtifacts();
}

private ImmutableList<Artifact.DerivedArtifact> getTestStatusArtifacts(
TransitiveInfoCollection target) {
return target.getProvider(TestProvider.class).getTestParams().getTestStatusArtifacts();
}
}

0 comments on commit d5880dd

Please sign in to comment.