Skip to content

Commit

Permalink
feat: adds dynamic routing. (#1135)
Browse files Browse the repository at this point in the history
feat: adds dynamic routing files.
  • Loading branch information
atulep authored Feb 3, 2022
1 parent 5df9733 commit 8c191a5
Show file tree
Hide file tree
Showing 17 changed files with 1,205 additions and 107 deletions.
175 changes: 130 additions & 45 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
import re
from itertools import chain
from typing import (Any, cast, Dict, FrozenSet, Iterator, Iterable, List, Mapping,
ClassVar, Optional, Sequence, Set, Tuple, Union)
ClassVar, Optional, Sequence, Set, Tuple, Union, Pattern)
from google.api import annotations_pb2 # type: ignore
from google.api import client_pb2
from google.api import field_behavior_pb2
from google.api import http_pb2
from google.api import resource_pb2
from google.api import routing_pb2
from google.api_core import exceptions
from google.api_core import path_template
from google.cloud import extended_operations_pb2 as ex_ops_pb2 # type: ignore
Expand All @@ -47,6 +48,7 @@

from gapic import utils
from gapic.schema import metadata
from gapic.utils import uri_sample


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -763,6 +765,118 @@ class RetryInfo:
retryable_exceptions: FrozenSet[exceptions.GoogleAPICallError]


@dataclasses.dataclass(frozen=True)
class RoutingParameter:
field: str
path_template: str

def _split_into_segments(self, path_template):
segments = path_template.split("/")
named_segment_ids = [i for i, x in enumerate(
segments) if "{" in x or "}" in x]
# bar/{foo}/baz, bar/{foo=one/two/three}/baz.
assert len(named_segment_ids) <= 2
if len(named_segment_ids) == 2:
# Need to merge a named segment.
i, j = named_segment_ids
segments = (
segments[:i] +
[self._merge_segments(segments[i: j + 1])] + segments[j + 1:]
)
return segments

def _convert_segment_to_regex(self, segment):
# Named segment
if "{" in segment:
assert "}" in segment
# Strip "{" and "}"
segment = segment[1:-1]
if "=" not in segment:
# e.g. {foo} should be {foo=*}
return self._convert_segment_to_regex("{" + f"{segment}=*" + "}")
key, sub_path_template = segment.split("=")
group_name = f"?P<{key}>"
sub_regex = self._convert_to_regex(sub_path_template)
return f"({group_name}{sub_regex})"
# Wildcards
if "**" in segment:
# ?: nameless capture
return ".*"
if "*" in segment:
return "[^/]+"
# Otherwise it's collection ID segment: transformed identically.
return segment

def _merge_segments(self, segments):
acc = segments[0]
for x in segments[1:]:
# Don't add "/" if it's followed by a "**"
# because "**" will eat it.
if x == ".*":
acc += "(?:/.*)?"
else:
acc += "/"
acc += x
return acc

def _how_many_named_segments(self, path_template):
return path_template.count("{")

def _convert_to_regex(self, path_template):
if self._how_many_named_segments(path_template) > 1:
# This also takes care of complex patterns (i.e. {foo}~{bar})
raise ValueError("There must be exactly one named segment. {} has {}.".format(
path_template, self._how_many_named_segments(path_template)))
segments = self._split_into_segments(path_template)
segment_regexes = [self._convert_segment_to_regex(x) for x in segments]
final_regex = self._merge_segments(segment_regexes)
return final_regex

def _to_regex(self, path_template: str) -> Pattern:
"""Converts path_template into a Python regular expression string.
Args:
path_template (str): A path template corresponding to a resource name.
It can only have 0 or 1 named segments. It can not contain complex resource ID path segments.
See https://google.aip.dev/122, https://google.aip.dev/4222
and https://google.aip.dev/client-libraries/4231 for more details.
Returns:
Pattern: A Pattern object that matches strings conforming to the path_template.
"""
return re.compile(f"^{self._convert_to_regex(path_template)}$")

def to_regex(self) -> Pattern:
return self._to_regex(self.path_template)

@property
def key(self) -> Union[str, None]:
if self.path_template == "":
return self.field
regex = self.to_regex()
group_names = list(regex.groupindex)
# Only 1 named segment is allowed and so only 1 key.
return group_names[0] if group_names else self.field

@property
def sample_request(self) -> str:
"""return json dict for sample request matching the uri template."""
sample = uri_sample.sample_from_path_template(
self.field, self.path_template)
return json.dumps(sample)


@dataclasses.dataclass(frozen=True)
class RoutingRule:
routing_parameters: List[RoutingParameter]

@classmethod
def try_parse_routing_rule(cls, routing_rule: routing_pb2.RoutingRule) -> Optional['RoutingRule']:
params = getattr(routing_rule, 'routing_parameters')
if not params:
return None
params = [RoutingParameter(x.field, x.path_template) for x in params]
return cls(params)


@dataclasses.dataclass(frozen=True)
class HttpRule:
"""Representation of the method's http bindings."""
Expand All @@ -788,59 +902,18 @@ def sample_from_path_fields(paths: List[Tuple[Field, str, str]]) -> Dict[str, An
Returns:
A new nested dict with the templates instantiated.
"""

request: Dict[str, Any] = {}

def _sample_names() -> Iterator[str]:
sample_num: int = 0
while True:
sample_num += 1
yield "sample{}".format(sample_num)

def add_field(obj, path, value):
"""Insert a field into a nested dict and return the (outer) dict.
Keys and sub-dicts are inserted if necessary to create the path.
e.g. if obj, as passed in, is {}, path is "a.b.c", and value is
"hello", obj will be updated to:
{'a':
{'b':
{
'c': 'hello'
}
}
}
Args:
obj: a (possibly) nested dict (parsed json)
path: a segmented field name, e.g. "a.b.c"
where each part is a dict key.
value: the value of the new key.
Returns:
obj, possibly modified
Raises:
AttributeError if the path references a key that is
not a dict.: e.g. path='a.b', obj = {'a':'abc'}
"""

segments = path.split('.')
leaf = segments.pop()
subfield = obj
for segment in segments:
subfield = subfield.setdefault(segment, {})
subfield[leaf] = value
return obj

sample_names = _sample_names()
sample_names_ = uri_sample.sample_names()
for field, path, template in paths:
sample_value = re.sub(
r"(\*\*|\*)",
lambda n: next(sample_names),
lambda n: next(sample_names_),
template or '*'
) if field.type == PrimitiveType.build(str) else field.mock_value_original_type
add_field(request, path, sample_value)
uri_sample.add_field(request, path, sample_value)

return request

sample = sample_from_path_fields(self.path_fields(method))
return sample

Expand Down Expand Up @@ -982,6 +1055,18 @@ def field_headers(self) -> Sequence[str]:

return next((tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ())

@property
def explicit_routing(self):
return routing_pb2.routing in self.options.Extensions

@property
def routing_rule(self):
if self.explicit_routing:
routing_ext = self.options.Extensions[routing_pb2.routing]
routing_rule = RoutingRule.try_parse_routing_rule(routing_ext)
return routing_rule
return None

@property
def http_options(self) -> List[HttpRule]:
"""Return a list of the http bindings for this method."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,31 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = self._transport._wrapped_methods[self._transport.{{ method.name|snake_case}}]
{% if method.field_headers %}

{% if method.explicit_routing %}
header_params = {}
{% for routing_param in method.routing_rule.routing_parameters %}
{% if routing_param.path_template %} {# Need to match. #}

routing_param_regex = {{ routing_param.to_regex() }}
regex_match = routing_param_regex.match(request.{{ routing_param.field }})
if regex_match and regex_match.group("{{ routing_param.key }}"):
header_params["{{ routing_param.key }}"] = regex_match.group("{{ routing_param.key }}")

{% else %}

if request.{{ routing_param.field }}:
header_params["{{ routing_param.key }}"] = request.{{ routing_param.field }}

{% endif %}
{% endfor %} {# method.routing_rule.routing_parameters #}

if header_params:
metadata = tuple(metadata) + (
gapic_v1.routing_header.to_grpc_metadata(header_params),
)

{% elif method.field_headers %} {# implicit routing #}
# Certain fields should be provided within the metadata header;
# add these here.
metadata = tuple(metadata) + (
Expand All @@ -491,7 +514,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
{% endfor %}
)),
)
{% endif %}
{% endif %} {# method.explicit_routing #}

# Send the request.
{%+ if not method.void %}response = {% endif %}rpc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,45 @@ async def test_{{ method_name }}_async_from_dict():
await test_{{ method_name }}_async(request_type=dict)


{% if method.field_headers and not method.client_streaming %}
{% if method.explicit_routing %}
def test_{{ method.name|snake_case }}_routing_parameters():
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
)

{% for routing_param in method.routing_rule.routing_parameters %}
# Any value that is part of the HTTP/1.1 URI should be sent as
# a field header. Set these to a non-empty value.
request = {{ method.input.ident }}({{ routing_param.sample_request }})

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.{{ method.name|snake_case }}),
'__call__') as call:
{% if method.void %}
call.return_value = None
{% elif method.lro %}
call.return_value = operations_pb2.Operation(name='operations/op')
{% elif method.server_streaming %}
call.return_value = iter([{{ method.output.ident }}()])
{% else %}
call.return_value = {{ method.output.ident }}()
{% endif %}
client.{{ method.name|snake_case }}(request)

# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
assert args[0] == request

_, _, kw = call.mock_calls[0]
# This test doesn't assert anything useful.
assert kw['metadata']
{% endfor %}
{% endif %}


{% if method.field_headers and not method.client_streaming and not method.explicit_routing %}
def test_{{ method_name }}_field_headers():
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
Expand Down
Loading

0 comments on commit 8c191a5

Please sign in to comment.