Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
chore: added QdrantKnowledgeBase.from_config()
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Jan 10, 2024
1 parent e343959 commit 6f01e6b
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/canopy/knowledge_base/qdrant/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Separate file to avoid circular imports
COLLECTION_NAME_PREFIX = "canopy--"
RESERVED_METADATA_KEYS = {"document_id", "text", "source", "chunk_id"}
DENSE_VECTOR_NAME = "dense"
RESERVED_METADATA_KEYS = {"document_id", "text", "source", "chunk_id"}
SPARSE_VECTOR_NAME = "sparse"
UUID_NAMESPACE = "867603e3-ba69-447d-a8ef-263dff19bda7"
21 changes: 19 additions & 2 deletions src/canopy/knowledge_base/qdrant/qdrant_knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,6 @@ def __init__(
else:
self._reranker = self._DEFAULT_COMPONENTS["reranker"]()

self._collection_params: Dict[str, Any] = {}

self._client, self._async_client = generate_clients(
location=location,
url=url,
Expand Down Expand Up @@ -293,6 +291,7 @@ async def aquery(
)]
>>> results = await kb.aquery(queries)
""" # noqa: E501
# TODO: Use aencode_queries() when implemented for the defaults
queries = self._encoder.encode_queries(queries)
results = [
await self._aquery_collection(q, global_metadata_filter) for q in queries
Expand Down Expand Up @@ -360,6 +359,7 @@ def upsert(
f"{forbidden_keys}. Please remove them and try again."
)

# TODO: Use achunk_documents, encode_documents when implemented for the defaults
chunks = self._chunker.chunk_documents(documents)
encoded_chunks = self._encoder.encode_documents(chunks)

Expand Down Expand Up @@ -625,6 +625,23 @@ def collection_name(self) -> str:
"""
return self._collection_name

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QdrantKnowledgeBase":
"""
Create a QdrantKnowledgeBase object from a configuration dictionary.
Args:
config: A dictionary containing the configuration for the Qdrant knowledge base.
Returns:
A QdrantKnowledgeBase object.
""" # noqa: E501

config = deepcopy(config)
config["params"] = config.get("params", {})
# TODO: Add support for collection creation config for use in the CLI
kb = cls._from_config(config)
return kb

@staticmethod
def _get_full_collection_name(collection_name: str) -> str:
if collection_name.startswith(COLLECTION_NAME_PREFIX):
Expand Down
9 changes: 9 additions & 0 deletions tests/system/knowledge_base/qdrant/test_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ===========================================================
# QdrantKnowledgeBase test configuration file
# ===========================================================

knowledge_base:
params:
default_top_k: 5
collection_name: test-config-collection
default_top_k: 10
14 changes: 12 additions & 2 deletions tests/system/knowledge_base/qdrant/test_qdrant_knowledge_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import random
from copy import copy
from pathlib import Path

import pytest
from dotenv import load_dotenv
Expand All @@ -19,6 +20,7 @@
from canopy.models.data_models import Query

from qdrant_client.qdrant_remote import QdrantRemote
from canopy_cli.cli import _load_kb_config
from tests.system.knowledge_base.qdrant.common import (
assert_chunks_in_collection,
assert_ids_in_collection,
Expand Down Expand Up @@ -242,7 +244,7 @@ def test_create_existing_collection(collection_full_name, knowledge_base):


def test_kb_non_existing_collection(knowledge_base):
kb = copy.copy(knowledge_base)
kb = copy(knowledge_base)

kb._collection_name = f"{COLLECTION_NAME_PREFIX}non-existing-collection"

Expand Down Expand Up @@ -311,3 +313,11 @@ def test_create_with_collection_encoder_dimension_none(collection_name, chunker)
assert "failed to infer" in str(e.value)
assert "dimension" in str(e.value)
assert f"{encoder.__class__.__name__} does not support" in str(e.value)


def test_knowlege_base_from_config():
config_path = Path(__file__).with_name("test_config.yml")
kb_config = _load_kb_config(config_path)
kb = QdrantKnowledgeBase.from_config(kb_config)
assert kb.collection_name == COLLECTION_NAME_PREFIX + "test-config-collection"
assert kb._default_top_k == 10

0 comments on commit 6f01e6b

Please sign in to comment.