Skip to content

Commit b04898c

Browse files
committed
Fix missing inclusion of neo4j_schema in custom prompt generation
1 parent c98f1ba commit b04898c

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

src/neo4j_graphrag/retrievers/text2cypher.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,13 @@ def __init__(
9999
self.result_formatter = validated_data.result_formatter
100100
self.custom_prompt = validated_data.custom_prompt
101101
if validated_data.custom_prompt:
102-
neo4j_schema = ""
102+
if (
103+
validated_data.neo4j_schema_model
104+
and validated_data.neo4j_schema_model.neo4j_schema
105+
):
106+
neo4j_schema = validated_data.neo4j_schema_model.neo4j_schema
107+
else:
108+
neo4j_schema = ""
103109
else:
104110
if (
105111
validated_data.neo4j_schema_model

tests/unit/retrievers/test_text2cypher.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,71 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples
245245
llm.invoke.assert_called_once_with("This is a custom prompt. test")
246246

247247

248+
@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
249+
def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples_for_prompt_params(
250+
_verify_version_mock: MagicMock,
251+
driver: MagicMock,
252+
llm: MagicMock,
253+
neo4j_record: MagicMock,
254+
) -> None:
255+
prompt = "This is a custom prompt. {query_text} {schema} {examples}"
256+
neo4j_schema = "dummy-schema"
257+
examples = ["example-1", "example-2"]
258+
259+
retriever = Text2CypherRetriever(
260+
driver=driver,
261+
llm=llm,
262+
custom_prompt=prompt,
263+
neo4j_schema=neo4j_schema,
264+
examples=examples,
265+
)
266+
267+
driver.execute_query.return_value = (
268+
[neo4j_record],
269+
None,
270+
None,
271+
)
272+
retriever.search(query_text="test")
273+
274+
llm.invoke.assert_called_once_with(
275+
"This is a custom prompt. test dummy-schema example-1\nexample-2"
276+
)
277+
278+
279+
@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
280+
def test_t2c_retriever_initialization_with_custom_prompt_and_unused_schema_and_examples(
281+
_verify_version_mock: MagicMock,
282+
driver: MagicMock,
283+
llm: MagicMock,
284+
neo4j_record: MagicMock,
285+
) -> None:
286+
prompt = "This is a custom prompt. {query_text} {schema} {examples}"
287+
neo4j_schema = "dummy-schema"
288+
examples = ["example-1", "example-2"]
289+
290+
retriever = Text2CypherRetriever(
291+
driver=driver,
292+
llm=llm,
293+
custom_prompt=prompt,
294+
neo4j_schema=neo4j_schema,
295+
examples=examples,
296+
)
297+
298+
driver.execute_query.return_value = (
299+
[neo4j_record],
300+
None,
301+
None,
302+
)
303+
retriever.search(
304+
query_text="test",
305+
prompt_params={"schema": "another-dummy-schema", "examples": "another-example"},
306+
)
307+
308+
llm.invoke.assert_called_once_with(
309+
"This is a custom prompt. test another-dummy-schema another-example"
310+
)
311+
312+
248313
@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")
249314
def test_t2c_retriever_invalid_custom_prompt_type(
250315
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock

0 commit comments

Comments
 (0)