Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions tools/load_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""
JSON to SQLite Database Loader

This tool loads a JSON-serialized podcast database into a SQLite database
that can be queried using tools/query.py.

Usage:
python tools/load_json.py <index_path> --database <db_file>
python tools/load_json.py tests/testdata/Episode_53_AdrianTchaikovsky_index -d transcript.db

The index_path should exclude the "_data.json" suffix.
"""

import argparse
import asyncio
import os

from typeagent.aitools import utils
from typeagent.knowpro.convsettings import ConversationSettings
from typeagent.podcasts import podcast
from typeagent.storage.utils import create_storage_provider


async def load_json_to_database(
podcast_file_prefix: str,
dbname: str,
verbose: bool = False,
) -> None:
"""Load JSON-serialized podcast data into a SQLite database.

Args:
podcast_file_prefix: Path to podcast index files (without "_data.json" suffix)
dbname: Path to SQLite database file (must be empty)
verbose: Whether to show verbose output
"""
if verbose:
print(f"Loading podcast from JSON: {podcast_file_prefix}")
print(f"Target database: {dbname}")

# Create settings and storage provider
settings = ConversationSettings()
settings.storage_provider = await create_storage_provider(
settings.message_text_index_settings,
settings.related_term_index_settings,
dbname,
podcast.PodcastMessage,
)

# Get the storage provider to check if database is empty
provider = await settings.get_storage_provider()
msgs = await provider.get_message_collection()

# Check if database already has data
msg_count = await msgs.size()
if msg_count > 0:
raise RuntimeError(
f"Database '{dbname}' already contains {msg_count} messages. "
"The database must be empty to load new data. "
"Please use a different database file or remove the existing one."
)

# Load podcast from JSON files
with utils.timelog(f"Loading podcast from {podcast_file_prefix!r}"):
async with provider:
conversation = await podcast.Podcast.read_from_file(
podcast_file_prefix, settings, dbname
)

# Print statistics
if verbose:
print(f"\nSuccessfully loaded podcast data:")
print(f" {await conversation.messages.size()} messages")
print(f" {await conversation.semantic_refs.size()} semantic refs")
if conversation.semantic_ref_index:
print(
f" {await conversation.semantic_ref_index.size()} semantic ref index entries"
)

print(f"\nDatabase created: {dbname}")
print(f"\nTo query the database, use:")
print(f" python tools/query.py --database '{dbname}' --query 'Your question here'")


def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Load JSON-serialized podcast data into a SQLite database",
)

parser.add_argument(
"-d",
"--database",
required=True,
help="Path to the SQLite database file (must be empty)",
)

parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Show verbose output including statistics",
)

parser.add_argument(
"index_path",
help="Path to the podcast index files (excluding the '_data.json' suffix)",
)

args = parser.parse_args()

# Ensure index file exists
index_file = args.index_path + "_data.json"
if not os.path.exists(index_file):
raise SystemExit(
f"Error: Podcast index file not found: {index_file}\n"
f"Please verify the path exists and is accessible.\n"
f"Note: The path should exclude the '_data.json' suffix."
)

# Load environment variables for API access
utils.load_dotenv()

# Run the loading process
asyncio.run(load_json_to_database(args.index_path, args.database, args.verbose))


if __name__ == "__main__":
main()
68 changes: 14 additions & 54 deletions tools/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
Topic,
)
from typeagent.podcasts import podcast
from typeagent.storage.sqlite.provider import SqliteStorageProvider
from typeagent.storage.utils import create_storage_provider

### Classes ###
Expand Down Expand Up @@ -536,24 +535,6 @@ async def main():
args = parser.parse_args()
fill_in_debug_defaults(parser, args)

# Validate required podcast argument
if args.podcast is None and args.database is None:
scriptname = sys.argv[0]
raise SystemExit(
f"Error: Either --podcast or --database is required.\n"
f"Usage: python {scriptname} --podcast <path_to_index>\n"
f" or: python {scriptname} --database <path_to_database>\n"
f"Example: python {scriptname} --podcast tests/testdata/Episode_53_AdrianTchaikovsky_index"
)
if args.podcast is not None:
index_file = args.podcast + "_data.json"
if not os.path.exists(index_file):
raise SystemExit(
f"Error: Podcast index file not found: {index_file}\n"
f"Please verify the path exists and is accessible.\n"
f"Note: The path should exclude the '_data.json' suffix."
)

if args.logfire:
utils.setup_logfire()

Expand All @@ -564,9 +545,18 @@ async def main():
args.database,
podcast.PodcastMessage,
)
query_context = await load_podcast_index(
args.podcast, settings, args.database, args.verbose
)

# Load existing database
provider = await settings.get_storage_provider()
msgs = await provider.get_message_collection()
if await msgs.size() == 0:
raise SystemExit(f"Error: Database '{args.database}' is empty.")

with utils.timelog(f"Loading conversation from database {args.database!r}"):
conversation = await podcast.Podcast.create(settings)

await print_conversation_stats(conversation, args.verbose)
query_context = query.QueryEvalContext(conversation)

ar_list, ar_index = load_index_file(
args.qafile, "question", QuestionAnswerData, args.verbose
Expand Down Expand Up @@ -943,12 +933,6 @@ def make_arg_parser(description: str) -> argparse.ArgumentParser:
),
)

parser.add_argument(
"--podcast",
type=str,
default=None,
help="Path to the podcast index files (excluding the '_data.json' suffix)",
)
explain_qa = "a list of questions and answers to test the full pipeline"
parser.add_argument(
"--qafile",
Expand All @@ -973,8 +957,8 @@ def make_arg_parser(description: str) -> argparse.ArgumentParser:
"-d",
"--database",
type=str,
default=None,
help="Path to the SQLite database file (default: in-memory)",
required=True,
help="Path to the SQLite database file",
)
parser.add_argument(
"--query",
Expand Down Expand Up @@ -1110,30 +1094,6 @@ def fill_in_debug_defaults(
### Data loading ###


async def load_podcast_index(
podcast_file_prefix: str,
settings: ConversationSettings,
dbname: str | None,
verbose: bool = True,
) -> query.QueryEvalContext:
provider = await settings.get_storage_provider()
msgs = await provider.get_message_collection()
if await msgs.size() > 0: # Sqlite provider with existing non-empty database
with utils.timelog(f"Reusing database {dbname!r}"):
conversation = await podcast.Podcast.create(settings)
else:
with utils.timelog(f"Loading podcast from {podcast_file_prefix!r}"):
conversation = await podcast.Podcast.read_from_file(
podcast_file_prefix, settings, dbname
)
if isinstance(provider, SqliteStorageProvider):
provider.db.commit()

await print_conversation_stats(conversation, verbose)

return query.QueryEvalContext(conversation)


def load_index_file[T: Mapping[str, typing.Any]](
file: str | None, selector: str, cls: type[T], verbose: bool = True
) -> tuple[list[T], dict[str, T]]:
Expand Down
Loading