Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add more generic helpers #19

Merged
merged 1 commit into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 336 additions & 1 deletion testbench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,18 @@

"""Common utils"""

import base64
import json
import random
import re
import scalpl
import struct
import types
from flask import Response as FlaskResponse
from google.protobuf import timestamp_pb2
from requests_toolbelt import MultipartDecoder
from requests_toolbelt.multipart.decoder import ImproperBodyPartContentException

import testbench

re_remove_index = re.compile(r"\[\d+\]+|^[0-9]+")
retry_return_error_code = re.compile(r"return-([0-9]+)$")
Expand All @@ -34,6 +43,10 @@ def to_snake_case(string):
return re_snake_case.sub("_", string).lower()


def remove_index(string):
return re_remove_index.sub("", string)


# === FAKE REQUEST === #


Expand All @@ -59,3 +72,325 @@ class FakeRequest(types.SimpleNamespace):

def __init__(self, **kwargs):
super().__init__(**kwargs)

@classmethod
def init_xml(cls, request):
headers = {
key.lower(): value
for key, value in request.headers.items()
if key.lower().startswith("x-goog-") or key.lower() == "range"
}
args = request.args.to_dict()
args.update(cls.xml_headers_to_json_args(headers))
return cls(args=args, headers=headers)

@classmethod
def xml_headers_to_json_args(cls, headers):
field_map = {
"x-goog-if-generation-match": "ifGenerationMatch",
"x-goog-if-metageneration-match": "ifMetagenerationMatch",
"x-goog-acl": "predefinedAcl",
}
args = {}
for field_xml, field_json in field_map.items():
if field_xml in headers:
args[field_json] = headers[field_xml]
return args

def HasField(self, field):
return hasattr(self, field) and getattr(self, field) is not None

@classmethod
def init_protobuf(cls, request, context):
fake_request = FakeRequest(args={}, headers={})
fake_request.update_protobuf(request, context)
return fake_request

def update_protobuf(self, request, context):
for (
proto_field,
args_field,
) in FakeRequest.protobuf_wrapper_to_json_args.items():
if hasattr(request, proto_field) and request.HasField(proto_field):
self.args[args_field] = getattr(request, proto_field).value
setattr(self, proto_field, getattr(request, proto_field))
for (
proto_field,
args_field,
) in FakeRequest.protobuf_scalar_to_json_args.items():
if hasattr(request, proto_field):
self.args[args_field] = getattr(request, proto_field)
setattr(self, proto_field, getattr(request, proto_field))
csek_field = "common_object_request_params"
if hasattr(request, csek_field):
algorithm, key_b64, key_sha256_b64 = testbench.csek.extract(
request, False, context
)
self.headers["x-goog-encryption-algorithm"] = algorithm
self.headers["x-goog-encryption-key"] = key_b64
self.headers["x-goog-encryption-key-sha256"] = key_sha256_b64
setattr(self, csek_field, getattr(request, csek_field))
elif not hasattr(self, csek_field):
setattr(
self,
csek_field,
types.SimpleNamespace(
encryption_algorithm="", encryption_key="", encryption_key_sha256=""
),
)


# === REST === #


def nested_key(data):
# This function take a dict and return a list of keys that works with `Scalpl` library.
if isinstance(data, list):
keys = []
for i in range(len(data)):
result = nested_key(data[i])
if isinstance(result, list):
if isinstance(data[i], dict):
keys.extend(["[%d].%s" % (i, item) for item in result])
elif isinstance(data[i], list):
keys.extend(["[%d]%s" % (i, item) for item in result])
elif result == "":
keys.append("[%d]" % i)
return keys
elif isinstance(data, dict):
keys = []
for key, value in data.items():
result = nested_key(value)
if isinstance(result, list):
if isinstance(value, dict):
keys.extend(["%s.%s" % (key, item) for item in result])
elif isinstance(value, list):
keys.extend(["%s%s" % (key, item) for item in result])
elif result == "":
keys.append("%s" % key)
return keys
else:
return ""


def parse_fields(fields):
# "kind,items(id,name)" -> ["kind", "items.id", "items.name"]
res = []
for i, c in enumerate(fields):
if c != " " and c != ")":
if c == "/":
res.append(".")
else:
res.append(c)
elif c == ")":
childrens_fields = []
tmp_field = []
while res:
if res[-1] != "," and res[-1] != "(":
tmp_field.append(res.pop())
else:
childrens_fields.append(tmp_field)
tmp_field = []
if res.pop() == "(":
break
parent_field = []
while res and res[-1] != "," and res[-1] != "(":
parent_field.append(res.pop())
for i, field in enumerate(childrens_fields):
res.extend(parent_field[::-1])
res.append(".")
while field:
res.append(field.pop())
if i < len(childrens_fields) - 1:
res.append(",")
return "".join(res).split(",")


def filter_response_rest(response, projection, fields):
if fields is not None:
fields = parse_fields(fields)
deleted_keys = set()
for key in nested_key(response):
simplfied_key = remove_index(key)
maybe_delete = True
if projection == "noAcl":
maybe_delete = False
if simplfied_key.startswith("owner"):
deleted_keys.add("owner")
elif simplfied_key.startswith("items.owner"):
deleted_keys.add(key[0 : key.find("owner") + len("owner")])
elif simplfied_key.startswith("acl"):
deleted_keys.add("acl")
elif simplfied_key.startswith("items.acl"):
deleted_keys.add(key[0 : key.find("acl") + len("acl")])
elif simplfied_key.startswith("defaultObjectAcl"):
deleted_keys.add("defaultObjectAcl")
elif simplfied_key.startswith("items.defaultObjectAcl"):
deleted_keys.add(
key[0 : key.find("defaultObjectAcl") + len("defaultObjectAcl")]
)
else:
maybe_delete = True
if fields is not None:
if maybe_delete:
for field in fields:
if field != "" and simplfied_key.startswith(field):
maybe_delete = False
break
if maybe_delete:
deleted_keys.add(key)
proxy = scalpl.Cut(response)
for key in deleted_keys:
del proxy[key]
return proxy.data


def parse_multipart(request):
content_type = request.headers.get("content-type")
if content_type is None or not content_type.startswith("multipart/related"):
testbench.error.invalid("Content-type header in multipart upload", None)
_, _, boundary = content_type.partition("boundary=")
if boundary is None:
testbench.error.missing(
"boundary in content-type header in multipart upload", None
)

body = extract_media(request)
try:
decoder = MultipartDecoder(body, content_type)
except ImproperBodyPartContentException as e:
testbench.error.invalid("Multipart body is malformed\n%s" % str(body), None)
if len(decoder.parts) != 2:
testbench.error.invalid("Multipart body is malformed\n%s" % str(body), None)
resource = decoder.parts[0].text
metadata = json.loads(resource)
content_type_key = "content-type".encode("utf-8")
headers = decoder.parts[1].headers
content_type = {"content-type": headers[content_type_key].decode("utf-8")}
media = decoder.parts[1].content
return metadata, content_type, media


def extract_media(request):
"""Extract the media from a flask Request.

To avoid race conditions when using greenlets we cannot perform I/O in the
constructor of GcsObject, or in any of the operations that modify the state
of the service. Because sometimes the media is uploaded with chunked encoding,
we need to do I/O before finishing the GcsObject creation. If we do this I/O
after the GcsObject creation started, the the state of the application may change
due to other I/O.

:param request:flask.Request the HTTP request.
:return: the full media of the request.
:rtype: str
"""
if request.environ.get("HTTP_TRANSFER_ENCODING", "") == "chunked":
return request.environ.get("wsgi.input").read()
return request.data


# === RESPONSE === #


def extract_projection(request, default, context):
if context is not None:
return request.projection if request.projection != 0 else default
else:
projection_map = ["noAcl", "full"]
projection = request.args.get("projection")
return (
projection if projection in projection_map else projection_map[default - 1]
)


# === DATA === #


def corrupt_media(media):
if not media:
return bytearray(random.sample("abcdefghijklmnopqrstuvwxyz", 1), "utf-8")
return b"B" + media[1:] if media[0:1] == b"A" else b"A" + media[1:]


# === HEADERS === #


def extract_instruction(request, context):
instruction = None
if context is not None:
if hasattr(context, "invocation_metadata"):
for key, value in context.invocation_metadata():
if key == "x-goog-emulator-instructions":
instruction = value
else:
instruction = request.headers.get("x-goog-emulator-instructions")
if instruction is None:
instruction = request.headers.get("x-goog-testbench-instructions")
return instruction


def enforce_patch_override(request):
if (
request.method == "POST"
and request.headers.get("X-Http-Method-Override", "") != "PATCH"
):
testbench.error.notallowed(context=None)


def rest_crc32c_to_proto(crc32c):
"""Convert from the REST representation of crc32c checksums to the proto representation.

REST uses base64 encoded 32-bit big endian integers, while protos use just `int32`.
"""
return struct.unpack(">I", base64.b64decode(crc32c.encode("utf-8")))[0]


def rest_crc32c_from_proto(crc32c):
"""Convert from the gRPC/proto representation of crc32c checksums to the REST representation.

REST uses base64 encoded 32-bit big endian integers, while protos use just `int32`.
"""
return base64.b64encode(struct.pack(">I", crc32c)).decode("utf-8")


def rest_rfc3339_to_proto(rfc3339):
"""Convert a RFC3339 timestamp to the google.protobuf.Timestamp format."""
ts = timestamp_pb2.Timestamp()
ts.FromJsonString(rfc3339)
return ts


def rest_adjust(data, adjustments):
"""
Apply a per-key 'actions' to a dictionary *if* the key is present.

When mapping between the gRPC and the REST representations of resources
(Bucket, Object, etc.) we sometimes need to change the name and/or format
of some fields.

The `adjustments` describes what keys (if present) need adjustment, and
a function that returns the new key and value for the item in `data`.

Parameters
----------
data : dict
A dictionary, typically the REST representation of a resource
adjustments : dict
The keys in `data` that, if present, need adjustment. The values
in this dictionary are functions returning a (key, value) tuple
that replaces the existing tuple in `data`.

Returns
-------
dict
A copy of `data` with the changes prescribed by `adjustments`.
"""
modified = data.copy()
for key, action in adjustments.items():
value = modified.pop(key, None)
if value is not None:
k, v = action(value)
if k is not None:
modified[k] = v
return modified
Loading