diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index 0024bf600..d9ec5d83b 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -38,6 +38,22 @@ class DBConfig(ABC, BaseModel): """ db_label: str = "" + version: str = "" + note: str = "" + + @staticmethod + def common_short_configs() -> list[str]: + """ + short input, such as `db_label`, `version` + """ + return ["version", "db_label"] + + @staticmethod + def common_long_configs() -> list[str]: + """ + long input, such as `note` + """ + return ["note"] @abstractmethod def to_dict(self) -> dict: @@ -45,7 +61,10 @@ def to_dict(self) -> dict: @validator("*") def not_empty_field(cls, v, field): - if field.name == "db_label": + if ( + field.name in cls.common_short_configs() + or field.name in cls.common_long_configs() + ): return v if not v and isinstance(v, (str, SecretStr)): raise ValueError("Empty string!") diff --git a/vectordb_bench/frontend/components/check_results/data.py b/vectordb_bench/frontend/components/check_results/data.py index 1e6bba00e..b3cac21e1 100644 --- a/vectordb_bench/frontend/components/check_results/data.py +++ b/vectordb_bench/frontend/components/check_results/data.py @@ -24,7 +24,10 @@ def getFilterTasks( task for task in tasks if task.task_config.db_name in dbNames - and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames + and task.task_config.case_config.case_id.case_cls( + task.task_config.case_config.custom_case + ).name + in caseNames ] return filterTasks @@ -35,17 +38,20 @@ def mergeTasks(tasks: list[CaseResult]): db_name = task.task_config.db_name db = task.task_config.db.value db_label = task.task_config.db_config.db_label or "" - case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case) + version = task.task_config.db_config.version or "" + case = task.task_config.case_config.case_id.case_cls( + task.task_config.case_config.custom_case + ) dbCaseMetricsMap[db_name][case.name] = { "db": db, "db_label": db_label, + "version": version, "metrics": mergeMetrics( dbCaseMetricsMap[db_name][case.name].get("metrics", {}), asdict(task.metrics), ), "label": getBetterLabel( - dbCaseMetricsMap[db_name][case.name].get( - "label", ResultLabel.FAILED), + dbCaseMetricsMap[db_name][case.name].get("label", ResultLabel.FAILED), task.label, ), } @@ -57,6 +63,7 @@ def mergeTasks(tasks: list[CaseResult]): metrics = metricInfo["metrics"] db = metricInfo["db"] db_label = metricInfo["db_label"] + version = metricInfo["version"] label = metricInfo["label"] if label == ResultLabel.NORMAL: mergedTasks.append( @@ -64,6 +71,7 @@ def mergeTasks(tasks: list[CaseResult]): "db_name": db_name, "db": db, "db_label": db_label, + "version": version, "case_name": case_name, "metricsSet": set(metrics.keys()), **metrics, @@ -79,8 +87,7 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict: metrics = {**metrics_1} for key, value in metrics_2.items(): metrics[key] = ( - getBetterMetric( - key, value, metrics[key]) if key in metrics else value + getBetterMetric(key, value, metrics[key]) if key in metrics else value ) return metrics diff --git a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py index 8f4f35c93..257608413 100644 --- a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py +++ b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py @@ -1,9 +1,10 @@ from pydantic import ValidationError -from vectordb_bench.frontend.config.styles import * +from vectordb_bench.backend.clients import DB +from vectordb_bench.frontend.config.styles import DB_CONFIG_SETTING_COLUMNS from vectordb_bench.frontend.utils import inputIsPassword -def dbConfigSettings(st, activedDbList): +def dbConfigSettings(st, activedDbList: list[DB]): expander = st.expander("Configurations for the selected databases", True) dbConfigs = {} @@ -27,7 +28,7 @@ def dbConfigSettings(st, activedDbList): return dbConfigs, isAllValid -def dbConfigSettingItem(st, activeDb): +def dbConfigSettingItem(st, activeDb: DB): st.markdown( f"