diff --git a/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisor.java b/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisor.java index 8c301fbdd04..d692d4df4d5 100644 --- a/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisor.java +++ b/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisor.java @@ -155,11 +155,14 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh @Nullable protected Filter.Expression doGetFilterExpression(Map context) { - if (!context.containsKey(FILTER_EXPRESSION) - || !StringUtils.hasText(context.get(FILTER_EXPRESSION).toString())) { + var filterExpression = context.get(FILTER_EXPRESSION); + if (filterExpression instanceof Filter.Expression) { + return (Filter.Expression) filterExpression; + } + if (!context.containsKey(FILTER_EXPRESSION) || !StringUtils.hasText(filterExpression.toString())) { return this.searchRequest.getFilterExpression(); } - return new FilterExpressionTextParser().parse(context.get(FILTER_EXPRESSION).toString()); + return new FilterExpressionTextParser().parse(filterExpression.toString()); } @Override diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/retrieval/search/VectorStoreDocumentRetrieverIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/retrieval/search/VectorStoreDocumentRetrieverIT.java index 2cd6bf8b01a..6648d840505 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/retrieval/search/VectorStoreDocumentRetrieverIT.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/retrieval/search/VectorStoreDocumentRetrieverIT.java @@ -30,6 +30,7 @@ import org.springframework.ai.rag.retrieval.search.DocumentRetriever; import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; import org.springframework.ai.vectorstore.pgvector.PgVectorStore; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; @@ -134,4 +135,29 @@ void withRequestFilter() { assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("4").getId())); } + @Test + void withRequestFilterExpression() { + FilterExpressionBuilder b = new FilterExpressionBuilder(); + DocumentRetriever documentRetriever = VectorStoreDocumentRetriever.builder() + .vectorStore(this.pgVectorStore) + .similarityThreshold(0.50) + .topK(3) + .build(); + + Query query = Query.builder() + .text("Who is Anacletus?") + .context(Map.of(VectorStoreDocumentRetriever.FILTER_EXPRESSION, + b.eq("location", "Whispering Woods").build())) + .build(); + List retrievedDocuments = documentRetriever.retrieve(query); + + assertThat(retrievedDocuments).hasSize(2); + assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("1").getId())); + assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("2").getId())); + + // No request filter expression applied, so full access to all documents. + retrievedDocuments = documentRetriever.retrieve(new Query("Who is Birba?")); + assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("4").getId())); + } + }