diff --git a/src/retriever.js b/src/retriever.js index ef5bb07..b589922 100644 --- a/src/retriever.js +++ b/src/retriever.js @@ -1,8 +1,8 @@ -import { ChatOpenAI } from '@langchain/openai' import { z } from 'zod' import 'neo4j-driver' import { Neo4jGraph } from '@langchain/community/graphs/neo4j_graph' import { ChatPromptTemplate, PromptTemplate } from '@langchain/core/prompts' +import { ChatOpenAI } from '@langchain/openai' import { Neo4jVectorStore } from '@langchain/community/vectorstores/neo4j_vector' import { OpenAIEmbeddings } from '@langchain/openai' import { createStructuredOutputRunnable } from 'langchain/chains/openai_functions' @@ -67,6 +67,10 @@ const entityChain = createStructuredOutputRunnable({ llm }) +await graph.query( + 'CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]' +) + /** * Collects the neighborhood of entities mentioned in the question. * @@ -80,18 +84,18 @@ async function structuredRetriever(question) { for (const entity of entities.names) { const response = await graph.query( `CALL db.index.fulltext.queryNodes('entity', $query, - {limit:2}) - YIELD node,score - CALL { - MATCH (node)-[r:!MENTIONS]->(neighbor) - RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS - output - UNION - MATCH (node)<-[r:!MENTIONS]-(neighbor) - RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS - output - } - RETURN output LIMIT 50`, + {limit:2}) + YIELD node,score + CALL { + MATCH (node)-[r:!MENTIONS]->(neighbor) + RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS + output + UNION + MATCH (node)<-[r:!MENTIONS]-(neighbor) + RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS + output + } + RETURN output LIMIT 50`, { query: generateFullTextQuery(entity) } ) @@ -141,8 +145,6 @@ const retrieverChain = RunnableSequence.from([ retriever ]) -// const combineDocumentsChain = RunnableSequence.from([combineDocuments]) - const answerChain = answerPrompt.pipe(llm).pipe(new StringOutputParser()) const chain = RunnableSequence.from([ @@ -193,14 +195,11 @@ async function ask(question) { logResult(answer) } -await ask('Who is the billionaire in the Avengers group?') -await ask('What is his other name?') await ask('Loki is a native of which planet?') -await ask('What is the name of his brother') -await ask('What is the other name of Bruce Banner?') -await ask('Where is the Stark Tower located?') -await ask('Who is Loki?') -await ask('Who is Jarvis?') +await ask('What is the name of his brother?') +await ask('Who is the villain among the two?') +await ask('Who is Tony Stark?') +await ask('Does he own the Stark Tower?') await graph.close() await neo4jVectorIndex.close()