diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml new file mode 100644 index 00000000..7c12e0a2 --- /dev/null +++ b/.github/workflows/linting.yaml @@ -0,0 +1,30 @@ +name: Linting and Formatting + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + lint-and-format: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pre-commit + + - name: Run pre-commit + run: pre-commit run --all-files diff --git a/.gitignore b/.gitignore index 5a41ae32..fd4bd830 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ dickens/ book.txt lightrag-dev/ .idea/ -dist/ \ No newline at end of file +dist/ diff --git a/README.md b/README.md index dbabcb56..abd7ceb9 100644 --- a/README.md +++ b/README.md @@ -58,8 +58,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete ######### # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() -# import nest_asyncio -# nest_asyncio.apply() +# import nest_asyncio +# nest_asyncio.apply() ######### WORKING_DIR = "./dickens" @@ -157,7 +157,7 @@ rag = LightRAG(
Using Ollama Models - + * If you want to use Ollama models, you only need to set LightRAG as follows: ```python @@ -328,8 +328,8 @@ def main(): SET e.entity_type = node.entity_type, e.description = node.description, e.source_id = node.source_id, - e.displayName = node.id - REMOVE e:Entity + e.displayName = node.id + REMOVE e:Entity WITH e, node CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode RETURN count(*) @@ -382,7 +382,7 @@ def main(): except Exception as e: print(f"Error occurred: {e}") - + finally: driver.close() diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py index b455e6de..e4337a54 100644 --- a/examples/graph_visual_with_html.py +++ b/examples/graph_visual_with_html.py @@ -3,7 +3,7 @@ import random # Load the GraphML file -G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml') +G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml") # Create a Pyvis network net = Network(notebook=True) @@ -13,7 +13,7 @@ # Add colors to nodes for node in net.nodes: - node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) + node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) # Save and display the network -net.show('knowledge_graph.html') \ No newline at end of file +net.show("knowledge_graph.html") diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py index 22dde368..7377f21c 100644 --- a/examples/graph_visual_with_neo4j.py +++ b/examples/graph_visual_with_neo4j.py @@ -13,6 +13,7 @@ NEO4J_USERNAME = "neo4j" NEO4J_PASSWORD = "your_password" + def convert_xml_to_json(xml_path, output_path): """Converts XML file to JSON and saves the output.""" if not os.path.exists(xml_path): @@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path): json_data = xml_to_json(xml_path) if json_data: - with open(output_path, 'w', encoding='utf-8') as f: + with open(output_path, "w", encoding="utf-8") as f: json.dump(json_data, f, ensure_ascii=False, indent=2) print(f"JSON file created: {output_path}") return json_data @@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path): print("Failed to create JSON data") return None + def process_in_batches(tx, query, data, batch_size): """Process data in batches and execute the given query.""" for i in range(0, len(data), batch_size): - batch = data[i:i + batch_size] + batch = data[i : i + batch_size] tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch}) + def main(): # Paths - xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml') - json_file = os.path.join(WORKING_DIR, 'graph_data.json') + xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml") + json_file = os.path.join(WORKING_DIR, "graph_data.json") # Convert XML to JSON json_data = convert_xml_to_json(xml_file, json_file) @@ -46,8 +49,8 @@ def main(): return # Load nodes and edges - nodes = json_data.get('nodes', []) - edges = json_data.get('edges', []) + nodes = json_data.get("nodes", []) + edges = json_data.get("edges", []) # Neo4j queries create_nodes_query = """ @@ -56,8 +59,8 @@ def main(): SET e.entity_type = node.entity_type, e.description = node.description, e.source_id = node.source_id, - e.displayName = node.id - REMOVE e:Entity + e.displayName = node.id + REMOVE e:Entity WITH e, node CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode RETURN count(*) @@ -100,19 +103,24 @@ def main(): # Execute queries in batches with driver.session() as session: # Insert nodes in batches - session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES) + session.execute_write( + process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES + ) # Insert edges in batches - session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES) + session.execute_write( + process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES + ) # Set displayName and labels session.run(set_displayname_and_labels_query) except Exception as e: print(f"Error occurred: {e}") - + finally: driver.close() + if __name__ == "__main__": main() diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py index 25d3722c..2470fc00 100644 --- a/examples/lightrag_openai_compatible_demo.py +++ b/examples/lightrag_openai_compatible_demo.py @@ -52,6 +52,7 @@ async def test_funcs(): # asyncio.run(test_funcs()) + async def main(): try: embedding_dimension = await get_embedding_dim() @@ -61,35 +62,47 @@ async def main(): working_dir=WORKING_DIR, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, max_token_size=8192, func=embedding_func + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, ), ) - with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) ) # Perform local search print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) ) # Perform global search print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) + rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="global"), + ) ) # Perform hybrid search print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) + rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid"), + ) ) except Exception as e: print(f"An error occurred: {e}") + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py index 82cab228..a73f16c5 100644 --- a/examples/lightrag_siliconcloud_demo.py +++ b/examples/lightrag_siliconcloud_demo.py @@ -30,7 +30,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray: texts, model="netease-youdao/bce-embedding-base_v1", api_key=os.getenv("SILICONFLOW_API_KEY"), - max_token_size=512 + max_token_size=512, ) diff --git a/examples/vram_management_demo.py b/examples/vram_management_demo.py index ec750254..c173b913 100644 --- a/examples/vram_management_demo.py +++ b/examples/vram_management_demo.py @@ -27,11 +27,12 @@ # Read all .txt files from the TEXT_FILES_DIR directory texts = [] for filename in os.listdir(TEXT_FILES_DIR): - if filename.endswith('.txt'): + if filename.endswith(".txt"): file_path = os.path.join(TEXT_FILES_DIR, filename) - with open(file_path, 'r', encoding='utf-8') as file: + with open(file_path, "r", encoding="utf-8") as file: texts.append(file.read()) + # Batch insert texts into LightRAG with a retry mechanism def insert_texts_with_retry(rag, texts, retries=3, delay=5): for _ in range(retries): @@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5): rag.insert(texts) return except Exception as e: - print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...") + print( + f"Error occurred during insertion: {e}. Retrying in {delay} seconds..." + ) time.sleep(delay) raise RuntimeError("Failed to insert texts after multiple retries.") + insert_texts_with_retry(rag, texts) # Perform different types of queries and handle potential errors try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) except Exception as e: print(f"Error performing naive search: {e}") try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) except Exception as e: print(f"Error performing local search: {e}") try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) except Exception as e: print(f"Error performing global search: {e}") try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="hybrid") + ) + ) except Exception as e: print(f"Error performing hybrid search: {e}") + # Function to clear VRAM resources def clear_vram(): os.system("sudo nvidia-smi --gpu-reset") + # Regularly clear VRAM to prevent overflow clear_vram_interval = 3600 # Clear once every hour start_time = time.time() diff --git a/lightrag/llm.py b/lightrag/llm.py index 4dcf535c..eaaa2b75 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -7,7 +7,13 @@ import numpy as np import ollama -from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI +from openai import ( + AsyncOpenAI, + APIConnectionError, + RateLimitError, + Timeout, + AsyncAzureOpenAI, +) import base64 import struct @@ -70,26 +76,31 @@ async def openai_complete_if_cache( ) return response.choices[0].message.content + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) -async def azure_openai_complete_if_cache(model, +async def azure_openai_complete_if_cache( + model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, - **kwargs): + **kwargs, +): if api_key: os.environ["AZURE_OPENAI_API_KEY"] = api_key if base_url: os.environ["AZURE_OPENAI_ENDPOINT"] = base_url - openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - api_version=os.getenv("AZURE_OPENAI_API_VERSION")) + openai_async_client = AsyncAzureOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION"), + ) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] @@ -114,6 +125,7 @@ async def azure_openai_complete_if_cache(model, ) return response.choices[0].message.content + class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" @@ -205,8 +217,12 @@ async def bedrock_complete_if_cache( @lru_cache(maxsize=1) def initialize_hf_model(model_name): - hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True) - hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True) + hf_tokenizer = AutoTokenizer.from_pretrained( + model_name, device_map="auto", trust_remote_code=True + ) + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", trust_remote_code=True + ) if hf_tokenizer.pad_token is None: hf_tokenizer.pad_token = hf_tokenizer.eos_token @@ -328,8 +344,9 @@ async def gpt_4o_mini_complete( **kwargs, ) + async def azure_openai_complete( - prompt, system_prompt=None, history_messages=[], **kwargs + prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: return await azure_openai_complete_if_cache( "conversation-4o-mini", @@ -339,6 +356,7 @@ async def azure_openai_complete( **kwargs, ) + async def bedrock_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -418,9 +436,11 @@ async def azure_openai_embedding( if base_url: os.environ["AZURE_OPENAI_ENDPOINT"] = base_url - openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - api_version=os.getenv("AZURE_OPENAI_API_VERSION")) + openai_async_client = AsyncAzureOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION"), + ) response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" @@ -440,35 +460,28 @@ async def siliconcloud_embedding( max_token_size: int = 512, api_key: str = None, ) -> np.ndarray: - if api_key and not api_key.startswith('Bearer '): - api_key = 'Bearer ' + api_key + if api_key and not api_key.startswith("Bearer "): + api_key = "Bearer " + api_key - headers = { - "Authorization": api_key, - "Content-Type": "application/json" - } + headers = {"Authorization": api_key, "Content-Type": "application/json"} truncate_texts = [text[0:max_token_size] for text in texts] - payload = { - "model": model, - "input": truncate_texts, - "encoding_format": "base64" - } + payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"} base64_strings = [] async with aiohttp.ClientSession() as session: async with session.post(base_url, headers=headers, json=payload) as response: content = await response.json() - if 'code' in content: + if "code" in content: raise ValueError(content) - base64_strings = [item['embedding'] for item in content['data']] - + base64_strings = [item["embedding"] for item in content["data"]] + embeddings = [] for string in base64_strings: decode_bytes = base64.b64decode(string) n = len(decode_bytes) // 4 - float_array = struct.unpack('<' + 'f' * n, decode_bytes) + float_array = struct.unpack("<" + "f" * n, decode_bytes) embeddings.append(float_array) return np.array(embeddings) @@ -563,6 +576,7 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray: return embed_text + class Model(BaseModel): """ This is a Pydantic model class named 'Model' that is used to define a custom language model. @@ -580,14 +594,20 @@ class Model(BaseModel): The 'kwargs' dictionary contains the model name and API key to be passed to the function. """ - gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string") - kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc") + gen_func: Callable[[Any], str] = Field( + ..., + description="A function that generates the response from the llm. The response must be a string", + ) + kwargs: Dict[str, Any] = Field( + ..., + description="The arguments to pass to the callable function. Eg. the api key, model name, etc", + ) class Config: arbitrary_types_allowed = True -class MultiModel(): +class MultiModel: """ Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier. Could also be used for spliting across diffrent models or providers. @@ -611,26 +631,31 @@ class MultiModel(): ) ``` """ + def __init__(self, models: List[Model]): self._models = models self._current_model = 0 - + def _next_model(self): self._current_model = (self._current_model + 1) % len(self._models) return self._models[self._current_model] async def llm_model_func( - self, - prompt, system_prompt=None, history_messages=[], **kwargs + self, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: - kwargs.pop("model", None) # stop from overwriting the custom model name + kwargs.pop("model", None) # stop from overwriting the custom model name next_model = self._next_model() - args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs) - - return await next_model.gen_func( - **args + args = dict( + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + **next_model.kwargs, ) + return await next_model.gen_func(**args) + + if __name__ == "__main__": import asyncio diff --git a/lightrag/utils.py b/lightrag/utils.py index 9a68c16b..0da4a51a 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -185,6 +185,7 @@ def save_data_to_file(data, file_name): with open(file_name, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=4) + def xml_to_json(xml_file): try: tree = ET.parse(xml_file) @@ -194,31 +195,42 @@ def xml_to_json(xml_file): print(f"Root element: {root.tag}") print(f"Root attributes: {root.attrib}") - data = { - "nodes": [], - "edges": [] - } + data = {"nodes": [], "edges": []} # Use namespace - namespace = {'': 'http://graphml.graphdrawing.org/xmlns'} + namespace = {"": "http://graphml.graphdrawing.org/xmlns"} - for node in root.findall('.//node', namespace): + for node in root.findall(".//node", namespace): node_data = { - "id": node.get('id').strip('"'), - "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "", - "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "", - "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else "" + "id": node.get("id").strip('"'), + "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') + if node.find("./data[@key='d0']", namespace) is not None + else "", + "description": node.find("./data[@key='d1']", namespace).text + if node.find("./data[@key='d1']", namespace) is not None + else "", + "source_id": node.find("./data[@key='d2']", namespace).text + if node.find("./data[@key='d2']", namespace) is not None + else "", } data["nodes"].append(node_data) - for edge in root.findall('.//edge', namespace): + for edge in root.findall(".//edge", namespace): edge_data = { - "source": edge.get('source').strip('"'), - "target": edge.get('target').strip('"'), - "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0, - "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "", - "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "", - "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else "" + "source": edge.get("source").strip('"'), + "target": edge.get("target").strip('"'), + "weight": float(edge.find("./data[@key='d3']", namespace).text) + if edge.find("./data[@key='d3']", namespace) is not None + else 0.0, + "description": edge.find("./data[@key='d4']", namespace).text + if edge.find("./data[@key='d4']", namespace) is not None + else "", + "keywords": edge.find("./data[@key='d5']", namespace).text + if edge.find("./data[@key='d5']", namespace) is not None + else "", + "source_id": edge.find("./data[@key='d6']", namespace).text + if edge.find("./data[@key='d6']", namespace) is not None + else "", } data["edges"].append(edge_data) diff --git a/requirements.txt b/requirements.txt index 5b3396fb..98f32b0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,15 @@ accelerate aioboto3 +aiohttp graspologic hnswlib nano-vectordb networkx ollama openai +pyvis tenacity tiktoken torch transformers xxhash -pyvis -aiohttp \ No newline at end of file