Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove user customized search model #946

Merged
merged 14 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 1 addition & 22 deletions src/interface/web/app/settings/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ export default function SettingsView() {
};

const updateModel = (name: string) => async (id: string) => {
if (!userConfig?.is_active && name !== "search") {
if (!userConfig?.is_active) {
toast({
title: `Model Update`,
description: `You need to be subscribed to update ${name} models`,
Expand Down Expand Up @@ -1233,27 +1233,6 @@ export default function SettingsView() {
</CardFooter>
</Card>
)}
{userConfig.search_model_options.length > 0 && (
<Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row">
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
Search
</CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400">
Pick the search model to find your documents
</p>
<DropdownComponent
items={userConfig.search_model_options}
selected={
userConfig.selected_search_model_config
}
callbackFunc={updateModel("search")}
/>
</CardContent>
<CardFooter className="flex flex-wrap gap-4"></CardFooter>
</Card>
)}
{userConfig.paint_model_options.length > 0 && (
<Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row">
Expand Down
33 changes: 13 additions & 20 deletions src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,18 +466,26 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config


def get_user_search_model_or_default(user=None):
if user and UserSearchModelConfig.objects.filter(user=user).exists():
return UserSearchModelConfig.objects.filter(user=user).first().setting
def get_default_search_model() -> SearchModelConfig:
default_search_model = SearchModelConfig.objects.filter(name="default").first()

if SearchModelConfig.objects.filter(name="default").exists():
return SearchModelConfig.objects.filter(name="default").first()
if default_search_model:
return default_search_model
else:
SearchModelConfig.objects.create()

return SearchModelConfig.objects.first()


def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
if user:
user_search_model = UserSearchModelConfig.objects.filter(user=user).first()
if user_search_model:
return user_search_model.setting

return get_default_search_model()


def get_or_create_search_models():
search_models = SearchModelConfig.objects.all()
if search_models.count() == 0:
Expand All @@ -487,21 +495,6 @@ def get_or_create_search_models():
return search_models


async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config


async def aget_user_search_model(user: KhojUser):
config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting


class ProcessLockAdapters:
@staticmethod
def get_process_lock(process_name: str):
Expand Down
2 changes: 2 additions & 0 deletions src/khoj/database/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class EntryAdmin(admin.ModelAdmin):
"created_at",
"updated_at",
"user",
"agent",
"file_source",
"file_type",
"file_name",
Expand All @@ -135,6 +136,7 @@ class EntryAdmin(admin.ModelAdmin):
list_filter = (
"file_type",
"user__email",
"search_model__name",
)
ordering = ("-created_at",)

Expand Down
182 changes: 182 additions & 0 deletions src/khoj/database/management/commands/change_default_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import logging
from typing import List

from django.core.management.base import BaseCommand
from django.db import transaction
from django.db.models import Count, Q
from tqdm import tqdm

from khoj.database.adapters import get_default_search_model
from khoj.database.models import (
Agent,
Entry,
KhojUser,
SearchModelConfig,
UserSearchModelConfig,
)
from khoj.processor.embeddings import EmbeddingsModel

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class Command(BaseCommand):
help = "Convert all existing Entry objects to use a new default Search model."

def add_arguments(self, parser):
# Pass default SearchModelConfig ID
parser.add_argument(
"--search_model_id",
action="store",
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.",
required=True,
)

# Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects.
parser.add_argument(
"--apply",
action="store_true",
help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.",
)

def handle(self, *args, **options):
@transaction.atomic
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
entries = Entry.objects.filter(entry_filter).all()
compiled_entries = [entry.compiled for entry in entries]
updated_entries: List[Entry] = []
try:
embeddings = embeddings_model.embed_documents(compiled_entries)

except Exception as e:
logger.error(f"Error embedding documents: {e}")
return

for i, entry in enumerate(tqdm(entries)):
entry.embeddings = embeddings[i]
entry.search_model_id = search_model.id
updated_entries.append(entry)

Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])

search_model_config_id = options.get("search_model_id")
apply = options.get("apply")

logger.info(f"SearchModelConfig ID: {search_model_config_id}")
logger.info(f"Apply: {apply}")

embeddings_model = dict()

search_models = SearchModelConfig.objects.all()
for model in search_models:
embeddings_model.update(
{
model.name: EmbeddingsModel(
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,
)
}
)

new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
logger.info(f"New default Search model: {new_default_search_model_config}")
user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude(
setting_id=search_model_config_id
).all()
logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}")

for user_config in user_search_model_configs_to_update:
affected_user = user_config.user
entry_filter = Q(user=affected_user)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}")

if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)
user_config.setting = new_default_search_model_config
user_config.save()

logger.info(
f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model."
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model."
)

except Exception as e:
logger.error(f"Error embedding documents: {e}")

logger.info("----")

# There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id.
current_default = get_default_search_model()
if current_default.id != new_default_search_model_config.id:
users_without_user_search_model_config = KhojUser.objects.annotate(
user_search_model_config_count=Count("usersearchmodelconfig")
).filter(user_search_model_config_count=0)

logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}")
for user in users_without_user_search_model_config:
entry_filter = Q(user=user)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}")

if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)

UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config)

logger.info(
f"Created UserSearchModelConfig object for user {user} to use the new default Search model."
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model."
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
else:
logger.info("Default is the same as search_model_config_id.")

all_agents = Agent.objects.all()
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
for agent in all_agents:
entry_filter = Q(agent=agent)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for agent {agent}: {relevant_entries.count()}")

if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for agent {agent} to use the new default Search model."
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
if apply and current_default.id != new_default_search_model_config.id:
# Get the existing default SearchModelConfig object and update its name
current_default.name = f"prev_default_{current_default.id}"
current_default.save()

# Update the new default SearchModelConfig object's name
new_default_search_model_config.name = "default"
new_default_search_model_config.save()
if not apply:
logger.info("Run the command with the --apply flag to apply the new default Search model.")
24 changes: 24 additions & 0 deletions src/khoj/database/migrations/0072_entry_search_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 5.0.8 on 2024-10-21 21:09

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0071_subscription_enabled_trial_at_and_more"),
]

operations = [
migrations.AddField(
model_name="entry",
name="search_model",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="database.searchmodelconfig",
),
),
]
2 changes: 2 additions & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ class UserVoiceModelConfig(BaseModel):
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)


# TODO Delete this model once all users have been migrated to the server's default settings
class UserSearchModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
Expand Down Expand Up @@ -535,6 +536,7 @@ class EntrySource(models.TextChoices):
url = models.URLField(max_length=400, default=None, null=True, blank=True)
hashed_value = models.CharField(max_length=100)
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)

def save(self, *args, **kwargs):
if self.user and self.agent:
Expand Down
6 changes: 4 additions & 2 deletions src/khoj/processor/content/text_to_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from khoj.database.adapters import (
EntryAdapters,
FileObjectAdapters,
get_user_search_model_or_default,
get_default_search_model,
get_user_default_search_model,
)
from khoj.database.models import Entry as DbEntry
from khoj.database.models import EntryDates, KhojUser
Expand Down Expand Up @@ -148,10 +149,10 @@ def update_embeddings(
hashes_to_process |= hashes_for_file - existing_entry_hashes

embeddings = []
model = get_user_default_search_model(user=user)
with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
model = get_user_search_model_or_default(user)
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)

added_entries: list[DbEntry] = []
Expand All @@ -177,6 +178,7 @@ def update_embeddings(
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
search_model=model,
)
)
try:
Expand Down
5 changes: 3 additions & 2 deletions src/khoj/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
AutomationAdapters,
ConversationAdapters,
EntryAdapters,
get_default_search_model,
get_user_default_search_model,
get_user_photo,
get_user_search_model_or_default,
)
from khoj.database.models import (
Agent,
Expand Down Expand Up @@ -149,7 +150,7 @@ async def execute_search(
encoded_asymmetric_query = None
if t != SearchType.Image:
with timer("Encoding query took", logger=logger):
search_model = await sync_to_async(get_user_search_model_or_default)(user)
search_model = await sync_to_async(get_user_default_search_model)(user)
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)

with concurrent.futures.ThreadPoolExecutor() as executor:
Expand Down
Loading
Loading