Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/typeagent/aitools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import re
import shutil
import sys
import time

import black
Expand All @@ -22,15 +23,26 @@ def timelog(label: str, verbose: bool = True):
"""Context manager to log the time taken by a block of code.

With verbose=False it prints nothing."""
dim = colorama.Style.DIM
reset = colorama.Style.RESET_ALL
if verbose:
print(
f"{dim}{label}...{reset}",
end="",
flush=True,
file=sys.stderr,
)
start_time = time.time()
try:
yield
finally:
elapsed_time = time.time() - start_time
if verbose:
dim = colorama.Style.DIM
reset = colorama.Style.RESET_ALL
print(f"{dim}{elapsed_time:.3f}s -- {label}{reset}")
print(
f"{dim} {elapsed_time:.3f}s{reset}",
file=sys.stderr,
flush=True,
)


def pretty_print(obj: object, prefix: str = "", suffix: str = "") -> None:
Expand Down
7 changes: 6 additions & 1 deletion src/typeagent/aitools/vectorbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,15 @@ def add_embedding(
if key is not None:
self._model.add_embedding(key, embedding)

def add_embeddings(self, embeddings: NormalizedEmbeddings) -> None:
def add_embeddings(
self, keys: None | list[str], embeddings: NormalizedEmbeddings
) -> None:
assert embeddings.ndim == 2
assert embeddings.shape[1] == self._embedding_size
self._vectors = np.concatenate((self._vectors, embeddings), axis=0)
if keys is not None:
for key, embedding in zip(keys, embeddings):
self._model.add_embedding(key, embedding)

async def add_key(self, key: str, cache: bool = True) -> None:
embeddings = (await self.get_embedding(key, cache=cache)).reshape(1, -1)
Expand Down
9 changes: 7 additions & 2 deletions src/typeagent/knowpro/conversation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,15 @@ async def add_messages_with_indexing(
Exception: Any error
"""
storage = await self.settings.get_storage_provider()
if source_ids:
if len(source_ids) != len(messages):
raise ValueError(
f"Length of source_ids {len(source_ids)} "
f"must match length of messages {len(messages)}"
)

async with storage:
# Mark source IDs as ingested before adding messages
# This way, if indexing fails, the rollback will also undo the marks
# Mark source IDs as ingested (will be rolled back on error)
if source_ids:
for source_id in source_ids:
storage.mark_source_ingested(source_id)
Expand Down
6 changes: 6 additions & 0 deletions src/typeagent/knowpro/convknowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# TODO: Move ModelWrapper and create_typechat_model() to aitools package.


# TODO: Make this a parameter that can be configured (e.g. from command line).
DEFAULT_TIMEOUT_SECONDS = 30


class ModelWrapper(typechat.TypeChatLanguageModel):
def __init__(
self,
Expand All @@ -34,6 +38,7 @@ async def complete(
key_name = "AZURE_OPENAI_API_KEY"
env[key_name] = api_key
self.base_model = typechat.create_language_model(env)
self.base_model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS
return await self.base_model.complete(prompt)


Expand All @@ -46,6 +51,7 @@ def create_typechat_model() -> typechat.TypeChatLanguageModel:
shared_token_provider = auth.get_shared_token_provider()
env[key_name] = shared_token_provider.get_token()
model = typechat.create_language_model(env)
model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS
if shared_token_provider is not None:
model = ModelWrapper(model, shared_token_provider)
return model
Expand Down
7 changes: 3 additions & 4 deletions src/typeagent/knowpro/fuzzyindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ def __init__(
# Use VectorBase for storage and operations on embeddings.
self._vector_base = VectorBase(settings)

# Initialize with embeddings if provided.
# Add embeddings to vectorbase if provided.
if embeddings is not None:
for embedding in embeddings:
self._vector_base.add_embedding(None, embedding)
self._vector_base.add_embeddings(None, embeddings)

def __len__(self) -> int:
return len(self._vector_base)
Expand All @@ -43,7 +42,7 @@ def get(self, pos: int) -> NormalizedEmbedding:
return self._vector_base.get_embedding_at(pos)

def push(self, embeddings: NormalizedEmbeddings) -> None:
self._vector_base.add_embeddings(embeddings)
self._vector_base.add_embeddings(None, embeddings)

async def add_texts(self, texts: list[str]) -> None:
await self._vector_base.add_keys(texts)
Expand Down
18 changes: 9 additions & 9 deletions src/typeagent/knowpro/interfaces_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,7 @@
ITimestampToTextRangeIndex,
)

__all__ = [
"ConversationMetadata",
"IReadonlyCollection",
"ICollection",
"IMessageCollection",
"ISemanticRefCollection",
"IStorageProvider",
]
STATUS_INGESTED = "ingested"


@dataclass
Expand Down Expand Up @@ -150,7 +143,13 @@ def is_source_ingested(self, source_id: str) -> bool:
"""Check if a source has already been ingested."""
...

def mark_source_ingested(self, source_id: str) -> None:
def get_source_status(self, source_id: str) -> str | None:
"""Get the ingestion status of a source."""
...

def mark_source_ingested(
self, source_id: str, status: str = STATUS_INGESTED
) -> None:
"""Mark a source as ingested (no commit; call within transaction context)."""
...

Expand Down Expand Up @@ -191,4 +190,5 @@ class IConversation[
"IReadonlyCollection",
"ISemanticRefCollection",
"IStorageProvider",
"STATUS_INGESTED",
]
18 changes: 17 additions & 1 deletion src/typeagent/storage/memory/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ITermToRelatedTermsIndex,
ITermToSemanticRefIndex,
ITimestampToTextRangeIndex,
STATUS_INGESTED,
)
from .collections import MemoryMessageCollection, MemorySemanticRefCollection
from .convthreads import ConversationThreads
Expand Down Expand Up @@ -150,7 +151,22 @@ def is_source_ingested(self, source_id: str) -> bool:
"""
return source_id in self._ingested_sources

def mark_source_ingested(self, source_id: str) -> None:
def get_source_status(self, source_id: str) -> str | None:
"""Get the ingestion status of a source.

Args:
source_id: External source identifier (email ID, file path, etc.)

Returns:
The ingestion status if the source has been ingested, None otherwise.
"""
if source_id in self._ingested_sources:
return STATUS_INGESTED
return None

def mark_source_ingested(
self, source_id: str, status: str = STATUS_INGESTED
) -> None:
"""Mark a source as ingested.

Args:
Expand Down
2 changes: 1 addition & 1 deletion src/typeagent/storage/memory/semrefindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def add_batch_to_semantic_ref_index_from_list[
for i, knowledge_result in enumerate(knowledge_results):
if isinstance(knowledge_result, Failure):
raise RuntimeError(
f"Knowledge extraction failed: {knowledge_result.message}"
f"Knowledge extraction failed: {knowledge_result.message:.150}"
)
text_location = batch[i]
knowledge = knowledge_result.value
Expand Down
13 changes: 10 additions & 3 deletions src/typeagent/storage/sqlite/messageindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,15 @@ def __init__(
if self._size():
cursor = self.db.cursor()
cursor.execute("SELECT embedding FROM MessageTextIndex")
for row in cursor.fetchall():
self._vectorbase.add_embedding(None, deserialize_embedding(row[0]))
rows = cursor.fetchall()
if rows:
embeddings: list[NormalizedEmbedding] = [
deserialize_embedding(row[0]) for row in rows
]
embeddings_array = np.stack(embeddings, axis=0).astype(
np.float32, copy=False
)
self._vectorbase.add_embeddings(None, embeddings_array)

async def size(self) -> int:
return self._size()
Expand Down Expand Up @@ -383,7 +390,7 @@ async def deserialize(self, data: interfaces.MessageTextIndexData) -> None:
)

# Update VectorBase
self._vectorbase.add_embeddings(embeddings)
self._vectorbase.add_embeddings(None, embeddings)

async def clear(self) -> None:
"""Clear the message text index."""
Expand Down
32 changes: 26 additions & 6 deletions src/typeagent/storage/sqlite/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ...aitools.vectorbase import TextEmbeddingIndexSettings
from ...knowpro import interfaces
from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings
from ...knowpro.interfaces import ConversationMetadata
from ...knowpro.interfaces import ConversationMetadata, STATUS_INGESTED
from .collections import SqliteMessageCollection, SqliteSemanticRefCollection
from .messageindex import SqliteMessageTextIndex
from .propindex import SqlitePropertyIndex
Expand Down Expand Up @@ -56,6 +56,7 @@ def __init__(
self.db = sqlite3.connect(db_path)

# Configure SQLite for optimal bulk insertion performance
# TODO: Move into init_db_schema()
self.db.execute("PRAGMA foreign_keys = ON")
# Improve write performance for bulk operations
self.db.execute("PRAGMA synchronous = NORMAL") # Faster than FULL, still safe
Expand Down Expand Up @@ -625,11 +626,30 @@ def is_source_ingested(self, source_id: str) -> bool:
"""
cursor = self.db.cursor()
cursor.execute(
"SELECT 1 FROM IngestedSources WHERE source_id = ?", (source_id,)
"SELECT status FROM IngestedSources WHERE source_id = ?", (source_id,)
)
return cursor.fetchone() is not None
row = cursor.fetchone()
return row is not None and row[0] == STATUS_INGESTED

def get_source_status(self, source_id: str) -> str | None:
"""Get the ingestion status of a source.

Args:
source_id: External source identifier (email ID, file path, etc.)

Returns:
The status string if the source exists, or None if it hasn't been ingested.
"""
cursor = self.db.cursor()
cursor.execute(
"SELECT status FROM IngestedSources WHERE source_id = ?", (source_id,)
)
row = cursor.fetchone()
return row[0] if row else None

def mark_source_ingested(self, source_id: str) -> None:
def mark_source_ingested(
self, source_id: str, status: str = STATUS_INGESTED
) -> None:
"""Mark a source as ingested.

This performs an INSERT but does NOT commit. It should be called within
Expand All @@ -641,6 +661,6 @@ def mark_source_ingested(self, source_id: str) -> None:
"""
cursor = self.db.cursor()
cursor.execute(
"INSERT OR IGNORE INTO IngestedSources (source_id) VALUES (?)",
(source_id,),
"INSERT OR REPLACE INTO IngestedSources (source_id, status) VALUES (?, ?)",
(source_id, status),
)
12 changes: 7 additions & 5 deletions src/typeagent/storage/sqlite/reltermsindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import sqlite3

from ...aitools.embeddings import NormalizedEmbeddings
import numpy as np

from ...aitools.embeddings import NormalizedEmbedding
from ...aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase
from ...knowpro import interfaces
from .schema import deserialize_embedding, serialize_embedding
Expand Down Expand Up @@ -145,13 +147,13 @@ def __init__(self, db: sqlite3.Connection, settings: TextEmbeddingIndexSettings)
"SELECT term, term_embedding FROM RelatedTermsFuzzy ORDER BY term"
)
rows = cursor.fetchall()
embeddings: list[NormalizedEmbedding] = []
for term, blob in rows:
assert blob is not None, term
embedding: NormalizedEmbeddings = deserialize_embedding(blob)
# Add to VectorBase at the correct ordinal
self._vector_base.add_embedding(term, embedding)
self._terms_list.append(term)
self._added_terms.add(term)
embeddings.append(deserialize_embedding(blob))
# Bulk add embeddings to VectorBase
self._vector_base.add_embeddings(None, np.array(embeddings))

async def lookup_term(
self,
Expand Down
7 changes: 4 additions & 3 deletions src/typeagent/storage/sqlite/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from ...aitools.embeddings import NormalizedEmbedding
from ...knowpro.interfaces import ConversationMetadata
from ...knowpro.interfaces import ConversationMetadata, STATUS_INGESTED

# Constants
CONVERSATION_SCHEMA_VERSION = 1
Expand Down Expand Up @@ -141,9 +141,10 @@

# Table for tracking ingested source IDs (e.g., email IDs, file paths)
# This prevents re-ingesting the same content on subsequent runs
INGESTED_SOURCES_SCHEMA = """
INGESTED_SOURCES_SCHEMA = f"""
CREATE TABLE IF NOT EXISTS IngestedSources (
source_id TEXT PRIMARY KEY -- External source identifier (email ID, file path, etc.)
source_id TEXT PRIMARY KEY, -- External source identifier (email ID, file path, etc.)
status TEXT NOT NULL DEFAULT {STATUS_INGESTED} -- Status of the source (e.g., 'ingested')
);
"""

Expand Down
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from contextlib import redirect_stdout
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
import os

Expand All @@ -10,11 +10,11 @@

def test_timelog():
buf = StringIO()
with redirect_stdout(buf):
with redirect_stderr(buf):
with utils.timelog("test block"):
pass
out = buf.getvalue()
assert "s -- test block" in out
assert "test block..." in out


def test_pretty_print():
Expand Down
20 changes: 20 additions & 0 deletions tests/test_vectorbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,26 @@ def test_add_embedding(vector_base: VectorBase, sample_embeddings: Samples):
np.testing.assert_array_equal(vector_base.serialize_embedding_at(i), embedding)


def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples):
"""Adding multiple embeddings at once matches repeated single adds."""
keys = list(sample_embeddings.keys())
for key, embedding in sample_embeddings.items():
vector_base.add_embedding(key, embedding)

bulk_vector_base = make_vector_base()
stacked_embeddings = np.stack([sample_embeddings[key] for key in keys], axis=0)
bulk_vector_base.add_embeddings(keys, stacked_embeddings)

assert len(bulk_vector_base) == len(vector_base)
np.testing.assert_array_equal(bulk_vector_base.serialize(), vector_base.serialize())

sequential_cache = vector_base._model._embedding_cache
bulk_cache = bulk_vector_base._model._embedding_cache
assert set(sequential_cache.keys()) == set(bulk_cache.keys())
for key in keys:
np.testing.assert_array_equal(bulk_cache[key], sequential_cache[key])


@pytest.mark.asyncio
async def test_add_key(vector_base: VectorBase, sample_embeddings: Samples):
"""Test adding keys to the VectorBase."""
Expand Down
Loading