From 71fb08b55e2985e92c92247a7ae29fa1b2d272bb Mon Sep 17 00:00:00 2001 From: jimmoffet Date: Mon, 18 Nov 2024 17:54:01 -0800 Subject: [PATCH 1/2] cost tracking for chats --- backend/apps/rag/clients/vector_client.py | 6 +-- backend/apps/rag/main.py | 15 ++++-- backend/apps/rag/utils.py | 9 ++++ .../internal/migrations/013_add_user_info.py | 2 +- .../internal/migrations/014_add_chat_cost.py | 46 +++++++++++++++++++ .../apps/webui/internal/migrations/README.md | 2 +- backend/apps/webui/models/chats.py | 2 + src/lib/apis/rag/index.ts | 2 +- src/lib/apis/streaming/index.ts | 1 + src/lib/components/chat/Chat.svelte | 32 +++++++------ src/routes/+layout.svelte | 6 +-- 11 files changed, 97 insertions(+), 26 deletions(-) create mode 100644 backend/apps/webui/internal/migrations/014_add_chat_cost.py diff --git a/backend/apps/rag/clients/vector_client.py b/backend/apps/rag/clients/vector_client.py index d1fcbffaf..50344600c 100644 --- a/backend/apps/rag/clients/vector_client.py +++ b/backend/apps/rag/clients/vector_client.py @@ -17,7 +17,7 @@ class VectorItem(BaseModel): class GetResult(BaseModel): - ids: Optional[List[List[str]]] + ids: Optional[List[List[str | int]]] documents: Optional[List[List[str]]] metadatas: Optional[List[List[Any]]] @@ -109,11 +109,11 @@ def get(self, collection_name: str) -> Optional[GetResult]: return GetResult( ids=[[row[0] for row in result]], documents=[[row[1] for row in result]], - metadatas=[row[2] for row in result], + metadatas=[[row[2] for row in result]], ) return None except Exception as e: - log.error("Get Error:", e) + log.error(f"Get Error: {e}") return None def insert(self, collection_name: str, items: list[VectorItem]): diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 8f61c716d..cdfd88118 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -909,6 +909,9 @@ async def store_web_search(form_data: SearchForm, user=Depends(get_current_user) ) +from langchain_experimental.text_splitter import SemanticChunker + + async def store_data_in_vector_db( data, collection_name, overwrite: bool = False ) -> bool: @@ -1243,13 +1246,19 @@ async def scan_database_docs(user=Depends(get_admin_user)): log.info(f"Doc: {doc}") log.info(f"Content: {content}") + docs_in_collection = VECTOR_CLIENT.get(collection_name) + ids = docs_in_collection.ids[0] + chunks = docs_in_collection.documents[0] + metadatas = docs_in_collection.metadatas[0] + + log.info(f"Found {len(chunks)} chunks in collection {collection_name}") + # + # Ensure the content is properly loaded for processing if content and isinstance(content, str): data = content # You can convert to the format needed for further processing - # Here you could add your logic to split documents if needed - await sleep(60) - await store_data_in_vector_db(data, collection_name) + # await store_data_in_vector_db(data, collection_name) except Exception as e: log.exception(e) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 41a98ccd4..2f81b6e3c 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -225,8 +225,17 @@ async def generate_openai_embeddings_async( key: str, url: str = "https://api.openai.com/v1", ): + # The default limit is 100 + # connector = aiohttp.TCPConnector(limit=10) + # async with aiohttp.ClientSession(connector=connector) as session: + + # from aiolimiter import AsyncLimiter + # RATE_LIMIT_IN_MINUTE = 5 + # rate_limiter = AsyncLimiter(RATE_LIMIT_IN_MINUTE, 60.0) + async with aiohttp.ClientSession() as session: if isinstance(text, list): + # async with rate_limiter: embeddings = await generate_openai_batch_embeddings_async( model, text, key, url, session ) diff --git a/backend/apps/webui/internal/migrations/013_add_user_info.py b/backend/apps/webui/internal/migrations/013_add_user_info.py index 0f68669cc..448fdd98a 100644 --- a/backend/apps/webui/internal/migrations/013_add_user_info.py +++ b/backend/apps/webui/internal/migrations/013_add_user_info.py @@ -1,4 +1,4 @@ -"""Peewee migrations -- 002_add_local_sharing.py. +"""Peewee migrations -- 013_add_user_info.py. Some examples (model - class or model name):: diff --git a/backend/apps/webui/internal/migrations/014_add_chat_cost.py b/backend/apps/webui/internal/migrations/014_add_chat_cost.py new file mode 100644 index 000000000..022781c30 --- /dev/null +++ b/backend/apps/webui/internal/migrations/014_add_chat_cost.py @@ -0,0 +1,46 @@ +"""Peewee migrations -- 013_add_user_info.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields("chat", cost=pw.FloatField(default=0.0)) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("chat", "cost") diff --git a/backend/apps/webui/internal/migrations/README.md b/backend/apps/webui/internal/migrations/README.md index 260214113..8f0eea6da 100644 --- a/backend/apps/webui/internal/migrations/README.md +++ b/backend/apps/webui/internal/migrations/README.md @@ -14,7 +14,7 @@ You will need to create a migration file to ensure that existing databases are u 2. Make your changes to the models. 3. From the `backend` directory, run the following command: ```bash - pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME} + pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/webui/internal/migrations ${MIGRATION_NAME} ``` - `$SQLITE_DB` should be the path to the database file. - `$MIGRATION_NAME` should be a descriptive name for the migration. diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index a6f1ae923..294a824d7 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -26,6 +26,8 @@ class Chat(Model): share_id = CharField(null=True, unique=True) archived = BooleanField(default=False) + cost = FloatField(default=0.0) + class Meta: database = DB diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index ca68827a3..0e8492c63 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -336,7 +336,7 @@ export const queryCollection = async ( export const scanDocs = async (token: string) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/scan`, { + const res = await fetch(`${RAG_API_BASE_URL}/scan_database_docs`, { method: 'GET', headers: { Accept: 'application/json', diff --git a/src/lib/apis/streaming/index.ts b/src/lib/apis/streaming/index.ts index 57d04014a..6f152b9d9 100644 --- a/src/lib/apis/streaming/index.ts +++ b/src/lib/apis/streaming/index.ts @@ -85,6 +85,7 @@ async function* openAIStreamToIterator( }; } catch (e) { console.error('Error extracting delta from SSE event:', e); + console.error('Error extracting delta from SSE event with data:', data); } } } diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index c08a92111..febdadfcb 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -218,6 +218,8 @@ ? chatContent.history : convertMessagesToHistory(chatContent.messages); title = chatContent.title; + currentDollarAmount = chatContent.cost ?? 0.0; + console.log('Found cost in stored chat', currentDollarAmount); const userSettings = await getUserSettings(localStorage.token); @@ -969,7 +971,7 @@ break; } - console.log('update', update); + // console.log('update', update); if (done || stopResponseFlag || _chatId !== $chatId) { responseMessage.done = true; @@ -1058,12 +1060,26 @@ responseMessage.info = { ...lastUsage, openai: true }; } + if (responseMessage.content) { + // get model name + const modelName = model.name ?? model.id; + console.log(`modelName: ${modelName}`); // i.e. FedRamp High Azure GPT 4 Omni + // count and log the number of words in the response + const tokens = Math.round(responseMessage.content.split(' ').length / 0.75); + console.log(`Response contains ${tokens} tokens`); + const tokenCost = tokens * 0.00001; + currentDollarAmount += tokenCost; + currentDollarAmount = Math.round(currentDollarAmount * 10000) / 10000; + console.log(`Cost: ${currentDollarAmount}`); + } + if ($chatId == _chatId) { if ($settings.saveChatHistory ?? true) { chat = await updateChatById(localStorage.token, _chatId, { models: selectedModels, messages: messages, - history: history + history: history, + cost: currentDollarAmount }); await chats.set(await getChatList(localStorage.token)); } @@ -1101,18 +1117,6 @@ scrollToBottom(); } - if (responseMessage.content) { - // get model name - const modelName = model.name ?? model.id; - console.log(`modelName: ${modelName}`); // i.e. FedRamp High Azure GPT 4 Omni - // count and log the number of words in the response - const tokens = Math.round(responseMessage.content.split(' ').length / 0.75); - console.log(`Response contains ${tokens} tokens`); - const tokenCost = tokens * 0.00001; - currentDollarAmount += tokenCost; - currentDollarAmount = Math.round(currentDollarAmount * 10000) / 10000; - } - if (messages.length == 2) { window.history.replaceState(history.state, '', `/c/${_chatId}`); diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index 1e0c8db75..877fae970 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -95,9 +95,9 @@ await socket.set(_socket); - _socket.on('connect_error', (err) => { - console.log('connect_error', err); - }); + // _socket.on('connect_error', (err) => { + // console.log('connect_error', err); + // }); _socket.on('connect', () => { console.log('connected'); From b75ad57483960340a8bed9edf80fb17dfe887bd6 Mon Sep 17 00:00:00 2001 From: jimmoffet Date: Fri, 22 Nov 2024 22:15:02 -0800 Subject: [PATCH 2/2] add census --- backend/apps/rag/main.py | 3 --- backend/apps/webui/routers/auths.py | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index cdfd88118..f47ac25fb 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -909,9 +909,6 @@ async def store_web_search(form_data: SearchForm, user=Depends(get_current_user) ) -from langchain_experimental.text_splitter import SemanticChunker - - async def store_data_in_vector_db( data, collection_name, overwrite: bool = False ) -> bool: diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index 5ef9d8fd0..f22f9ce37 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -230,7 +230,7 @@ async def signin_oauth(request: Request, provider: str, form_data: SigninFormOau user_email_domains.append(domain) if provider == "github": log.error(f"github provided email is: {this_email}") - if domain in ["gsa.gov"]: + if domain in ["gsa.gov", "census.gov"]: email = this_email user_has_permitted_domain = True break @@ -411,7 +411,7 @@ async def signup(request: Request, form_data: SignupForm): names = names[0].split(".") if names and len(names) > 1: first_name = names[0] - last_name = names[1] + last_name = names[-1] if first_name and last_name: name = first_name.capitalize() + " " + last_name.capitalize()