diff --git a/src/api/bkuser_core/api/web/category/serializers.py b/src/api/bkuser_core/api/web/category/serializers.py index ed7978a44..97d037d06 100644 --- a/src/api/bkuser_core/api/web/category/serializers.py +++ b/src/api/bkuser_core/api/web/category/serializers.py @@ -192,6 +192,10 @@ class CategoryFileImportInputSLZ(serializers.Serializer): file = serializers.FileField(required=False) +class CategoryFileImportQuerySLZ(serializers.Serializer): + is_overwrite = serializers.BooleanField(required=False, default=False) + + class CategorySyncResponseOutputSLZ(serializers.Serializer): task_id = serializers.CharField(help_text="task_id for the sync job.") diff --git a/src/api/bkuser_core/api/web/category/views.py b/src/api/bkuser_core/api/web/category/views.py index 0750866ed..7f1a04eb0 100644 --- a/src/api/bkuser_core/api/web/category/views.py +++ b/src/api/bkuser_core/api/web/category/views.py @@ -25,6 +25,7 @@ CategoryExportInputSLZ, CategoryExportProfileOutputSLZ, CategoryFileImportInputSLZ, + CategoryFileImportQuerySLZ, CategoryMetaOutputSLZ, CategoryNamespaceSettingUpdateInputSLZ, CategoryProfileListInputSLZ, @@ -454,6 +455,9 @@ def post(self, request, *args, **kwargs): def _local_category_do_import(self, request, instance): """向本地目录导入数据文件""" + query_slz = CategoryFileImportQuerySLZ(data=request.query_params) + query_slz.is_valid(raise_exception=True) + slz = CategoryFileImportInputSLZ(data=request.data) slz.is_valid(raise_exception=True) @@ -465,7 +469,10 @@ def _local_category_do_import(self, request, instance): raise error_codes.CREATE_SYNC_TASK_FAILED.f(str(e)) instance_id = instance.id - params = {"raw_data_file": slz.validated_data["file"]} + params = { + "raw_data_file": slz.validated_data["file"], + "is_overwrite": query_slz.validated_data["is_overwrite"], + } try: # TODO: FileField 可能不能反序列化, 所以可能不能传到 celery 执行 adapter_sync(instance_id, operator=request.operator, task_id=task_id, **params) diff --git a/src/api/bkuser_core/categories/plugins/local/syncer.py b/src/api/bkuser_core/categories/plugins/local/syncer.py index 2f3e33484..f4d7e3b92 100644 --- a/src/api/bkuser_core/categories/plugins/local/syncer.py +++ b/src/api/bkuser_core/categories/plugins/local/syncer.py @@ -125,13 +125,13 @@ def __post_init__(self): self._default_password_valid_days = int(ConfigProvider(self.category_id).get("password_valid_days")) self.fetcher: ExcelFetcher = self.get_fetcher() - def sync(self, raw_data_file): + def sync(self, raw_data_file, is_overwrite): user_rows, departments = self.fetcher.fetch(raw_data_file) with transaction.atomic(): self._sync_departments(departments) with transaction.atomic(): - self._sync_users(self.fetcher.parser_set, user_rows) + self._sync_users(self.fetcher.parser_set, user_rows, is_overwrite) self._sync_leaders(self.fetcher.parser_set, user_rows) self._notify_init_passwords() @@ -175,7 +175,35 @@ def _judge_data_all_none(raw_data: list) -> bool: """某些状况下会读取 Excel 整个空行""" return all(x is None for x in raw_data) - def _sync_users(self, parser_set: "ParserSet", users: list): + def _department_profile_relation_handle( + self, is_overwrite, department_groups, profile_id, should_deleted_department_profile_relation_ids + ): + cell_parser = DepartmentCellParser(self.category_id) + # 已存在的用户-部门关系 + old_department_profile_relations = DepartmentThroughModel.objects.filter(profile_id=profile_id) + # Note: 有新关系可能存在重复数据,所以这里使用不变的old_department_set用于后续判断是否存在的依据, + # 而不使用后面会变更的old_department_relations数据 + old_department_set = set([r.department_id for r in old_department_profile_relations]) + old_department_relations = {r.department_id: r.id for r in old_department_profile_relations} + + for department in cell_parser.parse_to_db_obj(department_groups): + # 用户-部门关系已存在 + if department.pk in old_department_set: + # Note: 可能本次更新里存在重复数据,dict无法重复移除 + if department.pk in old_department_relations: + del old_department_relations[department.pk] + continue + + # 不存在则添加 + department_attachment = DepartmentThroughModel(department_id=department.pk, profile_id=profile_id) + self.db_sync_manager.magic_add(department_attachment) + + # 已存在的数据从old_department_relations移除后,最后剩下的数据,表示多余的,即本次更新里不存在的用户部门关系 + # 如果是覆盖,则记录需要删除多余数据 + if is_overwrite and len(old_department_relations) > 0: + should_deleted_department_profile_relation_ids.extend(old_department_relations.values()) + + def _sync_users(self, parser_set: "ParserSet", users: list, is_overwrite: bool = False): """在内存中操作&判断数据,bulk 插入""" logger.info("=========== trying to load profiles into memory ===========") @@ -184,6 +212,7 @@ def _sync_users(self, parser_set: "ParserSet", users: list): success_count = 0 total = len(users) + should_deleted_department_profile_relation_ids = [] for index, user_raw_info in enumerate(users): if self._judge_data_all_none(user_raw_info): logger.debug("empty line, skipping") @@ -235,8 +264,13 @@ def _sync_users(self, parser_set: "ParserSet", users: list): progress(index, total, f"loading {username}") try: updating_profile = Profile.objects.get(username=username, category_id=self.category_id) - - # 如果已经存在,则更新该 profile + # 已存在的用户:如果未勾选 <进行覆盖更新>(即is_overwrite为false)=》则忽略,反之则更新该 profile + if not is_overwrite: + logger.debug( + "username %s exist, and is_overwrite is false, so will not do update for this user, skip", + username, + ) + continue for name, value in profile_params.items(): if name == "extras": extras = updating_profile.extras or {} @@ -250,6 +284,7 @@ def _sync_users(self, parser_set: "ParserSet", users: list): setattr(updating_profile, name, value) profile_id = updating_profile.id + self.db_sync_manager.magic_add(updating_profile, SyncOperation.UPDATE.value) logger.debug("(%s/%s) username<%s> already exist, trying to update it", username, index + 1, total) @@ -278,15 +313,12 @@ def _sync_users(self, parser_set: "ParserSet", users: list): # 2 获取关联的部门DB实例,创建关联对象 progress(index, total, "adding profile & department relation") department_groups = parser_set.get_cell_data("department_name", user_raw_info) + self._department_profile_relation_handle( + is_overwrite, department_groups, profile_id, should_deleted_department_profile_relation_ids + ) - cell_parser = DepartmentCellParser(self.category_id) - for department in cell_parser.parse_to_db_obj(department_groups): - relation_params = {"department_id": department.pk, "profile_id": profile_id} - try: - DepartmentThroughModel.objects.get(**relation_params) - except DepartmentThroughModel.DoesNotExist: - department_attachment = DepartmentThroughModel(**relation_params) - self.db_sync_manager.magic_add(department_attachment) + if len(should_deleted_department_profile_relation_ids) > 0: + DepartmentThroughModel.objects.filter(id__in=should_deleted_department_profile_relation_ids).delete() # 需要在处理 leader 之前全部插入 DB self.db_sync_manager[Profile].sync_to_db() diff --git a/src/api/bkuser_core/tests/categories/plugins/local/test_syncer.py b/src/api/bkuser_core/tests/categories/plugins/local/test_syncer.py index e9ff9f281..7d2d7be8b 100644 --- a/src/api/bkuser_core/tests/categories/plugins/local/test_syncer.py +++ b/src/api/bkuser_core/tests/categories/plugins/local/test_syncer.py @@ -105,7 +105,7 @@ def test_sync_users(self, syncer, pre_create_department, users, make_parser_set, ) @pytest.mark.parametrize( - "users,titles,expected", + "users,titles,expected,is_overwrite", [ ( [ @@ -118,10 +118,11 @@ def test_sync_users(self, syncer, pre_create_department, users, make_parser_set, "bbbb": "xxxx@xxxx.xyz", "cccc": "cccc@xxxx.com", }, - ), + True, + ) ], ) - def test_update_existed_users(self, syncer, users, make_parser_set, titles, expected): + def test_update_existed_users(self, syncer, users, make_parser_set, titles, expected, is_overwrite): """测试更新已存在用户""" for u in ["aaaa", "cccc"]: a = make_simple_profile(username=u, force_create_params={"category_id": syncer.category_id}) @@ -130,7 +131,7 @@ def test_update_existed_users(self, syncer, users, make_parser_set, titles, expe # TODO: 当前 id 最大值是在 db_sync_manager 初始化时确定的,实际上并不科学 syncer.db_sync_manager._update_cache() - syncer._sync_users(make_parser_set(titles), users) + syncer._sync_users(make_parser_set(titles), users, is_overwrite=is_overwrite) for k, v in expected.items(): assert Profile.objects.get(category_id=syncer.category_id, username=k).email == v