diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/TypeCheckedTag.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/TypeCheckedTag.java index 979ecb1354afce..e50e8ad280e3f4 100644 --- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/TypeCheckedTag.java +++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/TypeCheckedTag.java @@ -81,14 +81,18 @@ public static TypeCheckedTag create( attrValues[attrIndex] = Attribute.valueToStarlark(nativeValue); } - // Check that all mandatory attributes have been specified. + // Check that all mandatory attributes have been specified, and fill in default values. for (int i = 0; i < attrValues.length; i++) { - if (tagClass.getAttributes().get(i).isMandatory() && attrValues[i] == null) { + Attribute attr = tagClass.getAttributes().get(i); + if (attr.isMandatory() && attrValues[i] == null) { throw ExternalDepsException.withMessage( Code.BAD_MODULE, "in tag at %s, mandatory attribute %s isn't being specified", tag.getLocation(), - tagClass.getAttributes().get(i).getPublicName()); + attr.getPublicName()); + } + if (attrValues[i] == null) { + attrValues[i] = Attribute.valueToStarlark(attr.getDefaultValueUnchecked()); } } return new TypeCheckedTag(tagClass, attrValues); @@ -106,11 +110,7 @@ public Object getValue(String name) throws EvalException { if (attrIndex == null) { return null; } - Object value = attrValues[attrIndex]; - if (value != null) { - return value; - } - return tagClass.getAttributes().get(attrIndex).getDefaultValueUnchecked(); + return attrValues[attrIndex]; } @Override diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java index b04d4c3056bffa..3521cd8d15ace6 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleExtensionResolutionTest.java @@ -720,6 +720,84 @@ public void labels_constructedInModuleExtension() throws Exception { .isEqualTo("requirements: get up at 6am. go to bed at 11pm."); } + /** Tests that a complex-typed attribute (here, string_list_dict) behaves well on a tag. */ + @Test + public void complexTypedAttribute() throws Exception { + scratch.file( + workspaceRoot.getRelative("MODULE.bazel").getPathString(), + "bazel_dep(name='data_repo', version='1.0')", + "ext = use_extension('//:defs.bzl', 'ext')", + "ext.tag(data={'foo':['val1','val2'],'bar':['val3','val4']})", + "use_repo(ext, 'foo', 'bar')"); + scratch.file( + workspaceRoot.getRelative("defs.bzl").getPathString(), + "load('@data_repo//:defs.bzl','data_repo')", + "tag = tag_class(attrs = {'data':attr.string_list_dict()})", + "def _ext_impl(ctx):", + " for mod in ctx.modules:", + " for tag in mod.tags.tag:", + " for key in tag.data:", + " data_repo(name=key,data=','.join(tag.data[key]))", + "ext = module_extension(implementation=_ext_impl, tag_classes={'tag':tag})"); + scratch.file(workspaceRoot.getRelative("BUILD").getPathString()); + scratch.file( + workspaceRoot.getRelative("data.bzl").getPathString(), + "load('@foo//:data.bzl', foo_data='data')", + "load('@bar//:data.bzl', bar_data='data')", + "data = 'foo:'+foo_data+' bar:'+bar_data"); + + SkyKey skyKey = BzlLoadValue.keyForBuild(Label.parseAbsoluteUnchecked("//:data.bzl")); + EvaluationResult result = + evaluator.evaluate(ImmutableList.of(skyKey), evaluationContext); + if (result.hasError()) { + throw result.getError().getException(); + } + assertThat(result.get(skyKey).getModule().getGlobal("data")) + .isEqualTo("foo:val1,val2 bar:val3,val4"); + } + + /** + * Tests that a complex-typed attribute (here, string_list_dict) behaves well when it has a + * default value and is omitted in a tag. + */ + @Test + public void complexTypedAttribute_default() throws Exception { + scratch.file( + workspaceRoot.getRelative("MODULE.bazel").getPathString(), + "bazel_dep(name='data_repo', version='1.0')", + "ext = use_extension('//:defs.bzl', 'ext')", + "ext.tag()", + "use_repo(ext, 'foo', 'bar')"); + scratch.file( + workspaceRoot.getRelative("defs.bzl").getPathString(), + "load('@data_repo//:defs.bzl','data_repo')", + "tag = tag_class(attrs = {", + " 'data': attr.string_list_dict(", + " default = {'foo':['val1','val2'],'bar':['val3','val4']},", + ")})", + "def _ext_impl(ctx):", + " for mod in ctx.modules:", + " for tag in mod.tags.tag:", + " for key in tag.data:", + " data_repo(name=key,data=','.join(tag.data[key]))", + "ext = module_extension(implementation=_ext_impl, tag_classes={'tag':tag})"); + scratch.file(workspaceRoot.getRelative("BUILD").getPathString()); + scratch.file( + workspaceRoot.getRelative("data.bzl").getPathString(), + "load('@foo//:data.bzl', foo_data='data')", + "load('@bar//:data.bzl', bar_data='data')", + "data = 'foo:'+foo_data+' bar:'+bar_data"); + + SkyKey skyKey = BzlLoadValue.keyForBuild(Label.parseAbsoluteUnchecked("//:data.bzl")); + EvaluationResult result = + evaluator.evaluate(ImmutableList.of(skyKey), evaluationContext); + if (result.hasError()) { + throw result.getError().getException(); + } + assertThat(result.get(skyKey).getModule().getGlobal("data")) + .isEqualTo("foo:val1,val2 bar:val3,val4"); + } + @Test public void generatedReposHaveCorrectMappings() throws Exception { scratch.file( diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/TypeCheckedTagTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/TypeCheckedTagTest.java index 14032b2e013c4e..56d0a11e474ada 100644 --- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/TypeCheckedTagTest.java +++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/TypeCheckedTagTest.java @@ -23,6 +23,8 @@ import static com.google.devtools.build.lib.packages.Attribute.attr; import static org.junit.Assert.assertThrows; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.devtools.build.lib.cmdline.Label; import com.google.devtools.build.lib.packages.Attribute.AllowedValueSet; import com.google.devtools.build.lib.packages.BuildType; @@ -30,8 +32,13 @@ import com.google.devtools.build.lib.packages.Type; import com.google.devtools.build.lib.util.FileTypeSet; import java.util.HashMap; +import net.starlark.java.eval.Dict; +import net.starlark.java.eval.Mutability; +import net.starlark.java.eval.Starlark; import net.starlark.java.eval.StarlarkInt; import net.starlark.java.eval.StarlarkList; +import net.starlark.java.eval.StarlarkSemantics; +import net.starlark.java.eval.Structure; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -40,6 +47,15 @@ @RunWith(JUnit4.class) public class TypeCheckedTagTest { + private static Object getattr(Structure structure, String fieldName) throws Exception { + return Starlark.getattr( + Mutability.IMMUTABLE, + StarlarkSemantics.DEFAULT, + structure, + fieldName, + /*defaultValue=*/ null); + } + @Test public void basic() throws Exception { TypeCheckedTag typeCheckedTag = @@ -48,7 +64,7 @@ public void basic() throws Exception { buildTag("tag_name").addAttr("foo", StarlarkInt.of(3)).build(), /*labelConversionContext=*/ null); assertThat(typeCheckedTag.getFieldNames()).containsExactly("foo"); - assertThat(typeCheckedTag.getValue("foo")).isEqualTo(StarlarkInt.of(3)); + assertThat(getattr(typeCheckedTag, "foo")).isEqualTo(StarlarkInt.of(3)); } @Test @@ -66,7 +82,7 @@ public void label() throws Exception { createRepositoryMapping(createModuleKey("test", "1.0"), "repo", "other_repo"), new HashMap<>())); assertThat(typeCheckedTag.getFieldNames()).containsExactly("foo"); - assertThat(typeCheckedTag.getValue("foo")) + assertThat(getattr(typeCheckedTag, "foo")) .isEqualTo( StarlarkList.immutableOf( Label.parseAbsoluteUnchecked("@myrepo//mypkg:thing1"), @@ -74,6 +90,39 @@ public void label() throws Exception { Label.parseAbsoluteUnchecked("@other_repo//pkg:thing3"))); } + @Test + public void label_withoutDefaultValue() throws Exception { + TypeCheckedTag typeCheckedTag = + TypeCheckedTag.create( + createTagClass( + attr("foo", BuildType.LABEL).allowedFileTypes(FileTypeSet.ANY_FILE).build()), + buildTag("tag_name").build(), + new LabelConversionContext( + Label.parseAbsoluteUnchecked("@myrepo//mypkg:defs.bzl"), + createRepositoryMapping(createModuleKey("test", "1.0"), "repo", "other_repo"), + new HashMap<>())); + assertThat(typeCheckedTag.getFieldNames()).containsExactly("foo"); + assertThat(getattr(typeCheckedTag, "foo")).isEqualTo(Starlark.NONE); + } + + @Test + public void stringListDict_default() throws Exception { + TypeCheckedTag typeCheckedTag = + TypeCheckedTag.create( + createTagClass( + attr("foo", Type.STRING_LIST_DICT) + .value(ImmutableMap.of("key", ImmutableList.of("value1", "value2"))) + .build()), + buildTag("tag_name").build(), + null); + assertThat(typeCheckedTag.getFieldNames()).containsExactly("foo"); + assertThat(getattr(typeCheckedTag, "foo")) + .isEqualTo( + Dict.builder() + .put("key", StarlarkList.immutableOf("value1", "value2")) + .buildImmutable()); + } + @Test public void multipleAttributesAndDefaults() throws Exception { TypeCheckedTag typeCheckedTag = @@ -88,9 +137,9 @@ public void multipleAttributesAndDefaults() throws Exception { .build(), /*labelConversionContext=*/ null); assertThat(typeCheckedTag.getFieldNames()).containsExactly("foo", "bar", "quux"); - assertThat(typeCheckedTag.getValue("foo")).isEqualTo("fooValue"); - assertThat(typeCheckedTag.getValue("bar")).isEqualTo(StarlarkInt.of(3)); - assertThat(typeCheckedTag.getValue("quux")) + assertThat(getattr(typeCheckedTag, "foo")).isEqualTo("fooValue"); + assertThat(getattr(typeCheckedTag, "bar")).isEqualTo(StarlarkInt.of(3)); + assertThat(getattr(typeCheckedTag, "quux")) .isEqualTo(StarlarkList.immutableOf("quuxValue1", "quuxValue2")); }