diff --git a/flask4modelcache.py b/flask4modelcache.py index 4dc85a9..ca70014 100644 --- a/flask4modelcache.py +++ b/flask4modelcache.py @@ -38,13 +38,16 @@ def response_hitquery(cache_resp): milvus_config = configparser.ConfigParser() milvus_config.read('modelcache/config/milvus_config.ini') +es_config = configparser.ConfigParser() +es_config.read('modelcache/config/elasticsearch_config.ini') + # redis_config = configparser.ConfigParser() # redis_config.read('modelcache/config/redis_config.ini') # chromadb_config = configparser.ConfigParser() # chromadb_config.read('modelcache/config/chromadb_config.ini') -data_manager = get_data_manager(CacheBase("mysql", config=mysql_config), +data_manager = get_data_manager(CacheBase("elasticsearch", config=es_config), VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config)) diff --git a/modelcache/config/elasticsearch_config.ini b/modelcache/config/elasticsearch_config.ini new file mode 100644 index 0000000..ba87a5a --- /dev/null +++ b/modelcache/config/elasticsearch_config.ini @@ -0,0 +1,5 @@ +[elasticsearch] +host = '' +port = '' +user = '' +password = '' \ No newline at end of file diff --git a/modelcache/manager/scalar_data/manager.py b/modelcache/manager/scalar_data/manager.py index 4c02c45..8ff3aee 100644 --- a/modelcache/manager/scalar_data/manager.py +++ b/modelcache/manager/scalar_data/manager.py @@ -27,6 +27,10 @@ def get(name, **kwargs): from modelcache.manager.scalar_data.sql_storage_sqlite import SQLStorage sql_url = kwargs.get("sql_url", SQL_URL[name]) cache_base = SQLStorage(db_type=name, url=sql_url) + elif name == 'elasticsearch': + from modelcache.manager.scalar_data.sql_storage_es import SQLStorage + config = kwargs.get("config") + cache_base = SQLStorage(db_type=name, config=config) else: raise NotFoundError("cache store", name) return cache_base diff --git a/modelcache/manager/scalar_data/sql_storage_es.py b/modelcache/manager/scalar_data/sql_storage_es.py new file mode 100644 index 0000000..7e0184a --- /dev/null +++ b/modelcache/manager/scalar_data/sql_storage_es.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +import json +from typing import List +from elasticsearch import Elasticsearch, helpers +from modelcache.manager.scalar_data.base import CacheStorage, CacheData +import time +from snowflake import SnowflakeGenerator + + +class SQLStorage(CacheStorage): + def __init__( + self, + db_type: str = "elasticsearch", + config=None + ): + self.host = config.get('elasticsearch', 'host') + self.port = int(config.get('elasticsearch', 'port')) + self.client = Elasticsearch( + hosts=[{"host": self.host, "port": self.port}], + timeout=30, + http_auth=('esuser', 'password') + ) + + self.log_index = "modelcache_query_log" + self.ans_index = "modelcache_llm_answer" + self.create() + self.instance_id = 1 # 雪花算法使用的机器id 使用同一套数据库的分布式系统需要配置不同id + # 生成雪花id + self.snowflake_id = SnowflakeGenerator(self.instance_id) + + def create(self): + answer_index_body = { + "mappings": { + "properties": { + "gmt_create": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "gmt_modified": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "question": {"type": "text"}, + "answer": {"type": "text"}, + "answer_type": {"type": "integer"}, + "hit_count": {"type": "integer"}, + "model": {"type": "keyword"}, + "embedding_data": {"type": "binary"}, + "is_deleted": {"type": "integer"}, + } + } + } + + log_index_body = { + "mappings": { + "properties": { + "gmt_create": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "gmt_modified": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "error_code": {"type": "integer"}, + "error_desc": {"type": "text"}, + "cache_hit": {"type": "keyword"}, + "delta_time": {"type": "float"}, + "model": {"type": "keyword"}, + "query": {"type": "text"}, + "hit_query": {"type": "text"}, + "answer": {"type": "text"} + } + } + } + + if not self.client.indices.exists(index=self.ans_index): + self.client.indices.create(index=self.ans_index, body=answer_index_body) + + if not self.client.indices.exists(index=self.log_index): + self.client.indices.create(index=self.log_index, body=log_index_body) + + def _insert(self, data: List) -> str or None: + doc = { + "answer": data[0], + "question": data[1], + "embedding_data": data[2].tolist() if hasattr(data[2], "tolist") else data[2], + "model": data[3], + "answer_type": 0, + "hit_count": 0, + "is_deleted": 0 + } + + try: + + response = self.client.index( + index=self.ans_index, + id=next(self.snowflake_id), + body=doc, + ) + return int(response['_id']) + except Exception as e: + + print(f"Failed to insert document: {e}") + return None + + def batch_insert(self, all_data: List[List]) -> List[str]: + successful_ids = [] + for data in all_data: + _id = self._insert(data) + if _id is not None: + successful_ids.append(_id) + self.client.indices.refresh(index=self.ans_index) # 批量插入后手动刷新 + + return successful_ids + + def insert_query_resp(self, query_resp, **kwargs): + doc = { + "error_code": query_resp.get('errorCode'), + "error_desc": query_resp.get('errorDesc'), + "cache_hit": query_resp.get('cacheHit'), + "model": kwargs.get('model'), + "query": kwargs.get('query'), + "delta_time": kwargs.get('delta_time'), + "hit_query": json.dumps(query_resp.get('hit_query'), ensure_ascii=False) if isinstance( + query_resp.get('hit_query'), list) else query_resp.get('hit_query'), + "answer": query_resp.get('answer'), + "hit_count": 0, + "is_deleted": 0 + + } + self.client.index(index=self.log_index, body=doc) + + def get_data_by_id(self, key: int): + try: + response = self.client.get(index=self.ans_index, id=key, _source=['question', 'answer', 'embedding_data', 'model']) + source = response["_source"] + result = [ + source.get('question'), + source.get('answer'), + source.get('embedding_data'), + source.get('model') + ] + return result + except Exception as e: + print(e) + + def update_hit_count_by_id(self, primary_id: int): + self.client.update( + index=self.ans_index, + id=primary_id, + body={"script": {"source": "ctx._source.hit_count += 1"}} + ) + + def get_ids(self, deleted=True): + query = { + "query": { + "term": {"is_deleted": 1 if deleted else 0} + } + } + response = self.client.search(index=self.ans_index, body=query) + return [hit["_id"] for hit in response["hits"]["hits"]] + + def mark_deleted(self, keys): + actions = [ + { + "_op_type": "update", + "_index": self.ans_index, + "_id": key, + "doc": {"is_deleted": 1} + } + for key in keys + ] + responses = helpers.bulk(self.client, actions) + return responses[0] # 返回更新的文档数 + + def model_deleted(self, model_name): + query = { + "query": { + "term": {"model": model_name} + } + } + + response = self.client.delete_by_query(index=self.ans_index, body=query) + return response["deleted"] + + def clear_deleted_data(self): + query = { + "query": { + "term": {"is_deleted": 1} + } + } + response = self.client.delete_by_query(index=self.ans_index, body=query) + return response["deleted"] + + def count(self, state: int = 0, is_all: bool = False): + query = {"query": {"match_all": {}}} if is_all else {"query": {"term": {"is_deleted": state}}} + response = self.client.count(index=self.ans_index, body=query) + return response["count"] + + def close(self): + self.client.close() + + def count_answers(self): + query = {"query": {"match_all": {}}} + response = self.client.count(index=self.ans_index, body=query) + return response["count"] diff --git a/modelcache_mm/config/elasticsearch_config.ini b/modelcache_mm/config/elasticsearch_config.ini new file mode 100644 index 0000000..ba87a5a --- /dev/null +++ b/modelcache_mm/config/elasticsearch_config.ini @@ -0,0 +1,5 @@ +[elasticsearch] +host = '' +port = '' +user = '' +password = '' \ No newline at end of file diff --git a/modelcache_mm/manager/scalar_data/sql_storage_es.py b/modelcache_mm/manager/scalar_data/sql_storage_es.py new file mode 100644 index 0000000..562b349 --- /dev/null +++ b/modelcache_mm/manager/scalar_data/sql_storage_es.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +import json +from typing import List +from elasticsearch import Elasticsearch, helpers +from modelcache.manager.scalar_data.base import CacheStorage, CacheData +import time +from snowflake import SnowflakeGenerator + + +class SQLStorage(CacheStorage): + def __init__( + self, + db_type: str = "elasticsearch", + config=None + ): + self.host = config.get('elasticsearch', 'host') + self.port = int(config.get('elasticsearch', 'port')) + self.client = Elasticsearch( + hosts=[{"host": self.host, "port": self.port}], + timeout=30, + http_auth=('esuser', 'password') + ) + + self.log_index = "open_cache_mm_query_log" + self.ans_index = "open_cache_mm_answer" + self.create() + self.instance_id = 1 # 雪花算法使用的机器id 使用同一套数据库的分布式系统需要配置不同id + # 生成雪花id + self.snowflake_id = SnowflakeGenerator(self.instance_id) + + def create(self): + answer_index_body = { + "mappings": { + "properties": { + "gmt_create": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "gmt_modified": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "question": {"type": "text"}, + "answer": {"type": "text"}, + "answer_type": {"type": "integer"}, + "hit_count": {"type": "integer"}, + "model": {"type": "keyword"}, + "image_url": {"type": "text"}, + "image_id": {"type": "text"}, + "is_deleted": {"type": "integer"}, + } + } + } + + log_index_body = { + "mappings": { + "properties": { + "gmt_create": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "gmt_modified": {"type": "date", "format": "strict_date_optional_time||epoch_millis"}, + "error_code": {"type": "integer"}, + "error_desc": {"type": "text"}, + "cache_hit": {"type": "keyword"}, + "delta_time": {"type": "float"}, + "model": {"type": "keyword"}, + "query": {"type": "text"}, + "hit_query": {"type": "text"}, + "answer": {"type": "text"} + } + } + } + + if not self.client.indices.exists(index=self.ans_index): + self.client.indices.create(index=self.ans_index, body=answer_index_body) + + if not self.client.indices.exists(index=self.log_index): + self.client.indices.create(index=self.log_index, body=log_index_body) + + def _insert(self, data: List) -> str or None: + doc = { + "answer": data[0], + "question": data[1], + "image_url": data[2], + "image_id": data[3], + "model": data[4], + "answer_type": 0, + "hit_count": 0, + "is_deleted": 0 + } + + try: + + response = self.client.index( + index=self.ans_index, + id=next(self.snowflake_id), + body=doc, + ) + return int(response['_id']) + except Exception as e: + + print(f"Failed to insert document: {e}") + return None + + def batch_insert(self, all_data: List[List]) -> List[str]: + successful_ids = [] + for data in all_data: + _id = self._insert(data) + if _id is not None: + successful_ids.append(_id) + self.client.indices.refresh(index=self.ans_index) # 批量插入后手动刷新 + + return successful_ids + + def insert_query_resp(self, query_resp, **kwargs): + doc = { + "error_code": query_resp.get('errorCode'), + "error_desc": query_resp.get('errorDesc'), + "cache_hit": query_resp.get('cacheHit'), + "model": kwargs.get('model'), + "query": kwargs.get('query'), + "delta_time": kwargs.get('delta_time'), + "hit_query": json.dumps(query_resp.get('hit_query'), ensure_ascii=False) if isinstance( + query_resp.get('hit_query'), list) else query_resp.get('hit_query'), + "answer": query_resp.get('answer'), + "hit_count": 0, + "is_deleted": 0 + + } + self.client.index(index=self.log_index, body=doc) + + def get_data_by_id(self, key: int): + try: + response = self.client.get(index=self.ans_index, id=key, _source=['question', 'image_url','image_id', 'answer', 'model']) + source = response["_source"] + result = [ + source.get('question'), + source.get('image_url'), + source.get('image_id'), + source.get('answer'), + source.get('model') + ] + return result + except Exception as e: + print(e) + + def update_hit_count_by_id(self, primary_id: int): + self.client.update( + index=self.ans_index, + id=primary_id, + body={"script": {"source": "ctx._source.hit_count += 1"}} + ) + + def get_ids(self, deleted=True): + query = { + "query": { + "term": {"is_deleted": 1 if deleted else 0} + } + } + response = self.client.search(index=self.ans_index, body=query) + return [hit["_id"] for hit in response["hits"]["hits"]] + + def mark_deleted(self, keys): + actions = [ + { + "_op_type": "update", + "_index": self.ans_index, + "_id": key, + "doc": {"is_deleted": 1} + } + for key in keys + ] + responses = helpers.bulk(self.client, actions) + return responses[0] # 返回更新的文档数 + + def model_deleted(self, model_name): + query = { + "query": { + "term": {"model": model_name} + } + } + + response = self.client.delete_by_query(index=self.ans_index, body=query) + return response["deleted"] + + def clear_deleted_data(self): + query = { + "query": { + "term": {"is_deleted": 1} + } + } + response = self.client.delete_by_query(index=self.ans_index, body=query) + return response["deleted"] + + def count(self, state: int = 0, is_all: bool = False): + query = {"query": {"match_all": {}}} if is_all else {"query": {"term": {"is_deleted": state}}} + response = self.client.count(index=self.ans_index, body=query) + return response["count"] + + def close(self): + self.client.close() + + def count_answers(self): + query = {"query": {"match_all": {}}} + response = self.client.count(index=self.ans_index, body=query) + return response["count"] diff --git a/requirements.txt b/requirements.txt index 77b3868..6d707d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,5 @@ modelscope==1.14.0 fastapi==0.115.5 uvicorn==0.32.0 chromadb==0.5.23 +elasticsearch==7.10.0 +snowflake-id==1.0.2