Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Indexing Images via OCR #823

Merged
merged 15 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/khoj/interface/web/chat.html
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@

To get started, just start typing below. You can also type / to see a list of commands.
`.trim()
const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf'];
const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf'];
const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'image/jpeg', 'image/png'];
const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'jpg', 'jpeg', 'png'];
MythicalCow marked this conversation as resolved.
Show resolved Hide resolved
let chatOptions = [];
function createCopyParentText(message) {
return function(event) {
Expand Down Expand Up @@ -903,7 +903,12 @@
fileType = "text/html";
} else if (fileExtension === "pdf") {
fileType = "application/pdf";
} else {
} else if (fileExtension === "jpg" || fileExtension === "jpeg"){
fileType = "image/jpeg";
} else if (fileExtension === "png") {
fileType = "image/png";
}
else {
// Skip this file if its type is not supported
resolve();
return;
Expand Down
Empty file.
124 changes: 124 additions & 0 deletions src/khoj/processor/content/images/image_to_entries.py
MythicalCow marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import base64
import logging
import os
from datetime import datetime
from typing import Dict, List, Tuple

from rapidocr_onnxruntime import RapidOCR

from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry

logger = logging.getLogger(__name__)


class ImageToEntries(TextToEntries):
def __init__(self):
super().__init__()

# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None

# Extract Entries from specified Pdf files
with timer("Extract entries from specified PDF files", logger):
MythicalCow marked this conversation as resolved.
Show resolved Hide resolved
file_to_text_map, current_entries = ImageToEntries.extract_image_entries(files)

# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)

# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.IMAGE,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)

return num_new_embeddings, num_deleted_embeddings

@staticmethod
def extract_image_entries(image_files) -> Tuple[Dict, List[Entry]]: # important function
"""Extract entries by page from specified image files"""
file_to_text_map = dict()
entries: List[str] = []
entry_to_location_map: List[Tuple[str, str]] = []
for image_file in image_files:
try:
try:
loader = RapidOCR()
bytes = image_files[image_file]
# write the image to a temporary file
timestamp_now = datetime.utcnow().timestamp()
timestamp_now = datetime.utcnow().timestamp()
MythicalCow marked this conversation as resolved.
Show resolved Hide resolved
# use either png or jpg
if image_file.endswith(".png"):
tmp_file = f"tmp_image_file_{timestamp_now}.png"
elif image_file.endswith(".jpg") or image_file.endswith(".jpeg"):
tmp_file = f"tmp_image_file_{timestamp_now}.jpg"
with open(f"{tmp_file}", "wb") as f:
MythicalCow marked this conversation as resolved.
Show resolved Hide resolved
bytes = image_files[image_file]
f.write(bytes)
try:
image_entries_per_file = []
result, _ = loader(tmp_file)
if result:
image_entries_per_file = [text[1] for text in result]
except ImportError:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
continue
except ImportError:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
continue
entry_to_location_map += zip(
image_entries_per_file, [image_file] * len(image_entries_per_file)
) # this is an indexed map of image_entries for the pdf.
MythicalCow marked this conversation as resolved.
Show resolved Hide resolved
entries.extend(image_entries_per_file)
file_to_text_map[image_file] = image_entries_per_file
except Exception as e:
MythicalCow marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
finally:
if os.path.exists(tmp_file):
os.remove(tmp_file)
return file_to_text_map, ImageToEntries.convert_image_entries_to_maps(entries, dict(entry_to_location_map))

@staticmethod
def convert_image_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
"Convert each image entries into a dictionary"
entries = []
for parsed_entry in parsed_entries:
entry_filename = entry_to_file_map[parsed_entry]
# Append base filename to compiled entry for context to model
heading = f"{entry_filename}\n"
compiled_entry = f"{heading}{parsed_entry}"
entries.append(
Entry(
compiled=compiled_entry,
raw=parsed_entry,
heading=heading,
file=f"{entry_filename}",
)
)

logger.debug(f"Converted {len(parsed_entries)} image entries to dictionaries")

return entries
23 changes: 22 additions & 1 deletion src/khoj/routers/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from khoj.database.models import GithubConfig, KhojUser, NotionConfig
from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.content.images.image_to_entries import ImageToEntries
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.notion.notion_to_entries import NotionToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
Expand Down Expand Up @@ -40,6 +41,7 @@ class IndexerInput(BaseModel):
markdown: Optional[dict[str, str]] = None
pdf: Optional[dict[str, bytes]] = None
plaintext: Optional[dict[str, str]] = None
image: Optional[dict[str, bytes]] = None


@indexer.post("/update")
Expand All @@ -63,7 +65,7 @@ async def update(
),
):
user = request.user.object
index_files: Dict[str, Dict[str, str]] = {"org": {}, "markdown": {}, "pdf": {}, "plaintext": {}}
index_files: Dict[str, Dict[str, str]] = {"org": {}, "markdown": {}, "pdf": {}, "plaintext": {}, "image": {}}
try:
logger.info(f"📬 Updating content index via API call by {client} client")
for file in files:
Expand All @@ -79,6 +81,7 @@ async def update(
markdown=index_files["markdown"],
pdf=index_files["pdf"],
plaintext=index_files["plaintext"],
image=index_files["image"],
)

if state.config == None:
Expand Down Expand Up @@ -129,6 +132,7 @@ async def update(
"num_markdown": len(index_files["markdown"]),
"num_pdf": len(index_files["pdf"]),
"num_plaintext": len(index_files["plaintext"]),
"num_image": len(index_files["image"]),
}

update_telemetry_state(
Expand Down Expand Up @@ -295,6 +299,23 @@ def configure_content(
logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True)
success = False

try:
# Initialize Image Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) and files[
"image"
]:
logger.info("🖼️ Setting up search for images")
# Extract Entries, Generate Image Embeddings
text_search.setup(
ImageToEntries,
files.get("image"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
success = False
# Invalidate Query Cache
if user:
state.query_cache[user.uuid] = LRU()
Expand Down
4 changes: 2 additions & 2 deletions src/khoj/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
elif file_type in ["application/pdf"]:
return "pdf", encoding
elif file_type in ["image/jpeg"]:
return "jpeg", encoding
return "image", encoding
elif file_type in ["image/png"]:
return "png", encoding
return "image", encoding
elif content_group in ["code", "text"]:
return "plaintext", encoding
else:
Expand Down
1 change: 1 addition & 0 deletions src/khoj/utils/rawconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class ContentConfig(ConfigBase):
plaintext: Optional[TextContentConfig] = None
github: Optional[GithubContentConfig] = None
notion: Optional[NotionContentConfig] = None
image: Optional[TextContentConfig] = None


class ImageSearchConfig(ConfigBase):
Expand Down
Loading