Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

修复 args_hash在使用常规缓存时候才计算导致embedding缓存时没有计算的bug #408

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,11 @@ if __name__ == "__main__":
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for embedding cache. Includes `enabled` (bool) to toggle cache and `similarity_threshold` (float) for cache retrieval | `{"enabled": False, "similarity_threshold": 0.95}` |
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters:
- `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.

Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` |

## API Server Implementation

Expand Down
66 changes: 49 additions & 17 deletions lightrag/llm.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
import os
import base64
import copy
from functools import lru_cache
import json
import os
import struct
from functools import lru_cache
from typing import List, Dict, Callable, Any

import aioboto3
import aiohttp
import numpy as np
import ollama

import torch
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
Timeout,
AsyncAzureOpenAI,
)

import base64
import struct

from pydantic import BaseModel, Field
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any

from .base import BaseKVStorage
from .utils import (
compute_args_hash,
Expand Down Expand Up @@ -66,7 +65,11 @@ async def openai_complete_if_cache(
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})

if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -86,7 +89,6 @@ async def openai_complete_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -159,7 +161,11 @@ async def azure_openai_complete_if_cache(
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})

if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -178,7 +184,7 @@ async def azure_openai_complete_if_cache(
if best_cached_response is not None:
return best_cached_response
else:
args_hash = compute_args_hash(model, messages)
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -271,6 +277,9 @@ async def bedrock_complete_if_cache(

hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -290,7 +299,6 @@ async def bedrock_complete_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -343,6 +351,11 @@ def initialize_hf_model(model_name):
return hf_model, hf_tokenizer


@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def hf_model_if_cache(
model,
prompt,
Expand All @@ -360,6 +373,9 @@ async def hf_model_if_cache(
messages.append({"role": "user", "content": prompt})

if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -379,7 +395,6 @@ async def hf_model_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -448,6 +463,11 @@ async def hf_model_if_cache(
return response_text


@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def ollama_model_if_cache(
model,
prompt,
Expand All @@ -468,7 +488,11 @@ async def ollama_model_if_cache(
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})

if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -488,7 +512,6 @@ async def ollama_model_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -542,6 +565,11 @@ def initialize_lmdeploy_pipeline(
return lmdeploy_pipe


@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def lmdeploy_model_if_cache(
model,
prompt,
Expand Down Expand Up @@ -620,7 +648,11 @@ async def lmdeploy_model_if_cache(
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})

if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)

# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
Expand All @@ -640,7 +672,6 @@ async def lmdeploy_model_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
Expand Down Expand Up @@ -831,7 +862,8 @@ async def openai_embedding(
)
async def nvidia_openai_embedding(
texts: list[str],
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
base_url: str = "https://integrate.api.nvidia.com/v1",
api_key: str = None,
input_type: str = "passage", # query for retrieval, passage for embedding
Expand Down
Loading