diff --git a/README.md b/README.md index 1d138377..22765c27 100644 --- a/README.md +++ b/README.md @@ -121,19 +121,18 @@ Choose from multiple Redis deployment options: }) ``` -2. [Create a SearchIndex](https://docs.redisvl.com/en/stable/user_guide/01_getting_started.html#create-a-searchindex) class with an input schema and client connection in order to perform admin and search operations on your index in Redis: +2. [Create a SearchIndex](https://docs.redisvl.com/en/stable/user_guide/01_getting_started.html#create-a-searchindex) class with an input schema to perform admin and search operations on your index in Redis: ```python from redis import Redis from redisvl.index import SearchIndex - # Establish Redis connection and define index - client = Redis.from_url("redis://localhost:6379") - index = SearchIndex(schema, client) + # Define the index + index = SearchIndex(schema, redis_url="redis://localhost:6379") # Create the index in Redis index.create() ``` - > Async compliant search index class also available: [AsyncSearchIndex](https://docs.redisvl.com/en/stable/api/searchindex.html#redisvl.index.AsyncSearchIndex). + > An async-compatible index class also available: [AsyncSearchIndex](https://docs.redisvl.com/en/stable/api/searchindex.html#redisvl.index.AsyncSearchIndex). 3. [Load](https://docs.redisvl.com/en/stable/user_guide/01_getting_started.html#load-data-to-searchindex) and [fetch](https://docs.redisvl.com/en/stable/user_guide/01_getting_started.html#fetch-an-object-from-redis) data to/from your Redis instance: @@ -346,7 +345,7 @@ Commands: stats Obtain statistics about an index ``` -> Read more about [using the CLI](https://docs.redisvl.com/en/stable/user_guide/cli.html). +> Read more about [using the CLI](https://docs.redisvl.com/en/latest/overview/cli.html). ## 🚀 Why RedisVL? diff --git a/docs/user_guide/01_getting_started.ipynb b/docs/user_guide/01_getting_started.ipynb index 9b7d0f66..6130f589 100644 --- a/docs/user_guide/01_getting_started.ipynb +++ b/docs/user_guide/01_getting_started.ipynb @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -195,7 +195,7 @@ "Now we also need to facilitate a Redis connection. There are a few ways to do this:\n", "\n", "- Create & manage your own client connection (recommended)\n", - "- Provide a simple Redis URL and let RedisVL connect on your behalf" + "- Provide a Redis URL and let RedisVL connect on your behalf (by default, it will connect to \"redis://localhost:6379\")" ] }, { @@ -209,7 +209,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -227,9 +227,13 @@ "from redis import Redis\n", "\n", "client = Redis.from_url(\"redis://localhost:6379\")\n", + "index = SearchIndex.from_dict(schema, redis_client=client)\n", "\n", - "index.set_client(client)\n", - "# optionally provide an async Redis client object to enable async index operations" + "# alternatively, provide an async Redis client object to enable async index operations\n", + "# from redis.asyncio import Redis\n", + "# from redisvl.index import AsyncSearchIndex\n", + "# client = Redis.from_url(\"redis://localhost:6379\")\n", + "# index = AsyncSearchIndex.from_dict(schema, redis_client=client)\n" ] }, { @@ -243,7 +247,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -258,8 +262,10 @@ } ], "source": [ - "index.connect(\"redis://localhost:6379\")\n", - "# optionally use an async client by passing use_async=True" + "index = SearchIndex.from_dict(schema, redis_url=\"redis://localhost:6379\")\n", + "\n", + "# If you don't specify a client or Redis URL, the index will attempt to\n", + "# connect to Redis at the default address (\"redis://localhost:6379\")." ] }, { @@ -273,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -297,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -315,7 +321,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -358,7 +364,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -392,7 +398,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -429,7 +435,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -454,13 +460,20 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "*=>[KNN 3 @user_embedding $vector AS vector_distance] RETURN 6 user age job credit_score vector_distance vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 3\n" + ] + }, { "data": { "text/html": [ - "
vector_distanceuseragejobcredit_score
0john1engineerhigh
0mary2doctorlow
0.0566299557686tyler9engineerhigh
" + "table>vector_distanceuseragejobcredit_score0john1engineerhigh0mary2doctorlow0.0566299557686tyler9engineerhigh" ], "text/plain": [ "" @@ -487,7 +500,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -537,13 +550,12 @@ "\n", "client = Redis.from_url(\"redis://localhost:6379\")\n", "\n", - "index = AsyncSearchIndex.from_dict(schema)\n", - "await index.set_client(client)" + "index = AsyncSearchIndex.from_dict(schema, redis_client=client)" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -584,7 +596,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -609,14 +621,14 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "11:53:25 redisvl.index.index INFO Index already exists, overwriting.\n" + "11:28:32 redisvl.index.index INFO Index already exists, overwriting.\n" ] } ], @@ -627,13 +639,13 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceuseragejobcredit_score
0john1engineerhigh
0mary2doctorlow
0.0566299557686tyler9engineerhigh
" + "
vector_distanceuseragejobcredit_score
0mary2doctorlow
0john1engineerhigh
0.0566299557686tyler9engineerhigh
" ], "text/plain": [ "" @@ -659,7 +671,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -677,19 +689,19 @@ "│ num_records │ 22 │\n", "│ percent_indexed │ 1 │\n", "│ hash_indexing_failures │ 0 │\n", - "│ number_of_uses │ 5 │\n", - "│ bytes_per_record_avg │ 50.9091 │\n", + "│ number_of_uses │ 2 │\n", + "│ bytes_per_record_avg │ 47.8 │\n", "│ doc_table_size_mb │ 0.000423431 │\n", - "│ inverted_sz_mb │ 0.00106812 │\n", + "│ inverted_sz_mb │ 0.000911713 │\n", "│ key_table_size_mb │ 0.000165939 │\n", - "│ offset_bits_per_record_avg │ 8 │\n", - "│ offset_vectors_sz_mb │ 5.72205e-06 │\n", - "│ offsets_per_term_avg │ 0.272727 │\n", - "│ records_per_doc_avg │ 5.5 │\n", + "│ offset_bits_per_record_avg │ nan │\n", + "│ offset_vectors_sz_mb │ 0 │\n", + "│ offsets_per_term_avg │ 0 │\n", + "│ records_per_doc_avg │ 5 │\n", "│ sortable_values_size_mb │ 0 │\n", - "│ total_indexing_time │ 0.197 │\n", - "│ total_inverted_index_blocks │ 12 │\n", - "│ vector_index_sz_mb │ 0.0201416 │\n", + "│ total_indexing_time │ 0.239 │\n", + "│ total_inverted_index_blocks │ 11 │\n", + "│ vector_index_sz_mb │ 0.235603 │\n", "╰─────────────────────────────┴─────────────╯\n" ] } @@ -718,7 +730,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -727,7 +739,7 @@ "4" ] }, - "execution_count": 19, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -739,7 +751,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -748,7 +760,7 @@ "True" ] }, - "execution_count": 20, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -760,7 +772,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -771,7 +783,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "env", "language": "python", "name": "python3" }, diff --git a/docs/user_guide/02_hybrid_queries.ipynb b/docs/user_guide/02_hybrid_queries.ipynb index 4a89ffb3..9568669d 100644 --- a/docs/user_guide/02_hybrid_queries.ipynb +++ b/docs/user_guide/02_hybrid_queries.ipynb @@ -78,15 +78,20 @@ "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13:02:18 redisvl.index.index INFO Index already exists, overwriting.\n" + ] + } + ], "source": [ "from redisvl.index import SearchIndex\n", "\n", "# construct a search index from the schema\n", - "index = SearchIndex.from_dict(schema)\n", - "\n", - "# connect to local redis instance\n", - "index.connect(\"redis://localhost:6379\")\n", + "index = SearchIndex.from_dict(schema, redis_url=\"redis://localhost:6379\")\n", "\n", "# create the index (no data yet)\n", "index.create(overwrite=True)" @@ -101,8 +106,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m14:16:51\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m14:16:51\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_queries\n" + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. float64_cache\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 2. float64_session\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 3. float16_cache\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 4. float16_session\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 5. float32_session\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 6. float32_cache\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 7. bfloat_cache\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 8. user_queries\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 9. student tutor\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 10. tutor\n", + "\u001b[32m13:02:25\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 11. bfloat_session\n" ] } ], @@ -142,13 +157,13 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -183,7 +198,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.217882037163taimurlow15CEO-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -209,7 +224,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -235,7 +250,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -272,7 +287,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
" ], "text/plain": [ "" @@ -307,7 +322,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -334,7 +349,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -360,7 +375,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -395,7 +410,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -423,7 +438,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -449,7 +464,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -475,7 +490,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
" ], "text/plain": [ "" @@ -501,7 +516,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -527,7 +542,7 @@ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
" ], "text/plain": [ "" @@ -554,21 +569,37 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[{'id': 'user_queries_docs:409ff48274724984ba14865db0495fc5',\n", - " 'score': 0.9090908893868948,\n", + "[{'id': 'user_queries_docs:01JMJJHE28ZW4F33ZNRKXRHYCS',\n", + " 'score': 1.8181817787737895,\n", + " 'vector_distance': '0',\n", + " 'user': 'john',\n", + " 'credit_score': 'high',\n", + " 'age': '18',\n", + " 'job': 'engineer',\n", + " 'office_location': '-122.4194,37.7749'},\n", + " {'id': 'user_queries_docs:01JMJJHE2899024DYPXT6424N9',\n", + " 'score': 0.0,\n", + " 'vector_distance': '0',\n", + " 'user': 'derrick',\n", + " 'credit_score': 'low',\n", + " 'age': '14',\n", + " 'job': 'doctor',\n", + " 'office_location': '-122.4194,37.7749'},\n", + " {'id': 'user_queries_docs:01JMJJPEYCQ89ZQW6QR27J72WT',\n", + " 'score': 1.8181817787737895,\n", " 'vector_distance': '0',\n", " 'user': 'john',\n", " 'credit_score': 'high',\n", " 'age': '18',\n", " 'job': 'engineer',\n", " 'office_location': '-122.4194,37.7749'},\n", - " {'id': 'user_queries_docs:69cb262c303a4147b213dfdec8bd4b01',\n", + " {'id': 'user_queries_docs:01JMJJPEYD544WB1TKDBJ3Z3J9',\n", " 'score': 0.0,\n", " 'vector_distance': '0',\n", " 'user': 'derrick',\n", @@ -576,15 +607,31 @@ " 'age': '14',\n", " 'job': 'doctor',\n", " 'office_location': '-122.4194,37.7749'},\n", - " {'id': 'user_queries_docs:562263669ff74a0295c515018d151d7b',\n", - " 'score': 0.9090908893868948,\n", + " {'id': 'user_queries_docs:01JMJJHE28B5R6T00DH37A7KSJ',\n", + " 'score': 1.8181817787737895,\n", " 'vector_distance': '0.109129190445',\n", " 'user': 'tyler',\n", " 'credit_score': 'high',\n", " 'age': '100',\n", " 'job': 'engineer',\n", " 'office_location': '-122.0839,37.3861'},\n", - " {'id': 'user_queries_docs:94176145f9de4e288ca2460cd5d1188e',\n", + " {'id': 'user_queries_docs:01JMJJPEYDPF9S5328WHCQN0ND',\n", + " 'score': 1.8181817787737895,\n", + " 'vector_distance': '0.109129190445',\n", + " 'user': 'tyler',\n", + " 'credit_score': 'high',\n", + " 'age': '100',\n", + " 'job': 'engineer',\n", + " 'office_location': '-122.0839,37.3861'},\n", + " {'id': 'user_queries_docs:01JMJJHE28G5F943YGWMB1ZX1V',\n", + " 'score': 0.0,\n", + " 'vector_distance': '0.158808946609',\n", + " 'user': 'tim',\n", + " 'credit_score': 'high',\n", + " 'age': '12',\n", + " 'job': 'dermatologist',\n", + " 'office_location': '-122.0839,37.3861'},\n", + " {'id': 'user_queries_docs:01JMJJPEYDKA9ARKHRK1D7KPXQ',\n", " 'score': 0.0,\n", " 'vector_distance': '0.158808946609',\n", " 'user': 'tim',\n", @@ -592,7 +639,7 @@ " 'age': '12',\n", " 'job': 'dermatologist',\n", " 'office_location': '-122.0839,37.3861'},\n", - " {'id': 'user_queries_docs:d0bcf6842862410583901004b6b3aeba',\n", + " {'id': 'user_queries_docs:01JMJJHE28NR7KF0EZEA433T2J',\n", " 'score': 0.0,\n", " 'vector_distance': '0.217882037163',\n", " 'user': 'taimur',\n", @@ -600,25 +647,17 @@ " 'age': '15',\n", " 'job': 'CEO',\n", " 'office_location': '-122.0839,37.3861'},\n", - " {'id': 'user_queries_docs:3dec0e9f2db04e19bff224c5a2a0ba3c',\n", + " {'id': 'user_queries_docs:01JMJJPEYD9EAVGJ2AZ8K9VX7Q',\n", " 'score': 0.0,\n", - " 'vector_distance': '0.266666650772',\n", - " 'user': 'nancy',\n", - " 'credit_score': 'high',\n", - " 'age': '94',\n", - " 'job': 'doctor',\n", - " 'office_location': '-122.4194,37.7749'},\n", - " {'id': 'user_queries_docs:93ee6c0e4ccb42f6b7af7858ea6a6408',\n", - " 'score': 0.0,\n", - " 'vector_distance': '0.653301358223',\n", - " 'user': 'joe',\n", - " 'credit_score': 'medium',\n", - " 'age': '35',\n", - " 'job': 'dentist',\n", + " 'vector_distance': '0.217882037163',\n", + " 'user': 'taimur',\n", + " 'credit_score': 'low',\n", + " 'age': '15',\n", + " 'job': 'CEO',\n", " 'office_location': '-122.0839,37.3861'}]" ] }, - "execution_count": 32, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -641,13 +680,13 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
scorevector_distanceusercredit_scoreagejoboffice_location
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
scorevector_distanceusercredit_scoreagejoboffice_location
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740.266666650772nancyhigh94doctor-122.4194,37.7749
0.45454544469344740.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -669,13 +708,13 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
scorevector_distanceusercredit_scoreagejoboffice_location
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740.109129190445tylerhigh100engineer-122.0839,37.3861
0.45454544469344740.158808946609timhigh12dermatologist-122.0839,37.3861
0.45454544469344740.217882037163taimurlow15CEO-122.0839,37.3861
0.45454544469344740.266666650772nancyhigh94doctor-122.4194,37.7749
0.45454544469344740.653301358223joemedium35dentist-122.0839,37.3861
" + "
scorevector_distanceusercredit_scoreagejoboffice_location
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740johnhigh18engineer-122.4194,37.7749
0.45454544469344740derricklow14doctor-122.4194,37.7749
0.45454544469344740.109129190445tylerhigh100engineer-122.0839,37.3861
0.45454544469344740.109129190445tylerhigh100engineer-122.0839,37.3861
0.45454544469344740.158808946609timhigh12dermatologist-122.0839,37.3861
0.45454544469344740.158808946609timhigh12dermatologist-122.0839,37.3861
0.45454544469344740.217882037163taimurlow15CEO-122.0839,37.3861
0.45454544469344740.217882037163taimurlow15CEO-122.0839,37.3861
" ], "text/plain": [ "" @@ -695,13 +734,13 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
scorevector_distanceusercredit_scoreagejoboffice_location
0.00.109129190445tylerhigh100engineer-122.0839,37.3861
0.00.158808946609timhigh12dermatologist-122.0839,37.3861
0.00.217882037163taimurlow15CEO-122.0839,37.3861
0.00.653301358223joemedium35dentist-122.0839,37.3861
" + "
scorevector_distanceusercredit_scoreagejoboffice_location
0.00.109129190445tylerhigh100engineer-122.0839,37.3861
0.00.109129190445tylerhigh100engineer-122.0839,37.3861
0.00.158808946609timhigh12dermatologist-122.0839,37.3861
0.00.158808946609timhigh12dermatologist-122.0839,37.3861
0.00.217882037163taimurlow15CEO-122.0839,37.3861
0.00.217882037163taimurlow15CEO-122.0839,37.3861
0.00.653301358223joemedium35dentist-122.0839,37.3861
0.00.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -732,13 +771,13 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -775,13 +814,13 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0derricklow14doctor-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -819,7 +858,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -834,13 +873,13 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
" ], "text/plain": [ "" @@ -859,13 +898,13 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
" + "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
" ], "text/plain": [ "" @@ -884,13 +923,13 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
0.653301358223joemedium35dentist-122.0839,37.3861
" ], "text/plain": [ "" @@ -909,13 +948,13 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.266666650772nancyhigh94doctor-122.4194,37.7749
0.653301358223joemedium35dentist-122.0839,37.3861
" + "
vector_distanceusercredit_scoreagejoboffice_location
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0johnhigh18engineer-122.4194,37.7749
0derricklow14doctor-122.4194,37.7749
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.109129190445tylerhigh100engineer-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.158808946609timhigh12dermatologist-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
0.217882037163taimurlow15CEO-122.0839,37.3861
" ], "text/plain": [ "" @@ -943,13 +982,13 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
usercredit_scoreagejob
derricklow14doctor
taimurlow15CEO
" + "
usercredit_scoreagejob
derricklow14doctor
taimurlow15CEO
derricklow14doctor
taimurlow15CEO
" ], "text/plain": [ "" @@ -985,14 +1024,14 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "2 records match the filter expression @credit_score:{low} for the given index.\n" + "4 records match the filter expression @credit_score:{low} for the given index.\n" ] } ], @@ -1019,13 +1058,13 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0derricklow14doctor
0.109129190445tylerhigh100engineer
0.158808946609timhigh12dermatologist
" + "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0derricklow14doctor
0johnhigh18engineer
0derricklow14doctor
0.109129190445tylerhigh100engineer
0.109129190445tylerhigh100engineer
0.158808946609timhigh12dermatologist
0.158808946609timhigh12dermatologist
" ], "text/plain": [ "" @@ -1060,13 +1099,13 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0derricklow14doctor
" + "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0derricklow14doctor
0johnhigh18engineer
0derricklow14doctor
" ], "text/plain": [ "" @@ -1091,13 +1130,13 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
" + "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0johnhigh18engineer
" ], "text/plain": [ "" @@ -1131,13 +1170,13 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
vector_distanceageusercredit_scorejoboffice_location
0.109129190445100tylerhighengineer-122.0839,37.3861
018johnhighengineer-122.4194,37.7749
" + "
vector_distanceageusercredit_scorejoboffice_location
0.109129190445100tylerhighengineer-122.0839,37.3861
0.109129190445100tylerhighengineer-122.0839,37.3861
018johnhighengineer-122.4194,37.7749
018johnhighengineer-122.4194,37.7749
" ], "text/plain": [ "" @@ -1172,7 +1211,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -1181,7 +1220,7 @@ "'@job:(\"engineer\")=>[KNN 5 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY age DESC DIALECT 3 LIMIT 0 5'" ] }, - "execution_count": 49, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } @@ -1193,7 +1232,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 38, "metadata": {}, "outputs": [ { @@ -1202,7 +1241,7 @@ "'@credit_score:{high}'" ] }, - "execution_count": 50, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } @@ -1215,7 +1254,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 39, "metadata": {}, "outputs": [ { @@ -1224,7 +1263,7 @@ "'((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])'" ] }, - "execution_count": 51, + "execution_count": 39, "metadata": {}, "output_type": "execute_result" } @@ -1249,17 +1288,21 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'id': 'user_queries_docs:409ff48274724984ba14865db0495fc5', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:3dec0e9f2db04e19bff224c5a2a0ba3c', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:562263669ff74a0295c515018d151d7b', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", - "{'id': 'user_queries_docs:94176145f9de4e288ca2460cd5d1188e', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" + "{'id': 'user_queries_docs:01JMJJHE28G5F943YGWMB1ZX1V', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJHE28ZW4F33ZNRKXRHYCS', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJHE28B5R6T00DH37A7KSJ', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJHE28EX13NEE7BGBM8FH3', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJPEYCQ89ZQW6QR27J72WT', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJPEYDAN0M3V7EQEVPS6HX', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJPEYDPF9S5328WHCQN0ND', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'user_queries_docs:01JMJJPEYDKA9ARKHRK1D7KPXQ', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" ] } ], @@ -1271,7 +1314,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -1282,7 +1325,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('redisvl2')", + "display_name": "env", "language": "python", "name": "python3" }, @@ -1296,14 +1339,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.10" + "version": "3.11.11" }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "9b1e6e9c2967143209c2f955cb869d1d3234f92dc4787f49f155f3abbdfb1316" - } - } + "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/user_guide/04_vectorizers.ipynb b/docs/user_guide/04_vectorizers.ipynb index 3647a2fc..c4f862e1 100644 --- a/docs/user_guide/04_vectorizers.ipynb +++ b/docs/user_guide/04_vectorizers.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -68,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -80,9 +80,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vector dimensions: 1536\n" + ] + }, + { + "data": { + "text/plain": [ + "[-0.0011391325388103724,\n", + " -0.003206387162208557,\n", + " 0.002380132209509611,\n", + " -0.004501554183661938,\n", + " -0.010328996926546097,\n", + " 0.012922565452754498,\n", + " -0.005491119809448719,\n", + " -0.0029864837415516376,\n", + " -0.007327961269766092,\n", + " -0.03365817293524742]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from redisvl.utils.vectorize import OpenAITextVectorizer\n", "\n", @@ -99,9 +126,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[-0.017466850578784943,\n", + " 1.8471690054866485e-05,\n", + " 0.00129731057677418,\n", + " -0.02555876597762108,\n", + " -0.019842341542243958,\n", + " 0.01603139191865921,\n", + " -0.0037347301840782166,\n", + " 0.0009670283179730177,\n", + " 0.006618348415941,\n", + " -0.02497442066669464]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Create many embeddings at once\n", "sentences = [\n", @@ -116,9 +163,17 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of Embeddings: 3\n" + ] + } + ], "source": [ "# openai also supports asyncronous requests, which we can use to speed up the vectorization process.\n", "embeddings = await oai.aembed_many(sentences)\n", @@ -138,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -151,9 +206,23 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "AzureOpenAI API endpoint is required. Provide it in api_config or set the AZURE_OPENAI_ENDPOINT environment variable.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mredisvl\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mvectorize\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AzureOpenAITextVectorizer\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# create a vectorizer\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m az_oai \u001b[38;5;241m=\u001b[39m \u001b[43mAzureOpenAITextVectorizer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdeployment_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Must be your CUSTOM deployment name\u001b[39;49;00m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mapi_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mapi_key\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mapi_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mapi_version\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mapi_version\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mazure_endpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mazure_endpoint\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m test \u001b[38;5;241m=\u001b[39m az_oai\u001b[38;5;241m.\u001b[39membed(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThis is a test sentence.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mVector dimensions: \u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mlen\u001b[39m(test))\n", + "File \u001b[0;32m~/src/redis-vl-python/redisvl/utils/vectorize/text/azureopenai.py:78\u001b[0m, in \u001b[0;36mAzureOpenAITextVectorizer.__init__\u001b[0;34m(self, model, api_config, dtype)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 56\u001b[0m model: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext-embedding-ada-002\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 57\u001b[0m api_config: Optional[Dict] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 58\u001b[0m dtype: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 59\u001b[0m ):\n\u001b[1;32m 60\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Initialize the AzureOpenAI vectorizer.\u001b[39;00m\n\u001b[1;32m 61\u001b[0m \n\u001b[1;32m 62\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;124;03m ValueError: If an invalid dtype is provided.\u001b[39;00m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 78\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_initialize_clients\u001b[49m\u001b[43m(\u001b[49m\u001b[43mapi_config\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(model\u001b[38;5;241m=\u001b[39mmodel, dims\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_set_model_dims(model), dtype\u001b[38;5;241m=\u001b[39mdtype)\n", + "File \u001b[0;32m~/src/redis-vl-python/redisvl/utils/vectorize/text/azureopenai.py:106\u001b[0m, in \u001b[0;36mAzureOpenAITextVectorizer._initialize_clients\u001b[0;34m(self, api_config)\u001b[0m\n\u001b[1;32m 99\u001b[0m azure_endpoint \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 100\u001b[0m api_config\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mazure_endpoint\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 101\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m api_config\n\u001b[1;32m 102\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m os\u001b[38;5;241m.\u001b[39mgetenv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAZURE_OPENAI_ENDPOINT\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 103\u001b[0m )\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m azure_endpoint:\n\u001b[0;32m--> 106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 107\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAzureOpenAI API endpoint is required. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mProvide it in api_config or set the AZURE_OPENAI_ENDPOINT\u001b[39m\u001b[38;5;130;01m\\\u001b[39;00m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;124m environment variable.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 110\u001b[0m )\n\u001b[1;32m 112\u001b[0m api_version \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 113\u001b[0m api_config\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mapi_version\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m api_config\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m os\u001b[38;5;241m.\u001b[39mgetenv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOPENAI_API_VERSION\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 116\u001b[0m )\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m api_version:\n", + "\u001b[0;31mValueError\u001b[0m: AzureOpenAI API endpoint is required. Provide it in api_config or set the AZURE_OPENAI_ENDPOINT environment variable." + ] + } + ], "source": [ "from redisvl.utils.vectorize import AzureOpenAITextVectorizer\n", "\n", @@ -589,10 +658,7 @@ "from redisvl.index import SearchIndex\n", "\n", "# construct a search index from the schema\n", - "index = SearchIndex.from_yaml(\"./schema.yaml\")\n", - "\n", - "# connect to local redis instance\n", - "index.connect(\"redis://localhost:6379\")\n", + "index = SearchIndex.from_yaml(\"./schema.yaml\", redis_url=\"redis://localhost:6379\")\n", "\n", "# create the index (no data yet)\n", "index.create(overwrite=True)" @@ -695,7 +761,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "env", "language": "python", "name": "python3" }, @@ -709,7 +775,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.11.11" }, "orig_nbformat": 4 }, diff --git a/docs/user_guide/05_hash_vs_json.ipynb b/docs/user_guide/05_hash_vs_json.ipynb index 217ab63f..7918949d 100644 --- a/docs/user_guide/05_hash_vs_json.ipynb +++ b/docs/user_guide/05_hash_vs_json.ipynb @@ -139,10 +139,7 @@ "outputs": [], "source": [ "# construct a search index from the hash schema\n", - "hindex = SearchIndex.from_dict(hash_schema)\n", - "\n", - "# connect to local redis instance\n", - "hindex.connect(\"redis://localhost:6379\")\n", + "hindex = SearchIndex.from_dict(hash_schema, redis_url=\"redis://localhost:6379\")\n", "\n", "# create the index (no data yet)\n", "hindex.create(overwrite=True)" @@ -286,10 +283,12 @@ "\n", "t = (Tag(\"credit_score\") == \"high\") & (Text(\"job\") % \"enginee*\") & (Num(\"age\") > 17)\n", "\n", - "v = VectorQuery([0.1, 0.1, 0.5],\n", - " \"user_embedding\",\n", - " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", - " filter_expression=t)\n", + "v = VectorQuery(\n", + " vector=[0.1, 0.1, 0.5],\n", + " vector_field_name=\"user_embedding\",\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"office_location\"],\n", + " filter_expression=t\n", + ")\n", "\n", "\n", "results = hindex.query(v)\n", @@ -359,10 +358,7 @@ "outputs": [], "source": [ "# construct a search index from the json schema\n", - "jindex = SearchIndex.from_dict(json_schema)\n", - "\n", - "# connect to local redis instance\n", - "jindex.connect(\"redis://localhost:6379\")\n", + "jindex = SearchIndex.from_dict(json_schema, redis_url=\"redis://localhost:6379\")\n", "\n", "# create the index (no data yet)\n", "jindex.create(overwrite=True)" @@ -401,8 +397,6 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "\n", "json_data = data.copy()\n", "\n", "for d in json_data:\n", @@ -599,10 +593,7 @@ "outputs": [], "source": [ "# construct a search index from the json schema\n", - "bike_index = SearchIndex.from_dict(bike_schema)\n", - "\n", - "# connect to local redis instance\n", - "bike_index.connect(\"redis://localhost:6379\")\n", + "bike_index = SearchIndex.from_dict(bike_schema, redis_url=\"redis://localhost:6379\")\n", "\n", "# create the index (no data yet)\n", "bike_index.create(overwrite=True)" @@ -639,14 +630,15 @@ "\n", "vec = emb_model.embed(\"I'd like a bike for aggressive riding\")\n", "\n", - "v = VectorQuery(vector=vec,\n", - " vector_field_name=\"bike_embedding\",\n", - " return_fields=[\n", - " \"brand\",\n", - " \"name\",\n", - " \"$.metadata.type\"\n", - " ]\n", - " )\n", + "v = VectorQuery(\n", + " vector=vec,\n", + " vector_field_name=\"bike_embedding\",\n", + " return_fields=[\n", + " \"brand\",\n", + " \"name\",\n", + " \"$.metadata.type\"\n", + " ]\n", + ")\n", "\n", "\n", "results = bike_index.query(v)" diff --git a/redisvl/cli/index.py b/redisvl/cli/index.py index c5b350b3..bbb0d3c7 100644 --- a/redisvl/cli/index.py +++ b/redisvl/cli/index.py @@ -59,9 +59,7 @@ def create(self, args: Namespace): """ if not args.schema: logger.error("Schema must be provided to create an index") - index = SearchIndex.from_yaml(args.schema) - redis_url = create_redis_url(args) - index.connect(redis_url) + index = SearchIndex.from_yaml(args.schema, redis_url=create_redis_url(args)) index.create() logger.info("Index created successfully") @@ -120,8 +118,7 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex: schema = IndexSchema.from_dict({"index": {"name": args.index}}) index = SearchIndex(schema=schema, redis_url=redis_url) elif args.schema: - index = SearchIndex.from_yaml(args.schema) - index.set_client(conn) + index = SearchIndex.from_yaml(args.schema, redis_client=conn) else: logger.error("Index name or schema must be provided") exit(0) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 86cdb02e..41e6e214 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -1,4 +1,5 @@ import asyncio +import weakref from typing import Any, Dict, List, Optional from redis import Redis @@ -22,14 +23,19 @@ from redisvl.index import AsyncSearchIndex, SearchIndex from redisvl.query import RangeQuery from redisvl.query.filter import FilterExpression +from redisvl.redis.connection import RedisConnectionFactory +from redisvl.utils.log import get_logger from redisvl.utils.utils import ( current_timestamp, deprecated_argument, serialize, + sync_wrapper, validate_vector_dims, ) from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +logger = get_logger("[RedisVL]") + class SemanticCache(BaseLLMCache): """Semantic Cache for Large Language Models.""" @@ -128,13 +134,18 @@ def __init__( name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore ) schema = self._modify_schema(schema, filterable_fields) - self._index = SearchIndex(schema=schema) - # Handle redis connection if redis_client: - self._index.set_client(redis_client) - elif redis_url: - self._index.connect(redis_url=redis_url, **connection_kwargs) + self._owns_redis_client = False + else: + self._owns_redis_client = True + + self._index = SearchIndex( + schema=schema, + redis_client=redis_client, + redis_url=redis_url, + **connection_kwargs, + ) # Check for existing cache index if not overwrite and self._index.exists(): @@ -174,17 +185,18 @@ def _modify_schema( async def _get_async_index(self) -> AsyncSearchIndex: """Lazily construct the async search index class.""" - if not self._aindex: - # Construct async index if necessary - self._aindex = AsyncSearchIndex(schema=self._index.schema) - # Connect Redis async client - redis_client = self.redis_kwargs["redis_client"] - redis_url = self.redis_kwargs["redis_url"] - connection_kwargs = self.redis_kwargs["connection_kwargs"] - if redis_client is not None: - await self._aindex.set_client(redis_client) - elif redis_url: - await self._aindex.connect(redis_url, **connection_kwargs) # type: ignore + # Construct async index if necessary + async_client = None + if self._aindex is None: + client = self.redis_kwargs.get("redis_client") + if isinstance(client, Redis): + async_client = RedisConnectionFactory.sync_to_async_redis(client) + self._aindex = AsyncSearchIndex( + schema=self._index.schema, + redis_client=async_client, + redis_url=self.redis_kwargs["redis_url"], + **self.redis_kwargs["connection_kwargs"], + ) return self._aindex @property @@ -284,13 +296,13 @@ async def adrop( def _refresh_ttl(self, key: str) -> None: """Refresh the time-to-live for the specified key.""" if self._ttl: - self._index.client.expire(key, self._ttl) # type: ignore + self._index.expire_keys(key, self._ttl) async def _async_refresh_ttl(self, key: str) -> None: """Async refresh the time-to-live for the specified key.""" aindex = await self._get_async_index() if self._ttl: - await aindex.client.expire(key, self._ttl) # type: ignore + await aindex.expire_keys(key, self._ttl) def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: """Converts a text prompt to its vector representation using the @@ -311,7 +323,9 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it doesn't match the search index vector dimensions.""" - schema_vector_dims = self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims # type: ignore + schema_vector_dims = self._index.schema.fields[ + CACHE_VECTOR_FIELD_NAME + ].attrs.dims # type: ignore validate_vector_dims(len(vector), schema_vector_dims) def check( @@ -386,7 +400,8 @@ def check( # Search the cache! cache_search_results = self._index.query(query) redis_keys, cache_hits = self._process_cache_results( - cache_search_results, return_fields # type: ignore + cache_search_results, + return_fields, # type: ignore ) # Extend TTL on keys for key in redis_keys: @@ -467,7 +482,8 @@ async def acheck( # Search the cache! cache_search_results = await aindex.query(query) redis_keys, cache_hits = self._process_cache_results( - cache_search_results, return_fields # type: ignore + cache_search_results, + return_fields, # type: ignore ) # Extend TTL on keys await asyncio.gather(*[self._async_refresh_ttl(key) for key in redis_keys]) @@ -640,7 +656,6 @@ def update(self, key: str, **kwargs) -> None: """ if kwargs: for k, v in kwargs.items(): - # Make sure the item is in the index schema if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]): raise ValueError(f"{k} is not a valid field within the cache entry") @@ -683,7 +698,6 @@ async def aupdate(self, key: str, **kwargs) -> None: if kwargs: for k, v in kwargs.items(): - # Make sure the item is in the index schema if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]): raise ValueError(f"{k} is not a valid field within the cache entry") @@ -702,3 +716,26 @@ async def aupdate(self, key: str, **kwargs) -> None: await aindex.load(data=[kwargs], keys=[key]) await self._async_refresh_ttl(key) + + def disconnect(self): + if self._owns_redis_client is False: + return + if self._index: + self._index.disconnect() + if self._aindex: + self._aindex.disconnect_sync() + + async def adisconnect(self): + if not self._owns_redis_client: + return + if self._index: + self._index.disconnect() + if self._aindex: + await self._aindex.disconnect() + self._aindex = None + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.adisconnect() diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index be519dd8..4a7e72c3 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -110,12 +110,12 @@ def _initialize_index( schema = SemanticRouterIndexSchema.from_params( self.name, self.vectorizer.dims, self.vectorizer.dtype # type: ignore ) - self._index = SearchIndex(schema=schema) - - if redis_client: - self._index.set_client(redis_client) - elif redis_url: - self._index.connect(redis_url=redis_url, **connection_kwargs) + self._index = SearchIndex( + schema=schema, + redis_client=redis_client, + redis_url=redis_url, + **connection_kwargs, + ) # Check for existing router index existed = self._index.exists() diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 6924ce5d..6825afa9 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -97,13 +97,12 @@ def __init__( name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore ) - self._index = SearchIndex(schema=schema) - - # handle redis connection - if redis_client: - self._index.set_client(redis_client) - elif redis_url: - self._index.connect(redis_url=redis_url, **connection_kwargs) + self._index = SearchIndex( + schema=schema, + redis_client=redis_client, + redis_url=redis_url, + **connection_kwargs, + ) # Check for existing session index if not overwrite and self._index.exists(): diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 42b86628..4e46010c 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -61,12 +61,12 @@ def __init__( schema = StandardSessionIndexSchema.from_params(name, prefix) - self._index = SearchIndex(schema=schema) - - if redis_client: - self._index.set_client(redis_client) - else: - self._index.connect(redis_url=redis_url, **connection_kwargs) + self._index = SearchIndex( + schema=schema, + redis_client=redis_client, + redis_url=redis_url, + **connection_kwargs, + ) self._index.create(overwrite=False) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 8914b4c9..59d20f28 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -1,8 +1,8 @@ import asyncio -import atexit import json import threading -from functools import wraps +import warnings +import weakref from typing import ( TYPE_CHECKING, Any, @@ -16,6 +16,8 @@ Union, ) +from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper + if TYPE_CHECKING: from redis.commands.search.aggregation import AggregateResult from redis.commands.search.document import Document @@ -33,7 +35,6 @@ from redisvl.redis.connection import ( RedisConnectionFactory, convert_index_info_to_schema, - validate_modules, ) from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType @@ -42,6 +43,12 @@ logger = get_logger(__name__) +REQUIRED_MODULES_FOR_INTROSPECTION = [ + {"name": "search", "ver": 20810}, + {"name": "searchlight", "ver": 20810}, +] + + def process_results( results: "Result", query: BaseQuery, storage_type: StorageType ) -> List[Dict[str, Any]]: @@ -94,36 +101,6 @@ def _process(doc: "Document") -> Dict[str, Any]: return [_process(doc) for doc in results.docs] -def setup_redis(): - def decorator(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - result = func(self, *args, **kwargs) - RedisConnectionFactory.validate_sync_redis( - self._redis_client, self._lib_name - ) - return result - - return wrapper - - return decorator - - -def setup_async_redis(): - def decorator(func): - @wraps(func) - async def wrapper(self, *args, **kwargs): - result = await func(self, *args, **kwargs) - await RedisConnectionFactory.validate_async_redis( - self._redis_client, self._lib_name - ) - return result - - return wrapper - - return decorator - - class BaseSearchIndex: """Base search engine class""" @@ -218,8 +195,7 @@ def from_dict(cls, schema_dict: Dict[str, Any], **kwargs): def disconnect(self): """Disconnect from the Redis database.""" - self._redis_client = None - return self + raise NotImplementedError("This method should be implemented by subclasses.") def key(self, id: str) -> str: """Construct a redis key as a combination of an index key prefix (optional) @@ -255,8 +231,7 @@ class SearchIndex(BaseSearchIndex): from redisvl.index import SearchIndex # initialize the index object with schema from file - index = SearchIndex.from_yaml("schemas/schema.yaml") - index.connect(redis_url="redis://localhost:6379") + index = SearchIndex.from_yaml("schemas/schema.yaml", redis_url="redis://localhost:6379") # create the index index.create(overwrite=True) @@ -269,12 +244,13 @@ class SearchIndex(BaseSearchIndex): """ + @deprecated_argument("connection_args", "Use connection_kwargs instead.") def __init__( self, schema: IndexSchema, redis_client: Optional[redis.Redis] = None, redis_url: Optional[str] = None, - connection_args: Dict[str, Any] = {}, + connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): """Initialize the RedisVL search index with a schema, Redis client @@ -287,10 +263,12 @@ def __init__( instantiated redis client. redis_url (Optional[str]): The URL of the Redis server to connect to. - connection_args (Dict[str, Any], optional): Redis client connection + connection_kwargs (Dict[str, Any], optional): Redis client connection args. """ - # final validation on schema object + if "connection_args" in kwargs: + connection_kwargs = kwargs.pop("connection_args") + if not isinstance(schema, IndexSchema): raise ValueError("Must provide a valid IndexSchema object") @@ -298,13 +276,24 @@ def __init__( self._lib_name: Optional[str] = kwargs.pop("lib_name", None) - # set up redis connection - self._redis_client: Optional[redis.Redis] = None + # Store connection parameters + self.__redis_client = redis_client + self._redis_url = redis_url + self._connection_kwargs = connection_kwargs or {} + self._lock = threading.Lock() - if redis_client is not None: - self.set_client(redis_client) - elif redis_url is not None: - self.connect(redis_url, **connection_args) + self._owns_redis_client = redis_client is None + if self._owns_redis_client: + weakref.finalize(self, self.disconnect) + + def disconnect(self): + """Disconnect from the Redis database.""" + if self._owns_redis_client is False: + logger.info("Index does not own client, not disconnecting") + return + if self.__redis_client: + self.__redis_client.close() + self.__redis_client = None @classmethod def from_existing( @@ -323,31 +312,30 @@ def from_existing( instantiated redis client. redis_url (Optional[str]): The URL of the Redis server to connect to. - """ - # Handle redis instance - if redis_url: - redis_client = RedisConnectionFactory.connect( - redis_url=redis_url, use_async=False, **kwargs - ) - if not redis_client: - raise ValueError( - "Must provide either a redis_url or redis_client to fetch Redis index info." - ) - - # Validate modules - installed_modules = RedisConnectionFactory.get_modules(redis_client) + Raises: + ValueError: If redis_url or redis_client is not provided. + RedisModuleVersionError: If required Redis modules are not installed. + """ try: - required_modules = [ - {"name": "search", "ver": 20810}, - {"name": "searchlight", "ver": 20810}, - ] - validate_modules(installed_modules, required_modules) + if redis_url: + redis_client = RedisConnectionFactory.get_redis_connection( + redis_url=redis_url, + required_modules=REQUIRED_MODULES_FOR_INTROSPECTION, + **kwargs, + ) + elif redis_client: + RedisConnectionFactory.validate_sync_redis( + redis_client, required_modules=REQUIRED_MODULES_FOR_INTROSPECTION + ) except RedisModuleVersionError as e: raise RedisModuleVersionError( f"Loading from existing index failed. {str(e)}" ) + if not redis_client: + raise ValueError("Must provide either a redis_url or redis_client") + # Fetch index info and convert to schema index_info = cls._info(name, redis_client) schema_dict = convert_index_info_to_schema(index_info) @@ -357,8 +345,25 @@ def from_existing( @property def client(self) -> Optional[redis.Redis]: """The underlying redis-py client object.""" - return self._redis_client + return self.__redis_client + + @property + def _redis_client(self) -> Optional[redis.Redis]: + """ + Get a Redis client instance. + Lazily creates a Redis client instance if it doesn't exist. + """ + if self.__redis_client is None: + with self._lock: + if self.__redis_client is None: + self.__redis_client = RedisConnectionFactory.get_redis_connection( + url=self._redis_url, + **self._connection_kwargs, + ) + return self.__redis_client + + @deprecated_function("connect", "Pass connection parameters in __init__.") def connect(self, redis_url: Optional[str] = None, **kwargs): """Connect to a Redis instance using the provided `redis_url`, falling back to the `REDIS_URL` environment variable (if available). @@ -368,26 +373,25 @@ def connect(self, redis_url: Optional[str] = None, **kwargs): Args: redis_url (Optional[str], optional): The URL of the Redis server to - connect to. If not provided, the method defaults to using the - `REDIS_URL` environment variable. + connect to. Raises: redis.exceptions.ConnectionError: If the connection to the Redis server fails. ValueError: If the Redis URL is not provided nor accessible through the `REDIS_URL` environment variable. + ModuleNotFoundError: If required Redis modules are not installed. .. code-block:: python index.connect(redis_url="redis://localhost:6379") """ - client = RedisConnectionFactory.connect( - redis_url=redis_url, use_async=False, **kwargs + self.__redis_client = RedisConnectionFactory.get_redis_connection( + redis_url=redis_url, **kwargs ) - return self.set_client(client) - @setup_redis() + @deprecated_function("set_client", "Pass connection parameters in __init__.") def set_client(self, redis_client: redis.Redis, **kwargs): """Manually set the Redis client to use with the search index. @@ -412,10 +416,8 @@ def set_client(self, redis_client: redis.Redis, **kwargs): index.set_client(client) """ - if not isinstance(redis_client, redis.Redis): - raise TypeError("Invalid Redis client instance") - - self._redis_client = redis_client + RedisConnectionFactory.validate_sync_redis(redis_client) + self.__redis_client = redis_client return self def create(self, overwrite: bool = False, drop: bool = False) -> None: @@ -519,6 +521,23 @@ def drop_keys(self, keys: Union[str, List[str]]) -> int: else: return self._redis_client.delete(keys) # type: ignore + def expire_keys( + self, keys: Union[str, List[str]], ttl: int + ) -> Union[int, List[int]]: + """Set the expiration time for a specific entry or entries in Redis. + + Args: + keys (Union[str, List[str]]): The entry ID or IDs to set the expiration for. + ttl (int): The time-to-live in seconds. + """ + if isinstance(keys, list): + pipe = self._redis_client.pipeline() # type: ignore + for key in keys: + pipe.expire(key, ttl) + return pipe.execute() + else: + return self._redis_client.expire(keys, ttl) # type: ignore + def load( self, data: Iterable[Any], @@ -765,6 +784,12 @@ def info(self, name: Optional[str] = None) -> Dict[str, Any]: index_name = name or self.schema.index.name return self._info(index_name, self._redis_client) # type: ignore + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disconnect() + class AsyncSearchIndex(BaseSearchIndex): """A search index class for interacting with Redis as a vector database in @@ -779,8 +804,10 @@ class AsyncSearchIndex(BaseSearchIndex): from redisvl.index import AsyncSearchIndex # initialize the index object with schema from file - index = AsyncSearchIndex.from_yaml("schemas/schema.yaml") - await index.connect(redis_url="redis://localhost:6379") + index = AsyncSearchIndex.from_yaml( + "schemas/schema.yaml", + redis_url="redis://localhost:6379" + ) # create the index await index.create(overwrite=True) @@ -793,18 +820,30 @@ class AsyncSearchIndex(BaseSearchIndex): """ + @deprecated_argument("redis_kwargs", "Use connection_kwargs instead.") def __init__( self, schema: IndexSchema, + *, + redis_url: Optional[str] = None, + redis_client: Optional[aredis.Redis] = None, + connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): """Initialize the RedisVL async search index with a schema. Args: schema (IndexSchema): Index schema object. - connection_args (Dict[str, Any], optional): Redis client connection + redis_url (Optional[str], optional): The URL of the Redis server to + connect to. + redis_client (Optional[aredis.Redis]): An + instantiated redis client. + connection_kwargs (Optional[Dict[str, Any]]): Redis client connection args. """ + if "redis_kwargs" in kwargs: + connection_kwargs = kwargs.pop("redis_kwargs") + # final validation on schema object if not isinstance(schema, IndexSchema): raise ValueError("Must provide a valid IndexSchema object") @@ -813,38 +852,15 @@ def __init__( self._lib_name: Optional[str] = kwargs.pop("lib_name", None) - # set up empty redis connection - self._redis_client: Optional[aredis.Redis] = None - - if "redis_client" in kwargs or "redis_url" in kwargs: - logger.warning( - "Must use set_client() or connect() methods to provide a Redis connection to AsyncSearchIndex" - ) - - atexit.register(self._cleanup_connection) - - def _cleanup_connection(self): - if self._redis_client: - - def run_in_thread(): - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._redis_client.aclose()) - loop.close() - except RuntimeError: - pass - - # Run cleanup in a background thread to avoid event loop issues - thread = threading.Thread(target=run_in_thread) - thread.start() - thread.join() - - self._redis_client = None + # Store connection parameters + self._redis_client = redis_client + self._redis_url = redis_url + self._connection_kwargs = connection_kwargs or {} + self._lock = asyncio.Lock() - def disconnect(self): - """Disconnect and cleanup the underlying async redis connection.""" - self._cleanup_connection() + self._owns_redis_client = redis_client is None + if self._owns_redis_client: + weakref.finalize(self, sync_wrapper(self.disconnect)) @classmethod async def from_existing( @@ -864,106 +880,99 @@ async def from_existing( redis_url (Optional[str]): The URL of the Redis server to connect to. """ - if redis_url: - redis_client = RedisConnectionFactory.connect( - redis_url=redis_url, use_async=True, **kwargs - ) - - if not redis_client: + if not redis_url and not redis_client: raise ValueError( "Must provide either a redis_url or redis_client to fetch Redis index info." ) - # Validate modules - installed_modules = await RedisConnectionFactory.get_modules_async(redis_client) - try: - required_modules = [ - {"name": "search", "ver": 20810}, - {"name": "searchlight", "ver": 20810}, - ] - validate_modules(installed_modules, required_modules) + if redis_url: + redis_client = await RedisConnectionFactory._get_aredis_connection( + url=redis_url, + required_modules=REQUIRED_MODULES_FOR_INTROSPECTION, + **kwargs, + ) + elif redis_client: + await RedisConnectionFactory.validate_async_redis( + redis_client, required_modules=REQUIRED_MODULES_FOR_INTROSPECTION + ) except RedisModuleVersionError as e: raise RedisModuleVersionError( f"Loading from existing index failed. {str(e)}" ) from e + if redis_client is None: + raise ValueError( + "Failed to obtain a valid Redis client. " + "Please provide a valid redis_client or redis_url." + ) + # Fetch index info and convert to schema index_info = await cls._info(name, redis_client) schema_dict = convert_index_info_to_schema(index_info) schema = IndexSchema.from_dict(schema_dict) - index = cls(schema, **kwargs) - await index.set_client(redis_client) - return index + return cls(schema, redis_client=redis_client, **kwargs) @property def client(self) -> Optional[aredis.Redis]: """The underlying redis-py client object.""" return self._redis_client + @deprecated_function("connect", "Pass connection parameters in __init__.") async def connect(self, redis_url: Optional[str] = None, **kwargs): - """Connect to a Redis instance using the provided `redis_url`, falling - back to the `REDIS_URL` environment variable (if available). - - Note: Additional keyword arguments (`**kwargs`) can be used to provide - extra options specific to the Redis connection. - - Args: - redis_url (Optional[str], optional): The URL of the Redis server to - connect to. If not provided, the method defaults to using the - `REDIS_URL` environment variable. - - Raises: - redis.exceptions.ConnectionError: If the connection to the Redis - server fails. - ValueError: If the Redis URL is not provided nor accessible - through the `REDIS_URL` environment variable. - - .. code-block:: python - - index.connect(redis_url="redis://localhost:6379") - - """ - client = RedisConnectionFactory.connect( - redis_url=redis_url, use_async=True, **kwargs + """[DEPRECATED] Connect to a Redis instance. Use connection parameters in __init__.""" + warnings.warn( + "connect() is deprecated; pass connection parameters in __init__", + DeprecationWarning, ) - return await self.set_client(client) - - @setup_async_redis() - async def set_client(self, redis_client: aredis.Redis): - """Manually set the Redis client to use with the search index. - - This method configures the search index to use a specific - Async Redis client. It is useful for cases where an external, - custom-configured client is preferred instead of creating a new one. - - Args: - redis_client (aredis.Redis): An Async Redis - client instance to be used for the connection. - - Raises: - TypeError: If the provided client is not valid. - - .. code-block:: python + client = await RedisConnectionFactory._get_aredis_connection( + redis_url=redis_url, **kwargs + ) + await self.set_client(client) - import redis.asyncio as aredis - from redisvl.index import AsyncSearchIndex + @deprecated_function("set_client", "Pass connection parameters in __init__.") + async def set_client(self, redis_client: Union[aredis.Redis, redis.Redis]): + """ + [DEPRECATED] Manually set the Redis client to use with the search index. + This method is deprecated; please provide connection parameters in __init__. + """ + redis_client = await self._validate_client(redis_client) + await self.disconnect() + async with self._lock: + self._redis_client = redis_client + return self - # async Redis client and index - client = aredis.Redis.from_url("redis://localhost:6379") - index = AsyncSearchIndex.from_yaml("schemas/schema.yaml") - await index.set_client(client) + async def _get_client(self) -> aredis.Redis: + """Lazily instantiate and return the async Redis client.""" + if self._redis_client is None: + async with self._lock: + # Double-check to protect against concurrent access + if self._redis_client is None: + kwargs = self._connection_kwargs + if self._redis_url: + kwargs["url"] = self._redis_url + self._redis_client = ( + await RedisConnectionFactory._get_aredis_connection(**kwargs) + ) + await RedisConnectionFactory.validate_async_redis( + self._redis_client, self._lib_name + ) + return self._redis_client - """ + async def _validate_client( + self, redis_client: Union[aredis.Redis, redis.Redis] + ) -> aredis.Redis: if isinstance(redis_client, redis.Redis): - print("Setting client and converting from async", flush=True) - self._redis_client = RedisConnectionFactory.sync_to_async_redis( - redis_client + warnings.warn( + "Converting sync Redis client to async client is deprecated " + "and will be removed in the next major version. Please use an " + "async Redis client instead.", + DeprecationWarning, ) - else: - self._redis_client = redis_client - - return self + redis_client = RedisConnectionFactory.sync_to_async_redis(redis_client) + elif not isinstance(redis_client, aredis.Redis): + raise ValueError("Invalid client type: must be redis.asyncio.Redis") + return redis_client async def create(self, overwrite: bool = False, drop: bool = False) -> None: """Asynchronously create an index in Redis with the current schema @@ -990,6 +999,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None: # overwrite an index in Redis; drop associated data (clean slate) await index.create(overwrite=True, drop=True) """ + client = await self._get_client() redis_fields = self.schema.redis_fields if not redis_fields: @@ -1005,7 +1015,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None: await self.delete(drop) try: - await self._redis_client.ft(self.schema.index.name).create_index( # type: ignore + await client.ft(self.schema.index.name).create_index( fields=redis_fields, definition=IndexDefinition( prefix=[self.schema.index.prefix], index_type=self._storage.type @@ -1025,10 +1035,9 @@ async def delete(self, drop: bool = True): Raises: redis.exceptions.ResponseError: If the index does not exist. """ + client = await self._get_client() try: - await self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore - delete_documents=drop - ) + await client.ft(self.schema.index.name).dropindex(delete_documents=drop) except Exception as e: raise RedisSearchError(f"Error while deleting index: {str(e)}") from e @@ -1039,16 +1048,15 @@ async def clear(self) -> int: Returns: int: Count of records deleted from Redis. """ - # Track deleted records + client = await self._get_client() total_records_deleted: int = 0 - # Paginate using queries and delete in batches async for batch in self.paginate( FilterQuery(FilterExpression("*"), return_fields=["id"]), page_size=500 ): batch_keys = [record["id"] for record in batch] - records_deleted = await self._redis_client.delete(*batch_keys) # type: ignore - total_records_deleted += records_deleted # type: ignore + records_deleted = await client.delete(*batch_keys) + total_records_deleted += records_deleted return total_records_deleted @@ -1061,10 +1069,29 @@ async def drop_keys(self, keys: Union[str, List[str]]) -> int: Returns: int: Count of records deleted from Redis. """ - if isinstance(keys, List): - return await self._redis_client.delete(*keys) # type: ignore + client = await self._get_client() + if isinstance(keys, list): + return await client.delete(*keys) + else: + return await client.delete(keys) + + async def expire_keys( + self, keys: Union[str, List[str]], ttl: int + ) -> Union[int, List[int]]: + """Set the expiration time for a specific entry or entries in Redis. + + Args: + keys (Union[str, List[str]]): The entry ID or IDs to set the expiration for. + ttl (int): The time-to-live in seconds. + """ + client = await self._get_client() + if isinstance(keys, list): + pipe = client.pipeline() + for key in keys: + pipe.expire(key, ttl) + return await pipe.execute() else: - return await self._redis_client.delete(keys) # type: ignore + return await client.expire(keys, ttl) async def load( self, @@ -1124,9 +1151,10 @@ async def add_field(d): keys = await index.load(data, preprocess=add_field) """ + client = await self._get_client() try: return await self._storage.awrite( - self._redis_client, # type: ignore + client, objects=data, id_field=id_field, keys=keys, @@ -1150,7 +1178,8 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]: Returns: Dict[str, Any]: The fetched object. """ - obj = await self._storage.aget(self._redis_client, [self.key(id)]) # type: ignore + client = await self._get_client() + obj = await self._storage.aget(client, [self.key(id)]) if obj: return convert_bytes(obj[0]) return None @@ -1165,10 +1194,9 @@ async def aggregate(self, *args, **kwargs) -> "AggregateResult": Returns: Result: Raw Redis aggregation results. """ + client = await self._get_client() try: - return await self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore - *args, **kwargs - ) + return client.ft(self.schema.index.name).aggregate(*args, **kwargs) except Exception as e: raise RedisSearchError(f"Error while aggregating: {str(e)}") from e @@ -1182,10 +1210,9 @@ async def search(self, *args, **kwargs) -> "Result": Returns: Result: Raw Redis search results. """ + client = await self._get_client() try: - return await self._redis_client.ft(self.schema.index.name).search( # type: ignore - *args, **kwargs - ) + return await client.ft(self.schema.index.name).search(*args, **kwargs) # type: ignore except Exception as e: raise RedisSearchError(f"Error while searching: {str(e)}") from e @@ -1256,7 +1283,7 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato """ if not isinstance(page_size, int): - raise TypeError("page_size must be an integer") + raise TypeError("page_size must be of type int") if page_size <= 0: raise ValueError("page_size must be greater than 0") @@ -1268,7 +1295,6 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato if not results: break yield results - # increment the pagination tracker first += page_size async def listall(self) -> List[str]: @@ -1277,9 +1303,8 @@ async def listall(self) -> List[str]: Returns: List[str]: The list of indices in the database. """ - return convert_bytes( - await self._redis_client.execute_command("FT._LIST") # type: ignore - ) + client = await self._get_client() + return convert_bytes(await client.execute_command("FT._LIST")) async def exists(self) -> bool: """Check if the index exists in Redis. @@ -1289,15 +1314,6 @@ async def exists(self) -> bool: """ return self.schema.index.name in await self.listall() - @staticmethod - async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]: - try: - return convert_bytes(await redis_client.ft(name).info()) # type: ignore - except Exception as e: - raise RedisSearchError( - f"Error while fetching {name} index info: {str(e)}" - ) from e - async def info(self, name: Optional[str] = None) -> Dict[str, Any]: """Get information about the index. @@ -1308,5 +1324,33 @@ async def info(self, name: Optional[str] = None) -> Dict[str, Any]: Returns: dict: A dictionary containing the information about the index. """ + client = await self._get_client() index_name = name or self.schema.index.name - return await self._info(index_name, self._redis_client) # type: ignore + return await type(self)._info(index_name, client) + + @staticmethod + async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]: + try: + return convert_bytes(await redis_client.ft(name).info()) # type: ignore + except Exception as e: + raise RedisSearchError( + f"Error while fetching {name} index info: {str(e)}" + ) from e + + async def disconnect(self): + if self._owns_redis_client is False: + return + if self._redis_client is not None: + await self._redis_client.aclose() # type: ignore + self._redis_client = None + + def disconnect_sync(self): + if self._redis_client is None or self._owns_redis_client is False: + return + sync_wrapper(self.disconnect)() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.disconnect() diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 6cf6cd4f..cc95c479 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -1,5 +1,6 @@ import os -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union +from warnings import warn from redis import Redis from redis.asyncio import Connection as AsyncConnection @@ -12,6 +13,7 @@ from redisvl.exceptions import RedisModuleVersionError from redisvl.redis.constants import DEFAULT_REQUIRED_MODULES from redisvl.redis.utils import convert_bytes +from redisvl.utils.utils import deprecated_function from redisvl.version import __version__ @@ -189,9 +191,12 @@ class RedisConnectionFactory: """ @classmethod + @deprecated_function( + "connect", "Please use `get_redis_connection` or `get_async_redis_connection`." + ) def connect( cls, redis_url: Optional[str] = None, use_async: bool = False, **kwargs - ) -> None: + ) -> Union[Redis, AsyncRedis]: """Create a connection to the Redis database based on a URL and some connection kwargs. @@ -217,12 +222,18 @@ def connect( return connection_func(redis_url, **kwargs) # type: ignore @staticmethod - def get_redis_connection(url: Optional[str] = None, **kwargs) -> Redis: + def get_redis_connection( + url: Optional[str] = None, + required_modules: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> Redis: """Creates and returns a synchronous Redis client. Args: url (Optional[str]): The URL of the Redis server. If not provided, the environment variable REDIS_URL is used. + required_modules (Optional[List[Dict[str, Any]]]): List of required + Redis modules with version requirements. **kwargs: Additional keyword arguments to be passed to the Redis client constructor. @@ -232,14 +243,57 @@ def get_redis_connection(url: Optional[str] = None, **kwargs) -> Redis: Raises: ValueError: If url is not provided and REDIS_URL environment variable is not set. + RedisModuleVersionError: If required Redis modules are not installed. + """ + url = url or get_address_from_env() + client = Redis.from_url(url, **kwargs) + + RedisConnectionFactory.validate_sync_redis( + client, required_modules=required_modules + ) + + return client + + @staticmethod + async def _get_aredis_connection( + url: Optional[str] = None, + required_modules: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> AsyncRedis: + """Creates and returns an asynchronous Redis client. + + NOTE: This method is the future form of `get_async_redis_connection` but is + only used internally by the library now. + + Args: + url (Optional[str]): The URL of the Redis server. If not provided, + the environment variable REDIS_URL is used. + required_modules (Optional[List[Dict[str, Any]]]): List of required + Redis modules with version requirements. + **kwargs: Additional keyword arguments to be passed to the async + Redis client constructor. + + Returns: + AsyncRedis: An asynchronous Redis client instance. + + Raises: + ValueError: If url is not provided and REDIS_URL environment + variable is not set. + RedisModuleVersionError: If required Redis modules are not installed. """ - if url: - return Redis.from_url(url, **kwargs) - # fallback to env var REDIS_URL - return Redis.from_url(get_address_from_env(), **kwargs) + url = url or get_address_from_env() + client = AsyncRedis.from_url(url, **kwargs) + + await RedisConnectionFactory.validate_async_redis( + client, required_modules=required_modules + ) + return client @staticmethod - def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedis: + def get_async_redis_connection( + url: Optional[str] = None, + **kwargs, + ) -> AsyncRedis: """Creates and returns an asynchronous Redis client. Args: @@ -255,19 +309,22 @@ def get_async_redis_connection(url: Optional[str] = None, **kwargs) -> AsyncRedi ValueError: If url is not provided and REDIS_URL environment variable is not set. """ - if url: - return AsyncRedis.from_url(url, **kwargs) - # fallback to env var REDIS_URL - return AsyncRedis.from_url(get_address_from_env(), **kwargs) + warn( + "get_async_redis_connection will become async in the next major release.", + DeprecationWarning, + ) + url = url or get_address_from_env() + return AsyncRedis.from_url(url, **kwargs) @staticmethod def sync_to_async_redis(redis_client: Redis) -> AsyncRedis: + """Convert a synchronous Redis client to an asynchronous one.""" # pick the right connection class connection_class: Type[AbstractConnection] = ( AsyncSSLConnection if redis_client.connection_pool.connection_class == SSLConnection else AsyncConnection - ) + ) # type: ignore # make async client return AsyncRedis.from_pool( # type: ignore AsyncConnectionPool( @@ -291,6 +348,9 @@ def validate_sync_redis( required_modules: Optional[List[Dict[str, Any]]] = None, ) -> None: """Validates the sync Redis client.""" + if not isinstance(redis_client, Redis): + raise TypeError("Invalid Redis client instance") + # Set client library name _lib_name = make_lib_name(lib_name) try: diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index a8f07338..4c40d41a 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -1,11 +1,13 @@ +import asyncio import inspect import json +import logging import warnings from contextlib import contextmanager from enum import Enum from functools import wraps from time import time -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Coroutine, Dict, Optional from warnings import warn from pydantic import BaseModel @@ -13,7 +15,7 @@ def create_ulid() -> str: - """Generate a unique indentifier to group related Redis documents.""" + """Generate a unique identifier to group related Redis documents.""" return str(ULID()) @@ -132,3 +134,60 @@ def assert_no_warnings(): with warnings.catch_warnings(): warnings.simplefilter("error") yield + + +def deprecated_function(name: Optional[str] = None, replacement: Optional[str] = None): + """ + Decorator to mark a function as deprecated. + + When the wrapped function is called, the decorator will log a deprecation + warning. + """ + + def decorator(func): + fn_name = name or func.__name__ + warning_message = ( + f"Function {fn_name} is deprecated and will be " + "removed in the next major release. " + ) + if replacement: + warning_message += replacement + + @wraps(func) + def wrapper(*args, **kwargs): + warn(warning_message, category=DeprecationWarning, stacklevel=3) + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def sync_wrapper(fn: Callable[[], Coroutine[Any, Any, Any]]) -> Callable[[], None]: + def wrapper(): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + try: + if loop is None or not loop.is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + task = loop.create_task(fn()) + loop.run_until_complete(task) + except RuntimeError: + # This could happen if an object stored an event loop and now + # that event loop is closed. There's nothing we can do other than + # advise the user to use explicit cleanup methods. + # + # Uses logging module instead of get_logger() to avoid I/O errors + # if the wrapped function is called as a finalizer. + logging.info( + f"Could not run the async function {fn.__name__} because the event loop is closed. " + "This usually means the object was not properly cleaned up. Please use explicit " + "cleanup methods (e.g., disconnect(), close()) or use the object as an async " + "context manager.", + ) + return + + return wrapper diff --git a/tests/conftest.py b/tests/conftest.py index 9ed97bd5..8657cb44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,9 +54,7 @@ async def async_client(redis_url): """ An async Redis client that uses the dynamic `redis_url`. """ - async with await RedisConnectionFactory.get_async_redis_connection( - redis_url - ) as client: + async with await RedisConnectionFactory._get_aredis_connection(redis_url) as client: yield client diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index 9dc8460d..2aa7ae0f 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -1,4 +1,8 @@ +import warnings + import pytest +from redis import Redis as SyncRedis +from redis.asyncio import Redis from redisvl.exceptions import RedisSearchError from redisvl.index import AsyncSearchIndex @@ -15,8 +19,8 @@ def index_schema(): @pytest.fixture -def async_index(index_schema): - return AsyncSearchIndex(schema=index_schema) +def async_index(index_schema, async_client): + return AsyncSearchIndex(schema=index_schema, redis_client=async_client) @pytest.fixture @@ -33,7 +37,7 @@ def test_search_index_properties(index_schema, async_index): assert async_index.schema == index_schema # custom settings assert async_index.name == index_schema.index.name == "my_index" - assert async_index.client == None + assert async_index.client # default settings assert async_index.prefix == index_schema.index.prefix == "rvl" assert async_index.key_separator == index_schema.index.key_separator == ":" @@ -45,7 +49,7 @@ def test_search_index_properties(index_schema, async_index): def test_search_index_from_yaml(async_index_from_yaml): assert async_index_from_yaml.name == "json-test" - assert async_index_from_yaml.client == None + assert async_index_from_yaml.client is None assert async_index_from_yaml.prefix == "json" assert async_index_from_yaml.key_separator == ":" assert async_index_from_yaml.storage_type == StorageType.JSON @@ -54,7 +58,7 @@ def test_search_index_from_yaml(async_index_from_yaml): def test_search_index_from_dict(async_index_from_dict): assert async_index_from_dict.name == "my_index" - assert async_index_from_dict.client == None + assert async_index_from_dict.client is None assert async_index_from_dict.prefix == "rvl" assert async_index_from_dict.key_separator == ":" assert async_index_from_dict.storage_type == StorageType.HASH @@ -64,7 +68,6 @@ def test_search_index_from_dict(async_index_from_dict): @pytest.mark.asyncio async def test_search_index_from_existing(async_client, async_index): - await async_index.set_client(async_client) await async_index.create(overwrite=True) try: @@ -107,9 +110,7 @@ async def test_search_index_from_existing_complex(async_client): }, ], } - async_index = await AsyncSearchIndex.from_dict(schema).set_client( - redis_client=async_client - ) + async_index = AsyncSearchIndex.from_dict(schema, redis_client=async_client) await async_index.create(overwrite=True) try: @@ -132,36 +133,44 @@ def test_search_index_no_prefix(index_schema): @pytest.mark.asyncio async def test_search_index_redis_url(redis_url, index_schema): - async_index = await AsyncSearchIndex(schema=index_schema).connect( - redis_url=redis_url - ) + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) + # Client is None until a command is run + assert async_index.client is None + + # Lazily create the client by running a command + await async_index.create(overwrite=True, drop=True) assert async_index.client - async_index.disconnect() - assert async_index.client == None + await async_index.disconnect() + assert async_index.client is None @pytest.mark.asyncio async def test_search_index_client(async_client, index_schema): - async_index = await AsyncSearchIndex(schema=index_schema).set_client( - redis_client=async_client - ) + async_index = AsyncSearchIndex(schema=index_schema, redis_client=async_client) assert async_index.client == async_client @pytest.mark.asyncio -async def test_search_index_set_client(async_client, client, async_index): - await async_index.set_client(async_client) - assert async_index.client == async_client - await async_index.set_client(client) +async def test_search_index_set_client(client, redis_url, index_schema): + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) + # Ignore deprecation warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + await async_index.create(overwrite=True, drop=True) + assert isinstance(async_index.client, Redis) - async_index.disconnect() - assert async_index.client == None + # Tests deprecated sync -> async conversion behavior + assert isinstance(client, SyncRedis) + await async_index.set_client(client) + assert isinstance(async_index.client, Redis) + + await async_index.disconnect() + assert async_index.client is None @pytest.mark.asyncio -async def test_search_index_create(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_create(async_index): await async_index.create(overwrite=True, drop=True) assert await async_index.exists() assert async_index.name in convert_bytes( @@ -170,8 +179,7 @@ async def test_search_index_create(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_delete(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_delete(async_index): await async_index.create(overwrite=True, drop=True) await async_index.delete(drop=True) assert not await async_index.exists() @@ -181,8 +189,7 @@ async def test_search_index_delete(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_clear(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_clear(async_index): await async_index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] await async_index.load(data, id_field="id") @@ -193,8 +200,7 @@ async def test_search_index_clear(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_drop_key(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_drop_key(async_index): await async_index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] keys = await async_index.load(data, id_field="id") @@ -206,8 +212,7 @@ async def test_search_index_drop_key(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_drop_keys(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_drop_keys(async_index): await async_index.create(overwrite=True, drop=True) data = [ {"id": "1", "test": "foo"}, @@ -226,8 +231,7 @@ async def test_search_index_drop_keys(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_load_and_fetch(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_load_and_fetch(async_index): await async_index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] await async_index.load(data, id_field="id") @@ -245,8 +249,7 @@ async def test_search_index_load_and_fetch(async_client, async_index): @pytest.mark.asyncio -async def test_search_index_load_preprocess(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_load_preprocess(async_index): await async_index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] @@ -270,15 +273,13 @@ async def bad_preprocess(record): @pytest.mark.asyncio -async def test_search_index_load_empty(async_client, async_index): - await async_index.set_client(async_client) +async def test_search_index_load_empty(async_index): await async_index.create(overwrite=True, drop=True) await async_index.load([]) @pytest.mark.asyncio -async def test_no_id_field(async_client, async_index): - await async_index.set_client(async_client) +async def test_no_id_field(async_index): await async_index.create(overwrite=True, drop=True) bad_data = [{"wrong_key": "1", "value": "test"}] @@ -288,8 +289,7 @@ async def test_no_id_field(async_client, async_index): @pytest.mark.asyncio -async def test_check_index_exists_before_delete(async_client, async_index): - await async_index.set_client(async_client) +async def test_check_index_exists_before_delete(async_index): await async_index.create(overwrite=True, drop=True) await async_index.delete(drop=True) with pytest.raises(RedisSearchError): @@ -297,8 +297,7 @@ async def test_check_index_exists_before_delete(async_client, async_index): @pytest.mark.asyncio -async def test_check_index_exists_before_search(async_client, async_index): - await async_index.set_client(async_client) +async def test_check_index_exists_before_search(async_index): await async_index.create(overwrite=True, drop=True) await async_index.delete(drop=True) @@ -313,10 +312,87 @@ async def test_check_index_exists_before_search(async_client, async_index): @pytest.mark.asyncio -async def test_check_index_exists_before_info(async_client, async_index): - await async_index.set_client(async_client) +async def test_check_index_exists_before_info(async_index): await async_index.create(overwrite=True, drop=True) await async_index.delete(drop=True) with pytest.raises(RedisSearchError): await async_index.info() + + +@pytest.mark.asyncio +async def test_search_index_that_does_not_own_client_context_manager(async_index): + async with async_index: + await async_index.create(overwrite=True, drop=True) + assert async_index._redis_client + client = async_index._redis_client + assert async_index._redis_client == client + + +@pytest.mark.asyncio +async def test_search_index_that_does_not_own_client_context_manager_with_exception( + async_index, +): + try: + async with async_index: + await async_index.create(overwrite=True, drop=True) + client = async_index._redis_client + raise ValueError("test") + except ValueError: + pass + assert async_index._redis_client == client + + +@pytest.mark.asyncio +async def test_search_index_that_does_not_own_client_disconnect(async_index): + await async_index.create(overwrite=True, drop=True) + client = async_index._redis_client + await async_index.disconnect() + assert async_index._redis_client == client + + +@pytest.mark.asyncio +async def test_search_index_that_does_not_own_client_disconnect_sync(async_index): + await async_index.create(overwrite=True, drop=True) + client = async_index._redis_client + async_index.disconnect_sync() + assert async_index._redis_client == client + + +@pytest.mark.asyncio +async def test_search_index_that_owns_client_context_manager(index_schema, redis_url): + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) + async with async_index: + await async_index.create(overwrite=True, drop=True) + assert async_index._redis_client + assert async_index._redis_client is None + + +@pytest.mark.asyncio +async def test_search_index_that_owns_client_context_manager_with_exception( + index_schema, redis_url +): + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) + try: + async with async_index: + await async_index.create(overwrite=True, drop=True) + raise ValueError("test") + except ValueError: + pass + assert async_index._redis_client is None + + +@pytest.mark.asyncio +async def test_search_index_that_owns_client_disconnect(index_schema, redis_url): + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) + await async_index.create(overwrite=True, drop=True) + await async_index.disconnect() + assert async_index._redis_client is None + + +@pytest.mark.asyncio +async def test_search_index_that_owns_client_disconnect_sync(index_schema, redis_url): + async_index = AsyncSearchIndex(schema=index_schema, redis_url=redis_url) + await async_index.create(overwrite=True, drop=True) + await async_index.disconnect() + assert async_index._redis_client is None diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py index aa1f4e2a..8e2d2ea8 100644 --- a/tests/integration/test_connection.py +++ b/tests/integration/test_connection.py @@ -10,7 +10,6 @@ RedisConnectionFactory, compare_versions, convert_index_info_to_schema, - get_address_from_env, unpack_redis_modules, validate_modules, ) @@ -19,6 +18,9 @@ EXPECTED_LIB_NAME = f"redis-py(redisvl_v{__version__})" +# Remove after we remove connect() method from RedisConnectionFactory +pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") + def test_unpack_redis_modules(): module_list = [ @@ -129,41 +131,38 @@ def test_validate_modules_not_exist(): ) -def test_sync_redis_connect(redis_url): - client = RedisConnectionFactory.connect(redis_url) - assert client is not None - assert isinstance(client, Redis) - # Perform a simple operation - assert client.ping() - - -@pytest.mark.asyncio -async def test_async_redis_connect(redis_url): - client = RedisConnectionFactory.connect(redis_url, use_async=True) - assert client is not None - assert isinstance(client, AsyncRedis) - # Perform a simple operation - assert await client.ping() - - -def test_missing_env_var(): - redis_url = os.getenv("REDIS_URL") - if redis_url: - del os.environ["REDIS_URL"] +class TestConnect: + def test_sync_redis_connect(self, redis_url): + client = RedisConnectionFactory.connect(redis_url) + assert client is not None + assert isinstance(client, Redis) + # Perform a simple operation + assert client.ping() + + @pytest.mark.asyncio + async def test_async_redis_connect(self, redis_url): + client = RedisConnectionFactory.connect(redis_url, use_async=True) + assert client is not None + assert isinstance(client, AsyncRedis) + # Perform a simple operation + assert await client.ping() + + def test_missing_env_var(self): + redis_url = os.getenv("REDIS_URL") + if redis_url: + del os.environ["REDIS_URL"] + with pytest.raises(ValueError): + RedisConnectionFactory.connect() + os.environ["REDIS_URL"] = redis_url + + def test_invalid_url_format(self): with pytest.raises(ValueError): - RedisConnectionFactory.connect() - os.environ["REDIS_URL"] = redis_url - - -def test_invalid_url_format(): - with pytest.raises(ValueError): - RedisConnectionFactory.connect(redis_url="invalid_url_format") - + RedisConnectionFactory.connect(redis_url="invalid_url_format") -def test_unknown_redis(): - bad_client = RedisConnectionFactory.connect(redis_url="redis://fake:1234") - with pytest.raises(ConnectionError): - bad_client.ping() + def test_unknown_redis(self): + with pytest.raises(ConnectionError): + bad_client = RedisConnectionFactory.connect(redis_url="redis://fake:1234") + bad_client.ping() def test_validate_redis(client): diff --git a/tests/integration/test_flow.py b/tests/integration/test_flow.py index b448a636..7542528a 100644 --- a/tests/integration/test_flow.py +++ b/tests/integration/test_flow.py @@ -43,9 +43,7 @@ @pytest.mark.parametrize("schema", [hash_schema, json_schema]) def test_simple(client, schema, sample_data): - index = SearchIndex.from_dict(schema) - # assign client (only for testing) - index.set_client(client) + index = SearchIndex.from_dict(schema, redis_client=client) # create the index index.create(overwrite=True, drop=True) diff --git a/tests/integration/test_flow_async.py b/tests/integration/test_flow_async.py index fbfa7d22..a368f677 100644 --- a/tests/integration/test_flow_async.py +++ b/tests/integration/test_flow_async.py @@ -47,9 +47,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("schema", [hash_schema, json_schema]) async def test_simple(async_client, schema, sample_data): - index = AsyncSearchIndex.from_dict(schema) - # assign client (only for testing) - await index.set_client(async_client) + index = AsyncSearchIndex.from_dict(schema, redis_client=async_client) # create the index await index.create(overwrite=True, drop=True) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 5b32b918..b489caf9 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -108,15 +108,21 @@ def test_get_index(cache): @pytest.mark.asyncio async def test_get_async_index(cache): - aindex = await cache._get_async_index() - assert isinstance(aindex, AsyncSearchIndex) + async with cache: + aindex = await cache._get_async_index() + assert isinstance(aindex, AsyncSearchIndex) @pytest.mark.asyncio async def test_get_async_index_from_provided_client(cache_with_redis_client): - aindex = await cache_with_redis_client._get_async_index() - assert isinstance(aindex, AsyncSearchIndex) - assert aindex == cache_with_redis_client.aindex + async with cache_with_redis_client: + aindex = await cache_with_redis_client._get_async_index() + # Shouldn't have to do this because it already was done + await aindex.create(overwrite=True, drop=True) + assert await aindex.exists() + assert isinstance(aindex, AsyncSearchIndex) + assert aindex == cache_with_redis_client.aindex + assert await cache_with_redis_client.aindex.exists() def test_delete(cache_no_cleanup): @@ -126,8 +132,9 @@ def test_delete(cache_no_cleanup): @pytest.mark.asyncio async def test_async_delete(cache_no_cleanup): - await cache_no_cleanup.adelete() - assert not cache_no_cleanup.index.exists() + async with cache_no_cleanup: + await cache_no_cleanup.adelete() + assert not cache_no_cleanup.index.exists() def test_store_and_check(cache, vectorizer): @@ -150,8 +157,9 @@ async def test_async_store_and_check(cache, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache.astore(prompt, response, vector=vector) - check_result = await cache.acheck(vector=vector, distance_threshold=0.4) + async with cache: + await cache.astore(prompt, response, vector=vector) + check_result = await cache.acheck(vector=vector, distance_threshold=0.4) assert len(check_result) == 1 print(check_result, flush=True) @@ -202,36 +210,37 @@ async def test_async_return_fields(cache, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache.astore(prompt, response, vector=vector) - - # check default return fields - check_result = await cache.acheck(vector=vector) - assert set(check_result[0].keys()) == { - "key", - "entry_id", - "prompt", - "response", - "vector_distance", - "inserted_at", - "updated_at", - } - - # check specific return fields - fields = [ - "key", - "entry_id", - "prompt", - "response", - "vector_distance", - ] - check_result = await cache.acheck(vector=vector, return_fields=fields) - assert set(check_result[0].keys()) == set(fields) + async with cache: + await cache.astore(prompt, response, vector=vector) - # check only some return fields - fields = ["inserted_at", "updated_at"] - check_result = await cache.acheck(vector=vector, return_fields=fields) - fields.append("key") - assert set(check_result[0].keys()) == set(fields) + # check default return fields + check_result = await cache.acheck(vector=vector) + assert set(check_result[0].keys()) == { + "key", + "entry_id", + "prompt", + "response", + "vector_distance", + "inserted_at", + "updated_at", + } + + # check specific return fields + fields = [ + "key", + "entry_id", + "prompt", + "response", + "vector_distance", + ] + check_result = await cache.acheck(vector=vector, return_fields=fields) + assert set(check_result[0].keys()) == set(fields) + + # check only some return fields + fields = ["inserted_at", "updated_at"] + check_result = await cache.acheck(vector=vector, return_fields=fields) + fields.append("key") + assert set(check_result[0].keys()) == set(fields) # Test clearing the cache @@ -253,9 +262,10 @@ async def test_async_clear(cache, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache.astore(prompt, response, vector=vector) - await cache.aclear() - check_result = await cache.acheck(vector=vector) + async with cache: + await cache.astore(prompt, response, vector=vector) + await cache.aclear() + check_result = await cache.acheck(vector=vector) assert len(check_result) == 0 @@ -279,10 +289,11 @@ async def test_async_ttl_expiration(cache_with_ttl, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache_with_ttl.astore(prompt, response, vector=vector) - await asyncio.sleep(3) + async with cache_with_ttl: + await cache_with_ttl.astore(prompt, response, vector=vector) + await asyncio.sleep(3) - check_result = await cache_with_ttl.acheck(vector=vector) + check_result = await cache_with_ttl.acheck(vector=vector) assert len(check_result) == 0 @@ -305,10 +316,11 @@ async def test_async_custom_ttl(cache_with_ttl, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache_with_ttl.astore(prompt, response, vector=vector, ttl=5) - await asyncio.sleep(3) + async with cache_with_ttl: + await cache_with_ttl.astore(prompt, response, vector=vector, ttl=5) + await asyncio.sleep(3) + check_result = await cache_with_ttl.acheck(vector=vector) - check_result = await cache_with_ttl.acheck(vector=vector) assert len(check_result) != 0 assert cache_with_ttl.ttl == 2 @@ -333,11 +345,12 @@ async def test_async_ttl_refresh(cache_with_ttl, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache_with_ttl.astore(prompt, response, vector=vector) + async with cache_with_ttl: + await cache_with_ttl.astore(prompt, response, vector=vector) - for _ in range(3): - await asyncio.sleep(1) - check_result = await cache_with_ttl.acheck(vector=vector) + for _ in range(3): + await asyncio.sleep(1) + check_result = await cache_with_ttl.acheck(vector=vector) assert len(check_result) == 1 @@ -362,11 +375,13 @@ async def test_async_drop_document(cache, vectorizer): response = "This is a test response." vector = vectorizer.embed(prompt) - await cache.astore(prompt, response, vector=vector) - check_result = await cache.acheck(vector=vector) + async with cache: + await cache.astore(prompt, response, vector=vector) + check_result = await cache.acheck(vector=vector) + + await cache.adrop(ids=[check_result[0]["entry_id"]]) + recheck_result = await cache.acheck(vector=vector) - await cache.adrop(ids=[check_result[0]["entry_id"]]) - recheck_result = await cache.acheck(vector=vector) assert len(recheck_result) == 0 @@ -411,12 +426,14 @@ async def test_async_drop_documents(cache, vectorizer): vector = vectorizer.embed(prompt) await cache.astore(prompt, response, vector=vector) - check_result = await cache.acheck(vector=vector, num_results=3) - print(check_result, flush=True) - ids = [r["entry_id"] for r in check_result[0:2]] # drop first 2 entries - await cache.adrop(ids=ids) + async with cache: + check_result = await cache.acheck(vector=vector, num_results=3) + print(check_result, flush=True) + ids = [r["entry_id"] for r in check_result[0:2]] # drop first 2 entries + await cache.adrop(ids=ids) + + recheck_result = await cache.acheck(vector=vector, num_results=3) - recheck_result = await cache.acheck(vector=vector, num_results=3) assert len(recheck_result) == 1 @@ -445,19 +462,22 @@ def test_updating_document(cache): async def test_async_updating_document(cache): prompt = "This is a test prompt." response = "This is a test response." - await cache.astore(prompt=prompt, response=response) - check_result = await cache.acheck(prompt=prompt, return_fields=["updated_at"]) - key = check_result[0]["key"] + async with cache: + await cache.astore(prompt=prompt, response=response) - await asyncio.sleep(1) + check_result = await cache.acheck(prompt=prompt, return_fields=["updated_at"]) + key = check_result[0]["key"] - metadata = {"foo": "bar"} - await cache.aupdate(key=key, metadata=metadata) + await asyncio.sleep(1) + + metadata = {"foo": "bar"} + await cache.aupdate(key=key, metadata=metadata) + + updated_result = await cache.acheck( + prompt=prompt, return_fields=["updated_at", "metadata"] + ) - updated_result = await cache.acheck( - prompt=prompt, return_fields=["updated_at", "metadata"] - ) assert updated_result[0]["metadata"] == metadata assert updated_result[0]["updated_at"] > check_result[0]["updated_at"] @@ -486,10 +506,12 @@ async def test_async_ttl_expiration_after_update(cache_with_ttl, vectorizer): assert cache_with_ttl.ttl == 4 - await cache_with_ttl.astore(prompt, response, vector=vector) - await asyncio.sleep(5) + async with cache_with_ttl: + await cache_with_ttl.astore(prompt, response, vector=vector) + await asyncio.sleep(5) + + check_result = await cache_with_ttl.acheck(vector=vector) - check_result = await cache_with_ttl.acheck(vector=vector) assert len(check_result) == 0 @@ -941,3 +963,42 @@ def test_deprecated_dtype_argument(redis_url): redis_url=redis_url, overwrite=True, ) + + +@pytest.mark.asyncio +async def test_cache_async_context_manager(redis_url): + async with SemanticCache( + name="test_cache_async_context_manager", redis_url=redis_url + ) as cache: + await cache.astore("test prompt", "test response") + assert cache._aindex + assert cache._aindex is None + + +@pytest.mark.asyncio +async def test_cache_async_context_manager_with_exception(redis_url): + try: + async with SemanticCache( + name="test_cache_async_context_manager_with_exception", redis_url=redis_url + ) as cache: + await cache.astore("test prompt", "test response") + raise ValueError("test") + except ValueError: + pass + assert cache._aindex is None + + +@pytest.mark.asyncio +async def test_cache_async_disconnect(redis_url): + cache = SemanticCache(name="test_cache_async_disconnect", redis_url=redis_url) + await cache.astore("test prompt", "test response") + await cache.adisconnect() + assert cache._aindex is None + + +def test_cache_disconnect(redis_url): + cache = SemanticCache(name="test_cache_disconnect", redis_url=redis_url) + cache.store("test prompt", "test response") + cache.disconnect() + # We keep this index object around because it isn't lazily created + assert cache._index.client is None diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 2c6dc376..271d36da 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -92,12 +92,10 @@ def index(sample_data, redis_url): }, }, ], - } + }, + redis_url=redis_url, ) - # connect to local redis instance - index.connect(redis_url) - # create the index (no data yet) index.create(overwrite=True) diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index 4f4392d3..9d649cc8 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -1,9 +1,10 @@ +import warnings + import pytest from redisvl.exceptions import RedisSearchError from redisvl.index import SearchIndex from redisvl.query import VectorQuery -from redisvl.redis.connection import RedisConnectionFactory, validate_modules from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType @@ -27,8 +28,8 @@ def index_schema(): @pytest.fixture -def index(index_schema): - return SearchIndex(schema=index_schema) +def index(index_schema, client): + return SearchIndex(schema=index_schema, redis_client=client) @pytest.fixture @@ -45,7 +46,7 @@ def test_search_index_properties(index_schema, index): assert index.schema == index_schema # custom settings assert index.name == index_schema.index.name == "my_index" - assert index.client == None + # default settings assert index.prefix == index_schema.index.prefix == "rvl" assert index.key_separator == index_schema.index.key_separator == ":" @@ -73,7 +74,6 @@ def test_search_index_from_dict(index_from_dict): def test_search_index_from_existing(client, index): - index.set_client(client) index.create(overwrite=True) try: @@ -134,10 +134,14 @@ def test_search_index_no_prefix(index_schema): def test_search_index_redis_url(redis_url, index_schema): index = SearchIndex(schema=index_schema, redis_url=redis_url) + # Client is not set until a command runs + assert index.client is None + + index.create(overwrite=True) assert index.client index.disconnect() - assert index.client == None + assert index.client is None def test_search_index_client(client, index_schema): @@ -145,34 +149,36 @@ def test_search_index_client(client, index_schema): assert index.client == client -def test_search_index_set_client(async_client, client, index): - index.set_client(client) - assert index.client == client - # should not be able to set the sync client here - with pytest.raises(TypeError): - index.set_client(async_client) +def test_search_index_set_client(async_client, redis_url, index_schema): + index = SearchIndex(schema=index_schema, redis_url=redis_url) - index.disconnect() - assert index.client == None + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + index.create(overwrite=True, drop=True) + assert index.client + # should not be able to set an async client here + with pytest.raises(TypeError): + index.set_client(async_client) + assert index.client is not async_client + + index.disconnect() + assert index.client is None -def test_search_index_create(client, index): - index.set_client(client) +def test_search_index_create(index): index.create(overwrite=True, drop=True) assert index.exists() assert index.name in convert_bytes(index.client.execute_command("FT._LIST")) -def test_search_index_delete(client, index): - index.set_client(client) +def test_search_index_delete(index): index.create(overwrite=True, drop=True) index.delete(drop=True) assert not index.exists() assert index.name not in convert_bytes(index.client.execute_command("FT._LIST")) -def test_search_index_clear(client, index): - index.set_client(client) +def test_search_index_clear(index): index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] index.load(data, id_field="id") @@ -182,8 +188,7 @@ def test_search_index_clear(client, index): assert index.exists() -def test_search_index_drop_key(client, index): - index.set_client(client) +def test_search_index_drop_key(index): index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}, {"id": "2", "test": "bar"}] keys = index.load(data, id_field="id") @@ -195,8 +200,7 @@ def test_search_index_drop_key(client, index): assert index.fetch(keys[1]) is not None # still have all other entries -def test_search_index_drop_keys(client, index): - index.set_client(client) +def test_search_index_drop_keys(index): index.create(overwrite=True, drop=True) data = [ {"id": "1", "test": "foo"}, @@ -215,22 +219,20 @@ def test_search_index_drop_keys(client, index): assert index.exists() -def test_search_index_load_and_fetch(client, index): - index.set_client(client) +def test_search_index_load_and_fetch(index): index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] index.load(data, id_field="id") res = index.fetch("1") - assert res["test"] == convert_bytes(client.hget("rvl:1", "test")) == "foo" + assert res["test"] == convert_bytes(index.client.hget("rvl:1", "test")) == "foo" index.delete(drop=True) assert not index.exists() assert not index.fetch("1") -def test_search_index_load_preprocess(client, index): - index.set_client(client) +def test_search_index_load_preprocess(index): index.create(overwrite=True, drop=True) data = [{"id": "1", "test": "foo"}] @@ -240,7 +242,7 @@ def preprocess(record): index.load(data, id_field="id", preprocess=preprocess) res = index.fetch("1") - assert res["test"] == convert_bytes(client.hget("rvl:1", "test")) == "bar" + assert res["test"] == convert_bytes(index.client.hget("rvl:1", "test")) == "bar" def bad_preprocess(record): return 1 @@ -249,8 +251,7 @@ def bad_preprocess(record): index.load(data, id_field="id", preprocess=bad_preprocess) -def test_no_id_field(client, index): - index.set_client(client) +def test_no_id_field(index): index.create(overwrite=True, drop=True) bad_data = [{"wrong_key": "1", "value": "test"}] @@ -259,16 +260,14 @@ def test_no_id_field(client, index): index.load(bad_data, id_field="key") -def test_check_index_exists_before_delete(client, index): - index.set_client(client) +def test_check_index_exists_before_delete(index): index.create(overwrite=True, drop=True) index.delete(drop=True) with pytest.raises(RedisSearchError): index.delete() -def test_check_index_exists_before_search(client, index): - index.set_client(client) +def test_check_index_exists_before_search(index): index.create(overwrite=True, drop=True) index.delete(drop=True) @@ -282,8 +281,7 @@ def test_check_index_exists_before_search(client, index): index.search(query.query, query_params=query.params) -def test_check_index_exists_before_info(client, index): - index.set_client(client) +def test_check_index_exists_before_info(index): index.create(overwrite=True, drop=True) index.delete(drop=True) @@ -293,4 +291,57 @@ def test_check_index_exists_before_info(client, index): def test_index_needs_valid_schema(): with pytest.raises(ValueError, match=r"Must provide a valid IndexSchema object"): - index = SearchIndex(schema="Not A Valid Schema") + SearchIndex(schema="Not A Valid Schema") # type: ignore + + +def test_search_index_that_does_not_own_client_context_manager(index): + with index: + index.create(overwrite=True, drop=True) + assert index.client + client = index.client + # Client should not have changed outside of the context manager + assert index.client == client + + +def test_search_index_that_does_not_own_client_context_manager_with_exception(index): + with pytest.raises(ValueError): + with index: + index.create(overwrite=True, drop=True) + client = index.client + raise ValueError("test") + # Client should not have changed outside of the context manager + assert index.client == client + + +def test_search_index_that_does_not_own_client_disconnect(index): + index.create(overwrite=True, drop=True) + client = index.client + index.disconnect() + # Client should not have changed after disconnecting + assert index.client == client + + +def test_search_index_that_owns_client_context_manager(index_schema, redis_url): + index = SearchIndex(schema=index_schema, redis_url=redis_url) + with index: + index.create(overwrite=True, drop=True) + assert index.client + assert index.client is None + + +def test_search_index_that_owns_client_context_manager_with_exception( + index_schema, redis_url +): + index = SearchIndex(schema=index_schema, redis_url=redis_url) + with pytest.raises(ValueError): + with index: + index.create(overwrite=True, drop=True) + raise ValueError("test") + assert index.client is None + + +def test_search_index_that_owns_client_disconnect(index_schema, redis_url): + index = SearchIndex(schema=index_schema, redis_url=redis_url) + index.create(overwrite=True, drop=True) + index.disconnect() + assert index.client is None diff --git a/tests/integration/test_search_results.py b/tests/integration/test_search_results.py index 83efad53..c451f039 100644 --- a/tests/integration/test_search_results.py +++ b/tests/integration/test_search_results.py @@ -44,10 +44,7 @@ def index(sample_data, redis_url): } # construct a search index from the schema - index = SearchIndex.from_dict(json_schema) - - # connect to local redis instance - index.connect(redis_url=redis_url) + index = SearchIndex.from_dict(json_schema, redis_url=redis_url) # create the index (no data yet) index.create(overwrite=True) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 4b82a6c9..a6e3dd02 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,3 @@ -import warnings from functools import wraps import numpy as np @@ -10,7 +9,11 @@ convert_bytes, make_dict, ) -from redisvl.utils.utils import assert_no_warnings, deprecated_argument +from redisvl.utils.utils import ( + assert_no_warnings, + deprecated_argument, + deprecated_function, +) def test_even_number_of_elements(): @@ -420,3 +423,21 @@ async def test_func(old_arg=None, new_arg=None): with assert_no_warnings(): await test_func() + + +class TestDeprecatedFunction: + def test_deprecated_function_warning(self): + @deprecated_function("new_func", "Use new_func2") + def old_func(): + pass + + with pytest.warns(DeprecationWarning): + old_func() + + def test_deprecated_function_warning_with_name(self): + @deprecated_function("new_func", "Use new_func2") + def old_func(): + pass + + with pytest.warns(DeprecationWarning): + old_func()