Skip to content

Commit

Permalink
Clean up proto_library.bzl implementation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 411986102
  • Loading branch information
comius authored and copybara-github committed Nov 24, 2021
1 parent 46d5936 commit cbad324
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 80 deletions.
187 changes: 117 additions & 70 deletions src/main/starlark/builtins_bzl/common/proto/proto_library.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -24,73 +24,146 @@ ProtoInfo = _builtins.toplevel.ProtoInfo
native_proto_common = _builtins.toplevel.proto_common

def _check_srcs_package(target_package, srcs):
"""Makes sure the given srcs live in the given package."""
"""Check that .proto files in sources are from the same package.
This is done to avoid clashes with the generated sources."""

#TODO(bazel-team): this does not work with filegroups that contain files that are not in the package
for src in srcs:
if target_package != src.label.package:
fail("Proto source with label '%s' must be in same package as consuming rule." % src.label)

def _join(*path):
return "/".join([p for p in path if p != ""])

def _create_proto_info(ctx):
srcs = ctx.files.srcs
deps = [dep[ProtoInfo] for dep in ctx.attr.deps]
exports = [dep[ProtoInfo] for dep in ctx.attr.exports]
def _get_import_prefix(ctx):
"""Gets and verifies import_prefix attribute if it is declared."""

import_prefix = ctx.attr.import_prefix if hasattr(ctx.attr, "import_prefix") else ""

if not paths.is_normalized(import_prefix):
fail("should be normalized (without uplevel references or '.' path segments)", attr = "import_prefix")
if paths.is_absolute(import_prefix):
fail("should be a relative path", attr = "import_prefix")

return import_prefix

def _get_strip_import_prefix(ctx):
"""Gets and verifies strip_import_prefix."""

strip_import_prefix = ctx.attr.strip_import_prefix

if not paths.is_normalized(strip_import_prefix):
fail("should be normalized (without uplevel references or '.' path segments)", attr = "strip_import_prefix")
if strip_import_prefix.startswith("/"):

if paths.is_absolute(strip_import_prefix):
strip_import_prefix = strip_import_prefix[1:]
elif strip_import_prefix != "DO_NOT_STRIP": # Relative to current package
strip_import_prefix = _join(ctx.label.package, strip_import_prefix)
else:
strip_import_prefix = ""

has_generated_sources = False
if ctx.fragments.proto.generated_protos_in_virtual_imports():
has_generated_sources = any([not src.is_source for src in srcs])
return strip_import_prefix

direct_sources = []
if import_prefix != "" or strip_import_prefix != "" or has_generated_sources:
# Use virtual source roots
if paths.is_absolute(import_prefix):
fail("should be a relative path", attr = "import_prefix")
def _proto_library_impl(ctx):
semantics.preprocess(ctx)

virtual_imports = _join("_virtual_imports", ctx.label.name)
if ctx.label.workspace_name == "" or ctx.label.workspace_root.startswith(".."): # siblingRepositoryLayout
proto_path = _join(ctx.genfiles_dir.path, ctx.label.package, virtual_imports)
else:
proto_path = _join(ctx.genfiles_dir.path, ctx.label.workspace_root, ctx.label.package, virtual_imports)
# Verifies attributes.
_check_srcs_package(ctx.label.package, ctx.attr.srcs)
srcs = ctx.files.srcs
deps = [dep[ProtoInfo] for dep in ctx.attr.deps]
exports = [dep[ProtoInfo] for dep in ctx.attr.exports]
import_prefix = _get_import_prefix(ctx)
strip_import_prefix = _get_strip_import_prefix(ctx)

for src in srcs:
if ctx.label.workspace_name == "":
repository_relative_path = src.short_path
else:
repository_relative_path = paths.relativize(src.short_path, "../" + ctx.label.workspace_name)
proto_path, direct_sources = _create_proto_sources(ctx, srcs, import_prefix, strip_import_prefix)
descriptor_set = ctx.actions.declare_file(ctx.label.name + "-descriptor-set.proto.bin")
proto_info = _create_proto_info(ctx, direct_sources, deps, exports, proto_path, descriptor_set)
_write_descriptor_set(ctx, deps, proto_info, descriptor_set)

# We assume that the proto sources will not have conflicting artifacts
# with the same root relative path
data_runfiles = ctx.runfiles(
files = [proto_info.direct_descriptor_set],
transitive_files = depset(transitive = [proto_info.transitive_sources]),
)
return [
proto_info,
DefaultInfo(
files = depset([proto_info.direct_descriptor_set]),
default_runfiles = ctx.runfiles(), # empty
data_runfiles = data_runfiles,
),
]

if not repository_relative_path.startswith(strip_import_prefix):
fail(".proto file '%s' is not under the specified strip prefix '%s'" %
(src.short_path, strip_import_prefix))
import_path = repository_relative_path[len(strip_import_prefix):]
def _create_proto_sources(ctx, srcs, import_prefix, strip_import_prefix):
"""Transforms Files in srcs to ProtoSources, optionally symlinking them to _virtual_imports.
virtual_src = ctx.actions.declare_file(_join(virtual_imports, import_prefix, import_path))
ctx.actions.symlink(
output = virtual_src,
target_file = src,
progress_message = "Symlinking virtual .proto sources for %{label}",
)
direct_sources.append(native_proto_common.ProtoSource(virtual_src, src, proto_path))
Returns:
A pair proto_path, directs_sources.
"""
generate_protos_in_virtual_imports = False
if ctx.fragments.proto.generated_protos_in_virtual_imports():
generate_protos_in_virtual_imports = any([not src.is_source for src in srcs])

if import_prefix != "" or strip_import_prefix != "" or generate_protos_in_virtual_imports:
# Use virtual source roots
return _symlink_to_virtual_imports(ctx, srcs, import_prefix, strip_import_prefix)
else:
# No virtual source roots
proto_path = "."
direct_sources = []
for src in srcs:
direct_sources.append(native_proto_common.ProtoSource(src, src, ctx.label.workspace_root + src.root.path))
if ctx.label.workspace_name == "" or ctx.label.workspace_root.startswith(".."):
# source_root == ''|'bazel-out/foo/k8-fastbuild/bin'
source_root = src.root.path
else:
# source_root == ''|'bazel-out/foo/k8-fastbuild/bin' / 'external/repo'
source_root = _join(src.root.path, ctx.label.workspace_root)
direct_sources.append(native_proto_common.ProtoSource(src, src, source_root))

return ctx.label.workspace_root if ctx.label.workspace_root else ".", direct_sources

def _join(*path):
return "/".join([p for p in path if p != ""])

def _symlink_to_virtual_imports(ctx, srcs, import_prefix, strip_import_prefix):
"""Symlinks srcs to _virtual_imports.
Returns:
A pair proto_path, directs_sources.
"""
virtual_imports = _join("_virtual_imports", ctx.label.name)
if ctx.label.workspace_name == "" or ctx.label.workspace_root.startswith(".."): # siblingRepositoryLayout
# Example: `bazel-out/[repo/]target/bin / pkg / _virtual_imports/name`
proto_path = _join(ctx.genfiles_dir.path, ctx.label.package, virtual_imports)
else:
# Example: `bazel-out/target/bin / repo / pkg / _virtual_imports/name`
proto_path = _join(ctx.genfiles_dir.path, ctx.label.workspace_root, ctx.label.package, virtual_imports)

direct_sources = []
for src in srcs:
if ctx.label.workspace_name == "":
repository_relative_path = src.short_path
else:
# src.short_path = ../repo/pkg/a.proto
repository_relative_path = paths.relativize(src.short_path, "../" + ctx.label.workspace_name)

# Remove strip_import_prefix
if not repository_relative_path.startswith(strip_import_prefix):
fail(".proto file '%s' is not under the specified strip prefix '%s'" %
(src.short_path, strip_import_prefix))
import_path = repository_relative_path[len(strip_import_prefix):]

# Add import_prefix
virtual_src = ctx.actions.declare_file(_join(virtual_imports, import_prefix, import_path))

ctx.actions.symlink(
output = virtual_src,
target_file = src,
progress_message = "Symlinking virtual .proto sources for %{label}",
)
direct_sources.append(native_proto_common.ProtoSource(virtual_src, src, proto_path))
return proto_path, direct_sources

def _create_proto_info(ctx, direct_sources, deps, exports, proto_path, descriptor_set):
"""Constructs ProtoInfo."""

# Construct ProtoInfo
transitive_proto_sources = depset(
Expand All @@ -112,9 +185,8 @@ def _create_proto_info(ctx):
else:
check_deps_sources = depset(transitive = [dep.check_deps_sources for dep in deps])

direct_descriptor_set = ctx.actions.declare_file(ctx.label.name + "-descriptor-set.proto.bin")
transitive_descriptor_sets = depset(
direct = [direct_descriptor_set],
direct = [descriptor_set],
transitive = [dep.transitive_descriptor_sets for dep in deps],
)

Expand All @@ -137,21 +209,19 @@ def _create_proto_info(ctx):
transitive_proto_sources,
transitive_proto_path,
check_deps_sources,
direct_descriptor_set,
descriptor_set,
transitive_descriptor_sets,
exported_sources,
strict_importable_sources,
public_import_protos,
)

def _write_descriptor_set(ctx, proto_info):
descriptor_set = proto_info.direct_descriptor_set

def _write_descriptor_set(ctx, deps, proto_info, descriptor_set):
"""Writes descriptor set."""
if proto_info.direct_sources == []:
ctx.actions.write(descriptor_set, "")
return

deps = [dep[ProtoInfo] for dep in ctx.attr.deps]
dependencies_descriptor_sets = depset(transitive = [dep.transitive_descriptor_sets for dep in deps])

args = []
Expand All @@ -170,29 +240,6 @@ def _write_descriptor_set(ctx, proto_info):
additional_args = args,
)

def _proto_library_impl(ctx):
semantics.preprocess(ctx)

_check_srcs_package(ctx.label.package, ctx.attr.srcs)

proto_info = _create_proto_info(ctx)

_write_descriptor_set(ctx, proto_info)

data_runfiles = ctx.runfiles(
files = [proto_info.direct_descriptor_set],
transitive_files = depset(transitive = [proto_info.transitive_sources]),
)

return [
proto_info,
DefaultInfo(
files = depset([proto_info.direct_descriptor_set]),
default_runfiles = ctx.runfiles(), # empty
data_runfiles = data_runfiles,
),
]

proto_library = rule(
_proto_library_impl,
attrs = dict({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ public void testStripImportPrefixWithDeps() throws Exception {
".");
}

private void testExternalRepoWithGeneratedProto(boolean siblingRepoLayout) throws Exception {
private void testExternalRepoWithGeneratedProto(
boolean siblingRepoLayout, boolean useVirtualImports) throws Exception {
if (!isThisBazel()) {
return;
}
Expand All @@ -361,6 +362,9 @@ private void testExternalRepoWithGeneratedProto(boolean siblingRepoLayout) throw
if (siblingRepoLayout) {
setBuildLanguageOptions("--experimental_sibling_repository_layout");
}
if (!useVirtualImports) {
useConfiguration("--noincompatible_generated_protos_in_virtual_imports");
}
invalidatePackages();

scratch.file("/foo/WORKSPACE");
Expand All @@ -369,7 +373,6 @@ private void testExternalRepoWithGeneratedProto(boolean siblingRepoLayout) throw
TestConstants.LOAD_PROTO_LIBRARY,
"proto_library(name='x', srcs=['generated.proto'])",
"genrule(name='g', srcs=[], outs=['generated.proto'], cmd='')");

scratch.file(
"a/BUILD",
TestConstants.LOAD_PROTO_LIBRARY,
Expand All @@ -380,27 +383,42 @@ private void testExternalRepoWithGeneratedProto(boolean siblingRepoLayout) throw
.getGenfilesFragment(
siblingRepoLayout ? RepositoryName.create("@foo") : RepositoryName.MAIN)
.toString();
String fooProtoRoot;
if (useVirtualImports) {
fooProtoRoot =
genfiles + (siblingRepoLayout ? "" : "/external/foo") + "/x/_virtual_imports/x";
} else {
fooProtoRoot = (siblingRepoLayout ? "../foo" : "external/foo");
}
ConfiguredTarget a = getConfiguredTarget("//a:a");
ProtoInfo aInfo = a.get(ProtoInfo.PROVIDER);
assertThat(aInfo.getTransitiveProtoSourceRoots().toList())
.containsExactly(
".", genfiles + (siblingRepoLayout ? "" : "/external/foo") + "/x/_virtual_imports/x");
assertThat(aInfo.getTransitiveProtoSourceRoots().toList()).containsExactly(".", fooProtoRoot);

ConfiguredTarget x = getConfiguredTarget("@foo//x:x");
ProtoInfo xInfo = x.get(ProtoInfo.PROVIDER);
assertThat(xInfo.getTransitiveProtoSourceRoots().toList())
.containsExactly(
genfiles + (siblingRepoLayout ? "" : "/external/foo") + "/x/_virtual_imports/x");
assertThat(xInfo.getTransitiveProtoSourceRoots().toList()).containsExactly(fooProtoRoot);
}

@Test
public void testExternalRepoWithGeneratedProto_withSubdirRepoLayout() throws Exception {
testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ false);
testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ false, true);
}

@Test
public void test_siblingRepoLayout_externalRepoWithGeneratedProto() throws Exception {
testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ true);
testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ true, true);
}

@Test
public void testExternalRepoWithGeneratedProto_withSubdirRepoLayoutAndNoVritualImports()
throws Exception {
testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ false, false);
}

@Test
public void test_siblingRepoLayout_externalRepoWithGeneratedProtoAndNoVritualImports()
throws Exception {
testExternalRepoWithGeneratedProto(/*siblingRepoLayout=*/ true, false);
}

@Test
Expand Down

0 comments on commit cbad324

Please sign in to comment.