Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

向量存储维度问题 #146

Closed
wangsj1018 opened this issue Oct 26, 2024 · 4 comments
Closed

向量存储维度问题 #146

wangsj1018 opened this issue Oct 26, 2024 · 4 comments

Comments

@wangsj1018
Copy link

使用智谱的大模型,一直出现向量存储的维度问题,麻烦大佬帮忙看看

这是我的代码

import os

import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc, compute_args_hash
from lightrag.base import BaseKVStorage
from zhipuai import ZhipuAI


WORKING_DIR = "./dickens"
ZHIPU_APIKEY = "我自己的key"

if not os.path.exists(WORKING_DIR):
    os.mkdir(WORKING_DIR)
    
    

async def zhipu_model_complete(
     prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
    model_name = kwargs['hashing_kv'].global_config['llm_model_name']
    return await zhipu_model_if_cache(
        model_name,
        prompt,
        system_prompt=system_prompt,
        history_messages=history_messages,
        **kwargs,
    )



async def zhipu_model_if_cache(
    model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
    kwargs.pop("max_tokens", None)
    kwargs.pop("response_format", None)
    zhipu_client = ZhipuAI(api_key=ZHIPU_APIKEY)
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})

    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
    messages.extend(history_messages)
    messages.append({"role": "user", "content": prompt})
    if hashing_kv is not None:
        args_hash = compute_args_hash(model, messages)
        if_cache_return = await hashing_kv.get_by_id(args_hash)
        if if_cache_return is not None:
            return if_cache_return["return"]

    response = zhipu_client.chat.completions.create(
        model=model, 
        messages=messages, 
        **kwargs
    )

    result = response.choices[0].message.content

    if hashing_kv is not None:
        await hashing_kv.upsert({args_hash: {"return": result, "model": model}})

    return result

    
async def zhipu_embedding(texts: list[str], embed_model) -> np.ndarray:
    zhipu_client = ZhipuAI(api_key=ZHIPU_APIKEY)
    response = zhipu_client.embeddings.create(model=embed_model, input=texts)

    return np.array([dp.embedding for dp in response.data])

rag = LightRAG(
    working_dir=WORKING_DIR,
    llm_model_func=zhipu_model_complete,  
    llm_model_name='GLM-4-Flash',
    embedding_func=EmbeddingFunc(
        embedding_dim=512,
        max_token_size=8192,
        func=lambda texts: zhipu_embedding(
            texts, 
            embed_model="embedding-3"
        )
    ),
)


with open("./book.txt", 'r', encoding='utf-8') as f:
    rag.insert(f.read())

# Perform naive search
# print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))

# # Perform local search
# print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))

# # Perform global search
# print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))

# # Perform hybrid search
print(rag.query("please provide the context or content from which I should", param=QueryParam(mode="hybrid")))

报错如下:

Traceback (most recent call last):
  File "D:\code\py\LightRAG\lightrag_zhipu_demo.py", line 87, in <module>
    rag.insert(f.read())
  File "D:\code\py\LightRAG\lightrag\lightrag.py", line 166, in insert
    return loop.run_until_complete(self.ainsert(string_or_strings))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\anaconda3\Lib\asyncio\base_events.py", line 687, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "D:\code\py\LightRAG\lightrag\lightrag.py", line 210, in ainsert
    await self.chunks_vdb.upsert(inserting_chunks)
  File "D:\code\py\LightRAG\lightrag\storage.py", line 103, in upsert
    results = self._client.upsert(datas=list_data)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\anaconda3\Lib\site-packages\nano_vectordb\dbs.py", line 108, in upsert
    self.__storage["matrix"] = np.vstack([self.__storage["matrix"], new_matrix])
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\anaconda3\Lib\site-packages\numpy\core\shape_base.py", line 289, in vstack
    return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 512 and the array at index 1 has size 2048
@Dormiveglia-elf
Copy link
Contributor

Please refer to my previous PR #116

@cristianohello
Copy link

怎么修改代码?还是报错

@wangsj1018
Copy link
Author

怎么修改代码?还是报错

不太会python, 我多练练

@wangsj1018
Copy link
Author

Please refer to my previous PR #116

thx

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants