Skip to content

Commit

Permalink
✨ feat: add LMDB runtime for caching
Browse files Browse the repository at this point in the history
🚀 Add LMDB runtime implementation for caching data with asyncio support. Update the set_data method to accept Union[dict, str, bytes] data type. Set lightning as the engine for storage in the MontyDatabaseClient. Update the init_client method to connect to LMDBClientAsyncWrapper and log the folder path. Ensure LMDBClientAsyncWrapper is loaded correctly and return the client for further use. Update the global_cache_runtime to use LMDBRuntime as the default cache runtime. This commit enhances caching capabilities with LMDB as a storage option.
  • Loading branch information
sudoskys committed Apr 18, 2024
1 parent 0851b74 commit 960673a
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 47 deletions.
8 changes: 8 additions & 0 deletions app/receiver/telegram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ async def file_forward(self, receiver: Location, file_list: List[File]):
sticker=file_downloaded,
)
elif file_obj.file_name.endswith(".ogg"):
await self.bot.send_chat_action(
chat_id=receiver.chat_id, action="record_voice"
)
try:
await self.bot.send_voice(
chat_id=receiver.chat_id,
Expand All @@ -78,6 +81,9 @@ async def file_forward(self, receiver: Location, file_list: List[File]):
else:
raise e
else:
await self.bot.send_chat_action(
chat_id=receiver.chat_id, action="upload_document"
)
await self.bot.send_document(
chat_id=receiver.chat_id,
document=file_downloaded,
Expand Down Expand Up @@ -118,6 +124,8 @@ async def reply(
:param messages: OPENAI Format Message
:param reply_to_message: 是否回复消息
"""
if receiver.chat_id is not None:
await self.bot.send_chat_action(chat_id=receiver.chat_id, action="typing")
event_message = [
EventMessage.from_openai_message(message=item, locate=receiver)
for item in messages
Expand Down
48 changes: 47 additions & 1 deletion llmkira/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,56 @@
from loguru import logger

from .elara_runtime import ElaraClientAsyncWrapper
from .lmdb_runtime import LMDBClientAsyncWrapper
from .redis_runtime import RedisClientWrapper
from .runtime_schema import singleton, BaseRuntime


@singleton
class LMDBRuntime(BaseRuntime):
client: Optional["LMDBClientAsyncWrapper"] = None
init_already = False
dsn = None

@staticmethod
def check_client_dsn(dsn):
"""
:raise ValueError: Please Use Local Path
"""
if "://" in dsn:
raise ValueError("Please Use Local Path")

def check_client(self):
if self.dsn is None:
pathlib.Path().cwd().joinpath(".cache").mkdir(exist_ok=True)
self.dsn = pathlib.Path().cwd().joinpath(".cache") / "lmdb_dir"
self.client = LMDBClientAsyncWrapper(str(self.dsn))
logger.debug(f"🍩 LMDBClientAsyncWrapper Loaded --folder {self.dsn}")
return True

def init_client(self, verbose=False):
if verbose:
logger.info("Try To Connect To LMDBClientAsyncWrapper")
self.check_client()
self.init_already = True
assert isinstance(
self.client, LMDBClientAsyncWrapper
), f"LMDBClientAsyncWrapper type error {type(self.client)}"
return self.client

def get_client(self) -> "LMDBClientAsyncWrapper":
if not self.init_already:
self.init_client()
assert isinstance(
self.client, LMDBClientAsyncWrapper
), f"LMDBClientAsyncWrapper error {type(self.client)}"
else:
assert isinstance(
self.client, LMDBClientAsyncWrapper
), f"Inited LMDBClientAsyncWrapper error {type(self.client)}"
return self.client


@singleton
class ElaraRuntime(BaseRuntime):
client: Optional["ElaraClientAsyncWrapper"] = None
Expand Down Expand Up @@ -115,4 +161,4 @@ def get_client(self) -> "RedisClientWrapper":
if RedisRuntime().check_client():
global_cache_runtime: BaseRuntime = RedisRuntime()
else:
global_cache_runtime: BaseRuntime = ElaraRuntime()
global_cache_runtime: BaseRuntime = LMDBRuntime()
3 changes: 2 additions & 1 deletion llmkira/cache/elara_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# @Time : 2023/7/10 下午9:44
import asyncio
import json
from typing import Union

import elara
from loguru import logger
Expand Down Expand Up @@ -38,7 +39,7 @@ async def read_data(self, key):
logger.trace(ex)
return data

async def set_data(self, key, value, timeout: int = None):
async def set_data(self, key, value: Union[dict, str, bytes], timeout: int = None):
"""
Set data to elara
:param key:
Expand Down
67 changes: 67 additions & 0 deletions llmkira/cache/lmdb_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import asyncio
import json
from typing import Union, Optional

import lmdb
from loguru import logger

from .runtime_schema import AbstractDataClass, PREFIX


class LMDBClientAsyncWrapper(AbstractDataClass):
"""
LMDB 数据基类
"""

def __init__(self, backend, prefix=PREFIX):
self.prefix = prefix
self.env = lmdb.open(backend)
self.lock = asyncio.Lock()

async def ping(self):
return True

def update_backend(self, backend):
self.env = lmdb.open(backend)
return True

async def read_data(self, key) -> Optional[Union[dict, str, bytes]]:
"""
Read data from LMDB
"""
data = None
async with self.lock:
with self.env.begin() as txn:
raw_data = txn.get((self.prefix + str(key)).encode())
if raw_data is not None:
try:
data = json.loads(raw_data.decode())
except json.JSONDecodeError:
# 如果JSON解码失败,并且数据以一个utf8字符串开头,我们假定数据是字符串
if raw_data.startswith(b'{"') is False:
data = raw_data.decode()
except UnicodeDecodeError:
# 如果Unicode解码失败,我们假定数据是字节型数据
data = raw_data
except Exception as ex:
logger.trace(ex)
return data

async def set_data(self, key, value: Union[dict, str, bytes], timeout: int = None):
"""
Set data to LMDB
:param key:
:param value: a dict, str or bytes
:param timeout: seconds
:return:
"""
async with self.lock:
with self.env.begin(write=True) as txn:
if isinstance(value, (dict, list)):
value = json.dumps(value).encode()
elif isinstance(value, str):
# 如果数据是一个字符串,我们将其编码为字节数据
value = value.encode()
# 对于字节类型的数据,我们直接存储
txn.put((self.prefix + str(key)).encode(), value)
return True
2 changes: 1 addition & 1 deletion llmkira/cache/redis_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def update_backend(self, backend):
self._redis = Redis(connection_pool=self.connection_pool)
return True

async def set_data(self, key, value, timeout=None):
async def set_data(self, key, value: Union[dict, str, bytes], timeout=None):
if isinstance(value, (dict, list)):
value = json.dumps(value)
return await self._redis.set(
Expand Down
6 changes: 4 additions & 2 deletions llmkira/cache/runtime_schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# @Time : 2023/7/10 下午9:43
from abc import abstractmethod, ABC
from typing import Any
from typing import Any, Union

PREFIX = "oai_bot:"

Expand Down Expand Up @@ -46,7 +46,9 @@ def update_backend(self, backend):
pass

@abstractmethod
async def set_data(self, key: str, value: Any, timeout: int = None) -> Any:
async def set_data(
self, key: str, value: Union[dict, str, bytes], timeout: int = None
) -> Any:
pass

@abstractmethod
Expand Down
5 changes: 4 additions & 1 deletion llmkira/doc_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pymongo
from dotenv import load_dotenv
from loguru import logger
from montydb import errors as monty_errors, MontyClient
from montydb import errors as monty_errors, MontyClient, set_storage
from pydantic import model_validator, Field
from pydantic_settings import BaseSettings
from pymongo import MongoClient
Expand Down Expand Up @@ -39,6 +39,9 @@ def update_one(
class MontyDatabaseClient(DatabaseClient):
def __init__(self, db_name=None, collection_name=None):
local_repo = ".montydb"
set_storage(
local_repo, storage="lightning"
) # required, to set lightning as engine
self.client = MontyClient(local_repo)
self.update_db_collection(db_name, collection_name)

Expand Down
127 changes: 92 additions & 35 deletions llmkira/extra/voice/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import base64
import json
from io import BytesIO

import aiohttp
import edge_tts
from loguru import logger


async def request_cn_speech(text):
async def request_dui_speech(text):
"""
Call the DuerOS endpoint to generate synthesized voice.
:param text: The text to synthesize
Expand All @@ -18,15 +21,18 @@ async def request_cn_speech(text):
"volume": "50",
"audioType": "wav",
}

async with aiohttp.ClientSession() as session:
async with session.get(base_url, params=params) as response:
if (
response.status != 200
or response.headers.get("Content-Type") != "audio/wav"
):
return None
return await response.read()
try:
async with aiohttp.ClientSession() as session:
async with session.get(base_url, params=params) as response:
if (
response.status != 200
or response.headers.get("Content-Type") != "audio/wav"
):
return None
return await response.read()
except Exception as e:
logger.warning(f"DuerOS TTS Error: {e}")
return None


def get_audio_bytes_from_data_url(data_url):
Expand Down Expand Up @@ -58,7 +64,7 @@ async def request_reecho_speech(
:return: The synthesized voice audio data, or None if the request failed
"""
if not reecho_api_key:
return await request_cn_speech(text)
return None
url = "https://v1.reecho.ai/api/tts/simple-generate"
headers = {
"User-Agent": "Apifox/1.0.0 (https://apifox.com)",
Expand All @@ -73,21 +79,39 @@ async def request_reecho_speech(
"stability_boost": 50,
"probability_optimization": 99,
}
audio_bytes = None
async with aiohttp.ClientSession() as session:
async with session.post(
url, headers=headers, data=json.dumps(data)
) as response:
if response.status == 200:
response_json = await response.json()
audio_url = response_json["data"].get("audio", None)
audio_bytes = get_audio_bytes_from_data_url(audio_url)
if not audio_bytes:
return await request_cn_speech(text)
return audio_bytes


async def request_en_speech(text):
try:
audio_bytes = None
async with aiohttp.ClientSession() as session:
async with session.post(
url, headers=headers, data=json.dumps(data)
) as response:
if response.status == 200:
response_json = await response.json()
audio_url = response_json["data"].get("audio", None)
audio_bytes = get_audio_bytes_from_data_url(audio_url)
if not audio_bytes:
return None
return audio_bytes
except Exception as e:
logger.warning(f"Reecho TTS Error: {e}")
return None


async def request_edge_speech(text: str, voice: str = "en-GB-SoniaNeural"):
try:
communicate = edge_tts.Communicate(text, voice)
byte_io = BytesIO()
async for chunk in communicate.stream():
if chunk["type"] == "audio":
byte_io.write(chunk["data"])
byte_io.seek(0)
return byte_io.getvalue()
except Exception as e:
logger.warning(f"Edge TTS Error: {e}")
return None


async def request_novelai_speech(text):
"""
Call the NovelAI endpoint to generate synthesized voice.
:param text: The text to synthesize
Expand All @@ -102,12 +126,45 @@ async def request_en_speech(text):
"opus": "false",
"version": "v2",
}
async with aiohttp.ClientSession() as session:
async with session.get(base_url, params=params, headers=headers) as response:
if response.status != 200:
return None
audio_content_type = response.headers.get("Content-Type")
valid_content_types = ["audio/mpeg", "audio/ogg", "audio/opus"]
if audio_content_type not in valid_content_types:
return None
return await response.read()
try:
async with aiohttp.ClientSession() as session:
async with session.get(
base_url, params=params, headers=headers
) as response:
if response.status != 200:
return None
audio_content_type = response.headers.get("Content-Type")
valid_content_types = ["audio/mpeg", "audio/ogg", "audio/opus"]
if audio_content_type not in valid_content_types:
return None
return await response.read()
except Exception as e:
logger.warning(f"NovelAI TTS Error: {e}")
return None


async def request_cn(text, reecho_api_key: str = None):
"""
Call the Reecho endpoint to generate synthesized voice.
:param text: The text to synthesize
:param reecho_api_key: The Reecho API token
:return: The synthesized voice audio data, or None if the request failed
"""
if not reecho_api_key:
return await request_dui_speech(text)
else:
stt = await request_reecho_speech(text, reecho_api_key)
if not stt:
return await request_dui_speech(text)


async def request_en(text):
"""
Call the Reecho endpoint to generate synthesized voice.
:param text: The text to synthesize
"""
nai = await request_novelai_speech(text)
if nai:
return nai
else:
return await request_edge_speech(text)
Loading

0 comments on commit 960673a

Please sign in to comment.