Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for a first-party client app to call into Khoj (Part 1) #601

Merged
merged 9 commits into from
Jan 18, 2024
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ dependencies = [
"rapidocr-onnxruntime == 1.3.8",
"stripe == 7.3.0",
"openai-whisper >= 20231117",
"django-phonenumber-field == 7.3.0",
"phonenumbers == 8.13.27",
]
dynamic = ["version"]

Expand Down
1 change: 1 addition & 0 deletions src/khoj/app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"django.contrib.sessions",
"django.contrib.messages",
"django.contrib.staticfiles",
"phonenumber_field",
]

MIDDLEWARE = [
Expand Down
59 changes: 56 additions & 3 deletions src/khoj/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests
import schedule
from django.utils.timezone import make_aware
from fastapi import Response
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
Expand All @@ -20,27 +21,32 @@
from starlette.requests import HTTPConnection

from khoj.database.adapters import (
ClientApplicationAdapters,
ConversationAdapters,
SubscriptionState,
aget_or_create_user_by_phone_number,
aget_user_by_phone_number,
aget_user_subscription_state,
get_all_users,
get_or_create_search_models,
)
from khoj.database.models import KhojUser, Subscription
from khoj.database.models import ClientApplication, KhojUser, Subscription
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, configure_search, load_content
from khoj.utils import constants, state
from khoj.utils.config import SearchType
from khoj.utils.fs_syncer import collect_files
from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import FullConfig

logger = logging.getLogger(__name__)


class AuthenticatedKhojUser(SimpleUser):
def __init__(self, user):
def __init__(self, user, client_app: Optional[ClientApplication] = None):
self.object = user
super().__init__(user.email)
self.client_app = client_app
super().__init__(user.username)


class UserAuthenticationBackend(AuthenticationBackend):
Expand Down Expand Up @@ -108,6 +114,53 @@ async def authenticate(self, request: HTTPConnection):
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
# Get query params for client_id and client_secret
client_id = request.query_params.get("client_id")
if client_id:
# Get the client secret, which is passed in the Authorization header
client_secret = request.headers["Authorization"].split("Bearer ")[1]
if not client_secret:
return Response(
status_code=401,
content="Please provide a client secret in the Authorization header with a client_id query param.",
)

# Get the client application
client_application = await ClientApplicationAdapters.aget_client_application_by_id(client_id, client_secret)
if client_application is None:
return AuthCredentials(), UnauthenticatedUser()
# Get the identifier used for the user
phone_number = request.query_params.get("phone_number")
if is_none_or_empty(phone_number):
return AuthCredentials(), UnauthenticatedUser()

if not phone_number.startswith("+"):
phone_number = f"+{phone_number}"

create_if_not_exists = request.query_params.get("create_if_not_exists")
if create_if_not_exists:
user = await aget_or_create_user_by_phone_number(phone_number)
else:
user = await aget_user_by_phone_number(phone_number)

if user is None:
return AuthCredentials(), UnauthenticatedUser()

if not state.billing_enabled:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user, client_application)

subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return (
AuthCredentials(["authenticated", "premium"]),
AuthenticatedKhojUser(user),
)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user, client_application)
if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user:
Expand Down
43 changes: 39 additions & 4 deletions src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from khoj.database.models import (
ChatModelOptions,
ClientApplication,
Conversation,
Entry,
GithubConfig,
Expand All @@ -40,7 +41,7 @@
from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import generate_random_name
from khoj.utils.helpers import generate_random_name, is_none_or_empty


class SubscriptionState(Enum):
Expand Down Expand Up @@ -85,6 +86,28 @@ async def get_or_create_user(token: dict) -> KhojUser:
return user


async def aget_or_create_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
user = await aget_user_by_phone_number(phone_number)
sabaimran marked this conversation as resolved.
Show resolved Hide resolved
if not user:
user = await acreate_user_by_phone_number(phone_number)
return user


async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
user, _ = await KhojUser.objects.filter(phone_number=phone_number).aupdate_or_create(
sabaimran marked this conversation as resolved.
Show resolved Hide resolved
defaults={"username": phone_number, "phone_number": phone_number}
)
await user.asave()

await Subscription.objects.acreate(user=user, type="trial")

return user


async def get_or_create_user_by_email(email: str) -> KhojUser:
user, _ = await KhojUser.objects.filter(email=email).aupdate_or_create(defaults={"username": email, "email": email})
await user.asave()
Expand Down Expand Up @@ -187,6 +210,12 @@ async def get_user_by_token(token: dict) -> KhojUser:
return google_user.user


async def aget_user_by_phone_number(phone_number: str) -> KhojUser:
if is_none_or_empty(phone_number):
return None
return await KhojUser.objects.filter(phone_number=phone_number).prefetch_related("subscription").afirst()
sabaimran marked this conversation as resolved.
Show resolved Hide resolved


async def retrieve_user(session_id: str) -> KhojUser:
session = SessionStore(session_key=session_id)
if not await sync_to_async(session.exists)(session_key=session_id):
Expand Down Expand Up @@ -270,6 +299,12 @@ async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
return new_config


class ClientApplicationAdapters:
@staticmethod
async def aget_client_application_by_id(client_id: str, client_secret: str):
return await ClientApplication.objects.filter(client_id=client_id, client_secret=client_secret).afirst()


class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser):
Expand All @@ -279,11 +314,11 @@ def get_conversation_by_user(user: KhojUser):
return Conversation.objects.create(user=user)

@staticmethod
async def aget_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
async def aget_conversation_by_user(user: KhojUser, client_application: ClientApplication = None):
conversation = Conversation.objects.filter(user=user, client=client_application)
if await conversation.aexists():
return await conversation.afirst()
return await Conversation.objects.acreate(user=user)
return await Conversation.objects.acreate(user=user, client=client_application)

@staticmethod
async def adelete_conversation_by_user(user: KhojUser):
Expand Down
20 changes: 18 additions & 2 deletions src/khoj/database/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from khoj.database.models import (
ChatModelOptions,
ClientApplication,
Conversation,
KhojUser,
OfflineChatProcessorConversationConfig,
Expand All @@ -19,10 +20,24 @@
UserSearchModelConfig,
)

# Register your models here.

class KhojUserAdmin(UserAdmin):
list_display = (
"id",
"email",
"username",
"is_active",
"is_staff",
"is_superuser",
"phone_number",
)
search_fields = ("email", "username", "phone_number")
filter_horizontal = ("groups", "user_permissions")

fieldsets = (("Personal info", {"fields": ("phone_number",)}),) + UserAdmin.fieldsets


admin.site.register(KhojUser, UserAdmin)
admin.site.register(KhojUser, KhojUserAdmin)

admin.site.register(ChatModelOptions)
admin.site.register(SpeechToTextModelOptions)
Expand All @@ -33,6 +48,7 @@
admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication)


@admin.register(Conversation)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Generated by Django 4.2.7 on 2024-01-04 12:22

import django.db.models.deletion
import phonenumber_field.modelfields
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0024_alter_entry_embeddings"),
]

operations = [
migrations.CreateModel(
name="ClientApplication",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(max_length=200)),
("client_id", models.CharField(max_length=200)),
("client_secret", models.CharField(max_length=200)),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="khojuser",
name="phone_number",
field=phonenumber_field.modelfields.PhoneNumberField(
blank=True, default=None, max_length=128, null=True, region=None
),
),
migrations.AddField(
model_name="conversation",
name="client",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.clientapplication",
),
),
]
13 changes: 13 additions & 0 deletions src/khoj/database/migrations/0027_merge_20240118_1324.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Generated by Django 4.2.7 on 2024-01-18 13:24
from typing import List

from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("database", "0025_clientapplication_khojuser_phone_number_and_more"),
("database", "0026_searchmodelconfig_cross_encoder_inference_endpoint_and_more"),
]

operations: List[str] = []
12 changes: 12 additions & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.contrib.auth.models import AbstractUser
from django.db import models
from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField


class BaseModel(models.Model):
Expand All @@ -13,8 +14,18 @@ class Meta:
abstract = True


class ClientApplication(BaseModel):
name = models.CharField(max_length=200)
client_id = models.CharField(max_length=200)
client_secret = models.CharField(max_length=200)

def __str__(self):
return self.name


class KhojUser(AbstractUser):
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
phone_number = PhoneNumberField(null=True, default=None, blank=True)

def save(self, *args, **kwargs):
if not self.uuid:
Expand Down Expand Up @@ -165,6 +176,7 @@ class UserSearchModelConfig(BaseModel):
class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)


class ReflectiveQuestion(BaseModel):
Expand Down
12 changes: 8 additions & 4 deletions src/khoj/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ async def chat(
n: Optional[int] = 5,
d: Optional[float] = 0.18,
stream: Optional[bool] = False,
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response:
user: KhojUser = request.user.object
q = unquote(q)
Expand All @@ -372,7 +372,7 @@ async def chat(

q = q.replace(f"/{conversation_command.value}", "").strip()

meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
meta_log = (await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app)).conversation_log

compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
Expand All @@ -392,7 +392,11 @@ async def chat(

elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user):
no_entries_found_format = no_entries_found.format()
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
if stream:
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
else:
response_obj = {"response": no_entries_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)

elif conversation_command == ConversationCommand.Online:
try:
Expand Down
Loading
Loading