Skip to content

Commit

Permalink
Index all text, code files in Github repos. Not just md, org files
Browse files Browse the repository at this point in the history
  • Loading branch information
debanjum committed Apr 9, 2024
1 parent 510399f commit 6c45ce2
Showing 1 changed file with 50 additions and 11 deletions.
61 changes: 50 additions & 11 deletions src/khoj/processor/content/github/github_to_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
from typing import Any, Dict, List, Tuple

import requests
from magika import Magika

from khoj.database.models import Entry as DbEntry
from khoj.database.models import GithubConfig, KhojUser
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry, GithubContentConfig, GithubRepoConfig
from khoj.utils.rawconfig import GithubContentConfig, GithubRepoConfig

logger = logging.getLogger(__name__)
magika = Magika()


class GithubToEntries(TextToEntries):
Expand Down Expand Up @@ -61,7 +64,7 @@ def process_repo(self, repo: GithubRepoConfig):
repo_url = f"https://api.github.com/repos/{repo.owner}/{repo.name}"
repo_shorthand = f"{repo.owner}/{repo.name}"
logger.info(f"Processing github repo {repo_shorthand}")
with timer("Download markdown files from github repo", logger):
with timer("Download files from github repo", logger):
try:
markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo)
except ConnectionAbortedError as e:
Expand All @@ -70,8 +73,9 @@ def process_repo(self, repo: GithubRepoConfig):
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
raise e

logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}")
logger.info(f"Found {len(org_files)} org files in github repo {repo_shorthand}")
logger.info(
f"Found {len(markdown_files)} md, {len(org_files)} org and {len(plaintext_files)} text files in github repo {repo_shorthand}"
)
current_entries = []

with timer(f"Extract markdown entries from github repo {repo_shorthand}", logger):
Expand All @@ -84,6 +88,11 @@ def process_repo(self, repo: GithubRepoConfig):
*GithubToEntries.extract_org_entries(org_files)
)

with timer(f"Extract plaintext entries from github repo {repo_shorthand}", logger):
current_entries += PlaintextToEntries.convert_text_files_to_entries(
*GithubToEntries.extract_plaintext_entries(plaintext_files)
)

with timer(f"Split entries by max token size supported by model {repo_shorthand}", logger):
current_entries = TextToEntries.split_entries_by_max_tokens(current_entries, max_tokens=256)

Expand Down Expand Up @@ -116,10 +125,11 @@ def get_files(self, repo_url: str, repo: GithubRepoConfig):
raise ConnectionAbortedError("Github rate limit reached")

# Extract markdown files from the repository
markdown_files: List[Any] = []
org_files: List[Any] = []
markdown_files: List[Dict[str, str]] = []
org_files: List[Dict[str, str]] = []
plaintext_files: List[Dict[str, str]] = []
if "tree" not in contents:
return markdown_files, org_files
return markdown_files, org_files, plaintext_files

for item in contents["tree"]:
# Find all markdown files in the repository
Expand All @@ -138,9 +148,27 @@ def get_files(self, repo_url: str, repo: GithubRepoConfig):
# Add org file contents and URL to list
org_files += [{"content": self.get_file_contents(item["url"]), "path": url_path}]

return markdown_files, org_files
# Find, index remaining non-binary files in the repository
elif item["type"] == "blob":
url_path = f'https://github.com/{repo.owner}/{repo.name}/blob/{repo.branch}/{item["path"]}'
content_bytes = self.get_file_contents(item["url"], decode=False)
content_type, content_str = None, None
try:
content_type = magika.identify_bytes(content_bytes).output.mime_type
content_str = content_bytes.decode("utf-8")
except:
logger.error(
f"Unable to identify content type or decode content of file at {url_path}. Skip indexing it"
)
continue

# Add non-binary file contents and URL to list
if content_type.startswith("text/"):
plaintext_files += [{"content": content_str, "path": url_path}]

def get_file_contents(self, file_url):
return markdown_files, org_files, plaintext_files

def get_file_contents(self, file_url, decode=True):
# Get text from each markdown file
headers = {"Accept": "application/vnd.github.v3.raw"}
response = self.session.get(file_url, headers=headers, stream=True)
Expand All @@ -149,11 +177,11 @@ def get_file_contents(self, file_url):
if response.status_code != 200 and response.headers.get("X-RateLimit-Remaining") == "0":
raise ConnectionAbortedError("Github rate limit reached")

content = ""
content = "" if decode else b""
for chunk in response.iter_content(chunk_size=2048):
if chunk:
try:
content += chunk.decode("utf-8")
content += chunk.decode("utf-8") if decode else chunk
except Exception as e:
logger.error(f"Unable to decode chunk from {file_url}")
logger.error(e)
Expand All @@ -180,3 +208,14 @@ def extract_org_entries(org_files):
doc["content"], doc["path"], entries, entry_to_file_map
)
return entries, dict(entry_to_file_map)

@staticmethod
def extract_plaintext_entries(plaintext_files):
entries = []
entry_to_file_map = []

for doc in plaintext_files:
entries, entry_to_file_map = PlaintextToEntries.process_single_plaintext_file(
doc["content"], doc["path"], entries, entry_to_file_map
)
return entries, dict(entry_to_file_map)

0 comments on commit 6c45ce2

Please sign in to comment.