Skip to content

Commit

Permalink
moving some acr specific recording infra to the general replacers (#1…
Browse files Browse the repository at this point in the history
…8716)

* moving some acr specific recording infra to the general replacers

* removing unneeded stuff

* removing a print statement
  • Loading branch information
seankane-msft authored May 19, 2021
1 parent c533d8e commit 03b6d21
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 64 deletions.
73 changes: 10 additions & 63 deletions sdk/containerregistry/azure-containerregistry/tests/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
import copy
from datetime import datetime
import json
import logging
import os
from azure_devtools.scenario_tests.recording_processors import SubscriptionRecordingProcessor
import pytest
import re
import six
import time

Expand All @@ -19,31 +16,19 @@
ContainerRegistryClient,
ArtifactTagProperties,
ContentProperties,
ArtifactManifestProperties,
)

from azure.core.credentials import AccessToken
from azure.mgmt.containerregistry import ContainerRegistryManagementClient
from azure.mgmt.containerregistry.models import (
ImportImageParameters,
ImportSource,
ImportMode
)
from azure.mgmt.containerregistry.models import ImportImageParameters, ImportSource, ImportMode
from azure.identity import DefaultAzureCredential

from devtools_testutils import AzureTestCase, is_live
from azure_devtools.scenario_tests import (
GeneralNameReplacer,
RequestUrlNormalizer,
AuthenticationMetadataFilter,
RecordingProcessor,
)
from azure_devtools.scenario_tests import (
GeneralNameReplacer,
RequestUrlNormalizer,
AuthenticationMetadataFilter,
OAuthRequestResponsesFilter,
RecordingProcessor,
)
from azure_devtools.scenario_tests import RecordingProcessor


REDACTED = "REDACTED"
Expand Down Expand Up @@ -78,11 +63,8 @@ def process_request(self, request):
class AcrBodyReplacer(RecordingProcessor):
"""Replace request body for oauth2 exchanges"""

def __init__(self, replacement="redacted"):
self._replacement = replacement
def __init__(self):
self._401_replacement = 'Bearer realm="https://fake_url.azurecr.io/oauth2/token",service="fake_url.azurecr.io",scope="fake_scope",error="invalid_token"'
self._redacted_service = "https://fakeurl.azurecr.io"
self._regex = r"(https://)[a-zA-Z0-9]+(\.azurecr.io)"

def _scrub_body(self, body):
# type: (bytes) -> bytes
Expand All @@ -109,29 +91,16 @@ def _scrub_body_dict(self, body):
for k in ["access_token", "refresh_token"]:
if k in new_body.keys():
new_body[k] = REDACTED
if "service" in new_body.keys():
new_body["service"] = "fake_url.azurecr.io"
return new_body

def process_request(self, request):
if request.body:
request.body = self._scrub_body(request.body)

if "seankane.azurecr.io" in request.uri:
request.uri = request.uri.replace("seankane.azurecr.io", "fake_url.azurecr.io")
if "seankane.azurecr.io" in request.url:
request.url = request.url.replace("seankane.azurecr.io", "fake_url.azurecr.io")

if "seankaneanon.azurecr.io" in request.uri:
request.uri = request.uri.replace("seankaneanon.azurecr.io", "fake_url.azurecr.io")
if "seankaneanon.azurecr.io" in request.url:
request.url = request.url.replace("seankaneanon.azurecr.io", "fake_url.azurecr.io")

return request

def process_response(self, response):
try:
self.process_url(response)
headers = response["headers"]

if "www-authenticate" in headers:
Expand All @@ -144,24 +113,13 @@ def process_response(self, response):
if body["string"] == b"" or body["string"] == "null":
return response

if "seankane.azurecr.io" in body["string"]:
body["string"] = body["string"].replace("seankane.azurecr.io", "fake_url.azurecr.io")

if "seankaneanon.azurecr.io" in body["string"]:
body["string"] = body["string"].replace("seankaneanon.azurecr.io", "fake_url.azurecr.io")

refresh = json.loads(body["string"])
if "refresh_token" in refresh.keys():
refresh["refresh_token"] = REDACTED
if "access_token" in refresh.keys():
refresh["access_token"] = REDACTED
if "service" in refresh.keys():
s = refresh["service"].split(".")
s[0] = "fake_url"
refresh["service"] = ".".join(s)
body["string"] = json.dumps(refresh)
except ValueError:
# Python 2.7 doesn't have the below error
pass
except json.decoder.JSONDecodeError:
pass
Expand All @@ -170,12 +128,6 @@ def process_response(self, response):
except (KeyError, ValueError):
return response

def process_url(self, response):
try:
response["url"] = re.sub(self._regex, r"\1{}\2".format("fake_url"), response["url"])
except KeyError:
pass


class FakeTokenCredential(object):
"""Protocol for classes able to provide OAuth tokens.
Expand All @@ -191,18 +143,13 @@ def get_token(self, *args):

class ContainerRegistryTestClass(AzureTestCase):
def __init__(self, method_name):
super(ContainerRegistryTestClass, self).__init__(
method_name,
recording_processors=[
GeneralNameReplacer(),
OAuthRequestResponsesFilterACR(),
AuthenticationMetadataFilter(),
RequestUrlNormalizer(),
AcrBodyReplacer(),
ManagementRequestReplacer(),
],
)
super(ContainerRegistryTestClass, self).__init__(method_name)
self.repository = "library/busybox"
self.recording_processors.append(AcrBodyReplacer())
self.recording_processors.append(ManagementRequestReplacer())
for idx, p in enumerate(self.recording_processors):
if isinstance(p, OAuthRequestResponsesFilter):
self.recording_processors[idx] = OAuthRequestResponsesFilterACR()

def sleep(self, t):
if self.is_live:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from copy import deepcopy
import six

from .utilities import is_text_payload, is_json_payload, is_batch_payload
Expand Down Expand Up @@ -197,7 +198,7 @@ def process_request(self, request):

if is_text_payload(request) and request.body:
if isinstance(request.body, dict):
pass
request.body = self._process_body_as_dict(request.body)
else:
body = six.ensure_str(request.body)
if old in body:
Expand All @@ -210,8 +211,18 @@ def process_request(self, request):
for matched_object in matched_objects:
request.body = body.replace(matched_object, new)
body = body.replace(matched_object, new)

return request

def _process_body_as_dict(self, body):
new_body = deepcopy(body)

for key in new_body.keys():
for old, new in self.names_name:
new_body[key].replace(old, new)

return new_body

def process_response(self, response):
for old, new in self.names_name:
if is_text_payload(response) and response['body']['string']:
Expand All @@ -225,6 +236,13 @@ def process_response(self, response):
self.replace_header(response, 'location', old, new)
self.replace_header(response, 'operation-location', old, new)
self.replace_header(response, 'azure-asyncoperation', old, new)
self.replace_header(response, "www-authenticate", old, new)

try:
for old, new in self.names_name:
response["url"].replace(old, new)
except KeyError:
pass

try:
response["url"] = response["url"].replace(old, new)
Expand Down

0 comments on commit 03b6d21

Please sign in to comment.