diff --git a/src/dashboard/apigateway/apigateway/apis/open/stage/serializers.py b/src/dashboard/apigateway/apigateway/apis/open/stage/serializers.py index 21b8ed4e2..d4586c6a5 100644 --- a/src/dashboard/apigateway/apigateway/apis/open/stage/serializers.py +++ b/src/dashboard/apigateway/apigateway/apis/open/stage/serializers.py @@ -238,8 +238,11 @@ def create(self, validated_data): # 4. create or update header rewrite plugin config stage_transform_headers = proxy_http_config.get("transform_headers") or {} stage_config = HeaderRewriteConvertor.transform_headers_to_plugin_config(stage_transform_headers) - HeaderRewriteConvertor.alter_plugin( - instance.gateway_id, PluginBindingScopeEnum.STAGE.value, instance.id, stage_config + HeaderRewriteConvertor.sync_plugins( + instance.gateway_id, + PluginBindingScopeEnum.STAGE.value, + {instance.id: stage_config}, + self.context["request"].user.username, ) return instance @@ -292,8 +295,11 @@ def update(self, instance, validated_data): # 3. create or update header rewrite plugin config stage_transform_headers = proxy_http_config.get("transform_headers") or {} stage_config = HeaderRewriteConvertor.transform_headers_to_plugin_config(stage_transform_headers) - HeaderRewriteConvertor.alter_plugin( - instance.gateway_id, PluginBindingScopeEnum.STAGE.value, instance.id, stage_config + HeaderRewriteConvertor.sync_plugins( + instance.gateway_id, + PluginBindingScopeEnum.STAGE.value, + {instance.id: stage_config}, + self.context["request"].user.username, ) return instance diff --git a/src/dashboard/apigateway/apigateway/apis/web/resource/legacy_serializers.py b/src/dashboard/apigateway/apigateway/apis/web/resource/legacy_serializers.py new file mode 100644 index 000000000..cfd5efe0e --- /dev/null +++ b/src/dashboard/apigateway/apigateway/apis/web/resource/legacy_serializers.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# +# TencentBlueKing is pleased to support the open source community by making +# 蓝鲸智云 - API 网关(BlueKing - APIGateway) available. +# Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the MIT License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://opensource.org/licenses/MIT +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +# either express or implied. See the License for the specific language governing permissions and +# limitations under the License. +# +# We undertake not to change the open source license (MIT license) applicable +# to the current version of the project delivered to anyone in the future. +# +# 1.13 版本:兼容旧版 (api_version=0.1) 资源 yaml 通过 openapi 导入 +import re + +from django.utils.translation import gettext as _ +from rest_framework import serializers + +from apigateway.core.constants import DEFAULT_LB_HOST_WEIGHT, STAGE_VAR_REFERENCE_PATTERN, LoadBalanceTypeEnum + +# 通过 openapi 导入时,只允许导入使用环境变量的后端地址 +RESOURCE_DOMAIN_PATTERN = re.compile(r"^http(s)?:\/\/\{%s\}$" % (STAGE_VAR_REFERENCE_PATTERN.pattern)) + +HEADER_KEY_PATTERN = re.compile(r"^[a-zA-Z0-9-]{1,100}$") + + +class LegacyResourceHostSLZ(serializers.Serializer): + host = serializers.RegexField(RESOURCE_DOMAIN_PATTERN) + weight = serializers.IntegerField(min_value=1, default=DEFAULT_LB_HOST_WEIGHT) + + +class LegacyUpstreamsSLZ(serializers.Serializer): + loadbalance = serializers.ChoiceField(choices=LoadBalanceTypeEnum.get_choices(), required=False) + hosts = serializers.ListField(child=LegacyResourceHostSLZ(), allow_empty=False, required=False) + + def validate(self, data): + if "hosts" in data and not data.get("loadbalance"): + raise serializers.ValidationError(_("hosts 存在时,需要指定 loadbalance 类型。")) + + return data + + +class LegacyTransformHeadersSLZ(serializers.Serializer): + set = serializers.DictField(label="设置", child=serializers.CharField(), required=False, allow_empty=True) + delete = serializers.ListField(label="删除", child=serializers.CharField(), required=False, allow_empty=True) + + def _validate_headers_key(self, value): + for key in value: + if not HEADER_KEY_PATTERN.match(key): + raise serializers.ValidationError(_("Header 键由字母、数字、连接符(-)组成,长度小于100个字符。")) + return value + + def validate_set(self, value): + return self._validate_headers_key(value) + + def validate_delete(self, value): + return self._validate_headers_key(value) diff --git a/src/dashboard/apigateway/apigateway/apis/web/resource/serializers.py b/src/dashboard/apigateway/apigateway/apis/web/resource/serializers.py index 7cb2d4686..620541f13 100644 --- a/src/dashboard/apigateway/apigateway/apis/web/resource/serializers.py +++ b/src/dashboard/apigateway/apigateway/apis/web/resource/serializers.py @@ -40,6 +40,7 @@ from apigateway.core.utils import get_path_display from .constants import MAX_LABEL_COUNT_PER_RESOURCE, PATH_PATTERN, RESOURCE_NAME_PATTERN +from .legacy_serializers import LegacyTransformHeadersSLZ, LegacyUpstreamsSLZ class ResourceQueryInputSLZ(serializers.Serializer): @@ -143,6 +144,11 @@ class HttpBackendConfigSLZ(serializers.Serializer): timeout = serializers.IntegerField( max_value=MAX_BACKEND_TIMEOUT_IN_SECOND, min_value=0, required=False, help_text="超时时间" ) + # 1.13 版本: 兼容旧版 (api_version=0.1) 资源 yaml 通过 openapi 导入 + legacy_upstreams = LegacyUpstreamsSLZ(allow_null=True, required=False, help_text="旧版 upstreams,管理端不需要处理") + legacy_transform_headers = LegacyTransformHeadersSLZ( + allow_null=True, required=False, help_text="旧版 transform_headers,管理端不需要处理" + ) class ResourceInputSLZ(serializers.ModelSerializer): diff --git a/src/dashboard/apigateway/apigateway/biz/resource/importer/importers.py b/src/dashboard/apigateway/apigateway/biz/resource/importer/importers.py index 1a1501c72..c21d0d607 100644 --- a/src/dashboard/apigateway/apigateway/biz/resource/importer/importers.py +++ b/src/dashboard/apigateway/apigateway/biz/resource/importer/importers.py @@ -30,6 +30,8 @@ from apigateway.core.constants import DEFAULT_BACKEND_NAME, HTTP_METHOD_ANY from apigateway.core.models import Backend, Gateway, Resource +from .legacy_synchronizers import LegacyTransformHeadersToPluginSynchronizer, LegacyUpstreamToBackendSynchronizer + logger = logging.getLogger(__name__) @@ -329,9 +331,15 @@ def import_resources(self): # 3. 补全标签 ID 数据 self._complete_label_ids() - # 4. 创建或更新资源 + # 4. [legacy upstreams] 创建或更新 backend,并替换资源对应的 backend + self._sync_legacy_upstreams_to_backend_and_replace_resource_backend() + + # 5. 创建或更新资源 self._create_or_update_resources() + # 6. [legacy transform-headers] 将 transform-headers 转换为 bk-header-rewrite 插件,并绑定到资源 + self._sync_legacy_transform_headers_to_plugins() + def get_selected_resource_data_list(self) -> List[ResourceData]: return self.resource_data_list @@ -387,3 +395,13 @@ def _create_or_update_resources(self) -> List[Resource]: username=self.username, ) return saver.save() + + def _sync_legacy_upstreams_to_backend_and_replace_resource_backend(self): + """根据 backend_config 中的 legacy_upstreams 创建 backend,并替换 resource_data_list 中资源关联的 backend""" + synchronizer = LegacyUpstreamToBackendSynchronizer(self.gateway, self.resource_data_list, self.username) + synchronizer.sync_backends_and_replace_resource_backend() + + def _sync_legacy_transform_headers_to_plugins(self): + """根据 backend_config 中的 legacy_transform_headers 创建 bk-header-rewrite 插件,并绑定到资源""" + synchronizer = LegacyTransformHeadersToPluginSynchronizer(self.gateway, self.resource_data_list, self.username) + synchronizer.sync_plugins() diff --git a/src/dashboard/apigateway/apigateway/biz/resource/importer/legacy_synchronizers.py b/src/dashboard/apigateway/apigateway/biz/resource/importer/legacy_synchronizers.py new file mode 100644 index 000000000..6781ff71a --- /dev/null +++ b/src/dashboard/apigateway/apigateway/biz/resource/importer/legacy_synchronizers.py @@ -0,0 +1,243 @@ +# +# TencentBlueKing is pleased to support the open source community by making +# 蓝鲸智云 - API 网关(BlueKing - APIGateway) available. +# Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the MIT License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://opensource.org/licenses/MIT +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +# either express or implied. See the License for the specific language governing permissions and +# limitations under the License. +# +# We undertake not to change the open source license (MIT license) applicable +# to the current version of the project delivered to anyone in the future. +# +# 1.13 版本: 兼容旧版 (api_version=0.1) 资源 yaml 通过 openapi 导入 +import logging +import re +from collections import defaultdict +from typing import Any, Dict, List, Optional + +from apigateway.apps.plugin.constants import PluginBindingScopeEnum +from apigateway.biz.resource.models import ResourceData +from apigateway.common.plugin.header_rewrite import HeaderRewriteConvertor +from apigateway.core.constants import DEFAULT_BACKEND_NAME, STAGE_VAR_PATTERN +from apigateway.core.models import Backend, BackendConfig, Gateway, Stage + +logger = logging.getLogger(__name__) + + +LEGACY_BACKEND_NAME_PREFIX = "backend-" + + +class LegacyUpstream: + def __init__(self, upstreams: Dict[str, Any]): + self.upstreams = upstreams + + def get_stage_id_to_backend_config( + self, + stages: List[Stage], + stage_id_to_timeout: Dict[int, int], + ) -> Dict[int, Dict]: + """获取此 upstream 对应的后端,在各个环境的后端配置""" + backend_configs = {} + + for stage in stages: + stage_vars = stage.vars + + hosts = [] + for host in self.upstreams["hosts"]: + scheme, host_ = host["host"].rstrip("/").split("://") + hosts.append( + { + "scheme": scheme, + "host": self._render_host(stage_vars, host_), + "weight": host["weight"], + } + ) + + backend_configs[stage.id] = { + "type": "node", + # 新创建的后端,其超时时间,默认使用 default 后端在各环境配置的超时时间 + "timeout": stage_id_to_timeout[stage.id], + "loadbalance": self.upstreams["loadbalance"], + "hosts": hosts, + } + + return backend_configs + + def _render_host(self, vars: Dict[str, Any], host: str) -> str: + def replace(matched): + return vars.get(matched.group(1), matched.group(0)) + + return re.sub(STAGE_VAR_PATTERN, replace, host) + + +class LegacyBackendCreator: + def __init__(self, gateway: Gateway, username: str): + self.gateway = gateway + self.username = username + + self._existing_backends = {backend.id: backend for backend in Backend.objects.filter(gateway=gateway)} + self._existing_backend_configs = self._get_existing_backend_configs() + self._max_legacy_backend_number = self._get_max_legacy_backend_number() + + def match_or_create_backend(self, stage_id_to_backend_config: Dict[int, Dict]) -> Backend: + """根据后端配置,匹配一个后端服务;如果未匹配,根据规则生成一个新的后端服务""" + # 排序 hosts,使其与 existing_backend_configs 中 hosts 顺序一致,便于对比数据 + for backend_config in stage_id_to_backend_config.values(): + backend_config["hosts"] = self._sort_hosts(backend_config["hosts"]) + + backend_id = self._match_existing_backend(stage_id_to_backend_config) + if backend_id: + return self._existing_backends[backend_id] + + new_backend_name = self._generate_new_backend_name() + backend = self._create_backend_and_backend_configs(new_backend_name, stage_id_to_backend_config) + + # 用新创建的 backend 更新辅助数据 + self._existing_backends[backend.id] = backend + self._existing_backend_configs[backend.id] = stage_id_to_backend_config + + return backend + + def _match_existing_backend(self, stage_id_to_backend_config: Dict[int, Dict]) -> Optional[int]: + for backend_id, existing_backend_configs in self._existing_backend_configs.items(): + if stage_id_to_backend_config == existing_backend_configs: + return backend_id + + return None + + def _get_existing_backend_configs(self) -> Dict[int, Dict[int, Dict]]: + # 对应关系:backend_id -> stage_id -> config + backend_configs: Dict[int, Dict[int, Dict]] = defaultdict(dict) + + for backend_config in BackendConfig.objects.filter(gateway=self.gateway): + config = backend_config.config + config["hosts"] = self._sort_hosts(config["hosts"]) + + backend_configs[backend_config.backend_id][backend_config.stage_id] = config + + return backend_configs + + def _generate_new_backend_name(self) -> str: + self._max_legacy_backend_number += 1 + return f"{LEGACY_BACKEND_NAME_PREFIX}{self._max_legacy_backend_number}" + + def _create_backend_and_backend_configs( + self, + backend_name: str, + stage_id_to_backend_config: Dict[int, Dict], + ) -> Backend: + backend = Backend.objects.create( + gateway=self.gateway, name=backend_name, created_by=self.username, updated_by=self.username + ) + + backend_configs = [ + BackendConfig( + gateway=self.gateway, + stage_id=stage_id, + backend=backend, + config=config, + created_by=self.username, + updated_by=self.username, + ) + for stage_id, config in stage_id_to_backend_config.items() + ] + BackendConfig.objects.bulk_create(backend_configs) + + return backend + + def _sort_hosts(self, hosts: List[Dict[str, Dict]]) -> List[Dict[str, Dict]]: + # 排序 host,使用 "==" 对比配置时顺序一致 + return sorted(hosts, key=lambda x: "{}://{}#{}".format(x["scheme"], x["host"], x["weight"])) + + def _get_max_legacy_backend_number(self) -> int: + """获取网关创建的后端中,后端名称中已使用的最大序号""" + names = Backend.objects.filter(gateway=self.gateway, name__startswith=LEGACY_BACKEND_NAME_PREFIX).values_list( + "name", flat=True + ) + + backend_numbers = [ + int(name[len(LEGACY_BACKEND_NAME_PREFIX) :]) + for name in names + if name[len(LEGACY_BACKEND_NAME_PREFIX) :].isdigit() + ] + return max(backend_numbers, default=0) + + +class LegacyUpstreamToBackendSynchronizer: + def __init__(self, gateway: Gateway, resource_data_list: List[ResourceData], username: str): + self.gateway = gateway + self.resource_data_list = resource_data_list + self.username = username + + def sync_backends_and_replace_resource_backend(self): + if not self._has_legacy_upstreams(): + return + + self._sync_backends_and_replace_resource_backend() + + def _has_legacy_upstreams(self) -> bool: + return any(resource_data.backend_config.legacy_upstreams for resource_data in self.resource_data_list) + + def _sync_backends_and_replace_resource_backend(self): + backend_creator = LegacyBackendCreator(self.gateway, self.username) + stages = list(Stage.objects.filter(gateway=self.gateway)) + stage_id_to_timeout = self._get_stage_id_to_default_timeout() + + for resource_data in self.resource_data_list: + if not resource_data.backend_config.legacy_upstreams: + continue + + legacy_upstream = LegacyUpstream(resource_data.backend_config.legacy_upstreams) + stage_id_to_backend_config = legacy_upstream.get_stage_id_to_backend_config(stages, stage_id_to_timeout) + backend = backend_creator.match_or_create_backend(stage_id_to_backend_config) + resource_data.backend = backend + + def _get_stage_id_to_default_timeout(self) -> Dict[int, int]: + return { + backend_config.stage_id: backend_config.config["timeout"] + for backend_config in BackendConfig.objects.filter( + gateway=self.gateway, + backend__name=DEFAULT_BACKEND_NAME, + ) + } + + +class LegacyTransformHeadersToPluginSynchronizer: + def __init__(self, gateway: Gateway, resource_data_list: List[ResourceData], username: str): + self.gateway = gateway + self.resource_data_list = resource_data_list + self.username = username + + def sync_plugins(self): + if not self._has_legacy_transform_headers(): + return + + scope_id_to_plugin_config = {} + for resource_data in self.resource_data_list: + transform_headers = resource_data.backend_config.legacy_transform_headers + if transform_headers is None: + continue + + assert resource_data.resource + + plugin_config = HeaderRewriteConvertor.transform_headers_to_plugin_config(transform_headers) + scope_id_to_plugin_config[resource_data.resource.id] = plugin_config + + HeaderRewriteConvertor.sync_plugins( + gateway_id=self.gateway.id, + scope_type=PluginBindingScopeEnum.RESOURCE.value, + scope_id_to_plugin_config=scope_id_to_plugin_config, + username=self.username, + ) + + def _has_legacy_transform_headers(self) -> bool: + return any( + resource_data.backend_config.legacy_transform_headers is not None + for resource_data in self.resource_data_list + ) diff --git a/src/dashboard/apigateway/apigateway/biz/resource/importer/swagger.py b/src/dashboard/apigateway/apigateway/biz/resource/importer/swagger.py index 62cf3fee9..0805e2b0f 100644 --- a/src/dashboard/apigateway/apigateway/biz/resource/importer/swagger.py +++ b/src/dashboard/apigateway/apigateway/biz/resource/importer/swagger.py @@ -23,7 +23,6 @@ from typing import Any, Dict, List, Optional import jsonschema -from django.utils.translation import gettext as _ from apigateway.biz.constants import SwaggerFormatEnum from apigateway.common.exceptions import SchemaValidationError @@ -197,14 +196,18 @@ def _adapt_backend(self, backend: Dict) -> Dict: """ 适配后端配置 """ - if backend.get("upstreams") or backend.get("transformHeaders"): - raise ValueError(_("当前版本,不支持 backend 中配置 upstreams, transformHeaders,请更新至最新版本资源 yaml 配置。")) + backend_type = backend.get("type", ProxyTypeEnum.HTTP.value).lower() + if backend_type != ProxyTypeEnum.HTTP.value: + raise ValueError(f"unsupported backend type: {backend['type']}") return { "method": backend["method"].upper(), "path": backend["path"], "match_subpath": backend.get("matchSubpath", False), "timeout": backend.get("timeout", 0), + # 1.13 版本: 兼容旧版 (api_version=0.1) 资源 yaml 通过 openapi 导入 + "legacy_upstreams": backend.get("upstreams"), + "legacy_transform_headers": backend.get("transformHeaders"), } def _adapt_description(self, summary: Optional[str], description: Optional[str]): @@ -235,7 +238,7 @@ def _adapt_auth_config(self, auth_config: dict): class ResourceSwaggerExporter: def __init__( self, - api_version: str = "0.2", + api_version: str = "2.0", include_bk_apigateway_resource: bool = True, title: str = "API Gateway Resources", description: str = "", diff --git a/src/dashboard/apigateway/apigateway/biz/resource/models.py b/src/dashboard/apigateway/apigateway/biz/resource/models.py index 233053f0c..e196d29c6 100644 --- a/src/dashboard/apigateway/apigateway/biz/resource/models.py +++ b/src/dashboard/apigateway/apigateway/biz/resource/models.py @@ -33,6 +33,9 @@ class ResourceBackendConfig(BaseModel): path: str match_subpath: bool = Field(default=False) timeout: int = Field(default=0) + # 1.13 版本: 兼容旧版 (api_version=0.1) 资源 yaml 通过 openapi 导入 + legacy_upstreams: Optional[dict] = Field(default=None, exclude=True) + legacy_transform_headers: Optional[dict] = Field(default=None, exclude=True) class ResourceData(BaseModel): diff --git a/src/dashboard/apigateway/apigateway/biz/resource/savers.py b/src/dashboard/apigateway/apigateway/biz/resource/savers.py index ab2bb6876..27d75a1e7 100644 --- a/src/dashboard/apigateway/apigateway/biz/resource/savers.py +++ b/src/dashboard/apigateway/apigateway/biz/resource/savers.py @@ -73,7 +73,10 @@ def _save_resources(self) -> bool: for resource_data in self.resource_data_list: if resource_data.resource: resource = resource_data.resource - resource.__dict__.update(updated_by=self.username, **resource_data.basic_data) + resource.updated_by = self.username + for key, value in resource_data.basic_data.items(): + setattr(resource, key, value) + update_resources.append(resource) else: resource = Resource( @@ -140,12 +143,11 @@ def _save_proxies(self, resource_ids: List[int]): proxy = proxies.get(resource_data.resource.id) if proxy: - proxy.__dict__.update( - type=ProxyTypeEnum.HTTP.value, - backend=resource_data.backend, - schema=schema, - _config=resource_data.backend_config.json(), - ) + proxy.type = ProxyTypeEnum.HTTP.value + proxy.backend = resource_data.backend + proxy.schema = schema + proxy._config = resource_data.backend_config.json() + update_proxies.append(proxy) else: proxy = Proxy( @@ -188,7 +190,8 @@ def _save_auth_configs(self, resource_ids: List[int]): auth_config.update(resource_data.auth_config.dict()) if context: - context.__dict__.update(_config=json.dumps(auth_config)) + context._config = json.dumps(auth_config) + update_contexts.append(context) else: context = Context( diff --git a/src/dashboard/apigateway/apigateway/common/plugin/header_rewrite.py b/src/dashboard/apigateway/apigateway/common/plugin/header_rewrite.py index e71a3431e..0cd6869ff 100644 --- a/src/dashboard/apigateway/apigateway/common/plugin/header_rewrite.py +++ b/src/dashboard/apigateway/apigateway/common/plugin/header_rewrite.py @@ -15,7 +15,7 @@ # We undertake not to change the open source license (MIT license) applicable # to the current version of the project delivered to anyone in the future. # -from typing import Optional +from typing import Dict, Optional from apigateway.apps.plugin.constants import PluginTypeCodeEnum from apigateway.apps.plugin.models import PluginBinding, PluginConfig, PluginType @@ -34,49 +34,89 @@ def transform_headers_to_plugin_config(transform_headers: dict) -> Optional[dict "remove": [{"key": key} for key in (transform_headers.get("delete") or [])], } - @staticmethod - def alter_plugin(gateway_id: int, scope_type: str, scope_id: int, plugin_config: Optional[dict]): - # 判断是否已经绑定header rewrite插件 - binding = ( - PluginBinding.objects.filter( + @classmethod + def sync_plugins( + cls, + gateway_id: int, + scope_type: str, + scope_id_to_plugin_config: Dict[int, Optional[Dict]], + username: str, + ): + """根据配置,同步 bk-header-rewrite 插件与 scope 对象的绑定 + - scope_type: Scope 类型 + - scope_id_to_plugin_config: Scope id 到插件配置的映射 + - username: 当前操作者的用户名 + """ + plugin_type = PluginType.objects.get(code=PluginTypeCodeEnum.BK_HEADER_REWRITE.value) + exist_bindings = { + binding.scope_id: binding + for binding in PluginBinding.objects.filter( + gateway_id=gateway_id, scope_type=scope_type, - scope_id=scope_id, - config__type__code=PluginTypeCodeEnum.BK_HEADER_REWRITE.value, - ) - .prefetch_related("config") - .first() - ) + scope_id__in=scope_id_to_plugin_config.keys(), + config__type=plugin_type, + ).prefetch_related("config") + } - if not binding and not plugin_config: - return + add_bindings = {} + update_plugin_configs = [] + delete_bindings = [] - if binding: - if plugin_config: - # 如果已经绑定, 更新插件配置 - config = binding.config - config.yaml = yaml_dumps(plugin_config) - PluginConfig.objects.bulk_update([config], ["yaml"]) - return + for scope_id, plugin_config in scope_id_to_plugin_config.items(): + if not plugin_config: + if scope_id in exist_bindings: + # 配置为空,但是插件已存在,则删除 + delete_bindings.append(exist_bindings[scope_id]) + continue - # 插件配置为空, 清理数据 - config = binding.config - PluginBinding.objects.bulk_delete([binding]) - PluginConfig.objects.bulk_delete([config]) - return + if scope_id in exist_bindings: + # 插件已绑定,更新插件配置 + plugin_config_obj = exist_bindings[scope_id].config + plugin_config_obj.yaml = yaml_dumps(plugin_config) + plugin_config_obj.updated_by = username + update_plugin_configs.append(plugin_config_obj) + else: + # 插件未绑定,新建插件配置 + add_bindings[scope_id] = PluginConfig( + gateway_id=gateway_id, + name=cls._generate_plugin_name(scope_type, scope_id), + type=plugin_type, + yaml=yaml_dumps(plugin_config), + created_by=username, + ) - # 如果没有绑定, 新建插件配置, 并绑定到scope - if plugin_config: - config = PluginConfig( - gateway_id=gateway_id, - name=f"{scope_type} [{scope_id}] header rewrite", - type=PluginType.objects.get(code=PluginTypeCodeEnum.BK_HEADER_REWRITE.value), - yaml=yaml_dumps(plugin_config), - ) - config.save() - binding = PluginBinding( - gateway_id=gateway_id, - scope_type=scope_type, - scope_id=scope_id, - config=config, - ) - PluginBinding.objects.bulk_create([binding]) + if add_bindings: + PluginConfig.objects.bulk_create(add_bindings.values(), batch_size=100) + + plugin_configs = { + config.name: config for config in PluginConfig.objects.filter(gateway_id=gateway_id, type=plugin_type) + } + + bindings = [] + for scope_id in add_bindings: + plugin_config = plugin_configs[cls._generate_plugin_name(scope_type, scope_id)] + bindings.append( + PluginBinding( + gateway_id=gateway_id, + scope_type=scope_type, + scope_id=scope_id, + config=plugin_config, + created_by=username, + ) + ) + PluginBinding.objects.bulk_create(bindings, batch_size=100) + + if update_plugin_configs: + PluginConfig.objects.bulk_update(update_plugin_configs, fields=["yaml", "updated_by"], batch_size=100) + + if delete_bindings: + PluginBinding.objects.filter( + gateway_id=gateway_id, id__in=[binding.id for binding in delete_bindings] + ).delete() + PluginConfig.objects.filter( + gateway_id=gateway_id, id__in=[binding.config.id for binding in delete_bindings] + ).delete() + + @staticmethod + def _generate_plugin_name(scope_type: str, scope_id: int) -> str: + return f"bk-header-rewrite::{scope_type}::{scope_id}" diff --git a/src/dashboard/apigateway/apigateway/core/management/commands/migrate_backend.py b/src/dashboard/apigateway/apigateway/core/management/commands/migrate_backend.py index 45738f2ec..0cd53abb7 100644 --- a/src/dashboard/apigateway/apigateway/core/management/commands/migrate_backend.py +++ b/src/dashboard/apigateway/apigateway/core/management/commands/migrate_backend.py @@ -16,13 +16,12 @@ # to the current version of the project delivered to anyone in the future. # import logging -import re -from collections import defaultdict -from typing import Any, Dict, List +from typing import Any, Dict from django.core.management.base import BaseCommand from django.core.paginator import Paginator +from apigateway.biz.resource.importer.legacy_synchronizers import LegacyBackendCreator, LegacyUpstream from apigateway.core.constants import DEFAULT_BACKEND_NAME, ContextScopeTypeEnum, ContextTypeEnum from apigateway.core.models import Backend, BackendConfig, Context, Gateway, Proxy, Stage @@ -49,15 +48,13 @@ def handle(self, *args, **options): def _handle_gateway(self, gateway: Gateway): # 创建默认backend - default_backend, _ = Backend.objects.get_or_create( - gateway=gateway, - name=DEFAULT_BACKEND_NAME, - ) + default_backend, _ = Backend.objects.get_or_create(gateway=gateway, name=DEFAULT_BACKEND_NAME) - # 迁移stage的proxy配置 + # 迁移 stage 的 proxy 配置 stages = list(Stage.objects.filter(gateway=gateway)) - # 记录stage配置的timeout, 用于后续resource的数据迁移 - stage_timeout: Dict[int, int] = {} + # 记录 stage 配置的 timeout, 用于后续 resource 的数据迁移 + stage_id_to_timeout: Dict[int, int] = {} + for stage in stages: context = Context.objects.filter( scope_type=ContextScopeTypeEnum.STAGE.value, @@ -69,134 +66,39 @@ def _handle_gateway(self, gateway: Gateway): continue config = context.config - self._handle_stage_backend(gateway, stage, default_backend, config) + stage_id_to_timeout[stage.id] = config["timeout"] - stage_timeout[stage.id] = config["timeout"] + self._handle_stage_backend(gateway, stage, default_backend, config) - # config 与已创建 backend 映射 - backend_stage_config: Dict[int, Dict[int, Any]] = self._get_backend_stage_config_map(gateway) + legacy_backend_creator = LegacyBackendCreator(gateway=gateway, username="cli") - resource_backend_count = self._get_max_resource_backend_count(gateway) # 迁移resource的proxy上游配置 qs = Proxy.objects.filter(resource__gateway=gateway).all().order_by("id") paginator = Paginator(qs, 100) for i in paginator.page_range: for proxy in paginator.page(i): config = proxy.config - if "upstreams" not in config or not config["upstreams"]: - # 关联resource与default_backend + if not config.get("upstreams"): + # 未配置自定义后端,关联 resource 到 default_backend proxy.backend = default_backend proxy.save() continue - resource_stage_config = self._get_resource_stage_config_map(stages, stage_timeout, config) - backend_id = self._match_existing_backend(backend_stage_config, resource_stage_config) - if backend_id is not None: - proxy.backend_id = backend_id - proxy.save() - continue + legacy_upstream = LegacyUpstream(config["upstreams"]) + stage_id_to_backend_config = legacy_upstream.get_stage_id_to_backend_config( + stages, stage_id_to_timeout + ) - resource_backend_count += 1 - backend = self._handle_resource_backend(gateway, resource_backend_count, resource_stage_config) - # 关联resource与backend + backend = legacy_backend_creator.match_or_create_backend(stage_id_to_backend_config) proxy.backend = backend proxy.save() - backend_stage_config[backend.id] = resource_stage_config - - def _get_max_resource_backend_count(self, gateway: Gateway): - count = 0 - names = Backend.objects.filter(gateway=gateway, name__startswith="backend-").values_list("name", flat=True) - for name in names: - if name.split("-")[-1].isdigit() and int(name.split("-")[-1]) > count: - count = int(name.split("-")[-1]) - - return count - - def _match_existing_backend(self, backend_stage_config, resource_stage_config): - for backend_id, stage_config in backend_stage_config.items(): - if stage_config == resource_stage_config: - return backend_id - - return None - - def _handle_resource_backend( + def _handle_stage_backend( self, gateway: Gateway, - resource_backend_count: int, - stage_config: Dict[int, Any], - ) -> Backend: - backend = Backend.objects.create( - gateway=gateway, - name=f"backend-{resource_backend_count}", - ) - - backend_configs = [] - for stage_id, config in stage_config.items(): - backend_config = BackendConfig( - gateway=gateway, - backend=backend, - stage_id=stage_id, - config=config, - ) - backend_configs.append(backend_config) - - if backend_configs: - BackendConfig.objects.bulk_create(backend_configs) - - return backend - - def _get_resource_stage_config_map( - self, - stages: List[Stage], - stage_timeout: Dict[int, int], + stage: Stage, + backend: Backend, proxy_http_config: Dict[str, Any], - ) -> Dict[int, Dict[str, Any]]: - stage_config = {} - for stage in stages: - vars = stage.vars - - hosts = [] - for host in proxy_http_config["upstreams"]["hosts"]: - scheme, _host = host["host"].rstrip("/").split("://") - - # 渲染host中的环境变量 - matches = re.findall(r"\{env.(\w+)\}", _host) - for key in matches: - if key in vars: - _host = _host.replace("{env." + key + "}", vars[key]) - - hosts.append({"scheme": scheme, "host": _host, "weight": host["weight"]}) - - hosts = self._sort_hosts(hosts) - - stage_config[stage.id] = { - "type": "node", - "timeout": stage_timeout[stage.id], - "loadbalance": proxy_http_config["upstreams"]["loadbalance"], - "hosts": hosts, - } - - return stage_config - - def _get_backend_stage_config_map(self, gateway: Gateway) -> Dict[int, Dict[int, Any]]: - backend_stage_config: Dict[int, Dict[int, Any]] = defaultdict(dict) - - for backend in Backend.objects.filter(gateway=gateway): - for backend_config in BackendConfig.objects.filter(gateway=gateway, backend=backend): - config = backend_config.config - config["hosts"] = self._sort_hosts(config["hosts"]) - - backend_stage_config[backend.id][backend_config.stage_id] = config - - return backend_stage_config - - def _sort_hosts(self, hosts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - # 排序host, == 对比时顺序一致 - return sorted(hosts, key=lambda x: "{}://{}#{}".format(x["scheme"], x["host"], x["weight"])) - - def _handle_stage_backend( - self, gateway: Gateway, stage: Stage, backend: Backend, proxy_http_config: Dict[str, Any] ): hosts = [] for host in proxy_http_config["upstreams"]["hosts"]: diff --git a/src/dashboard/apigateway/apigateway/tests/apis/open/stage/test_views.py b/src/dashboard/apigateway/apigateway/tests/apis/open/stage/test_views.py index 83cf8c5cf..bd045020b 100644 --- a/src/dashboard/apigateway/apigateway/tests/apis/open/stage/test_views.py +++ b/src/dashboard/apigateway/apigateway/tests/apis/open/stage/test_views.py @@ -137,7 +137,7 @@ def test_sync(self, mocker, unique_gateway_name, request_factory): ) mocker.patch( - "apigateway.common.plugin.header_rewrite.HeaderRewriteConvertor.alter_plugin", + "apigateway.common.plugin.header_rewrite.HeaderRewriteConvertor.sync_plugins", return_value=True, ) diff --git a/src/dashboard/apigateway/apigateway/tests/apis/web/resource/test_legacy_serializers.py b/src/dashboard/apigateway/apigateway/tests/apis/web/resource/test_legacy_serializers.py new file mode 100644 index 000000000..f92a6b41f --- /dev/null +++ b/src/dashboard/apigateway/apigateway/tests/apis/web/resource/test_legacy_serializers.py @@ -0,0 +1,124 @@ +import pytest +from rest_framework.exceptions import ValidationError + +from apigateway.apis.web.resource.legacy_serializers import ( + LegacyResourceHostSLZ, + LegacyTransformHeadersSLZ, + LegacyUpstreamsSLZ, +) + + +class TestLegacyResourceHostSLZ: + @pytest.mark.parametrize( + "data, expected, expected_error", + [ + ( + {"host": "http://{env.foo}"}, + {"host": "http://{env.foo}", "weight": 100}, + None, + ), + ( + {"host": "http://{env.foo}", "weight": 10}, + {"host": "http://{env.foo}", "weight": 10}, + None, + ), + ( + {}, + None, + ValidationError, + ), + ( + {"host": "{env.foo}", "weight": 10}, + None, + ValidationError, + ), + ], + ) + def test_validate(self, data, expected, expected_error): + slz = LegacyResourceHostSLZ(data=data) + if not expected_error: + slz.is_valid(raise_exception=True) + assert slz.validated_data == expected + return + + with pytest.raises(expected_error): + slz.is_valid(raise_exception=True) + + +class TestLegacyUpstreamsSLZ: + @pytest.mark.parametrize( + "data, expected, expected_error", + [ + ( + {}, + {}, + None, + ), + ( + {"hosts": [{"host": "http://{env.foo}"}], "loadbalance": "roundrobin"}, + {"hosts": [{"host": "http://{env.foo}", "weight": 100}], "loadbalance": "roundrobin"}, + None, + ), + ( + {"hosts": [{"host": "http://{env.foo}"}]}, + None, + ValidationError, + ), + ], + ) + def test_validate(self, data, expected, expected_error): + slz = LegacyUpstreamsSLZ(data=data) + if not expected_error: + slz.is_valid(raise_exception=True) + assert slz.validated_data == expected + return + + with pytest.raises(expected_error): + slz.is_valid(raise_exception=True) + + +class TestLegacyTransformHeadersSLZ: + @pytest.mark.parametrize( + "data, expected, expected_error", + [ + ( + {}, + {}, + None, + ), + ( + {"set": {}, "delete": []}, + {"set": {}, "delete": []}, + None, + ), + ( + {"set": {"X-Token": "test"}, "delete": []}, + {"set": {"X-Token": "test"}, "delete": []}, + None, + ), + ( + {"set": {"X_Token": "test"}, "delete": ["X-Token"]}, + None, + ValidationError, + ), + ( + {"set": {"": "test"}, "delete": []}, + None, + ValidationError, + ), + ( + {"set": {"a" * 101: "test"}, "delete": []}, + None, + ValidationError, + ), + ], + ) + def test_validate(self, data, expected, expected_error): + slz = LegacyTransformHeadersSLZ(data=data) + if not expected_error: + slz.is_valid(raise_exception=True) + assert slz.validated_data == expected + return + + with pytest.raises(expected_error): + slz.is_valid(raise_exception=True) diff --git a/src/dashboard/apigateway/apigateway/tests/apis/web/resource/test_serializers.py b/src/dashboard/apigateway/apigateway/tests/apis/web/resource/test_serializers.py index f31f00863..070a45ce7 100644 --- a/src/dashboard/apigateway/apigateway/tests/apis/web/resource/test_serializers.py +++ b/src/dashboard/apigateway/apigateway/tests/apis/web/resource/test_serializers.py @@ -26,6 +26,7 @@ from apigateway.apis.web.resource.serializers import ( BackendPathCheckInputSLZ, + HttpBackendConfigSLZ, ResourceDataSLZ, ResourceExportOutputSLZ, ResourceImportInputSLZ, @@ -63,6 +64,64 @@ def test_has_updated(self, fake_resource, context, expected): assert slz.get_has_updated(fake_resource) is expected +class TestHttpBackendConfigSLZ: + @pytest.mark.parametrize( + "data, expected", + [ + ( + { + "method": "GET", + "path": "/test", + }, + { + "method": "GET", + "path": "/test", + "legacy_upstreams": None, + "legacy_transform_headers": None, + }, + ), + ( + { + "method": "GET", + "path": "/test", + "legacy_upstreams": None, + "legacy_transform_headers": None, + }, + { + "method": "GET", + "path": "/test", + "legacy_upstreams": None, + "legacy_transform_headers": None, + }, + ), + ( + { + "method": "GET", + "path": "/test", + "legacy_upstreams": { + "hosts": [{"host": "http://{env.foo}", "weight": 20}], + "loadbalance": "roundrobin", + }, + "legacy_transform_headers": {"set": {"x-token": "test"}, "delete": ["x-token"]}, + }, + { + "method": "GET", + "path": "/test", + "legacy_upstreams": { + "hosts": [{"host": "http://{env.foo}", "weight": 20}], + "loadbalance": "roundrobin", + }, + "legacy_transform_headers": {"set": {"x-token": "test"}, "delete": ["x-token"]}, + }, + ), + ], + ) + def test_validate(self, data, expected): + slz = HttpBackendConfigSLZ(data=data) + slz.is_valid(raise_exception=True) + assert slz.data == expected + + class TestResourceInputSLZ: @pytest.mark.parametrize( "description_en, expected", diff --git a/src/dashboard/apigateway/apigateway/tests/biz/resource/importer/test_legacy_synchronizers.py b/src/dashboard/apigateway/apigateway/tests/biz/resource/importer/test_legacy_synchronizers.py new file mode 100644 index 000000000..1f86a8227 --- /dev/null +++ b/src/dashboard/apigateway/apigateway/tests/biz/resource/importer/test_legacy_synchronizers.py @@ -0,0 +1,392 @@ +import copy + +import pytest +from ddf import G + +from apigateway.apps.plugin.models import PluginBinding, PluginConfig +from apigateway.biz.resource.importer.legacy_synchronizers import ( + LegacyBackendCreator, + LegacyTransformHeadersToPluginSynchronizer, + LegacyUpstream, + LegacyUpstreamToBackendSynchronizer, +) +from apigateway.core.constants import DEFAULT_BACKEND_NAME +from apigateway.core.models import Backend, BackendConfig, Stage + + +class TestLegacyUpstream: + def test_get_stage_id_to_backend_config(self, fake_gateway): + s1 = G(Stage, gateway=fake_gateway, _vars='{"foo": "bar.com"}') + s2 = G(Stage, gateway=fake_gateway, _vars='{"foo": "baz.com"}') + + upstreams = { + "hosts": [{"host": "https://{env.foo}/", "weight": 10}], + "loadbalance": "roundrobin", + } + stage_id_to_timeout = {s1.id: 20, s2.id: 30} + + result = LegacyUpstream(upstreams).get_stage_id_to_backend_config([s1, s2], stage_id_to_timeout) + assert result == { + s1.id: { + "type": "node", + "timeout": 20, + "loadbalance": "roundrobin", + "hosts": [{"scheme": "https", "host": "bar.com", "weight": 10}], + }, + s2.id: { + "type": "node", + "timeout": 30, + "loadbalance": "roundrobin", + "hosts": [{"scheme": "https", "host": "baz.com", "weight": 10}], + }, + } + + @pytest.mark.parametrize( + "vars, host, expected", + [ + ({"foo": "bar.com"}, "https://{env.foo}/", "https://bar.com/"), + ({"foo1": "bar.com", "foo2": "baz.com"}, "https://{env.foo1}/{env.foo2}", "https://bar.com/baz.com"), + ({}, "https://{env.foo}/", "https://{env.foo}/"), + ({"color": "green"}, "https://{env.foo}/", "https://{env.foo}/"), + ], + ) + def test_render_host(self, vars, host, expected): + result = LegacyUpstream({})._render_host(vars, host) + assert result == expected + + +class TestLegacyBackendCreator: + def test_match_or_create_backend(self, fake_gateway, fake_stage): + stage_id_to_backend_config = { + fake_stage.id: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [{"scheme": "https", "host": "foo.com", "weight": 10}], + } + } + creator = LegacyBackendCreator(fake_gateway, "admin") + result = creator.match_or_create_backend(stage_id_to_backend_config) + assert result.name == "backend-1" + assert BackendConfig.objects.get(backend=result).config == stage_id_to_backend_config[fake_stage.id] + + result = creator.match_or_create_backend(stage_id_to_backend_config) + assert result.name == "backend-1" + + stage_id_to_backend_config_2 = copy.deepcopy(stage_id_to_backend_config) + stage_id_to_backend_config_2[fake_stage.id]["timeout"] = 10 + + result = creator.match_or_create_backend(stage_id_to_backend_config_2) + assert result.name == "backend-2" + + result = creator.match_or_create_backend(stage_id_to_backend_config) + assert result.name == "backend-1" + + @pytest.mark.parametrize( + "existing_backend_configs, stage_id_to_backend_config, expected", + [ + ( + { + 1: { + 1: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [ + {"scheme": "https", "host": "foo.com", "weight": 10}, + {"scheme": "http", "host": "bar.com", "weight": 10}, + ], + }, + 2: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [ + {"scheme": "http", "host": "foo.com", "weight": 20}, + {"scheme": "http", "host": "bar.com", "weight": 20}, + ], + }, + }, + }, + { + 1: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [ + {"scheme": "https", "host": "foo.com", "weight": 10}, + {"scheme": "http", "host": "bar.com", "weight": 10}, + ], + }, + 2: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [ + {"scheme": "http", "host": "foo.com", "weight": 20}, + {"scheme": "http", "host": "bar.com", "weight": 20}, + ], + }, + }, + 1, + ), + ( + { + 1: { + 1: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [ + {"scheme": "http", "host": "foo.com", "weight": 10}, + {"scheme": "http", "host": "bar.com", "weight": 10}, + ], + }, + }, + }, + { + 1: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [ + {"scheme": "http", "host": "bar.com", "weight": 10}, + {"scheme": "http", "host": "foo.com", "weight": 10}, + ], + }, + }, + None, + ), + ( + { + 1: { + 1: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [ + {"scheme": "https", "host": "foo.com", "weight": 10}, + ], + }, + }, + }, + { + 1: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [ + {"scheme": "http", "host": "foo.com", "weight": 10}, + ], + }, + }, + None, + ), + ], + ) + def test_match_existing_backend( + self, + fake_gateway, + existing_backend_configs, + stage_id_to_backend_config, + expected, + ): + creator = LegacyBackendCreator(fake_gateway, "admin") + creator._existing_backend_configs = existing_backend_configs + + result = creator._match_existing_backend(stage_id_to_backend_config) + assert result == expected + + def test_get_existing_backend_configs(self, fake_gateway, fake_stage): + b1 = G(Backend, name="default", gateway=fake_gateway) + b2 = G(Backend, name="backend-1", gateway=fake_gateway) + + G( + BackendConfig, + backend=b1, + gateway=fake_gateway, + stage=fake_stage, + config={ + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [{"scheme": "http", "host": "foo.com", "weight": 100}], + }, + ) + G( + BackendConfig, + backend=b2, + gateway=fake_gateway, + stage=fake_stage, + config={ + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [{"scheme": "http", "host": "bar.com", "weight": 100}], + }, + ) + + creator = LegacyBackendCreator(fake_gateway, "admin") + result = creator._get_existing_backend_configs() + assert result == { + b1.id: { + fake_stage.id: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [{"scheme": "http", "host": "foo.com", "weight": 100}], + } + }, + b2.id: { + fake_stage.id: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [{"scheme": "http", "host": "bar.com", "weight": 100}], + } + }, + } + + def test_generate_new_backend_name(self, fake_gateway): + creator = LegacyBackendCreator(fake_gateway, "admin") + result = creator._generate_new_backend_name() + assert result == "backend-1" + + creator._max_legacy_backend_number = 100 + result = creator._generate_new_backend_name() + assert result == "backend-101" + + def test_create_backend_and_backend_configs(self, fake_gateway, fake_stage): + creator = LegacyBackendCreator(fake_gateway, "admin") + creator._create_backend_and_backend_configs( + "backend-1", + { + fake_stage.id: { + "type": "node", + "timeout": 50, + "loadbalance": "roundrobin", + "hosts": [{"scheme": "http", "host": "foo.com", "weight": 100}], + } + }, + ) + assert Backend.objects.filter(name="backend-1", gateway=fake_gateway).exists() + assert BackendConfig.objects.filter(backend__name="backend-1", gateway=fake_gateway).exists() + + @pytest.mark.parametrize( + "hosts, expected", + [ + ( + [ + {"scheme": "http", "host": "foo.com", "weight": 10}, + {"scheme": "http", "host": "bar.com", "weight": 10}, + {"scheme": "http", "host": "baz.com", "weight": 10}, + ], + [ + {"scheme": "http", "host": "bar.com", "weight": 10}, + {"scheme": "http", "host": "baz.com", "weight": 10}, + {"scheme": "http", "host": "foo.com", "weight": 10}, + ], + ), + ], + ) + def test_sort_hosts(self, fake_gateway, hosts, expected): + creator = LegacyBackendCreator(fake_gateway, "admin") + result = creator._sort_hosts(hosts) + assert result == expected + + def test_get_max_legacy_backend_number(self, fake_gateway): + creator = LegacyBackendCreator(fake_gateway, "admin") + result = creator._get_max_legacy_backend_number() + assert result == 0 + + G(Backend, name="backend-1", gateway=fake_gateway) + G(Backend, name="backend-10", gateway=fake_gateway) + G(Backend, name="foo", gateway=fake_gateway) + G(Backend, name="backend-2", gateway=fake_gateway) + + result = creator._get_max_legacy_backend_number() + assert result == 10 + + +class TestLegacyUpstreamToBackendSynchronizer: + def test_sync_backends_and_replace_resource_backend__no_upstreams(self, fake_gateway, fake_resource_data): + synchronizer = LegacyUpstreamToBackendSynchronizer(fake_gateway, [fake_resource_data], "admin") + synchronizer.sync_backends_and_replace_resource_backend() + assert fake_resource_data.backend is None + + def test_sync_backends_and_replace_resource_backend__has_upstreams( + self, + fake_gateway, + fake_stage, + fake_resource_data, + ): + backend = G(Backend, name=DEFAULT_BACKEND_NAME, gateway=fake_gateway) + G( + BackendConfig, + gateway=fake_gateway, + stage=fake_stage, + backend=backend, + config={ + "type": "node", + "timeout": 30, + "loadbalance": "roundrobin", + "hosts": [{"scheme": "http", "host": "foo.com", "weight": 100}], + }, + ) + fake_resource_data.backend_config.legacy_upstreams = { + "loadbalance": "roundrobin", + "hosts": [{"host": "https://bar.com", "weight": 10}], + } + + synchronizer = LegacyUpstreamToBackendSynchronizer(fake_gateway, [fake_resource_data], "admin") + synchronizer.sync_backends_and_replace_resource_backend() + + backend = Backend.objects.get(gateway=fake_gateway, name="backend-1") + backend_config = BackendConfig.objects.get(gateway=fake_gateway, backend__name="backend-1") + assert fake_resource_data.backend == backend + assert backend_config.config == { + "type": "node", + "timeout": 30, + "loadbalance": "roundrobin", + "hosts": [{"scheme": "https", "host": "bar.com", "weight": 10}], + } + + +class TestLegacyTransformHeadersToPluginSynchronizer: + def test_sync_plugins(self, fake_gateway, fake_resource, fake_resource_data, fake_plugin_type_bk_header_rewrite): + fake_resource_data.resource = fake_resource + synchronizer = LegacyTransformHeadersToPluginSynchronizer(fake_gateway, [fake_resource_data], "admin") + + synchronizer.sync_plugins() + assert not PluginConfig.objects.filter(gateway=fake_gateway).exists() + assert not PluginBinding.objects.filter(gateway=fake_gateway).exists() + + # add + fake_resource_data.backend_config.legacy_transform_headers = { + "set": {"x-token": "test"}, + "delete": ["x-token"], + } + synchronizer.sync_plugins() + plugin_config = PluginConfig.objects.get(gateway=fake_gateway, type__code="bk-header-rewrite") + assert plugin_config.config == {"set": [{"key": "x-token", "value": "test"}], "remove": [{"key": "x-token"}]} + assert PluginBinding.objects.filter( + gateway=fake_gateway, scope_type="resource", scope_id=fake_resource.id + ).exists() + + # update + fake_resource_data.backend_config.legacy_transform_headers = { + "set": {"x-foo": "test"}, + "delete": ["x-bar"], + } + synchronizer.sync_plugins() + plugin_config = PluginConfig.objects.get(gateway=fake_gateway, type__code="bk-header-rewrite") + assert plugin_config.config == {"set": [{"key": "x-foo", "value": "test"}], "remove": [{"key": "x-bar"}]} + assert PluginBinding.objects.filter( + gateway=fake_gateway, scope_type="resource", scope_id=fake_resource.id + ).exists() + + # delete + fake_resource_data.backend_config.legacy_transform_headers = {} + synchronizer.sync_plugins() + assert not PluginConfig.objects.filter(gateway=fake_gateway).exists() + assert not PluginBinding.objects.filter(gateway=fake_gateway).exists() diff --git a/src/dashboard/apigateway/apigateway/tests/biz/resource/importer/test_swagger.py b/src/dashboard/apigateway/apigateway/tests/biz/resource/importer/test_swagger.py index cc38b1b1e..30113efe8 100644 --- a/src/dashboard/apigateway/apigateway/tests/biz/resource/importer/test_swagger.py +++ b/src/dashboard/apigateway/apigateway/tests/biz/resource/importer/test_swagger.py @@ -687,6 +687,8 @@ def fake_swagger_content(self): "method": "GET", "match_subpath": True, "timeout": 30, + "legacy_transform_headers": None, + "legacy_upstreams": None, }, "auth_config": { "auth_verified_required": False, @@ -738,6 +740,8 @@ def fake_swagger_content(self): "path": "/echo/", "match_subpath": False, "timeout": 0, + "legacy_transform_headers": None, + "legacy_upstreams": None, }, "auth_config": { "auth_verified_required": True, @@ -789,6 +793,8 @@ def fake_swagger_content(self): "path": "/echo/", "match_subpath": False, "timeout": 0, + "legacy_transform_headers": None, + "legacy_upstreams": None, }, "auth_config": { "auth_verified_required": True, @@ -810,26 +816,58 @@ def test_adapt_method(self, fake_swagger_content): assert importer._adapt_method("get") == "GET" assert importer._adapt_method("x-bk-apigateway-method-any") == "ANY" - def test_adapt_backend(self, fake_swagger_content): + def test_adapt_backend__error(self, fake_swagger_content): importer = ResourceSwaggerImporter(fake_swagger_content) with pytest.raises(ValueError): - importer._adapt_backend({"upstreams": {"foo": "bar"}}) + importer._adapt_backend({"type": "MOCK"}) - result = importer._adapt_backend( - { - "type": "HTTP", - "method": "get", - "path": "/foo", - "matchSubpath": True, - }, - ) - assert result == { - "method": "GET", - "path": "/foo", - "match_subpath": True, - "timeout": 0, - } + @pytest.mark.parametrize( + "backend, expected", + [ + ( + { + "type": "HTTP", + "method": "get", + "path": "/foo", + "matchSubpath": True, + }, + { + "method": "GET", + "path": "/foo", + "match_subpath": True, + "timeout": 0, + "legacy_upstreams": None, + "legacy_transform_headers": None, + }, + ), + ( + { + "type": "HTTP", + "method": "get", + "path": "/foo", + "matchSubpath": True, + "upstreams": {"loadbalance": "roundrobin", "hosts": [{"host": "http://foo.com", "weight": 100}]}, + "transformHeaders": {"set": {"x-token": "test"}, "delete": ["x-token"]}, + }, + { + "method": "GET", + "path": "/foo", + "match_subpath": True, + "timeout": 0, + "legacy_upstreams": { + "loadbalance": "roundrobin", + "hosts": [{"host": "http://foo.com", "weight": 100}], + }, + "legacy_transform_headers": {"set": {"x-token": "test"}, "delete": ["x-token"]}, + }, + ), + ], + ) + def test_adapt_backend(self, fake_swagger_content, backend, expected): + importer = ResourceSwaggerImporter(fake_swagger_content) + result = importer._adapt_backend(backend) + assert result == expected @pytest.mark.parametrize( "auth_config, expected", diff --git a/src/dashboard/apigateway/apigateway/tests/biz/resource/test_models.py b/src/dashboard/apigateway/apigateway/tests/biz/resource/test_models.py new file mode 100644 index 000000000..b4dfad110 --- /dev/null +++ b/src/dashboard/apigateway/apigateway/tests/biz/resource/test_models.py @@ -0,0 +1,62 @@ +import pytest + +from apigateway.biz.resource.models import ResourceBackendConfig + + +class TestResourceBackendConfig: + @pytest.mark.parametrize( + "data, expected, expected_legacy_upstreams, expected_legacy_transform_headers", + [ + ( + { + "method": "GET", + "path": "/user", + "match_subpath": True, + "timeout": 10, + }, + { + "method": "GET", + "path": "/user", + "match_subpath": True, + "timeout": 10, + }, + None, + None, + ), + ( + { + "method": "GET", + "path": "/user", + "match_subpath": True, + "timeout": 10, + "legacy_upstreams": { + "loadbalance": "roundrobin", + "hosts": [{"host": "http://foo.com", "weight": 10}], + }, + "legacy_transform_headers": { + "set": {"x-token": "test"}, + "delete": ["x-token"], + }, + }, + { + "method": "GET", + "path": "/user", + "match_subpath": True, + "timeout": 10, + }, + { + "loadbalance": "roundrobin", + "hosts": [{"host": "http://foo.com", "weight": 10}], + }, + { + "set": {"x-token": "test"}, + "delete": ["x-token"], + }, + ), + ], + ) + def test_dict(self, data, expected, expected_legacy_upstreams, expected_legacy_transform_headers): + config = ResourceBackendConfig.parse_obj(data) + assert config.legacy_upstreams == expected_legacy_upstreams + assert config.legacy_transform_headers == expected_legacy_transform_headers + assert config.dict() == expected diff --git a/src/dashboard/apigateway/apigateway/tests/biz/resource/test_savers.py b/src/dashboard/apigateway/apigateway/tests/biz/resource/test_savers.py index edb0c4ef3..a309cbe01 100644 --- a/src/dashboard/apigateway/apigateway/tests/biz/resource/test_savers.py +++ b/src/dashboard/apigateway/apigateway/tests/biz/resource/test_savers.py @@ -21,7 +21,7 @@ from apigateway.apps.label.models import APILabel, ResourceLabel from apigateway.biz.resource.savers import ResourceProxyDuplicateError, ResourcesSaver from apigateway.core.constants import ContextScopeTypeEnum, ContextTypeEnum, ProxyTypeEnum -from apigateway.core.models import Context, Proxy, Resource +from apigateway.core.models import Backend, Context, Proxy, Resource class TestResourceSavers: @@ -85,15 +85,24 @@ def test_complete_with_resource(self, fake_gateway, fake_resource_data): assert resource_data_list[1].resource == resource_2 def test_save_proxies(self, fake_gateway, fake_resource_data): + backend_1 = G(Backend, gateway=fake_gateway) + backend_2 = G(Backend, gateway=fake_gateway) + resource_1 = G(Resource, gateway=fake_gateway, name="foo1", method="GET") resource_2 = G(Resource, gateway=fake_gateway, name="foo2", method="POST") resource_data_list = [ - fake_resource_data.copy(update={"resource": resource_1}, deep=True), + fake_resource_data.copy(update={"resource": resource_1, "backend": backend_1}, deep=True), ] saver = ResourcesSaver(fake_gateway, resource_data_list, "admin") saver._save_proxies(resource_ids=[resource_1.id]) assert Proxy.objects.filter(resource__gateway=fake_gateway).count() == 1 + assert Proxy.objects.get(resource=resource_1).backend_id == backend_1.id + + # 测试 proxy backend 是否被更新 + resource_data_list[0].backend = backend_2 + saver._save_proxies(resource_ids=[resource_1.id]) + assert Proxy.objects.get(resource=resource_1).backend_id == backend_2.id resource_data_list = [ fake_resource_data.copy(update={"resource": resource_1}, deep=True), diff --git a/src/dashboard/apigateway/apigateway/tests/common/plugin/test_header_rewrite.py b/src/dashboard/apigateway/apigateway/tests/common/plugin/test_header_rewrite.py index 74f0882ae..ad49786c8 100644 --- a/src/dashboard/apigateway/apigateway/tests/common/plugin/test_header_rewrite.py +++ b/src/dashboard/apigateway/apigateway/tests/common/plugin/test_header_rewrite.py @@ -17,6 +17,7 @@ # import pytest +from apigateway.apps.plugin.models import PluginBinding, PluginConfig from apigateway.common.plugin.header_rewrite import HeaderRewriteConvertor @@ -34,3 +35,45 @@ class TestHeaderRewriteConvertor: ) def test_transform_headers_to_plugin_config(self, transform_headers, expected): assert HeaderRewriteConvertor.transform_headers_to_plugin_config(transform_headers) == expected + + def test_sync_plugins(self, fake_gateway, fake_resource, fake_plugin_type_bk_header_rewrite): + HeaderRewriteConvertor.sync_plugins(fake_gateway.id, "resource", {}, "admin") + assert not PluginConfig.objects.filter(gateway=fake_gateway).exists() + assert not PluginBinding.objects.filter(gateway=fake_gateway).exists() + + # add + scope_id_to_plugin_config = { + fake_resource.id: { + "set": [{"key": "x-token", "value": "test"}], + "remove": [{"key": "x-token"}], + } + } + HeaderRewriteConvertor.sync_plugins(fake_gateway.id, "resource", scope_id_to_plugin_config, "admin") + + plugin_config = PluginConfig.objects.get(gateway=fake_gateway, type__code="bk-header-rewrite") + assert plugin_config.config == {"set": [{"key": "x-token", "value": "test"}], "remove": [{"key": "x-token"}]} + assert PluginBinding.objects.filter( + gateway=fake_gateway, scope_type="resource", scope_id=fake_resource.id + ).exists() + + # update + scope_id_to_plugin_config = { + fake_resource.id: { + "set": [{"key": "x-foo", "value": "test"}], + "remove": [{"key": "x-bar"}], + } + } + HeaderRewriteConvertor.sync_plugins(fake_gateway.id, "resource", scope_id_to_plugin_config, "admin") + + plugin_config = PluginConfig.objects.get(gateway=fake_gateway, type__code="bk-header-rewrite") + assert plugin_config.config == {"set": [{"key": "x-foo", "value": "test"}], "remove": [{"key": "x-bar"}]} + assert PluginBinding.objects.filter( + gateway=fake_gateway, scope_type="resource", scope_id=fake_resource.id + ).exists() + + # delete + scope_id_to_plugin_config = {fake_resource.id: None} + HeaderRewriteConvertor.sync_plugins(fake_gateway.id, "resource", scope_id_to_plugin_config, "admin") + + assert not PluginConfig.objects.filter(gateway=fake_gateway).exists() + assert not PluginBinding.objects.filter(gateway=fake_gateway).exists() diff --git a/src/dashboard/apigateway/apigateway/tests/conftest.py b/src/dashboard/apigateway/apigateway/tests/conftest.py index e2f975940..06c9ee4ad 100644 --- a/src/dashboard/apigateway/apigateway/tests/conftest.py +++ b/src/dashboard/apigateway/apigateway/tests/conftest.py @@ -30,7 +30,7 @@ from django.urls import resolve, reverse from rest_framework.test import APIRequestFactory as DRFAPIRequestFactory -from apigateway.apps.plugin.constants import PluginBindingScopeEnum, PluginStyleEnum +from apigateway.apps.plugin.constants import PluginBindingScopeEnum, PluginStyleEnum, PluginTypeCodeEnum from apigateway.apps.plugin.models import PluginBinding, PluginConfig, PluginForm, PluginType from apigateway.apps.support.models import GatewaySDK, ReleasedResourceDoc, ResourceDoc, ResourceDocVersion from apigateway.biz.resource import ResourceHandler @@ -690,6 +690,18 @@ def echo_plugin_resource_binding(echo_plugin, fake_resource): ) +@pytest.fixture() +def fake_plugin_type_bk_header_rewrite(): + return PluginType.objects.get_or_create( + code=PluginTypeCodeEnum.BK_HEADER_REWRITE.value, + defaults={ + "name": PluginTypeCodeEnum.BK_HEADER_REWRITE.value, + "is_public": True, + "schema": None, + }, + ) + + @pytest.fixture() def clone_model(): """Clone a django model"""