Skip to content

Commit ca4b9f1

Browse files
committed
fix: address final PR review recommendations (#579)
Addresses remaining issues from PR review comment: 1. Fixed Async Context Manager (CRITICAL - Issue #1) - Added warning that default calls sync connect/disconnect - Documented that implementations with async clients should override - Included example of proper async implementation - Prevents event loop blocking for async implementations 2. Updated Connection Logging (Issue #2) - Changed from INFO to DEBUG level for connect/disconnect - Added class name to log messages for clarity - Emphasizes these are flag-only operations (not real connections) - Reduces log noise for base implementation 3. Added Batch Size Upper Bound Warning (Issue #4) - Warns when batch_size > 10,000 (memory concerns) - Recommends 100-1000 for optimal performance - Helps prevent out-of-memory errors - Non-blocking (warning only, not error) 4. Documented Timeout Limitation (Issue #3) - Clarified that default implementation doesn't enforce timeout - Added example with signal-based timeout enforcement (Unix) - Added simple example without timeout - Subclasses can choose appropriate timeout strategy All changes are backward compatible and non-breaking. Test Results: - 24/24 tests passing - All linting checks pass (Ruff, format) Signed-off-by: manavgup <manavg@gmail.com>
1 parent 877f155 commit ca4b9f1

File tree

1 file changed

+57
-20
lines changed

1 file changed

+57
-20
lines changed

backend/vectordbs/vector_store.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def connect(self) -> None:
8787
"""
8888
self._connected = True
8989
self._connection_metadata["connected_at"] = time.time()
90-
logger.info("Connected to vector store (base implementation)")
90+
logger.debug(
91+
"Connection flag set for %s (base implementation - override for real connections)", self.__class__.__name__
92+
)
9193

9294
def disconnect(self) -> None:
9395
"""Close connection to the vector database.
@@ -115,7 +117,10 @@ def disconnect(self) -> None:
115117
"""
116118
self._connected = False
117119
self._connection_metadata["disconnected_at"] = time.time()
118-
logger.info("Disconnected from vector store (base implementation)")
120+
logger.debug(
121+
"Connection flag cleared for %s (base implementation - override for real disconnections)",
122+
self.__class__.__name__,
123+
)
119124

120125
@property
121126
def is_connected(self) -> bool:
@@ -174,31 +179,33 @@ async def async_connection_context(self) -> AsyncIterator[None]:
174179
disconnecting connections that IT created. If a connection already exists,
175180
it leaves it intact on exit to avoid breaking calling code.
176181
182+
Warning:
183+
The default implementation calls synchronous connect()/disconnect() methods,
184+
which may block the event loop. Subclasses with async database clients should
185+
override this method to use async connection methods instead.
186+
177187
Usage:
178188
async with vector_store.async_connection_context():
179189
await vector_store.async_add_documents(...)
180190
181-
Example with existing connection:
182-
>>> store = VectorStore(settings)
183-
>>> store.connect() # Manual connection
184-
>>> async with store.async_connection_context():
185-
... await store.async_query(...) # Uses existing connection
186-
>>> # Connection still active after context exit
187-
>>> store.is_connected
188-
True
189-
190-
Example without existing connection:
191-
>>> store = VectorStore(settings)
192-
>>> async with store.async_connection_context():
193-
... await store.async_query(...) # Creates connection
194-
>>> # Connection cleaned up after context exit
195-
>>> store.is_connected
196-
False
191+
Example for implementations with async clients (override recommended):
192+
>>> @asynccontextmanager
193+
... async def async_connection_context(self):
194+
... needs_disconnect = False
195+
... try:
196+
... if not self._connected:
197+
... await self.async_connect() # Async method
198+
... needs_disconnect = True
199+
... yield
200+
... finally:
201+
... if needs_disconnect:
202+
... await self.async_disconnect() # Async method
197203
"""
198204
# Track if WE created the connection
199205
needs_disconnect = False
200206
try:
201207
if not self._connected:
208+
# Default uses sync methods - override if using async client
202209
self.connect()
203210
needs_disconnect = True # Only disconnect what we connected
204211
yield
@@ -274,7 +281,9 @@ def _health_check_impl(self, timeout: float) -> dict[str, Any]: # noqa: ARG002
274281
memory usage, query latency).
275282
276283
Args:
277-
timeout: Maximum time to wait for health check in seconds
284+
timeout: Maximum time to wait for health check in seconds.
285+
Note: The default implementation does not enforce this timeout.
286+
Subclasses should implement timeout handling for actual health checks.
278287
279288
Returns:
280289
Dictionary with health status information. Default keys:
@@ -286,7 +295,25 @@ def _health_check_impl(self, timeout: float) -> dict[str, Any]: # noqa: ARG002
286295
VectorStoreError: If health check fails due to connection issues
287296
TimeoutError: If health check exceeds timeout duration
288297
289-
Example:
298+
Example with timeout enforcement:
299+
>>> import signal
300+
>>> def _health_check_impl(self, timeout: float) -> dict[str, Any]:
301+
... def timeout_handler(signum, frame):
302+
... raise TimeoutError(f"Health check exceeded {timeout}s")
303+
...
304+
... # Set timeout (Unix-like systems only)
305+
... signal.signal(signal.SIGALRM, timeout_handler)
306+
... signal.alarm(int(timeout))
307+
... try:
308+
... # Perform actual health check
309+
... result = self.client.health_check()
310+
... signal.alarm(0) # Cancel alarm
311+
... return {"status": "healthy", "nodes": result.nodes}
312+
... except Exception:
313+
... signal.alarm(0) # Cancel alarm
314+
... raise
315+
316+
Example without timeout (simple):
290317
>>> def _health_check_impl(self, timeout: float) -> dict[str, Any]:
291318
... return {
292319
... "status": "healthy",
@@ -484,6 +511,16 @@ def _batch_chunks(self, chunks: list[EmbeddedChunk], batch_size: int) -> list[li
484511
"Common batch sizes: 100 (conservative), 500 (balanced), 1000 (aggressive)"
485512
)
486513

514+
# Warn about very large batch sizes that may cause memory issues
515+
if batch_size > 10000:
516+
logger.warning(
517+
"Batch size %d is very large and may cause memory issues. "
518+
"Consider using smaller batches (100-1000 recommended). "
519+
"Collection: %s",
520+
batch_size,
521+
getattr(chunks[0], "collection_name", "unknown") if chunks else "unknown",
522+
)
523+
487524
batches: list[list[EmbeddedChunk]] = []
488525
for i in range(0, len(chunks), batch_size):
489526
batches.append(chunks[i : i + batch_size])

0 commit comments

Comments
 (0)