Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mg #4

Merged
merged 3 commits into from
Nov 6, 2024
Merged

mg #4

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ TRIPLET_GRAPH_ENABLED=True # enable the graph search for triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks

KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks
KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE=20 # the batch size of triplet extraction from the text

### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data
Expand Down
8 changes: 4 additions & 4 deletions assets/schema/dbgpt.sql
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ CREATE TABLE `gpts_app_collection` (
`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code',
`user_code` int(11) NOT NULL COMMENT 'user code',
`sys_code` varchar(255) NOT NULL COMMENT 'system app code',
`sys_code` varchar(255) NULL COMMENT 'system app code',
`created_at` datetime DEFAULT NULL COMMENT 'create time',
`updated_at` datetime DEFAULT NULL COMMENT 'last update time',
PRIMARY KEY (`id`),
Expand Down Expand Up @@ -439,10 +439,10 @@ CREATE TABLE `recommend_question` (
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'create time',
`gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'last update time',
`app_code` varchar(255) DEFAULT NULL COMMENT 'Current AI assistant code',
`app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code',
`question` text DEFAULT NULL COMMENT 'question',
`user_code` int(11) NOT NULL COMMENT 'user code',
`sys_code` varchar(255) NOT NULL COMMENT 'system app code',
`sys_code` varchar(255) NULL COMMENT 'system app code',
`valid` varchar(10) DEFAULT 'true' COMMENT 'is it effective,true/false',
`chat_mode` varchar(255) DEFAULT NULL COMMENT 'Conversation scene mode,chat_knowledge...',
`params` text DEFAULT NULL COMMENT 'question param',
Expand All @@ -456,7 +456,7 @@ CREATE TABLE `user_recent_apps` (
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id',
`gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'create time',
`gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'last update time',
`app_code` varchar(255) DEFAULT NULL COMMENT 'AI assistant code',
`app_code` varchar(255) NOT NULL COMMENT 'AI assistant code',
`last_accessed` timestamp NULL DEFAULT NULL COMMENT 'User recent usage time',
`user_code` varchar(255) DEFAULT NULL COMMENT 'user code',
`sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code',
Expand Down
32 changes: 30 additions & 2 deletions dbgpt/model/proxy/llms/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,43 @@ def __init__(
context_length: Optional[int] = 4096,
executor: Optional[Executor] = None,
):
"""
Tips: 星火大模型API当前有Lite、Pro、Pro-128K、Max、Max-32K和4.0 Ultra六个版本,各版本独立计量tokens。
传输协议 :ws(s),为提高安全性,强烈推荐wss

Spark4.0 Ultra 请求地址,对应的domain参数为4.0Ultra:
wss://spark-api.xf-yun.com/v4.0/chat

Spark Max-32K请求地址,对应的domain参数为max-32k
wss://spark-api.xf-yun.com/chat/max-32k

Spark Max请求地址,对应的domain参数为generalv3.5
wss://spark-api.xf-yun.com/v3.5/chat

Spark Pro-128K请求地址,对应的domain参数为pro-128k:
wss://spark-api.xf-yun.com/chat/pro-128k

Spark Pro请求地址,对应的domain参数为generalv3:
wss://spark-api.xf-yun.com/v3.1/chat

Spark Lite请求地址,对应的domain参数为lite:
wss://spark-api.xf-yun.com/v1.1/chat
"""
if not model_version:
model_version = model or os.getenv("XUNFEI_SPARK_API_VERSION")
if not api_base:
if model_version == SPARK_DEFAULT_API_VERSION:
api_base = "ws://spark-api.xf-yun.com/v3.1/chat"
domain = "generalv3"
elif model_version == "v4.0":
api_base = "ws://spark-api.xf-yun.com/v4.0/chat"
domain = "4.0Ultra"
elif model_version == "v3.5":
api_base = "ws://spark-api.xf-yun.com/v3.5/chat"
domain = "generalv3.5"
else:
api_base = "ws://spark-api.xf-yun.com/v2.1/chat"
domain = "generalv2"
api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
domain = "lite"
if not api_domain:
api_domain = domain
self._model = model
Expand Down
10 changes: 10 additions & 0 deletions dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Transformer base class."""

import logging
from abc import ABC, abstractmethod
from typing import List, Optional
Expand Down Expand Up @@ -37,6 +38,15 @@ class ExtractorBase(TransformerBase, ABC):
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract results from text."""

@abstractmethod
async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract results from texts."""


class TranslatorBase(TransformerBase, ABC):
"""Translator base class."""
98 changes: 80 additions & 18 deletions dbgpt/rag/transformer/graph_extractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""GraphExtractor class."""

import asyncio
import logging
import re
from typing import List, Optional
from typing import Dict, List, Optional

from dbgpt.core import Chunk, LLMClient
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
Expand All @@ -23,35 +24,96 @@ def __init__(
self._chunk_history = chunk_history

config = self._chunk_history.get_config()

self._vector_space = config.name
self._max_chunks_once_load = config.max_chunks_once_load
self._max_threads = config.max_threads
self._topk = config.topk
self._score_threshold = config.score_threshold

async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Load similar chunks."""
# load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]
context = "\n".join(history) if history else ""

try:
# extract with chunk history
return await super()._extract(text, context, limit)

finally:
# save chunk to history
async def aload_chunk_context(self, texts: List[str]) -> Dict[str, str]:
"""Load chunk context."""
text_context_map: Dict[str, str] = {}

for text in texts:
# Load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]

# Save chunk to history
await self._chunk_history.aload_document_with_limit(
[Chunk(content=text, metadata={"relevant_cnt": len(history)})],
self._max_chunks_once_load,
self._max_threads,
)

# Save chunk context to map
context = "\n".join(history) if history else ""
text_context_map[text] = context
return text_context_map

async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract graphs from text.

Suggestion: to extract triplets in batches, call `batch_extract`.
"""
# Load similar chunks
text_context_map = await self.aload_chunk_context([text])
context = text_context_map[text]

# Extract with chunk history
return await super()._extract(text, context, limit)

async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List[List[Graph]]:
"""Extract graphs from chunks in batches.

Returns list of graphs in same order as input texts (text <-> graphs).
"""
if batch_size < 1:
raise ValueError("batch_size >= 1")

# 1. Load chunk context
text_context_map = await self.aload_chunk_context(texts)

# Pre-allocate results list to maintain order
graphs_list: List[List[Graph]] = [None] * len(texts)
total_batches = (len(texts) + batch_size - 1) // batch_size

for batch_idx in range(total_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]

# 2. Create tasks with their original indices
extraction_tasks = [
(
idx,
self._extract(text, text_context_map[text], limit),
)
for idx, text in enumerate(batch_texts, start=start_idx)
]

# 3. Process extraction in parallel while keeping track of indices
batch_results = await asyncio.gather(
*(task for _, task in extraction_tasks)
)

# 4. Place results in the correct positions
for (idx, _), graphs in zip(extraction_tasks, batch_results):
graphs_list[idx] = graphs

assert all(x is not None for x in graphs_list), "All positions should be filled"
return graphs_list

def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]:
graph = MemoryGraph()
edge_count = 0
Expand Down
28 changes: 28 additions & 0 deletions dbgpt/rag/transformer/llm_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""TripletExtractor class."""

import asyncio
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
Expand All @@ -22,6 +24,32 @@ async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract by LLM."""
return await self._extract(text, None, limit)

async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract by LLM."""
if batch_size < 1:
raise ValueError("batch_size >= 1")

results = []

for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]

# Create tasks for current batch
extraction_tasks = [
self._extract(text, None, limit) for text in batch_texts
]

# Execute batch concurrently and wait for all to complete
batch_results = await asyncio.gather(*extraction_tasks)
results.extend(batch_results)

return results

async def _extract(
self, text: str, history: str = None, limit: Optional[int] = None
) -> List:
Expand Down
3 changes: 2 additions & 1 deletion dbgpt/rag/transformer/triplet_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""TripletExtractor class."""

import logging
import re
from typing import Any, List, Optional, Tuple
Expand All @@ -12,7 +13,7 @@
"Some text is provided below. Given the text, "
"extract up to knowledge triplets as more as possible "
"in the form of (subject, predicate, object).\n"
"Avoid stopwords.\n"
"Avoid stopwords. The subject, predicate, object can not be none.\n"
"---------------------\n"
"Example:\n"
"Text: Alice is Bob's mother.\n"
Expand Down
44 changes: 42 additions & 2 deletions dbgpt/serve/dbgpts/hub/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
You can define your own models and DAOs here
"""
from datetime import datetime
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union

from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from sqlalchemy import Column, DateTime, Index, Integer, String, UniqueConstraint, desc

from dbgpt.storage.metadata import BaseDao, Model, db
from dbgpt.util.pagination_utils import PaginationResult

from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
Expand Down Expand Up @@ -109,3 +110,42 @@ def to_response(self, entity: ServeEntity) -> ServerResponse:
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
)

def dbgpts_list(
self,
query_request: ServeRequest,
page: int,
page_size: int,
desc_order_column: Optional[str] = None,
) -> PaginationResult[ServerResponse]:
"""Get a page of dbgpts.

Args:
query_request (ServeRequest): The request schema object or dict for query.
page (int): The page number.
page_size (int): The page size.
desc_order_column(Optional[str]): The column for descending order.
Returns:
PaginationResult: The pagination result.
"""
session = self.get_raw_session()
try:
query = session.query(ServeEntity)
if query_request.name:
query = query.filter(ServeEntity.name.like(f"%{query_request.name}%"))
if desc_order_column:
query = query.order_by(desc(getattr(ServeEntity, desc_order_column)))
total_count = query.count()
items = query.offset((page - 1) * page_size).limit(page_size)
res_items = [self.to_response(item) for item in items]
total_pages = (total_count + page_size - 1) // page_size
finally:
session.close()

return PaginationResult(
items=res_items,
total_count=total_count,
total_pages=total_pages,
page=page,
page_size=page_size,
)
4 changes: 2 additions & 2 deletions dbgpt/serve/dbgpts/hub/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def init_app(self, system_app: SystemApp) -> None:
self._system_app = system_app

@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
def dao(self) -> ServeDao:
"""Returns the internal DAO."""
return self._dao

Expand Down Expand Up @@ -130,7 +130,7 @@ def get_list_by_page(
installed=request.installed,
)

return self.dao.get_list_page(query_request, page, page_size)
return self.dao.dbgpts_list(query_request, page, page_size)

def refresh_hub_from_git(
self,
Expand Down
8 changes: 0 additions & 8 deletions dbgpt/storage/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ class GraphStoreConfig(BaseModel):
default=False,
description="Enable graph community summary or not.",
)
document_graph_enabled: bool = Field(
default=True,
description="Enable document graph search or not.",
)
triplet_graph_enabled: bool = Field(
default=True,
description="Enable knowledge graph search or not.",
)


class GraphStoreBase(ABC):
Expand Down
8 changes: 0 additions & 8 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,6 @@ def __init__(self, config: TuGraphStoreConfig) -> None:
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.enable_summary
)
self._enable_document_graph = (
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
or config.document_graph_enabled
)
self._enable_triplet_graph = (
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
or config.triplet_graph_enabled
)
self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
or config.plugin_names
Expand Down
Loading