Skip to content

Commit 4d7bbee

Browse files
Jinash RouniyarJinash Rouniyar
authored andcommitted
Added support for knowledge param
1 parent 061c4ed commit 4d7bbee

File tree

3 files changed

+52
-0
lines changed

3 files changed

+52
-0
lines changed

integration/test_collection_openai.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,51 @@ def test_contextualai_generative_search_single(
787787
assert res.generated is None
788788

789789

790+
def test_contextualai_generative_with_knowledge_parameter(
791+
collection_factory: CollectionFactory,
792+
) -> None:
793+
"""Test Contextual AI generative search with knowledge parameter override."""
794+
api_key = os.environ.get("CONTEXTUAL_API_KEY")
795+
if api_key is None:
796+
pytest.skip("No Contextual AI API key found.")
797+
798+
collection = collection_factory(
799+
name="TestContextualAIGenerativeKnowledge",
800+
generative_config=Configure.Generative.contextualai(
801+
model="v2",
802+
max_tokens=100,
803+
temperature=0.1,
804+
system_prompt="You are a helpful assistant.",
805+
avoid_commentary=False,
806+
),
807+
vectorizer_config=Configure.Vectorizer.none(),
808+
properties=[
809+
Property(name="text", data_type=DataType.TEXT),
810+
],
811+
headers={"X-Contextual-Api-Key": api_key},
812+
ports=(8086, 50057),
813+
)
814+
if collection._connection._weaviate_version.is_lower_than(1, 23, 1):
815+
pytest.skip("Generative search requires Weaviate 1.23.1 or higher")
816+
817+
collection.data.insert_many(
818+
[
819+
DataObject(properties={"text": "base knowledge"}),
820+
]
821+
)
822+
823+
# Test with knowledge parameter override
824+
res = collection.generate.fetch_objects(
825+
single_prompt="What is the custom knowledge?",
826+
config=GenerativeConfig.contextualai(
827+
knowledge=["Custom knowledge override", "Additional context"],
828+
),
829+
)
830+
for obj in res.objects:
831+
assert obj.generated is not None
832+
assert isinstance(obj.generated, str)
833+
834+
790835
def test_contextualai_generative_and_rerank_combined(collection_factory: CollectionFactory) -> None:
791836
"""Test Contextual AI generative search combined with reranking."""
792837
contextual_api_key = os.environ.get("CONTEXTUAL_API_KEY")

test/collection/test_classes_generative.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def test_generative_parameters_images_parsing(
422422
top_p=0.9,
423423
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.",
424424
avoid_commentary=False,
425+
knowledge=["knowledge1", "knowledge2"],
425426
)._to_grpc(_GenerativeConfigRuntimeOptions(return_metadata=True)),
426427
generative_pb2.GenerativeProvider(
427428
return_metadata=True,
@@ -432,6 +433,7 @@ def test_generative_parameters_images_parsing(
432433
top_p=0.9,
433434
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.",
434435
avoid_commentary=False,
436+
knowledge=base_pb2.TextArray(values=["knowledge1", "knowledge2"]),
435437
),
436438
),
437439
),

weaviate/collections/classes/generative.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ class _GenerativeContextualAI(_GenerativeConfigRuntime):
455455
top_p: Optional[float]
456456
system_prompt: Optional[str]
457457
avoid_commentary: Optional[bool]
458+
knowledge: Optional[List[str]]
458459

459460
def _to_grpc(self, opts: _GenerativeConfigRuntimeOptions) -> generative_pb2.GenerativeProvider:
460461
self._validate_multi_modal(opts)
@@ -467,6 +468,7 @@ def _to_grpc(self, opts: _GenerativeConfigRuntimeOptions) -> generative_pb2.Gene
467468
top_p=self.top_p,
468469
system_prompt=self.system_prompt,
469470
avoid_commentary=self.avoid_commentary or False,
471+
knowledge=_to_text_array(self.knowledge),
470472
),
471473
)
472474

@@ -615,6 +617,7 @@ def contextualai(
615617
top_p: Optional[float] = None,
616618
system_prompt: Optional[str] = None,
617619
avoid_commentary: Optional[bool] = None,
620+
knowledge: Optional[List[str]] = None,
618621
) -> _GenerativeConfigRuntime:
619622
"""Create a `_GenerativeContextualAI` object for use with the `generative-contextualai` module.
620623
@@ -625,6 +628,7 @@ def contextualai(
625628
top_p: The top P to use. Defaults to `None`, which uses the server-defined default
626629
system_prompt: The system prompt to prepend to the conversation
627630
avoid_commentary: Whether to avoid model commentary in responses
631+
knowledge: Optional knowledge array to override the default knowledge from retrieved objects
628632
"""
629633
return _GenerativeContextualAI(
630634
model=model,
@@ -633,6 +637,7 @@ def contextualai(
633637
top_p=top_p,
634638
system_prompt=system_prompt,
635639
avoid_commentary=avoid_commentary,
640+
knowledge=knowledge,
636641
)
637642

638643
@staticmethod

0 commit comments

Comments
 (0)