Skip to content

[ENH] When embedding functions have defined default_space, use them if the user hasn't specified #4321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def create_collection(
configuration = {}
if embedding_function is not None:
configuration["embedding_function"] = embedding_function
else:
if configuration.get("embedding_function") is None:
configuration["embedding_function"] = embedding_function
model = self._server.create_collection(
name=name,
metadata=metadata,
Expand Down
18 changes: 18 additions & 0 deletions chromadb/api/collection_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,24 @@ def create_collection_configuration_to_json(
"config": ef.get_config(),
}
register_embedding_function(type(ef)) # type: ignore
if hnsw_config is not None and hnsw_config.get("space") is None:
try:
hnsw_config["space"] = ef.default_space()
Copy link
Collaborator

@HammadB HammadB Apr 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there not some base impl that will always be used? I am bit a confused by this pattern. You could also check if its the base impl.

If default_space is not specified, what is the behavior we want?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, the base impl is l2, which is the space we define in our docs as the default. so when available, it’ll pull from the ef’s default otherwise the base one

Copy link
Contributor Author

@jairad26 jairad26 Apr 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if they’re running on an old version or are using a legacy custom ef (which won't have the base impl), the try will fail and issue a warning

except Exception:
warnings.warn(
f"default_space not supported for {ef.name()}",
DeprecationWarning,
stacklevel=2,
)
if spann_config is not None and spann_config.get("space") is None:
try:
spann_config["space"] = ef.default_space()
except Exception:
warnings.warn(
f"default_space not supported for {ef.name()}",
DeprecationWarning,
stacklevel=2,
)
except Exception as e:
warnings.warn(
f"legacy embedding function config: {e}",
Expand Down
92 changes: 92 additions & 0 deletions chromadb/test/configurations/test_collection_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,3 +776,95 @@ def test_default_collection_creation(client: ClientAPI) -> None:
ef = config.get("embedding_function")
assert ef is not None
assert ef.name() == "default"


def test_default_space_hnsw(client: ClientAPI) -> None:
"""Test that the default space is used for HNSW if not specified in config."""
client.reset()
hnsw_config_no_space: CreateHNSWConfiguration = {"ef_construction": 123}
coll = client.create_collection(
name="test_default_space_hnsw",
configuration={
"hnsw": hnsw_config_no_space,
"embedding_function": CustomEmbeddingFunction(),
},
)
config = load_collection_configuration_from_json(coll.configuration_json)
assert config is not None
hnsw_config = config.get("hnsw")
assert hnsw_config is not None
assert hnsw_config.get("space") == "cosine" # Should match EF default
assert hnsw_config.get("ef_construction") == 123
assert config.get("spann") is None
ef = config.get("embedding_function")
assert ef is not None
assert ef.name() == "custom_ef"


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_default_space_spann(client: ClientAPI) -> None:
"""Test that the default space is used for SPANN if not specified in config."""
client.reset()
spann_config_no_space: CreateSpannConfiguration = {"ef_construction": 123}
coll = client.create_collection(
name="test_default_space_spann",
configuration={
"spann": spann_config_no_space,
"embedding_function": CustomEmbeddingFunction(),
},
)
config = load_collection_configuration_from_json(coll.configuration_json)
assert config is not None
spann_config = config.get("spann")
assert spann_config is not None
assert spann_config.get("space") == "cosine" # Should match EF default
assert spann_config.get("ef_construction") == 123
assert config.get("hnsw") is None
ef = config.get("embedding_function")
assert ef is not None
assert ef.name() == "custom_ef"


def test_override_default_space_hnsw(client: ClientAPI) -> None:
"""Test that specifying space in HNSW config overrides the EF default."""
client.reset()
hnsw_config_override: CreateHNSWConfiguration = {"space": "l2"}
coll = client.create_collection(
name="test_override_space_hnsw",
configuration={
"hnsw": hnsw_config_override,
"embedding_function": CustomEmbeddingFunction(),
},
)
config = load_collection_configuration_from_json(coll.configuration_json)
assert config is not None
hnsw_config = config.get("hnsw")
assert hnsw_config is not None
assert hnsw_config.get("space") == "l2" # Should be overridden value
assert config.get("spann") is None
ef = config.get("embedding_function")
assert ef is not None
assert ef.name() == "custom_ef"


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_override_default_space_spann(client: ClientAPI) -> None:
"""Test that specifying space in SPANN config overrides the EF default."""
client.reset()
spann_config_override: CreateSpannConfiguration = {"space": "l2"}
coll = client.create_collection(
name="test_override_space_spann",
configuration={
"spann": spann_config_override,
"embedding_function": CustomEmbeddingFunction(),
},
)
config = load_collection_configuration_from_json(coll.configuration_json)
assert config is not None
spann_config = config.get("spann")
assert spann_config is not None
assert spann_config.get("space") == "l2" # Should be overridden value
assert config.get("hnsw") is None
ef = config.get("embedding_function")
assert ef is not None
assert ef.name() == "custom_ef"
16 changes: 7 additions & 9 deletions chromadb/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from chromadb.serde import BaseModelJSONSerializable
from chromadb.api.collection_configuration import (
CollectionConfiguration,
HNSWConfiguration,
SpannConfiguration,
collection_configuration_to_json,
load_collection_configuration_from_json,
)
Expand Down Expand Up @@ -155,8 +153,8 @@ def get_configuration(self) -> CollectionConfiguration:
stacklevel=2,
)
return CollectionConfiguration(
hnsw=HNSWConfiguration(),
spann=SpannConfiguration(),
hnsw=None,
spann=None,
embedding_function=None,
)

Expand All @@ -175,11 +173,11 @@ def get_model_fields(self) -> Dict[Any, Any]:
@override
def from_json(cls, json_map: Dict[str, Any]) -> Self:
"""Deserializes a Collection object from JSON"""
configuration: CollectionConfiguration = {
"hnsw": {},
"spann": {},
"embedding_function": None,
}
configuration = CollectionConfiguration(
hnsw=None,
spann=None,
embedding_function=None,
)
try:
configuration_json = json_map.get("configuration_json", None)
configuration = load_collection_configuration_from_json(configuration_json)
Expand Down
Loading