Skip to content

Commit

Permalink
Merge pull request #4 from eosphoros-ai/main
Browse files Browse the repository at this point in the history
mg
  • Loading branch information
hiyizi authored Nov 6, 2024
2 parents b867fa2 + 8593f10 commit cfce1ac
Show file tree
Hide file tree
Showing 12 changed files with 406 additions and 248 deletions.
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ TRIPLET_GRAPH_ENABLED=True # enable the graph search for triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks

KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks
KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE=20 # the batch size of triplet extraction from the text

### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data
Expand Down
10 changes: 10 additions & 0 deletions dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Transformer base class."""

import logging
from abc import ABC, abstractmethod
from typing import List, Optional
Expand Down Expand Up @@ -37,6 +38,15 @@ class ExtractorBase(TransformerBase, ABC):
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract results from text."""

@abstractmethod
async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract results from texts."""


class TranslatorBase(TransformerBase, ABC):
"""Translator base class."""
98 changes: 80 additions & 18 deletions dbgpt/rag/transformer/graph_extractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""GraphExtractor class."""

import asyncio
import logging
import re
from typing import List, Optional
from typing import Dict, List, Optional

from dbgpt.core import Chunk, LLMClient
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
Expand All @@ -23,35 +24,96 @@ def __init__(
self._chunk_history = chunk_history

config = self._chunk_history.get_config()

self._vector_space = config.name
self._max_chunks_once_load = config.max_chunks_once_load
self._max_threads = config.max_threads
self._topk = config.topk
self._score_threshold = config.score_threshold

async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Load similar chunks."""
# load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]
context = "\n".join(history) if history else ""

try:
# extract with chunk history
return await super()._extract(text, context, limit)

finally:
# save chunk to history
async def aload_chunk_context(self, texts: List[str]) -> Dict[str, str]:
"""Load chunk context."""
text_context_map: Dict[str, str] = {}

for text in texts:
# Load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]

# Save chunk to history
await self._chunk_history.aload_document_with_limit(
[Chunk(content=text, metadata={"relevant_cnt": len(history)})],
self._max_chunks_once_load,
self._max_threads,
)

# Save chunk context to map
context = "\n".join(history) if history else ""
text_context_map[text] = context
return text_context_map

async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract graphs from text.
Suggestion: to extract triplets in batches, call `batch_extract`.
"""
# Load similar chunks
text_context_map = await self.aload_chunk_context([text])
context = text_context_map[text]

# Extract with chunk history
return await super()._extract(text, context, limit)

async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List[List[Graph]]:
"""Extract graphs from chunks in batches.
Returns list of graphs in same order as input texts (text <-> graphs).
"""
if batch_size < 1:
raise ValueError("batch_size >= 1")

# 1. Load chunk context
text_context_map = await self.aload_chunk_context(texts)

# Pre-allocate results list to maintain order
graphs_list: List[List[Graph]] = [None] * len(texts)
total_batches = (len(texts) + batch_size - 1) // batch_size

for batch_idx in range(total_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]

# 2. Create tasks with their original indices
extraction_tasks = [
(
idx,
self._extract(text, text_context_map[text], limit),
)
for idx, text in enumerate(batch_texts, start=start_idx)
]

# 3. Process extraction in parallel while keeping track of indices
batch_results = await asyncio.gather(
*(task for _, task in extraction_tasks)
)

# 4. Place results in the correct positions
for (idx, _), graphs in zip(extraction_tasks, batch_results):
graphs_list[idx] = graphs

assert all(x is not None for x in graphs_list), "All positions should be filled"
return graphs_list

def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]:
graph = MemoryGraph()
edge_count = 0
Expand Down
28 changes: 28 additions & 0 deletions dbgpt/rag/transformer/llm_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""TripletExtractor class."""

import asyncio
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
Expand All @@ -22,6 +24,32 @@ async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract by LLM."""
return await self._extract(text, None, limit)

async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract by LLM."""
if batch_size < 1:
raise ValueError("batch_size >= 1")

results = []

for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]

# Create tasks for current batch
extraction_tasks = [
self._extract(text, None, limit) for text in batch_texts
]

# Execute batch concurrently and wait for all to complete
batch_results = await asyncio.gather(*extraction_tasks)
results.extend(batch_results)

return results

async def _extract(
self, text: str, history: str = None, limit: Optional[int] = None
) -> List:
Expand Down
3 changes: 2 additions & 1 deletion dbgpt/rag/transformer/triplet_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""TripletExtractor class."""

import logging
import re
from typing import Any, List, Optional, Tuple
Expand All @@ -12,7 +13,7 @@
"Some text is provided below. Given the text, "
"extract up to knowledge triplets as more as possible "
"in the form of (subject, predicate, object).\n"
"Avoid stopwords.\n"
"Avoid stopwords. The subject, predicate, object can not be none.\n"
"---------------------\n"
"Example:\n"
"Text: Alice is Bob's mother.\n"
Expand Down
44 changes: 42 additions & 2 deletions dbgpt/serve/dbgpts/hub/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
You can define your own models and DAOs here
"""
from datetime import datetime
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union

from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint
from sqlalchemy import Column, DateTime, Index, Integer, String, UniqueConstraint, desc

from dbgpt.storage.metadata import BaseDao, Model, db
from dbgpt.util.pagination_utils import PaginationResult

from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
Expand Down Expand Up @@ -109,3 +110,42 @@ def to_response(self, entity: ServeEntity) -> ServerResponse:
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
)

def dbgpts_list(
self,
query_request: ServeRequest,
page: int,
page_size: int,
desc_order_column: Optional[str] = None,
) -> PaginationResult[ServerResponse]:
"""Get a page of dbgpts.
Args:
query_request (ServeRequest): The request schema object or dict for query.
page (int): The page number.
page_size (int): The page size.
desc_order_column(Optional[str]): The column for descending order.
Returns:
PaginationResult: The pagination result.
"""
session = self.get_raw_session()
try:
query = session.query(ServeEntity)
if query_request.name:
query = query.filter(ServeEntity.name.like(f"%{query_request.name}%"))
if desc_order_column:
query = query.order_by(desc(getattr(ServeEntity, desc_order_column)))
total_count = query.count()
items = query.offset((page - 1) * page_size).limit(page_size)
res_items = [self.to_response(item) for item in items]
total_pages = (total_count + page_size - 1) // page_size
finally:
session.close()

return PaginationResult(
items=res_items,
total_count=total_count,
total_pages=total_pages,
page=page,
page_size=page_size,
)
4 changes: 2 additions & 2 deletions dbgpt/serve/dbgpts/hub/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def init_app(self, system_app: SystemApp) -> None:
self._system_app = system_app

@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
def dao(self) -> ServeDao:
"""Returns the internal DAO."""
return self._dao

Expand Down Expand Up @@ -130,7 +130,7 @@ def get_list_by_page(
installed=request.installed,
)

return self.dao.get_list_page(query_request, page, page_size)
return self.dao.dbgpts_list(query_request, page, page_size)

def refresh_hub_from_git(
self,
Expand Down
8 changes: 0 additions & 8 deletions dbgpt/storage/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ class GraphStoreConfig(BaseModel):
default=False,
description="Enable graph community summary or not.",
)
document_graph_enabled: bool = Field(
default=True,
description="Enable document graph search or not.",
)
triplet_graph_enabled: bool = Field(
default=True,
description="Enable knowledge graph search or not.",
)


class GraphStoreBase(ABC):
Expand Down
8 changes: 0 additions & 8 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,6 @@ def __init__(self, config: TuGraphStoreConfig) -> None:
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.enable_summary
)
self._enable_document_graph = (
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
or config.document_graph_enabled
)
self._enable_triplet_graph = (
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
or config.triplet_graph_enabled
)
self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
or config.plugin_names
Expand Down
Loading

0 comments on commit cfce1ac

Please sign in to comment.