diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f2349633..ce226199 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,14 +17,76 @@ env: POETRY_VERSION: "1.8.3" jobs: + prime-cache: + name: Prime HuggingFace Model Cache + runs-on: ubuntu-latest + env: + HF_HOME: ${{ github.workspace }}/hf_cache + steps: + - name: Check out repository + uses: actions/checkout@v3 + + - name: Cache HuggingFace Models + id: hf-cache + uses: actions/cache@v3 + with: + path: hf_cache + key: ${{ runner.os }}-hf-cache + + - name: Set HuggingFace token + run: | + mkdir -p ~/.huggingface + echo '{"token":"${{ secrets.HF_TOKEN }}"}' > ~/.huggingface/token + + - name: Set up Python 3.9 + uses: actions/setup-python@v4 + with: + python-version: 3.9 + cache: pip + + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + version: ${{ env.POETRY_VERSION }} + + - name: Install dependencies + run: | + poetry install --all-extras + + - name: Authenticate to Google Cloud + uses: google-github-actions/auth@v1 + with: + credentials_json: ${{ secrets.GOOGLE_CREDENTIALS }} + + - name: Run full test suite to prime cache + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + HF_HOME: ${{ github.workspace }}/hf_cache + OPENAI_API_KEY: ${{ secrets.OPENAI_KEY }} + GCP_LOCATION: ${{ secrets.GCP_LOCATION }} + GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} + COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} + MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} + VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} + AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }} + AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }} + AZURE_OPENAI_DEPLOYMENT_NAME: ${{ secrets.AZURE_OPENAI_DEPLOYMENT_NAME }} + OPENAI_API_VERSION: ${{ secrets.OPENAI_API_VERSION }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + run: | + make test-all + test: name: Python ${{ matrix.python-version }} - ${{ matrix.connection }} [redis ${{ matrix.redis-version }}] runs-on: ubuntu-latest - + needs: prime-cache + env: + HF_HOME: ${{ github.workspace }}/hf_cache strategy: fail-fast: false matrix: - python-version: [3.9, '3.10', 3.11, 3.12, 3.13] + python-version: ['3.10', '3.11', 3.12, 3.13] connection: ['hiredis', 'plain'] redis-version: ['6.2.6-v9', 'latest', '8.0-M03'] @@ -32,11 +94,17 @@ jobs: - name: Check out repository uses: actions/checkout@v3 + - name: Cache HuggingFace Models + uses: actions/cache@v3 + with: + path: hf_cache + key: ${{ runner.os }}-hf-cache + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - cache: 'pip' + cache: pip - name: Install Poetry uses: snok/install-poetry@v1 @@ -68,22 +136,23 @@ jobs: - name: Run tests if: matrix.connection == 'plain' && matrix.redis-version == 'latest' env: + HF_HOME: ${{ github.workspace }}/hf_cache OPENAI_API_KEY: ${{ secrets.OPENAI_KEY }} GCP_LOCATION: ${{ secrets.GCP_LOCATION }} GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} - AZURE_OPENAI_API_KEY: ${{secrets.AZURE_OPENAI_API_KEY}} - AZURE_OPENAI_ENDPOINT: ${{secrets.AZURE_OPENAI_ENDPOINT}} - AZURE_OPENAI_DEPLOYMENT_NAME: ${{secrets.AZURE_OPENAI_DEPLOYMENT_NAME}} - OPENAI_API_VERSION: ${{secrets.OPENAI_API_VERSION}} + AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }} + AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }} + AZURE_OPENAI_DEPLOYMENT_NAME: ${{ secrets.AZURE_OPENAI_DEPLOYMENT_NAME }} + OPENAI_API_VERSION: ${{ secrets.OPENAI_API_VERSION }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: | make test-all - - name: Run tests + - name: Run tests (alternate) if: matrix.connection != 'plain' || matrix.redis-version != 'latest' run: | make test @@ -91,21 +160,22 @@ jobs: - name: Run notebooks if: matrix.connection == 'plain' && matrix.redis-version == 'latest' env: + HF_HOME: ${{ github.workspace }}/hf_cache OPENAI_API_KEY: ${{ secrets.OPENAI_KEY }} GCP_LOCATION: ${{ secrets.GCP_LOCATION }} GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} - AZURE_OPENAI_API_KEY: ${{secrets.AZURE_OPENAI_API_KEY}} - AZURE_OPENAI_ENDPOINT: ${{secrets.AZURE_OPENAI_ENDPOINT}} - AZURE_OPENAI_DEPLOYMENT_NAME: ${{secrets.AZURE_OPENAI_DEPLOYMENT_NAME}} - OPENAI_API_VERSION: ${{secrets.OPENAI_API_VERSION}} + AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }} + AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }} + AZURE_OPENAI_DEPLOYMENT_NAME: ${{ secrets.AZURE_OPENAI_DEPLOYMENT_NAME }} + OPENAI_API_VERSION: ${{ secrets.OPENAI_API_VERSION }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: | docker run -d --name redis -p 6379:6379 redis/redis-stack-server:latest - make test-notebooks + make test-notebooks docs: runs-on: ubuntu-latest @@ -117,17 +187,17 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ env.PYTHON_VERSION }} - cache: 'pip' + cache: pip - name: Install Poetry uses: snok/install-poetry@v1 with: version: ${{ env.POETRY_VERSION }} - + - name: Install dependencies run: | poetry install --all-extras - name: Build docs run: | - make docs-build \ No newline at end of file + make docs-build diff --git a/redisvl/index/index.py b/redisvl/index/index.py index f52de03d..c4e5de62 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1,6 +1,7 @@ import asyncio import json import threading +import time import warnings import weakref from typing import ( @@ -13,6 +14,7 @@ Iterable, List, Optional, + Tuple, Union, ) @@ -26,6 +28,8 @@ import redis import redis.asyncio as aredis +from redis.client import NEVER_DECODE +from redis.commands.helpers import get_protocol_version # type: ignore from redis.commands.search.indexDefinition import IndexDefinition from redisvl.exceptions import RedisModuleVersionError, RedisSearchError @@ -48,6 +52,14 @@ {"name": "searchlight", "ver": 20810}, ] +SearchParams = Union[ + Tuple[ + Union[str, BaseQuery], + Union[Dict[str, Union[str, int, float, bytes]], None], + ], + Union[str, BaseQuery], +] + def process_results( results: "Result", query: BaseQuery, storage_type: StorageType @@ -349,7 +361,7 @@ def client(self) -> Optional[redis.Redis]: return self.__redis_client @property - def _redis_client(self) -> Optional[redis.Redis]: + def _redis_client(self) -> redis.Redis: """ Get a Redis client instance. @@ -652,6 +664,73 @@ def aggregate(self, *args, **kwargs) -> "AggregateResult": except Exception as e: raise RedisSearchError(f"Error while aggregating: {str(e)}") from e + def batch_search( + self, + queries: List[SearchParams], + batch_size: int = 10, + ) -> List["Result"]: + """Perform a search against the index for multiple queries. + + This method takes a list of queries and optionally query params and + returns a list of Result objects for each query. Results are + returned in the same order as the queries. + + Args: + queries (List[SearchParams]): The queries to search for. batch_size + (int, optional): The number of queries to search for at a time. + Defaults to 10. + + Returns: + List[Result]: The search results for each query. + """ + all_parsed = [] + search = self._redis_client.ft(self.schema.index.name) + options = {} + if get_protocol_version(self._redis_client) not in ["3", 3]: + options[NEVER_DECODE] = True + + for i in range(0, len(queries), batch_size): + batch_queries = queries[i : i + batch_size] + + # redis-py doesn't support calling `search` in a pipeline, + # so we need to manually execute each command in a pipeline + # and parse the results + with self._redis_client.pipeline(transaction=False) as pipe: + batch_built_queries = [] + for query in batch_queries: + if isinstance(query, tuple): + query_args, q = search._mk_query_args( # type: ignore + query[0], query_params=query[1] + ) + else: + query_args, q = search._mk_query_args( # type: ignore + query, query_params=None + ) + batch_built_queries.append(q) + pipe.execute_command( + "FT.SEARCH", + *query_args, + **options, + ) + + st = time.time() + results = pipe.execute() + + # We don't know how long each query took, so we'll use the total time + # for all queries in the batch as the duration for each query + duration = (time.time() - st) * 1000.0 + + for i, query_results in enumerate(results): + _built_query = batch_built_queries[i] + parsed_result = search._parse_search( # type: ignore + query_results, + query=_built_query, + duration=duration, + ) + # Return a parsed Result object for each query + all_parsed.append(parsed_result) + return all_parsed + def search(self, *args, **kwargs) -> "Result": """Perform a search against the index. @@ -669,6 +748,26 @@ def search(self, *args, **kwargs) -> "Result": except Exception as e: raise RedisSearchError(f"Error while searching: {str(e)}") from e + def batch_query( + self, queries: List[BaseQuery], batch_size: int = 10 + ) -> List[List[Dict[str, Any]]]: + """Execute a batch of queries and process results.""" + results = self.batch_search( + [(query.query, query.params) for query in queries], batch_size=batch_size + ) + all_parsed = [] + for query, batch_results in zip(queries, results): + parsed = process_results( + batch_results, + query=query, + storage_type=self.schema.index.storage_type, + ) + # Create separate lists of parsed results for each query + # passed in to the batch_search method, so that callers can + # access the results for each query individually + all_parsed.append(parsed) + return all_parsed + def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Execute a query and process results.""" results = self.search(query.query, query_params=query.params) @@ -1211,6 +1310,71 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult": except Exception as e: raise RedisSearchError(f"Error while aggregating: {str(e)}") from e + async def batch_search( + self, queries: List[SearchParams], batch_size: int = 10 + ) -> List["Result"]: + """Perform a search against the index for multiple queries. + + This method takes a list of queries and returns a list of Result objects + for each query. Results are returned in the same order as the queries. + + Args: + queries (List[SearchParams]): The queries to search for. batch_size + (int, optional): The number of queries to search for at a time. + Defaults to 10. + + Returns: + List[Result]: The search results for each query. + """ + all_results = [] + client = await self._get_client() + search = client.ft(self.schema.index.name) + options = {} + if get_protocol_version(client) not in ["3", 3]: + options[NEVER_DECODE] = True + + for i in range(0, len(queries), batch_size): + batch_queries = queries[i : i + batch_size] + + # redis-py doesn't support calling `search` in a pipeline, + # so we need to manually execute each command in a pipeline + # and parse the results + async with client.pipeline(transaction=False) as pipe: + batch_built_queries = [] + for query in batch_queries: + if isinstance(query, tuple): + query_args, q = search._mk_query_args( # type: ignore + query[0], query_params=query[1] + ) + else: + query_args, q = search._mk_query_args( # type: ignore + query, query_params=None + ) + batch_built_queries.append(q) + pipe.execute_command( + "FT.SEARCH", + *query_args, + **options, + ) + + st = time.time() + results = await pipe.execute() + + # We don't know how long each query took, so we'll use the total time + # for all queries in the batch as the duration for each query + duration = (time.time() - st) * 1000.0 + + for i, query_results in enumerate(results): + _built_query = batch_built_queries[i] + parsed_result = search._parse_search( # type: ignore + query_results, + query=_built_query, + duration=duration, + ) + # Return a parsed Result object for each query + all_results.append(parsed_result) + return all_results + async def search(self, *args, **kwargs) -> "Result": """Perform a search on this index. @@ -1227,6 +1391,27 @@ async def search(self, *args, **kwargs) -> "Result": except Exception as e: raise RedisSearchError(f"Error while searching: {str(e)}") from e + async def batch_query( + self, queries: List[BaseQuery], batch_size: int = 10 + ) -> List[List[Dict[str, Any]]]: + """Asynchronously execute a batch of queries and process results.""" + results = await self.batch_search( + [(query.query, query.params) for query in queries], batch_size=batch_size + ) + all_parsed = [] + for query, batch_results in zip(queries, results): + parsed = process_results( + batch_results, + query=query, + storage_type=self.schema.index.storage_type, + ) + # Create separate lists of parsed results for each query + # passed in to the batch_search method, so that callers can + # access the results for each query individually + all_parsed.append(parsed) + + return all_parsed + async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Asynchronously execute a query and process results.""" results = await self.search(query.query, query_params=query.params) diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index d1b42235..edc6c01a 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -8,6 +8,7 @@ from redisvl.exceptions import RedisModuleVersionError, RedisSearchError from redisvl.index import AsyncSearchIndex from redisvl.query import VectorQuery +from redisvl.query.query import FilterQuery from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType @@ -436,3 +437,117 @@ async def test_async_search_index_validates_redis_modules(redis_url): await index.create(overwrite=True, drop=True) mock_validate_async_redis.assert_called_once() + + +@pytest.mark.asyncio +async def test_batch_search(async_index): + await async_index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + await async_index.load(data, id_field="id") + + results = await async_index.batch_search(["@test:{foo}", "@test:{bar}"]) + assert len(results) == 2 + assert results[0].total == 1 + assert results[0].docs[0]["id"] == "rvl:1" + assert results[1].total == 1 + assert results[1].docs[0]["id"] == "rvl:2" + + +@pytest.mark.parametrize( + "queries", + [ + [ + [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + ], + [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + FilterQuery(filter_expression="@test:{baz}"), + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + FilterQuery(filter_expression="@test:{baz}"), + ], + ], + [ + [ + "@test:{foo}", + "@test:{bar}", + ], + [ + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + ], + ], + ], +) +@pytest.mark.asyncio +async def test_batch_search_with_multiple_batches(async_index, queries): + await async_index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + await async_index.load(data, id_field="id") + + results = await async_index.batch_search(queries[0]) + assert len(results) == 2 + assert results[0].total == 1 + assert results[0].docs[0]["id"] == "rvl:1" + assert results[1].total == 1 + assert results[1].docs[0]["id"] == "rvl:2" + + results = await async_index.batch_search( + queries[1], + batch_size=2, + ) + assert len(results) == 6 + + # First (and only) result for the first query + assert results[0].total == 1 + assert results[0].docs[0]["id"] == "rvl:1" + + # Second (and only) result for the second query + assert results[1].total == 1 + assert results[1].docs[0]["id"] == "rvl:2" + + # Third query should have zero results because there is no baz + assert results[2].total == 0 + + # Then the pattern repeats + assert results[3].total == 1 + assert results[3].docs[0]["id"] == "rvl:1" + assert results[4].total == 1 + assert results[4].docs[0]["id"] == "rvl:2" + assert results[5].total == 0 + + +@pytest.mark.asyncio +async def test_batch_query(async_index): + await async_index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + await async_index.load(data, id_field="id") + + query = FilterQuery(filter_expression="@test:{foo}") + results = await async_index.batch_query([query]) + + assert len(results) == 1 + assert results[0][0]["id"] == "rvl:1" + + +@pytest.mark.asyncio +async def test_batch_query_with_multiple_batches(async_index): + await async_index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + await async_index.load(data, id_field="id") + + queries = [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + ] + results = await async_index.batch_query(queries, batch_size=1) + assert len(results) == 2 + assert results[0][0]["id"] == "rvl:1" + assert results[1][0]["id"] == "rvl:2" diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index 368c048a..800f6a06 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -7,6 +7,7 @@ from redisvl.exceptions import RedisModuleVersionError, RedisSearchError from redisvl.index import SearchIndex from redisvl.query import VectorQuery +from redisvl.query.query import FilterQuery from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType @@ -389,3 +390,109 @@ def test_search_index_validates_redis_modules(redis_url): index.create(overwrite=True, drop=True) mock_validate_sync_redis.assert_called_once() + + +def test_batch_search(index): + index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + index.load(data, id_field="id") + + results = index.batch_search(["@test:{foo}", "@test:{bar}"]) + assert len(results) == 2 + assert results[0].total == 1 + assert results[0].docs[0]["id"] == "rvl:1" + assert results[1].total == 1 + assert results[1].docs[0]["id"] == "rvl:2" + + +@pytest.mark.parametrize( + "queries", + [ + [ + [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + ], + [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + FilterQuery(filter_expression="@test:{baz}"), + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + FilterQuery(filter_expression="@test:{baz}"), + ], + ], + [ + [ + "@test:{foo}", + "@test:{bar}", + ], + [ + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + "@test:{foo}", + "@test:{bar}", + "@test:{baz}", + ], + ], + ], +) +def test_batch_search_with_multiple_batches(index, queries): + index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + index.load(data, id_field="id") + + results = index.batch_search(queries[0]) + assert len(results) == 2 + assert results[0].total == 1 + assert results[0].docs[0]["id"] == "rvl:1" + assert results[1].total == 1 + assert results[1].docs[0]["id"] == "rvl:2" + + results = index.batch_search( + queries[1], + batch_size=2, + ) + assert len(results) == 6 + + # First (and only) result for the first query + assert results[0].docs[0]["id"] == "rvl:1" + + # Second (and only) result for the second query + assert results[1].docs[0]["id"] == "rvl:2" + + # Third query should have zero results because there is no baz + assert results[2].total == 0 + + # Then the pattern repeats + assert results[3].docs[0]["id"] == "rvl:1" + assert results[4].docs[0]["id"] == "rvl:2" + assert results[5].total == 0 + + +def test_batch_query(index): + index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + index.load(data, id_field="id") + + query = FilterQuery(filter_expression="@test:{foo}") + results = index.batch_query([query]) + + assert len(results) == 1 + assert results[0][0]["id"] == "rvl:1" + + +def test_batch_query_with_multiple_batches(index): + index.create(overwrite=True, drop=True) + data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] + index.load(data, id_field="id") + + queries = [ + FilterQuery(filter_expression="@test:{foo}"), + FilterQuery(filter_expression="@test:{bar}"), + ] + results = index.batch_query(queries, batch_size=1) + assert len(results) == 2 + assert results[0][0]["id"] == "rvl:1" + assert results[1][0]["id"] == "rvl:2"