|
4 | 4 |
|
5 | 5 | from datetime import datetime |
6 | 6 | from enum import Enum |
7 | | -from typing import Any |
| 7 | +from typing import Any, TypeVar |
8 | 8 |
|
9 | | -from pydantic import BaseModel, ConfigDict |
| 9 | +from pydantic import BaseModel, ConfigDict, Field, field_validator |
10 | 10 |
|
11 | 11 | # Simplified embedding type |
12 | 12 | Embedding = float |
@@ -313,3 +313,327 @@ def vectors(self) -> Embeddings: |
313 | 313 | def vectors(self, value: Embeddings) -> None: |
314 | 314 | """Set vectors (updates embeddings).""" |
315 | 315 | self.embeddings = value |
| 316 | + |
| 317 | + |
| 318 | +class EmbeddedChunk(DocumentChunk): |
| 319 | + """A DocumentChunk with mandatory (non-optional) embeddings. |
| 320 | +
|
| 321 | + This class ensures that embeddings are always present, making it suitable |
| 322 | + for vector database operations that require embedded content. |
| 323 | +
|
| 324 | + Attributes: |
| 325 | + embeddings: Required vector embedding (non-optional) |
| 326 | + All other attributes inherited from DocumentChunk |
| 327 | + """ |
| 328 | + |
| 329 | + embeddings: Embeddings = Field(..., description="Required vector embedding for this chunk") |
| 330 | + |
| 331 | + @field_validator("embeddings") |
| 332 | + @classmethod |
| 333 | + def validate_embeddings_not_empty(cls, v: Embeddings) -> Embeddings: |
| 334 | + """Ensure embeddings list is not empty.""" |
| 335 | + if not v: |
| 336 | + raise ValueError("Embeddings cannot be empty") |
| 337 | + return v |
| 338 | + |
| 339 | + @classmethod |
| 340 | + def from_chunk(cls, chunk: DocumentChunk, embeddings: Embeddings | None = None) -> EmbeddedChunk: |
| 341 | + """Convert a DocumentChunk to an EmbeddedChunk. |
| 342 | +
|
| 343 | + Args: |
| 344 | + chunk: The source DocumentChunk |
| 345 | + embeddings: Optional embeddings to use (if not provided, uses chunk.embeddings) |
| 346 | +
|
| 347 | + Returns: |
| 348 | + EmbeddedChunk with mandatory embeddings |
| 349 | +
|
| 350 | + Raises: |
| 351 | + ValueError: If embeddings are not available |
| 352 | + """ |
| 353 | + emb = embeddings or chunk.embeddings |
| 354 | + if not emb: |
| 355 | + raise ValueError("Cannot create EmbeddedChunk without embeddings") |
| 356 | + |
| 357 | + return cls( |
| 358 | + chunk_id=chunk.chunk_id, |
| 359 | + text=chunk.text, |
| 360 | + embeddings=emb, |
| 361 | + metadata=chunk.metadata, |
| 362 | + document_id=chunk.document_id, |
| 363 | + parent_chunk_id=chunk.parent_chunk_id, |
| 364 | + child_chunk_ids=chunk.child_chunk_ids, |
| 365 | + level=chunk.level, |
| 366 | + ) |
| 367 | + |
| 368 | + def to_vector_metadata(self) -> dict[str, Any]: |
| 369 | + """Convert to metadata dictionary suitable for vector database storage. |
| 370 | +
|
| 371 | + Returns: |
| 372 | + Dictionary containing all metadata fields for vector DB |
| 373 | + """ |
| 374 | + base_metadata = { |
| 375 | + "chunk_id": self.chunk_id, |
| 376 | + "text": self.text, |
| 377 | + "document_id": self.document_id, |
| 378 | + "parent_chunk_id": self.parent_chunk_id, |
| 379 | + "child_chunk_ids": self.child_chunk_ids, |
| 380 | + "level": self.level, |
| 381 | + } |
| 382 | + |
| 383 | + # Add chunk metadata if present |
| 384 | + if self.metadata: |
| 385 | + base_metadata.update(self.metadata.model_dump(exclude_none=True)) |
| 386 | + |
| 387 | + # Remove None values |
| 388 | + return {k: v for k, v in base_metadata.items() if v is not None} |
| 389 | + |
| 390 | + def to_vector_db(self) -> dict[str, Any]: |
| 391 | + """Convert to complete vector database record format. |
| 392 | +
|
| 393 | + Returns: |
| 394 | + Dictionary containing embeddings and metadata for vector DB insertion |
| 395 | + """ |
| 396 | + return {"id": self.chunk_id, "vector": self.embeddings, "metadata": self.to_vector_metadata()} |
| 397 | + |
| 398 | + |
| 399 | +class DocumentIngestionRequest(BaseModel): |
| 400 | + """Request model for ingesting documents into a vector database. |
| 401 | +
|
| 402 | + Handles batching and provides convenient access to embedded chunks. |
| 403 | +
|
| 404 | + Attributes: |
| 405 | + chunks: List of document chunks to ingest |
| 406 | + collection_id: Target collection identifier |
| 407 | + batch_size: Number of chunks to process per batch (default: 100) |
| 408 | + """ |
| 409 | + |
| 410 | + chunks: list[DocumentChunk] = Field(..., description="Document chunks to ingest") |
| 411 | + collection_id: str = Field(..., description="Target collection identifier") |
| 412 | + batch_size: int = Field(default=100, ge=1, le=1000, description="Batch size for processing") |
| 413 | + |
| 414 | + model_config = ConfigDict(from_attributes=True) |
| 415 | + |
| 416 | + @field_validator("chunks") |
| 417 | + @classmethod |
| 418 | + def validate_chunks_not_empty(cls, v: list[DocumentChunk]) -> list[DocumentChunk]: |
| 419 | + """Ensure chunks list is not empty.""" |
| 420 | + if not v: |
| 421 | + raise ValueError("Chunks list cannot be empty") |
| 422 | + return v |
| 423 | + |
| 424 | + def get_embedded_chunks(self) -> list[EmbeddedChunk]: |
| 425 | + """Extract chunks that have embeddings. |
| 426 | +
|
| 427 | + Returns: |
| 428 | + List of EmbeddedChunk instances |
| 429 | +
|
| 430 | + Raises: |
| 431 | + ValueError: If a chunk's embeddings are invalid |
| 432 | + """ |
| 433 | + embedded_chunks = [] |
| 434 | + for chunk in self.chunks: |
| 435 | + if chunk.embeddings: |
| 436 | + embedded_chunks.append(EmbeddedChunk.from_chunk(chunk)) |
| 437 | + return embedded_chunks |
| 438 | + |
| 439 | + def get_batches(self) -> list[list[DocumentChunk]]: |
| 440 | + """Split chunks into batches based on batch_size. |
| 441 | +
|
| 442 | + Returns: |
| 443 | + List of chunk batches |
| 444 | + """ |
| 445 | + batches = [] |
| 446 | + for i in range(0, len(self.chunks), self.batch_size): |
| 447 | + batches.append(self.chunks[i : i + self.batch_size]) |
| 448 | + return batches |
| 449 | + |
| 450 | + |
| 451 | +class VectorSearchRequest(BaseModel): |
| 452 | + """Request model for vector database searches. |
| 453 | +
|
| 454 | + Standardizes search operations with support for both text and vector queries. |
| 455 | +
|
| 456 | + Attributes: |
| 457 | + query_text: Text query to search for (optional if query_vector provided) |
| 458 | + query_vector: Pre-computed query embedding (optional if query_text provided) |
| 459 | + collection_id: Collection to search in |
| 460 | + top_k: Number of results to return (default: 10) |
| 461 | + metadata_filter: Optional metadata filtering criteria |
| 462 | + include_metadata: Whether to include metadata in results (default: True) |
| 463 | + include_vectors: Whether to include vectors in results (default: False) |
| 464 | + """ |
| 465 | + |
| 466 | + query_text: str | None = Field(default=None, description="Text query to search for") |
| 467 | + query_vector: Embeddings | None = Field(default=None, description="Pre-computed query embedding") |
| 468 | + collection_id: str = Field(..., description="Collection to search in") |
| 469 | + top_k: int = Field(default=10, ge=1, le=100, description="Number of results to return") |
| 470 | + metadata_filter: DocumentMetadataFilter | None = Field(default=None, description="Optional metadata filtering") |
| 471 | + include_metadata: bool = Field(default=True, description="Include metadata in results") |
| 472 | + include_vectors: bool = Field(default=False, description="Include vectors in results") |
| 473 | + |
| 474 | + model_config = ConfigDict(from_attributes=True) |
| 475 | + |
| 476 | + @field_validator("query_text", "query_vector") |
| 477 | + @classmethod |
| 478 | + def validate_query_provided(cls, v: Any) -> Any: |
| 479 | + """Ensure at least one query type is provided.""" |
| 480 | + # This validator runs for each field, so we'll do the cross-field validation in model_validator |
| 481 | + return v |
| 482 | + |
| 483 | + def model_post_init(self, __context: Any) -> None: |
| 484 | + """Validate that at least one query type is provided.""" |
| 485 | + if not self.query_text and not self.query_vector: |
| 486 | + raise ValueError("Either query_text or query_vector must be provided") |
| 487 | + |
| 488 | + def to_vector_query(self) -> VectorQuery: |
| 489 | + """Convert to VectorQuery model. |
| 490 | +
|
| 491 | + Returns: |
| 492 | + VectorQuery instance for backward compatibility |
| 493 | + """ |
| 494 | + return VectorQuery( |
| 495 | + text=self.query_text or "", |
| 496 | + embeddings=self.query_vector, |
| 497 | + metadata_filter=self.metadata_filter, |
| 498 | + number_of_results=self.top_k, |
| 499 | + ) |
| 500 | + |
| 501 | + |
| 502 | +class CollectionConfig(BaseModel): |
| 503 | + """Configuration for a vector database collection. |
| 504 | +
|
| 505 | + Manages collection settings with database-specific validation. |
| 506 | +
|
| 507 | + Attributes: |
| 508 | + collection_name: Name of the collection |
| 509 | + dimension: Vector dimension size |
| 510 | + metric_type: Distance metric (L2, IP, COSINE) |
| 511 | + index_type: Index type (FLAT, IVF_FLAT, HNSW, etc.) |
| 512 | + index_params: Database-specific index parameters |
| 513 | + description: Optional collection description |
| 514 | + """ |
| 515 | + |
| 516 | + collection_name: str = Field(..., min_length=1, max_length=255, description="Collection name") |
| 517 | + dimension: int = Field(..., ge=1, le=4096, description="Vector dimension size") |
| 518 | + metric_type: str = Field(default="L2", description="Distance metric type") |
| 519 | + index_type: str = Field(default="HNSW", description="Index type") |
| 520 | + index_params: dict[str, Any] = Field(default_factory=dict, description="Database-specific index parameters") |
| 521 | + description: str | None = Field(default=None, max_length=1000, description="Collection description") |
| 522 | + |
| 523 | + model_config = ConfigDict(from_attributes=True) |
| 524 | + |
| 525 | + @field_validator("metric_type") |
| 526 | + @classmethod |
| 527 | + def validate_metric_type(cls, v: str) -> str: |
| 528 | + """Validate metric type is one of the supported types.""" |
| 529 | + valid_metrics = ["L2", "IP", "COSINE", "HAMMING", "JACCARD"] |
| 530 | + v_upper = v.upper() |
| 531 | + if v_upper not in valid_metrics: |
| 532 | + raise ValueError(f"Invalid metric_type. Must be one of: {', '.join(valid_metrics)}") |
| 533 | + return v_upper |
| 534 | + |
| 535 | + @field_validator("index_type") |
| 536 | + @classmethod |
| 537 | + def validate_index_type(cls, v: str) -> str: |
| 538 | + """Validate index type is one of the supported types.""" |
| 539 | + valid_indexes = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "ANNOY"] |
| 540 | + v_upper = v.upper() |
| 541 | + if v_upper not in valid_indexes: |
| 542 | + raise ValueError(f"Invalid index_type. Must be one of: {', '.join(valid_indexes)}") |
| 543 | + return v_upper |
| 544 | + |
| 545 | + def to_dict(self) -> dict[str, Any]: |
| 546 | + """Convert to dictionary for database-specific usage. |
| 547 | +
|
| 548 | + Returns: |
| 549 | + Dictionary representation suitable for vector DB creation |
| 550 | + """ |
| 551 | + return self.model_dump(exclude_none=True) |
| 552 | + |
| 553 | + |
| 554 | +# Generic type variable for response data |
| 555 | +T = TypeVar("T") |
| 556 | + |
| 557 | + |
| 558 | +class VectorDBResponse[T](BaseModel): |
| 559 | + """Generic response wrapper for vector database operations. |
| 560 | +
|
| 561 | + Provides consistent success/error handling across all vector DB operations. |
| 562 | +
|
| 563 | + Attributes: |
| 564 | + success: Whether the operation succeeded |
| 565 | + data: Response data (type depends on operation) |
| 566 | + error: Error message if operation failed |
| 567 | + metadata: Optional metadata about the operation |
| 568 | + """ |
| 569 | + |
| 570 | + success: bool = Field(..., description="Whether the operation succeeded") |
| 571 | + data: T | None = Field(default=None, description="Response data") |
| 572 | + error: str | None = Field(default=None, description="Error message if operation failed") |
| 573 | + metadata: dict[str, Any] = Field(default_factory=dict, description="Optional operation metadata") |
| 574 | + |
| 575 | + model_config = ConfigDict(from_attributes=True) |
| 576 | + |
| 577 | + @classmethod |
| 578 | + def create_success(cls, data: T, metadata: dict[str, Any] | None = None) -> VectorDBResponse[T]: |
| 579 | + """Create a success response. |
| 580 | +
|
| 581 | + Args: |
| 582 | + data: The response data |
| 583 | + metadata: Optional metadata about the operation |
| 584 | +
|
| 585 | + Returns: |
| 586 | + VectorDBResponse with success=True |
| 587 | + """ |
| 588 | + return cls(success=True, data=data, error=None, metadata=metadata or {}) |
| 589 | + |
| 590 | + @classmethod |
| 591 | + def create_error(cls, error: str, metadata: dict[str, Any] | None = None) -> VectorDBResponse[T]: |
| 592 | + """Create an error response. |
| 593 | +
|
| 594 | + Args: |
| 595 | + error: Error message |
| 596 | + metadata: Optional metadata about the operation |
| 597 | +
|
| 598 | + Returns: |
| 599 | + VectorDBResponse with success=False |
| 600 | + """ |
| 601 | + return cls(success=False, data=None, error=error, metadata=metadata or {}) |
| 602 | + |
| 603 | + def is_success(self) -> bool: |
| 604 | + """Check if the operation was successful. |
| 605 | +
|
| 606 | + Returns: |
| 607 | + True if success, False otherwise |
| 608 | + """ |
| 609 | + return self.success |
| 610 | + |
| 611 | + def is_error(self) -> bool: |
| 612 | + """Check if the operation failed. |
| 613 | +
|
| 614 | + Returns: |
| 615 | + True if error, False otherwise |
| 616 | + """ |
| 617 | + return not self.success |
| 618 | + |
| 619 | + def get_data_or_raise(self) -> T: |
| 620 | + """Get data or raise an exception if operation failed. |
| 621 | +
|
| 622 | + Returns: |
| 623 | + The response data |
| 624 | +
|
| 625 | + Raises: |
| 626 | + ValueError: If the operation failed |
| 627 | + """ |
| 628 | + if self.is_error(): |
| 629 | + raise ValueError(f"Operation failed: {self.error}") |
| 630 | + if self.data is None: |
| 631 | + raise ValueError("No data available in response") |
| 632 | + return self.data |
| 633 | + |
| 634 | + |
| 635 | +# Type aliases for common response types |
| 636 | +VectorDBIngestionResponse = VectorDBResponse[list[str]] # List of ingested IDs |
| 637 | +VectorDBSearchResponse = VectorDBResponse[list[QueryResult]] # Search results |
| 638 | +VectorDBCollectionResponse = VectorDBResponse[dict[str, Any]] # Collection info |
| 639 | +VectorDBDeleteResponse = VectorDBResponse[bool] # Delete success status |
0 commit comments