Skip to content

Commit

Permalink
Merge pull request #117 from GSA-TTS/batch-embedding
Browse files Browse the repository at this point in the history
Census to domains
  • Loading branch information
jimmoffet authored Nov 23, 2024
2 parents de1265a + b75ad57 commit cd90901
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 28 deletions.
6 changes: 3 additions & 3 deletions backend/apps/rag/clients/vector_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class VectorItem(BaseModel):


class GetResult(BaseModel):
ids: Optional[List[List[str]]]
ids: Optional[List[List[str | int]]]
documents: Optional[List[List[str]]]
metadatas: Optional[List[List[Any]]]

Expand Down Expand Up @@ -109,11 +109,11 @@ def get(self, collection_name: str) -> Optional[GetResult]:
return GetResult(
ids=[[row[0] for row in result]],
documents=[[row[1] for row in result]],
metadatas=[row[2] for row in result],
metadatas=[[row[2] for row in result]],
)
return None
except Exception as e:
log.error("Get Error:", e)
log.error(f"Get Error: {e}")
return None

def insert(self, collection_name: str, items: list[VectorItem]):
Expand Down
12 changes: 9 additions & 3 deletions backend/apps/rag/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,13 +1243,19 @@ async def scan_database_docs(user=Depends(get_admin_user)):
log.info(f"Doc: {doc}")
log.info(f"Content: {content}")

docs_in_collection = VECTOR_CLIENT.get(collection_name)
ids = docs_in_collection.ids[0]
chunks = docs_in_collection.documents[0]
metadatas = docs_in_collection.metadatas[0]

log.info(f"Found {len(chunks)} chunks in collection {collection_name}")
#

# Ensure the content is properly loaded for processing
if content and isinstance(content, str):
data = content # You can convert to the format needed for further processing

# Here you could add your logic to split documents if needed
await sleep(60)
await store_data_in_vector_db(data, collection_name)
# await store_data_in_vector_db(data, collection_name)

except Exception as e:
log.exception(e)
Expand Down
9 changes: 9 additions & 0 deletions backend/apps/rag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,17 @@ async def generate_openai_embeddings_async(
key: str,
url: str = "https://api.openai.com/v1",
):
# The default limit is 100
# connector = aiohttp.TCPConnector(limit=10)
# async with aiohttp.ClientSession(connector=connector) as session:

# from aiolimiter import AsyncLimiter
# RATE_LIMIT_IN_MINUTE = 5
# rate_limiter = AsyncLimiter(RATE_LIMIT_IN_MINUTE, 60.0)

async with aiohttp.ClientSession() as session:
if isinstance(text, list):
# async with rate_limiter:
embeddings = await generate_openai_batch_embeddings_async(
model, text, key, url, session
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Peewee migrations -- 002_add_local_sharing.py.
"""Peewee migrations -- 013_add_user_info.py.
Some examples (model - class or model name)::
Expand Down
46 changes: 46 additions & 0 deletions backend/apps/webui/internal/migrations/014_add_chat_cost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Peewee migrations -- 013_add_user_info.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""

from contextlib import suppress

import peewee as pw
from peewee_migrate import Migrator


with suppress(ImportError):
import playhouse.postgres_ext as pw_pext


def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""

migrator.add_fields("chat", cost=pw.FloatField(default=0.0))


def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""

migrator.remove_fields("chat", "cost")
2 changes: 1 addition & 1 deletion backend/apps/webui/internal/migrations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ You will need to create a migration file to ensure that existing databases are u
2. Make your changes to the models.
3. From the `backend` directory, run the following command:
```bash
pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME}
pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/webui/internal/migrations ${MIGRATION_NAME}
```
- `$SQLITE_DB` should be the path to the database file.
- `$MIGRATION_NAME` should be a descriptive name for the migration.
Expand Down
2 changes: 2 additions & 0 deletions backend/apps/webui/models/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class Chat(Model):
share_id = CharField(null=True, unique=True)
archived = BooleanField(default=False)

cost = FloatField(default=0.0)

class Meta:
database = DB

Expand Down
4 changes: 2 additions & 2 deletions backend/apps/webui/routers/auths.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def signin_oauth(request: Request, provider: str, form_data: SigninFormOau
user_email_domains.append(domain)
if provider == "github":
log.error(f"github provided email is: {this_email}")
if domain in ["gsa.gov"]:
if domain in ["gsa.gov", "census.gov"]:
email = this_email
user_has_permitted_domain = True
break
Expand Down Expand Up @@ -411,7 +411,7 @@ async def signup(request: Request, form_data: SignupForm):
names = names[0].split(".")
if names and len(names) > 1:
first_name = names[0]
last_name = names[1]
last_name = names[-1]
if first_name and last_name:
name = first_name.capitalize() + " " + last_name.capitalize()

Expand Down
2 changes: 1 addition & 1 deletion src/lib/apis/rag/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ export const queryCollection = async (
export const scanDocs = async (token: string) => {
let error = null;

const res = await fetch(`${RAG_API_BASE_URL}/scan`, {
const res = await fetch(`${RAG_API_BASE_URL}/scan_database_docs`, {
method: 'GET',
headers: {
Accept: 'application/json',
Expand Down
1 change: 1 addition & 0 deletions src/lib/apis/streaming/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async function* openAIStreamToIterator(
};
} catch (e) {
console.error('Error extracting delta from SSE event:', e);
console.error('Error extracting delta from SSE event with data:', data);
}
}
}
Expand Down
32 changes: 18 additions & 14 deletions src/lib/components/chat/Chat.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@
? chatContent.history
: convertMessagesToHistory(chatContent.messages);
title = chatContent.title;
currentDollarAmount = chatContent.cost ?? 0.0;
console.log('Found cost in stored chat', currentDollarAmount);
const userSettings = await getUserSettings(localStorage.token);
Expand Down Expand Up @@ -969,7 +971,7 @@
break;
}
console.log('update', update);
// console.log('update', update);
if (done || stopResponseFlag || _chatId !== $chatId) {
responseMessage.done = true;
Expand Down Expand Up @@ -1058,12 +1060,26 @@
responseMessage.info = { ...lastUsage, openai: true };
}
if (responseMessage.content) {
// get model name
const modelName = model.name ?? model.id;
console.log(`modelName: ${modelName}`); // i.e. FedRamp High Azure GPT 4 Omni
// count and log the number of words in the response
const tokens = Math.round(responseMessage.content.split(' ').length / 0.75);
console.log(`Response contains ${tokens} tokens`);
const tokenCost = tokens * 0.00001;
currentDollarAmount += tokenCost;
currentDollarAmount = Math.round(currentDollarAmount * 10000) / 10000;
console.log(`Cost: ${currentDollarAmount}`);
}
if ($chatId == _chatId) {
if ($settings.saveChatHistory ?? true) {
chat = await updateChatById(localStorage.token, _chatId, {
models: selectedModels,
messages: messages,
history: history
history: history,
cost: currentDollarAmount
});
await chats.set(await getChatList(localStorage.token));
}
Expand Down Expand Up @@ -1101,18 +1117,6 @@
scrollToBottom();
}
if (responseMessage.content) {
// get model name
const modelName = model.name ?? model.id;
console.log(`modelName: ${modelName}`); // i.e. FedRamp High Azure GPT 4 Omni
// count and log the number of words in the response
const tokens = Math.round(responseMessage.content.split(' ').length / 0.75);
console.log(`Response contains ${tokens} tokens`);
const tokenCost = tokens * 0.00001;
currentDollarAmount += tokenCost;
currentDollarAmount = Math.round(currentDollarAmount * 10000) / 10000;
}
if (messages.length == 2) {
window.history.replaceState(history.state, '', `/c/${_chatId}`);
Expand Down
6 changes: 3 additions & 3 deletions src/routes/+layout.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@
await socket.set(_socket);
_socket.on('connect_error', (err) => {
console.log('connect_error', err);
});
// _socket.on('connect_error', (err) => {
// console.log('connect_error', err);
// });
_socket.on('connect', () => {
console.log('connected');
Expand Down

0 comments on commit cd90901

Please sign in to comment.