diff --git a/pyproject.toml b/pyproject.toml index f3ae99eec..e048a19cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,13 +59,14 @@ all = [ "redis", "chromadb", "psycopg2", + "psycopg", ] qdrant = [ "qdrant-client" ] pinecone = [ "pinecone-client" ] weaviate = [ "weaviate-client" ] elastic = [ "elasticsearch" ] -pgvector = [ "pgvector", "psycopg2" ] +pgvector = [ "pgvector", "psycopg" ] pgvecto_rs = [ "psycopg2" ] redis = [ "redis" ] chromadb = [ "chromadb" ] diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index 7d4e4f684..d1325322a 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -20,6 +20,7 @@ class IndexType(str, Enum): Flat = "FLAT" AUTOINDEX = "AUTOINDEX" ES_HNSW = "hnsw" + ES_IVFFlat = "ivfflat" GPU_IVF_FLAT = "GPU_IVF_FLAT" GPU_IVF_PQ = "GPU_IVF_PQ" GPU_CAGRA = "GPU_CAGRA" diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index 5de49ae76..a4cc584e8 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -1,29 +1,62 @@ +from abc import abstractmethod +from typing import Any, Mapping, Optional, Sequence, TypedDict from pydantic import BaseModel, SecretStr -from ..api import DBConfig, DBCaseConfig, IndexType, MetricType +from typing_extensions import LiteralString +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" + +class PgVectorConfigDict(TypedDict): + """These keys will be directly used as kwargs in psycopg connection string, + so the names must match exactly psycopg API""" + + user: str + password: str + host: str + port: int + dbname: str + + class PgVectorConfig(DBConfig): - user_name: SecretStr = "postgres" + user_name: SecretStr = SecretStr("postgres") password: SecretStr host: str = "localhost" port: int = 5432 db_name: str - def to_dict(self) -> dict: + def to_dict(self) -> PgVectorConfigDict: user_str = self.user_name.get_secret_value() pwd_str = self.password.get_secret_value() return { - "host" : self.host, - "port" : self.port, - "dbname" : self.db_name, - "user" : user_str, - "password" : pwd_str + "host": self.host, + "port": self.port, + "dbname": self.db_name, + "user": user_str, + "password": pwd_str, } + +class PgVectorIndexParam(TypedDict): + metric: str + index_type: str + index_creation_with_options: Sequence[dict[str, Any]] + maintenance_work_mem: Optional[str] + max_parallel_workers: Optional[int] + + +class PgVectorSearchParam(TypedDict): + metric_fun_op: LiteralString + + +class PgVectorSessionCommands(TypedDict): + session_options: Sequence[dict[str, Any]] + + class PgVectorIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None - index: IndexType + create_index_before_load: bool = False + create_index_after_load: bool = True def parse_metric(self) -> str: if self.metric_type == MetricType.L2: @@ -32,7 +65,7 @@ def parse_metric(self) -> str: return "vector_ip_ops" return "vector_cosine_ops" - def parse_metric_fun_op(self) -> str: + def parse_metric_fun_op(self) -> LiteralString: if self.metric_type == MetricType.L2: return "<->" elif self.metric_type == MetricType.IP: @@ -46,48 +79,137 @@ def parse_metric_fun_str(self) -> str: return "max_inner_product" return "cosine_distance" + @abstractmethod + def index_param(self) -> PgVectorIndexParam: + ... + + @abstractmethod + def search_param(self) -> PgVectorSearchParam: + ... + + @abstractmethod + def session_param(self) -> PgVectorSessionCommands: + ... + + @staticmethod + def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]: + """Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause""" + options = [] + for option_name, value in with_options.items(): + if value is not None: + options.append( + { + "option_name": option_name, + "val": str(value), + } + ) + return options + + @staticmethod + def _optionally_build_set_options( + set_mapping: Mapping[str, Any] + ) -> Sequence[dict[str, Any]]: + """Walk through options, creating 'SET 'key1 = "value1";' commands""" + session_options = [] + for setting_name, value in set_mapping.items(): + if value: + session_options.append( + {"parameter": { + "setting_name": setting_name, + "val": str(value), + }, + } + ) + return session_options + + +class PgVectorIVFFlatConfig(PgVectorIndexConfig): + """ + An IVFFlat index divides vectors into lists, and then searches a subset of those lists that are + closest to the query vector. It has faster build times and uses less memory than HNSW, + but has lower query performance (in terms of speed-recall tradeoff). + + Three keys to achieving good recall are: + + Create the index after the table has some data + Choose an appropriate number of lists - a good place to start is rows / 1000 for up to 1M rows and sqrt(rows) for + over 1M rows. + When querying, specify an appropriate number of probes (higher is better for recall, lower is better for speed) - + a good place to start is sqrt(lists) + """ + + lists: int | None + probes: int | None + index: IndexType = IndexType.ES_IVFFlat + maintenance_work_mem: Optional[str] = None + max_parallel_workers: Optional[int] = None + + def index_param(self) -> PgVectorIndexParam: + index_parameters = {"lists": self.lists} + return { + "metric": self.parse_metric(), + "index_type": self.index.value, + "index_creation_with_options": self._optionally_build_with_options( + index_parameters + ), + "maintenance_work_mem": self.maintenance_work_mem, + "max_parallel_workers": self.max_parallel_workers, + } - -class HNSWConfig(PgVectorIndexConfig): - M: int - efConstruction: int - ef: int | None = None - index: IndexType = IndexType.HNSW - - def index_param(self) -> dict: + def search_param(self) -> PgVectorSearchParam: return { - "m" : self.M, - "ef_construction" : self.efConstruction, - "metric" : self.parse_metric() + "metric_fun_op": self.parse_metric_fun_op(), } - def search_param(self) -> dict: + def session_param(self) -> PgVectorSessionCommands: + session_parameters = {"ivfflat.probes": self.probes} return { - "ef" : self.ef, - "metric_fun" : self.parse_metric_fun_str(), - "metric_fun_op" : self.parse_metric_fun_op(), + "session_options": self._optionally_build_set_options(session_parameters) } -class IVFFlatConfig(PgVectorIndexConfig): - lists: int | None = 1000 - probes: int | None = 10 - index: IndexType = IndexType.IVFFlat +class PgVectorHNSWConfig(PgVectorIndexConfig): + """ + An HNSW index creates a multilayer graph. It has better query performance than IVFFlat (in terms of + speed-recall tradeoff), but has slower build times and uses more memory. Also, an index can be + created without any data in the table since there isn't a training step like IVFFlat. + """ + + m: int | None # DETAIL: Valid values are between "2" and "100". + ef_construction: ( + int | None + ) # ef_construction must be greater than or equal to 2 * m + ef_search: int | None + index: IndexType = IndexType.ES_HNSW + maintenance_work_mem: Optional[str] = None + max_parallel_workers: Optional[int] = None + + def index_param(self) -> PgVectorIndexParam: + index_parameters = {"m": self.m, "ef_construction": self.ef_construction} + return { + "metric": self.parse_metric(), + "index_type": self.index.value, + "index_creation_with_options": self._optionally_build_with_options( + index_parameters + ), + "maintenance_work_mem": self.maintenance_work_mem, + "max_parallel_workers": self.max_parallel_workers, + } - def index_param(self) -> dict: + def search_param(self) -> PgVectorSearchParam: return { - "lists" : self.lists, - "metric" : self.parse_metric() + "metric_fun_op": self.parse_metric_fun_op(), } - def search_param(self) -> dict: + def session_param(self) -> PgVectorSessionCommands: + session_parameters = {"hnsw.ef_search": self.ef_search} return { - "probes" : self.probes, - "metric_fun" : self.parse_metric_fun_str(), - "metric_fun_op" : self.parse_metric_fun_op(), + "session_options": self._optionally_build_set_options(session_parameters) } + _pgvector_case_config = { - IndexType.HNSW: HNSWConfig, - IndexType.IVFFlat: IVFFlatConfig, + IndexType.HNSW: PgVectorHNSWConfig, + IndexType.ES_HNSW: PgVectorHNSWConfig, + IndexType.IVFFlat: PgVectorIVFFlatConfig, } diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index afe4218c4..8f8244412 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -1,25 +1,36 @@ """Wrapper around the Pgvector vector database over VectorDB""" -import io import logging +import pprint from contextlib import contextmanager -from typing import Any -import pandas as pd -import psycopg2 -import psycopg2.extras +from typing import Any, Generator, Optional, Tuple, Sequence -from ..api import IndexType, VectorDB, DBCaseConfig +import numpy as np +import psycopg +from pgvector.psycopg import register_vector +from psycopg import Connection, Cursor, sql + +from ..api import VectorDB +from .config import PgVectorConfigDict, PgVectorIndexConfig log = logging.getLogger(__name__) + class PgVector(VectorDB): - """ Use SQLAlchemy instructions""" + """Use psycopg instructions""" + + conn: psycopg.Connection[Any] | None = None + cursor: psycopg.Cursor[Any] | None = None + + # TODO add filters support + _unfiltered_search: sql.Composed + def __init__( self, dim: int, - db_config: dict, - db_case_config: DBCaseConfig, - collection_name: str = "PgVectorCollection", + db_config: PgVectorConfigDict, + db_case_config: PgVectorIndexConfig, + collection_name: str = "pg_vector_collection", drop_old: bool = False, **kwargs, ): @@ -29,43 +40,88 @@ def __init__( self.table_name = collection_name self.dim = dim - self._index_name = "pqvector_index" + self._index_name = "pgvector_index" self._primary_field = "id" self._vector_field = "embedding" # construct basic units - self.conn = psycopg2.connect(**self.db_config) - self.conn.autocommit = False - self.cursor = self.conn.cursor() + self.conn, self.cursor = self._create_connection(**self.db_config) # create vector extension - self.cursor.execute('CREATE EXTENSION IF NOT EXISTS vector') + self.cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") self.conn.commit() - if drop_old : - log.info(f"Pgvector client drop table : {self.table_name}") + log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}") + if not any( + ( + self.case_config.create_index_before_load, + self.case_config.create_index_after_load, + ) + ): + err = f"{self.name} config must create an index using create_index_before_load and/or create_index_after_load" + log.error(err) + raise RuntimeError( + f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}" + ) + + if drop_old: # self.pg_table.drop(pg_engine, checkfirst=True) self._drop_index() self._drop_table() self._create_table(dim) - self._create_index() + if self.case_config.create_index_before_load: + self._create_index() self.cursor.close() self.conn.close() self.cursor = None self.conn = None + @staticmethod + def _create_connection(**kwargs) -> Tuple[Connection, Cursor]: + conn = psycopg.connect(**kwargs) + register_vector(conn) + conn.autocommit = False + cursor = conn.cursor() + + assert conn is not None, "Connection is not initialized" + assert cursor is not None, "Cursor is not initialized" + + return conn, cursor + @contextmanager - def init(self) -> None: + def init(self) -> Generator[None, None, None]: """ Examples: >>> with self.init(): >>> self.insert_embeddings() >>> self.search_embedding() """ - self.conn = psycopg2.connect(**self.db_config) - self.conn.autocommit = False - self.cursor = self.conn.cursor() + + self.conn, self.cursor = self._create_connection(**self.db_config) + + # index configuration may have commands defined that we should set during each client session + session_options: Sequence[dict[str, Any]] = self.case_config.session_param()["session_options"] + + if len(session_options) > 0: + for setting in session_options: + command = sql.SQL("SET {setting_name} " + "= {val};").format( + setting_name=sql.Identifier(setting['parameter']['setting_name']), + val=sql.Identifier(str(setting['parameter']['val'])), + ) + log.debug(command.as_string(self.cursor)) + self.cursor.execute(command) + self.conn.commit() + + self._unfiltered_search = sql.Composed( + [ + sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format( + sql.Identifier(self.table_name) + ), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" %s::vector LIMIT %s::int"), + ] + ) try: yield @@ -78,54 +134,166 @@ def init(self) -> None: def _drop_table(self): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop table : {self.table_name}") - self.cursor.execute(f'DROP TABLE IF EXISTS public."{self.table_name}"') + self.cursor.execute( + sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format( + table_name=sql.Identifier(self.table_name) + ) + ) self.conn.commit() def ready_to_load(self): pass def optimize(self): - log.info(f"{self.name} optimizing") - self._drop_index() - self._create_index() + self._post_insert() - def ready_to_search(self): - pass + def _post_insert(self): + log.info(f"{self.name} post insert before optimize") + if self.case_config.create_index_after_load: + self._drop_index() + self._create_index() def _drop_index(self): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client drop index : {self._index_name}") - self.cursor.execute(f'DROP INDEX IF EXISTS "{self._index_name}"') + drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format( + index_name=sql.Identifier(self._index_name) + ) + log.debug(drop_index_sql.as_string(self.cursor)) + self.cursor.execute(drop_index_sql) self.conn.commit() + def _set_parallel_index_build_param(self): + assert self.conn is not None, "Connection is not initialized" + assert self.cursor is not None, "Cursor is not initialized" + + index_param = self.case_config.index_param() + + if index_param["maintenance_work_mem"] is not None: + self.cursor.execute( + sql.SQL("SET maintenance_work_mem TO {};").format( + index_param["maintenance_work_mem"] + ) + ) + self.cursor.execute( + sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format( + sql.Identifier(self.db_config["user"]), + index_param["maintenance_work_mem"], + ) + ) + self.conn.commit() + + if index_param["max_parallel_workers"] is not None: + self.cursor.execute( + sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format( + index_param["max_parallel_workers"] + ) + ) + self.cursor.execute( + sql.SQL( + "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';" + ).format( + sql.Identifier(self.db_config["user"]), + index_param["max_parallel_workers"], + ) + ) + self.cursor.execute( + sql.SQL("SET max_parallel_workers TO '{}';").format( + index_param["max_parallel_workers"] + ) + ) + self.cursor.execute( + sql.SQL( + "ALTER USER {} SET max_parallel_workers TO '{}';" + ).format( + sql.Identifier(self.db_config["user"]), + index_param["max_parallel_workers"], + ) + ) + self.cursor.execute( + sql.SQL( + "ALTER TABLE {} SET (parallel_workers = {});" + ).format( + sql.Identifier(self.table_name), + index_param["max_parallel_workers"], + ) + ) + self.conn.commit() + + results = self.cursor.execute( + sql.SQL("SHOW max_parallel_maintenance_workers;") + ).fetchall() + results.extend( + self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall() + ) + results.extend( + self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall() + ) + log.info(f"{self.name} parallel index creation parameters: {results}") + def _create_index(self): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" + log.info(f"{self.name} client create index : {self._index_name}") index_param = self.case_config.index_param() - if self.case_config.index == IndexType.HNSW: - log.debug(f'Creating HNSW index. m={index_param["m"]}, ef_construction={index_param["ef_construction"]}') - self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING hnsw (embedding {index_param["metric"]}) WITH (m={index_param["m"]}, ef_construction={index_param["ef_construction"]});') - elif self.case_config.index == IndexType.IVFFlat: - log.debug(f'Creating IVFFLAT index. list={index_param["lists"]}') - self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING ivfflat (embedding {index_param["metric"]}) WITH (lists={index_param["lists"]});') + self._set_parallel_index_build_param() + options = [] + for option in index_param["index_creation_with_options"]: + if option['val'] is not None: + options.append( + sql.SQL("{option_name} = {val}").format( + option_name=sql.Identifier(option['option_name']), + val=sql.Identifier(str(option['val'])), + ) + ) + if any(options): + with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) else: - assert "Invalid index type {self.case_config.index}" + with_clause = sql.Composed(()) + + index_create_sql = sql.SQL( + "CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING {index_type} (embedding {embedding_metric})" + ).format( + index_name=sql.Identifier(self._index_name), + table_name=sql.Identifier(self.table_name), + index_type=sql.Identifier(index_param["index_type"]), + embedding_metric=sql.Identifier(index_param["metric"]), + ) + index_create_sql_with_with_clause = ( + index_create_sql + with_clause + ).join(" ") + log.debug(index_create_sql_with_with_clause.as_string(self.cursor)) + self.cursor.execute(index_create_sql_with_with_clause) self.conn.commit() - def _create_table(self, dim : int): + def _create_table(self, dim: int): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" try: + log.info(f"{self.name} client create table : {self.table_name}") + # create table - self.cursor.execute(f'CREATE TABLE IF NOT EXISTS public."{self.table_name}" (id BIGINT PRIMARY KEY, embedding vector({dim}));') - self.cursor.execute(f'ALTER TABLE public."{self.table_name}" ALTER COLUMN embedding SET STORAGE PLAIN;') + self.cursor.execute( + sql.SQL( + "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" + ).format(table_name=sql.Identifier(self.table_name), dim=dim) + ) + self.cursor.execute( + sql.SQL( + "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;" + ).format(table_name=sql.Identifier(self.table_name)) + ) self.conn.commit() except Exception as e: - log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}") + log.warning( + f"Failed to create pgvector table: {self.table_name} error: {e}" + ) raise e from None def insert_embeddings( @@ -133,25 +301,32 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> (int, Exception): + ) -> Tuple[int, Optional[Exception]]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" try: - items = { - "id": metadata, - "embedding": embeddings - } - df = pd.DataFrame(items) - csv_buffer = io.StringIO() - df.to_csv(csv_buffer, index=False, header=False) - csv_buffer.seek(0) - self.cursor.copy_expert(f"COPY public.\"{self.table_name}\" FROM STDIN WITH (FORMAT CSV)", csv_buffer) + metadata_arr = np.array(metadata) + embeddings_arr = np.array(embeddings) + + with self.cursor.copy( + sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( + table_name=sql.Identifier(self.table_name) + ) + ) as copy: + copy.set_types(["bigint", "vector"]) + for i, row in enumerate(metadata_arr): + copy.write_row((row, embeddings_arr[i])) self.conn.commit() + if kwargs.get("last_batch"): + self._post_insert() + return len(metadata), None except Exception as e: - log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}") + log.warning( + f"Failed to insert data into pgvector table ({self.table_name}), error: {e}" + ) return 0, e def search_embedding( @@ -164,17 +339,9 @@ def search_embedding( assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" - search_param =self.case_config.search_param() - - if self.case_config.index == IndexType.HNSW: - self.cursor.execute(f'SET hnsw.ef_search = {search_param["ef"]}') - self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};") - elif self.case_config.index == IndexType.IVFFlat: - self.cursor.execute(f'SET ivfflat.probes = {search_param["probes"]}') - self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};") - else: - assert "Invalid index type {self.case_config.index}" - self.conn.commit() - result = self.cursor.fetchall() + # TODO add filters support + result = self.cursor.execute( + self._unfiltered_search, (query, k), prepare=True, binary=True + ) - return [int(i[0]) for i in result] + return [int(i[0]) for i in result.fetchall()] diff --git a/vectordb_bench/frontend/components/run_test/caseSelector.py b/vectordb_bench/frontend/components/run_test/caseSelector.py index 9af023518..49b839163 100644 --- a/vectordb_bench/frontend/components/run_test/caseSelector.py +++ b/vectordb_bench/frontend/components/run_test/caseSelector.py @@ -65,25 +65,28 @@ def caseConfigSetting(st, allCaseConfigs, case, activedDbList): key = "%s-%s-%s" % (db, case, config.label.value) if config.inputType == InputType.Text: caseConfig[config.label] = column.text_input( - config.label.value, + config.displayLabel if config.displayLabel else config.label.value, key=key, + help=config.inputHelp, value=config.inputConfig["value"], ) elif config.inputType == InputType.Option: caseConfig[config.label] = column.selectbox( - config.label.value, + config.displayLabel if config.displayLabel else config.label.value, config.inputConfig["options"], key=key, + help=config.inputHelp, ) elif config.inputType == InputType.Number: caseConfig[config.label] = column.number_input( - config.label.value, + config.displayLabel if config.displayLabel else config.label.value, # format="%d", step=config.inputConfig.get("step", 1), min_value=config.inputConfig["min"], max_value=config.inputConfig["max"], key=key, value=config.inputConfig["value"], + help=config.inputHelp, ) k += 1 if k == 0: diff --git a/vectordb_bench/frontend/const/dbCaseConfigs.py b/vectordb_bench/frontend/const/dbCaseConfigs.py index 9b122910a..1e69c57aa 100644 --- a/vectordb_bench/frontend/const/dbCaseConfigs.py +++ b/vectordb_bench/frontend/const/dbCaseConfigs.py @@ -49,6 +49,8 @@ class CaseConfigInput(BaseModel): label: CaseConfigParamType inputType: InputType = InputType.Text inputConfig: dict = {} + inputHelp: str = "" + displayLabel: str = "" # todo type should be a function isDisplayed: typing.Any = lambda x: True @@ -71,6 +73,18 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_IndexType_PgVector = CaseConfigInput( + label=CaseConfigParamType.IndexType, + inputHelp="Select Index Type", + inputType=InputType.Option, + inputConfig={ + "options": [ + IndexType.HNSW.value, + IndexType.IVFFlat.value, + ], + }, +) + CaseConfigParamInput_M = CaseConfigInput( label=CaseConfigParamType.M, inputType=InputType.Number, @@ -83,6 +97,19 @@ class CaseConfigInput(BaseModel): == IndexType.HNSW.value, ) +CaseConfigParamInput_m = CaseConfigInput( + label=CaseConfigParamType.m, + inputType=InputType.Number, + inputConfig={ + "min": 4, + "max": 64, + "value": 16, + }, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) + == IndexType.HNSW.value, +) + + CaseConfigParamInput_EFConstruction_Milvus = CaseConfigInput( label=CaseConfigParamType.EFConstruction, inputType=InputType.Number, @@ -115,6 +142,30 @@ class CaseConfigInput(BaseModel): }, ) +CaseConfigParamInput_maintenance_work_mem_PgVector = CaseConfigInput( + label=CaseConfigParamType.maintenance_work_mem, + inputHelp="Recommended value: 1.33x the index size, not to exceed the available free memory." + "Specify in gigabytes. e.g. 8GB", + inputType=InputType.Text, + inputConfig={ + "value": "8GB", + }, +) + +CaseConfigParamInput_max_parallel_workers_PgVector = CaseConfigInput( + label=CaseConfigParamType.max_parallel_workers, + displayLabel="Max parallel workers", + inputHelp="Recommended value: (cpu cores - 1). This will set the parameters: max_parallel_maintenance_workers," + " max_parallel_workers & table(parallel_workers)", + inputType=InputType.Number, + inputConfig={ + "min": 0, + "max": 1024, + "value": 16, + }, +) + + CaseConfigParamInput_EFConstruction_PgVectoRS = CaseConfigInput( label=CaseConfigParamType.EFConstruction, inputType=InputType.Number, @@ -127,6 +178,19 @@ class CaseConfigInput(BaseModel): == IndexType.HNSW.value, ) +CaseConfigParamInput_EFConstruction_PgVector = CaseConfigInput( + label=CaseConfigParamType.ef_construction, + inputType=InputType.Number, + inputConfig={ + "min": 8, + "max": 1024, + "value": 256, + }, + isDisplayed=lambda config: config[CaseConfigParamType.IndexType] + == IndexType.HNSW.value, +) + + CaseConfigParamInput_M_ES = CaseConfigInput( label=CaseConfigParamType.M, inputType=InputType.Number, @@ -371,24 +435,13 @@ class CaseConfigInput(BaseModel): ], ) -CaseConfigParamInput_IndexType_PG = CaseConfigInput( - label=CaseConfigParamType.IndexType, - inputType=InputType.Option, - inputConfig={ - "options": [ - IndexType.HNSW.value, - IndexType.IVFFlat.value, - ], - }, -) - CaseConfigParamInput_Lists = CaseConfigInput( label=CaseConfigParamType.lists, inputType=InputType.Number, inputConfig={ "min": 1, "max": 65536, - "value": 1000, + "value": 10, }, isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.IVFFlat.value], @@ -397,37 +450,47 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_Probes = CaseConfigInput( label=CaseConfigParamType.probes, inputType=InputType.Number, + inputConfig={ + "min": 1, + "max": 65536, + "value": 1, + }, +) + +CaseConfigParamInput_Lists_PgVector = CaseConfigInput( + label=CaseConfigParamType.lists, + inputType=InputType.Number, inputConfig={ "min": 1, "max": 65536, "value": 10, }, isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.IVFFlat.value], + == IndexType.IVFFlat.value, ) -CaseConfigParamInput_EF_PG = CaseConfigInput( - label=CaseConfigParamType.EF, +CaseConfigParamInput_Probes_PgVector = CaseConfigInput( + label=CaseConfigParamType.probes, inputType=InputType.Number, inputConfig={ "min": 1, "max": 65536, - "value": 128, + "value": 1, }, isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.HNSW.value], + == IndexType.IVFFlat.value, ) -CaseConfigParamInput_EFC_PG = CaseConfigInput( - label=CaseConfigParamType.EFConstruction, +CaseConfigParamInput_EFSearch_PgVector = CaseConfigInput( + label=CaseConfigParamType.ef_search, inputType=InputType.Number, inputConfig={ "min": 1, - "max": 65536, - "value": 300, + "max": 2048, + "value": 256, }, isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.HNSW.value], + == IndexType.HNSW.value, ) CaseConfigParamInput_QuantizationType_PgVectoRS = CaseConfigInput( @@ -518,20 +581,22 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_NumCandidates_ES, ] -PgVectorLoadingConfig = [ - CaseConfigParamInput_IndexType_PG, - CaseConfigParamInput_Lists, - CaseConfigParamInput_M, - CaseConfigParamInput_EFC_PG -] -PgVectorPerformanceConfig = [ - CaseConfigParamInput_IndexType_PG, - CaseConfigParamInput_Lists, - CaseConfigParamInput_Probes, - CaseConfigParamInput_M, - CaseConfigParamInput_EF_PG, - CaseConfigParamInput_EFC_PG -] +PgVectorLoadingConfig = [CaseConfigParamInput_IndexType_PgVector, + CaseConfigParamInput_Lists_PgVector, + CaseConfigParamInput_m, + CaseConfigParamInput_EFConstruction_PgVector, + CaseConfigParamInput_maintenance_work_mem_PgVector, + CaseConfigParamInput_max_parallel_workers_PgVector, + ] +PgVectorPerformanceConfig = [CaseConfigParamInput_IndexType_PgVector, + CaseConfigParamInput_m, + CaseConfigParamInput_EFConstruction_PgVector, + CaseConfigParamInput_EFSearch_PgVector, + CaseConfigParamInput_Lists_PgVector, + CaseConfigParamInput_Probes_PgVector, + CaseConfigParamInput_maintenance_work_mem_PgVector, + CaseConfigParamInput_max_parallel_workers_PgVector, + ] PgVectoRSLoadingConfig = [ CaseConfigParamInput_IndexType, diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 3c2a5b9aa..ec1b610e1 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -1,8 +1,8 @@ import logging import pathlib from datetime import date -from typing import Self -from enum import Enum +from enum import Enum, StrEnum, auto +from typing import List, Self, Sequence, Set import ujson @@ -37,8 +37,10 @@ class CaseConfigParamType(Enum): IndexType = "IndexType" M = "M" EFConstruction = "efConstruction" + ef_construction = "ef_construction" EF = "ef" SearchList = "search_list" + ef_search = "ef_search" Nlist = "nlist" Nprobe = "nprobe" MaxConnections = "maxConnections" @@ -60,7 +62,8 @@ class CaseConfigParamType(Enum): cache_dataset_on_device = "cache_dataset_on_device" refine_ratio = "refine_ratio" level = "level" - + maintenance_work_mem = "maintenance_work_mem" + max_parallel_workers = "max_parallel_workers" class CustomizedCase(BaseModel): pass