Skip to content

Commit

Permalink
Redesign for optional titles or tags
Browse files Browse the repository at this point in the history
  • Loading branch information
icereed committed Oct 7, 2024
1 parent 4776486 commit 96dc250
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 178 deletions.
129 changes: 96 additions & 33 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,28 @@ type GetDocumentsApiResponse struct {
} `json:"results"`
}

// Document is a stripped down version of the document object from paperless-ngx.
// Response payload for /documents endpoint and part of request payload for /generate-suggestions endpoint
type Document struct {
ID int `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
Tags []string `json:"tags"`
SuggestedTitle string `json:"suggested_title,omitempty"`
SuggestedTags []string `json:"suggested_tags,omitempty"`
ID int `json:"id"`
Title string `json:"title"`
Content string `json:"content"`
Tags []string `json:"tags"`
}

// GenerateSuggestionsRequest is the request payload for generating suggestions for /generate-suggestions endpoint
type GenerateSuggestionsRequest struct {
Documents []Document `json:"documents"`
GenerateTitles bool `json:"generate_titles,omitempty"`
GenerateTags bool `json:"generate_tags,omitempty"`
}

// DocumentSuggestion is the response payload for /generate-suggestions endpoint and the request payload for /update-documents endpoint (as an array)
type DocumentSuggestion struct {
ID int `json:"id"`
OriginalDocument Document `json:"original_document"`
SuggestedTitle string `json:"suggested_title,omitempty"`
SuggestedTags []string `json:"suggested_tags,omitempty"`
}

var (
Expand Down Expand Up @@ -207,14 +222,14 @@ func documentsHandler(c *gin.Context) {
func generateSuggestionsHandler(c *gin.Context) {
ctx := c.Request.Context()

var documents []Document
if err := c.ShouldBindJSON(&documents); err != nil {
var suggestionRequest GenerateSuggestionsRequest
if err := c.ShouldBindJSON(&suggestionRequest); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
log.Printf("Invalid request payload: %v", err)
return
}

results, err := processDocuments(ctx, documents)
results, err := generateDocumentSuggestions(ctx, suggestionRequest)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Error processing documents: %v", err)})
log.Printf("Error processing documents: %v", err)
Expand All @@ -227,7 +242,7 @@ func generateSuggestionsHandler(c *gin.Context) {
// updateDocumentsHandler updates documents with new titles
func updateDocumentsHandler(c *gin.Context) {
ctx := c.Request.Context()
var documents []Document
var documents []DocumentSuggestion
if err := c.ShouldBindJSON(&documents); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request payload: %v", err)})
log.Printf("Invalid request payload: %v", err)
Expand Down Expand Up @@ -348,7 +363,7 @@ func getDocumentsByTags(ctx context.Context, baseURL, apiToken string, tags []st
return documents, nil
}

func processDocuments(ctx context.Context, documents []Document) ([]Document, error) {
func generateDocumentSuggestions(ctx context.Context, suggestionRequest GenerateSuggestionsRequest) ([]DocumentSuggestion, error) {
llm, err := createLLM()
if err != nil {
return nil, fmt.Errorf("failed to create LLM client: %v", err)
Expand All @@ -369,6 +384,9 @@ func processDocuments(ctx context.Context, documents []Document) ([]Document, er
availableTagNames = append(availableTagNames, tagName)
}

documents := suggestionRequest.Documents
documentSuggestions := []DocumentSuggestion{}

var wg sync.WaitGroup
var mu sync.Mutex
errors := make([]error, 0)
Expand All @@ -385,27 +403,50 @@ func processDocuments(ctx context.Context, documents []Document) ([]Document, er
content = content[:5000]
}

suggestedTitle, err := getSuggestedTitle(ctx, llm, content)
if err != nil {
mu.Lock()
errors = append(errors, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
log.Printf("Error processing document %d: %v", documentID, err)
return
var suggestedTitle string
var suggestedTags []string

if suggestionRequest.GenerateTitles {
suggestedTitle, err = getSuggestedTitle(ctx, llm, content)
if err != nil {
mu.Lock()
errors = append(errors, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
log.Printf("Error processing document %d: %v", documentID, err)
return
}
}

suggestedTags, err := getSuggestedTags(ctx, llm, content, suggestedTitle, availableTagNames)
if err != nil {
mu.Lock()
errors = append(errors, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
log.Printf("Error generating tags for document %d: %v", documentID, err)
return
if suggestionRequest.GenerateTags {
suggestedTags, err = getSuggestedTags(ctx, llm, content, suggestedTitle, availableTagNames)
if err != nil {
mu.Lock()
errors = append(errors, fmt.Errorf("Document %d: %v", documentID, err))
mu.Unlock()
log.Printf("Error generating tags for document %d: %v", documentID, err)
return
}
}

mu.Lock()
doc.SuggestedTitle = suggestedTitle
doc.SuggestedTags = suggestedTags
suggestion := DocumentSuggestion{
ID: documentID,
OriginalDocument: *doc,
}
// Titles
if suggestionRequest.GenerateTitles {
suggestion.SuggestedTitle = suggestedTitle
} else {
suggestion.SuggestedTitle = doc.Title
}

// Tags
if suggestionRequest.GenerateTags {
suggestion.SuggestedTags = suggestedTags
} else {
suggestion.SuggestedTags = removeTagFromList(doc.Tags, tagToFilter)
}
documentSuggestions = append(documentSuggestions, suggestion)
mu.Unlock()
log.Printf("Document %d processed successfully.", documentID)
}(&documents[i])
Expand All @@ -417,7 +458,17 @@ func processDocuments(ctx context.Context, documents []Document) ([]Document, er
return nil, errors[0]
}

return documents, nil
return documentSuggestions, nil
}

func removeTagFromList(tags []string, tagToRemove string) []string {
filteredTags := []string{}
for _, tag := range tags {
if tag != tagToRemove {
filteredTags = append(filteredTags, tag)
}
}
return filteredTags
}

func getSuggestedTags(ctx context.Context, llm llms.Model, content string, suggestedTitle string, availableTags []string) ([]string, error) {
Expand Down Expand Up @@ -507,7 +558,7 @@ Content:
return strings.TrimSpace(strings.Trim(completion.Choices[0].Content, "\"")), nil
}

func updateDocuments(ctx context.Context, baseURL, apiToken string, documents []Document) error {
func updateDocuments(ctx context.Context, baseURL, apiToken string, documents []DocumentSuggestion) error {
client := &http.Client{}

// Fetch all available tags
Expand All @@ -524,8 +575,13 @@ func updateDocuments(ctx context.Context, baseURL, apiToken string, documents []

newTags := []int{}

tags := document.SuggestedTags
if len(tags) == 0 {
tags = document.OriginalDocument.Tags
}

// Map suggested tag names to IDs
for _, tagName := range document.SuggestedTags {
for _, tagName := range tags {
if tagID, exists := availableTags[tagName]; exists {
// Skip the tag that we are filtering
if tagName == tagToFilter {
Expand All @@ -537,13 +593,20 @@ func updateDocuments(ctx context.Context, baseURL, apiToken string, documents []
}
}

updatedFields["tags"] = newTags

if len(newTags) > 0 {
updatedFields["tags"] = newTags
} else {
log.Printf("No valid tags found for document %d, skipping.", documentID)
}
suggestedTitle := document.SuggestedTitle
if len(suggestedTitle) > 128 {
suggestedTitle = suggestedTitle[:128]
}
updatedFields["title"] = suggestedTitle
if suggestedTitle != "" {
updatedFields["title"] = suggestedTitle
} else {
log.Printf("No valid title found for document %d, skipping.", documentID)
}

// Send the update request
url := fmt.Sprintf("%s/api/documents/%d/", baseURL, documentID)
Expand Down
Loading

0 comments on commit 96dc250

Please sign in to comment.