Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,15 @@ describe('Document By ID API Route', () => {
}),
}

// Mock transaction
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
update: vi.fn().mockReturnValue(updateChain),
}
await callback(mockTx)
})

// Mock db operations in sequence
mockDbChain.update.mockReturnValue(updateChain)
mockDbChain.select.mockReturnValue(selectChain)

const req = createMockRequest('PUT', validUpdateData)
Expand All @@ -231,7 +238,7 @@ describe('Document By ID API Route', () => {
expect(data.success).toBe(true)
expect(data.data.filename).toBe('updated-document.pdf')
expect(data.data.enabled).toBe(false)
expect(mockDbChain.update).toHaveBeenCalled()
expect(mockDbChain.transaction).toHaveBeenCalled()
expect(mockDbChain.select).toHaveBeenCalled()
})

Expand Down Expand Up @@ -298,8 +305,15 @@ describe('Document By ID API Route', () => {
}),
}

// Mock transaction
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
update: vi.fn().mockReturnValue(updateChain),
}
await callback(mockTx)
})

// Mock db operations in sequence
mockDbChain.update.mockReturnValue(updateChain)
mockDbChain.select.mockReturnValue(selectChain)

const req = createMockRequest('PUT', { markFailedDueToTimeout: true })
Expand All @@ -309,7 +323,7 @@ describe('Document By ID API Route', () => {

expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(mockDbChain.update).toHaveBeenCalled()
expect(mockDbChain.transaction).toHaveBeenCalled()
expect(updateChain.set).toHaveBeenCalledWith(
expect.objectContaining({
processingStatus: 'failed',
Expand Down Expand Up @@ -479,7 +493,9 @@ describe('Document By ID API Route', () => {
document: mockDocument,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.set.mockRejectedValue(new Error('Database error'))

// Mock transaction to throw an error
mockDbChain.transaction.mockRejectedValue(new Error('Database error'))

const req = createMockRequest('PUT', validUpdateData)
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
Expand Down
38 changes: 37 additions & 1 deletion apps/sim/app/api/knowledge/[id]/documents/[documentId]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { TAG_SLOTS } from '@/lib/constants/knowledge'
import { createLogger } from '@/lib/logs/console/logger'

export const dynamic = 'force-dynamic'
Expand All @@ -26,6 +27,14 @@ const UpdateDocumentSchema = z.object({
processingError: z.string().optional(),
markFailedDueToTimeout: z.boolean().optional(),
retryProcessing: z.boolean().optional(),
// Tag fields
tag1: z.string().optional(),
tag2: z.string().optional(),
tag3: z.string().optional(),
tag4: z.string().optional(),
tag5: z.string().optional(),
tag6: z.string().optional(),
tag7: z.string().optional(),
})

export async function GET(
Expand Down Expand Up @@ -213,9 +222,36 @@ export async function PUT(
updateData.processingStatus = validatedData.processingStatus
if (validatedData.processingError !== undefined)
updateData.processingError = validatedData.processingError

// Tag field updates
TAG_SLOTS.forEach((slot) => {
if ((validatedData as any)[slot] !== undefined) {
;(updateData as any)[slot] = (validatedData as any)[slot]
}
})
}

await db.update(document).set(updateData).where(eq(document.id, documentId))
await db.transaction(async (tx) => {
// Update the document
await tx.update(document).set(updateData).where(eq(document.id, documentId))

// If any tag fields were updated, also update the embeddings
const hasTagUpdates = TAG_SLOTS.some((field) => (validatedData as any)[field] !== undefined)

if (hasTagUpdates) {
const embeddingUpdateData: Record<string, string | null> = {}
TAG_SLOTS.forEach((field) => {
if ((validatedData as any)[field] !== undefined) {
embeddingUpdateData[field] = (validatedData as any)[field] || null
}
})

await tx
.update(embedding)
.set(embeddingUpdateData)
.where(eq(embedding.documentId, documentId))
}
})

// Fetch the updated document
const updatedDocument = await db
Expand Down
Loading
Loading