diff --git a/apps/backend/subscription/handler.py b/apps/backend/subscription/handler.py index dfacdaf2f..baad55100 100644 --- a/apps/backend/subscription/handler.py +++ b/apps/backend/subscription/handler.py @@ -19,6 +19,7 @@ from django.conf import settings from django.core.cache import cache from django.db.models import Max, Q, QuerySet +from django.forms.models import model_to_dict from django.utils.translation import get_language from django.utils.translation import ugettext as _ @@ -306,7 +307,9 @@ def retry( except models.Subscription.DoesNotExist: raise errors.SubscriptionNotExist({"subscription_id": self.subscription_id}) - if tools.check_subscription_is_disabled(subscription): + if tools.check_subscription_is_disabled( + subscription_info=model_to_dict(subscription), scope=subscription.scope, steps=subscription.steps + ): raise errors.SubscriptionIncludeGrayBizError() base_filter_kwargs = filter_values( @@ -421,7 +424,9 @@ def run(self, scope: Dict = None, actions: Dict[str, str] = None) -> Dict[str, i if subscription.is_running(): raise InstanceTaskIsRunning() - if tools.check_subscription_is_disabled(subscription): + if tools.check_subscription_is_disabled( + subscription_info=model_to_dict(subscription), scope=subscription.scope, steps=subscription.steps + ): raise errors.SubscriptionIncludeGrayBizError() subscription_task = models.SubscriptionTask.objects.create( diff --git a/apps/backend/subscription/serializers.py b/apps/backend/subscription/serializers.py index 70c41e10c..474fddc15 100644 --- a/apps/backend/subscription/serializers.py +++ b/apps/backend/subscription/serializers.py @@ -12,9 +12,11 @@ import base64 from bkcrypto.asymmetric.ciphers import BaseAsymmetricCipher +from django.utils.translation import gettext_lazy as _ from rest_framework import serializers from apps.backend.constants import SubscriptionSwithBizAction +from apps.backend.subscription.tools import check_subscription_is_disabled from apps.exceptions import ValidationError from apps.node_man import constants, models, tools from apps.node_man.models import ProcessStatus @@ -95,6 +97,8 @@ class CreateStepSerializer(serializers.Serializer): pid = serializers.IntegerField(required=False, label="父策略ID") def validate(self, attrs): + if check_subscription_is_disabled(subscription_info=attrs, scope=attrs["scope"], steps=attrs["steps"]): + raise ValidationError(_("订阅范围包含Gse2.0灰度业务")) step_types = {step["type"] for step in attrs["steps"]} if constants.SubStepType.AGENT not in step_types: return attrs diff --git a/apps/backend/subscription/tasks.py b/apps/backend/subscription/tasks.py index d58fdadcc..9e9adc5b2 100644 --- a/apps/backend/subscription/tasks.py +++ b/apps/backend/subscription/tasks.py @@ -18,6 +18,7 @@ from functools import wraps from typing import Any, Dict, List, Optional, Union +from django.forms.models import model_to_dict from django.utils.translation import ugettext as _ from apps.backend.celery import app @@ -800,12 +801,10 @@ def update_subscription_instances_chunk(subscription_ids: List[int]): 分片更新订阅状态 """ subscriptions = models.Subscription.objects.filter(id__in=subscription_ids, enable=True) - disable_subscription_biz_ids: List[int] = models.GlobalSettings.get_config( - key=models.GlobalSettings.KeyEnum.DISABLE_SUBSCRIPTION_SCOPE_LIST.value, - default=[], - ) for subscription in subscriptions: - if tools.check_subscription_is_disabled(subscription, disable_subscription_biz_ids): + if tools.check_subscription_is_disabled( + subscription_info=model_to_dict(subscription), scope=subscription.scope, steps=subscription.steps + ): logger.info("[update_subscription_instances] skipped for subscription disabled") continue logger.info(f"[update_subscription_instances] start: {subscription}") diff --git a/apps/backend/subscription/tools.py b/apps/backend/subscription/tools.py index 35384ae7b..3f4b5fe86 100644 --- a/apps/backend/subscription/tools.py +++ b/apps/backend/subscription/tools.py @@ -26,6 +26,7 @@ from django.db.models import Q from django.utils import timezone +from apps.backend.constants import InstNodeType from apps.backend.subscription import task_tools from apps.backend.subscription.commons import get_host_by_inst, list_biz_hosts from apps.backend.subscription.constants import SUBSCRIPTION_SCOPE_CACHE_TIME @@ -1450,36 +1451,51 @@ def update_job_status(pipeline_id, result=None): ) -def check_subscription_is_disabled(subscription: models.Subscription, disable_biz_ids: List[int] = None) -> bool: +def check_subscription_is_disabled( + subscription_info: typing.Dict[str, typing.Any], + scope: typing.Dict[str, typing.Any], + steps: typing.Union[typing.List[Dict[str, typing.Any]], typing.List[models.SubscriptionStep]], + disable_biz_ids: List[int] = None, +) -> bool: """ 检查订阅任务是否已被禁用巡检 """ - if not disable_biz_ids: - disable_biz_ids: List[int] = models.GlobalSettings.get_config( - key=models.GlobalSettings.KeyEnum.DISABLE_SUBSCRIPTION_SCOPE_LIST.value, - default=[], - ) + nodes = scope["nodes"] is_step_include_plugin: bool = False - for step in subscription.steps: - if step.type == constants.SubStepType.PLUGIN: + for step in steps: + sub_step_type = step.type if hasattr(step, "type") else step["type"] + if sub_step_type == constants.SubStepType.PLUGIN: is_step_include_plugin = True break # 非插件类任务不进行禁用 if not is_step_include_plugin: - logger.info(f"[check_subscription_is_disabled] {subscription}: not include plugin step, skipping") + logger.info(f"[check_subscription_is_disabled] {subscription_info}: not include plugin step, skipping") return False + disable_biz_ids: List[int] = models.GlobalSettings.get_config( + key=models.GlobalSettings.KeyEnum.DISABLE_SUBSCRIPTION_SCOPE_LIST.value, + default=[], + ) + nodes_biz_ids: typing.Optional[typing.List[int]] = [node.get("bk_biz_id") for node in nodes] + + full_biz_scope_ids: typing.Optional[typing.List[int]] = [ + node.get("bk_inst_id") for node in nodes if node.get("bk_obj_id") == InstNodeType.BIZ + ] + + nodes_biz_ids = nodes_biz_ids + full_biz_scope_ids + if any( [ - subscription.bk_biz_id in disable_biz_ids, - set(subscription.bk_biz_scope or []) & set(disable_biz_ids), + scope.get("bk_biz_id") in disable_biz_ids, + set(scope.get("bk_biz_scope") or []) & set(disable_biz_ids), + set(nodes_biz_ids) & set(disable_biz_ids), ] ): # 禁用规则:订阅业务范围中包含已被禁用的业务 - logger.info(f"[check_subscription_is_disabled] {subscription}: in the disable list, will be disabled") + logger.info(f"[check_subscription_is_disabled] {subscription_info}: in the disable list, will be disabled") return True - logger.info(f"[check_subscription_is_disabled] {subscription}: not in the disable list, skipping") + logger.info(f"[check_subscription_is_disabled] {subscription_info}: not in the disable list, skipping") return False diff --git a/apps/backend/subscription/views.py b/apps/backend/subscription/views.py index c241ee80f..c6ac8043d 100644 --- a/apps/backend/subscription/views.py +++ b/apps/backend/subscription/views.py @@ -20,6 +20,7 @@ from django.core.cache import caches from django.db import transaction from django.db.models import Q +from django.forms.models import model_to_dict from django.utils.translation import get_language from django.utils.translation import gettext_lazy as _ from drf_yasg.utils import swagger_auto_schema @@ -117,10 +118,6 @@ def create_subscription(self, request): if subscription.is_running(): raise InstanceTaskIsRunning() - # 立即执行场景,如果订阅被禁用,需要抛出异常 - if tools.check_subscription_is_disabled(subscription): - raise errors.SubscriptionIncludeGrayBizError() - subscription_task = models.SubscriptionTask.objects.create( subscription_id=subscription.id, scope=subscription.scope, actions={} ) @@ -198,6 +195,13 @@ def update_subscription(self, request): subscription = models.Subscription.objects.get(id=params["subscription_id"], is_deleted=False) except models.Subscription.DoesNotExist: raise errors.SubscriptionNotExist({"subscription_id": params["subscription_id"]}) + # 更新订阅不在序列化器中做校验,因为获取更新订阅的类型 step 需要查一次表 + if tools.check_subscription_is_disabled( + subscription_info=model_to_dict(subscription), + steps=subscription.steps, + scope=scope, + ): + raise errors.SubscriptionIncludeGrayBizError() subscription.name = params.get("name", "") subscription.node_type = scope["node_type"] @@ -271,10 +275,6 @@ def update_subscription(self, request): if subscription.is_running(): raise InstanceTaskIsRunning() - # 立即执行场景,如果订阅被禁用,需要抛出异常 - if tools.check_subscription_is_disabled(subscription): - raise errors.SubscriptionIncludeGrayBizError() - subscription_task = models.SubscriptionTask.objects.create( subscription_id=subscription.id, scope=subscription.scope, actions={} ) diff --git a/apps/backend/tests/subscription/test_views.py b/apps/backend/tests/subscription/test_views.py index 95ef37861..4a6e90920 100644 --- a/apps/backend/tests/subscription/test_views.py +++ b/apps/backend/tests/subscription/test_views.py @@ -32,6 +32,7 @@ ) from apps.node_man import constants from apps.node_man.models import ( + GlobalSettings, GsePluginDesc, Host, Packages, @@ -49,6 +50,8 @@ class TestSubscription(TestCase): 测试订阅相关的接口 """ + TEST_BIZ_ID = 2 + def setUp(self): mock.patch.stopall() self.get_host_object_attribute_client = mock.patch( @@ -92,7 +95,12 @@ def _test_create_subscription(self): { "bk_username": "admin", "bk_app_code": "blueking", - "scope": {"bk_biz_id": 2, "node_type": "TOPO", "object_type": "SERVICE", "nodes": [{"id": 123}]}, + "scope": { + "bk_biz_id": self.TEST_BIZ_ID, + "node_type": "TOPO", + "object_type": "SERVICE", + "nodes": [{"id": 123}], + }, "steps": [ { "id": "my_first", @@ -438,6 +446,10 @@ def test_run_task(self): self._test_check_task_ready(subscription_id=subscription_id, task_id_list=[task_id]) self._test_check_task_not_exist(subscription_id=subscription_id, task_id_list=[task_id]) + def test_disable_biz_subscriptin(self): + GlobalSettings.set_config(GlobalSettings.KeyEnum.DISABLE_SUBSCRIPTION_SCOPE_LIST.value, [self.TEST_BIZ_ID]) + self.assertRaises(AttributeError, self._test_create_subscription) + def test_delete_subscription(self): subscription_id = self._test_create_subscription()