Skip to content

Commit

Permalink
ollama test
Browse files Browse the repository at this point in the history
  • Loading branch information
LarFii committed Oct 16, 2024
1 parent 3fabaf0 commit 10d1ac4
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 5 deletions.
40 changes: 40 additions & 0 deletions examples/lightrag_ollama_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os

from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
from lightrag.utils import EmbeddingFunc

WORKING_DIR = "./dickens"

if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)

rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name='your_model_name',
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts,
embed_model="nomic-embed-text"
)
),
)


with open("./book.txt") as f:
rag.insert(f.read())

# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))

# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))

# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))

# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
2 changes: 1 addition & 1 deletion lightrag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .lightrag import LightRAG, QueryParam

__version__ = "0.0.5"
__version__ = "0.0.6"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"
2 changes: 1 addition & 1 deletion lightrag/lightrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Type, cast, Any
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM

from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
from .operate import (
chunking_by_token_size,
extract_entities,
Expand Down
50 changes: 48 additions & 2 deletions lightrag/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import numpy as np
import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
from tenacity import (
retry,
Expand Down Expand Up @@ -92,6 +93,34 @@ async def hf_model_if_cache(
)
return response_text

async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)

ollama_client = ollama.AsyncClient()
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})

hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]

response = await ollama_client.chat(model=model, messages=messages, **kwargs)

result = response["message"]["content"]

if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": result, "model": model}})

return result

async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
Expand All @@ -116,8 +145,6 @@ async def gpt_4o_mini_complete(
**kwargs,
)



async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
Expand All @@ -130,6 +157,18 @@ async def hf_model_complete(
**kwargs,
)

async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
return await ollama_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)

@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
Expand All @@ -154,6 +193,13 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().numpy()

async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
embed_text = []
for text in texts:
data = ollama.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])

return embed_text

if __name__ == "__main__":
import asyncio
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ nano-vectordb
hnswlib
xxhash
tenacity
transformers
torch
ollama
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import setuptools

with open("README.md", "r") as fh:
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()


Expand Down

0 comments on commit 10d1ac4

Please sign in to comment.