diff --git a/Dockerfile b/Dockerfile index 3c0eb92..b57e698 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,5 +27,5 @@ COPY --from=builder /app/.venv /app/.venv COPY --from=builder /app/src /app/src EXPOSE 9001 - -CMD ["python", "src/server.py", "--host", "0.0.0.0", "--transport", "sse"] +# TODO Move to a path +ENTRYPOINT ["python", "src/server.py", "--host", "0.0.0.0", "--port", "9001", "--transport", "http", "--path", "/db-tools"] diff --git a/build/env.tmpl b/build/env.tmpl new file mode 100644 index 0000000..cafbe30 --- /dev/null +++ b/build/env.tmpl @@ -0,0 +1,9 @@ +DB_HOST=yourdb +DB_USER=claude-mariadb +DB_PASSWORD=testing-testing-testing +DB_PORT=3306 +DB_NAME=ThatDB +MCP_READ_ONLY=true +MCP_MAX_POOL_SIZE=10 +JWT_ISSUER=https://truth.domain.tld +JWT_AUDIENCE=https://truth.domain.tld diff --git a/pyproject.toml b/pyproject.toml index 4c22b8b..3d26063 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,10 +6,11 @@ readme = "README.md" requires-python = ">=3.11" dependencies = [ "asyncmy>=0.2.10", - "fastmcp[websockets]==2.12.1", + "fastmcp==2.13.1", "google-genai>=1.15.0", "openai>=1.78.1", "python-dotenv>=1.1.0", "sentence-transformers>=4.1.0", "tokenizers==0.21.2", + "python-json-logger>=4.0.0" ] diff --git a/src/config.py b/src/config.py index 85270c2..73d905d 100644 --- a/src/config.py +++ b/src/config.py @@ -1,18 +1,23 @@ # config.py import os from dotenv import load_dotenv -import logging -from logging.handlers import RotatingFileHandler -from pathlib import Path + +# Import our dedicated logging configuration +from logging_config import setup_logger, get_logger, setup_third_party_logging # Load environment variables from .env file load_dotenv() +# --- Authentication Configuration --- +JWT_AUDIENCE = os.getenv("JWT_AUDIENCE", "mariadb_ops_server") +JWT_ISSUER = os.getenv("JWT_ISSUER", "http://localhost") + # --- Logging Configuration --- LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() LOG_FILE_PATH = os.getenv("LOG_FILE", "logs/mcp_server.log") LOG_MAX_BYTES = int(os.getenv("LOG_MAX_BYTES", 10 * 1024 * 1024)) LOG_BACKUP_COUNT = int(os.getenv("LOG_BACKUP_COUNT", 5)) +THIRD_PARTY_LOG_LEVEL = os.getenv("THIRD_PARTY_LOG_LEVEL", "WARNING").upper() ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS") if ALLOWED_ORIGINS: @@ -26,36 +31,18 @@ else: ALLOWED_HOSTS = ["localhost", "127.0.0.1"] -# Get the root logger -root_logger = logging.getLogger() -root_logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO)) - -# Create formatter -log_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - -# Remove existing handlers to avoid duplication if script is reloaded -for handler in root_logger.handlers[:]: - root_logger.removeHandler(handler) - -# Console Handler -console_handler = logging.StreamHandler() -console_handler.setFormatter(log_formatter) -root_logger.addHandler(console_handler) - -# File Handler - Ensure log directory exists -log_file = Path(LOG_FILE_PATH) -log_file.parent.mkdir(parents=True, exist_ok=True) - -file_handler = RotatingFileHandler( - log_file, - maxBytes=LOG_MAX_BYTES, - backupCount=LOG_BACKUP_COUNT +# Set up the dedicated logger for this project (NOT the root logger) +logger = setup_logger( + log_level=LOG_LEVEL, + log_file_path=LOG_FILE_PATH, + log_max_bytes=LOG_MAX_BYTES, + log_backup_count=LOG_BACKUP_COUNT, + enable_console=True, + enable_file=True ) -file_handler.setFormatter(log_formatter) -root_logger.addHandler(file_handler) -# The specific logger used in server.py and elsewhere will inherit this configuration. -logger = logging.getLogger(__name__) +# Configure third-party library logging to reduce noise +setup_third_party_logging(level=THIRD_PARTY_LOG_LEVEL) # --- Database Configuration --- DB_HOST = os.getenv("DB_HOST", "localhost") @@ -104,4 +91,4 @@ logger.info(f"No EMBEDDING_PROVIDER selected or it is set to None. Disabling embedding features.") logger.info(f"Read-only mode: {MCP_READ_ONLY}") -logger.info(f"Logging to console and to file: {LOG_FILE_PATH} (Level: {LOG_LEVEL}, MaxSize: {LOG_MAX_BYTES}B, Backups: {LOG_BACKUP_COUNT})") \ No newline at end of file +logger.info(f"Logging to console and to file: {LOG_FILE_PATH} (Level: {LOG_LEVEL}, MaxSize: {LOG_MAX_BYTES}B, Backups: {LOG_BACKUP_COUNT})") diff --git a/src/embeddings.py b/src/embeddings.py index e16ea1c..df4fb1d 100644 --- a/src/embeddings.py +++ b/src/embeddings.py @@ -1,19 +1,22 @@ -import logging import sys import os import asyncio from typing import List, Optional, Dict, Any, Union, Awaitable import numpy as np -# Import configuration variables and the logger instance +# Import configuration variables from config import ( EMBEDDING_PROVIDER, OPENAI_API_KEY, GEMINI_API_KEY, - HF_MODEL, - logger + HF_MODEL ) +# Import the dedicated logger +from logging_config import get_logger + +logger = get_logger("embeddings") + # Import specific client libraries try: from openai import AsyncOpenAI, OpenAIError diff --git a/src/logging_config.py b/src/logging_config.py new file mode 100644 index 0000000..22b558f --- /dev/null +++ b/src/logging_config.py @@ -0,0 +1,153 @@ +""" +Logging configuration for the mariadb-mcp project. + +This module sets up a dedicated logger that does NOT configure the root logger, +following Python best practices for library code. +""" +import logging +import os +from logging.handlers import RotatingFileHandler +from pathlib import Path +from pythonjsonlogger import jsonlogger + + +# Logger name for this project - NOT the root logger +LOGGER_NAME = "mariadb_mcp" + + +class CustomJsonFormatter(jsonlogger.JsonFormatter): + """ + Custom JSON formatter that includes timestamp, calling context, + and all relevant fields for structured logging. + """ + def add_fields(self, log_record, record, message_dict): + super(CustomJsonFormatter, self).add_fields(log_record, record, message_dict) + + # Ensure timestamp is always present + if not log_record.get('timestamp'): + log_record['timestamp'] = self.formatTime(record, self.datefmt) + + # Add calling context + log_record['level'] = record.levelname + log_record['logger'] = record.name + log_record['module'] = record.module + log_record['function'] = record.funcName + log_record['line'] = record.lineno + + # Add process/thread info if relevant + if record.process: + log_record['process'] = record.process + if record.thread: + log_record['thread'] = record.thread + + +def setup_logger( + log_level: str = "INFO", + log_file_path: str = "logs/mcp_server.log", + log_max_bytes: int = 10 * 1024 * 1024, + log_backup_count: int = 5, + enable_console: bool = True, + enable_file: bool = True +) -> logging.Logger: + """ + Set up the dedicated logger for mariadb-mcp. + + This function creates a logger with the name "mariadb_mcp" and configures + it with console and/or file handlers. It does NOT touch the root logger. + + Args: + log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + log_file_path: Path to the log file + log_max_bytes: Maximum size of log file before rotation + log_backup_count: Number of backup log files to keep + enable_console: Whether to enable console logging + enable_file: Whether to enable file logging + + Returns: + Configured logger instance + """ + # Get the dedicated logger (NOT root logger) + logger = logging.getLogger(LOGGER_NAME) + + # Set the level + logger.setLevel(getattr(logging, log_level.upper(), logging.INFO)) + + # Prevent propagation to root logger to avoid duplicate logs + logger.propagate = False + + # Remove any existing handlers to avoid duplication + for handler in logger.handlers[:]: + logger.removeHandler(handler) + + # Create formatter with timestamp and calling context + formatter = CustomJsonFormatter( + fmt='%(timestamp)s %(level)s %(name)s %(module)s %(funcName)s:%(lineno)d %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # Console Handler + if enable_console: + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # File Handler + if enable_file: + # Ensure log directory exists + log_file = Path(log_file_path) + log_file.parent.mkdir(parents=True, exist_ok=True) + + file_handler = RotatingFileHandler( + log_file, + maxBytes=log_max_bytes, + backupCount=log_backup_count + ) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +def get_logger(name: str = None) -> logging.Logger: + """ + Get a child logger under the mariadb_mcp logger hierarchy. + + Args: + name: Optional name for the child logger. If None, returns the root mariadb_mcp logger. + + Returns: + Logger instance + """ + if name: + return logging.getLogger(f"{LOGGER_NAME}.{name}") + return logging.getLogger(LOGGER_NAME) + + +def setup_third_party_logging(level: str = "WARNING"): + """ + Configure logging for third-party libraries like fastmcp, uvicorn, etc. + + This sets the logging level for known third-party loggers to reduce noise + without touching the root logger. + + Args: + level: Logging level for third-party libraries + """ + third_party_loggers = [ + "fastmcp", + "uvicorn", + "uvicorn.access", + "uvicorn.error", + "starlette", + "asyncmy", + "httpx", + "httpcore" + ] + + log_level = getattr(logging, level.upper(), logging.WARNING) + + for logger_name in third_party_loggers: + third_party_logger = logging.getLogger(logger_name) + third_party_logger.setLevel(log_level) + # Ensure they don't propagate excessively + third_party_logger.propagate = True diff --git a/src/server.py b/src/server.py index 7b18e66..98b2bc6 100644 --- a/src/server.py +++ b/src/server.py @@ -4,22 +4,23 @@ from config import ( DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_CHARSET, MCP_READ_ONLY, MCP_MAX_POOL_SIZE, EMBEDDING_PROVIDER, - ALLOWED_ORIGINS, ALLOWED_HOSTS, + ALLOWED_ORIGINS, ALLOWED_HOSTS, JWT_ISSUER, JWT_AUDIENCE, logger ) -import asyncio +import datetime import argparse import re from typing import List, Dict, Any, Optional -from functools import partial +from functools import partial import asyncmy -import anyio +import anyio from fastmcp import FastMCP, Context from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware +from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware # Import EmbeddingService for vector store creation @@ -31,7 +32,13 @@ embedding_service = EmbeddingService() from asyncmy.errors import Error as AsyncMyError +from fastmcp.server.auth.providers.jwt import JWTVerifier +auth = JWTVerifier( + jwks_uri=f"{JWT_ISSUER}/.well-known/jwks.json", + issuer=JWT_ISSUER, + audience=JWT_AUDIENCE +) # --- MariaDB MCP Server Class --- class MariaDBServer: """ @@ -39,7 +46,7 @@ class MariaDBServer: Manages the database connection pool. """ def __init__(self, server_name="MariaDB_Server", autocommit=True): - self.mcp = FastMCP(server_name) + self.mcp = FastMCP(server_name, stateless_http=True) self.pool: Optional[asyncmy.Pool] = None self.autocommit = not MCP_READ_ONLY self.is_read_only = MCP_READ_ONLY @@ -50,7 +57,7 @@ def __init__(self, server_name="MariaDB_Server", autocommit=True): async def create_vector_store(self, database_name: str, vector_store_name: str, model_name: Optional[str] = None, distance_function: Optional[str] = None) -> dict: """ This tool creates a table which stores embeddings. - + Creates a new vector store (table) with a predefined schema if it doesn't already exist. It first checks if the database exists, creating it if necessary. Then, it checks if the table exists; if so, it reports that. @@ -88,13 +95,13 @@ async def initialize_pool(self): "autocommit": self.autocommit, "pool_recycle": 3600 } - + if DB_CHARSET: pool_params["charset"] = DB_CHARSET logger.info(f"Creating connection pool for {DB_USER}@{DB_HOST}:{DB_PORT}/{DB_NAME} (max size: {MCP_MAX_POOL_SIZE}, charset: {DB_CHARSET})") else: logger.info(f"Creating connection pool for {DB_USER}@{DB_HOST}:{DB_PORT}/{DB_NAME} (max size: {MCP_MAX_POOL_SIZE})") - + self.pool = await asyncmy.create_pool(**pool_params) logger.info("Connection pool initialized successfully.") except AsyncMyError as e: @@ -126,14 +133,14 @@ async def _execute_query(self, sql: str, params: Optional[tuple] = None, databas raise RuntimeError("Database connection pool not available.") allowed_prefixes = ('SELECT', 'SHOW', 'DESC', 'DESCRIBE', 'USE') - + # Strip SQL comments from query # Remove single-line comments (-- comment) sql_no_comments = re.sub(r'--.*?$', '', sql, flags=re.MULTILINE) # Remove multi-line comments (/* comment */) sql_no_comments = re.sub(r'/\*.*?\*/', '', sql_no_comments, flags=re.DOTALL) sql_no_comments = sql_no_comments.strip() - + query_upper = sql_no_comments.upper() is_allowed_read_query = any(query_upper.startswith(prefix) for prefix in allowed_prefixes) @@ -181,12 +188,12 @@ async def _execute_query(self, sql: str, params: Optional[tuple] = None, databas conn_state = f"Connection: {'acquired' if conn else 'not acquired'}" logger.error(f"Unexpected error during query execution ({conn_state}): {e}", exc_info=True) raise RuntimeError(f"An unexpected error occurred: {e}") from e - + async def _database_exists(self, database_name: str) -> bool: """Checks if a database exists.""" if not database_name or not database_name.isidentifier(): logger.warning(f"_database_exists called with invalid database_name: {database_name}") - return False + return False sql = "SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = %s" try: @@ -195,7 +202,7 @@ async def _database_exists(self, database_name: str) -> bool: except Exception as e: logger.error(f"Error checking if database '{database_name}' exists: {e}", exc_info=True) return False - + async def _table_exists(self, database_name: str, table_name: str) -> bool: """Checks if a table exists in the given database.""" if not database_name or not database_name.isidentifier() or \ @@ -256,7 +263,7 @@ async def _is_vector_store(self, database_name: str, table_name: str) -> bool: logger.error(f"Error checking if '{database_name}.{table_name}' is a vector store: {e}", exc_info=True) return False # Treat errors as "not a vector store" for safety in deletion context - + # --- MCP Tool Definitions --- async def list_databases(self) -> List[str]: @@ -332,7 +339,7 @@ async def get_table_schema(self, database_name: str, table_name: str) -> Dict[st except Exception as e: logger.error(f"TOOL ERROR: get_table_schema failed for database_name={database_name}, table_name={table_name}: {e}", exc_info=True) raise RuntimeError(f"Could not retrieve schema for table '{database_name}.{table_name}'.") - + async def get_table_schema_with_relations(self, database_name: str, table_name: str) -> Dict[str, Any]: """ Retrieves table schema with foreign key relationship information. @@ -349,10 +356,10 @@ async def get_table_schema_with_relations(self, database_name: str, table_name: try: # 1. Get basic schema information basic_schema = await self.get_table_schema(database_name, table_name) - + # 2. Retrieve foreign key information fk_sql = """ - SELECT + SELECT kcu.COLUMN_NAME as column_name, kcu.CONSTRAINT_NAME as constraint_name, kcu.REFERENCED_TABLE_NAME as referenced_table, @@ -363,20 +370,20 @@ async def get_table_schema_with_relations(self, database_name: str, table_name: INNER JOIN information_schema.REFERENTIAL_CONSTRAINTS rc ON kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME AND kcu.CONSTRAINT_SCHEMA = rc.CONSTRAINT_SCHEMA - WHERE kcu.TABLE_SCHEMA = %s - AND kcu.TABLE_NAME = %s + WHERE kcu.TABLE_SCHEMA = %s + AND kcu.TABLE_NAME = %s AND kcu.REFERENCED_TABLE_NAME IS NOT NULL ORDER BY kcu.CONSTRAINT_NAME, kcu.ORDINAL_POSITION """ - + fk_results = await self._execute_query(fk_sql, params=(database_name, table_name)) - + # 3. Add foreign key information to the basic schema enhanced_schema = {} for col_name, col_info in basic_schema.items(): enhanced_schema[col_name] = col_info.copy() enhanced_schema[col_name]['foreign_key'] = None - + # 4. Add foreign key information to the corresponding columns for fk_row in fk_results: column_name = fk_row['column_name'] @@ -388,16 +395,16 @@ async def get_table_schema_with_relations(self, database_name: str, table_name: 'on_update': fk_row['on_update'], 'on_delete': fk_row['on_delete'] } - + # 5. Return the enhanced schema with foreign key relations result = { 'table_name': table_name, 'columns': enhanced_schema } - + logger.info(f"TOOL END: get_table_schema_with_relations completed. Columns: {len(enhanced_schema)}, Foreign keys: {len(fk_results)}") return result - + except Exception as e: logger.error(f"TOOL ERROR: get_table_schema_with_relations failed for database_name={database_name}, table_name={table_name}: {e}", exc_info=True) raise RuntimeError(f"Could not retrieve schema with relations for table '{database_name}.{table_name}': {str(e)}") @@ -421,7 +428,7 @@ async def execute_sql(self, sql_query: str, database_name: str, parameters: Opti except Exception as e: logger.error(f"TOOL ERROR: execute_sql failed for database_name={database_name}, sql_query={sql_query[:100]}, parameters={parameters}: {e}", exc_info=True) raise - + async def create_database(self, database_name: str) -> Dict[str, Any]: """ Creates a new database if it doesn't exist. @@ -500,14 +507,14 @@ async def create_vector_store_tool(self, raise ValueError(f"Invalid distance_function: '{distance_function}'. Must be one of {list(valid_distance_functions_map.keys())}.") else: logger.info(f"Distance function not provided, defaulting to '{processed_distance_function_sql}'.") - + logger.info(f"Using SQL distance function: '{processed_distance_function_sql}'.") # --- Database Existence Check --- if not await self._database_exists(database_name): logger.info(f"Database '{database_name}' does not exist. Attempting to create it.") try: - await self.create_database(database_name) + await self.create_database(database_name) except Exception as db_create_e: logger.error(f"Failed to ensure database '{database_name}' existence: {db_create_e}", exc_info=True) raise RuntimeError(f"Failed to ensure database '{database_name}' exists before creating vector store. Reason: {str(db_create_e)}") @@ -537,7 +544,7 @@ async def create_vector_store_tool(self, try: # --- Execute Query --- await self._execute_query(schema_query, database=database_name) - + success_message = f"Vector store '{vector_store_name}' created successfully in database '{database_name}' with {processed_distance_function_sql} distance." logger.info(f"TOOL END: create_vector_store completed. {success_message}") return { @@ -563,7 +570,7 @@ async def list_vector_stores(self, database_name: str) -> List[str]: Returns: - List[str]: A list of table names that are identified as vector stores. Returns an empty list if no such tables are found or if the database doesn't exist. - + Raises: - ValueError: If the database_name is invalid. - RuntimeError: For database errors during the operation. @@ -593,20 +600,20 @@ async def list_vector_stores(self, database_name: str) -> List[str]: AND T1.COLUMN_NAME = T2.COLUMN_NAME WHERE T1.TABLE_SCHEMA = %s AND UPPER(T1.COLUMN_NAME) = 'EMBEDDING' - AND UPPER(T1.DATA_TYPE) = 'VECTOR' + AND UPPER(T1.DATA_TYPE) = 'VECTOR' ORDER BY T1.TABLE_NAME; """ try: results = await self._execute_query(sql_query, params=(database_name,), database='information_schema') - + store_list = [row['TABLE_NAME'] for row in results if 'TABLE_NAME' in row] - + if not store_list: logger.info(f"No vector stores found in database '{database_name}'.") else: logger.info(f"Found {len(store_list)} vector store(s) in database '{database_name}': {store_list}") - + logger.info(f"TOOL END: list_vector_stores completed for database '{database_name}'.") return store_list @@ -614,7 +621,7 @@ async def list_vector_stores(self, database_name: str) -> List[str]: error_message = f"Failed to list vector stores in database '{database_name}'." logger.error(f"TOOL ERROR: list_vector_stores. {error_message} Error: {e}", exc_info=True) raise RuntimeError(f"{error_message} Reason: {str(e)}") - + async def delete_vector_store(self, database_name: str, vector_store_name: str) -> Dict[str, Any]: @@ -659,13 +666,13 @@ async def delete_vector_store(self, message = f"Table '{vector_store_name}' in database '{database_name}' is not a valid vector store (missing indexed 'embedding' column of type VECTOR). Deletion aborted." logger.warning(message) return {"status": "not_vector_store", "message": message} - + # --- SQL Query for Deletion --- drop_query = f"DROP TABLE IF EXISTS `{vector_store_name}`;" try: await self._execute_query(drop_query, database=database_name) - + success_message = f"Vector store '{vector_store_name}' deleted successfully from database '{database_name}'." logger.info(f"TOOL END: delete_vector_store. {success_message}") return { @@ -683,7 +690,7 @@ async def delete_vector_store(self, "database_name": database_name, "vector_store_name": vector_store_name } - + async def insert_docs_vector_store(self, database_name: str, vector_store_name: str, documents: List[str], metadata: Optional[List[dict]] = None) -> dict: """ Insert a batch of documents (with optional metadata) into a vector store. @@ -727,7 +734,7 @@ async def insert_docs_vector_store(self, database_name: str, vector_store_name: if errors: result["errors"] = errors return result - + async def search_vector_store(self, user_query: str, database_name: str, vector_store_name: str, k: int = 7) -> list: """ Search a vector store for the most similar documents to a query using semantic search. @@ -758,7 +765,7 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_ emb_str = json.dumps(embedding) # Prepare the search query search_query = f""" - SELECT + SELECT document, metadata, VEC_DISTANCE_COSINE(embedding, VEC_FromText(%s)) AS distance @@ -779,7 +786,7 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_ except Exception as e: logger.error(f"Failed to search vector store {database_name}.{vector_store_name}: {e}", exc_info=True) return [] - + # --- Tool Registration (Synchronous) --- def register_tools(self): """Registers the class methods as MCP tools using the instance. This is synchronous.""" @@ -791,58 +798,58 @@ def register_tools(self): async def list_databases() -> List[str]: """Lists all accessible databases on the connected MariaDB server.""" return await self.list_databases() - + @self.mcp.tool async def list_tables(database_name: str) -> List[str]: """Lists all tables within the specified database.""" return await self.list_tables(database_name) - + @self.mcp.tool async def get_table_schema(database_name: str, table_name: str) -> Dict[str, Any]: """Retrieves the schema for a specific table in a database.""" return await self.get_table_schema(database_name, table_name) - + @self.mcp.tool async def get_table_schema_with_relations(database_name: str, table_name: str) -> Dict[str, Any]: """Retrieves table schema with foreign key relationship information.""" return await self.get_table_schema_with_relations(database_name, table_name) - + @self.mcp.tool async def execute_sql(sql_query: str, database_name: str, parameters: Optional[List[Any]] = None) -> List[Dict[str, Any]]: """Executes a read-only SQL query against a specified database.""" return await self.execute_sql(sql_query, database_name, parameters) - + @self.mcp.tool async def create_database(database_name: str) -> Dict[str, Any]: """Creates a new database if it doesn't exist.""" return await self.create_database(database_name) - + if EMBEDDING_PROVIDER is not None: @self.mcp.tool async def create_vector_store(database_name: str, vector_store_name: str, model_name: Optional[str] = None, distance_function: Optional[str] = None) -> dict: """Creates a table which stores embeddings.""" return await self.create_vector_store(database_name, vector_store_name, model_name, distance_function) - + @self.mcp.tool async def list_vector_stores(database_name: str) -> List[str]: """Lists all vector stores in a database.""" return await self.list_vector_stores(database_name) - + @self.mcp.tool async def delete_vector_store(database_name: str, vector_store_name: str) -> Dict[str, Any]: """Deletes a vector store from the specified database.""" return await self.delete_vector_store(database_name, vector_store_name) - + @self.mcp.tool async def insert_docs_vector_store(database_name: str, vector_store_name: str, documents: List[str], metadata: Optional[List[dict]] = None) -> dict: """Insert a batch of documents into a vector store.""" return await self.insert_docs_vector_store(database_name, vector_store_name, documents, metadata) - + @self.mcp.tool async def search_vector_store(user_query: str, database_name: str, vector_store_name: str, k: int = 7) -> list: """Search a vector store for similar documents.""" return await self.search_vector_store(user_query, database_name, vector_store_name, k) - + logger.info("Registered MCP tools explicitly.") # --- Async Main Server Logic --- @@ -868,7 +875,7 @@ async def run_async_server(self, transport="stdio", host="127.0.0.1", port=9001, allow_methods=["GET", "POST"], allow_headers=["*"], ), - Middleware(TrustedHostMiddleware, + Middleware(TrustedHostMiddleware, allowed_hosts=ALLOWED_HOSTS) ] if transport == "sse": @@ -881,7 +888,7 @@ async def run_async_server(self, transport="stdio", host="127.0.0.1", port=9001, logger.info(f"Starting MCP server via {transport}...") else: logger.error(f"Unsupported transport type: {transport}") - return + return # 4. Run the appropriate async listener from FastMCP await self.mcp.run_async(transport=transport, **transport_kwargs) @@ -916,10 +923,10 @@ async def run_async_server(self, transport="stdio", host="127.0.0.1", port=9001, try: # 2. Use anyio.run to manage the event loop and call the main async server logic anyio.run( - partial(server.run_async_server, - transport=args.transport, - host=args.host, - port=args.port, + partial(server.run_async_server, + transport=args.transport, + host=args.host, + port=args.port, path=args.path) ) logger.info("Server finished gracefully.") @@ -930,4 +937,4 @@ async def run_async_server(self, transport="stdio", host="127.0.0.1", port=9001, logger.critical(f"Server failed to start or crashed: {e}", exc_info=True) exit_code = 1 finally: - logger.info(f"Server exiting with code {exit_code}.") \ No newline at end of file + logger.info(f"Server exiting with code {exit_code}.")