Skip to content

Commit d02025d

Browse files
fix remaining issues
1 parent 1340690 commit d02025d

File tree

10 files changed

+58
-77
lines changed

10 files changed

+58
-77
lines changed

redisvl/extensions/llmcache/schema.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Dict, List, Optional
22

3-
from pydantic import BaseModel, Field, field_validator, model_validator
3+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
44

55
from redisvl.extensions.constants import (
66
CACHE_VECTOR_FIELD_NAME,
@@ -80,33 +80,33 @@ class CacheHit(BaseModel):
8080
filters: Optional[Dict[str, Any]] = Field(default=None)
8181
"""Optional filter data stored on the cache entry for customizing retrieval"""
8282

83+
# Allow extra fields to simplify handling filters
84+
model_config = ConfigDict(extra="allow")
85+
8386
@model_validator(mode="before")
8487
@classmethod
85-
def validate_cache_hit(cls, values):
88+
def validate_cache_hit(cls, values: Dict[str, Any]) -> Dict[str, Any]:
8689
# Deserialize metadata if necessary
8790
if "metadata" in values and isinstance(values["metadata"], str):
8891
values["metadata"] = deserialize(values["metadata"])
8992

90-
# Separate filters from other fields
91-
known_fields = set(cls.model_fields.keys())
92-
filters = {k: v for k, v in values.items() if k not in known_fields}
93-
94-
# Add filters to valuesgiy s
95-
if filters:
96-
values["filters"] = filters
97-
98-
# Remove filter fields from the main values
99-
for k in filters:
100-
values.pop(k)
93+
# Collect any extra fields and store them as filters
94+
extra_data = values.pop("__pydantic_extra__", {}) or {}
95+
if extra_data:
96+
current_filters = values.get("filters") or {}
97+
if not isinstance(current_filters, dict):
98+
current_filters = {}
99+
current_filters.update(extra_data)
100+
values["filters"] = current_filters
101101

102102
return values
103103

104-
def to_dict(self) -> Dict:
104+
def to_dict(self) -> Dict[str, Any]:
105+
"""Convert this model to a dictionary, merging filters into the result."""
105106
data = self.model_dump(exclude_none=True)
106-
if self.filters:
107-
data.update(self.filters)
107+
if data.get("filters"):
108+
data.update(data["filters"])
108109
del data["filters"]
109-
110110
return data
111111

112112

redisvl/extensions/llmcache/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
existing_index = SearchIndex.from_existing(
142142
name, redis_client=self._index.client
143143
)
144-
if existing_index.schema != self._index.schema:
144+
if existing_index.schema.to_dict() != self._index.schema.to_dict():
145145
raise ValueError(
146146
f"Existing index {name} schema does not match the user provided schema for the semantic cache. "
147147
"If you wish to overwrite the index schema, set overwrite=True during initialization."

redisvl/extensions/router.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

redisvl/extensions/router/schema.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import warnings
12
from enum import Enum
2-
from typing import Dict, List, Optional
3+
from typing import Any, Dict, List, Optional
34

4-
from pydantic import BaseModel, Field, field_validator
5+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
6+
from typing_extensions import Annotated
57

68
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
79
from redisvl.schema import IndexSchema
@@ -16,7 +18,7 @@ class Route(BaseModel):
1618
"""List of reference phrases for the route."""
1719
metadata: Dict[str, str] = Field(default={})
1820
"""Metadata associated with the route."""
19-
distance_threshold: float = Field(default=0.5)
21+
distance_threshold: Annotated[float, Field(strict=True, default=0.5, gt=0, le=1)]
2022
"""Distance threshold for matching the route."""
2123

2224
@field_validator("name")
@@ -35,13 +37,6 @@ def references_must_not_be_empty(cls, v):
3537
raise ValueError("All references must be non-empty strings")
3638
return v
3739

38-
@field_validator("distance_threshold")
39-
@classmethod
40-
def distance_threshold_must_be_positive(cls, v):
41-
if v is not None and v <= 0:
42-
raise ValueError("Route distance threshold must be greater than zero")
43-
return v
44-
4540

4641
class RouteMatch(BaseModel):
4742
"""Model representing a matched route with distance information."""
@@ -66,28 +61,26 @@ class DistanceAggregationMethod(Enum):
6661
class RoutingConfig(BaseModel):
6762
"""Configuration for routing behavior."""
6863

69-
# distance_threshold: float = Field(default=0.5)
70-
"""The threshold for semantic distance."""
71-
max_k: int = Field(default=1)
72-
64+
"""The maximum number of top matches to return."""
65+
max_k: Annotated[int, Field(strict=True, default=1, gt=0)] = 1
7366
"""Aggregation method to use to classify queries."""
7467
aggregation_method: DistanceAggregationMethod = Field(
7568
default=DistanceAggregationMethod.avg
7669
)
7770

78-
"""The maximum number of top matches to return."""
79-
distance_threshold: float = Field(
80-
default=0.5,
81-
deprecated=True,
82-
description="Global distance threshold is deprecated all distance_thresholds now apply at route level.",
83-
)
71+
model_config = ConfigDict(extra="ignore")
8472

85-
@field_validator("max_k")
73+
@model_validator(mode="before")
8674
@classmethod
87-
def max_k_must_be_positive(cls, v):
88-
if v <= 0:
89-
raise ValueError("max_k must be a positive integer")
90-
return v
75+
def remove_distance_threshold(cls, values: Dict[str, Any]) -> Dict[str, Any]:
76+
if "distance_threshold" in values:
77+
warnings.warn(
78+
"The 'distance_threshold' field is deprecated and will be ignored. Set distance_threshold per Route.",
79+
DeprecationWarning,
80+
stacklevel=2,
81+
)
82+
values.pop("distance_threshold")
83+
return values
9184

9285

9386
class SemanticRouterIndexSchema(IndexSchema):

redisvl/extensions/router/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _initialize_index(
123123
existing_index = SearchIndex.from_existing(
124124
self.name, redis_client=self._index.client
125125
)
126-
if existing_index.schema != self._index.schema:
126+
if existing_index.schema.to_dict() != self._index.schema.to_dict():
127127
raise ValueError(
128128
f"Existing index {self.name} schema does not match the user provided schema for the semantic router. "
129129
"If you wish to overwrite the index schema, set overwrite=True during initialization."

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(
110110
existing_index = SearchIndex.from_existing(
111111
name, redis_client=self._index.client
112112
)
113-
if existing_index.schema != self._index.schema:
113+
if existing_index.schema.to_dict() != self._index.schema.to_dict():
114114
raise ValueError(
115115
f"Existing index {name} schema does not match the user provided schema for the semantic session. "
116116
"If you wish to overwrite the index schema, set overwrite=True during initialization."

redisvl/utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def wrapper(func):
7777
@wraps(func)
7878
def inner(*args, **kwargs):
7979
argument_names = func.__code__.co_varnames
80+
print(argument_names, flush=True)
8081

8182
if argument in argument_names:
8283
warn(message, DeprecationWarning, stacklevel=2)

schemas/semantic_router.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,3 @@ vectorizer:
2020
routing_config:
2121
max_k: 2
2222
aggregation_method: avg
23-
distance_threshold: 0.3

tests/integration/test_semantic_router.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
from redis.exceptions import ConnectionError
77

88
from redisvl.exceptions import RedisModuleVersionError
9-
from redisvl.extensions.llmcache.semantic import SemanticCache
109
from redisvl.extensions.router import SemanticRouter
11-
from redisvl.extensions.router.schema import Route, RoutingConfig
10+
from redisvl.extensions.router.schema import (
11+
DistanceAggregationMethod,
12+
Route,
13+
RoutingConfig,
14+
)
1215
from redisvl.redis.connection import compare_versions
1316
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
1417

@@ -58,7 +61,6 @@ def disable_deprecation_warnings():
5861
def test_initialize_router(semantic_router):
5962
assert semantic_router.name == "test-router"
6063
assert len(semantic_router.routes) == 2
61-
assert semantic_router.routing_config.distance_threshold == 0.3
6264
assert semantic_router.routing_config.max_k == 2
6365

6466

@@ -114,10 +116,13 @@ def test_multiple_query(semantic_router):
114116

115117

116118
def test_update_routing_config(semantic_router):
117-
new_config = RoutingConfig(distance_threshold=0.5, max_k=1)
119+
new_config = RoutingConfig(max_k=27, aggregation_method="min")
118120
semantic_router.update_routing_config(new_config)
119-
assert semantic_router.routing_config.distance_threshold == 0.5
120-
assert semantic_router.routing_config.max_k == 1
121+
assert semantic_router.routing_config.max_k == 27
122+
assert (
123+
semantic_router.routing_config.aggregation_method
124+
== DistanceAggregationMethod.min
125+
)
121126

122127

123128
def test_vector_query(semantic_router):
@@ -189,7 +194,7 @@ def test_from_dict(semantic_router):
189194
new_router = SemanticRouter.from_dict(
190195
router_dict, redis_client=semantic_router._index.client, overwrite=True
191196
)
192-
assert new_router == semantic_router
197+
assert new_router.to_dict() == router_dict
193198

194199

195200
def test_to_yaml(semantic_router):
@@ -203,7 +208,7 @@ def test_from_yaml(semantic_router):
203208
new_router = SemanticRouter.from_yaml(
204209
yaml_file, redis_client=semantic_router._index.client, overwrite=True
205210
)
206-
assert new_router == semantic_router
211+
assert new_router.to_dict() == semantic_router.to_dict()
207212

208213

209214
def test_to_dict_missing_fields():
@@ -290,14 +295,14 @@ def test_different_vector_dtypes(redis_url, routes):
290295
def test_bad_dtype_connecting_to_exiting_router(redis_url, routes):
291296
try:
292297
router = SemanticRouter(
293-
name="float64 router",
298+
name="float64-router",
294299
routes=routes,
295300
dtype="float64",
296301
redis_url=redis_url,
297302
)
298303

299304
same_type = SemanticRouter(
300-
name="float64 router",
305+
name="float64-router",
301306
routes=routes,
302307
dtype="float64",
303308
redis_url=redis_url,
@@ -308,7 +313,7 @@ def test_bad_dtype_connecting_to_exiting_router(redis_url, routes):
308313

309314
with pytest.raises(ValueError):
310315
bad_type = SemanticRouter(
311-
name="float64 router",
316+
name="float64-router",
312317
routes=routes,
313318
dtype="float16",
314319
redis_url=redis_url,

tests/unit/test_route_schema.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_route_invalid_threshold_zero():
7474
metadata={"key": "value"},
7575
distance_threshold=0,
7676
)
77-
assert "Route distance threshold must be greater than zero" in str(excinfo.value)
77+
assert "Input should be greater than 0" in str(excinfo.value)
7878

7979

8080
def test_route_invalid_threshold_negative():
@@ -85,7 +85,7 @@ def test_route_invalid_threshold_negative():
8585
metadata={"key": "value"},
8686
distance_threshold=-0.1,
8787
)
88-
assert "Route distance threshold must be greater than zero" in str(excinfo.value)
88+
assert "Input should be greater than 0" in str(excinfo.value)
8989

9090

9191
def test_route_match():
@@ -115,4 +115,4 @@ def test_routing_config_valid():
115115
def test_routing_config_invalid_max_k():
116116
with pytest.raises(ValidationError) as excinfo:
117117
RoutingConfig(max_k=0)
118-
assert "max_k must be a positive integer" in str(excinfo.value)
118+
assert "Input should be greater than 0" in str(excinfo.value)

0 commit comments

Comments
 (0)