-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for pgvector's hnsw (0.7.4) and generic support for Postgres (16) indexes #309
Changes from all commits
498e8b2
9b68a2a
3f65080
5147f07
0cc417c
381215d
2c30c59
ac3eb27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,5 @@ results/* | |
venv | ||
|
||
.idea | ||
|
||
.vscode |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
FROM neurips23 | ||
|
||
# install Postgres and dev package | ||
RUN apt-get update && \ | ||
apt-get -y install wget gnupg2 lsb-release | ||
RUN sh -c 'echo "deb https://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' | ||
RUN wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - | ||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -y install postgresql-16 postgresql-server-dev-16 | ||
|
||
# install git | ||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y git | ||
|
||
# create a Postgres database | ||
USER postgres | ||
RUN mkdir /var/lib/postgresql/test_database | ||
RUN /usr/lib/postgresql/16/bin/initdb -D /var/lib/postgresql/test_database | ||
USER root | ||
|
||
# create the script that we will use start the database during the tests | ||
RUN echo "su - postgres -c \"/usr/lib/postgresql/16/bin/pg_ctl \ | ||
-D /var/lib/postgresql/test_database \ | ||
-l /var/lib/postgresql/test_database_logfile \ | ||
-o \\\"-F -p 5432\\\" start\"" > /home/app/start_database.sh | ||
RUN chmod +x /home/app/start_database.sh | ||
|
||
# install python deps | ||
RUN pip3 install pgvector==0.3.3 psycopg==3.2.1 | ||
|
||
# install linux-tools-generic into docker so that devs can use perf if they want | ||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y linux-tools-generic | ||
|
||
# clone FlameGraph for the same purpose | ||
RUN git clone https://github.com/brendangregg/FlameGraph | ||
Comment on lines
+29
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can be reverted as mentioned in PR descr, only useful for FlameGraph'ing purposes |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
import subprocess | ||
import time | ||
import numpy as np | ||
import psycopg | ||
import concurrent.futures | ||
|
||
from math import ceil | ||
from neurips23.streaming.base import BaseStreamingANN | ||
from pgvector.psycopg import register_vector | ||
|
||
|
||
PG_CONN_STR = "dbname=postgres user=postgres port=5432 host=localhost" | ||
|
||
|
||
def cursor_print_and_execute(cur, query): | ||
print(query) | ||
cur.execute(query) | ||
|
||
class BaseStreamingANNPostgres(BaseStreamingANN): | ||
# Child classes should implement the following methods ..: | ||
# - determine_index_op_class(self, metric) | ||
# - determine_query_op(self, metric) | ||
# | ||
# .. as well as setting the following attributes in their __init__ methods before calling super().__init__: | ||
# - self.name | ||
# - self.pg_index_method | ||
# - self.guc_prefix | ||
def determine_index_op_class(self, metric): | ||
raise NotImplementedError() | ||
|
||
def determine_query_op(self, metric): | ||
raise NotImplementedError() | ||
|
||
def __init__(self, metric, index_params): | ||
self.n_insert_conns = index_params.get("insert_conns") | ||
if self.n_insert_conns == None: | ||
raise Exception('Missing parameter insert_conns') | ||
|
||
# save it for later use in __str__() | ||
self._index_params = index_params | ||
|
||
# we'll initialize the connections later in set_query_arguments() per "query-arguments" set | ||
self.conns = [] | ||
|
||
self.index_build_params = {k: v for k, v in index_params.items() if k != "insert_conns"} | ||
|
||
self.ind_op_class = self.determine_index_op_class(metric) | ||
|
||
self.search_query = f"SELECT id FROM test_tbl ORDER BY vec_col {self.determine_query_op(metric)} %b LIMIT %b" | ||
|
||
start_database_result = subprocess.run(['bash', '/home/app/start_database.sh'], capture_output=True, text=True) | ||
if start_database_result.returncode != 0: | ||
raise Exception(f'Failed to start the database: {start_database_result.stderr}') | ||
|
||
def setup(self, dtype, max_pts, ndim): | ||
if dtype != 'float32': | ||
raise Exception('Invalid data type') | ||
|
||
index_build_params_clause = "" | ||
if self.index_build_params: | ||
index_build_params_clause = "WITH (" | ||
first = True | ||
for k, v in self.index_build_params.items(): | ||
if not first: | ||
index_build_params_clause += ", " | ||
|
||
first = False | ||
index_build_params_clause += f"{k} = {v}" | ||
|
||
index_build_params_clause += ")" | ||
|
||
# create the table and index by using a temporary connection | ||
with psycopg.connect(PG_CONN_STR, autocommit=True) as conn: | ||
with conn.cursor() as cur: | ||
cursor_print_and_execute(cur, f"CREATE TABLE test_tbl (id bigint, vec_col vector({ndim}))") | ||
cursor_print_and_execute(cur, f"CREATE INDEX vec_col_idx ON test_tbl USING {self.pg_index_method} (vec_col {self.ind_op_class}) {index_build_params_clause}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To stabilize the measurements, should we disable autovacuum when creating the table since we explicitly vacuum the table at some point? @orhankislal? |
||
|
||
# required by insert() & delete() | ||
self.max_pts = max_pts | ||
self.active_indices = set() | ||
self.num_unprocessed_deletes = 0 | ||
|
||
print('Index class constructed and ready') | ||
|
||
def done(self): | ||
# close any existing connections | ||
for conn in self.conns: | ||
conn.close() | ||
|
||
super().done() | ||
|
||
def insert(self, X, ids): | ||
n_insert_rows = len(ids) | ||
|
||
self.active_indices.update(ids+1) | ||
|
||
print('#active pts', len(self.active_indices), '#unprocessed deletes', self.num_unprocessed_deletes, '#inserting', n_insert_rows) | ||
|
||
# Execute VACUUM if the number of active points + the number of unprocessed deletes exceeds the max_pts | ||
if len(self.active_indices) + self.num_unprocessed_deletes >= self.max_pts: | ||
print('Executing VACUUM') | ||
|
||
start_time = time.time() | ||
|
||
with self.conns[0].cursor() as cur: | ||
cur.execute('VACUUM test_tbl') | ||
|
||
exec_time = time.time() - start_time | ||
|
||
log_dict = { | ||
'vacuum': self.num_unprocessed_deletes, | ||
'exec_time': exec_time | ||
} | ||
|
||
print('Timing:', log_dict) | ||
|
||
self.num_unprocessed_deletes = 0 | ||
|
||
def copy_data(conn_idx, id_start_idx, id_end_idx): | ||
with self.conns[conn_idx].cursor().copy("COPY test_tbl (id, vec_col) FROM STDIN WITH (FORMAT BINARY)") as copy: | ||
copy.set_types(["int8", "vector"]) | ||
for id, vec in zip(ids[id_start_idx:id_end_idx], X[id_start_idx:id_end_idx]): | ||
copy.write_row((id, vec)) | ||
|
||
# Run the copy_data function in parallel | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_insert_conns) as executor: | ||
chunk_size = ceil(n_insert_rows / self.n_insert_conns) | ||
copy_futures = [] | ||
for conn_idx, id_start_idx in enumerate(range(0, n_insert_rows, chunk_size)): | ||
id_end_idx = min(id_start_idx + chunk_size, n_insert_rows) | ||
copy_futures.append(executor.submit(copy_data, conn_idx, id_start_idx, id_end_idx)) | ||
|
||
start_time = time.time() | ||
|
||
for copy_future in concurrent.futures.as_completed(copy_futures): | ||
# raise any exceptions that occurred during execution | ||
copy_future.result() | ||
|
||
exec_time = time.time() - start_time | ||
|
||
log_dict = { | ||
'insert': n_insert_rows, | ||
'exec_time': exec_time | ||
} | ||
|
||
print('Timing:', log_dict) | ||
|
||
def delete(self, ids): | ||
n_delete_rows = len(ids) | ||
|
||
start_time = time.time() | ||
|
||
with self.conns[0].cursor() as cur: | ||
# delete ids in batches of 1000 | ||
for i in range(0, n_delete_rows, 1000): | ||
subset = [x for x in ids[i:i+1000]] | ||
cur.execute("DELETE FROM test_tbl WHERE id = ANY(%s)", (subset,)) | ||
|
||
exec_time = time.time() - start_time | ||
|
||
log_dict = { | ||
'delete': n_delete_rows, | ||
'exec_time': exec_time | ||
} | ||
|
||
print('Timing:', log_dict) | ||
|
||
self.active_indices.difference_update(ids+1) | ||
self.num_unprocessed_deletes += n_delete_rows | ||
|
||
def query(self, X, k): | ||
def batch_query(conn_idx, query_vec_start_idx, query_vec_end_idx): | ||
batch_result_id_lists = [] | ||
for query_vec in X[query_vec_start_idx: query_vec_end_idx]: | ||
with self.conns[conn_idx].cursor() as cur: | ||
try: | ||
cur.execute(self.search_query, (query_vec, k, ), binary=True, prepare=True) | ||
except Exception as e: | ||
raise Exception(f"Error '{e}' when querying with k={k}\nQuery vector was:\n{query_vec}") from e | ||
|
||
result_tuples = cur.fetchall() | ||
|
||
result_ids = list(map(lambda tup: tup[0], result_tuples)) | ||
|
||
if len(result_ids) < k: | ||
# Pad with -1 if we have less than k results. This is only needed if the | ||
# index-access method cannot guarantee returning k results. | ||
# | ||
# As of today, this is only possible with PostgresPgvectorHnsw when | ||
# ef_search < k. | ||
result_ids.extend([-1] * (k - len(result_ids))) | ||
Comment on lines
+185
to
+191
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it ok to pad self.res like this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, should be fine. The recall computation code performs a set intersection between the ground-truth results and reported results, so any -1 results would be ignored as intended. |
||
|
||
batch_result_id_lists.append(result_ids) | ||
|
||
return batch_result_id_lists | ||
|
||
total_queries = len(X) | ||
|
||
result_id_lists = [] | ||
|
||
# Run the batch_query function in parallel | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_query_conns) as executor: | ||
chunk_size = ceil(total_queries / self.n_query_conns) | ||
query_futures = [] | ||
for conn_idx, query_vec_start_idx in enumerate(range(0, total_queries, chunk_size)): | ||
query_vec_end_idx = min(query_vec_start_idx + chunk_size, total_queries) | ||
query_futures.append(executor.submit(batch_query, conn_idx, query_vec_start_idx, query_vec_end_idx)) | ||
|
||
start_time = time.time() | ||
|
||
# wait for all futures to complete | ||
done, not_done = concurrent.futures.wait(query_futures) | ||
|
||
exec_time = time.time() - start_time | ||
|
||
log_dict = { | ||
'query': total_queries, | ||
'exec_time': exec_time | ||
} | ||
|
||
print('Timing:', log_dict) | ||
|
||
assert len(not_done) == 0 | ||
assert len(done) == len(query_futures) | ||
|
||
# retrieve the results in the order they were submitted to avoid messing up the order | ||
for query_future in query_futures: | ||
batch_result_id_lists = query_future.result() | ||
result_id_lists.extend(batch_result_id_lists) | ||
|
||
self.res = np.vstack(result_id_lists, dtype=np.int32) | ||
|
||
def set_query_arguments(self, query_args): | ||
# save it for later use in __str__() | ||
self._query_args = query_args | ||
|
||
# close any existing connections | ||
for conn in self.conns: | ||
conn.close() | ||
|
||
# By using a temporary connection, truncate the table since set_query_arguments() is called | ||
# before each testing phase with new set of query params. | ||
with psycopg.connect(PG_CONN_STR, autocommit=True) as conn: | ||
with conn.cursor() as cur: | ||
cursor_print_and_execute(cur, "TRUNCATE test_tbl") | ||
Comment on lines
+243
to
+245
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although the other algo implementations don't seem to be doing this, I thought that we have to reset all the data in main table and the index here, i.e., before switching to a different query-args set. Does that make sense? |
||
|
||
self.n_query_conns = query_args.get("query_conns") | ||
if self.n_query_conns == None: | ||
raise Exception('Missing parameter query_conns') | ||
|
||
n_conns_needed = max(self.n_query_conns, self.n_insert_conns) | ||
|
||
self.conns = [psycopg.connect(PG_CONN_STR, autocommit=True) for _ in range(n_conns_needed)] | ||
|
||
# so that we can insert np arrays as pgvector's vector data type transparently | ||
for conn in self.conns: | ||
register_vector(conn) | ||
|
||
guc_args = {k: v for k, v in query_args.items() if k != "query_conns"} | ||
|
||
for conn in self.conns: | ||
with conn.cursor() as cur: | ||
for k, v in guc_args.items(): | ||
cursor_print_and_execute(cur, f"SET {self.guc_prefix}.{k} TO {v}") | ||
|
||
# disable seqscan for all connections since we mainly want to test index-scan | ||
cursor_print_and_execute(cur, f"SET enable_seqscan TO OFF") | ||
|
||
def __str__(self): | ||
build_args_str = ' '.join([f'{k}={v}' for k, v in sorted(self._index_params.items())]) | ||
query_args_str = ' '.join([f'{k}={v}' for k, v in sorted(self._query_args.items())]) | ||
|
||
return f'{self.name}({build_args_str} - {query_args_str})' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be reverted as mentioned in PR descr, only useful for FlameGraph'ing purposes