diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index b1a263d..c5c11e5 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -8,6 +8,7 @@ import os import re import shutil +import sys import time import black @@ -22,15 +23,26 @@ def timelog(label: str, verbose: bool = True): """Context manager to log the time taken by a block of code. With verbose=False it prints nothing.""" + dim = colorama.Style.DIM + reset = colorama.Style.RESET_ALL + if verbose: + print( + f"{dim}{label}...{reset}", + end="", + flush=True, + file=sys.stderr, + ) start_time = time.time() try: yield finally: elapsed_time = time.time() - start_time if verbose: - dim = colorama.Style.DIM - reset = colorama.Style.RESET_ALL - print(f"{dim}{elapsed_time:.3f}s -- {label}{reset}") + print( + f"{dim} {elapsed_time:.3f}s{reset}", + file=sys.stderr, + flush=True, + ) def pretty_print(obj: object, prefix: str = "", suffix: str = "") -> None: diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 1ea7ce9..3bbc572 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -93,10 +93,15 @@ def add_embedding( if key is not None: self._model.add_embedding(key, embedding) - def add_embeddings(self, embeddings: NormalizedEmbeddings) -> None: + def add_embeddings( + self, keys: None | list[str], embeddings: NormalizedEmbeddings + ) -> None: assert embeddings.ndim == 2 assert embeddings.shape[1] == self._embedding_size self._vectors = np.concatenate((self._vectors, embeddings), axis=0) + if keys is not None: + for key, embedding in zip(keys, embeddings): + self._model.add_embedding(key, embedding) async def add_key(self, key: str, cache: bool = True) -> None: embeddings = (await self.get_embedding(key, cache=cache)).reshape(1, -1) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index b2476cb..025b050 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -141,10 +141,15 @@ async def add_messages_with_indexing( Exception: Any error """ storage = await self.settings.get_storage_provider() + if source_ids: + if len(source_ids) != len(messages): + raise ValueError( + f"Length of source_ids {len(source_ids)} " + f"must match length of messages {len(messages)}" + ) async with storage: - # Mark source IDs as ingested before adding messages - # This way, if indexing fails, the rollback will also undo the marks + # Mark source IDs as ingested (will be rolled back on error) if source_ids: for source_id in source_ids: storage.mark_source_ingested(source_id) diff --git a/src/typeagent/knowpro/convknowledge.py b/src/typeagent/knowpro/convknowledge.py index c445328..59f9adf 100644 --- a/src/typeagent/knowpro/convknowledge.py +++ b/src/typeagent/knowpro/convknowledge.py @@ -13,6 +13,10 @@ # TODO: Move ModelWrapper and create_typechat_model() to aitools package. +# TODO: Make this a parameter that can be configured (e.g. from command line). +DEFAULT_TIMEOUT_SECONDS = 30 + + class ModelWrapper(typechat.TypeChatLanguageModel): def __init__( self, @@ -34,6 +38,7 @@ async def complete( key_name = "AZURE_OPENAI_API_KEY" env[key_name] = api_key self.base_model = typechat.create_language_model(env) + self.base_model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS return await self.base_model.complete(prompt) @@ -46,6 +51,7 @@ def create_typechat_model() -> typechat.TypeChatLanguageModel: shared_token_provider = auth.get_shared_token_provider() env[key_name] = shared_token_provider.get_token() model = typechat.create_language_model(env) + model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS if shared_token_provider is not None: model = ModelWrapper(model, shared_token_provider) return model diff --git a/src/typeagent/knowpro/fuzzyindex.py b/src/typeagent/knowpro/fuzzyindex.py index 44bea04..6ace1b3 100644 --- a/src/typeagent/knowpro/fuzzyindex.py +++ b/src/typeagent/knowpro/fuzzyindex.py @@ -22,10 +22,9 @@ def __init__( # Use VectorBase for storage and operations on embeddings. self._vector_base = VectorBase(settings) - # Initialize with embeddings if provided. + # Add embeddings to vectorbase if provided. if embeddings is not None: - for embedding in embeddings: - self._vector_base.add_embedding(None, embedding) + self._vector_base.add_embeddings(None, embeddings) def __len__(self) -> int: return len(self._vector_base) @@ -43,7 +42,7 @@ def get(self, pos: int) -> NormalizedEmbedding: return self._vector_base.get_embedding_at(pos) def push(self, embeddings: NormalizedEmbeddings) -> None: - self._vector_base.add_embeddings(embeddings) + self._vector_base.add_embeddings(None, embeddings) async def add_texts(self, texts: list[str]) -> None: await self._vector_base.add_keys(texts) diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index e6d2070..a19834e 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -26,14 +26,7 @@ ITimestampToTextRangeIndex, ) -__all__ = [ - "ConversationMetadata", - "IReadonlyCollection", - "ICollection", - "IMessageCollection", - "ISemanticRefCollection", - "IStorageProvider", -] +STATUS_INGESTED = "ingested" @dataclass @@ -150,7 +143,13 @@ def is_source_ingested(self, source_id: str) -> bool: """Check if a source has already been ingested.""" ... - def mark_source_ingested(self, source_id: str) -> None: + def get_source_status(self, source_id: str) -> str | None: + """Get the ingestion status of a source.""" + ... + + def mark_source_ingested( + self, source_id: str, status: str = STATUS_INGESTED + ) -> None: """Mark a source as ingested (no commit; call within transaction context).""" ... @@ -191,4 +190,5 @@ class IConversation[ "IReadonlyCollection", "ISemanticRefCollection", "IStorageProvider", + "STATUS_INGESTED", ] diff --git a/src/typeagent/storage/memory/provider.py b/src/typeagent/storage/memory/provider.py index e80286e..83ef6ab 100644 --- a/src/typeagent/storage/memory/provider.py +++ b/src/typeagent/storage/memory/provider.py @@ -17,6 +17,7 @@ ITermToRelatedTermsIndex, ITermToSemanticRefIndex, ITimestampToTextRangeIndex, + STATUS_INGESTED, ) from .collections import MemoryMessageCollection, MemorySemanticRefCollection from .convthreads import ConversationThreads @@ -150,7 +151,22 @@ def is_source_ingested(self, source_id: str) -> bool: """ return source_id in self._ingested_sources - def mark_source_ingested(self, source_id: str) -> None: + def get_source_status(self, source_id: str) -> str | None: + """Get the ingestion status of a source. + + Args: + source_id: External source identifier (email ID, file path, etc.) + + Returns: + The ingestion status if the source has been ingested, None otherwise. + """ + if source_id in self._ingested_sources: + return STATUS_INGESTED + return None + + def mark_source_ingested( + self, source_id: str, status: str = STATUS_INGESTED + ) -> None: """Mark a source as ingested. Args: diff --git a/src/typeagent/storage/memory/semrefindex.py b/src/typeagent/storage/memory/semrefindex.py index 3aa650a..ec44c87 100644 --- a/src/typeagent/storage/memory/semrefindex.py +++ b/src/typeagent/storage/memory/semrefindex.py @@ -136,7 +136,7 @@ async def add_batch_to_semantic_ref_index_from_list[ for i, knowledge_result in enumerate(knowledge_results): if isinstance(knowledge_result, Failure): raise RuntimeError( - f"Knowledge extraction failed: {knowledge_result.message}" + f"Knowledge extraction failed: {knowledge_result.message:.150}" ) text_location = batch[i] knowledge = knowledge_result.value diff --git a/src/typeagent/storage/sqlite/messageindex.py b/src/typeagent/storage/sqlite/messageindex.py index f5cbd13..8b7afd5 100644 --- a/src/typeagent/storage/sqlite/messageindex.py +++ b/src/typeagent/storage/sqlite/messageindex.py @@ -34,8 +34,15 @@ def __init__( if self._size(): cursor = self.db.cursor() cursor.execute("SELECT embedding FROM MessageTextIndex") - for row in cursor.fetchall(): - self._vectorbase.add_embedding(None, deserialize_embedding(row[0])) + rows = cursor.fetchall() + if rows: + embeddings: list[NormalizedEmbedding] = [ + deserialize_embedding(row[0]) for row in rows + ] + embeddings_array = np.stack(embeddings, axis=0).astype( + np.float32, copy=False + ) + self._vectorbase.add_embeddings(None, embeddings_array) async def size(self) -> int: return self._size() @@ -383,7 +390,7 @@ async def deserialize(self, data: interfaces.MessageTextIndexData) -> None: ) # Update VectorBase - self._vectorbase.add_embeddings(embeddings) + self._vectorbase.add_embeddings(None, embeddings) async def clear(self) -> None: """Clear the message text index.""" diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index ac4e52b..b3a63a9 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -10,7 +10,7 @@ from ...aitools.vectorbase import TextEmbeddingIndexSettings from ...knowpro import interfaces from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings -from ...knowpro.interfaces import ConversationMetadata +from ...knowpro.interfaces import ConversationMetadata, STATUS_INGESTED from .collections import SqliteMessageCollection, SqliteSemanticRefCollection from .messageindex import SqliteMessageTextIndex from .propindex import SqlitePropertyIndex @@ -56,6 +56,7 @@ def __init__( self.db = sqlite3.connect(db_path) # Configure SQLite for optimal bulk insertion performance + # TODO: Move into init_db_schema() self.db.execute("PRAGMA foreign_keys = ON") # Improve write performance for bulk operations self.db.execute("PRAGMA synchronous = NORMAL") # Faster than FULL, still safe @@ -625,11 +626,30 @@ def is_source_ingested(self, source_id: str) -> bool: """ cursor = self.db.cursor() cursor.execute( - "SELECT 1 FROM IngestedSources WHERE source_id = ?", (source_id,) + "SELECT status FROM IngestedSources WHERE source_id = ?", (source_id,) ) - return cursor.fetchone() is not None + row = cursor.fetchone() + return row is not None and row[0] == STATUS_INGESTED + + def get_source_status(self, source_id: str) -> str | None: + """Get the ingestion status of a source. + + Args: + source_id: External source identifier (email ID, file path, etc.) + + Returns: + The status string if the source exists, or None if it hasn't been ingested. + """ + cursor = self.db.cursor() + cursor.execute( + "SELECT status FROM IngestedSources WHERE source_id = ?", (source_id,) + ) + row = cursor.fetchone() + return row[0] if row else None - def mark_source_ingested(self, source_id: str) -> None: + def mark_source_ingested( + self, source_id: str, status: str = STATUS_INGESTED + ) -> None: """Mark a source as ingested. This performs an INSERT but does NOT commit. It should be called within @@ -641,6 +661,6 @@ def mark_source_ingested(self, source_id: str) -> None: """ cursor = self.db.cursor() cursor.execute( - "INSERT OR IGNORE INTO IngestedSources (source_id) VALUES (?)", - (source_id,), + "INSERT OR REPLACE INTO IngestedSources (source_id, status) VALUES (?, ?)", + (source_id, status), ) diff --git a/src/typeagent/storage/sqlite/reltermsindex.py b/src/typeagent/storage/sqlite/reltermsindex.py index e56221a..dec29db 100644 --- a/src/typeagent/storage/sqlite/reltermsindex.py +++ b/src/typeagent/storage/sqlite/reltermsindex.py @@ -5,7 +5,9 @@ import sqlite3 -from ...aitools.embeddings import NormalizedEmbeddings +import numpy as np + +from ...aitools.embeddings import NormalizedEmbedding from ...aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase from ...knowpro import interfaces from .schema import deserialize_embedding, serialize_embedding @@ -145,13 +147,13 @@ def __init__(self, db: sqlite3.Connection, settings: TextEmbeddingIndexSettings) "SELECT term, term_embedding FROM RelatedTermsFuzzy ORDER BY term" ) rows = cursor.fetchall() + embeddings: list[NormalizedEmbedding] = [] for term, blob in rows: assert blob is not None, term - embedding: NormalizedEmbeddings = deserialize_embedding(blob) - # Add to VectorBase at the correct ordinal - self._vector_base.add_embedding(term, embedding) self._terms_list.append(term) - self._added_terms.add(term) + embeddings.append(deserialize_embedding(blob)) + # Bulk add embeddings to VectorBase + self._vector_base.add_embeddings(None, np.array(embeddings)) async def lookup_term( self, diff --git a/src/typeagent/storage/sqlite/schema.py b/src/typeagent/storage/sqlite/schema.py index 4ec25b8..db6933d 100644 --- a/src/typeagent/storage/sqlite/schema.py +++ b/src/typeagent/storage/sqlite/schema.py @@ -10,7 +10,7 @@ import numpy as np from ...aitools.embeddings import NormalizedEmbedding -from ...knowpro.interfaces import ConversationMetadata +from ...knowpro.interfaces import ConversationMetadata, STATUS_INGESTED # Constants CONVERSATION_SCHEMA_VERSION = 1 @@ -141,9 +141,10 @@ # Table for tracking ingested source IDs (e.g., email IDs, file paths) # This prevents re-ingesting the same content on subsequent runs -INGESTED_SOURCES_SCHEMA = """ +INGESTED_SOURCES_SCHEMA = f""" CREATE TABLE IF NOT EXISTS IngestedSources ( - source_id TEXT PRIMARY KEY -- External source identifier (email ID, file path, etc.) + source_id TEXT PRIMARY KEY, -- External source identifier (email ID, file path, etc.) + status TEXT NOT NULL DEFAULT {STATUS_INGESTED} -- Status of the source (e.g., 'ingested') ); """ diff --git a/tests/test_utils.py b/tests/test_utils.py index 1a9f7d7..84bd6ee 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from contextlib import redirect_stdout +from contextlib import redirect_stderr, redirect_stdout from io import StringIO import os @@ -10,11 +10,11 @@ def test_timelog(): buf = StringIO() - with redirect_stdout(buf): + with redirect_stderr(buf): with utils.timelog("test block"): pass out = buf.getvalue() - assert "s -- test block" in out + assert "test block..." in out def test_pretty_print(): diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index 22e3440..62abd39 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -48,6 +48,26 @@ def test_add_embedding(vector_base: VectorBase, sample_embeddings: Samples): np.testing.assert_array_equal(vector_base.serialize_embedding_at(i), embedding) +def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): + """Adding multiple embeddings at once matches repeated single adds.""" + keys = list(sample_embeddings.keys()) + for key, embedding in sample_embeddings.items(): + vector_base.add_embedding(key, embedding) + + bulk_vector_base = make_vector_base() + stacked_embeddings = np.stack([sample_embeddings[key] for key in keys], axis=0) + bulk_vector_base.add_embeddings(keys, stacked_embeddings) + + assert len(bulk_vector_base) == len(vector_base) + np.testing.assert_array_equal(bulk_vector_base.serialize(), vector_base.serialize()) + + sequential_cache = vector_base._model._embedding_cache + bulk_cache = bulk_vector_base._model._embedding_cache + assert set(sequential_cache.keys()) == set(bulk_cache.keys()) + for key in keys: + np.testing.assert_array_equal(bulk_cache[key], sequential_cache[key]) + + @pytest.mark.asyncio async def test_add_key(vector_base: VectorBase, sample_embeddings: Samples): """Test adding keys to the VectorBase.""" diff --git a/tools/ingest_email.py b/tools/ingest_email.py index 34b65f7..ef4b623 100644 --- a/tools/ingest_email.py +++ b/tools/ingest_email.py @@ -14,8 +14,16 @@ python query.py --database email.db --query "What was discussed?" """ +""" +TODO + +- Catch auth errors and stop rather than marking as failed +- Collect knowledge outside db transaction to reduce lock time +""" + import argparse import asyncio +from email.header import decode_header from pathlib import Path import sys import time @@ -74,7 +82,7 @@ def collect_email_files(paths: list[str], verbose: bool) -> list[Path]: elif path.is_dir(): eml_files = sorted(path.glob("*.eml")) if verbose: - print(f" Found {len(eml_files)} .eml files in {path}") + print(f"Found {len(eml_files)} .eml files in {path}") email_files.extend(eml_files) else: print(f"Error: Not a file or directory: {path}", file=sys.stderr) @@ -83,6 +91,21 @@ def collect_email_files(paths: list[str], verbose: bool) -> list[Path]: return email_files +def decode_encoded_word(s: str) -> str: + """Decode an RFC 2047 encoded string.""" + if "=?utf-8?" not in s: + return s # Fast path for common case + decoded_parts = decode_header(s) + return "".join( + ( + part.decode(encoding or "utf-8", errors="replace") + if isinstance(part, bytes) + else part + ) + for part, encoding in decoded_parts + ) + + async def ingest_emails( paths: list[str], database: str, @@ -91,16 +114,15 @@ async def ingest_emails( """Ingest email files into a database.""" # Collect all .eml files - if verbose: - print("Collecting email files...") - email_files = collect_email_files(paths, verbose) + with utils.timelog("Collecting email files"): + email_files = collect_email_files(paths, verbose) if not email_files: print("Error: No .eml files found", file=sys.stderr) sys.exit(1) if verbose: - print(f"Found {len(email_files)} email files to ingest") + print(f"Found {len(email_files)} email files in total to ingest") # Load environment for model API access if verbose: @@ -144,54 +166,79 @@ async def ingest_emails( for i, email_file in enumerate(email_files): try: if verbose: - print(f"\n[{i + 1}/{len(email_files)}] {email_file}") + print(f"[{i + 1}/{len(email_files)}] {email_file}", end="", flush=True) + if status := storage_provider.get_source_status(str(email_file)): + skipped_count += 1 + if verbose: + print(f" [Previously {status}, skipping]") + continue + else: + if verbose: + print() email = import_email_from_file(str(email_file)) - email_id = email.metadata.id + source_id = email.metadata.id + if verbose: + print(f" Email ID: {source_id}", end="") # Check if this email was already ingested - if email_id and storage_provider.is_source_ingested(email_id): + if source_id and (status := storage_provider.get_source_status(source_id)): skipped_count += 1 if verbose: - print(f" [Already ingested, skipping]") + print(f" [Previously {status}, skipping]") + async with storage_provider: + storage_provider.mark_source_ingested(str(email_file), status) continue + else: + if verbose: + print() if verbose: print(f" From: {email.metadata.sender}") if email.metadata.subject: - print(f" Subject: {email.metadata.subject}") + print( + f" Subject: {decode_encoded_word(email.metadata.subject).replace('\n', '\\n')}" + ) print(f" Date: {email.timestamp}") print(f" Body chunks: {len(email.text_chunks)}") for chunk in email.text_chunks: - # Show first 200 chars of each chunk - preview = chunk[:200].replace("\n", " ") - if len(chunk) > 200: - preview += "..." + # Show first N chars of each decoded chunk + N = 150 + chunk = decode_encoded_word(chunk) + preview = repr(chunk[: N + 1])[1:-1] + if len(preview) > N: + preview = preview[: N - 3] + "..." print(f" {preview}") # Pass source_id to mark as ingested atomically with the message - source_ids = [email_id] if email_id else None await email_memory.add_messages_with_indexing( - [email], source_ids=source_ids - ) + [email], source_ids=[str(email_file)] + ) # This may raise, esp. if the knowledge extraction fails (see except below) successful_count += 1 # Print progress periodically - if not verbose and (i + 1) % batch_size == 0: + if (i + 1) % batch_size == 0: elapsed = time.time() - start_time semref_count = await semref_coll.size() print( - f" [{i + 1}/{len(email_files)}] {successful_count} imported | " - f"{semref_count} refs | {elapsed:.1f}s elapsed" + f"\n[{i + 1}/{len(email_files)}] {successful_count} imported | " + f"{semref_count} refs | {elapsed:.1f}s elapsed\n" ) except Exception as e: failed_count += 1 print(f"Error processing {email_file}: {e}", file=sys.stderr) + exc_name = ( + e.__class__.__qualname__ + if e.__class__.__module__ == "builtins" + else f"{e.__class__.__module__}.{e.__class__.__qualname__}" + ) + async with storage_provider: + storage_provider.mark_source_ingested(str(email_file), exc_name) if verbose: import traceback - traceback.print_exc() + traceback.print_exc(limit=10) # Final summary elapsed = time.time() - start_time