Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: file_level and indirectly used resources generate helper methods #642

Merged
merged 9 commits into from
Oct 9, 2020
69 changes: 57 additions & 12 deletions gapic/schema/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@

import collections
import dataclasses
import itertools
import keyword
import os
import sys
from typing import Callable, Container, Dict, FrozenSet, Mapping, Optional, Sequence, Set, Tuple
from types import MappingProxyType

from google.api_core import exceptions # type: ignore
from google.api import resource_pb2 # type: ignore
from google.longrunning import operations_pb2 # type: ignore
from google.protobuf import descriptor_pb2

Expand Down Expand Up @@ -58,11 +61,14 @@ def __getattr__(self, name: str):

@classmethod
def build(
cls, file_descriptor: descriptor_pb2.FileDescriptorProto,
file_to_generate: bool, naming: api_naming.Naming,
opts: Options = Options(),
prior_protos: Mapping[str, 'Proto'] = None,
load_services: bool = True
cls,
file_descriptor: descriptor_pb2.FileDescriptorProto,
file_to_generate: bool,
naming: api_naming.Naming,
opts: Options = Options(),
prior_protos: Mapping[str, 'Proto'] = None,
load_services: bool = True,
all_resources: Optional[Mapping[str, wrappers.MessageType]] = None,
) -> 'Proto':
"""Build and return a Proto instance.

Expand All @@ -85,7 +91,8 @@ def build(
naming=naming,
opts=opts,
prior_protos=prior_protos or {},
load_services=load_services
load_services=load_services,
all_resources=all_resources or {},
).proto

@cached_property
Expand All @@ -104,6 +111,24 @@ def messages(self) -> Mapping[str, wrappers.MessageType]:
if not v.meta.address.parent
)

@cached_property
def resource_messages(self) -> Mapping[str, wrappers.MessageType]:
"""Return the file level resources of the proto."""
file_resource_messages = (
(res.type, wrappers.CommonResource.build(res).message_type)
for res in self.file_pb2.options.Extensions[resource_pb2.resource_definition]
)
resource_messages = (
(msg.options.Extensions[resource_pb2.resource].type, msg)
for msg in self.messages.values()
if msg.options.Extensions[resource_pb2.resource].type
)
return collections.OrderedDict(
itertools.chain(
file_resource_messages, resource_messages,
)
)

@property
def module_name(self) -> str:
"""Return the appropriate module name for this service.
Expand Down Expand Up @@ -264,6 +289,13 @@ def disambiguate_keyword_fname(
load_services=False,
)

# A file descriptor's file-level resources are NOT visible to any importers.
# The only way to make referenced resources visible is to aggregate them at
# the API level and then pass that around.
all_file_resources = collections.ChainMap(
*(proto.resource_messages for proto in pre_protos.values())
)

# Second pass uses all the messages and enums defined in the entire API.
# This allows LRO returning methods to see all the types in the API,
# bypassing the above missing import problem.
Expand All @@ -274,6 +306,7 @@ def disambiguate_keyword_fname(
naming=naming,
opts=opts,
prior_protos=pre_protos,
all_resources=MappingProxyType(all_file_resources),
)
for name, proto in pre_protos.items()
}
Expand Down Expand Up @@ -390,7 +423,8 @@ def __init__(
naming: api_naming.Naming,
opts: Options = Options(),
prior_protos: Mapping[str, Proto] = None,
load_services: bool = True
load_services: bool = True,
all_resources: Optional[Mapping[str, wrappers.MessageType]] = None,
):
self.proto_messages: Dict[str, wrappers.MessageType] = {}
self.proto_enums: Dict[str, wrappers.EnumType] = {}
Expand Down Expand Up @@ -432,9 +466,11 @@ def __init__(
# below is because `repeated DescriptorProto message_type = 4;` in
# descriptor.proto itself).
self._load_children(file_descriptor.enum_type, self._load_enum,
address=self.address, path=(5,))
address=self.address, path=(5,),
resources=all_resources or {})
self._load_children(file_descriptor.message_type, self._load_message,
address=self.address, path=(4,))
address=self.address, path=(4,),
resources=all_resources or {})

# Edge case: Protocol buffers is not particularly picky about
# ordering, and it is possible that a message will have had a field
Expand Down Expand Up @@ -469,7 +505,8 @@ def __init__(
# same files.
if file_to_generate and load_services:
self._load_children(file_descriptor.service, self._load_service,
address=self.address, path=(6,))
address=self.address, path=(6,),
resources=all_resources or {})
# TODO(lukesneeringer): oneofs are on path 7.

@property
Expand Down Expand Up @@ -528,7 +565,8 @@ def api_messages(self) -> Mapping[str, wrappers.MessageType]:

def _load_children(self,
children: Sequence, loader: Callable, *,
address: metadata.Address, path: Tuple[int, ...]) -> Mapping:
address: metadata.Address, path: Tuple[int, ...],
resources: Mapping[str, wrappers.MessageType]) -> Mapping:
"""Return wrapped versions of arbitrary children from a Descriptor.

Args:
Expand All @@ -554,7 +592,8 @@ def _load_children(self,
# applicable loader function on each.
answer = {}
for child, i in zip(children, range(0, sys.maxsize)):
wrapped = loader(child, address=address, path=path + (i,))
wrapped = loader(child, address=address, path=path + (i,),
resources=resources)
answer[wrapped.name] = wrapped
return answer

Expand Down Expand Up @@ -794,6 +833,7 @@ def _load_message(self,
message_pb: descriptor_pb2.DescriptorProto,
address: metadata.Address,
path: Tuple[int],
resources: Mapping[str, wrappers.MessageType],
) -> wrappers.MessageType:
"""Load message descriptions from DescriptorProtos."""
address = address.child(message_pb.name, path)
Expand All @@ -810,12 +850,14 @@ def _load_message(self,
address=address,
loader=self._load_enum,
path=path + (4,),
resources=resources,
)
nested_messages = self._load_children(
message_pb.nested_type,
address=address,
loader=self._load_message,
path=path + (3,),
resources=resources,
)

oneofs = self._get_oneofs(
Expand Down Expand Up @@ -856,6 +898,7 @@ def _load_enum(self,
enum: descriptor_pb2.EnumDescriptorProto,
address: metadata.Address,
path: Tuple[int],
resources: Mapping[str, wrappers.MessageType],
) -> wrappers.EnumType:
"""Load enum descriptions from EnumDescriptorProtos."""
address = address.child(enum.name, path)
Expand Down Expand Up @@ -886,6 +929,7 @@ def _load_service(self,
service: descriptor_pb2.ServiceDescriptorProto,
address: metadata.Address,
path: Tuple[int],
resources: Mapping[str, wrappers.MessageType],
) -> wrappers.Service:
"""Load comments for a service and its methods."""
address = address.child(service.name, path)
Expand All @@ -905,6 +949,7 @@ def _load_service(self,
),
methods=methods,
service_pb=service,
visible_resources=resources,
)
return self.proto_services[address.proto]

Expand Down
39 changes: 38 additions & 1 deletion gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from itertools import chain
from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping,
ClassVar, Optional, Sequence, Set, Tuple, Union)

from google.api import annotations_pb2 # type: ignore
from google.api import client_pb2
from google.api import field_behavior_pb2
Expand Down Expand Up @@ -62,6 +61,12 @@ class Field:
def __getattr__(self, name):
return getattr(self.field_pb, name)

def __hash__(self):
# The only sense in which it is meaningful to say a field is equal to
# another field is if they are the same, i.e. they live in the same
# message type under the same moniker, i.e. they have the same id.
return id(self)

@property
def name(self) -> str:
"""Used to prevent collisions with python keywords"""
Expand Down Expand Up @@ -305,6 +310,15 @@ def recursive_field_types(self) -> Sequence[

return tuple(types)

@utils.cached_property
def recursive_fields(self) -> FrozenSet[Field]:
return frozenset(chain(
self.fields.values(),
(field
for t in self.recursive_field_types if isinstance(t, MessageType)
for field in t.fields.values()),
))

@property
def map(self) -> bool:
"""Return True if the given message is a map, False otherwise."""
Expand Down Expand Up @@ -860,6 +874,13 @@ class CommonResource:
type_name: str
pattern: str

@classmethod
def build(cls, resource: resource_pb2.ResourceDescriptor):
return cls(
type_name=resource.type,
pattern=next(iter(resource.pattern))
)

@utils.cached_property
def message_type(self):
message_pb = descriptor_pb2.DescriptorProto()
Expand All @@ -880,6 +901,10 @@ class Service:
"""Description of a service (defined with the ``service`` keyword)."""
service_pb: descriptor_pb2.ServiceDescriptorProto
methods: Mapping[str, Method]
# N.B.: visible_resources is intended to be a read-only view
# whose backing store is owned by the API.
# This is represented by a types.MappingProxyType instance.
visible_resources: Mapping[str, MessageType]
meta: metadata.Metadata = dataclasses.field(
default_factory=metadata.Metadata,
)
Expand Down Expand Up @@ -1021,6 +1046,14 @@ def gen_resources(message):
if type_.resource_path:
yield type_

def gen_indirect_resources_used(message):
for field in message.recursive_fields:
resource = field.options.Extensions[
resource_pb2.resource_reference]
resource_type = resource.type or resource.child_type
if resource_type:
yield self.visible_resources[resource_type]

return frozenset(
msg
for method in self.methods.values()
Expand All @@ -1029,6 +1062,10 @@ def gen_resources(message):
gen_resources(
method.lro.response_type if method.lro else method.output
),
gen_indirect_resources_used(method.input),
gen_indirect_resources_used(
method.lro.response_type if method.lro else method.output
),
)
)

Expand Down
18 changes: 14 additions & 4 deletions test_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,16 @@
from google.protobuf import descriptor_pb2 as desc


def make_service(name: str = 'Placeholder', host: str = '',
methods: typing.Tuple[wrappers.Method] = (),
scopes: typing.Tuple[str] = ()) -> wrappers.Service:
def make_service(
name: str = "Placeholder",
host: str = "",
methods: typing.Tuple[wrappers.Method] = (),
scopes: typing.Tuple[str] = (),
visible_resources: typing.Optional[
typing.Mapping[str, wrappers.CommonResource]
] = None,
) -> wrappers.Service:
visible_resources = visible_resources or {}
# Define a service descriptor, and set a host and oauth scopes if
# appropriate.
service_pb = desc.ServiceDescriptorProto(name=name)
Expand All @@ -38,6 +45,7 @@ def make_service(name: str = 'Placeholder', host: str = '',
return wrappers.Service(
service_pb=service_pb,
methods={m.name: m for m in methods},
visible_resources=visible_resources,
)


Expand All @@ -47,7 +55,8 @@ def make_service_with_method_options(
*,
http_rule: http_pb2.HttpRule = None,
method_signature: str = '',
in_fields: typing.Tuple[desc.FieldDescriptorProto] = ()
in_fields: typing.Tuple[desc.FieldDescriptorProto] = (),
visible_resources: typing.Optional[typing.Mapping[str, wrappers.CommonResource]] = None,
) -> wrappers.Service:
# Declare a method with options enabled for long-running operations and
# field headers.
Expand All @@ -69,6 +78,7 @@ def make_service_with_method_options(
return wrappers.Service(
service_pb=service_pb,
methods={method.name: method},
visible_resources=visible_resources or {},
)


Expand Down
1 change: 1 addition & 0 deletions tests/unit/generator/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def test_get_filename_with_service():
methods=[],
service_pb=descriptor_pb2.ServiceDescriptorProto(
name="Eggs"),
visible_resources={},
),
},
)
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/samplegen/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def test_generate_sample_basic():
"classify_target": DummyField(name="classify_target")
}
)
}
},
visible_resources={},
)

schema = DummyApiSchema(
Expand Down Expand Up @@ -216,7 +217,8 @@ def test_generate_sample_basic_unflattenable():
input=input_type,
output=message_factory("$resp.taxonomy"),
)
}
},
visible_resources={},
)

schema = DummyApiSchema(
Expand Down
Loading