Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new db_config for better labeling: version, note #350

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion vectordb_bench/backend/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,33 @@ class DBConfig(ABC, BaseModel):
"""

db_label: str = ""
version: str = ""
note: str = ""

@staticmethod
def common_short_configs() -> list[str]:
"""
short input, such as `db_label`, `version`
"""
return ["version", "db_label"]

@staticmethod
def common_long_configs() -> list[str]:
"""
long input, such as `note`
"""
return ["note"]

@abstractmethod
def to_dict(self) -> dict:
raise NotImplementedError

@validator("*")
def not_empty_field(cls, v, field):
if field.name == "db_label":
if (
field.name in cls.common_short_configs()
or field.name in cls.common_long_configs()
):
return v
if not v and isinstance(v, (str, SecretStr)):
raise ValueError("Empty string!")
Expand Down
19 changes: 13 additions & 6 deletions vectordb_bench/frontend/components/check_results/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def getFilterTasks(
task
for task in tasks
if task.task_config.db_name in dbNames
and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
and task.task_config.case_config.case_id.case_cls(
task.task_config.case_config.custom_case
).name
in caseNames
]
return filterTasks

Expand All @@ -35,17 +38,20 @@ def mergeTasks(tasks: list[CaseResult]):
db_name = task.task_config.db_name
db = task.task_config.db.value
db_label = task.task_config.db_config.db_label or ""
case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
version = task.task_config.db_config.version or ""
case = task.task_config.case_config.case_id.case_cls(
task.task_config.case_config.custom_case
)
dbCaseMetricsMap[db_name][case.name] = {
"db": db,
"db_label": db_label,
"version": version,
"metrics": mergeMetrics(
dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
asdict(task.metrics),
),
"label": getBetterLabel(
dbCaseMetricsMap[db_name][case.name].get(
"label", ResultLabel.FAILED),
dbCaseMetricsMap[db_name][case.name].get("label", ResultLabel.FAILED),
task.label,
),
}
Expand All @@ -57,13 +63,15 @@ def mergeTasks(tasks: list[CaseResult]):
metrics = metricInfo["metrics"]
db = metricInfo["db"]
db_label = metricInfo["db_label"]
version = metricInfo["version"]
label = metricInfo["label"]
if label == ResultLabel.NORMAL:
mergedTasks.append(
{
"db_name": db_name,
"db": db,
"db_label": db_label,
"version": version,
"case_name": case_name,
"metricsSet": set(metrics.keys()),
**metrics,
Expand All @@ -79,8 +87,7 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
metrics = {**metrics_1}
for key, value in metrics_2.items():
metrics[key] = (
getBetterMetric(
key, value, metrics[key]) if key in metrics else value
getBetterMetric(key, value, metrics[key]) if key in metrics else value
)

return metrics
Expand Down
52 changes: 37 additions & 15 deletions vectordb_bench/frontend/components/run_test/dbConfigSetting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pydantic import ValidationError
from vectordb_bench.frontend.config.styles import *
from vectordb_bench.backend.clients import DB
from vectordb_bench.frontend.config.styles import DB_CONFIG_SETTING_COLUMNS
from vectordb_bench.frontend.utils import inputIsPassword


def dbConfigSettings(st, activedDbList):
def dbConfigSettings(st, activedDbList: list[DB]):
expander = st.expander("Configurations for the selected databases", True)

dbConfigs = {}
Expand All @@ -27,7 +28,7 @@ def dbConfigSettings(st, activedDbList):
return dbConfigs, isAllValid


def dbConfigSettingItem(st, activeDb):
def dbConfigSettingItem(st, activeDb: DB):
st.markdown(
f"<div style='font-weight: 600; font-size: 20px; margin-top: 16px;'>{activeDb.value}</div>",
unsafe_allow_html=True,
Expand All @@ -36,20 +37,41 @@ def dbConfigSettingItem(st, activeDb):

dbConfigClass = activeDb.config_cls
properties = dbConfigClass.schema().get("properties")
propertiesItems = list(properties.items())
moveDBLabelToLast(propertiesItems)
dbConfig = {}
for j, property in enumerate(propertiesItems):
column = columns[j % DB_CONFIG_SETTING_COLUMNS]
key, value = property
idx = 0

# db config (unique)
for key, property in properties.items():
if (
key not in dbConfigClass.common_short_configs()
and key not in dbConfigClass.common_long_configs()
):
column = columns[idx % DB_CONFIG_SETTING_COLUMNS]
idx += 1
dbConfig[key] = column.text_input(
key,
key="%s-%s" % (activeDb.name, key),
value=property.get("default", ""),
type="password" if inputIsPassword(key) else "default",
)
# db config (common short labels)
for key in dbConfigClass.common_short_configs():
column = columns[idx % DB_CONFIG_SETTING_COLUMNS]
idx += 1
dbConfig[key] = column.text_input(
key,
key="%s-%s" % (activeDb, key),
value=value.get("default", ""),
type="password" if inputIsPassword(key) else "default",
key="%s-%s" % (activeDb.name, key),
value="",
type="default",
placeholder="optional, for labeling results",
)
return dbConfig


def moveDBLabelToLast(propertiesItems):
propertiesItems.sort(key=lambda x: 1 if x[0] == "db_label" else 0)
# db config (common long text_input)
for key in dbConfigClass.common_long_configs():
dbConfig[key] = st.text_area(
key,
key="%s-%s" % (activeDb.name, key),
value="",
placeholder="optional",
)
return dbConfig
4 changes: 3 additions & 1 deletion vectordb_bench/frontend/components/run_test/initStyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def initStyle(st):
div[data-testid='stHorizontalBlock'] {gap: 8px;}
/* check box */
.stCheckbox p { color: #000; font-size: 18px; font-weight: 600; }
/* db selector - db_name should not wrap */
div[data-testid="stVerticalBlockBorderWrapper"] div[data-testid="stCheckbox"] div[data-testid="stWidgetLabel"] p { white-space: nowrap; }
</style>""",
unsafe_allow_html=True,
)
)
12 changes: 8 additions & 4 deletions vectordb_bench/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import pathlib
from datetime import date
from enum import Enum, StrEnum, auto
from typing import List, Self, Sequence, Set
from typing import List, Self

import ujson

from .backend.clients import (
DB,
DBConfig,
DBCaseConfig,
IndexType,
)
from .backend.cases import CaseType
from .base import BaseModel
Expand Down Expand Up @@ -128,9 +127,14 @@ class TaskConfig(BaseModel):

@property
def db_name(self):
db = self.db.value
db_name = f"{self.db.value}"
db_label = self.db_config.db_label
return f"{db}-{db_label}" if db_label else db
if db_label:
db_name += f"-{db_label}"
version = self.db_config.version
if version:
db_name += f"-{version}"
return db_name


class ResultLabel(Enum):
Expand Down