diff --git a/src/backend/job/biz-job/src/main/kotlin/com/tencent/bkrepo/job/batch/task/cache/preload/ArtifactAccessLogEmbeddingJob.kt b/src/backend/job/biz-job/src/main/kotlin/com/tencent/bkrepo/job/batch/task/cache/preload/ArtifactAccessLogEmbeddingJob.kt index 8725cc8832..cb4987ac31 100644 --- a/src/backend/job/biz-job/src/main/kotlin/com/tencent/bkrepo/job/batch/task/cache/preload/ArtifactAccessLogEmbeddingJob.kt +++ b/src/backend/job/biz-job/src/main/kotlin/com/tencent/bkrepo/job/batch/task/cache/preload/ArtifactAccessLogEmbeddingJob.kt @@ -27,7 +27,8 @@ package com.tencent.bkrepo.job.batch.task.cache.preload -import com.tencent.bkrepo.common.api.util.executeAndMeasureTime +import com.tencent.bkrepo.common.api.exception.ErrorCodeException +import com.tencent.bkrepo.common.api.message.CommonMessageCode import com.tencent.bkrepo.common.artifact.event.base.EventType import com.tencent.bkrepo.common.mongo.constant.ID import com.tencent.bkrepo.common.mongo.constant.MIN_OBJECT_ID @@ -53,6 +54,7 @@ import org.springframework.data.mongodb.core.findOne import org.springframework.data.mongodb.core.query.Criteria import org.springframework.data.mongodb.core.query.Query import org.springframework.data.mongodb.core.query.gte +import org.springframework.data.mongodb.core.query.isEqualTo import org.springframework.stereotype.Component import java.time.Duration import java.time.LocalDate @@ -90,7 +92,7 @@ class ArtifactAccessLogEmbeddingJob( // 上个月的数据不存在时,使用上个月的访问记录生成数据 logger.info("collection[${lastMonthVectorStore.collectionName()}] not exists, try to create") lastMonthVectorStore.createCollection() - lastMonthVectorStore.findAccessLogAndInsert(1L) + findAndHandle(1L, null, null) { lastMonthVectorStore.insert(it.values) } logger.info("insert data into collection[${lastMonthVectorStore.collectionName()}] success") } @@ -98,13 +100,14 @@ class ArtifactAccessLogEmbeddingJob( // 当月数据不存在时候,使用月初至今的访问记录生成数据 logger.info("collection[${curMonthVectorStore.collectionName()}] not exists, try to create") curMonthVectorStore.createCollection() - curMonthVectorStore.findAccessLogAndInsert(0L, before = LocalDate.now().atStartOfDay()) + val startOfToday = LocalDate.now().atStartOfDay() + findAndHandle(0L, null, startOfToday) { curMonthVectorStore.insert(it.values) } } else { // 已有数据,使用昨日数据生成记录 logger.info("collection[${curMonthVectorStore.collectionName()}] exists, insert data of last day") val startOfToday = LocalDate.now().atStartOfDay() val startOfLastDay = LocalDate.now().minusDays(1L).atStartOfDay() - curMonthVectorStore.findAccessLogAndInsert(0L, after = startOfLastDay, before = startOfToday) + findAndHandle(0L, startOfLastDay, startOfToday) { curMonthVectorStore.insert(it.values) } } logger.info("insert data into collection[${curMonthVectorStore.collectionName()}] success") @@ -118,24 +121,42 @@ class ArtifactAccessLogEmbeddingJob( } /** - * 获取访问记录并写入向量数据库 + * 对指定项目的访问日志进行向量化 + * + * @param projectId 需要对访问日志进行向量化的项目 */ - private fun VectorStore.findAccessLogAndInsert( - minusMonth: Long, - after: LocalDateTime? = null, - before: LocalDateTime? = null - ) { - findAndHandle(minusMonth, after, before) { projectId, paths -> - val documents = paths.map { - val metadata = mapOf( - METADATA_KEY_DOWNLOAD_TIMESTAMP to it.value.downloadTimestamp.joinToString(","), - METADATA_KEY_ACCESS_COUNT to it.value.count.toString() - ) - Document(content = it.key, metadata = metadata) - } - val elapsed = measureTimeMillis { insert(documents) } - logger.info("[$projectId] insert ${documents.size} data into [${collectionName()}] in $elapsed ms") + fun embedAccessLog(projectId: String) { + if (!properties.enabled) { + return + } + logger.info("embed [$projectId] access log start") + val lastMonthVectorStore = createVectorStore(1L) + val curMonthVectorStore = createVectorStore(0L) + if (!lastMonthVectorStore.collectionExists() || !curMonthVectorStore.collectionExists()) { + throw ErrorCodeException(CommonMessageCode.RESOURCE_NOT_FOUND, "collection has not been created") + } + findAndHandleByProjectId(projectId, 1L, null, null) { + lastMonthVectorStore.insert(it.values) + } + findAndHandleByProjectId(projectId, 0L, null, LocalDate.now().atStartOfDay()) { + curMonthVectorStore.insert(it.values) + } + logger.info("embed [$projectId] access log finished") + } + + private fun VectorStore.insert(paths: Collection) { + if (paths.isEmpty()) { + return + } + val documents = paths.map { + val metadata = mapOf( + METADATA_KEY_DOWNLOAD_TIMESTAMP to it.downloadTimestamp.joinToString(","), + METADATA_KEY_ACCESS_COUNT to it.count.toString() + ) + Document(content = it.projectRepoFullPath, metadata = metadata) } + val elapsed = measureTimeMillis { insert(documents) } + logger.info("insert ${documents.size} data into [${collectionName()}] in $elapsed ms") } private fun createVectorStore(minusMonth: Long): VectorStore { @@ -150,12 +171,11 @@ class ArtifactAccessLogEmbeddingJob( return MilvusVectorStore(config, milvusClient, embeddingModel) } - private fun findAndHandle( minusMonth: Long, after: LocalDateTime? = null, before: LocalDateTime? = null, - handler: (String, Map) -> Unit + handler: (Map) -> Unit ) { val collectionName = collectionName(minusMonth) // buffer存储的内容结构为(projectId, (path, accessLog)) @@ -170,12 +190,30 @@ class ArtifactAccessLogEmbeddingJob( if (!outOfDateRange && acceptableProject && acceptableType) { val shouldFlush = projectBuffer.addToBuffer(operateLog) if (shouldFlush) { - handler(operateLog.projectId, projectBuffer[operateLog.projectId]!!) + handler(projectBuffer[operateLog.projectId]!!) projectBuffer.remove(operateLog.projectId) } } } - projectBuffer.forEach { (projectId, paths) -> handler(projectId, paths) } + projectBuffer.forEach { (_, paths) -> handler(paths) } + } + + private fun findAndHandleByProjectId( + projectId: String, + minusMonth: Long, + after: LocalDateTime? = null, + before: LocalDateTime? = null, + handler: (Map) -> Unit + ) { + // buffer存储的内容结构为(projectId, (path, accessLog)) + val projectBuffer = HashMap>() + findByProject(projectId, minusMonth, after, before) { operateLog -> + if (projectBuffer.addToBuffer(operateLog)) { + handler(projectBuffer[operateLog.projectId]!!) + projectBuffer.remove(operateLog.projectId) + } + } + handler(projectBuffer[projectId] ?: emptyMap()) } private fun HashMap>.addToBuffer(operateLog: OperateLog): Boolean { @@ -207,11 +245,6 @@ class ArtifactAccessLogEmbeddingJob( logger.warn("mongo collection[$collectionName] not exists") return } - val result = executeAndMeasureTime { - mongoTemplate.count(Query(Criteria.where(ID).gt(startId)), collectionName) - } - logger.info("count $collectionName elapsed[${result.second}]") - val count = result.first var progress = 0 var records: List var lastId = startId @@ -223,7 +256,7 @@ class ArtifactAccessLogEmbeddingJob( progress += records.size if (progress % 1000000 == 0) { val end = System.currentTimeMillis() - logger.info("find access log from db elapsed[${end - start}]ms, $progress/$count") + logger.info("find access log from db elapsed[${end - start}]ms, $progress") } records.forEach { handler(it) } @@ -231,6 +264,65 @@ class ArtifactAccessLogEmbeddingJob( } while (records.size == query.limit && shouldRun()) } + private fun findByProject( + projectId: String, + minusMonth: Long, + after: LocalDateTime?, + before: LocalDateTime?, + handler: (OperateLog) -> Unit, + ) { + val collectionName = collectionName(minusMonth) + if (!mongoTemplate.collectionExists(collectionName)) { + logger.warn("mongo collection[$collectionName] not exists") + return + } + val pageSize = properties.batchSize + var offset = 0L + var resultSize: Int + val criteria = buildProjectCriteria(projectId, after, before) + var progress = 0 + do { + val query = Query(criteria) + .limit(pageSize) + .skip(offset) + .with(Sort.by(TOperateLog::projectId.name).ascending()) + query.fields().include( + ID, + TOperateLog::projectId.name, + TOperateLog::repoName.name, + TOperateLog::type.name, + TOperateLog::resourceKey.name, + TOperateLog::createdDate.name + ) + + val start = System.currentTimeMillis() + val records = mongoTemplate.find(query, OperateLog::class.java, collectionName) + progress += records.size + if (progress % 10000 == 0) { + val end = System.currentTimeMillis() + logger.info("find [$projectId] access log from db elapsed[${end - start}]ms, $progress") + } + + // 记录制品访问时间 + records.forEach { handler(it) } + resultSize = records.size + offset += records.size + } while (resultSize == pageSize) + } + + private fun buildProjectCriteria(projectId: String, after: LocalDateTime?, before: LocalDateTime?): Criteria { + val criteria = Criteria + .where(TOperateLog::projectId.name).isEqualTo(projectId) + .and(TOperateLog::type.name).isEqualTo(EventType.NODE_DOWNLOADED.name) + if (after != null && before != null) { + criteria.and(TOperateLog::createdDate.name).gte(after).lt(before) + } else { + after?.let { criteria.and(TOperateLog::createdDate.name).gte(it) } + before?.let { criteria.and(TOperateLog::createdDate.name).lt(it) } + } + return criteria + } + private fun buildQuery(lastId: ObjectId): Query { val query = Query(Criteria.where(ID).gt(lastId)).limit(properties.batchSize).with(Sort.by(ID).ascending()) query.fields().include( diff --git a/src/backend/job/biz-job/src/main/kotlin/com/tencent/bkrepo/job/controller/user/UserEmbeddingController.kt b/src/backend/job/biz-job/src/main/kotlin/com/tencent/bkrepo/job/controller/user/UserEmbeddingController.kt new file mode 100644 index 0000000000..4c33671f44 --- /dev/null +++ b/src/backend/job/biz-job/src/main/kotlin/com/tencent/bkrepo/job/controller/user/UserEmbeddingController.kt @@ -0,0 +1,28 @@ +package com.tencent.bkrepo.job.controller.user + +import com.tencent.bkrepo.common.api.exception.ErrorCodeException +import com.tencent.bkrepo.common.api.message.CommonMessageCode +import com.tencent.bkrepo.common.security.permission.Principal +import com.tencent.bkrepo.common.security.permission.PrincipalType +import com.tencent.bkrepo.job.batch.task.cache.preload.ArtifactAccessLogEmbeddingJob +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor +import org.springframework.web.bind.annotation.PathVariable +import org.springframework.web.bind.annotation.PostMapping +import org.springframework.web.bind.annotation.RequestMapping +import org.springframework.web.bind.annotation.RestController + +@RestController +@RequestMapping("/api/embedding") +@Principal(type = PrincipalType.ADMIN) +class UserEmbeddingController( + private val artifactAccessLogEmbeddingJob: ArtifactAccessLogEmbeddingJob?, + private val executor: ThreadPoolTaskExecutor, +) { + @PostMapping("/project/{projectId}") + fun embed(@PathVariable projectId: String) { + if (artifactAccessLogEmbeddingJob == null) { + throw ErrorCodeException(CommonMessageCode.SYSTEM_ERROR, "unsupported operation") + } + executor.execute { artifactAccessLogEmbeddingJob.embedAccessLog(projectId) } + } +}