From d5880ddf6d9af228a9ce87603edc723201c25ae2 Mon Sep 17 00:00:00 2001 From: Googler Date: Mon, 28 Nov 2022 07:53:51 -0800 Subject: [PATCH] Move test sharding test cases to open source PiperOrigin-RevId: 491343307 Change-Id: Iab24e133f00b705870fe814efc8da55af5bfc360 --- .../analysis/test/TestActionBuilderTest.java | 130 ++++++++++++++++-- 1 file changed, 122 insertions(+), 8 deletions(-) diff --git a/src/test/java/com/google/devtools/build/lib/analysis/test/TestActionBuilderTest.java b/src/test/java/com/google/devtools/build/lib/analysis/test/TestActionBuilderTest.java index 9302de011d35f7..90ba01d57c2ba9 100644 --- a/src/test/java/com/google/devtools/build/lib/analysis/test/TestActionBuilderTest.java +++ b/src/test/java/com/google/devtools/build/lib/analysis/test/TestActionBuilderTest.java @@ -15,6 +15,7 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; @@ -23,6 +24,7 @@ import com.google.devtools.build.lib.analysis.AnalysisResult; import com.google.devtools.build.lib.analysis.ConfiguredAspect; import com.google.devtools.build.lib.analysis.ConfiguredTarget; +import com.google.devtools.build.lib.analysis.TransitiveInfoCollection; import com.google.devtools.build.lib.analysis.util.BuildViewTestCase; import com.google.devtools.build.lib.cmdline.Label; import com.google.devtools.build.lib.collect.nestedset.Depset; @@ -31,7 +33,9 @@ import com.google.devtools.build.lib.packages.TestTimeout; import java.io.IOException; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -55,13 +59,12 @@ public final void createBuildFile() throws Exception { "", "sh_test(name = 'small_test_2',", " srcs = ['small_test_2.sh'],", - " data = ['//testing/shbase:googletest.sh'],", " size = 'small',", " tags = ['tag2'])", "", "sh_test(name = 'large_test_1',", " srcs = ['large_test_1.sh'],", - " data = ['//testing/shbase:googletest.sh', ':xUnit'],", + " data = [':xUnit'],", " size = 'large',", " tags = ['tag1'])", "", @@ -72,6 +75,112 @@ public final void createBuildFile() throws Exception { "test_suite(name = 'smallTests', tags=['small'])"); } + private void assertSharded(ConfiguredTarget testRule, int expectSharding) { + ImmutableList testStatusList = getTestStatusArtifacts(testRule); + if (expectSharding == 0) { + Artifact testResult = Iterables.getOnlyElement(testStatusList); + TestRunnerAction action = (TestRunnerAction) getGeneratingAction(testResult); + assertThat(action.isSharded()).isFalse(); + assertThat(action.getExecutionSettings().getTotalShards()).isSameInstanceAs(0); + assertThat(action.getShardNum()).isSameInstanceAs(0); + return; + } + + int totalShards = testStatusList.size(); + Set shardNumbers = new HashSet<>(); + for (Artifact testResult : testStatusList) { + TestRunnerAction action = (TestRunnerAction) getGeneratingAction(testResult); + assertThat(action.isSharded()).isTrue(); + assertThat(action.getExecutionSettings().getTotalShards()).isSameInstanceAs(totalShards); + assertThat(action.getTestLog().getExecPath().getPathString()) + .endsWith( + String.format("shard_%d_of_%d/test.log", action.getShardNum() + 1, totalShards)); + shardNumbers.add(action.getShardNum()); + } + assertThat(shardNumbers).isEqualTo(sequenceSet(0, totalShards)); + assertThat(shardNumbers).hasSize(expectSharding); + } + + private static Set sequenceSet(int start, int end) { + Preconditions.checkArgument(end > start); + Set seqSet = new HashSet<>(); + for (int i = start; i < end; i++) { + seqSet.add(i); + } + return seqSet; + } + + private void writeJavaTests() throws IOException { + scratch.file( + "javatests/jt/BUILD", + "java_test(name = 'RGT',", + " srcs = ['RGT.java'])", + "java_test(name = 'RGT_none',", + " shard_count = 0,", + " srcs = ['RGT.java'])", + "java_test(name = 'RGT_many',", + " shard_count = 33,", + " srcs = ['RGT.java'])", + "java_test(name = 'RGT_small',", + " srcs = ['RGT.java'],", + " size = 'small')", + "", + "java_test(name = 'NoRunner',", + " main_class = 'NoTestRunnerTest.java',", + " use_testrunner = 0,", + " srcs = ['NoTestRunnerTest.java'])", + ""); + } + + @Test + public void testSharding() throws Exception { + useConfiguration("--test_sharding_strategy=explicit"); + + assertSharded(getConfiguredTarget("//tests:small_test_1"), 0); + assertSharded(getConfiguredTarget("//tests:large_test_1"), 0); + + writeJavaTests(); + assertSharded(getConfiguredTarget("//javatests/jt:NoRunner"), 0); + assertSharded(getConfiguredTarget("//javatests/jt:RGT"), 0); + assertSharded(getConfiguredTarget("//javatests/jt:RGT_small"), 0); + assertSharded(getConfiguredTarget("//javatests/jt:RGT_none"), 0); + + // Has an explicit "shard_count" attribute. + assertSharded(getConfiguredTarget("//javatests/jt:RGT_many"), 33); + } + + @Test + public void testShardingDisabled() throws Exception { + useConfiguration("--test_sharding_strategy=disabled"); + + assertSharded(getConfiguredTarget("//tests:small_test_1"), 0); + assertSharded(getConfiguredTarget("//tests:large_test_1"), 0); + + writeJavaTests(); + assertSharded(getConfiguredTarget("//javatests/jt:NoRunner"), 0); + assertSharded(getConfiguredTarget("//javatests/jt:RGT"), 0); + assertSharded(getConfiguredTarget("//javatests/jt:RGT_small"), 0); + assertSharded(getConfiguredTarget("//javatests/jt:RGT_none"), 0); + + // Has an explicit "shard_count" attribute. + assertSharded(getConfiguredTarget("//javatests/jt:RGT_many"), 0); + } + + @Test + public void testShardingForced() throws Exception { + useConfiguration("--test_sharding_strategy=forced=5"); + + assertSharded(getConfiguredTarget("//tests:small_test_1"), 5); + assertSharded(getConfiguredTarget("//tests:large_test_1"), 5); + + writeJavaTests(); + assertSharded(getConfiguredTarget("//javatests/jt:NoRunner"), 5); + assertSharded(getConfiguredTarget("//javatests/jt:RGT"), 5); + assertSharded(getConfiguredTarget("//javatests/jt:RGT_small"), 5); + assertSharded(getConfiguredTarget("//javatests/jt:RGT_none"), 5); + assertSharded(getConfiguredTarget("//javatests/jt:RGT_many"), 5); + } + @Test public void testFlakyAttributeValidation() throws Exception { scratch.file("flaky/BUILD", @@ -364,12 +473,6 @@ public void testOverrideExecGroup() throws Exception { assertThat(executionInfo).containsExactly("key", "good"); } - private ImmutableList getTestStatusArtifacts(String label) - throws Exception { - ConfiguredTarget target = getConfiguredTarget(label); - return target.getProvider(TestProvider.class).getTestParams().getTestStatusArtifacts(); - } - @Test public void testNonExecutableCoverageReportGenerator() throws Exception { useConfiguration( @@ -381,4 +484,15 @@ public void testNonExecutableCoverageReportGenerator() throws Exception { "sh_library(name = 'bad_cov_gen')", "cc_test(name = 'some_test')"); } + + private ImmutableList getTestStatusArtifacts(String label) + throws Exception { + ConfiguredTarget target = getConfiguredTarget(label); + return target.getProvider(TestProvider.class).getTestParams().getTestStatusArtifacts(); + } + + private ImmutableList getTestStatusArtifacts( + TransitiveInfoCollection target) { + return target.getProvider(TestProvider.class).getTestParams().getTestStatusArtifacts(); + } }