diff --git a/gapic/schema/api.py b/gapic/schema/api.py index 92c2b741c0..44c7c9854b 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -19,12 +19,15 @@ import collections import dataclasses +import itertools import keyword import os import sys from typing import Callable, Container, Dict, FrozenSet, Mapping, Optional, Sequence, Set, Tuple +from types import MappingProxyType from google.api_core import exceptions # type: ignore +from google.api import resource_pb2 from google.longrunning import operations_pb2 # type: ignore from google.protobuf import descriptor_pb2 @@ -58,11 +61,14 @@ def __getattr__(self, name: str): @classmethod def build( - cls, file_descriptor: descriptor_pb2.FileDescriptorProto, - file_to_generate: bool, naming: api_naming.Naming, - opts: Options = Options(), - prior_protos: Mapping[str, 'Proto'] = None, - load_services: bool = True + cls, + file_descriptor: descriptor_pb2.FileDescriptorProto, + file_to_generate: bool, + naming: api_naming.Naming, + opts: Options = Options(), + prior_protos: Mapping[str, 'Proto'] = None, + load_services: bool = True, + all_resources: Optional[Mapping[str, wrappers.CommonResource]] = None, ) -> 'Proto': """Build and return a Proto instance. @@ -85,7 +91,8 @@ def build( naming=naming, opts=opts, prior_protos=prior_protos or {}, - load_services=load_services + load_services=load_services, + all_resources=all_resources or {}, ).proto @cached_property @@ -104,6 +111,24 @@ def messages(self) -> Mapping[str, wrappers.MessageType]: if not v.meta.address.parent ) + @cached_property + def resource_messages(self) -> Mapping[str, wrappers.MessageType]: + """Return the file level resources of the proto.""" + file_resource_messages = ( + (res.type, wrappers.CommonResource.build(res).message_type) + for res in self.file_pb2.options.Extensions[resource_pb2.resource_definition] + ) + resource_messages = ( + (msg.options.Extensions[resource_pb2.resource].type, msg) + for msg in self.messages.values() + if msg.options.Extensions[resource_pb2.resource].type + ) + return collections.OrderedDict( + itertools.chain( + file_resource_messages, resource_messages, + ) + ) + @property def module_name(self) -> str: """Return the appropriate module name for this service. @@ -264,6 +289,13 @@ def disambiguate_keyword_fname( load_services=False, ) + # A file descriptor's file-level resources are NOT visible to any importers. + # The only way to make referenced resources visible is to aggregate them at + # the API level and then pass that around. + all_file_resources = collections.ChainMap( + *(proto.resource_messages for proto in pre_protos.values()) + ) + # Second pass uses all the messages and enums defined in the entire API. # This allows LRO returning methods to see all the types in the API, # bypassing the above missing import problem. @@ -274,6 +306,7 @@ def disambiguate_keyword_fname( naming=naming, opts=opts, prior_protos=pre_protos, + all_resources=MappingProxyType(all_file_resources), ) for name, proto in pre_protos.items() } @@ -390,7 +423,8 @@ def __init__( naming: api_naming.Naming, opts: Options = Options(), prior_protos: Mapping[str, Proto] = None, - load_services: bool = True + load_services: bool = True, + all_resources: Optional[Mapping[str, wrappers.CommonResource]] = None, ): self.proto_messages: Dict[str, wrappers.MessageType] = {} self.proto_enums: Dict[str, wrappers.EnumType] = {} @@ -432,9 +466,9 @@ def __init__( # below is because `repeated DescriptorProto message_type = 4;` in # descriptor.proto itself). self._load_children(file_descriptor.enum_type, self._load_enum, - address=self.address, path=(5,)) + address=self.address, path=(5,), resources=all_resources) self._load_children(file_descriptor.message_type, self._load_message, - address=self.address, path=(4,)) + address=self.address, path=(4,), resources=all_resources) # Edge case: Protocol buffers is not particularly picky about # ordering, and it is possible that a message will have had a field @@ -469,7 +503,7 @@ def __init__( # same files. if file_to_generate and load_services: self._load_children(file_descriptor.service, self._load_service, - address=self.address, path=(6,)) + address=self.address, path=(6,), resources=all_resources) # TODO(lukesneeringer): oneofs are on path 7. @property @@ -528,7 +562,8 @@ def api_messages(self) -> Mapping[str, wrappers.MessageType]: def _load_children(self, children: Sequence, loader: Callable, *, - address: metadata.Address, path: Tuple[int, ...]) -> Mapping: + address: metadata.Address, path: Tuple[int, ...], + resources: Mapping[str, wrappers.CommonResource]) -> Mapping: """Return wrapped versions of arbitrary children from a Descriptor. Args: @@ -554,7 +589,7 @@ def _load_children(self, # applicable loader function on each. answer = {} for child, i in zip(children, range(0, sys.maxsize)): - wrapped = loader(child, address=address, path=path + (i,)) + wrapped = loader(child, address=address, path=path + (i,), resources=resources) answer[wrapped.name] = wrapped return answer @@ -794,6 +829,7 @@ def _load_message(self, message_pb: descriptor_pb2.DescriptorProto, address: metadata.Address, path: Tuple[int], + resources: Mapping[str, wrappers.CommonResource], ) -> wrappers.MessageType: """Load message descriptions from DescriptorProtos.""" address = address.child(message_pb.name, path) @@ -810,12 +846,14 @@ def _load_message(self, address=address, loader=self._load_enum, path=path + (4,), + resources=resources, ) nested_messages = self._load_children( message_pb.nested_type, address=address, loader=self._load_message, path=path + (3,), + resources=resources, ) oneofs = self._get_oneofs( @@ -856,6 +894,7 @@ def _load_enum(self, enum: descriptor_pb2.EnumDescriptorProto, address: metadata.Address, path: Tuple[int], + resources: Mapping[str, wrappers.CommonResource], ) -> wrappers.EnumType: """Load enum descriptions from EnumDescriptorProtos.""" address = address.child(enum.name, path) @@ -886,6 +925,7 @@ def _load_service(self, service: descriptor_pb2.ServiceDescriptorProto, address: metadata.Address, path: Tuple[int], + resources: Mapping[str, wrappers.CommonResource], ) -> wrappers.Service: """Load comments for a service and its methods.""" address = address.child(service.name, path) @@ -905,6 +945,7 @@ def _load_service(self, ), methods=methods, service_pb=service, + visible_resources=resources, ) return self.proto_services[address.proto] diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 8fc4b8d712..da1e979908 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -33,7 +33,6 @@ from itertools import chain from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping, ClassVar, Optional, Sequence, Set, Tuple, Union) - from google.api import annotations_pb2 # type: ignore from google.api import client_pb2 from google.api import field_behavior_pb2 @@ -62,6 +61,12 @@ class Field: def __getattr__(self, name): return getattr(self.field_pb, name) + def __hash__(self): + # The only sense in which it is meaningful to say a field is equal to + # another field is if they are the same, i.e. they live in the same + # message type under the same moniker, i.e. they have the same id. + return id(self) + @property def name(self) -> str: """Used to prevent collisions with python keywords""" @@ -305,6 +310,15 @@ def recursive_field_types(self) -> Sequence[ return tuple(types) + @utils.cached_property + def recursive_fields(self) -> Sequence[Field]: + return frozenset(chain( + self.fields.values(), + (field + for t in self.recursive_field_types if isinstance(t, MessageType) + for field in t.fields.values()), + )) + @property def map(self) -> bool: """Return True if the given message is a map, False otherwise.""" @@ -860,6 +874,13 @@ class CommonResource: type_name: str pattern: str + @classmethod + def build(cls, resource: resource_pb2.ResourceDescriptor): + return cls( + type_name=resource.type, + pattern=next(iter(resource.pattern)) + ) + @utils.cached_property def message_type(self): message_pb = descriptor_pb2.DescriptorProto() @@ -880,6 +901,10 @@ class Service: """Description of a service (defined with the ``service`` keyword).""" service_pb: descriptor_pb2.ServiceDescriptorProto methods: Mapping[str, Method] + # N.B.: visible_resources is intended to be a read-only view + # whose backing store is owned by the API. + # This is represented by a types.MappingProxyType instance. + visible_resources: Mapping[str, MessageType] meta: metadata.Metadata = dataclasses.field( default_factory=metadata.Metadata, ) @@ -1021,6 +1046,14 @@ def gen_resources(message): if type_.resource_path: yield type_ + def gen_indirect_resources_used(message): + for field in message.recursive_fields: + resource = field.options.Extensions[ + resource_pb2.resource_reference] + resource_type = resource.type or resource.child_type + if resource_type: + yield self.visible_resources[resource_type] + return frozenset( msg for method in self.methods.values() @@ -1029,6 +1062,10 @@ def gen_resources(message): gen_resources( method.lro.response_type if method.lro else method.output ), + gen_indirect_resources_used(method.input), + gen_indirect_resources_used( + method.lro.response_type if method.lro else method.output + ), ) ) diff --git a/test_utils/test_utils.py b/test_utils/test_utils.py index 89fd735142..beab26518f 100644 --- a/test_utils/test_utils.py +++ b/test_utils/test_utils.py @@ -24,9 +24,16 @@ from google.protobuf import descriptor_pb2 as desc -def make_service(name: str = 'Placeholder', host: str = '', - methods: typing.Tuple[wrappers.Method] = (), - scopes: typing.Tuple[str] = ()) -> wrappers.Service: +def make_service( + name: str = "Placeholder", + host: str = "", + methods: typing.Tuple[wrappers.Method] = (), + scopes: typing.Tuple[str] = (), + visible_resources: typing.Optional[ + typing.Mapping[str, wrappers.CommonResource] + ] = None, +) -> wrappers.Service: + visible_resources = visible_resources or {} # Define a service descriptor, and set a host and oauth scopes if # appropriate. service_pb = desc.ServiceDescriptorProto(name=name) @@ -38,6 +45,7 @@ def make_service(name: str = 'Placeholder', host: str = '', return wrappers.Service( service_pb=service_pb, methods={m.name: m for m in methods}, + visible_resources=visible_resources, ) @@ -47,7 +55,8 @@ def make_service_with_method_options( *, http_rule: http_pb2.HttpRule = None, method_signature: str = '', - in_fields: typing.Tuple[desc.FieldDescriptorProto] = () + in_fields: typing.Tuple[desc.FieldDescriptorProto] = (), + visible_resources: typing.Optional[typing.Mapping[str, wrappers.CommonResource]] = None, ) -> wrappers.Service: # Declare a method with options enabled for long-running operations and # field headers. @@ -69,6 +78,7 @@ def make_service_with_method_options( return wrappers.Service( service_pb=service_pb, methods={method.name: method}, + visible_resources=visible_resources or {}, ) diff --git a/tests/unit/generator/test_generator.py b/tests/unit/generator/test_generator.py index 6a120bbe1d..a258c294cf 100644 --- a/tests/unit/generator/test_generator.py +++ b/tests/unit/generator/test_generator.py @@ -262,6 +262,7 @@ def test_get_filename_with_service(): methods=[], service_pb=descriptor_pb2.ServiceDescriptorProto( name="Eggs"), + visible_resources={}, ), }, ) diff --git a/tests/unit/samplegen/test_integration.py b/tests/unit/samplegen/test_integration.py index ca6e1a44f0..b652f1359e 100644 --- a/tests/unit/samplegen/test_integration.py +++ b/tests/unit/samplegen/test_integration.py @@ -80,7 +80,8 @@ def test_generate_sample_basic(): "classify_target": DummyField(name="classify_target") } ) - } + }, + visible_resources={}, ) schema = DummyApiSchema( @@ -216,7 +217,8 @@ def test_generate_sample_basic_unflattenable(): input=input_type, output=message_factory("$resp.taxonomy"), ) - } + }, + visible_resources={}, ) schema = DummyApiSchema( diff --git a/tests/unit/schema/test_api.py b/tests/unit/schema/test_api.py index fbe82a8f9b..5e74d653ae 100644 --- a/tests/unit/schema/test_api.py +++ b/tests/unit/schema/test_api.py @@ -18,6 +18,7 @@ import pytest from google.api import client_pb2 +from google.api import resource_pb2 from google.api_core import exceptions from google.longrunning import operations_pb2 from google.protobuf import descriptor_pb2 @@ -1079,3 +1080,131 @@ def test_enums(): assert enum.values[1].meta.doc == 'This is the one value.' assert enum.values[2].name == 'THREE' assert enum.values[2].meta.doc == '' + + +def test_file_level_resources(): + fdp = make_file_pb2( + name="nomenclature.proto", + package="nomenclature.linneaen.v1", + messages=( + make_message_pb2( + name="CreateSpeciesRequest", + fields=( + make_field_pb2(name='species', number=1, type=9), + ), + ), + make_message_pb2( + name="CreateSpeciesResponse" , + ), + ), + services=( + descriptor_pb2.ServiceDescriptorProto( + name="SpeciesService", + method=( + descriptor_pb2.MethodDescriptorProto( + name="CreateSpecies", + input_type="nomenclature.linneaen.v1.CreateSpeciesRequest", + output_type="nomenclature.linneaen.v1.CreateSpeciesResponse", + ), + ), + ), + ), + ) + res_pb2 = fdp.options.Extensions[resource_pb2.resource_definition] + definitions = [ + ("nomenclature.linnaen.com/Species", "families/{family}/genera/{genus}/species/{species}"), + ("nomenclature.linnaen.com/Phylum", "kingdoms/{kingdom}/phyla/{phylum}") + ] + for type_, pattern in definitions: + resource_definition = res_pb2.add() + resource_definition.type = type_ + resource_definition.pattern.append(pattern) + + + species_field = fdp.message_type[0].field[0] + resource_reference = species_field.options.Extensions[resource_pb2.resource_reference] + resource_reference.type = "nomenclature.linnaen.com/Species" + + api_schema = api.API.build([fdp], package='nomenclature.linneaen.v1') + actual = api_schema.protos['nomenclature.proto'].resource_messages + expected = { + "nomenclature.linnaen.com/Species": wrappers.CommonResource( + type_name="nomenclature.linnaen.com/Species", + pattern="families/{family}/genera/{genus}/species/{species}" + ).message_type, + "nomenclature.linnaen.com/Phylum": wrappers.CommonResource( + type_name="nomenclature.linnaen.com/Phylum", + pattern="kingdoms/{kingdom}/phyla/{phylum}" + ).message_type, + } + + assert actual == expected + + # The proto file _owns_ the file level resources, but the service needs to + # see them too because the client class owns all the helper methods. + service = api_schema.services["nomenclature.linneaen.v1.SpeciesService"] + actual = service.visible_resources + assert actual == expected + + # The service doesn't own any method that owns a message that references + # Phylum, so the service doesn't count it among its resource messages. + expected.pop("nomenclature.linnaen.com/Phylum") + expected = frozenset(expected.values()) + actual = service.resource_messages + + assert actual == expected + + +def test_resources_referenced_but_not_typed(reference_attr="type"): + fdp = make_file_pb2( + name="nomenclature.proto", + package="nomenclature.linneaen.v1", + messages=( + make_message_pb2( + name="Species", + ), + make_message_pb2( + name="CreateSpeciesRequest", + fields=( + make_field_pb2(name='species', number=1, type=9), + ), + ), + make_message_pb2( + name="CreateSpeciesResponse" , + ), + ), + services=( + descriptor_pb2.ServiceDescriptorProto( + name="SpeciesService", + method=( + descriptor_pb2.MethodDescriptorProto( + name="CreateSpecies", + input_type="nomenclature.linneaen.v1.CreateSpeciesRequest", + output_type="nomenclature.linneaen.v1.CreateSpeciesResponse", + ), + ), + ), + ), + ) + + # Set up the resource + species_resource_opts = fdp.message_type[0].options.Extensions[resource_pb2.resource] + species_resource_opts.type = "nomenclature.linnaen.com/Species" + species_resource_opts.pattern.append("families/{family}/genera/{genus}/species/{species}") + + # Set up the reference + name_resource_opts = fdp.message_type[1].field[0].options.Extensions[resource_pb2.resource_reference] + if reference_attr == "type": + name_resource_opts.type = species_resource_opts.type + else: + name_resource_opts.child_type = species_resource_opts.type + + api_schema = api.API.build([fdp], package="nomenclature.linneaen.v1") + expected = {api_schema.messages["nomenclature.linneaen.v1.Species"]} + actual = api_schema.services["nomenclature.linneaen.v1.SpeciesService"].resource_messages + + assert actual == expected + + +def test_resources_referenced_but_not_typed_child_type(): + test_resources_referenced_but_not_typed("child_type")