@@ -35,10 +35,11 @@ def routes():
35
35
36
36
37
37
@pytest .fixture
38
- def semantic_router (client , routes ):
38
+ def semantic_router (client , routes , hf_vectorizer ):
39
39
router = SemanticRouter (
40
40
name = "test-router" ,
41
41
routes = routes ,
42
+ vectorizer = hf_vectorizer ,
42
43
routing_config = RoutingConfig (max_k = 2 ),
43
44
redis_client = client ,
44
45
overwrite = False ,
@@ -86,7 +87,7 @@ def test_data_optimization():
86
87
87
88
88
89
def test_routes_different_distance_thresholds_optimizer_default (
89
- semantic_router , routes , redis_url , test_data_optimization
90
+ semantic_router , routes , redis_url , test_data_optimization , hf_vectorizer
90
91
):
91
92
redis_version = semantic_router ._index .client .info ()["redis_version" ]
92
93
if not compare_versions (redis_version , "7.0.0" ):
@@ -101,6 +102,7 @@ def test_routes_different_distance_thresholds_optimizer_default(
101
102
router = SemanticRouter (
102
103
name = "test_routes_different_distance_optimizer" ,
103
104
routes = routes ,
105
+ vectorizer = hf_vectorizer ,
104
106
redis_url = redis_url ,
105
107
overwrite = True ,
106
108
)
@@ -119,7 +121,7 @@ def test_routes_different_distance_thresholds_optimizer_default(
119
121
120
122
121
123
def test_routes_different_distance_thresholds_optimizer_precision (
122
- semantic_router , routes , redis_url , test_data_optimization
124
+ semantic_router , routes , redis_url , test_data_optimization , hf_vectorizer
123
125
):
124
126
125
127
redis_version = semantic_router ._index .client .info ()["redis_version" ]
@@ -135,6 +137,7 @@ def test_routes_different_distance_thresholds_optimizer_precision(
135
137
router = SemanticRouter (
136
138
name = "test_routes_different_distance_optimizer" ,
137
139
routes = routes ,
140
+ vectorizer = hf_vectorizer ,
138
141
redis_url = redis_url ,
139
142
overwrite = True ,
140
143
)
@@ -155,7 +158,7 @@ def test_routes_different_distance_thresholds_optimizer_precision(
155
158
156
159
157
160
def test_routes_different_distance_thresholds_optimizer_recall (
158
- semantic_router , routes , redis_url , test_data_optimization
161
+ semantic_router , routes , redis_url , test_data_optimization , hf_vectorizer
159
162
):
160
163
redis_version = semantic_router ._index .client .info ()["redis_version" ]
161
164
if not compare_versions (redis_version , "7.0.0" ):
@@ -170,6 +173,7 @@ def test_routes_different_distance_thresholds_optimizer_recall(
170
173
router = SemanticRouter (
171
174
name = "test_routes_different_distance_optimizer" ,
172
175
routes = routes ,
176
+ vectorizer = hf_vectorizer ,
173
177
redis_url = redis_url ,
174
178
overwrite = True ,
175
179
)
0 commit comments