diff --git a/src/spring/HISTORY.md b/src/spring/HISTORY.md index 38ce813364f..1ddf3fc0e87 100644 --- a/src/spring/HISTORY.md +++ b/src/spring/HISTORY.md @@ -1,5 +1,9 @@ Release History =============== +1.19.0 +--- +* Add new commands for managed component log streaming `az spring component list`, `az spring component instance list` and `az spring component logs`. + 1.18.0 --- * Add arguments `--bind-service-registry` in `spring app create`. diff --git a/src/spring/azext_spring/_help.py b/src/spring/azext_spring/_help.py index 3c7eda1bf0b..b6622fedfea 100644 --- a/src/spring/azext_spring/_help.py +++ b/src/spring/azext_spring/_help.py @@ -1541,3 +1541,45 @@ - name: Disable an APM globally. text: az spring apm disable-globally --name first-apm --service MyCluster --resource-group MyResourceGroup """ + +helps['spring component'] = """ + type: group + short-summary: (Enterprise Tier Only) Commands to handle managed components. +""" + +helps['spring component logs'] = """ + type: command + short-summary: (Enterprise Tier Only) Show logs for managed components. Logs will be streamed when setting '-f/--follow'. For now, only supports subcomponents of (a) Application Configuration Service (b) Spring Cloud Gateway + examples: + - name: Show logs for all instances of flux in Application Configuration Serice (Gen2) + text: az spring component logs --name flux-source-controller --service MyAzureSpringAppsInstance --resource-group MyResourceGroup --all-instances + - name: Show logs for a specific instance of application-configuration-service in Application Configuration Serice + text: az spring component logs --name application-configuration-service --service MyAzureSpringAppsInstance --resource-group MyResourceGroup --instance InstanceName + - name: Stream and watch logs for all instances of spring-cloud-gateway + text: az spring component logs --name spring-cloud-gateway --service MyAzureSpringAppsInstance --resource-group MyResourceGroup --all-instances --follow + - name: Show logs for a specific instance without specify the component name + text: az spring component logs --service MyAzureSpringAppsInstance --resource-group MyResourceGroup --instance InstanceName +""" + +helps['spring component list'] = """ + type: command + short-summary: (Enterprise Tier Only) List managed components. + examples: + - name: List all managed components + text: az spring component list --service MyAzureSpringAppsInstance --resource-group MyResourceGroup +""" + +helps['spring component instance'] = """ + type: group + short-summary: (Enterprise Tier Only) Commands to handle instances of a managed component. +""" + +helps['spring component instance list'] = """ + type: command + short-summary: (Enterprise Tier Only) List all available instances of a specific managed component in an Azure Spring Apps instance. + examples: + - name: List instances for spring-cloud-gateway of Spring Cloud Gateway + text: az spring component instance list --component spring-cloud-gateway --service MyAzureSpringAppsInstance --resource-group MyResourceGroup + - name: List instances for spring-cloud-gateway-operator of Spring Cloud Gateway + text: az spring component instance list --component spring-cloud-gateway-operator --service MyAzureSpringAppsInstance --resource-group MyResourceGroup +""" diff --git a/src/spring/azext_spring/_params.py b/src/spring/azext_spring/_params.py index f9c7230ce0d..afe4edb511f 100644 --- a/src/spring/azext_spring/_params.py +++ b/src/spring/azext_spring/_params.py @@ -284,7 +284,7 @@ def load_arguments(self, _): TestKeyType), help='Type of test-endpoint key') with self.argument_context('spring list-support-server-versions') as c: - c.argument('service', service_name_type, validator=not_support_enterprise) + c.argument('service', service_name_type, validator=not_support_enterprise) with self.argument_context('spring app') as c: c.argument('service', service_name_type) @@ -1095,3 +1095,36 @@ def prepare_logs_argument(c): c.argument('private_key', help='Private SSH Key algorithm of git repository.') c.argument('host_key', help='Public SSH Key of git repository.') c.argument('host_key_algorithm', help='SSH Key algorithm of git repository.') + + for scope in ['spring component']: + with self.argument_context(scope) as c: + c.argument('service', service_name_type) + + with self.argument_context('spring component logs') as c: + c.argument('name', options_list=['--name', '-n'], + help="Name of the component. Find component names from command `az spring component list`") + c.argument('all_instances', + help='The flag to indicate get logs for all instances of the component.', + action='store_true') + c.argument('instance', + options_list=['--instance', '-i'], + help='Name of an existing instance of the component.') + c.argument('follow', + options_list=['--follow ', '-f'], + help='The flag to indicate logs should be streamed.', + action='store_true') + c.argument('lines', + type=int, + help='Number of lines to show. Maximum is 10000. Default is 50.') + c.argument('since', + help='Only return logs newer than a relative duration like 5s, 2m, or 1h. Maximum is 1h') + c.argument('limit', + type=int, + help='Maximum kibibyte of logs to return. Ceiling number is 2048.') + c.argument('max_log_requests', + type=int, + help="Specify maximum number of concurrent logs to follow when get logs by all-instances.") + + with self.argument_context('spring component instance') as c: + c.argument('component', options_list=['--component', '-c'], + help="Name of the component. Find components from command `az spring component list`") diff --git a/src/spring/azext_spring/app.py b/src/spring/azext_spring/app.py index 9d0b0c4082f..33d852fb51d 100644 --- a/src/spring/azext_spring/app.py +++ b/src/spring/azext_spring/app.py @@ -8,7 +8,7 @@ from knack.log import get_logger from azure.cli.core.util import sdk_no_wait from azure.cli.core.azclierror import (ValidationError, ArgumentUsageError) -from .custom import app_get, _get_app_log +from .custom import app_get from ._utils import (get_spring_sku, wait_till_end, convert_argument_to_parameter_list) from ._deployment_factory import (deployment_selector, deployment_settings_options_from_resource, @@ -20,6 +20,7 @@ from .custom import app_tail_log_internal import datetime from time import sleep +from .log_stream.log_stream_operations import log_stream_from_url logger = get_logger(__name__) DEFAULT_DEPLOYMENT_NAME = "default" @@ -516,7 +517,7 @@ def _get_deployment_ignore_exception(client, resource_group, service, app_name, def _get_app_log_deploy_phase(url, auth, format_json, exceptions): try: - _get_app_log(url, auth, format_json, exceptions, chunk_size=10 * 1024, stderr=True) + log_stream_from_url(url, auth, format_json, exceptions, chunk_size=10 * 1024, stderr=True) except Exception: pass diff --git a/src/spring/azext_spring/commands.py b/src/spring/azext_spring/commands.py index 1fe265aa24d..2ac3543ad1a 100644 --- a/src/spring/azext_spring/commands.py +++ b/src/spring/azext_spring/commands.py @@ -31,7 +31,8 @@ transform_support_server_versions_output) from ._validators import validate_app_insights_command_not_supported_tier from ._marketplace import (transform_marketplace_plan_output) -from ._validators_enterprise import (validate_gateway_update, validate_api_portal_update, validate_dev_tool_portal, validate_customized_accelerator, validate_central_build_instance) +from ._validators_enterprise import (validate_gateway_update, validate_api_portal_update, validate_dev_tool_portal, validate_customized_accelerator) +from .managed_components.validators_managed_component import (validate_component_logs, validate_component_list, validate_instance_list) from ._app_managed_identity_validator import (validate_app_identity_remove_or_warning, validate_app_identity_assign_or_warning) @@ -118,6 +119,11 @@ def load_command_table(self, _): client_factory=cf_spring ) + managed_component_cmd_group = CliCommandType( + operations_tmpl='azext_spring.managed_components.managed_component_operations#{}', + client_factory=cf_spring + ) + with self.command_group('spring', custom_command_type=spring_routing_util, exception_handler=handle_asc_exception) as g: g.custom_command('create', 'spring_create', supports_no_wait=True) @@ -459,5 +465,16 @@ def load_command_table(self, _): g.custom_command('update', 'update_build_service', supports_no_wait=True) g.custom_show_command('show', 'build_service_show') + with self.command_group('spring component', + custom_command_type=managed_component_cmd_group, + exception_handler=handle_asc_exception) as g: + g.custom_command('logs', 'managed_component_logs', validator=validate_component_logs) + g.custom_command('list', 'managed_component_list', validator=validate_component_list) + + with self.command_group('spring component instance', + custom_command_type=managed_component_cmd_group, + exception_handler=handle_asc_exception) as g: + g.custom_command('list', 'managed_component_instance_list', validator=validate_instance_list) + with self.command_group('spring', exception_handler=handle_asc_exception): pass diff --git a/src/spring/azext_spring/custom.py b/src/spring/azext_spring/custom.py index 44e5b8b92bd..5e5358b8068 100644 --- a/src/spring/azext_spring/custom.py +++ b/src/spring/azext_spring/custom.py @@ -37,6 +37,7 @@ from collections import defaultdict from ._log_stream import LogStream from ._build_service import _update_default_build_agent_pool +from .log_stream.log_stream_operations import log_stream_from_url logger = get_logger(__name__) DEFAULT_DEPLOYMENT_NAME = "default" @@ -512,7 +513,7 @@ def app_get_build_log(cmd, client, resource_group, service, name, deployment=Non def app_tail_log(cmd, client, resource_group, service, name, deployment=None, instance=None, follow=False, lines=50, since=None, limit=2048, format_json=None): app_tail_log_internal(cmd, client, resource_group, service, name, deployment, instance, follow, lines, since, limit, - format_json, get_app_log=_get_app_log) + format_json, get_app_log=log_stream_from_url) def app_tail_log_internal(cmd, client, resource_group, service, name, @@ -1167,136 +1168,6 @@ def _get_redis_primary_key(cli_ctx, resource_id): return keys.primary_key -# pylint: disable=bare-except, too-many-statements -def _get_app_log(url, auth, format_json, exceptions, chunk_size=None, stderr=False): - logger_seg_regex = re.compile(r'([^\.])[^\.]+\.') - - def build_log_shortener(length): - if length <= 0: - raise InvalidArgumentValueError('Logger length in `logger{length}` should be positive') - - def shortener(record): - ''' - Try shorten the logger property to the specified length before feeding it to the formatter. - ''' - logger_name = record.get('logger', None) - if logger_name is None: - return record - - # first, try to shorten the package name to one letter, e.g., - # org.springframework.cloud.netflix.eureka.config.DiscoveryClientOptionalArgsConfiguration - # to: o.s.c.n.e.c.DiscoveryClientOptionalArgsConfiguration - while len(logger_name) > length: - logger_name, count = logger_seg_regex.subn(r'\1.', logger_name, 1) - if count < 1: - break - - # then, cut off the leading packages if necessary - logger_name = logger_name[-length:] - record['logger'] = logger_name - return record - - return shortener - - def build_formatter(): - ''' - Build the log line formatter based on the format_json argument. - ''' - nonlocal format_json - - def identity(o): - return o - - if format_json is None or len(format_json) == 0: - return identity - - logger_regex = re.compile(r'\blogger\{(\d+)\}') - match = logger_regex.search(format_json) - pre_processor = identity - if match: - length = int(match[1]) - pre_processor = build_log_shortener(length) - format_json = logger_regex.sub('logger', format_json, 1) - - first_exception = True - - def format_line(line): - nonlocal first_exception - try: - log_record = json.loads(line) - # Add n=\n so that in Windows CMD it's easy to specify customized format with line ending - # e.g., "{timestamp} {message}{n}" - # (Windows CMD does not escape \n in string literal.) - return format_json.format_map(pre_processor(defaultdict(str, n="\n", **log_record))) - except: - if first_exception: - # enable this format error logging only with --verbose - logger.info("Failed to format log line '{}'".format(line), exc_info=sys.exc_info()) - first_exception = False - return line - - return format_line - - def iter_lines(response, limit=2 ** 20, chunk_size=None): - ''' - Returns a line iterator from the response content. If no line ending was found and the buffered content size is - larger than the limit, the buffer will be yielded directly. - ''' - buffer = [] - total = 0 - for content in response.iter_content(chunk_size=chunk_size): - if not content: - if len(buffer) > 0: - yield b''.join(buffer) - break - - start = 0 - while start < len(content): - line_end = content.find(b'\n', start) - should_print = False - if line_end < 0: - next = (content if start == 0 else content[start:]) - buffer.append(next) - total += len(next) - start = len(content) - should_print = total >= limit - else: - buffer.append(content[start:line_end + 1]) - start = line_end + 1 - should_print = True - - if should_print: - yield b''.join(buffer) - buffer.clear() - total = 0 - - with requests.get(url, stream=True, auth=auth) as response: - try: - if response.status_code != 200: - failure_reason = response.reason - if response.content: - if isinstance(response.content, bytes): - failure_reason = "{}:{}".format(failure_reason, response.content.decode('utf-8')) - else: - failure_reason = "{}:{}".format(failure_reason, response.content) - raise CLIError("Failed to connect to the server with status code '{}' and reason '{}'".format( - response.status_code, failure_reason)) - std_encoding = sys.stdout.encoding - - formatter = build_formatter() - - for line in iter_lines(response, chunk_size=chunk_size): - decoded = (line.decode(encoding='utf-8', errors='replace') - .encode(std_encoding, errors='replace') - .decode(std_encoding, errors='replace')) - if stderr: - print(formatter(decoded), end='', file=sys.stderr) - else: - print(formatter(decoded), end='') - except CLIError as e: - exceptions.append(e) - - def storage_callback(pipeline_response, deserialized, headers): return models.StorageResource.deserialize(json.loads(pipeline_response.http_response.text())) diff --git a/src/spring/azext_spring/log_stream/__init__.py b/src/spring/azext_spring/log_stream/__init__.py new file mode 100644 index 00000000000..34913fb394d --- /dev/null +++ b/src/spring/azext_spring/log_stream/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/src/spring/azext_spring/log_stream/log_stream_operations.py b/src/spring/azext_spring/log_stream/log_stream_operations.py new file mode 100644 index 00000000000..57079c845dc --- /dev/null +++ b/src/spring/azext_spring/log_stream/log_stream_operations.py @@ -0,0 +1,164 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import re +import requests +import json +import sys + +from azure.cli.core.azclierror import InvalidArgumentValueError +from collections import defaultdict +from knack.log import get_logger +from knack.util import CLIError +from .writer import DefaultWriter + + +logger = get_logger(__name__) + + +# pylint: disable=bare-except, too-many-statements +def iter_lines(response, limit=2 ** 20, chunk_size=None): + ''' + Returns a line iterator from the response content. If no line ending was found and the buffered content size is + larger than the limit, the buffer will be yielded directly. + ''' + buffer = [] + total = 0 + for content in response.iter_content(chunk_size=chunk_size): + if not content: + if len(buffer) > 0: + yield b''.join(buffer) + break + + start = 0 + while start < len(content): + line_end = content.find(b'\n', start) + should_print = False + if line_end < 0: + next = (content if start == 0 else content[start:]) + buffer.append(next) + total += len(next) + start = len(content) + should_print = total >= limit + else: + buffer.append(content[start:line_end + 1]) + start = line_end + 1 + should_print = True + + if should_print: + yield b''.join(buffer) + buffer.clear() + total = 0 + + +def log_stream_from_url(url, auth, format_json, exceptions, writer=DefaultWriter(), chunk_size=None, stderr=False): + logger_seg_regex = re.compile(r'([^\.])[^\.]+\.') + + def build_log_shortener(length): + if length <= 0: + raise InvalidArgumentValueError('Logger length in `logger{length}` should be positive') + + def shortener(record): + ''' + Try shorten the logger property to the specified length before feeding it to the formatter. + ''' + logger_name = record.get('logger', None) + if logger_name is None: + return record + + # first, try to shorten the package name to one letter, e.g., + # org.springframework.cloud.netflix.eureka.config.DiscoveryClientOptionalArgsConfiguration + # to: o.s.c.n.e.c.DiscoveryClientOptionalArgsConfiguration + while len(logger_name) > length: + logger_name, count = logger_seg_regex.subn(r'\1.', logger_name, 1) + if count < 1: + break + + # then, cut off the leading packages if necessary + logger_name = logger_name[-length:] + record['logger'] = logger_name + return record + + return shortener + + def build_formatter(): + ''' + Build the log line formatter based on the format_json argument. + ''' + nonlocal format_json + + def identity(o): + return o + + if format_json is None or len(format_json) == 0: + return identity + + logger_regex = re.compile(r'\blogger\{(\d+)\}') + match = logger_regex.search(format_json) + pre_processor = identity + if match: + length = int(match[1]) + pre_processor = build_log_shortener(length) + format_json = logger_regex.sub('logger', format_json, 1) + + first_exception = True + + def format_line(line): + nonlocal first_exception + try: + log_record = json.loads(line) + # Add n=\n so that in Windows CMD it's easy to specify customized format with line ending + # e.g., "{timestamp} {message}{n}" + # (Windows CMD does not escape \n in string literal.) + return format_json.format_map(pre_processor(defaultdict(str, n="\n", **log_record))) + except: + if first_exception: + # enable this format error logging only with --verbose + logger.info("Failed to format log line '{}'".format(line), exc_info=sys.exc_info()) + first_exception = False + return line + + return format_line + + try: + with requests.get(url, stream=True, auth=auth) as response: + try: + if response.status_code != 200: + failure_reason = response.reason + if response.content: + if isinstance(response.content, bytes): + failure_reason = "{}:{}".format(failure_reason, response.content.decode('utf-8')) + else: + failure_reason = "{}:{}".format(failure_reason, response.content) + raise CLIError("Failed to access the url '{}' with status code '{}' and reason '{}'".format( + url, response.status_code, failure_reason)) + std_encoding = sys.stdout.encoding + + formatter = build_formatter() + + for line in iter_lines(response, chunk_size=chunk_size): + decoded = (line.decode(encoding='utf-8', errors='replace') + .encode(std_encoding, errors='replace') + .decode(std_encoding, errors='replace')) + if stderr: + writer.write(formatter(decoded), end='', file=sys.stderr) + else: + writer.write(formatter(decoded), end='') + except CLIError as e: + exceptions.append(e) + except requests.exceptions.ConnectionError as e: + try: + message = str(e) + if "getaddrinfo failed" in message: + exceptions.append(CLIError("Failed to connect to \"{}\" due to getaddrinfo failed. " + "For an Azure Spring Apps instance deployed in a custom virtual network, " + "you can access log streaming by default from a private network. " + "But if you want to access real-time app logs from a public network, " + "please make sure 'Dataplane resources on public network' is enabled. " + "Learn more https://aka.ms/asa/component/logstream/vnet" + .format(url))) + else: + exceptions.append(CLIError("Failed to connecto to \"{}\" due to \"{}\"".format(url, message))) + except Exception: + exceptions.append(CLIError("Failed to connect to '{}'.".format(url))) diff --git a/src/spring/azext_spring/log_stream/writer.py b/src/spring/azext_spring/log_stream/writer.py new file mode 100644 index 00000000000..54a95947cde --- /dev/null +++ b/src/spring/azext_spring/log_stream/writer.py @@ -0,0 +1,18 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import sys + + +class DefaultWriter: + def write(self, data, end='', file=None): + print(data, end=end, file=file) + + +class PrefixWriter(DefaultWriter): + def __init__(self, prefix): + self.prefix = prefix + + def write(self, data, end='', file=None): + super().write("{} {}".format(self.prefix, data), end=end, file=file) diff --git a/src/spring/azext_spring/managed_components/__init__.py b/src/spring/azext_spring/managed_components/__init__.py new file mode 100644 index 00000000000..34913fb394d --- /dev/null +++ b/src/spring/azext_spring/managed_components/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/src/spring/azext_spring/managed_components/managed_component.py b/src/spring/azext_spring/managed_components/managed_component.py new file mode 100644 index 00000000000..43330f991ad --- /dev/null +++ b/src/spring/azext_spring/managed_components/managed_component.py @@ -0,0 +1,174 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from knack.log import get_logger +from knack.util import CLIError +from azure.core.exceptions import ResourceNotFoundError +from ..application_configuration_service import DEFAULT_NAME as ACS_DEFAULT_NAME +from ..gateway import DEFAULT_NAME as SCG_DEFAULT_NAME + + +logger = get_logger(__name__) + + +# Acs +ACS = "application-configuration-service" +ACS_INSTANCE_PREFIX = "application-configuration-service" + + +# Flux +FLUX = "flux-source-controller" +FLUX_INSTANCE_PREFIX = "fluxcd-source-controller" + + +# Scg +SCG = "spring-cloud-gateway" + + +# Scg operator +SCG_OPERATOR = "spring-cloud-gateway-operator" + + +class ManagedComponentInstance: + def __init__(self, name): + self.name = name + + +class ManagedComponent: + def __init__(self, name): + self.name = name + + def get_name(self): + return self.name + + def get_api_name(self): + return self._to_camel_case(self.name) + + def match(self, name): + return name and self.name == name + + def match_ignore_case(self, name: str): + return name and self.name.lower() == name.lower() + + def list_instances(self, client, resource_group, service): + raise NotImplementedError("Must be implemented by child class.") + + def _to_camel_case(self, text): + if text is None or len(text) == 0: + return text + + s = text.replace("-", " ").replace("_", " ") + s = s.split() + + if len(s) == 1: + return s[0] + + return s[0] + ''.join(i.capitalize() for i in s[1:]) + + +class Acs(ManagedComponent): + def __init__(self): + super().__init__(ACS) + + def list_instances(self, client, resource_group, service): + try: + return self._list_instances(client, resource_group, service) + except ResourceNotFoundError: + raise CLIError("'{}' is a subcomponent of Application Configuration Service (ACS), " + "failed to perform operations when ACS is not enabled.".format(self.name)) + + def _list_instances(self, client, resource_group, service): + acs_arm_resource = client.configuration_services.get(resource_group, service, ACS_DEFAULT_NAME) + instance_array = acs_arm_resource.properties.instances + instances = [] + for i in instance_array: + if i.name.startswith(ACS_INSTANCE_PREFIX): + instances.append(ManagedComponentInstance(i.name)) + if len(instances) == 0: + logger.warning("No instance found for component {}.".format(self.name)) + return instances + + +class Flux(ManagedComponent): + def __init__(self): + super().__init__(FLUX) + + def list_instances(self, client, resource_group, service): + try: + return self._list_instances(client, resource_group, service) + except ResourceNotFoundError: + raise CLIError("'{}' is a subcomponent of Application Configuration Service (ACS) Gen2, " + "failed to perform operations when ACS is not enabled.".format(self.name)) + + def _list_instances(self, client, resource_group, service): + acs_arm_resource = client.configuration_services.get(resource_group, service, ACS_DEFAULT_NAME) + instance_array = acs_arm_resource.properties.instances + instances = [] + for i in instance_array: + if i.name.startswith(FLUX_INSTANCE_PREFIX): + instances.append(ManagedComponentInstance(i.name)) + if len(instances) == 0: + logger.warning("No instance found for component {}. " + "Please double check Application Configuration Service Gen2 is enabled.".format(self.name)) + return instances + + +class Scg(ManagedComponent): + def __init__(self): + super().__init__(SCG) + + def list_instances(self, client, resource_group, service): + try: + return self._list_instances(client, resource_group, service) + except ResourceNotFoundError: + raise CLIError("'{}' is a subcomponent of Spring Cloud Gateway (SCG), " + "failed to perform operations when SCG is not enabled.".format(self.name)) + + def _list_instances(self, client, resource_group, service): + scg_arm_resource = client.gateways.get(resource_group, service, SCG_DEFAULT_NAME) + instance_array = scg_arm_resource.properties.instances + instances = [] + for i in instance_array: + instances.append(ManagedComponentInstance(i.name)) + if len(instances) == 0: + logger.warning("No instance found for component {}.".format(self.name)) + return instances + + +class ScgOperator(ManagedComponent): + def __init__(self): + super().__init__(SCG_OPERATOR) + + def list_instances(self, client, resource_group, service): + try: + return self._list_instances(client, resource_group, service) + except ResourceNotFoundError: + raise CLIError("'{}' is a subcomponent of Spring Cloud Gateway (SCG), " + "failed to perform operations when SCG is not enabled.".format(self.name)) + + def _list_instances(self, client, resource_group, service): + scg_arm_resource = client.gateways.get(resource_group, service, SCG_DEFAULT_NAME) + instance_array = scg_arm_resource.properties.operator_properties.instances + instances = [] + for i in instance_array: + instances.append(ManagedComponentInstance(i.name)) + if len(instances) == 0: + logger.warning("No instance found for component {}.".format(self.name)) + return instances + + +supported_components = [ + Acs(), + Flux(), + Scg(), + ScgOperator(), +] + + +def get_component(component): + for c in supported_components: + if c.match(component): + return c + + return None diff --git a/src/spring/azext_spring/managed_components/managed_component_operations.py b/src/spring/azext_spring/managed_components/managed_component_operations.py new file mode 100644 index 00000000000..6e0fb4359b7 --- /dev/null +++ b/src/spring/azext_spring/managed_components/managed_component_operations.py @@ -0,0 +1,203 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from azure.cli.core._profile import Profile +from azure.cli.core.commands.client_factory import get_subscription_id +from knack.log import get_logger +from knack.util import CLIError +from six.moves.urllib import parse +from threading import Thread +from time import sleep + +from .managed_component import (Acs, Flux, Scg, ScgOperator, + ManagedComponentInstance, supported_components, get_component) + +from ..log_stream.writer import (DefaultWriter, PrefixWriter) +from ..log_stream.log_stream_operations import log_stream_from_url +from .._utils import (get_proxy_api_endpoint, BearerAuth) + + +logger = get_logger(__name__) + + +class ManagedComponentInstanceInfo: + component: str + instance: str + + def __init__(self, component, instance): + self.component = component + self.instance = instance + + +class QueryOptions: + def __init__(self, follow, lines, since, limit): + self.follow = follow + self.lines = lines + self.since = since + self.limit = limit + + +def managed_component_logs(cmd, client, resource_group, service, + name=None, all_instances=None, instance=None, + follow=None, max_log_requests=5, lines=50, since=None, limit=2048): + auth = _get_bearer_auth(cmd) + exceptions = [] + threads = None + queryOptions = QueryOptions(follow=follow, lines=lines, since=since, limit=limit) + if not name and instance: + threads = _get_log_threads_without_component(cmd, client, resource_group, service, + instance, auth, exceptions, queryOptions) + else: + url_dict = _get_log_stream_urls(cmd, client, resource_group, service, name, all_instances, + instance, queryOptions) + if (follow is True and len(url_dict) > max_log_requests): + raise CLIError("You are attempting to follow {} log streams, but maximum allowed concurrency is {}, " + "use --max-log-requests to increase the limit".format(len(url_dict), max_log_requests)) + threads = _get_log_threads(all_instances, url_dict, auth, exceptions) + + if follow and len(threads) > 1: + _parallel_start_threads(threads) + else: + _sequential_start_threads(threads) + + if exceptions: + raise exceptions[0] + + +def managed_component_list(cmd, client, resource_group, service): + return supported_components + + +def managed_component_instance_list(cmd, client, resource_group, service, component): + instances = _list_managed_component_instances(cmd, client, resource_group, service, component) + if instances is None or len(instances) == 0: + logger.warning("No instance found for component '{}'".format(component)) + return instances + + +def _list_managed_component_instances(cmd, client, resource_group, service, component): + managed_component = _get_component(component) + return managed_component.list_instances(client, resource_group, service) + + +def _get_component(component): + for c in supported_components: + if c.match(component): + return c + + return None + + +def _get_log_stream_urls(cmd, client, resource_group, service, component_name, + all_instances, instance, queryOptions: QueryOptions): + component_api_name = _get_component(component_name).get_api_name() + hostname = _get_hostname(cmd, client, resource_group, service) + url_dict = {} + + if component_name and all_instances is True: + instances: [ManagedComponentInstance] = _list_managed_component_instances(cmd, client, resource_group, service, component_name) + if instances is None or len(instances) == 0: + return url_dict + for i in instances: + url = _get_stream_url(hostname, component_api_name, i.name, queryOptions) + url_dict[url] = ManagedComponentInstanceInfo(component_name, i.name) + elif instance: + url = _get_stream_url(hostname, component_api_name, instance, queryOptions) + url_dict[url] = ManagedComponentInstanceInfo(component_name, instance) + + return url_dict + + +def _get_stream_url(hostname, component_name, instance_name, queryOptions: QueryOptions): + url_template = "https://{}/api/logstream/managedComponents/{}/instances/{}" + url = url_template.format(hostname, component_name, instance_name) + url = _attach_logs_query_options(url, queryOptions) + return url + + +def _get_bearer_auth(cmd): + profile = Profile(cli_ctx=cmd.cli_ctx) + creds, _, tenant = profile.get_raw_token() + token = creds[1] + return BearerAuth(token) + + +def _get_hostname(cmd, client, resource_group, service): + resource = client.services.get(resource_group, service) + return get_proxy_api_endpoint(cmd.cli_ctx, resource) + + +def _get_log_threads(all_instances, url_dict, auth, exceptions): + threads = [] + need_prefix = all_instances is True + for url in url_dict.keys(): + writer = _get_default_writer() + if need_prefix: + instance_info = url_dict[url] + prefix = "[{}]".format(instance_info.instance) + writer = _get_prefix_writer(prefix) + threads.append(Thread(target=log_stream_from_url, args=(url, auth, None, exceptions, writer))) + return threads + + +def _contains_alive_thread(threads: [Thread]): + for t in threads: + if t.is_alive(): + return True + + +def _parallel_start_threads(threads: [Thread]): + for t in threads: + t.daemon = True + t.start() + + while _contains_alive_thread(threads): + sleep(1) + # so that ctrl+c can stop the command + + +def _sequential_start_threads(threads: [Thread]): + for idx, t in enumerate(threads): + t.daemon = True + t.start() + + while t.is_alive(): + sleep(1) + # so that ctrl+c can stop the command + + +def _get_log_threads_without_component(cmd, client, resource_group, service, instance_name, auth, exceptions, queryOptions: QueryOptions): + hostname = _get_hostname(cmd, client, resource_group, service) + url_template = "https://{}/api/logstream/managedComponentInstances/{}" + url = url_template.format(hostname, instance_name) + url = _attach_logs_query_options(url, queryOptions) + + return [Thread(target=log_stream_from_url, args=(url, auth, None, exceptions, _get_default_writer()))] + + +def _attach_logs_query_options(url, queryOptions: QueryOptions): + params = {} + params["tailLines"] = queryOptions.lines + params["limitBytes"] = queryOptions.limit + if queryOptions.since: + params["sinceSeconds"] = queryOptions.since + if queryOptions.follow: + params["follow"] = True + + url += "?{}".format(parse.urlencode(params)) if params else "" + return url + + +def _get_prefix_writer(prefix): + """ + Define this method, so that we can mock this method in scenario test to test output + """ + return PrefixWriter(prefix) + + +def _get_default_writer(): + """ + Define this method, so that we can mock this method in scenario test to test output + """ + return DefaultWriter() diff --git a/src/spring/azext_spring/managed_components/validators_managed_component.py b/src/spring/azext_spring/managed_components/validators_managed_component.py new file mode 100644 index 00000000000..96631430479 --- /dev/null +++ b/src/spring/azext_spring/managed_components/validators_managed_component.py @@ -0,0 +1,95 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + + +from azure.cli.core.azclierror import InvalidArgumentValueError +from knack.log import get_logger +from ..managed_components.managed_component import supported_components +from .._validators import validate_log_lines as _validate_n_normalize_component_log_lines +from .._validators import validate_log_since as _validate_n_normalize_component_log_since +from .._validators import validate_log_limit as _validate_n_normalize_component_log_limit +from .._clierror import NotSupportedPricingTierError +from .._util_enterprise import is_enterprise_tier + + +logger = get_logger(__name__) + + +def validate_component_logs(cmd, namespace): + _validate_component_log_mutual_exclusive_param(namespace) + _validate_component_log_required_param(namespace) + _validate_n_normalize_component_for_logs(namespace) + _validate_n_normalize_component_log_lines(namespace) + _validate_n_normalize_component_log_since(namespace) + _validate_n_normalize_component_log_limit(namespace) + _validate_max_log_requests(namespace) + _validate_is_enterprise_tier(cmd, namespace) + + +def validate_component_list(cmd, namespace): + _validate_is_enterprise_tier(cmd, namespace) + + +def validate_instance_list(cmd, namespace): + _validate_component_for_instance_list(namespace) + _validate_is_enterprise_tier(cmd, namespace) + + +def _validate_max_log_requests(namespace): + if namespace.max_log_requests <= 1: + raise InvalidArgumentValueError("--max-log-requests should be larger than 0.") + + +def _validate_is_enterprise_tier(cmd, namespace): + if is_enterprise_tier(cmd, namespace.resource_group, namespace.service) is False: + raise NotSupportedPricingTierError("Only enterprise tier service instance is supported in this command.") + + +def _validate_n_normalize_component_for_logs(namespace): + # Component name is optional for logs + if namespace.name is None: + return + + (is_supported, component_standard_name) = _is_component_supported(namespace.name) + if is_supported: + namespace.name = component_standard_name + return + + _raise_invalid_component_error(namespace.name) + + +def _validate_component_for_instance_list(namespace): + if namespace.component: + (is_supported, component_standard_name) = _is_component_supported(namespace.component) + if is_supported: + namespace.component = component_standard_name + return + + _raise_invalid_component_error(namespace.component) + + +def _is_component_supported(user_input_component_name): + for c in supported_components: + if c.match_ignore_case(user_input_component_name): + return (True, c.get_name()) + return (False, None) + + +def _raise_invalid_component_error(user_input_component_name): + msg_template = "Component '{}' is not supported. Supported components are: '{}'." + component_names = list(map(lambda c: c.get_name(), supported_components)) + raise InvalidArgumentValueError(msg_template.format(user_input_component_name, ",".join(component_names))) + + +def _validate_component_log_mutual_exclusive_param(namespace): + if namespace.all_instances is True and namespace.instance is not None: + raise InvalidArgumentValueError("--all-instances cannot be set together with --instance/-i.") + + +def _validate_component_log_required_param(namespace): + if namespace.name is None and not namespace.instance: + raise InvalidArgumentValueError("When --name/-n is not set, --instance/-i is required.") + if namespace.name is None and namespace.instance: + logger.warning("--instance/-i is specified without --name/-n, will try best effort get logs by instance.") diff --git a/src/spring/azext_spring/tests/latest/managed_component/__init__.py b/src/spring/azext_spring/tests/latest/managed_component/__init__.py new file mode 100644 index 00000000000..99c0f28cd71 --- /dev/null +++ b/src/spring/azext_spring/tests/latest/managed_component/__init__.py @@ -0,0 +1,5 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ----------------------------------------------------------------------------- diff --git a/src/spring/azext_spring/tests/latest/managed_component/test_managed_component_scenarios.py b/src/spring/azext_spring/tests/latest/managed_component/test_managed_component_scenarios.py new file mode 100644 index 00000000000..d3970373983 --- /dev/null +++ b/src/spring/azext_spring/tests/latest/managed_component/test_managed_component_scenarios.py @@ -0,0 +1,415 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import unittest + +from azure.cli.testsdk import (ScenarioTest, record_only, live_only) +from azure.cli.testsdk.base import ExecutionResult +from requests import Response +from ....managed_components.managed_component import get_component +from ...._utils import BearerAuth + + +try: + import unittest.mock as mock +except ImportError: + from unittest import mock + + +class TestingWriter: + def __init__(self, buffer): + self.buffer = buffer + + def write(self, data, end='', file=None): + self.buffer.append(data) + + +class ManagedComponentTest(ScenarioTest): + @mock.patch('azext_spring.commands.cf_spring', autospec=True) + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_component_list(self, is_enterprise_tier_mock, cf_spring_mock): + self.kwargs.update({ + 'serviceName': 'asae-name', + 'rg': 'resource-group', + }) + + cf_spring_mock.return_value = mock.MagicMock() + + is_enterprise_tier_mock.return_value = True + result: ExecutionResult = self.cmd('spring component list -s {serviceName} -g {rg}') + self.assertTrue(isinstance(result.get_output_in_json(), list)) + component_list: list = result.get_output_in_json() + + for e in component_list: + self.assertTrue(isinstance(e, dict)) + e: dict = e + self.assertTrue("name" in e) + component_obj = get_component(e["name"]) + self.assertIsNotNone(component_obj) + + @mock.patch('azext_spring.commands.cf_spring', autospec=True) + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_acs_component_instance_list(self, is_enterprise_tier_mock, cf_spring_mock): + self.kwargs.update({ + 'serviceName': 'asae-name', + 'rg': 'resource-group', + 'component': 'application-configuration-service' + }) + + is_enterprise_tier_mock.return_value = True + + client = mock.MagicMock() + client.configuration_services.get.return_value = self._get_mocked_acs_gen2() + cf_spring_mock.return_value = client + + # ACS (Gen1 or Gen2) is enabled in service instance. + result: ExecutionResult = self.cmd('spring component instance list -s {serviceName} -g {rg} -c {component}') + output = result.get_output_in_json() + self.assertTrue(type(output), list) + self.assertEqual(2, len(output)) + for e in output: + self.assertTrue(isinstance(e, dict)) + self.assertTrue("name" in e) + instance: str = e["name"] + self.assertTrue(instance.startswith("application-configuration-service")) + + @mock.patch('azext_spring.commands.cf_spring', autospec=True) + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_flux_component_instance_list(self, is_enterprise_tier_mock, client_factory_mock): + self.kwargs.update({ + 'serviceName': 'asae-name', + 'rg': 'resource-group', + 'component': 'flux-source-controller', + }) + + is_enterprise_tier_mock.return_value = True + + client = mock.MagicMock() + client.configuration_services.get.return_value = self._get_mocked_acs_gen2() + client_factory_mock.return_value = client + + # flux is a subcomponent of ACS Gen2, make sure it's enabled in service instance. + result: ExecutionResult = self.cmd('spring component instance list -s {serviceName} -g {rg} -c {component}') + output = result.get_output_in_json() + self.assertTrue(type(output), list) + self.assertEqual(1, len(output)) + for e in output: + self.assertTrue(isinstance(e, dict)) + self.assertTrue("name" in e) + instance: str = e["name"] + self.assertTrue(instance.startswith("fluxcd-source-controller")) + + @mock.patch('azext_spring.commands.cf_spring', autospec=True) + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_scg_component_instance_list(self, is_enterprise_tier_mock, client_factory_mock): + self.kwargs.update({ + 'serviceName': 'asae-name', + 'rg': 'resource-group', + 'component': 'spring-cloud-gateway', + }) + + is_enterprise_tier_mock.return_value = True + + client = mock.MagicMock() + client.gateways.get.return_value = self._get_mocked_scg() + client_factory_mock.return_value = client + + # scg is a subcomponent of Spring Cloud Gateway, need to enable it first. + result: ExecutionResult = self.cmd('spring component instance list -s {serviceName} -g {rg} -c {component}') + output = result.get_output_in_json() + self.assertTrue(type(output), list) + self.assertEqual(3, len(output)) + for e in output: + self.assertTrue(isinstance(e, dict)) + self.assertTrue("name" in e) + instance: str = e["name"] + self.assertTrue(instance.startswith("asc-scg-default")) + + @mock.patch('azext_spring.commands.cf_spring', autospec=True) + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_scg_operator_component_instance_list(self, is_enterprise_tier_mock, client_factory_mock): + self.kwargs.update({ + 'serviceName': 'asae-name', + 'rg': 'resource-group', + 'component': 'spring-cloud-gateway-operator' + }) + + is_enterprise_tier_mock.return_value = True + + client = mock.MagicMock() + client.gateways.get.return_value = self._get_mocked_scg() + client_factory_mock.return_value = client + + # scg operator is a subcomponent of Spring Cloud Gateway, need to enable it first. + result: ExecutionResult = self.cmd('spring component instance list -s {serviceName} -g {rg} -c {component}') + output = result.get_output_in_json() + self.assertTrue(type(output), list) + self.assertEqual(2, len(output)) + for e in output: + self.assertTrue(isinstance(e, dict)) + self.assertTrue("name" in e) + instance: str = e["name"] + self.assertTrue(instance.startswith("scg-operator")) + + @mock.patch('azext_spring.log_stream.log_stream_operations.iter_lines', autospec=True) + @mock.patch('azext_spring.log_stream.log_stream_operations.requests', autospec=True) + @mock.patch('azext_spring.commands.cf_spring', autospec=True) + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_hostname', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_bearer_auth', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_default_writer', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_prefix_writer', autospec=True) + def test_acs_log_stream(self, _get_prefix_writer_mock, _get_default_writer_mock, _get_bearer_auth_mock, + _get_hostname_mock, is_enterprise_tier_mock, client_factory_mock, requests_mock, + iter_lines_mock): + command_std_out = [] + _get_default_writer_mock.return_value = TestingWriter(command_std_out) + _get_prefix_writer_mock.return_value = TestingWriter(command_std_out) + + _get_bearer_auth_mock.return_value = BearerAuth("fake-bearer-token") + asae_name = "asae-name" + _get_hostname_mock.return_value = "{}.asc-test.net".format(asae_name) + + is_enterprise_tier_mock.return_value = True + + client = mock.MagicMock() + client.configuration_services.get.return_value = self._get_mocked_acs_gen2() + client_factory_mock.return_value = client + + response = Response() + response.status_code = 200 + requests_mock.get.return_value = response + + lines = [] + for i in range(50): + line = "Log line No.{}\n".format(i) + line = line.encode('utf-8') + lines.append(line) + iter_lines_mock.return_value = lines + + self.kwargs.update({ + 'serviceName': asae_name, + 'rg': 'resource-group', + 'component': 'application-configuration-service' + }) + + self.cmd('spring component logs -s {serviceName} -g {rg} -n {component} --all-instances --lines 50') + self.assertEqual(len(command_std_out), 100) + + @mock.patch('azext_spring.log_stream.log_stream_operations.iter_lines', autospec=True) + @mock.patch('azext_spring.log_stream.log_stream_operations.requests', autospec=True) + @mock.patch('azext_spring.commands.cf_spring', autospec=True) + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_hostname', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_bearer_auth', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_default_writer', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_prefix_writer', autospec=True) + def test_flux_log_stream(self, _get_prefix_writer_mock, _get_default_writer_mock, _get_bearer_auth_mock, + _get_hostname_mock, is_enterprise_tier_mock, client_factory_mock, requests_mock, + iter_lines_mock): + command_std_out = [] + _get_default_writer_mock.return_value = TestingWriter(command_std_out) + _get_prefix_writer_mock.return_value = TestingWriter(command_std_out) + + _get_bearer_auth_mock.return_value = BearerAuth("fake-bearer-token") + asae_name = "asae-name" + _get_hostname_mock.return_value = "{}.asc-test.net".format(asae_name) + + is_enterprise_tier_mock.return_value = True + + client = mock.MagicMock() + client.configuration_services.get.return_value = self._get_mocked_acs_gen2() + client_factory_mock.return_value = client + + response = Response() + response.status_code = 200 + requests_mock.get.return_value = response + + lines = [] + for i in range(50): + line = "Log line No.{}\n".format(i) + line = line.encode('utf-8') + lines.append(line) + iter_lines_mock.return_value = lines + + self.kwargs.update({ + 'serviceName': asae_name, + 'rg': 'resource-group', + 'component': 'flux-source-controller' + }) + + self.cmd('spring component logs -s {serviceName} -g {rg} -n {component} --all-instances --lines 50') + self.assertEqual(len(command_std_out), 50) + + @mock.patch('azext_spring.log_stream.log_stream_operations.iter_lines', autospec=True) + @mock.patch('azext_spring.log_stream.log_stream_operations.requests', autospec=True) + @mock.patch('azext_spring.commands.cf_spring', autospec=True) + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_hostname', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_bearer_auth', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_default_writer', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_prefix_writer', autospec=True) + def test_scg_log_stream(self, _get_prefix_writer_mock, _get_default_writer_mock, _get_bearer_auth_mock, + _get_hostname_mock, is_enterprise_tier_mock, client_factory_mock, requests_mock, + iter_lines_mock): + command_std_out = [] + _get_default_writer_mock.return_value = TestingWriter(command_std_out) + _get_prefix_writer_mock.return_value = TestingWriter(command_std_out) + + _get_bearer_auth_mock.return_value = BearerAuth("fake-bearer-token") + asae_name = "asae-name" + _get_hostname_mock.return_value = "{}.asc-test.net".format(asae_name) + + is_enterprise_tier_mock.return_value = True + + client = mock.MagicMock() + client.gateways.get.return_value = self._get_mocked_scg() + client_factory_mock.return_value = client + + response = Response() + response.status_code = 200 + requests_mock.get.return_value = response + + lines = [] + for i in range(50): + line = "Log line No.{}\n".format(i) + line = line.encode('utf-8') + lines.append(line) + iter_lines_mock.return_value = lines + + self.kwargs.update({ + 'serviceName': asae_name, + 'rg': 'resource-group', + 'component': 'spring-cloud-gateway' + }) + + self.cmd('spring component logs -s {serviceName} -g {rg} -n {component} --all-instances --lines 50') + self.assertEqual(len(command_std_out), 150) + + @mock.patch('azext_spring.log_stream.log_stream_operations.iter_lines', autospec=True) + @mock.patch('azext_spring.log_stream.log_stream_operations.requests', autospec=True) + @mock.patch('azext_spring.commands.cf_spring', autospec=True) + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_hostname', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_bearer_auth', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_default_writer', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_prefix_writer', autospec=True) + def test_scg_operator_log_stream(self, _get_prefix_writer_mock, _get_default_writer_mock, _get_bearer_auth_mock, + _get_hostname_mock, is_enterprise_tier_mock, client_factory_mock, requests_mock, + iter_lines_mock): + command_std_out = [] + _get_default_writer_mock.return_value = TestingWriter(command_std_out) + _get_prefix_writer_mock.return_value = TestingWriter(command_std_out) + + _get_bearer_auth_mock.return_value = BearerAuth("fake-bearer-token") + asae_name = "asae-name" + _get_hostname_mock.return_value = "{}.asc-test.net".format(asae_name) + + is_enterprise_tier_mock.return_value = True + + client = mock.MagicMock() + client.gateways.get.return_value = self._get_mocked_scg() + client_factory_mock.return_value = client + + response = Response() + response.status_code = 200 + requests_mock.get.return_value = response + + lines = [] + for i in range(50): + line = "Log line No.{}\n".format(i) + line = line.encode('utf-8') + lines.append(line) + iter_lines_mock.return_value = lines + + self.kwargs.update({ + 'serviceName': asae_name, + 'rg': 'resource-group', + 'component': 'spring-cloud-gateway-operator' + }) + + self.cmd('spring component logs -s {serviceName} -g {rg} -n {component} --all-instances --lines 50') + self.assertEqual(len(command_std_out), 100) + + @mock.patch('azext_spring.commands.cf_spring', autospec=True) + @mock.patch('azext_spring.log_stream.log_stream_operations.iter_lines', autospec=True) + @mock.patch('azext_spring.log_stream.log_stream_operations.requests', autospec=True) + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_hostname', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_bearer_auth', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_default_writer', autospec=True) + @mock.patch('azext_spring.managed_components.managed_component_operations._get_prefix_writer', autospec=True) + def test_log_stream_only_by_instance_name(self, _get_prefix_writer_mock, _get_default_writer_mock, + _get_bearer_auth_mock, _get_hostname_mock, is_enterprise_tier_mock, + requests_mock, iter_lines_mock, cf_spring_mock): + instance_names = [ + "application-configuration-service-6fb669cfc5-z6tq9", + "fluxcd-source-controller-675dbdd58b-8sk75", + "asc-scg-default-0", + "scg-operator-6d8895c44b-wcngh" + ] + + for i in instance_names: + + command_std_out = [] + _get_default_writer_mock.return_value = TestingWriter(command_std_out) + _get_prefix_writer_mock.return_value = TestingWriter(command_std_out) + + _get_bearer_auth_mock.return_value = BearerAuth("fake-bearer-token") + asae_name = "asae-name" + _get_hostname_mock.return_value = "{}.asc-test.net".format(asae_name) + + is_enterprise_tier_mock.return_value = True + + cf_spring_mock.return_value = mock.MagicMock() + + response = Response() + response.status_code = 200 + requests_mock.get.return_value = response + + lines = [] + for i in range(50): + line = "Log line No.{}\n".format(i) + line = line.encode('utf-8') + lines.append(line) + iter_lines_mock.return_value = lines + + self.kwargs.update({ + 'serviceName': asae_name, + 'rg': 'resource-group', + 'instance': i + }) + + self.cmd('spring component logs -s {serviceName} -g {rg} -i {instance} --lines 50') + self.assertEqual(len(command_std_out), 50) + + def _get_mocked_acs_gen2(self): + resource = mock.MagicMock() + instance_1 = mock.MagicMock() + instance_2 = mock.MagicMock() + instance_3 = mock.MagicMock() + resource.properties = mock.MagicMock() + resource.properties.instances = [instance_1, instance_2, instance_3] + instance_1.name = "application-configuration-service-11111111-1111" + instance_2.name = "application-configuration-service-11111111-2222" + instance_3.name = "fluxcd-source-controller-11111111-3333" + return resource + + def _get_mocked_scg(self): + resource = mock.MagicMock() + resource.properties = mock.MagicMock() + instance_1 = mock.MagicMock() + instance_2 = mock.MagicMock() + instance_3 = mock.MagicMock() + resource.properties.instances = [instance_1, instance_2, instance_3] + instance_1.name = "asc-scg-default-0" + instance_2.name = "asc-scg-default-1" + instance_3.name = "asc-scg-default-2" + resource.properties.operator_properties = mock.MagicMock() + operator_1 = mock.MagicMock() + operator_2 = mock.MagicMock() + resource.properties.operator_properties.instances = [operator_1, operator_2] + operator_1.name = "scg-operator-74947fdcb-8hj85" + operator_2.name = "scg-operator-74947fdcb-askdj" + return resource diff --git a/src/spring/azext_spring/tests/latest/managed_component/test_managed_component_validators.py b/src/spring/azext_spring/tests/latest/managed_component/test_managed_component_validators.py new file mode 100644 index 00000000000..c132ff6ee44 --- /dev/null +++ b/src/spring/azext_spring/tests/latest/managed_component/test_managed_component_validators.py @@ -0,0 +1,436 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ----------------------------------------------------------------------------- + +import unittest + +from argparse import Namespace +from azure.cli.core import AzCommandsLoader +from azure.cli.core.azclierror import InvalidArgumentValueError +from azure.cli.core.mock import DummyCli +from azure.cli.core.commands import AzCliCommand + +from ....managed_components.managed_component import get_component +from ....managed_components.validators_managed_component import (validate_component_logs, + validate_component_list, + validate_instance_list) +from ...._clierror import NotSupportedPricingTierError + +try: + import unittest.mock as mock +except ImportError: + from unittest import mock + + +valid_component_names = [ + "application-configuration-service", + "APPLICATION-configuration-service", + "APPlication-configuration-service", + "Application-configuration-serVIcE", + "application-CONFIGuratioN-service", + "flux-source-controller", + "FLUX-source-controller", + "flux-sOurce-controller", + "flux-source-controllEr", + "flux-source-controlleR", + "spring-cloud-gateway", + "SPrINg-cloud-gateway", + "spring-cloud-gaTeway", + "spring-cloud-Gateway", + "spring-cloud-GATEWAY", + "spring-cloud-gateway-operator", + "spring-cloud-gateway-operatoR", + "spring-cloud-gatewaY-operator", + "spring-CLOUD-gateway-operator", + "sprinG-cloud-gateway-operator" +] + + +invalid_component_names = [ + "app-configuration-service", + "", + "None", + "flux" +] + + +def _get_test_cmd(): + cli_ctx = DummyCli() + cli_ctx.data['subscription_id'] = '00000000-0000-0000-0000-000000000000' + loader = AzCommandsLoader(cli_ctx, resource_type='Microsoft.AppPlatform') + cmd = AzCliCommand(loader, 'test', None) + cmd.command_kwargs = {'resource_type': 'Microsoft.AppPlatform'} + cmd.cli_ctx = cli_ctx + return cmd + + +class TestValidateComponentList(unittest.TestCase): + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_tier(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = False + + with self.assertRaises(NotSupportedPricingTierError) as context: + validate_component_list(_get_test_cmd(), Namespace(resource_group="group", service="service")) + + is_enterprise_tier_mock.return_value = True + validate_component_list(_get_test_cmd(), Namespace(resource_group="group", service="service")) + + +class TestValidateComponentInstanceList(unittest.TestCase): + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_component_name(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + for c in valid_component_names: + ns = Namespace(resource_group="group", service="service", component=c) + validate_instance_list(_get_test_cmd(), ns) + component_obj = get_component(ns.component) + self.assertIsNotNone(component_obj) + + for c in invalid_component_names: + with self.assertRaises(InvalidArgumentValueError) as context: + ns = Namespace(resource_group="group", service="service", component=c) + validate_instance_list(_get_test_cmd(), ns) + + self.assertTrue("is not supported" in str(context.exception)) + self.assertTrue("Supported components are:" in str(context.exception)) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_tier(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + ns = Namespace(resource_group="group", service="service", component="application-configuration-service") + validate_instance_list(_get_test_cmd(), ns) + + is_enterprise_tier_mock.return_value = False + with self.assertRaises(NotSupportedPricingTierError) as context: + validate_instance_list(_get_test_cmd(), ns) + + +class TestValidateComponentLogs(unittest.TestCase): + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_mutual_exclusive_param(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + ns = Namespace( + resource_group="group", + service="service", + all_instances = True, + instance = "fake-instance-name" + ) + + with self.assertRaises(InvalidArgumentValueError) as context: + validate_component_logs(_get_test_cmd(), ns) + + self.assertEquals("--all-instances cannot be set together with --instance/-i.", str(context.exception)) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_required_param_missing(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=None, + instance=None, + ) + + with self.assertRaises(InvalidArgumentValueError) as context: + validate_component_logs(_get_test_cmd(), ns) + + self.assertEquals("When --name/-n is not set, --instance/-i is required.", str(context.exception)) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_only_instance_name(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=None, + instance="fake-instance-name", + lines=50, + limit=2048, + since=None, + max_log_requests=5 + ) + + with self.assertLogs('cli.azext_spring.managed_components.validators_managed_component', 'WARNING') as cm: + validate_component_logs(_get_test_cmd(), ns) + self.assertEquals(cm.output, ['WARNING:cli.azext_spring.managed_components.validators_managed_component:--instance/-i is specified without --name/-n, will try best effort get logs by instance.']) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_valid_component_name(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + for n in valid_component_names: + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=n, + instance="fake-instance-name", + lines=50, + limit=2048, + since=None, + max_log_requests=5 + ) + validate_component_logs(_get_test_cmd(), ns) + + component_obj = get_component(ns.name) + self.assertIsNotNone(component_obj) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_invalid_component_name(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_valid_log_lines(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + valid_log_lines = [1, 2, 5, 10, 99, 100, 200, 10000] + + for n in valid_component_names: + for lines in valid_log_lines: + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=n, + instance="fake-instance-name", + lines=lines, + limit=2048, + since=None, + max_log_requests=5 + ) + validate_component_logs(_get_test_cmd(), ns) + self.assertEquals(lines, ns.lines) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_log_lines_too_small(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + for n in valid_component_names: + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=n, + instance="fake-instance-name", + lines=-1, + limit=2048, + since=None + ) + with self.assertRaises(InvalidArgumentValueError) as context: + validate_component_logs(_get_test_cmd(), ns) + self.assertEquals('--lines must be in the range [1,10000]', str(context.exception)) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_log_lines_too_big(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + for n in valid_component_names: + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=n, + instance="fake-instance-name", + lines=10001, + limit=2048, + since=None, + max_log_requests=5 + ) + with self.assertLogs('cli.azext_spring._validators', 'ERROR') as cm: + validate_component_logs(_get_test_cmd(), ns) + expect_error_msgs = ['ERROR:cli.azext_spring._validators:' + '--lines can not be more than 10000, using 10000 instead'] + self.assertEquals(expect_error_msgs, cm.output) + self.assertEquals(10000, ns.lines) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_valid_log_since(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + valid_log_since = ['1h', + '1m', '2m', '5m', '10m', '11m', '20m', '30m', '40m', '50m', '59m', '60m', + '1s', '2s', '5s', '9s', '10s', '20s', '29s', '30s', '60s', '100s', '500s', '3000s', '3600s', + '1', '2', '5', '10', '20', '29', '30', '3000', '3600'] + + for n in valid_component_names: + for since in valid_log_since: + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=n, + instance="fake-instance-name", + lines=10001, + limit=2048, + since=since, + max_log_requests=5 + ) + validate_component_logs(_get_test_cmd(), ns) + last = since[-1:] + since_in_seconds = int(since[:-1]) if last in ("hms") else int(since) + if last == 'h': + since_in_seconds = since_in_seconds * 3600 + elif last == 'm': + since_in_seconds = since_in_seconds * 60 + self.assertEquals(since_in_seconds, ns.since) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_invalid_log_since(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + invalid_log_since = ['asdf1h', '1masdf', 'asdfe2m', 'asd5m', '1efef0m', '11mm'] + + for n in valid_component_names: + for since in invalid_log_since: + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=n, + instance="fake-instance-name", + lines=10001, + limit=2048, + since=since + ) + with self.assertRaises(InvalidArgumentValueError) as context: + validate_component_logs(_get_test_cmd(), ns) + self.assertEquals("--since contains invalid characters", str(context.exception)) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_log_since_too_big(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + invalid_log_since = ['2h', '61m', '3601s', '9000s', '9000'] + + for n in valid_component_names: + for since in invalid_log_since: + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=n, + instance="fake-instance-name", + lines=10000, + limit=2048, + since=since + ) + with self.assertRaises(InvalidArgumentValueError) as context: + validate_component_logs(_get_test_cmd(), ns) + self.assertEquals("--since can not be more than 1h", str(context.exception)) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_valid_log_limit(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + valid_log_limit = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] + + for n in valid_component_names: + for limit in valid_log_limit: + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=n, + instance="fake-instance-name", + lines=10000, + limit=limit, + since='1h', + max_log_requests=5 + ) + validate_component_logs(_get_test_cmd(), ns) + self.assertEquals(limit * 1024, ns.limit) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_negative_log_limit(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + invalid_log_limit = [-1, -2, -3, -4, -10, -100, -1000] + + for n in valid_component_names: + for limit in invalid_log_limit: + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=n, + instance="fake-instance-name", + lines=10000, + limit=limit, + since='1h' + ) + with self.assertRaises(InvalidArgumentValueError) as context: + validate_component_logs(_get_test_cmd(), ns) + self.assertEquals('--limit must be in the range [1,2048]', str(context.exception)) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_log_limit_too_big(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + invalid_log_limit = [2049, 2050, 3000, 3001, 10000, 20000, 100000] + + for n in valid_component_names: + for limit in invalid_log_limit: + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name=n, + instance="fake-instance-name", + lines=10000, + limit=limit, + since='1h', + max_log_requests=5, + ) + with self.assertLogs('cli.azext_spring._validators', 'ERROR') as cm: + validate_component_logs(_get_test_cmd(), ns) + error_msgs = ['ERROR:cli.azext_spring._validators:' + '--limit can not be more than 2048, using 2048 instead'] + self.assertEquals(error_msgs, cm.output) + self.assertEquals(2048 * 1024, ns.limit) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_tier(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = True + + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name="application-configuration-service", + instance="fake-instance-name", + lines=10000, + limit=2048, + since='1h', + max_log_requests=5 + ) + + validate_component_logs(_get_test_cmd(), ns) + + @mock.patch('azext_spring.managed_components.validators_managed_component.is_enterprise_tier', autospec=True) + def test_invalid_tier(self, is_enterprise_tier_mock): + is_enterprise_tier_mock.return_value = False + + ns = Namespace( + resource_group="group", + service="service", + all_instances=False, + name="application-configuration-service", + instance="fake-instance-name", + lines=10000, + limit=2048, + since='1h', + max_log_requests=5, + ) + + with self.assertRaises(NotSupportedPricingTierError) as context: + validate_component_logs(_get_test_cmd(), ns) + self.assertEquals("Only enterprise tier service instance is supported in this command.", str(context.exception)) diff --git a/src/spring/azext_spring/tests/latest/managed_component/test_writer.py b/src/spring/azext_spring/tests/latest/managed_component/test_writer.py new file mode 100644 index 00000000000..7757af6ca30 --- /dev/null +++ b/src/spring/azext_spring/tests/latest/managed_component/test_writer.py @@ -0,0 +1,22 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ----------------------------------------------------------------------------- +import io +import unittest +from ....log_stream.writer import DefaultWriter, PrefixWriter + + +class TestValidateComponentList(unittest.TestCase): + def test_default_writer(self): + writer = DefaultWriter() + buffer = io.StringIO() + writer.write("test-data", end='', file=buffer) + self.assertEquals("test-data", buffer.getvalue().strip()) + + def test_prefix_writer(self): + writer = PrefixWriter("prefix") + buffer = io.StringIO() + writer.write("test-data", end='', file=buffer) + self.assertEquals("prefix test-data", buffer.getvalue().strip()) diff --git a/src/spring/setup.py b/src/spring/setup.py index a066fe56e3a..eb74c738294 100644 --- a/src/spring/setup.py +++ b/src/spring/setup.py @@ -16,7 +16,7 @@ # TODO: Confirm this is the right version number you want and it matches your # HISTORY.rst entry. -VERSION = '1.18.0' +VERSION = '1.19.0' # The full list of classifiers is available at # https://pypi.python.org/pypi?%3Aaction=list_classifiers