Skip to content

Commit e406d76

Browse files
tylerhutchersonabrookins
authored andcommitted
start centralizing the use of fixtures for hugging face models
1 parent b5f3780 commit e406d76

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from testcontainers.compose import DockerCompose
66

77
from redisvl.redis.connection import RedisConnectionFactory
8+
from redisvl.utils.vectorize import HFTextVectorizer
89

910

1011
@pytest.fixture(autouse=True)
@@ -68,6 +69,15 @@ def client(redis_url):
6869
yield conn
6970

7071

72+
@pytest.fixture(scope="session", autouse=True)
73+
def hf_vectorizer():
74+
return HFTextVectorizer(
75+
model="sentence-transformers/all-mpnet-base-v2",
76+
token=os.getenv("HF_TOKEN"),
77+
cache_folder=os.getenv("SENTENCE_TRANSFORMERS_HOME"),
78+
)
79+
80+
7181
@pytest.fixture
7282
def sample_datetimes():
7383
return {

tests/integration/test_threshold_optimizer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ def routes():
3535

3636

3737
@pytest.fixture
38-
def semantic_router(client, routes):
38+
def semantic_router(client, routes, hf_vectorizer):
3939
router = SemanticRouter(
4040
name="test-router",
4141
routes=routes,
42+
vectorizer=hf_vectorizer,
4243
routing_config=RoutingConfig(max_k=2),
4344
redis_client=client,
4445
overwrite=False,
@@ -86,7 +87,7 @@ def test_data_optimization():
8687

8788

8889
def test_routes_different_distance_thresholds_optimizer_default(
89-
semantic_router, routes, redis_url, test_data_optimization
90+
semantic_router, routes, redis_url, test_data_optimization, hf_vectorizer
9091
):
9192
redis_version = semantic_router._index.client.info()["redis_version"]
9293
if not compare_versions(redis_version, "7.0.0"):
@@ -101,6 +102,7 @@ def test_routes_different_distance_thresholds_optimizer_default(
101102
router = SemanticRouter(
102103
name="test_routes_different_distance_optimizer",
103104
routes=routes,
105+
vectorizer=hf_vectorizer,
104106
redis_url=redis_url,
105107
overwrite=True,
106108
)
@@ -119,7 +121,7 @@ def test_routes_different_distance_thresholds_optimizer_default(
119121

120122

121123
def test_routes_different_distance_thresholds_optimizer_precision(
122-
semantic_router, routes, redis_url, test_data_optimization
124+
semantic_router, routes, redis_url, test_data_optimization, hf_vectorizer
123125
):
124126

125127
redis_version = semantic_router._index.client.info()["redis_version"]
@@ -135,6 +137,7 @@ def test_routes_different_distance_thresholds_optimizer_precision(
135137
router = SemanticRouter(
136138
name="test_routes_different_distance_optimizer",
137139
routes=routes,
140+
vectorizer=hf_vectorizer,
138141
redis_url=redis_url,
139142
overwrite=True,
140143
)
@@ -155,7 +158,7 @@ def test_routes_different_distance_thresholds_optimizer_precision(
155158

156159

157160
def test_routes_different_distance_thresholds_optimizer_recall(
158-
semantic_router, routes, redis_url, test_data_optimization
161+
semantic_router, routes, redis_url, test_data_optimization, hf_vectorizer
159162
):
160163
redis_version = semantic_router._index.client.info()["redis_version"]
161164
if not compare_versions(redis_version, "7.0.0"):
@@ -170,6 +173,7 @@ def test_routes_different_distance_thresholds_optimizer_recall(
170173
router = SemanticRouter(
171174
name="test_routes_different_distance_optimizer",
172175
routes=routes,
176+
vectorizer=hf_vectorizer,
173177
redis_url=redis_url,
174178
overwrite=True,
175179
)

0 commit comments

Comments
 (0)