Skip to content

Commit

Permalink
use llm instead of langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
iQuxLE committed Aug 20, 2024
1 parent 5af5e0e commit 5a61ac6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
33 changes: 19 additions & 14 deletions src/curate_gpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from dataclasses import dataclass, field
from typing import Any, Callable, ClassVar, Dict, Iterable, Iterator, List, Mapping, Optional, Union
from langchain_openai import OpenAIEmbeddings
import llm
import duckdb
import numpy as np
import openai
Expand All @@ -33,6 +33,8 @@
IDS,
METADATAS,
MODEL_DIMENSIONS,
MODEL_MAP,
DEFAULT_MODEL,
MODELS,
OBJECT,
OPENAI_MODEL_DIMENSIONS,
Expand Down Expand Up @@ -192,12 +194,12 @@ def _embedding_function(self, texts: Union[str, List[str], List[List[str]]], mod
if model.startswith("openai:"):
self._initialize_openai_client()
openai_model = model.split(":", 1)[1]
if openai_model == "" or openai_model not in MODELS:
if openai_model == "" or openai_model not in MODEL_MAP.keys():
logger.info(
f"The model {openai_model} is not "
f"one of {MODELS}. Defaulting to {MODELS[1]}"
f"one of {[MODEL_MAP.keys()]}. Defaulting to {DEFAULT_MODEL}"
)
openai_model = MODELS[1]
openai_model = DEFAULT_MODEL

responses = [
self.openai_client.embeddings.create(input=text, model=openai_model)
Expand Down Expand Up @@ -343,10 +345,10 @@ def _process_objects(
else:
if model.startswith("openai:"):
openai_model = model.split(":", 1)[1]
if openai_model == "" or openai_model not in MODELS:
if openai_model == "" or openai_model not in MODEL_MAP.keys():
logger.info(f"The model {openai_model} is not "
f"one of {MODELS}. Defaulting to {MODELS[0]}")
openai_model = MODELS[0] #ada 002
f"one of {MODEL_MAP.keys()}. Defaulting to {DEFAULT_MODEL}")
openai_model = DEFAULT_MODEL #ada 002
else:
logger.error(f"Something went wonky ## model: {model}")
from transformers import GPT2Tokenizer
Expand All @@ -373,19 +375,20 @@ def _process_objects(
i += 1
else:
if current_batch:
logger.info(f"Curent token count to embed: {current_token_count}")
logger.info(f"Tokens: {current_token_count}")
texts = [tokenizer.decode(tokens) for tokens in current_batch]
embeddings = OpenAIEmbeddings(model=openai_model, tiktoken_model_name=model).embed_documents(texts,
openai_model)
logger.info(f"len embeddings: {len(embeddings)}")
short_name, _ = MODEL_MAP[openai_model]
embedding_model = llm.get_embedding_model(short_name)
embeddings = list(embedding_model.embed_multi(texts))
logger.info(f"Number of Documents in batch: {len(embeddings)}")
batch_embeddings.extend(embeddings)

if len(doc_tokens) > 8192:
logger.warning(
f"Document with ID {ids[i]} exceeds the token limit alone and will be skipped.")
# try:
# embeddings = OpenAIEmbeddings(model=model, tiktoken_model_name=model).embed_query(texts,
# model)
# embeddings.average model)
# batch_embeddings.extend(embeddings)
# skipping
i += 1
Expand All @@ -395,9 +398,11 @@ def _process_objects(
current_token_count = 0

if current_batch:
logger.info(f"Last batch, token count: {current_token_count}")
texts = [tokenizer.decode(tokens) for tokens in current_batch]
embeddings = OpenAIEmbeddings(model=openai_model, tiktoken_model_name=openai_model).embed_documents(texts,
openai_model)
short_name, _ = MODEL_MAP[openai_model]
embedding_model = llm.get_embedding_model(short_name)
embeddings = list(embedding_model.embed_multi(texts))
batch_embeddings.extend(embeddings)
logger.info(f"Trying to insert: {len(ids)} IDS, {len(metadatas)} METADATAS, {len(batch_embeddings)} EMBEDDINGS")
try:
Expand Down
11 changes: 11 additions & 0 deletions src/curate_gpt/store/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,14 @@
"text-embedding-3-large": 3072,
}
MODELS = ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]

MODEL_MAP = {
"text-embedding-ada-002": ("ada-002", 1536),
"text-embedding-3-small": ("3-small", 1536),
"text-embedding-3-large": ("3-large", 3072),
"text-embedding-3-small-512": ("3-small-512", 512),
"text-embedding-3-large-256": ("3-large-256", 256),
"text-embedding-3-large-1024": ("3-large-1024", 1024)
}

DEFAULT_MODEL = "text-embedding-ada-002"

0 comments on commit 5a61ac6

Please sign in to comment.