diff --git a/tools/load_json.py b/tools/load_json.py new file mode 100644 index 00000000..573b4a79 --- /dev/null +++ b/tools/load_json.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +JSON to SQLite Database Loader + +This tool loads a JSON-serialized podcast database into a SQLite database +that can be queried using tools/query.py. + +Usage: + python tools/load_json.py --database + python tools/load_json.py tests/testdata/Episode_53_AdrianTchaikovsky_index -d transcript.db + +The index_path should exclude the "_data.json" suffix. +""" + +import argparse +import asyncio +import os + +from typeagent.aitools import utils +from typeagent.knowpro.convsettings import ConversationSettings +from typeagent.podcasts import podcast +from typeagent.storage.utils import create_storage_provider + + +async def load_json_to_database( + podcast_file_prefix: str, + dbname: str, + verbose: bool = False, +) -> None: + """Load JSON-serialized podcast data into a SQLite database. + + Args: + podcast_file_prefix: Path to podcast index files (without "_data.json" suffix) + dbname: Path to SQLite database file (must be empty) + verbose: Whether to show verbose output + """ + if verbose: + print(f"Loading podcast from JSON: {podcast_file_prefix}") + print(f"Target database: {dbname}") + + # Create settings and storage provider + settings = ConversationSettings() + settings.storage_provider = await create_storage_provider( + settings.message_text_index_settings, + settings.related_term_index_settings, + dbname, + podcast.PodcastMessage, + ) + + # Get the storage provider to check if database is empty + provider = await settings.get_storage_provider() + msgs = await provider.get_message_collection() + + # Check if database already has data + msg_count = await msgs.size() + if msg_count > 0: + raise RuntimeError( + f"Database '{dbname}' already contains {msg_count} messages. " + "The database must be empty to load new data. " + "Please use a different database file or remove the existing one." + ) + + # Load podcast from JSON files + with utils.timelog(f"Loading podcast from {podcast_file_prefix!r}"): + async with provider: + conversation = await podcast.Podcast.read_from_file( + podcast_file_prefix, settings, dbname + ) + + # Print statistics + if verbose: + print(f"\nSuccessfully loaded podcast data:") + print(f" {await conversation.messages.size()} messages") + print(f" {await conversation.semantic_refs.size()} semantic refs") + if conversation.semantic_ref_index: + print( + f" {await conversation.semantic_ref_index.size()} semantic ref index entries" + ) + + print(f"\nDatabase created: {dbname}") + print(f"\nTo query the database, use:") + print(f" python tools/query.py --database '{dbname}' --query 'Your question here'") + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Load JSON-serialized podcast data into a SQLite database", + ) + + parser.add_argument( + "-d", + "--database", + required=True, + help="Path to the SQLite database file (must be empty)", + ) + + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Show verbose output including statistics", + ) + + parser.add_argument( + "index_path", + help="Path to the podcast index files (excluding the '_data.json' suffix)", + ) + + args = parser.parse_args() + + # Ensure index file exists + index_file = args.index_path + "_data.json" + if not os.path.exists(index_file): + raise SystemExit( + f"Error: Podcast index file not found: {index_file}\n" + f"Please verify the path exists and is accessible.\n" + f"Note: The path should exclude the '_data.json' suffix." + ) + + # Load environment variables for API access + utils.load_dotenv() + + # Run the loading process + asyncio.run(load_json_to_database(args.index_path, args.database, args.verbose)) + + +if __name__ == "__main__": + main() diff --git a/tools/query.py b/tools/query.py index 3e93ecde..f2cc311f 100644 --- a/tools/query.py +++ b/tools/query.py @@ -55,7 +55,6 @@ Topic, ) from typeagent.podcasts import podcast -from typeagent.storage.sqlite.provider import SqliteStorageProvider from typeagent.storage.utils import create_storage_provider ### Classes ### @@ -536,24 +535,6 @@ async def main(): args = parser.parse_args() fill_in_debug_defaults(parser, args) - # Validate required podcast argument - if args.podcast is None and args.database is None: - scriptname = sys.argv[0] - raise SystemExit( - f"Error: Either --podcast or --database is required.\n" - f"Usage: python {scriptname} --podcast \n" - f" or: python {scriptname} --database \n" - f"Example: python {scriptname} --podcast tests/testdata/Episode_53_AdrianTchaikovsky_index" - ) - if args.podcast is not None: - index_file = args.podcast + "_data.json" - if not os.path.exists(index_file): - raise SystemExit( - f"Error: Podcast index file not found: {index_file}\n" - f"Please verify the path exists and is accessible.\n" - f"Note: The path should exclude the '_data.json' suffix." - ) - if args.logfire: utils.setup_logfire() @@ -564,9 +545,18 @@ async def main(): args.database, podcast.PodcastMessage, ) - query_context = await load_podcast_index( - args.podcast, settings, args.database, args.verbose - ) + + # Load existing database + provider = await settings.get_storage_provider() + msgs = await provider.get_message_collection() + if await msgs.size() == 0: + raise SystemExit(f"Error: Database '{args.database}' is empty.") + + with utils.timelog(f"Loading conversation from database {args.database!r}"): + conversation = await podcast.Podcast.create(settings) + + await print_conversation_stats(conversation, args.verbose) + query_context = query.QueryEvalContext(conversation) ar_list, ar_index = load_index_file( args.qafile, "question", QuestionAnswerData, args.verbose @@ -943,12 +933,6 @@ def make_arg_parser(description: str) -> argparse.ArgumentParser: ), ) - parser.add_argument( - "--podcast", - type=str, - default=None, - help="Path to the podcast index files (excluding the '_data.json' suffix)", - ) explain_qa = "a list of questions and answers to test the full pipeline" parser.add_argument( "--qafile", @@ -973,8 +957,8 @@ def make_arg_parser(description: str) -> argparse.ArgumentParser: "-d", "--database", type=str, - default=None, - help="Path to the SQLite database file (default: in-memory)", + required=True, + help="Path to the SQLite database file", ) parser.add_argument( "--query", @@ -1110,30 +1094,6 @@ def fill_in_debug_defaults( ### Data loading ### -async def load_podcast_index( - podcast_file_prefix: str, - settings: ConversationSettings, - dbname: str | None, - verbose: bool = True, -) -> query.QueryEvalContext: - provider = await settings.get_storage_provider() - msgs = await provider.get_message_collection() - if await msgs.size() > 0: # Sqlite provider with existing non-empty database - with utils.timelog(f"Reusing database {dbname!r}"): - conversation = await podcast.Podcast.create(settings) - else: - with utils.timelog(f"Loading podcast from {podcast_file_prefix!r}"): - conversation = await podcast.Podcast.read_from_file( - podcast_file_prefix, settings, dbname - ) - if isinstance(provider, SqliteStorageProvider): - provider.db.commit() - - await print_conversation_stats(conversation, verbose) - - return query.QueryEvalContext(conversation) - - def load_index_file[T: Mapping[str, typing.Any]]( file: str | None, selector: str, cls: type[T], verbose: bool = True ) -> tuple[list[T], dict[str, T]]: