diff --git a/.gitignore b/.gitignore index 65ce6169..67528dd5 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ sbin/** **/chatbot_graph.png **/*.sh **/etc/*.json +**/optimizer_settings.json !opentofu/**/cloudinit-oke.sh !src/entrypoint.sh !src/client/spring_ai/templates/env.sh @@ -22,7 +23,7 @@ tests/db_startup_temp/** ############################################################################## # Environment (PyVen, IDE, etc.) ############################################################################## -**/.venv +**/.*env* **/.vscode **/.DS_Store **/*.swp @@ -51,6 +52,11 @@ __pycache__/ opentofu/**/stage/*.* opentofu/**/stage/kubeconfig +############################################################################## +# AI Code Assists +############################################################################## +**/*[Cc][Ll][Aa][Uu][Dd][Ee]* + ############################################################################## # Helm ############################################################################## @@ -61,12 +67,8 @@ helm/values*.yaml ############################################################################## # Random ############################################################################## -spring_ai/src/main/resources/data/optimizer_settings.json spring_ai/target/** spring_ai/create_user.sql spring_ai/drop.sql src/client/spring_ai/target/classes/* api_server_key -.env - -optimizer_settings.json diff --git a/docs/content/client/api_server/_index.md b/docs/content/client/api_server/_index.md index 29648ee7..045f29f4 100644 --- a/docs/content/client/api_server/_index.md +++ b/docs/content/client/api_server/_index.md @@ -11,7 +11,7 @@ The {{< full_app_ref >}} is powered by an API Server to allow for any client to Each client connected to the API Server, including those from the {{< short_app_ref >}} GUI client, share the same configuration but maintain their own settings. Database, Model, OCI, and Prompt configurations are used across all clients; but which database, models, OCI profile, and prompts set are specific to each client. -When started as part of the {{< short_app_ref >}} "All-in-One" deployment, you can change the Port it listens on and the API Server Key. A restart is required for the changes to take effect. +When started as part of the {{< short_app_ref >}} "All-in-One" deployment, by setting `API_SERVER_CONTROL=TRUE` before startup, you can change the Port it listens on and the API Server Key. ![Server Configuration](images/api_server_config.png) @@ -19,7 +19,7 @@ If the API Server is started independently of the {{< short_app_ref >}} client, ## Server Configuration -During the startup of the API Server, a `server` client is created and populated with minimal settings. The `server` client is the default when calling the API Server outside of the {{< short_app_ref >}} GUI client. To copy your {{< short_app_ref >}} GUI client settings to the `server` client for use with external application clients, click the "Copy AI Optimizer Settings". +During the startup of the API Server, a `server` client is created and populated with minimal settings. The `server` client is the default when calling the API Server outside of the {{< short_app_ref >}} GUI client. To copy your {{< short_app_ref >}} GUI client settings to the `server` client for use with external application clients, click the "Copy Client Settings". ![Server Settings](images/api_server_settings.png) diff --git a/docs/content/client/api_server/images/api_server_activity.png b/docs/content/client/api_server/images/api_server_activity.png index 28a97b02..820ca3b5 100644 Binary files a/docs/content/client/api_server/images/api_server_activity.png and b/docs/content/client/api_server/images/api_server_activity.png differ diff --git a/docs/content/client/api_server/images/api_server_config.png b/docs/content/client/api_server/images/api_server_config.png index f0db2217..a19cfd59 100644 Binary files a/docs/content/client/api_server/images/api_server_config.png and b/docs/content/client/api_server/images/api_server_config.png differ diff --git a/docs/content/client/api_server/images/api_server_settings.png b/docs/content/client/api_server/images/api_server_settings.png index 47534007..642d6cc7 100644 Binary files a/docs/content/client/api_server/images/api_server_settings.png and b/docs/content/client/api_server/images/api_server_settings.png differ diff --git a/docs/content/client/chatbot/images/chatbot_history_context.png b/docs/content/client/chatbot/images/chatbot_history_context.png index 41c36295..b41dc960 100644 Binary files a/docs/content/client/chatbot/images/chatbot_history_context.png and b/docs/content/client/chatbot/images/chatbot_history_context.png differ diff --git a/docs/content/client/chatbot/images/language_parameters.png b/docs/content/client/chatbot/images/language_parameters.png index a97683db..a1b52caa 100644 Binary files a/docs/content/client/chatbot/images/language_parameters.png and b/docs/content/client/chatbot/images/language_parameters.png differ diff --git a/docs/content/client/configuration/images/database_config.png b/docs/content/client/configuration/images/database_config.png index 658e9ba1..b922d8e1 100644 Binary files a/docs/content/client/configuration/images/database_config.png and b/docs/content/client/configuration/images/database_config.png differ diff --git a/docs/content/client/configuration/images/models_add.png b/docs/content/client/configuration/images/models_add.png index 6404ef8e..112e2083 100644 Binary files a/docs/content/client/configuration/images/models_add.png and b/docs/content/client/configuration/images/models_add.png differ diff --git a/docs/content/client/configuration/images/models_config.png b/docs/content/client/configuration/images/models_config.png index b6b0a9f5..26a617cb 100644 Binary files a/docs/content/client/configuration/images/models_config.png and b/docs/content/client/configuration/images/models_config.png differ diff --git a/docs/content/client/configuration/images/oci_config.png b/docs/content/client/configuration/images/oci_config.png index 9820c306..14ea5262 100644 Binary files a/docs/content/client/configuration/images/oci_config.png and b/docs/content/client/configuration/images/oci_config.png differ diff --git a/docs/content/client/configuration/images/oci_genai_config.png b/docs/content/client/configuration/images/oci_genai_config.png deleted file mode 100644 index ea3b6efe..00000000 Binary files a/docs/content/client/configuration/images/oci_genai_config.png and /dev/null differ diff --git a/docs/content/client/configuration/images/settings_download.png b/docs/content/client/configuration/images/settings_download.png index 5a43d9a3..e1b2ca69 100644 Binary files a/docs/content/client/configuration/images/settings_download.png and b/docs/content/client/configuration/images/settings_download.png differ diff --git a/docs/content/client/configuration/images/settings_spring_ai.png b/docs/content/client/configuration/images/settings_spring_ai.png index 7c3637cb..477d2d82 100644 Binary files a/docs/content/client/configuration/images/settings_spring_ai.png and b/docs/content/client/configuration/images/settings_spring_ai.png differ diff --git a/docs/content/client/configuration/images/settings_upload.png b/docs/content/client/configuration/images/settings_upload.png index da784242..cd9f1614 100644 Binary files a/docs/content/client/configuration/images/settings_upload.png and b/docs/content/client/configuration/images/settings_upload.png differ diff --git a/docs/content/client/configuration/model_config.md b/docs/content/client/configuration/model_config.md index ec6c0fec..f6208f73 100644 --- a/docs/content/client/configuration/model_config.md +++ b/docs/content/client/configuration/model_config.md @@ -13,30 +13,13 @@ spell-checker:ignore ollama, mxbai, nomic, thenlper, minilm, uniqueid, huggingfa At a minimum, a Large _Language Model_ (LLM) must be configured in {{< short_app_ref >}} for basic functionality. For Retrieval-Augmented Generation (**RAG**), an _Embedding Model_ will also need to be configured. -{{% notice style="default" title="Model APIs" icon="circle-info" %}} -If there is a specific model API that you would like to use, please [open an issue in GitHub](https://github.com/oracle/ai-optimizer/issues/new). -{{% /notice %}} - -| Type | API | Location | -| ----- | -------------------------------------------------------- | ------------- | -| LLM | [ChatOCIGenAI](#additional-information) | Private Cloud | -| LLM | [ChatOllama](#additional-information) | On-Premises | -| LLM | [CompatOpenAI](#additional-information) | On-Premises | -| LLM | [OpenAI](#additional-information) | Third-Party | -| LLM | [ChatPerplexity](#additional-information) | Third-Party | -| LLM | [Cohere](#additional-information) | Third-Party | -| Embed | [OCIGenAIEmbeddings](#additional-information) | Private Cloud | -| Embed | [OllamaEmbeddings](#additional-information) | On-Premises | -| Embed | [HuggingFaceEndpointEmbeddings](#additional-information) | On-Premises | -| Embed | [CompatOpenAIEmbeddings](#additional-information) | On-Premises | -| Embed | [OpenAIEmbeddings](#additional-information) | Third-Party | -| Embed | [CohereEmbeddings](#additional-information) | Third-Party | +There is an extensive list of different API Model APIs available you can choose from. ## Configuration The models can either be configured using environment variables or through the {{< short_app_ref >}} interface. To configure models through environment variables, please read the [Additional Information](#additional-information) about the specific model you would like to configure. -To configure an LLM or embedding model from the {{< short_app_ref >}}, navigate to `Configuration -> Models`: +To configure an LLM or embedding model from the {{< short_app_ref >}}, navigate to _Configuration_ page and _Models_ tab: ![Model Config](../images/models_config.png) diff --git a/docs/content/client/configuration/oci_config.md b/docs/content/client/configuration/oci_config.md index 9acb7c9b..87575b0b 100644 --- a/docs/content/client/configuration/oci_config.md +++ b/docs/content/client/configuration/oci_config.md @@ -17,22 +17,23 @@ Oracle Cloud Infrastructure (OCI) can _optionally_ be configured to enable addit ## Configuration -OCI can either be configured through the [{{< short_app_ref >}} interface](#{{< short_app_ref >}}-interface), a [CLI Configuration File](#config-file), or by using [environment variables](#environment-variables). +OCI can either be configured through the [{{< short_app_ref >}} interface](#{{< short_app_ref >}}-interface), a [CLI Configuration File](#config-file), or by using [environment variables](#environment-variables). + You will need to [generate an API Key](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm#two) to obtain the required configuration values. --- ### Interface -To configure the Database from the {{< short_app_ref >}}, navigate to `Configuration -> OCI`: +To configure the Database from the {{< short_app_ref >}}, navigate to _Configuration_ menu and _OCI_ tab: ![OCI Config](../images/oci_config.png) -OCI GenAI Services can be configured once OCI access has been confirmed: +Provide the values obtained by [generating an API Key](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm#two). + +OCI GenAI Services can also be configured on this page, once OCI access has been confirmed. -![OCI GenAI Config](../images/oci_genai_config.png) -Provide the values obtained by [generating an API Key](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm#two). --- diff --git a/docs/content/client/configuration/settings.md b/docs/content/client/configuration/settings.md index e36377fd..82e38f1c 100644 --- a/docs/content/client/configuration/settings.md +++ b/docs/content/client/configuration/settings.md @@ -12,11 +12,11 @@ Once you are happy with the specific configuration of your {{< short_app_ref >}} ## View and Download -To view and download the {{< short_app_ref >}} configuration, navigate to `Configuration -> Settings`: +To view and download the {{< short_app_ref >}} configuration, navigate to the _Configuration_ page and _Settings_ tab: ![Download Settings](../images/settings_download.png) -{{< icon "triangle-exclamation" >}} Settings contain sensitive information such as database passwords and API Keys. By default, these settings will not be exported and will have to be re-entered after uploading the settings in a new instance of the {{< short_app_ref >}}. If have a secure way to store the settings and would would like to export the sensitive data, tick the "Include Sensitive Settings" box. +{{< icon "triangle-exclamation" >}} Settings contain sensitive information such as database passwords and API Keys. By default, these settings will not be exported and will have to be re-entered after uploading the settings in a new instance of the {{< short_app_ref >}}. If you have a secure way to store the settings and would would like to export the sensitive data, tick the "Include Sensitive Settings" box. ## Upload @@ -29,7 +29,7 @@ To upload previously downloaded settings, navigate to `Configuration -> Settings If there are differences found, you can review the differences before clicking "Apply New Settings". -## SpringAI +## Source Code Templates You can download from the console a basic template that could help to expose as a OpenAI API compliant REST endpoint the RAG Chatbot defined in the chat console. If your configuration has both OLLAMA or OpenAI as providers for chat and embeddings LLMs, it will appear a button named “Download SpringAI”: diff --git a/docs/content/client/testbed/images/generate.png b/docs/content/client/testbed/images/generate.png index 344b7671..e56d893d 100644 Binary files a/docs/content/client/testbed/images/generate.png and b/docs/content/client/testbed/images/generate.png differ diff --git a/docs/content/client/tools/images/embed.png b/docs/content/client/tools/images/embed.png index 55063889..6386ca64 100644 Binary files a/docs/content/client/tools/images/embed.png and b/docs/content/client/tools/images/embed.png differ diff --git a/docs/content/client/tools/images/prompt_eng_context.png b/docs/content/client/tools/images/prompt_eng_context.png index 371d7b09..5912c869 100644 Binary files a/docs/content/client/tools/images/prompt_eng_context.png and b/docs/content/client/tools/images/prompt_eng_context.png differ diff --git a/docs/content/client/tools/images/prompt_eng_system.png b/docs/content/client/tools/images/prompt_eng_system.png index 163fba39..a5258153 100644 Binary files a/docs/content/client/tools/images/prompt_eng_system.png and b/docs/content/client/tools/images/prompt_eng_system.png differ diff --git a/docs/content/client/tools/images/split.png b/docs/content/client/tools/images/split.png index 495498eb..488aa56e 100644 Binary files a/docs/content/client/tools/images/split.png and b/docs/content/client/tools/images/split.png differ diff --git a/docs/content/client/tools/split_embed.md b/docs/content/client/tools/split_embed.md index a58f4044..83d5e881 100644 --- a/docs/content/client/tools/split_embed.md +++ b/docs/content/client/tools/split_embed.md @@ -10,27 +10,29 @@ Licensed under the Universal Permissive License v1.0 as shown at http://oss.orac The first phase building of building a RAG Chatbot using Vector Search starts with the document chunking based on vector embeddings generation. Embeddings will be stored into a vector store to be retrieved by vectors distance search and added to the LLM context in order to answer the question grounded to the information provided. -We choose the freedom to exploit LLMs for vector embeddings provided by public services like Cohere, OpenAI, and Perplexity, or running on top a GPU compute node managed by the user and exposed through open source platforms like OLLAMA or HuggingFace, to avoid sharing data with external services that are beyond full customer control. +You have the freedom to choose different Embedding Models for vector embeddings provided by public services like Cohere, OpenAI, and Perplexity, or local models running on top a GPU compute node managed by the yourself. Running a local model, such as Ollama or HuggingFace, avoids sharing data with external services that are beyond your control. -From the **Split/Embed** voice of the left side menu, you’ll access to the ingestion page: +From the _Tools_ menu, select the _Split/Embed_ tab to perform the splitting and embedding process: ![Split](../images/split.png) -The Load and Split Documents, parts of Split/Embed form, will allow to choose documents (txt,pdf,html,etc.) stored on the Object Storage service available on the Oracle Cloud Infrastructure, on the client’s desktop or getting from URLs, like shown in following snapshot: +The Load and Split Documents, parts of Split/Embed form, will allow to choose documents (txt,pdf,html,etc.) stored on the Object Storage service available on the Oracle Cloud Infrastructure, on the client’s desktop or from URLs, like shown in following snapshot: ![Embed](../images/embed.png) -It will be created a “speaking” table, like the TEXT_EMBEDDING_3_SMALL_8191_1639_COSINE in the example. You can create, on the same set of documents, several options of vectorstore table, since nobody normally knows which is the best chunking size, and then test them indipendently. +"Populating the Vector Store" will create a table in the Oracle Database with the embeddings. You can create multiple vector stores, on the same set of documents, to experiment with chunking size, distance metrics, etc, and then test them independently. ## Embedding Configuration Choose one of the **Embedding models available** from the listbox that will depend by the **Configuration/Models** page. -The **Embedding Server** URL associated to the model chosen will be shown. The **Chunk Size (tokens)** will change according the kind of embeddings model selected, as well as the **Chunk Overlap (% of Chunk Size)**. +The **Embedding Server** URL associated to the model chosen will be shown. The **Chunk Size (tokens)** will change according the kind of embeddings model selected, as well as the **Chunk Overlap (% of Chunk Size)**. + Then you have to choose one of the **Distance Metric** available in the Oracle DB23ai: - COSINE - EUCLIDEAN_DISTANCE - DOT_PRODUCT - MAX_INNER_PRODUCT + To understand the meaning of these metrics, please refer to the doc [Vector Distance Metrics](https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/vector-distance-metrics.html) in the Oracle DB23ai "*AI Vector Search User's Guide*". The **Embedding Alias** field let you to add a more meaningful info to the vectorstore table that allows you to have more than one vector table with the same: *model + chunksize + chunk_overlap + distance_strategy* combination. diff --git a/docs/content/walkthrough/_index.md b/docs/content/walkthrough/_index.md index 40651391..7e6df52f 100644 --- a/docs/content/walkthrough/_index.md +++ b/docs/content/walkthrough/_index.md @@ -224,7 +224,7 @@ Notice that there are no language models configured to use. Let's start the conf ### Configure the LLM -To configure the On-Premises **LLM**, navigate to the _Configuration -> Models_ screen: +To configure the On-Premises **LLM**, navigate to the _Configuration_ screen and _Models_ tab: 1. Enable the `llama3.1` model that you pulled earlier by clicking the _Edit_ button ![Configure LLM](images/models_edit.png) @@ -240,7 +240,7 @@ Navigate to the _ChatBot_ screen: The error about language models will have disappeared, but there are new warnings about embedding models and the database. You'll take care of those in the next steps. -The `Chat model:` will have been pre-set to the only enabled **LLM** (_llama3.1_) and a dialog box to interact with the **LLM** will be ready for input. +The `Chat model:` will have been pre-set to the only enabled **LLM** (_ollama/llama3.1_) and a dialog box to interact with the **LLM** will be ready for input. Feel free to play around with the different **LLM** Parameters, hovering over the {{% icon circle-question %}} icons to get more information on what they do. @@ -248,7 +248,7 @@ You'll come back to the _ChatBot_ later to experiment further. ### Configure the Embedding Model -To configure the On-Premises Embedding Model, navigate back to the _Configuration -> Models_ screen: +To configure the On-Premises Embedding Model, navigate back to the _Configuration_ screen and _Models_ tab: 1. Enable the `mxbai-embed-large` Embedding Model following the same process as you did for the Language Model. ![Configure Embedding Model](images/models_enable_embed.png) @@ -257,12 +257,12 @@ To configure the On-Premises Embedding Model, navigate back to the _Configuratio ### Configure the Database -To configure Oracle Database 23ai Free, navigate to the _Configuration -> Database_ screen: +To configure Oracle Database 23ai Free, navigate to the _Configuration_ screen and _Databases_ tab: 1. Enter the Database Username: `WALKTHROUGH` 1. Enter the Database Password for the database user: `OrA_41_OpTIMIZER` 1. Enter the Database Connection String: `//localhost:1521/FREEPDB1` -1. Save +1. Save Database ![Configure Database](../client/configuration/images/database_config.png) @@ -272,7 +272,7 @@ To configure Oracle Database 23ai Free, navigate to the _Configuration -> Databa With the embedding model and database configured, you can now split and embed documents for use in **Vector Search**. -Navigate to the _Split/Embed_ Screen: +Navigate to the _Tools_ screen and _Split/Embed_ tab: 1. Change the File Source to `Web` 1. Enter the URL: @@ -285,7 +285,7 @@ Navigate to the _Split/Embed_ Screen: 1. Please be patient... {{% notice style="code" title="Performance: Grab a beverage of your choosing..." icon="circle-info" %}} -Depending on the infrastructure, the embedding process can take a few minutes. As long as the "RUNNING" dialog in the top-right corner is moving... it's working. +Depending on the infrastructure, the embedding process can take a few minutes. As long as the "Populating Vector Store..." timer is running... it's working. {{% /notice %}} ![Split and Embed](images/split_embed_web.png) diff --git a/docs/content/walkthrough/images/chatbot_say_hello.png b/docs/content/walkthrough/images/chatbot_say_hello.png index fd55e77d..4bd91ddb 100644 Binary files a/docs/content/walkthrough/images/chatbot_say_hello.png and b/docs/content/walkthrough/images/chatbot_say_hello.png differ diff --git a/docs/content/walkthrough/images/models_edit.png b/docs/content/walkthrough/images/models_edit.png index a114760b..c695d7de 100644 Binary files a/docs/content/walkthrough/images/models_edit.png and b/docs/content/walkthrough/images/models_edit.png differ diff --git a/docs/content/walkthrough/images/models_enable_embed.png b/docs/content/walkthrough/images/models_enable_embed.png index 2320531d..faf9931e 100644 Binary files a/docs/content/walkthrough/images/models_enable_embed.png and b/docs/content/walkthrough/images/models_enable_embed.png differ diff --git a/docs/content/walkthrough/images/models_enable_llm.png b/docs/content/walkthrough/images/models_enable_llm.png index 2fe34781..9c0ab290 100644 Binary files a/docs/content/walkthrough/images/models_enable_llm.png and b/docs/content/walkthrough/images/models_enable_llm.png differ diff --git a/docs/content/walkthrough/images/split_embed_web.png b/docs/content/walkthrough/images/split_embed_web.png index 4895e176..e36aebfb 100644 Binary files a/docs/content/walkthrough/images/split_embed_web.png and b/docs/content/walkthrough/images/split_embed_web.png differ diff --git a/src/Dockerfile b/src/Dockerfile index b8185533..a8b582a8 100644 --- a/src/Dockerfile +++ b/src/Dockerfile @@ -34,6 +34,7 @@ FROM all_in_one_pyenv AS oai_application ENV PATH=/opt/.venv/bin:$PATH ENV TEMP=/app/tmp ENV TNS_ADMIN=/app/tns_admin +ENV API_SERVER_CONTROL="TRUE" ENV OCI_CLI_CONFIG_FILE=/app/runtime/.oci/config # Expect the .oci directory to be mounted to /app/.oci diff --git a/src/client/content/config/tabs/settings.py b/src/client/content/config/tabs/settings.py index c7d1e3ed..8e03a56c 100644 --- a/src/client/content/config/tabs/settings.py +++ b/src/client/content/config/tabs/settings.py @@ -178,7 +178,7 @@ def spring_ai_obaas(src_dir, file_name, provider, ll_config, embed_config): for item in state.prompt_configs if item["name"] == state.client_settings["prompts"]["sys"] and item["category"] == "sys" ) - logger.info(f"Prompt used in export:\n{sys_prompt}") + logger.info("Prompt used in export:\n%s", sys_prompt) with open(src_dir / "templates" / file_name, "r", encoding="utf-8") as template: template_content = template.read() diff --git a/src/client/content/tools/tabs/split_embed.py b/src/client/content/tools/tabs/split_embed.py index 7db4640c..61b37789 100644 --- a/src/client/content/tools/tabs/split_embed.py +++ b/src/client/content/tools/tabs/split_embed.py @@ -134,7 +134,7 @@ def display_split_embed() -> None: file_sources = ["OCI", "Local", "Web"] oci_lookup = st_common.state_configs_lookup("oci_configs", "auth_profile") oci_setup = oci_lookup.get(state.client_settings["oci"].get("auth_profile")) - if not oci_setup or "namespace" not in oci_setup or "tenancy" not in oci_setup: + if not oci_setup or oci_setup.get("namespace") is None or oci_setup.get("tenancy") is None: st.warning("OCI is not fully configured, some functionality is disabled", icon="⚠️") file_sources.remove("OCI") diff --git a/src/server/api/core/databases.py b/src/server/api/core/databases.py index 4c7feb98..a853448a 100644 --- a/src/server/api/core/databases.py +++ b/src/server/api/core/databases.py @@ -35,12 +35,18 @@ def get_database(name: Optional[DatabaseNameType] = None) -> Union[list[Database def create_database(database: Database) -> Database: - """Create a new Model definition""" + """Create a new Database definition""" database_objects = bootstrap.DATABASE_OBJECTS - _ = get_database(name=database.name) + try: + existing = get_database(name=database.name) + if existing: + raise ValueError(f"Database {database.name} already exists") + except ValueError as ex: + if "not found" not in str(ex): + raise - if any(not getattr(database_objects, key) for key in ("user", "password", "dsn")): + if any(not getattr(database, key) for key in ("user", "password", "dsn")): raise ValueError("'user', 'password', and 'dsn' are required") database_objects.append(database) diff --git a/tests/integration/client/content/test_api_server.py b/tests/integration/client/content/test_api_server.py index 4f61fcf3..87913dbb 100644 --- a/tests/integration/client/content/test_api_server.py +++ b/tests/integration/client/content/test_api_server.py @@ -28,6 +28,9 @@ def test_copy_client_settings_success(self, app_test, app_server): assert app_server is not None at = app_test(self.ST_FILE).run() + # Store original value for cleanup + original_auth_profile = at.session_state.client_settings["oci"]["auth_profile"] + # Check that Server/Client Identical assert at.session_state.client_settings == at.session_state.server_settings # Update Client Settings @@ -38,3 +41,10 @@ def test_copy_client_settings_success(self, app_test, app_server): # Validate settings have been copied assert at.session_state.client_settings == at.session_state.server_settings assert at.session_state.server_settings["oci"]["auth_profile"] == "TESTING" + + # Clean up: restore original value both in session state and on server to avoid polluting other tests + at.session_state.client_settings["oci"]["auth_profile"] = original_auth_profile + # Copy the restored settings back to the server + at.button(key="copy_client_settings").click().run() + # Verify cleanup worked + assert at.session_state.server_settings["oci"]["auth_profile"] == original_auth_profile diff --git a/tests/integration/client/content/tools/tabs/test_split_embed.py b/tests/integration/client/content/tools/tabs/test_split_embed.py index eecb360d..3dba7ea0 100644 --- a/tests/integration/client/content/tools/tabs/test_split_embed.py +++ b/tests/integration/client/content/tools/tabs/test_split_embed.py @@ -18,11 +18,9 @@ class TestStreamlit: # Streamlit File path ST_FILE = "../src/client/content/tools/tabs/split_embed.py" - def test_initialization(self, app_server, app_test, monkeypatch): - """Test initialization of the split_embed component""" - assert app_server is not None - - # Mock the API responses for get_models + def _setup_common_mocks(self, monkeypatch, oci_configured=True): + """Setup common mocks used across multiple tests""" + # Mock the API responses for get_models and OCI configs def mock_get(endpoint=None, **kwargs): if endpoint == "v1/models": return [ @@ -34,326 +32,203 @@ def mock_get(endpoint=None, **kwargs): "max_chunk_size": 1000, } ] + elif endpoint == "v1/oci": + if oci_configured: + return [ + { + "auth_profile": "DEFAULT", + "namespace": "test-namespace", + "tenancy": "test-tenancy", + "region": "us-ashburn-1" + } + ] + else: + return [ + { + "auth_profile": "DEFAULT", + "namespace": None, + "tenancy": None, + "region": "us-ashburn-1" + } + ] return {} monkeypatch.setattr("client.utils.api_call.get", mock_get) - - # Initialize app_test and run it to bring up the component - at = app_test(self.ST_FILE) - - # Mock functions that make external calls to avoid failures monkeypatch.setattr("common.functions.is_url_accessible", lambda api_base: (True, "")) monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) - # Run the app - this is critical to initialize all widgets! + def _run_app_and_verify_no_errors(self, app_test): + """Run the app and verify it renders without errors""" + at = app_test(self.ST_FILE) at = at.run() - - # Verify the app renders successfully with no errors assert not at.error + return at - # Verify that the radio button is present - radios = at.get("radio") - assert len(radios) > 0 - - # Check for presence of file uploader widgets - uploaders = at.get("file_uploader") - assert len(uploaders) >= 0 # May not be visible yet depending on default radio selection - - # Verify that the selectbox and sliders are rendered - selectboxes = at.get("selectbox") - sliders = at.get("slider") + def test_initialization(self, app_server, app_test, monkeypatch): + """Test initialization of the split_embed component""" + assert app_server is not None + self._setup_common_mocks(monkeypatch) + at = self._run_app_and_verify_no_errors(app_test) - assert len(selectboxes) > 0 - assert len(sliders) > 0 + # Verify UI components are present + assert len(at.get("radio")) > 0 + assert len(at.get("selectbox")) > 0 + assert len(at.get("slider")) > 0 - # Check for text inputs (may include the alias input) + # Test invalid input handling text_inputs = at.get("text_input") - assert len(text_inputs) >= 0 - if len(text_inputs) > 0: - # Set an invalid value with special characters for any text input text_inputs[0].set_value("invalid!value").run() - - # Check if an error was displayed - errors = at.get("error") - assert len(errors) > 0 + assert len(at.get("error")) > 0 def test_chunk_size_and_overlap_sync(self, app_server, app_test, monkeypatch): """Test synchronization between chunk size and overlap sliders and inputs""" assert app_server is not None + self._setup_common_mocks(monkeypatch) + at = self._run_app_and_verify_no_errors(app_test) - # Mock the API responses for get_models - def mock_get(endpoint=None, **kwargs): - if endpoint == "v1/models": - return [ - { - "id": "test-model", - "type": "embed", - "enabled": True, - "api_base": "http://test.url", - "max_chunk_size": 1000, - } - ] - return {} - - monkeypatch.setattr("client.utils.api_call.get", mock_get) - - # Mock functions that make external calls - monkeypatch.setattr("common.functions.is_url_accessible", lambda api_base: (True, "")) - monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) - - # Initialize app_test - at = app_test(self.ST_FILE) - - # Run the app first to initialize widgets - at = at.run() - - # Verify sliders and number inputs are present + # Verify sliders and number inputs are present and functional sliders = at.get("slider") number_inputs = at.get("number_input") - assert len(sliders) > 0 assert len(number_inputs) > 0 - # Test changing the first slider value - if len(sliders) > 0 and len(number_inputs) > 0: + # Test slider value change + if len(sliders) > 0: initial_value = sliders[0].value sliders[0].set_value(initial_value // 2).run() - - # Verify that the change was successful assert sliders[0].value == initial_value // 2 @patch("client.utils.api_call.post") def test_embed_local_file(self, mock_post, app_test, app_server, monkeypatch): """Test embedding of local files""" assert app_server is not None + self._setup_common_mocks(monkeypatch) - # Mock the API responses for get_models - def mock_get(endpoint=None, **kwargs): - if endpoint == "v1/models": - return [ - { - "id": "test-model", - "type": "embed", - "enabled": True, - "api_base": "http://test.url", - "max_chunk_size": 1000, - } - ] - return {} - - monkeypatch.setattr("client.utils.api_call.get", mock_get) - - # Mock functions that make external calls - monkeypatch.setattr("common.functions.is_url_accessible", lambda api_base: (True, "")) - monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) - - # Initialize app_test - at = app_test(self.ST_FILE) - - # Mock the API post calls + # Mock additional functions for file handling mock_post.side_effect = [ - {"message": "Files uploaded successfully"}, # Response for file upload - {"message": "10 chunks embedded."}, # Response for embedding + {"message": "Files uploaded successfully"}, + {"message": "10 chunks embedded."}, ] - - # Set up mock for st_common.local_file_payload monkeypatch.setattr( "client.utils.st_common.local_file_payload", lambda files: [("file", "test.txt", b"test content")] ) - - # Set up mock for st_common.clear_state_key monkeypatch.setattr("client.utils.st_common.clear_state_key", lambda key: None) - # Run the app first to initialize widgets - at = at.run() - - # Verify the app renders successfully - assert not at.error + at = self._run_app_and_verify_no_errors(app_test) - # Verify file uploaders and buttons are present - uploaders = at.get("file_uploader") - buttons = at.get("button") - - # Check that no API calls have been made yet + # Verify components are present and no premature API calls + assert len(at.get("file_uploader")) >= 0 + assert len(at.get("button")) >= 0 assert mock_post.call_count == 0 - # Test successful - assert True - def test_web_api_base_validation(self, app_server, app_test, monkeypatch): """Test web URL validation""" assert app_server is not None + self._setup_common_mocks(monkeypatch) + at = self._run_app_and_verify_no_errors(app_test) - # Mock the API responses for get_models - def mock_get(endpoint=None, **kwargs): - if endpoint == "v1/models": - return [ - { - "id": "test-model", - "type": "embed", - "enabled": True, - "api_base": "http://test.url", - "max_chunk_size": 1000, - } - ] - return {} - - monkeypatch.setattr("client.utils.api_call.get", mock_get) - - # Mock functions that make external calls - monkeypatch.setattr("common.functions.is_url_accessible", lambda api_base: (True, "")) - monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) - - # Initialize app_test - at = app_test(self.ST_FILE) - - # Run the app - at = at.run() - - # Verify the app renders successfully - assert not at.error - - # Check for text inputs and buttons - text_inputs = at.get("text_input") - buttons = at.get("button") - - assert len(text_inputs) >= 0 - assert len(buttons) >= 0 - - # Test passes - assert True + # Verify UI components are present + assert len(at.get("text_input")) >= 0 + assert len(at.get("button")) >= 0 @patch("client.utils.api_call.post") def test_api_error_handling(self, mock_post, app_server, app_test, monkeypatch): """Test error handling when API calls fail""" assert app_server is not None + self._setup_common_mocks(monkeypatch) - # Mock the API responses for get_models - def mock_get(endpoint=None, **kwargs): - if endpoint == "v1/models": - return [ - { - "id": "test-model", - "type": "embed", - "enabled": True, - "api_base": "http://test.url", - "max_chunk_size": 1000, - } - ] - return {} - - monkeypatch.setattr("client.utils.api_call.get", mock_get) - - # Mock functions that make external calls - monkeypatch.setattr("common.functions.is_url_accessible", lambda api_base: (True, "")) - monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) - - # Initialize app_test - at = app_test(self.ST_FILE) - - # Create ApiError exception + # Setup error handling test class ApiError(Exception): - """Mock API Error class""" - pass - # Mock API call to raise an error mock_post.side_effect = ApiError("Test API error") monkeypatch.setattr("client.utils.api_call.ApiError", ApiError) - - # Set up mock for st_common.local_file_payload monkeypatch.setattr( "client.utils.st_common.local_file_payload", lambda files: [("file", "test.txt", b"test content")] ) - # Run the app first to initialize widgets - at = at.run() - - # Verify app renders without errors - assert not at.error - - # Verify radio buttons and buttons are present - radios = at.get("radio") - buttons = at.get("button") + at = self._run_app_and_verify_no_errors(app_test) - assert len(radios) >= 0 - assert len(buttons) >= 0 - - # Test passes - assert True + # Verify UI components are present + assert len(at.get("radio")) >= 0 + assert len(at.get("button")) >= 0 @patch("client.utils.api_call.post") def test_embed_oci_files(self, mock_post, app_server, app_test, monkeypatch): """Test embedding of OCI files""" assert app_server is not None - # Create mock responses for OCI endpoints + # Mock OCI-specific responses mock_compartments = {"comp1": "ocid1.compartment.oc1..aaaaaaaa1"} mock_buckets = ["bucket1", "bucket2"] mock_objects = ["file1.txt", "file2.pdf", "file3.csv"] - # Set up get_compartments mock def mock_get_response(endpoint=None, **kwargs): - if "compartments" in endpoint: + if "compartments" in str(endpoint): return mock_compartments - elif "buckets" in endpoint: + elif "buckets" in str(endpoint): return mock_buckets - elif "objects" in endpoint: + elif "objects" in str(endpoint): return mock_objects elif endpoint == "v1/models": - return [ - { - "id": "test-model", - "type": "embed", - "enabled": True, - "api_base": "http://test.url", - "max_chunk_size": 1000, - } - ] + return [{"id": "test-model", "type": "embed", "enabled": True, "api_base": "http://test.url", "max_chunk_size": 1000}] + elif endpoint == "v1/oci": + return [{"auth_profile": "DEFAULT", "namespace": "test-namespace", "tenancy": "test-tenancy", "region": "us-ashburn-1"}] return {} monkeypatch.setattr("client.utils.api_call.get", mock_get_response) + monkeypatch.setattr("common.functions.is_url_accessible", lambda api_base: (True, "")) + monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) - # Mock the files_data_frame function to return a proper DataFrame + # Mock DataFrame function def mock_files_data_frame(objects, process=False): - if not objects: - return pd.DataFrame({"File": [], "Process": []}) - - data = {"File": objects, "Process": [process] * len(objects)} - return pd.DataFrame(data) + return pd.DataFrame({"File": objects or [], "Process": [process] * len(objects or [])}) monkeypatch.setattr("client.content.tools.tabs.split_embed.files_data_frame", mock_files_data_frame) - - # Mock get_compartments function monkeypatch.setattr("client.content.tools.tabs.split_embed.get_compartments", lambda: mock_compartments) + monkeypatch.setattr("client.utils.st_common.clear_state_key", lambda key: None) - # Initialize app_test - at = app_test(self.ST_FILE) - - # Set up session state requirements - # at.session_state.oci_config = {"DEFAULT": {"namespace": "test-namespace"}} - - # Mock the API post calls (downloading and embedding) mock_post.side_effect = [ - ["file1.txt", "file2.pdf", "file3.csv"], # Response for file download - {"message": "15 chunks embedded."}, # Response for embedding + ["file1.txt", "file2.pdf", "file3.csv"], + {"message": "15 chunks embedded."}, ] - # Set up mock for st_common.clear_state_key - monkeypatch.setattr("client.utils.st_common.clear_state_key", lambda key: None) + try: + at = self._run_app_and_verify_no_errors(app_test) + assert len(at.get("selectbox")) > 0 + except AssertionError: + # Some OCI configuration issues are expected in test environment + pass + + def test_file_source_radio_with_oci_configured(self, app_server, app_test, monkeypatch): + """Test file source radio button options when OCI is configured""" + assert app_server is not None + self._setup_common_mocks(monkeypatch, oci_configured=True) + at = self._run_app_and_verify_no_errors(app_test) + + # Verify OCI option is available when properly configured + radios = at.get("radio") + assert len(radios) > 0 + + file_source_radio = next((r for r in radios if hasattr(r, 'options') and "OCI" in r.options), None) + assert file_source_radio is not None, "File source radio button not found" + assert "OCI" in file_source_radio.options, "OCI option missing from radio button" + assert "Local" in file_source_radio.options, "Local option missing from radio button" + assert "Web" in file_source_radio.options, "Web option missing from radio button" + + def test_file_source_radio_without_oci_configured(self, app_server, app_test, monkeypatch): + """Test file source radio button options when OCI is not configured""" + assert app_server is not None + self._setup_common_mocks(monkeypatch, oci_configured=False) + at = self._run_app_and_verify_no_errors(app_test) + + # Verify OCI option is NOT available when not properly configured + radios = at.get("radio") + assert len(radios) > 0 - # Run with URL check passing - with patch("common.functions.is_url_accessible", return_value=(True, "")): - try: - at = at.run() - # If the app runs without errors, verify that components are present - assert len(at.get("selectbox")) > 0 - except AssertionError: - # In some cases there might be an error in the UI due to OCI configuration - # This is expected and we can allow the test to pass anyway - # The main purpose of this test is to verify the mocks are set up correctly - pass - - # Test passes regardless of UI errors - assert True + file_source_radio = next((r for r in radios if hasattr(r, 'options') and ("Local" in r.options or "Web" in r.options)), None) + assert file_source_radio is not None, "File source radio button not found" + assert "OCI" not in file_source_radio.options, "OCI option should not be present when not configured" + assert "Local" in file_source_radio.options, "Local option missing from radio button" + assert "Web" in file_source_radio.options, "Web option missing from radio button" diff --git a/tests/integration/server/test_endpoints_embed.py b/tests/integration/server/test_endpoints_embed.py index 102972ab..aba3ccef 100644 --- a/tests/integration/server/test_endpoints_embed.py +++ b/tests/integration/server/test_endpoints_embed.py @@ -383,6 +383,11 @@ def test_split_embed_with_different_file_types(self, client, auth_headers, db_co num_chunks = int(response_data["message"].split()[0]) assert num_chunks > 0, "Should have embedded at least one chunk" + # Clean up - drop the vector store that was created + expected_vector_store_name = self.get_vector_store_name("test_mixed_files") + drop_response = client.delete(f"/v1/embed/{expected_vector_store_name}", headers=auth_headers["valid_auth"]) + assert drop_response.status_code == 200 + def test_vector_store_creation_and_deletion(self, client, auth_headers, db_container, mock_embedding_model): """Test that vector stores are created in the database and can be deleted""" assert db_container is not None diff --git a/tests/unit/server/api/core/test_core_bootstrap.py b/tests/unit/server/api/core/test_core_bootstrap.py new file mode 100644 index 00000000..ca94b6ce --- /dev/null +++ b/tests/unit/server/api/core/test_core_bootstrap.py @@ -0,0 +1,53 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import patch, MagicMock + +from server.api.core import bootstrap + + +class TestBootstrap: + """Test bootstrap module functionality""" + + @patch("server.bootstrap.databases.main") + @patch("server.bootstrap.models.main") + @patch("server.bootstrap.oci.main") + @patch("server.bootstrap.prompts.main") + @patch("server.bootstrap.settings.main") + def test_module_imports_and_initialization( + self, mock_settings, mock_prompts, mock_oci, mock_models, mock_databases + ): + """Test that all bootstrap objects are properly initialized""" + # Mock return values + mock_databases.return_value = [MagicMock()] + mock_models.return_value = [MagicMock()] + mock_oci.return_value = [MagicMock()] + mock_prompts.return_value = [MagicMock()] + mock_settings.return_value = [MagicMock()] + + # Reload the module to trigger initialization + import importlib + + importlib.reload(bootstrap) + + # Verify all bootstrap functions were called + mock_databases.assert_called_once() + mock_models.assert_called_once() + mock_oci.assert_called_once() + mock_prompts.assert_called_once() + mock_settings.assert_called_once() + + # Verify objects are created + assert hasattr(bootstrap, "DATABASE_OBJECTS") + assert hasattr(bootstrap, "MODEL_OBJECTS") + assert hasattr(bootstrap, "OCI_OBJECTS") + assert hasattr(bootstrap, "PROMPT_OBJECTS") + assert hasattr(bootstrap, "SETTINGS_OBJECTS") + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(bootstrap, "logger") + assert bootstrap.logger.name == "api.core.bootstrap" diff --git a/tests/unit/server/api/core/test_core_databases.py b/tests/unit/server/api/core/test_core_databases.py new file mode 100644 index 00000000..e3e9deca --- /dev/null +++ b/tests/unit/server/api/core/test_core_databases.py @@ -0,0 +1,410 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable +# pylint: disable=attribute-defined-outside-init + +from unittest.mock import patch, MagicMock +import pytest + +from server.api.core import databases +from server.api.core import bootstrap +from common.schema import Database + + +class TestDatabases: + """Test databases module functionality""" + + def setup_method(self): + """Setup test data before each test""" + self.sample_database = Database(name="test_db", user="test_user", password="test_password", dsn="test_dsn") + self.sample_database_2 = Database( + name="test_db_2", user="test_user_2", password="test_password_2", dsn="test_dsn_2" + ) + + @patch("server.api.core.databases.bootstrap.DATABASE_OBJECTS") + def test_get_database_all(self, mock_database_objects): + """Test getting all databases when no name is provided""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database, self.sample_database_2])) + mock_database_objects.__len__ = MagicMock(return_value=2) + + result = databases.get_database() + + assert result == [self.sample_database, self.sample_database_2] + assert len(result) == 2 + + @patch("server.api.core.databases.bootstrap.DATABASE_OBJECTS") + def test_get_database_by_name_found(self, mock_database_objects): + """Test getting database by name when it exists""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database, self.sample_database_2])) + mock_database_objects.__len__ = MagicMock(return_value=2) + + result = databases.get_database(name="test_db") + + assert result == [self.sample_database] + assert len(result) == 1 + + @patch("server.api.core.databases.bootstrap.DATABASE_OBJECTS") + def test_get_database_by_name_not_found(self, mock_database_objects): + """Test getting database by name when it doesn't exist""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database])) + mock_database_objects.__len__ = MagicMock(return_value=1) + + with pytest.raises(ValueError, match="nonexistent not found"): + databases.get_database(name="nonexistent") + + @patch("server.api.core.databases.bootstrap.DATABASE_OBJECTS") + def test_get_database_empty_list(self, mock_database_objects): + """Test getting databases when list is empty""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([])) + mock_database_objects.__len__ = MagicMock(return_value=0) + + result = databases.get_database() + + assert result == [] + + @patch("server.api.core.databases.bootstrap.DATABASE_OBJECTS") + def test_get_database_empty_list_with_name(self, mock_database_objects): + """Test getting database by name when list is empty""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([])) + mock_database_objects.__len__ = MagicMock(return_value=0) + + with pytest.raises(ValueError, match="test_db not found"): + databases.get_database(name="test_db") + + def test_create_database_success(self, db_container): + """Test successful database creation when database doesn't exist""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + # Clear the list to start fresh + bootstrap.DATABASE_OBJECTS.clear() + + # Create a new database + new_database = Database(name="new_test_db", user="test_user", password="test_password", dsn="test_dsn") + + result = databases.create_database(new_database) + + # Verify database was added + assert len(bootstrap.DATABASE_OBJECTS) == 1 + assert bootstrap.DATABASE_OBJECTS[0].name == "new_test_db" + assert result == [new_database] + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_create_database_already_exists(self, db_container): + """Test database creation when database already exists""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + # Add a database to the list + bootstrap.DATABASE_OBJECTS.clear() + existing_db = Database(name="existing_db", user="test_user", password="test_password", dsn="test_dsn") + bootstrap.DATABASE_OBJECTS.append(existing_db) + + # Try to create a database with the same name + duplicate_db = Database(name="existing_db", user="other_user", password="other_password", dsn="other_dsn") + + # Should raise an error for duplicate database + with pytest.raises(ValueError, match="Database existing_db already exists"): + databases.create_database(duplicate_db) + + # Verify only original database exists + assert len(bootstrap.DATABASE_OBJECTS) == 1 + assert bootstrap.DATABASE_OBJECTS[0] == existing_db + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_create_database_missing_user(self, db_container): + """Test database creation with missing user field""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + + # Create database with missing user + incomplete_db = Database(name="incomplete_db", password="test_password", dsn="test_dsn") + + with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): + databases.create_database(incomplete_db) + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_create_database_missing_password(self, db_container): + """Test database creation with missing password field""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + + # Create database with missing password + incomplete_db = Database(name="incomplete_db", user="test_user", dsn="test_dsn") + + with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): + databases.create_database(incomplete_db) + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_create_database_missing_dsn(self, db_container): + """Test database creation with missing dsn field""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + + # Create database with missing dsn + incomplete_db = Database(name="incomplete_db", user="test_user", password="test_password") + + with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): + databases.create_database(incomplete_db) + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_create_database_multiple_missing_fields(self, db_container): + """Test database creation with multiple missing required fields""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + + # Create database with multiple missing fields + incomplete_db = Database(name="incomplete_db") + + with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): + databases.create_database(incomplete_db) + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_delete_database(self, db_container): + """Test database deletion""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + # Setup test data + db1 = Database(name="test_db_1", user="user1", password="pass1", dsn="dsn1") + db2 = Database(name="test_db_2", user="user2", password="pass2", dsn="dsn2") + db3 = Database(name="test_db_3", user="user3", password="pass3", dsn="dsn3") + + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend([db1, db2, db3]) + + # Delete middle database + databases.delete_database("test_db_2") + + # Verify deletion + assert len(bootstrap.DATABASE_OBJECTS) == 2 + names = [db.name for db in bootstrap.DATABASE_OBJECTS] + assert "test_db_1" in names + assert "test_db_2" not in names + assert "test_db_3" in names + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_delete_database_nonexistent(self, db_container): + """Test deleting non-existent database""" + assert db_container is not None + + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + # Setup test data + db1 = Database(name="test_db_1", user="user1", password="pass1", dsn="dsn1") + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.append(db1) + + original_length = len(bootstrap.DATABASE_OBJECTS) + + # Try to delete non-existent database (should not raise error) + databases.delete_database("nonexistent") + + # Verify no change + assert len(bootstrap.DATABASE_OBJECTS) == original_length + assert bootstrap.DATABASE_OBJECTS[0].name == "test_db_1" + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_delete_database_empty_list(self, db_container): + """Test deleting from empty database list""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + + # Try to delete from empty list (should not raise error) + databases.delete_database("any_name") + + # Verify still empty + assert len(bootstrap.DATABASE_OBJECTS) == 0 + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_delete_database_multiple_same_name(self, db_container): + """Test deleting when multiple databases have the same name""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + # Setup test data with duplicate names + db1 = Database(name="duplicate", user="user1", password="pass1", dsn="dsn1") + db2 = Database(name="duplicate", user="user2", password="pass2", dsn="dsn2") + db3 = Database(name="other", user="user3", password="pass3", dsn="dsn3") + + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend([db1, db2, db3]) + + # Delete databases with duplicate name + databases.delete_database("duplicate") + + # Verify all duplicates are removed + assert len(bootstrap.DATABASE_OBJECTS) == 1 + assert bootstrap.DATABASE_OBJECTS[0].name == "other" + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(databases, "logger") + assert databases.logger.name == "api.core.database" + + def test_get_database_filters_correctly(self, db_container): + """Test that get_database correctly filters by name""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + # Setup test data + db1 = Database(name="alpha", user="user1", password="pass1", dsn="dsn1") + db2 = Database(name="beta", user="user2", password="pass2", dsn="dsn2") + db3 = Database(name="alpha", user="user3", password="pass3", dsn="dsn3") # Duplicate name + + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend([db1, db2, db3]) + + # Test getting all + all_dbs = databases.get_database() + assert len(all_dbs) == 3 + + # Test getting by specific name + alpha_dbs = databases.get_database(name="alpha") + assert len(alpha_dbs) == 2 + assert all(db.name == "alpha" for db in alpha_dbs) + + beta_dbs = databases.get_database(name="beta") + assert len(beta_dbs) == 1 + assert beta_dbs[0].name == "beta" + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_database_model_validation(self, db_container): + """Test Database model validation and optional fields""" + assert db_container is not None + # Test with all required fields + complete_db = Database(name="complete", user="test_user", password="test_password", dsn="test_dsn") + assert complete_db.name == "complete" + assert complete_db.user == "test_user" + assert complete_db.password == "test_password" + assert complete_db.dsn == "test_dsn" + assert complete_db.connected is False # Default value + assert complete_db.tcp_connect_timeout == 5 # Default value + assert complete_db.vector_stores == [] # Default value + + # Test with optional fields + complete_db_with_options = Database( + name="complete_with_options", + user="test_user", + password="test_password", + dsn="test_dsn", + wallet_location="/path/to/wallet", + wallet_password="wallet_pass", + tcp_connect_timeout=10, + ) + assert complete_db_with_options.wallet_location == "/path/to/wallet" + assert complete_db_with_options.wallet_password == "wallet_pass" + assert complete_db_with_options.tcp_connect_timeout == 10 + + def test_create_database_real_scenario(self, db_container): + """Test create_database with realistic data using container DB""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + + # Create database with realistic configuration + test_db = Database( + name="container_test", + user="PYTEST", + password="OrA_41_3xPl0d3r", + dsn="//localhost:1525/FREEPDB1", + tcp_connect_timeout=10, + ) + + result = databases.create_database(test_db) + + # Verify creation + assert len(bootstrap.DATABASE_OBJECTS) == 1 + created_db = bootstrap.DATABASE_OBJECTS[0] + assert created_db.name == "container_test" + assert created_db.user == "PYTEST" + assert created_db.dsn == "//localhost:1525/FREEPDB1" + assert created_db.tcp_connect_timeout == 10 + assert result == [test_db] + + finally: + # Restore original state + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) diff --git a/tests/unit/server/api/core/test_core_models.py b/tests/unit/server/api/core/test_core_models.py new file mode 100644 index 00000000..f3260b47 --- /dev/null +++ b/tests/unit/server/api/core/test_core_models.py @@ -0,0 +1,205 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import patch, MagicMock + +import pytest + +from server.api.core import models +from server.api.core.models import URLUnreachableError, InvalidModelError, ExistsModelError, UnknownModelError +from common.schema import Model + + +class TestModelsExceptions: + """Test custom exception classes""" + + def test_url_unreachable_error(self): + """Test URLUnreachableError exception""" + error = URLUnreachableError("URL is unreachable") + assert str(error) == "URL is unreachable" + assert isinstance(error, ValueError) + + def test_invalid_model_error(self): + """Test InvalidModelError exception""" + error = InvalidModelError("Invalid model data") + assert str(error) == "Invalid model data" + assert isinstance(error, ValueError) + + def test_exists_model_error(self): + """Test ExistsModelError exception""" + error = ExistsModelError("Model already exists") + assert str(error) == "Model already exists" + assert isinstance(error, ValueError) + + def test_unknown_model_error(self): + """Test UnknownModelError exception""" + error = UnknownModelError("Model not found") + assert str(error) == "Model not found" + assert isinstance(error, ValueError) + + +class TestModels: + """Test models module functionality""" + + def setup_method(self): + """Setup test data before each test""" + self.sample_model = Model( + id="test-model", provider="openai", type="ll", enabled=True, api_base="https://api.openai.com" + ) + self.disabled_model = Model(id="disabled-model", provider="anthropic", type="ll", enabled=False) + + @patch("server.api.core.models.bootstrap.MODEL_OBJECTS") + def test_get_model_all_models(self, mock_model_objects): + """Test getting all models without filters""" + mock_model_objects.__iter__ = MagicMock(return_value=iter([self.sample_model, self.disabled_model])) + mock_model_objects.__len__ = MagicMock(return_value=2) + + result = models.get_model() + + assert result == [self.sample_model, self.disabled_model] + + @patch("server.api.core.models.bootstrap.MODEL_OBJECTS") + def test_get_model_by_id_found(self, mock_model_objects): + """Test getting model by ID when it exists""" + mock_model_objects.__iter__ = MagicMock(return_value=iter([self.sample_model])) + mock_model_objects.__len__ = MagicMock(return_value=1) + + result = models.get_model(model_id="test-model") + + assert result == self.sample_model + + @patch("server.api.core.models.bootstrap.MODEL_OBJECTS") + def test_get_model_by_id_not_found(self, mock_model_objects): + """Test getting model by ID when it doesn't exist""" + mock_model_objects.__iter__ = MagicMock(return_value=iter([self.sample_model])) + mock_model_objects.__len__ = MagicMock(return_value=1) + + with pytest.raises(UnknownModelError, match="nonexistent not found"): + models.get_model(model_id="nonexistent") + + @patch("server.api.core.models.bootstrap.MODEL_OBJECTS") + def test_get_model_by_provider(self, mock_model_objects): + """Test filtering models by provider""" + all_models = [self.sample_model, self.disabled_model] + mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) + mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) + + result = models.get_model(model_provider="openai") + + # Since only one model matches provider="openai", it should return the single object + assert result == self.sample_model + + @patch("server.api.core.models.bootstrap.MODEL_OBJECTS") + def test_get_model_by_type(self, mock_model_objects): + """Test filtering models by type""" + all_models = [self.sample_model, self.disabled_model] + mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) + mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) + + result = models.get_model(model_type="ll") + + assert result == all_models + + @patch("server.api.core.models.bootstrap.MODEL_OBJECTS") + def test_get_model_exclude_disabled(self, mock_model_objects): + """Test excluding disabled models""" + all_models = [self.sample_model, self.disabled_model] + mock_model_objects.__iter__ = MagicMock(return_value=iter(all_models)) + mock_model_objects.__len__ = MagicMock(return_value=len(all_models)) + + result = models.get_model(include_disabled=False) + + # Since only one model is enabled, it should return the single object + assert result == self.sample_model + + @patch("server.api.core.models.bootstrap.MODEL_OBJECTS") + @patch("server.api.core.models.get_model") + @patch("common.functions.is_url_accessible") + def test_create_model_success(self, mock_url_check, mock_get_model, mock_model_objects): + """Test successful model creation""" + mock_model_objects.append = MagicMock() + mock_get_model.side_effect = [ + UnknownModelError("test-model not found"), # First call should fail (model doesn't exist) + self.sample_model, # Second call returns the created model + ] + mock_url_check.return_value = (True, None) + + result = models.create_model(self.sample_model) + + mock_model_objects.append.assert_called_once_with(self.sample_model) + assert result == self.sample_model + + @patch("server.api.core.models.bootstrap.MODEL_OBJECTS") + @patch("server.api.core.models.get_model") + def test_create_model_already_exists(self, mock_get_model, mock_model_objects): + """Test creating model that already exists""" + mock_get_model.return_value = self.sample_model # Model already exists + + with pytest.raises(ExistsModelError, match="Model: openai/test-model already exists"): + models.create_model(self.sample_model) + + @patch("server.api.core.models.bootstrap.MODEL_OBJECTS") + @patch("server.api.core.models.get_model") + @patch("common.functions.is_url_accessible") + def test_create_model_unreachable_url(self, mock_url_check, mock_get_model, mock_model_objects): + """Test creating model with unreachable URL""" + mock_model_objects.append = MagicMock() + + # Create a copy of the model that will be modified + test_model = Model( + id="test-model", + provider="openai", + type="ll", + enabled=True, # Start as enabled + api_base="https://api.openai.com", + ) + + modified_model = Model( + id="test-model", + provider="openai", + type="ll", + enabled=False, # Will be disabled due to URL check + api_base="https://api.openai.com", + ) + + mock_get_model.side_effect = [UnknownModelError("test-model not found"), modified_model] + mock_url_check.return_value = (False, "Connection failed") + + result = models.create_model(test_model) + + assert result.enabled is False + + @patch("server.api.core.models.bootstrap.MODEL_OBJECTS") + @patch("server.api.core.models.get_model") + def test_create_model_skip_url_check(self, mock_get_model, mock_model_objects): + """Test creating model without URL check""" + mock_model_objects.append = MagicMock() + mock_get_model.side_effect = [UnknownModelError("test-model not found"), self.sample_model] + + result = models.create_model(self.sample_model, check_url=False) + + assert result == self.sample_model + + @patch("server.api.core.models.bootstrap") + def test_delete_model(self, mock_bootstrap): + """Test model deletion""" + mock_model_objects = [ + Model(id="test-model", provider="openai", type="ll"), + Model(id="other-model", provider="anthropic", type="ll"), + ] + mock_bootstrap.MODEL_OBJECTS = mock_model_objects + + models.delete_model("openai", "test-model") + + # Verify the model was removed + remaining_models = [m for m in mock_bootstrap.MODEL_OBJECTS if (m.id, m.provider) != ("test-model", "openai")] + assert len(remaining_models) == 1 + assert remaining_models[0].id == "other-model" + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(models, "logger") + assert models.logger.name == "api.core.models" diff --git a/tests/unit/server/api/core/test_core_oci.py b/tests/unit/server/api/core/test_core_oci.py new file mode 100644 index 00000000..dcdc88a9 --- /dev/null +++ b/tests/unit/server/api/core/test_core_oci.py @@ -0,0 +1,104 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import patch, MagicMock + +import pytest + +from server.api.core import oci +from common.schema import OracleCloudSettings, Settings, OciSettings + + +class TestOci: + """Test OCI module functionality""" + + def setup_method(self): + """Setup test data before each test""" + self.sample_oci_default = OracleCloudSettings( + auth_profile="DEFAULT", compartment_id="ocid1.compartment.oc1..default" + ) + self.sample_oci_custom = OracleCloudSettings( + auth_profile="CUSTOM", compartment_id="ocid1.compartment.oc1..custom" + ) + self.sample_client_settings = Settings(client="test_client", oci=OciSettings(auth_profile="CUSTOM")) + + @patch("server.api.core.oci.bootstrap") + def test_get_oci_all(self, mock_bootstrap): + """Test getting all OCI settings when no filters are provided""" + all_oci = [self.sample_oci_default, self.sample_oci_custom] + mock_bootstrap.OCI_OBJECTS = all_oci + + result = oci.get_oci() + + assert result == all_oci + + @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") + def test_get_oci_no_objects_configured(self, mock_oci_objects): + """Test getting OCI settings when none are configured""" + mock_oci_objects.__bool__ = MagicMock(return_value=False) + + with pytest.raises(ValueError, match="not configured"): + oci.get_oci() + + @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") + def test_get_oci_by_auth_profile_found(self, mock_oci_objects): + """Test getting OCI settings by auth_profile when it exists""" + mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default, self.sample_oci_custom])) + + result = oci.get_oci(auth_profile="CUSTOM") + + assert result == self.sample_oci_custom + + @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") + def test_get_oci_by_auth_profile_not_found(self, mock_oci_objects): + """Test getting OCI settings by auth_profile when it doesn't exist""" + mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default])) + + with pytest.raises(ValueError, match="profile 'NONEXISTENT' not found"): + oci.get_oci(auth_profile="NONEXISTENT") + + @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") + @patch("server.api.core.oci.settings.get_client_settings") + def test_get_oci_by_client_with_oci_settings(self, mock_get_client_settings, mock_oci_objects): + """Test getting OCI settings by client when client has OCI settings""" + mock_get_client_settings.return_value = self.sample_client_settings + mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default, self.sample_oci_custom])) + + result = oci.get_oci(client="test_client") + + assert result == self.sample_oci_custom + + @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") + @patch("server.api.core.oci.settings.get_client_settings") + def test_get_oci_by_client_without_oci_settings(self, mock_get_client_settings, mock_oci_objects): + """Test getting OCI settings by client when client has no OCI settings""" + client_settings_no_oci = Settings(client="test_client", oci=None) + mock_get_client_settings.return_value = client_settings_no_oci + mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default])) + + result = oci.get_oci(client="test_client") + + assert result == self.sample_oci_default + + @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") + @patch("server.api.core.oci.settings.get_client_settings") + def test_get_oci_by_client_no_matching_profile(self, mock_get_client_settings, mock_oci_objects): + """Test getting OCI settings by client when no matching profile exists""" + mock_get_client_settings.return_value = self.sample_client_settings + mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default])) # Only DEFAULT profile + + with pytest.raises(ValueError, match="No settings found for client 'test_client' with auth_profile 'CUSTOM'"): + oci.get_oci(client="test_client") + + def test_get_oci_both_client_and_auth_profile(self): + """Test that providing both client and auth_profile raises an error""" + with pytest.raises(ValueError, match="provide either 'client' or 'auth_profile', not both"): + oci.get_oci(client="test_client", auth_profile="CUSTOM") + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(oci, "logger") + assert oci.logger.name == "api.core.oci" diff --git a/tests/unit/server/api/core/test_core_prompts.py b/tests/unit/server/api/core/test_core_prompts.py new file mode 100644 index 00000000..43a2f81b --- /dev/null +++ b/tests/unit/server/api/core/test_core_prompts.py @@ -0,0 +1,83 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import patch, MagicMock + +import pytest + +from server.api.core import prompts +from common.schema import Prompt + + +class TestPrompts: + """Test prompts module functionality""" + + def setup_method(self): + """Setup test data before each test""" + self.sample_prompt_1 = Prompt(category="sys", name="default", prompt="You are a helpful assistant.") + self.sample_prompt_2 = Prompt(category="sys", name="custom", prompt="You are a custom assistant.") + self.sample_prompt_3 = Prompt(category="ctx", name="greeting", prompt="Hello, how can I help you?") + + @patch("server.api.core.prompts.bootstrap") + def test_get_prompts_all(self, mock_bootstrap): + """Test getting all prompts when no filters are provided""" + all_prompts = [self.sample_prompt_1, self.sample_prompt_2, self.sample_prompt_3] + mock_bootstrap.PROMPT_OBJECTS = all_prompts + + result = prompts.get_prompts() + + assert result == all_prompts + + @patch("server.api.core.prompts.bootstrap.PROMPT_OBJECTS") + def test_get_prompts_by_category(self, mock_prompt_objects): + """Test filtering prompts by category""" + all_prompts = [self.sample_prompt_1, self.sample_prompt_2, self.sample_prompt_3] + mock_prompt_objects.__iter__ = MagicMock(return_value=iter(all_prompts)) + + result = prompts.get_prompts(category="sys") + + expected = [self.sample_prompt_1, self.sample_prompt_2] + assert result == expected + + @patch("server.api.core.prompts.bootstrap.PROMPT_OBJECTS") + def test_get_prompts_by_category_and_name_found(self, mock_prompt_objects): + """Test filtering prompts by category and name when found""" + all_prompts = [self.sample_prompt_1, self.sample_prompt_2, self.sample_prompt_3] + mock_prompt_objects.__iter__ = MagicMock(return_value=iter(all_prompts)) + + result = prompts.get_prompts(category="sys", name="custom") + + assert result == self.sample_prompt_2 + + @patch("server.api.core.prompts.bootstrap.PROMPT_OBJECTS") + def test_get_prompts_by_category_and_name_not_found(self, mock_prompt_objects): + """Test filtering prompts by category and name when not found""" + all_prompts = [self.sample_prompt_1, self.sample_prompt_2, self.sample_prompt_3] + mock_prompt_objects.__iter__ = MagicMock(return_value=iter(all_prompts)) + + with pytest.raises(ValueError, match="nonexistent \\(sys\\) not found"): + prompts.get_prompts(category="sys", name="nonexistent") + + @patch("server.api.core.prompts.bootstrap.PROMPT_OBJECTS") + def test_get_prompts_by_name_without_category_raises_error(self, mock_prompt_objects): + """Test that filtering by name without category raises an error""" + with pytest.raises(ValueError, match="Cannot filter prompts by name without specifying category"): + prompts.get_prompts(name="default") + + @patch("server.api.core.prompts.bootstrap.PROMPT_OBJECTS") + def test_get_prompts_empty_category_filter(self, mock_prompt_objects): + """Test filtering by category that has no matches""" + all_prompts = [self.sample_prompt_1, self.sample_prompt_2, self.sample_prompt_3] + mock_prompt_objects.__iter__ = MagicMock(return_value=iter(all_prompts)) + + result = prompts.get_prompts(category="nonexistent") + + assert result == [] + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(prompts, "logger") + assert prompts.logger.name == "api.core.prompts" diff --git a/tests/unit/server/api/core/test_core_settings.py b/tests/unit/server/api/core/test_core_settings.py new file mode 100644 index 00000000..d2fe6b24 --- /dev/null +++ b/tests/unit/server/api/core/test_core_settings.py @@ -0,0 +1,177 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import patch, MagicMock, mock_open +import os + +import pytest + +from server.api.core import settings +from common.schema import Settings, Configuration, Database, Model, OracleCloudSettings, Prompt + + +class TestSettings: + """Test settings module functionality""" + + def setup_method(self): + """Setup test data before each test""" + self.default_settings = Settings(client="default") + self.test_client_settings = Settings(client="test_client") + self.sample_config_data = { + "database_configs": [{"name": "test_db", "user": "user", "password": "pass", "dsn": "dsn"}], + "model_configs": [{"id": "test-model", "provider": "openai", "type": "ll"}], + "oci_configs": [{"auth_profile": "DEFAULT", "compartment_id": "ocid1.compartment.oc1..test"}], + "prompt_configs": [{"category": "sys", "name": "default", "prompt": "You are helpful"}], + "client_settings": {"client": "default", "max_tokens": 1000, "temperature": 0.7}, + } + + @patch("server.api.core.settings.bootstrap") + def test_create_client_settings_success(self, mock_bootstrap): + """Test successful client settings creation""" + # Create a list that includes the default settings and will be appended to + settings_list = [self.default_settings] + mock_bootstrap.SETTINGS_OBJECTS = settings_list + + result = settings.create_client_settings("new_client") + + assert result.client == "new_client" + assert result.ll_model.max_completion_tokens == self.default_settings.ll_model.max_completion_tokens + # Check that a new client was added to the list + assert len(settings_list) == 2 + assert settings_list[-1].client == "new_client" + + @patch("server.api.core.settings.bootstrap.SETTINGS_OBJECTS") + def test_create_client_settings_already_exists(self, mock_settings_objects): + """Test creating client settings when client already exists""" + mock_settings_objects.__iter__ = MagicMock(return_value=iter([self.test_client_settings])) + + with pytest.raises(ValueError, match="client test_client already exists"): + settings.create_client_settings("test_client") + + @patch("server.api.core.settings.bootstrap.SETTINGS_OBJECTS") + def test_get_client_settings_found(self, mock_settings_objects): + """Test getting client settings when client exists""" + mock_settings_objects.__iter__ = MagicMock(return_value=iter([self.test_client_settings])) + + result = settings.get_client_settings("test_client") + + assert result == self.test_client_settings + + @patch("server.api.core.settings.bootstrap.SETTINGS_OBJECTS") + def test_get_client_settings_not_found(self, mock_settings_objects): + """Test getting client settings when client doesn't exist""" + mock_settings_objects.__iter__ = MagicMock(return_value=iter([self.default_settings])) + + with pytest.raises(ValueError, match="client nonexistent not found"): + settings.get_client_settings("nonexistent") + + @patch("server.api.core.settings.bootstrap.DATABASE_OBJECTS") + @patch("server.api.core.settings.bootstrap.MODEL_OBJECTS") + @patch("server.api.core.settings.bootstrap.OCI_OBJECTS") + @patch("server.api.core.settings.bootstrap.PROMPT_OBJECTS") + def test_get_server_config(self, mock_prompts, mock_oci, mock_models, mock_databases): + """Test getting server configuration""" + mock_databases.__iter__ = MagicMock( + return_value=iter([Database(name="test", user="u", password="p", dsn="d")]) + ) + mock_models.__iter__ = MagicMock(return_value=iter([Model(id="test", provider="openai", type="ll")])) + mock_oci.__iter__ = MagicMock(return_value=iter([OracleCloudSettings(auth_profile="DEFAULT")])) + mock_prompts.__iter__ = MagicMock(return_value=iter([Prompt(category="sys", name="test", prompt="test")])) + + result = settings.get_server_config() + + assert "database_configs" in result + assert "model_configs" in result + assert "oci_configs" in result + assert "prompt_configs" in result + + @patch("server.api.core.settings.bootstrap.SETTINGS_OBJECTS") + @patch("server.api.core.settings.get_client_settings") + def test_update_client_settings(self, mock_get_settings, mock_settings_objects): + """Test updating client settings""" + mock_get_settings.return_value = self.test_client_settings + mock_settings_objects.remove = MagicMock() + mock_settings_objects.append = MagicMock() + mock_settings_objects.__iter__ = MagicMock(return_value=iter([self.test_client_settings])) + + new_settings = Settings(client="test_client", max_tokens=800, temperature=0.9) + result = settings.update_client_settings(new_settings, "test_client") + + assert result.client == "test_client" + mock_settings_objects.remove.assert_called_once_with(self.test_client_settings) + mock_settings_objects.append.assert_called_once() + + @patch("server.api.core.settings.bootstrap") + def test_update_server_config(self, mock_bootstrap): + """Test updating server configuration""" + # Use the valid sample config data that includes client_settings + settings.update_server_config(self.sample_config_data) + + assert hasattr(mock_bootstrap, "DATABASE_OBJECTS") + assert hasattr(mock_bootstrap, "MODEL_OBJECTS") + + @patch("server.api.core.settings.update_server_config") + @patch("server.api.core.settings.update_client_settings") + def test_load_config_from_json_data_with_client(self, mock_update_client, mock_update_server): + """Test loading config from JSON data with specific client""" + settings.load_config_from_json_data(self.sample_config_data, client="test_client") + + mock_update_server.assert_called_once_with(self.sample_config_data) + mock_update_client.assert_called_once() + + @patch("server.api.core.settings.update_server_config") + @patch("server.api.core.settings.update_client_settings") + def test_load_config_from_json_data_without_client(self, mock_update_client, mock_update_server): + """Test loading config from JSON data without specific client""" + settings.load_config_from_json_data(self.sample_config_data) + + mock_update_server.assert_called_once_with(self.sample_config_data) + # Should be called twice: once for "server" and once for "default" + assert mock_update_client.call_count == 2 + + @patch("server.api.core.settings.update_server_config") + def test_load_config_from_json_data_missing_client_settings(self, mock_update_server): + """Test loading config from JSON data without client_settings""" + # Create config without client_settings + invalid_config = {"database_configs": [], "model_configs": [], "oci_configs": [], "prompt_configs": []} + + with pytest.raises(KeyError, match="Missing client_settings in config file"): + settings.load_config_from_json_data(invalid_config) + + @patch.dict(os.environ, {"CONFIG_FILE": "/path/to/config.json"}) + @patch("os.path.isfile") + @patch("os.access") + @patch("builtins.open", mock_open(read_data='{"test": "data"}')) + @patch("json.load") + def test_read_config_from_json_file_success(self, mock_json_load, mock_access, mock_isfile): + """Test successful reading of config file""" + mock_isfile.return_value = True + mock_access.return_value = True + mock_json_load.return_value = self.sample_config_data + + result = settings.read_config_from_json_file() + + assert isinstance(result, Configuration) + mock_json_load.assert_called_once() + + @patch.dict(os.environ, {"CONFIG_FILE": "/path/to/nonexistent.json"}) + @patch("os.path.isfile") + def test_read_config_from_json_file_not_exists(self, mock_isfile): + """Test reading config file that doesn't exist""" + mock_isfile.return_value = False + + # This should still attempt to process, but will log a warning + # The actual behavior depends on the implementation + + @patch.dict(os.environ, {"CONFIG_FILE": "/path/to/config.txt"}) + def test_read_config_from_json_file_wrong_extension(self): + """Test reading config file with wrong extension""" + # This should log a warning about the file extension + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(settings, "logger") + assert settings.logger.name == "api.core.settings" diff --git a/tests/unit/server/api/utils/test_utils_chat.py b/tests/unit/server/api/utils/test_utils_chat.py new file mode 100644 index 00000000..d55f4ac1 --- /dev/null +++ b/tests/unit/server/api/utils/test_utils_chat.py @@ -0,0 +1,288 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import patch, MagicMock +import pytest + +from langchain_core.messages import ChatMessage + +from server.api.utils import chat +from server.api.core.models import UnknownModelError +from common.schema import ( + ChatRequest, + Settings, + LargeLanguageSettings, + VectorSearchSettings, + SelectAISettings, + PromptSettings, + OciSettings, +) + + +class TestChatUtils: + """Test chat utility functions""" + + def setup_method(self): + """Setup test data""" + self.sample_message = ChatMessage(role="user", content="Hello, how are you?") + self.sample_request = ChatRequest(messages=[self.sample_message], model="openai/gpt-4") + self.sample_client_settings = Settings( + client="test_client", + ll_model=LargeLanguageSettings( + model="openai/gpt-4", chat_history=True, temperature=0.7, max_completion_tokens=4096 + ), + vector_search=VectorSearchSettings(enabled=False), + selectai=SelectAISettings(enabled=False), + prompts=PromptSettings(sys="Basic Example", ctx="Basic Example"), + oci=OciSettings(auth_profile="DEFAULT"), + ) + + @patch("server.api.core.settings.get_client_settings") + @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.models.get_litellm_config") + @patch("server.api.core.prompts.get_prompts") + @patch("server.agents.chatbot.chatbot_graph.astream") + @pytest.mark.asyncio + async def test_completion_generator_success( + self, mock_astream, mock_get_prompts, mock_get_litellm_config, mock_get_oci, mock_get_client_settings + ): + """Test successful completion generation""" + # Setup mocks + mock_get_client_settings.return_value = self.sample_client_settings + mock_get_oci.return_value = MagicMock() + mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} + mock_get_prompts.return_value = MagicMock(prompt="You are a helpful assistant") + + # Mock the async generator - this should only yield the final completion for "completions" mode + async def mock_generator(): + yield {"stream": "Hello"} + yield {"stream": " there"} + yield {"completion": "Hello there"} + + mock_astream.return_value = mock_generator() + + # Test the function + results = [] + async for result in chat.completion_generator("test_client", self.sample_request, "completions"): + results.append(result) + + # Verify results - for "completions" mode, we get stream chunks + final completion + assert len(results) == 3 + assert results[0] == b"Hello" # Stream chunks are encoded as bytes + assert results[1] == b" there" + assert results[2] == "Hello there" # Final completion is a string + mock_get_client_settings.assert_called_once_with("test_client") + mock_get_oci.assert_called_once_with(client="test_client") + + @patch("server.api.core.settings.get_client_settings") + @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.models.get_litellm_config") + @patch("server.api.core.prompts.get_prompts") + @patch("server.agents.chatbot.chatbot_graph.astream") + @pytest.mark.asyncio + async def test_completion_generator_streaming( + self, mock_astream, mock_get_prompts, mock_get_litellm_config, mock_get_oci, mock_get_client_settings + ): + """Test streaming completion generation""" + # Setup mocks + mock_get_client_settings.return_value = self.sample_client_settings + mock_get_oci.return_value = MagicMock() + mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} + mock_get_prompts.return_value = MagicMock(prompt="You are a helpful assistant") + + # Mock the async generator + async def mock_generator(): + yield {"stream": "Hello"} + yield {"stream": " there"} + yield {"completion": "Hello there"} + + mock_astream.return_value = mock_generator() + + # Test the function + results = [] + async for result in chat.completion_generator("test_client", self.sample_request, "streams"): + results.append(result) + + # Verify results - should include encoded stream chunks and finish marker + assert len(results) == 3 + assert results[0] == b"Hello" + assert results[1] == b" there" + assert results[2] == "[stream_finished]" + + @patch("server.api.core.settings.get_client_settings") + @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.models.get_litellm_config") + @patch("server.api.core.prompts.get_prompts") + @patch("server.api.utils.databases.get_client_database") + @patch("server.api.utils.models.get_client_embed") + @patch("server.agents.chatbot.chatbot_graph.astream") + @pytest.mark.asyncio + async def test_completion_generator_with_vector_search( + self, + mock_astream, + mock_get_client_embed, + mock_get_client_database, + mock_get_prompts, + mock_get_litellm_config, + mock_get_oci, + mock_get_client_settings, + ): + """Test completion generation with vector search enabled""" + # Setup settings with vector search enabled + vector_search_settings = self.sample_client_settings.model_copy() + vector_search_settings.vector_search.enabled = True + + # Setup mocks + mock_get_client_settings.return_value = vector_search_settings + mock_get_oci.return_value = MagicMock() + mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} + mock_get_prompts.return_value = MagicMock(prompt="You are a helpful assistant") + + mock_db = MagicMock() + mock_db.connection = MagicMock() + mock_get_client_database.return_value = mock_db + mock_get_client_embed.return_value = MagicMock() + + # Mock the async generator + async def mock_generator(): + yield {"completion": "Response with vector search"} + + mock_astream.return_value = mock_generator() + + # Test the function + results = [] + async for result in chat.completion_generator("test_client", self.sample_request, "completions"): + results.append(result) + + # Verify vector search setup + mock_get_client_database.assert_called_once_with("test_client", False) + mock_get_client_embed.assert_called_once() + assert len(results) == 1 + + @patch("server.api.core.settings.get_client_settings") + @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.models.get_litellm_config") + @patch("server.api.core.prompts.get_prompts") + @patch("server.api.utils.databases.get_client_database") + @patch("server.api.utils.selectai.set_profile") + @patch("server.agents.chatbot.chatbot_graph.astream") + @pytest.mark.asyncio + async def test_completion_generator_with_selectai( + self, + mock_astream, + mock_set_profile, + mock_get_client_database, + mock_get_prompts, + mock_get_litellm_config, + mock_get_oci, + mock_get_client_settings, + ): + """Test completion generation with SelectAI enabled""" + # Setup settings with SelectAI enabled + selectai_settings = self.sample_client_settings.model_copy() + selectai_settings.selectai.enabled = True + selectai_settings.selectai.profile = "TEST_PROFILE" + + # Setup mocks + mock_get_client_settings.return_value = selectai_settings + mock_get_oci.return_value = MagicMock() + mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} + mock_get_prompts.return_value = MagicMock(prompt="You are a helpful assistant") + + mock_db = MagicMock() + mock_db.connection = MagicMock() + mock_get_client_database.return_value = mock_db + + # Mock the async generator + async def mock_generator(): + yield {"completion": "Response with SelectAI"} + + mock_astream.return_value = mock_generator() + + # Test the function + results = [] + async for result in chat.completion_generator("test_client", self.sample_request, "completions"): + results.append(result) + + # Verify SelectAI setup + mock_get_client_database.assert_called_once_with("test_client", False) + # Should set profile parameters + assert mock_set_profile.call_count == 2 # temperature and max_tokens + assert len(results) == 1 + + @patch("server.api.core.settings.get_client_settings") + @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.models.get_litellm_config") + @patch("server.api.core.prompts.get_prompts") + @patch("server.agents.chatbot.chatbot_graph.astream") + @pytest.mark.asyncio + async def test_completion_generator_no_model_specified( + self, mock_astream, mock_get_prompts, mock_get_litellm_config, mock_get_oci, mock_get_client_settings + ): + """Test completion generation when no model is specified in request""" + # Create request without model + request_no_model = ChatRequest(messages=[self.sample_message], model=None) + + # Setup mocks + mock_get_client_settings.return_value = self.sample_client_settings + mock_get_oci.return_value = MagicMock() + mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} + mock_get_prompts.return_value = MagicMock(prompt="You are a helpful assistant") + + # Mock the async generator + async def mock_generator(): + yield {"completion": "Response using default model"} + + mock_astream.return_value = mock_generator() + + # Test the function + results = [] + async for result in chat.completion_generator("test_client", request_no_model, "completions"): + results.append(result) + + # Should use model from client settings + assert len(results) == 1 + + @patch("server.api.core.settings.get_client_settings") + @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.models.get_litellm_config") + @patch("server.api.core.prompts.get_prompts") + @patch("server.agents.chatbot.chatbot_graph.astream") + @pytest.mark.asyncio + async def test_completion_generator_custom_prompts( + self, mock_astream, mock_get_prompts, mock_get_litellm_config, mock_get_oci, mock_get_client_settings + ): + """Test completion generation with custom prompts""" + # Setup settings with custom prompts + custom_settings = self.sample_client_settings.model_copy() + custom_settings.prompts.sys = "Custom System" + custom_settings.prompts.ctx = "Custom Context" + + # Setup mocks + mock_get_client_settings.return_value = custom_settings + mock_get_oci.return_value = MagicMock() + mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} + mock_get_prompts.return_value = MagicMock(prompt="Custom prompt") + + # Mock the async generator + async def mock_generator(): + yield {"completion": "Response with custom prompts"} + + mock_astream.return_value = mock_generator() + + # Test the function + results = [] + async for result in chat.completion_generator("test_client", self.sample_request, "completions"): + results.append(result) + + # Verify custom prompts are used + mock_get_prompts.assert_called_with(category="sys", name="Custom System") + assert len(results) == 1 + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(chat, "logger") + assert chat.logger.name == "api.utils.chat" diff --git a/tests/unit/server/api/utils/test_utils_databases.py b/tests/unit/server/api/utils/test_utils_databases.py new file mode 100644 index 00000000..812c3468 --- /dev/null +++ b/tests/unit/server/api/utils/test_utils_databases.py @@ -0,0 +1,736 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable +# pylint: disable=protected-access,import-error,too-many-public-methods,attribute-defined-outside-init + +import json +from unittest.mock import patch, MagicMock + +import pytest +import oracledb +from conftest import TEST_CONFIG + +from server.api.utils import databases +from server.api.utils.databases import DbException +from server.api.core import bootstrap +from common.schema import Database + + +class TestDbException: + """Test custom database exception class""" + + def test_db_exception_initialization(self): + """Test DbException initialization""" + exc = DbException(status_code=500, detail="Database error") + assert exc.status_code == 500 + assert exc.detail == "Database error" + assert str(exc) == "Database error" + + def test_db_exception_inheritance(self): + """Test DbException inherits from Exception""" + exc = DbException(status_code=404, detail="Not found") + assert isinstance(exc, Exception) + + def test_db_exception_different_status_codes(self): + """Test DbException with different status codes""" + test_cases = [ + (400, "Bad request"), + (401, "Unauthorized"), + (403, "Forbidden"), + (503, "Service unavailable"), + ] + + for status_code, detail in test_cases: + exc = DbException(status_code=status_code, detail=detail) + assert exc.status_code == status_code + assert exc.detail == detail + + +class TestDatabaseUtilsPrivateFunctions: + """Test private utility functions""" + + def setup_method(self): + """Setup test data""" + self.sample_database = Database( + name="test_db", + user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], + dsn=TEST_CONFIG["db_dsn"] + ) + + def test_test_function_success(self, db_container): + """Test successful database connection test with real database""" + assert db_container is not None + # Connect to real database + conn = databases.connect(self.sample_database) + self.sample_database.set_connection(conn) + + try: + # Test the connection + databases._test(self.sample_database) + assert self.sample_database.connected is True + finally: + databases.disconnect(conn) + + @patch("oracledb.Connection") + def test_test_function_reconnect(self, mock_connection): + """Test database reconnection when ping fails""" + mock_connection.ping.side_effect = oracledb.DatabaseError("Connection lost") + self.sample_database.set_connection(mock_connection) + + with patch("server.api.utils.databases.connect") as mock_connect: + databases._test(self.sample_database) + mock_connect.assert_called_once_with(self.sample_database) + + @patch("oracledb.Connection") + def test_test_function_value_error(self, mock_connection): + """Test handling of value errors""" + mock_connection.ping.side_effect = ValueError("Invalid value") + self.sample_database.set_connection(mock_connection) + + with pytest.raises(DbException) as exc_info: + databases._test(self.sample_database) + + assert exc_info.value.status_code == 400 + assert "Database: Invalid value" in str(exc_info.value) + + @patch("oracledb.Connection") + def test_test_function_permission_error(self, mock_connection): + """Test handling of permission errors""" + mock_connection.ping.side_effect = PermissionError("Access denied") + self.sample_database.set_connection(mock_connection) + + with pytest.raises(DbException) as exc_info: + databases._test(self.sample_database) + + assert exc_info.value.status_code == 401 + assert "Database: Access denied" in str(exc_info.value) + + @patch("oracledb.Connection") + def test_test_function_connection_error(self, mock_connection): + """Test handling of connection errors""" + mock_connection.ping.side_effect = ConnectionError("Connection failed") + self.sample_database.set_connection(mock_connection) + + with pytest.raises(DbException) as exc_info: + databases._test(self.sample_database) + + assert exc_info.value.status_code == 503 + assert "Database: Connection failed" in str(exc_info.value) + + @patch("oracledb.Connection") + def test_test_function_generic_exception(self, mock_connection): + """Test handling of generic exceptions""" + mock_connection.ping.side_effect = RuntimeError("Unknown error") + self.sample_database.set_connection(mock_connection) + + with pytest.raises(DbException) as exc_info: + databases._test(self.sample_database) + + assert exc_info.value.status_code == 500 + assert "Unknown error" in str(exc_info.value) + + def test_get_vs_with_real_database(self, db_container): + """Test vector storage retrieval with real database""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + # Test with empty result (no vector stores initially) + result = databases._get_vs(conn) + assert isinstance(result, list) + assert len(result) == 0 # Initially no vector stores + finally: + databases.disconnect(conn) + + @patch("server.api.utils.databases.execute_sql") + def test_get_vs_with_mock_data(self, mock_execute_sql): + """Test vector storage retrieval with mocked data""" + mock_connection = MagicMock() + mock_execute_sql.return_value = [ + ( + "TEST_TABLE", + '{"alias": "test_alias", "model": "test_model", "chunk_size": 1000, "distance_metric": "COSINE"}', + ), + ( + "ANOTHER_TABLE", + '{"alias": "another_alias", "model": "another_model", ' + '"chunk_size": 500, "distance_metric": "EUCLIDEAN_DISTANCE"}' + ) + ] + + result = databases._get_vs(mock_connection) + + assert len(result) == 2 + assert result[0].vector_store == "TEST_TABLE" + assert result[0].alias == "test_alias" + assert result[0].model == "test_model" + assert result[0].chunk_size == 1000 + assert result[0].distance_metric == "COSINE" + + assert result[1].vector_store == "ANOTHER_TABLE" + assert result[1].alias == "another_alias" + assert result[1].distance_metric == "EUCLIDEAN_DISTANCE" + + @patch("server.api.utils.databases.execute_sql") + def test_get_vs_empty_result(self, mock_execute_sql): + """Test vector storage retrieval with empty results""" + mock_connection = MagicMock() + mock_execute_sql.return_value = [] + + result = databases._get_vs(mock_connection) + + assert isinstance(result, list) + assert len(result) == 0 + + @patch("server.api.utils.databases.execute_sql") + def test_get_vs_malformed_json(self, mock_execute_sql): + """Test vector storage retrieval with malformed JSON""" + mock_connection = MagicMock() + mock_execute_sql.return_value = [ + ("TEST_TABLE", '{"invalid_json": }'), + ] + + with pytest.raises(json.JSONDecodeError): + databases._get_vs(mock_connection) + + def test_selectai_enabled_with_real_database(self, db_container): + """Test SelectAI enabled check with real database""" + conn = databases.connect(self.sample_database) + + try: + # Test with real database (likely returns False for test environment) + result = databases._selectai_enabled(conn) + assert isinstance(result, bool) + # We don't assert the specific value as it depends on the database setup + finally: + databases.disconnect(conn) + + @patch("server.api.utils.databases.execute_sql") + def test_selectai_enabled_true(self, mock_execute_sql): + """Test SelectAI enabled check returns True""" + mock_connection = MagicMock() + mock_execute_sql.return_value = [(3,)] + + result = databases._selectai_enabled(mock_connection) + + assert result is True + + @patch("server.api.utils.databases.execute_sql") + def test_selectai_enabled_false(self, mock_execute_sql): + """Test SelectAI enabled check returns False""" + mock_connection = MagicMock() + mock_execute_sql.return_value = [(2,)] + + result = databases._selectai_enabled(mock_connection) + + assert result is False + + @patch("server.api.utils.databases.execute_sql") + def test_selectai_enabled_zero_privileges(self, mock_execute_sql): + """Test SelectAI enabled check with zero privileges""" + mock_connection = MagicMock() + mock_execute_sql.return_value = [(0,)] + + result = databases._selectai_enabled(mock_connection) + + assert result is False + + def test_get_selectai_profiles_with_real_database(self, db_container): + """Test SelectAI profiles retrieval with real database""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + # Test with real database (likely returns empty list for test environment) + result = databases._get_selectai_profiles(conn) + assert isinstance(result, list) + # We don't assert the specific content as it depends on the database setup + finally: + databases.disconnect(conn) + + @patch("server.api.utils.databases.execute_sql") + def test_get_selectai_profiles_with_data(self, mock_execute_sql): + """Test SelectAI profiles retrieval with data""" + mock_connection = MagicMock() + mock_execute_sql.return_value = [("PROFILE1",), ("PROFILE2",), ("PROFILE3",)] + + result = databases._get_selectai_profiles(mock_connection) + + assert result == ["PROFILE1", "PROFILE2", "PROFILE3"] + + @patch("server.api.utils.databases.execute_sql") + def test_get_selectai_profiles_empty(self, mock_execute_sql): + """Test SelectAI profiles retrieval with no profiles""" + mock_connection = MagicMock() + mock_execute_sql.return_value = [] + + result = databases._get_selectai_profiles(mock_connection) + + assert result == [] + + @patch("server.api.utils.databases.execute_sql") + def test_get_selectai_profiles_none_result(self, mock_execute_sql): + """Test SelectAI profiles retrieval with None results""" + mock_connection = MagicMock() + mock_execute_sql.return_value = None + + result = databases._get_selectai_profiles(mock_connection) + + assert result == [] + + +class TestDatabaseUtilsPublicFunctions: + """Test public utility functions""" + + def setup_method(self): + """Setup test data""" + self.sample_database = Database( + name="test_db", + user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], + dsn=TEST_CONFIG["db_dsn"] + ) + + def test_connect_success_with_real_database(self, db_container): + """Test successful database connection with real database""" + assert db_container is not None + result = databases.connect(self.sample_database) + + try: + assert result is not None + assert isinstance(result, oracledb.Connection) + # Test that connection is active + result.ping() + finally: + databases.disconnect(result) + + def test_connect_missing_user(self): + """Test connection with missing user""" + incomplete_db = Database( + name="test_db", + user="", # Missing user + password=TEST_CONFIG["db_password"], + dsn=TEST_CONFIG["db_dsn"], + ) + + with pytest.raises(ValueError, match="missing connection details"): + databases.connect(incomplete_db) + + def test_connect_missing_password(self): + """Test connection with missing password""" + incomplete_db = Database( + name="test_db", + user=TEST_CONFIG["db_username"], + password="", # Missing password + dsn=TEST_CONFIG["db_dsn"], + ) + + with pytest.raises(ValueError, match="missing connection details"): + databases.connect(incomplete_db) + + def test_connect_missing_dsn(self): + """Test connection with missing DSN""" + incomplete_db = Database( + name="test_db", + user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], + dsn="", # Missing DSN + ) + + with pytest.raises(ValueError, match="missing connection details"): + databases.connect(incomplete_db) + + def test_connect_with_wallet_configuration(self, db_container): + """Test connection with wallet configuration""" + assert db_container is not None + db_with_wallet = Database( + name="test_db", + user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], + dsn=TEST_CONFIG["db_dsn"], + wallet_password="wallet_pass", + config_dir="/path/to/config", + ) + + # This should attempt to connect but may fail due to wallet config + # The test verifies the code path works, not necessarily successful connection + try: + result = databases.connect(db_with_wallet) + databases.disconnect(result) + except oracledb.DatabaseError: + # Expected if wallet doesn't exist + pass + + def test_connect_wallet_password_without_location(self, db_container): + """Test connection with wallet password but no location""" + assert db_container is not None + db_with_wallet = Database( + name="test_db", + user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], + dsn=TEST_CONFIG["db_dsn"], + wallet_password="wallet_pass", + config_dir="/default/config", + ) + + # This should set wallet_location to config_dir + try: + result = databases.connect(db_with_wallet) + databases.disconnect(result) + except oracledb.DatabaseError: + # Expected if wallet doesn't exist + pass + + def test_connect_invalid_credentials(self, db_container): + """Test connection with invalid credentials""" + assert db_container is not None + invalid_db = Database( + name="test_db", + user="invalid_user", + password="invalid_password", + dsn=TEST_CONFIG["db_dsn"], + ) + + with pytest.raises(PermissionError): + databases.connect(invalid_db) + + def test_connect_invalid_dsn(self, db_container): + """Test connection with invalid DSN""" + assert db_container is not None + invalid_db = Database( + name="test_db", + user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], + dsn="//invalid:1521/INVALID", + ) + + # This will raise socket.gaierror which is wrapped in oracledb.DatabaseError + with pytest.raises(Exception): # Catch any exception - DNS resolution errors vary by environment + databases.connect(invalid_db) + + def test_disconnect_success(self, db_container): + """Test successful database disconnection""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + result = databases.disconnect(conn) + + assert result is None + # Try to use connection after disconnect - should fail + with pytest.raises(oracledb.InterfaceError): + conn.ping() + + def test_execute_sql_success_with_real_database(self, db_container): + """Test successful SQL execution with real database""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + # Test simple query + result = databases.execute_sql(conn, "SELECT 1 FROM DUAL") + assert result is not None + assert len(result) == 1 + assert result[0][0] == 1 + finally: + databases.disconnect(conn) + + def test_execute_sql_with_binds(self, db_container): + """Test SQL execution with bind variables using real database""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + binds = {"test_value": 42} + result = databases.execute_sql(conn, "SELECT :test_value FROM DUAL", binds) + assert result is not None + assert len(result) == 1 + assert result[0][0] == 42 + finally: + databases.disconnect(conn) + + def test_execute_sql_no_rows(self, db_container): + """Test SQL execution that returns no rows""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + # Test query with no results + result = databases.execute_sql(conn, "SELECT 1 FROM DUAL WHERE 1=0") + assert result == [] + finally: + databases.disconnect(conn) + + def test_execute_sql_ddl_statement(self, db_container): + """Test SQL execution with DDL statement""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + # Create a test table + databases.execute_sql(conn, "CREATE TABLE test_temp (id NUMBER)") + + # Drop the test table + result = databases.execute_sql(conn, "DROP TABLE test_temp") + # DDL statements typically return None + assert result is None + except oracledb.DatabaseError as e: + # If table already exists or other DDL error, that's okay for testing + if "name is already used" not in str(e): + raise + finally: + # Clean up if table still exists + try: + databases.execute_sql(conn, "DROP TABLE test_temp") + except oracledb.DatabaseError: + pass # Table doesn't exist, which is fine + databases.disconnect(conn) + + def test_execute_sql_table_exists_error(self, db_container): + """Test SQL execution with table exists error (ORA-00955)""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + # Create table twice to trigger ORA-00955 + databases.execute_sql(conn, "CREATE TABLE test_exists (id NUMBER)") + + # This should log but not raise an exception + databases.execute_sql(conn, "CREATE TABLE test_exists (id NUMBER)") + + except oracledb.DatabaseError: + # Expected behavior - the function should handle this gracefully + pass + finally: + try: + databases.execute_sql(conn, "DROP TABLE test_exists") + except oracledb.DatabaseError: + pass + databases.disconnect(conn) + + def test_execute_sql_table_not_exists_error(self, db_container): + """Test SQL execution with table not exists error (ORA-00942)""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + # Try to select from non-existent table to trigger ORA-00942 + databases.execute_sql(conn, "SELECT * FROM non_existent_table") + except oracledb.DatabaseError: + # Expected behavior - the function should handle this gracefully + pass + finally: + databases.disconnect(conn) + + def test_execute_sql_invalid_syntax(self, db_container): + """Test SQL execution with invalid syntax""" + assert db_container is not None + conn = databases.connect(self.sample_database) + + try: + with pytest.raises(oracledb.DatabaseError): + databases.execute_sql(conn, "INVALID SQL STATEMENT") + finally: + databases.disconnect(conn) + + def test_drop_vs_function_exists(self): + """Test that drop_vs function exists and is callable""" + assert hasattr(databases, "drop_vs") + assert callable(databases.drop_vs) + + @patch("langchain_community.vectorstores.oraclevs.drop_table_purge") + def test_drop_vs_calls_langchain(self, mock_drop_table): + """Test drop_vs calls LangChain drop_table_purge""" + mock_connection = MagicMock() + vs_name = "TEST_VECTOR_STORE" + + databases.drop_vs(mock_connection, vs_name) + + mock_drop_table.assert_called_once_with(mock_connection, vs_name) + + def test_get_databases_without_validation(self, db_container): + """Test get_databases without validation""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.append(self.sample_database) + + # Test getting all databases + result = databases.get_databases() + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].name == "test_db" + assert result[0].connected is False # No validation, so not connected + + finally: + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_get_databases_with_validation(self, db_container): + """Test get_databases with validation using real database""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.append(self.sample_database) + + # Test getting all databases with validation + result = databases.get_databases(validate=True) + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].name == "test_db" + assert result[0].connected is True # Validation should connect + assert result[0].connection is not None + + finally: + # Clean up connections + for db in bootstrap.DATABASE_OBJECTS: + if db.connection: + databases.disconnect(db.connection) + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_get_databases_by_name(self, db_container): + """Test get_databases by specific name""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + db1 = Database(name="db1", user="user1", password="pass1", dsn="dsn1") + db2 = Database(name="db2", user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], dsn=TEST_CONFIG["db_dsn"]) + bootstrap.DATABASE_OBJECTS.extend([db1, db2]) + + # Test getting specific database + result = databases.get_databases(db_name="db2") + assert isinstance(result, Database) # Single database, not list + assert result.name == "db2" + + finally: + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_get_databases_validation_failure(self, db_container): + """Test get_databases with validation when connection fails""" + assert db_container is not None + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + # Add database with invalid credentials + invalid_db = Database(name="invalid", user="invalid", password="invalid", dsn="invalid") + bootstrap.DATABASE_OBJECTS.append(invalid_db) + + # Test validation with invalid database (should continue without error) + result = databases.get_databases(validate=True) + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].connected is False # Should remain False due to connection failure + + finally: + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + @patch("server.api.core.settings.get_client_settings") + def test_get_client_database_default(self, mock_get_settings, db_container): + """Test get_client_database with default settings""" + assert db_container is not None + # Mock client settings without vector_search or selectai + mock_settings = MagicMock() + mock_settings.vector_search = None + mock_settings.selectai = None + mock_get_settings.return_value = mock_settings + + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + default_db = Database(name="DEFAULT", user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], dsn=TEST_CONFIG["db_dsn"]) + bootstrap.DATABASE_OBJECTS.append(default_db) + + result = databases.get_client_database("test_client") + assert isinstance(result, Database) + assert result.name == "DEFAULT" + + finally: + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + @patch("server.api.core.settings.get_client_settings") + def test_get_client_database_with_vector_search(self, mock_get_settings, db_container): + """Test get_client_database with vector_search settings""" + assert db_container is not None + # Mock client settings with vector_search + mock_vector_search = MagicMock() + mock_vector_search.database = "VECTOR_DB" + mock_settings = MagicMock() + mock_settings.vector_search = mock_vector_search + mock_settings.selectai = None + mock_get_settings.return_value = mock_settings + + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + vector_db = Database(name="VECTOR_DB", user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], dsn=TEST_CONFIG["db_dsn"]) + bootstrap.DATABASE_OBJECTS.append(vector_db) + + result = databases.get_client_database("test_client") + assert isinstance(result, Database) + assert result.name == "VECTOR_DB" + + finally: + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + @patch("server.api.core.settings.get_client_settings") + def test_get_client_database_with_validation(self, mock_get_settings, db_container): + """Test get_client_database with validation enabled""" + assert db_container is not None + # Mock client settings + mock_settings = MagicMock() + mock_settings.vector_search = None + mock_settings.selectai = None + mock_get_settings.return_value = mock_settings + + # Use real bootstrap DATABASE_OBJECTS + original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + + try: + bootstrap.DATABASE_OBJECTS.clear() + default_db = Database(name="DEFAULT", user=TEST_CONFIG["db_username"], + password=TEST_CONFIG["db_password"], dsn=TEST_CONFIG["db_dsn"]) + bootstrap.DATABASE_OBJECTS.append(default_db) + + result = databases.get_client_database("test_client", validate=True) + assert isinstance(result, Database) + assert result.name == "DEFAULT" + assert result.connected is True + assert result.connection is not None + + finally: + # Clean up connections + for db in bootstrap.DATABASE_OBJECTS: + if db.connection: + databases.disconnect(db.connection) + bootstrap.DATABASE_OBJECTS.clear() + bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(databases, "logger") + assert databases.logger.name == "api.utils.database" diff --git a/tests/unit/server/api/utils/test_utils_embed.py b/tests/unit/server/api/utils/test_utils_embed.py new file mode 100644 index 00000000..08364528 --- /dev/null +++ b/tests/unit/server/api/utils/test_utils_embed.py @@ -0,0 +1,84 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from pathlib import Path +from unittest.mock import patch, mock_open + +from langchain.docstore.document import Document as LangchainDocument + +from server.api.utils import embed + + +class TestEmbedUtils: + """Test embed utility functions""" + + def setup_method(self): + """Setup test data""" + self.sample_document = LangchainDocument( + page_content="This is a test document content.", metadata={"source": "/path/to/test_file.txt", "page": 1} + ) + self.sample_split_doc = LangchainDocument( + page_content="This is a chunk of content.", metadata={"source": "/path/to/test_file.txt", "start_index": 0} + ) + + @patch("pathlib.Path.exists") + @patch("pathlib.Path.is_dir") + @patch("pathlib.Path.mkdir") + def test_get_temp_directory_app_tmp(self, mock_mkdir, mock_is_dir, mock_exists): + """Test temp directory creation in /app/tmp""" + mock_exists.return_value = True + mock_is_dir.return_value = True + + result = embed.get_temp_directory("test_client", "embed") + + assert result == Path("/app/tmp") / "test_client" / "embed" + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + @patch("pathlib.Path.exists") + @patch("pathlib.Path.mkdir") + def test_get_temp_directory_tmp_fallback(self, mock_mkdir, mock_exists): + """Test temp directory creation fallback to /tmp""" + mock_exists.return_value = False + + result = embed.get_temp_directory("test_client", "embed") + + assert result == Path("/tmp") / "test_client" / "embed" + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.getsize") + @patch("json.dumps") + def test_doc_to_json_default_output(self, mock_json_dumps, mock_getsize, mock_file): + """Test document to JSON conversion with default output directory""" + mock_json_dumps.return_value = '{"test": "data"}' + mock_getsize.return_value = 100 + + result = embed.doc_to_json([self.sample_document], "/path/to/test_file.txt", "/tmp") + + mock_file.assert_called_once() + mock_json_dumps.assert_called_once() + mock_getsize.assert_called_once() + assert result.endswith("_test_file.json") + + @patch("builtins.open", new_callable=mock_open) + @patch("os.path.getsize") + @patch("json.dumps") + def test_doc_to_json_custom_output(self, mock_json_dumps, mock_getsize, mock_file): + """Test document to JSON conversion with custom output directory""" + mock_json_dumps.return_value = '{"test": "data"}' + mock_getsize.return_value = 100 + + result = embed.doc_to_json([self.sample_document], "/path/to/test_file.txt", "/custom/output") + + mock_file.assert_called_once() + mock_json_dumps.assert_called_once() + mock_getsize.assert_called_once() + assert result == "/custom/output/_test_file.json" + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(embed, "logger") + assert embed.logger.name == "api.utils.embed" diff --git a/tests/unit/server/api/utils/test_utils_models.py b/tests/unit/server/api/utils/test_utils_models.py new file mode 100644 index 00000000..8a2e6604 --- /dev/null +++ b/tests/unit/server/api/utils/test_utils_models.py @@ -0,0 +1,157 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import patch + +import pytest + +from server.api.utils import models +from server.api.core.models import UnknownModelError +from common.schema import Model, OracleCloudSettings + + +class TestModelsUtils: + """Test models utility functions""" + + def setup_method(self): + """Setup test data""" + self.sample_model = Model( + id="test-model", provider="openai", type="ll", enabled=True, api_base="https://api.openai.com" + ) + self.sample_oci_config = OracleCloudSettings( + auth_profile="DEFAULT", + compartment_id="ocid1.compartment.oc1..test", + genai_region="us-ashburn-1", + user="ocid1.user.oc1..testuser", + fingerprint="test-fingerprint", + tenancy="ocid1.tenancy.oc1..testtenant", + key_file="/path/to/key.pem", + ) + + @patch("server.api.core.models.get_model") + @patch("common.functions.is_url_accessible") + def test_update_success(self, mock_url_check, mock_get_model): + """Test successful model update""" + mock_get_model.return_value = self.sample_model + mock_url_check.return_value = (True, None) + + update_payload = Model( + id="test-model", + provider="openai", + type="ll", + enabled=True, + api_base="https://api.openai.com", + temperature=0.8, + ) + + result = models.update(update_payload) + + assert result.temperature == 0.8 + mock_get_model.assert_called_once_with(model_provider="openai", model_id="test-model") + + @patch("server.api.core.models.get_model") + def test_get_full_config_success(self, mock_get_model): + """Test successful full config retrieval""" + mock_get_model.return_value = self.sample_model + model_config = {"model": "openai/gpt-4", "temperature": 0.8} + + full_config, provider = models._get_full_config(model_config, self.sample_oci_config) + + assert provider == "openai" + assert full_config["temperature"] == 0.8 + assert full_config["id"] == "test-model" + mock_get_model.assert_called_once_with(model_provider="openai", model_id="gpt-4", include_disabled=False) + + @patch("server.api.core.models.get_model") + def test_get_full_config_unknown_model(self, mock_get_model): + """Test full config retrieval with unknown model""" + mock_get_model.side_effect = UnknownModelError("Model not found") + model_config = {"model": "unknown/model"} + + with pytest.raises(UnknownModelError): + models._get_full_config(model_config, self.sample_oci_config) + + @patch("server.api.utils.models._get_full_config") + @patch("litellm.get_supported_openai_params") + def test_get_litellm_config_basic(self, mock_get_params, mock_get_full_config): + """Test basic LiteLLM config generation""" + mock_get_full_config.return_value = ( + {"temperature": 0.7, "max_completion_tokens": 4096, "api_base": "https://api.openai.com"}, + "openai", + ) + mock_get_params.return_value = ["temperature", "max_completion_tokens"] + model_config = {"model": "openai/gpt-4"} + + result = models.get_litellm_config(model_config, self.sample_oci_config) + + assert result["model"] == "openai/gpt-4" + assert result["temperature"] == 0.7 + assert result["max_completion_tokens"] == 4096 + assert result["drop_params"] is True + + @patch("server.api.utils.models._get_full_config") + @patch("litellm.get_supported_openai_params") + def test_get_litellm_config_cohere(self, mock_get_params, mock_get_full_config): + """Test LiteLLM config generation for Cohere""" + mock_get_full_config.return_value = ({"api_base": "https://custom.cohere.com/v1"}, "cohere") + mock_get_params.return_value = [] + model_config = {"model": "cohere/command"} + + result = models.get_litellm_config(model_config, self.sample_oci_config) + + assert result["api_base"] == "https://api.cohere.ai/compatibility/v1" + assert result["model"] == "cohere/command" + + @patch("server.api.utils.models._get_full_config") + @patch("litellm.get_supported_openai_params") + def test_get_litellm_config_xai(self, mock_get_params, mock_get_full_config): + """Test LiteLLM config generation for xAI""" + mock_get_full_config.return_value = ( + {"temperature": 0.7, "presence_penalty": 0.1, "frequency_penalty": 0.1}, + "xai", + ) + mock_get_params.return_value = ["temperature", "presence_penalty", "frequency_penalty"] + model_config = {"model": "xai/grok"} + + result = models.get_litellm_config(model_config, self.sample_oci_config) + + assert result["temperature"] == 0.7 + assert "presence_penalty" not in result + assert "frequency_penalty" not in result + + @patch("server.api.utils.models._get_full_config") + @patch("litellm.get_supported_openai_params") + def test_get_litellm_config_oci(self, mock_get_params, mock_get_full_config): + """Test LiteLLM config generation for OCI""" + mock_get_full_config.return_value = ({"temperature": 0.7}, "oci") + mock_get_params.return_value = ["temperature"] + model_config = {"model": "oci/cohere.command"} + + result = models.get_litellm_config(model_config, self.sample_oci_config) + + assert result["oci_user"] == "ocid1.user.oc1..testuser" + assert result["oci_fingerprint"] == "test-fingerprint" + assert result["oci_tenancy"] == "ocid1.tenancy.oc1..testtenant" + assert result["oci_region"] == "us-ashburn-1" + assert result["oci_key_file"] == "/path/to/key.pem" + + @patch("server.api.utils.models._get_full_config") + @patch("litellm.get_supported_openai_params") + def test_get_litellm_config_giskard(self, mock_get_params, mock_get_full_config): + """Test LiteLLM config generation for Giskard""" + mock_get_full_config.return_value = ({"temperature": 0.7, "model": "test-model"}, "openai") + mock_get_params.return_value = ["temperature", "model"] + model_config = {"model": "openai/gpt-4"} + + result = models.get_litellm_config(model_config, self.sample_oci_config, giskard=True) + + assert "model" not in result + assert "temperature" not in result + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(models, "logger") + assert models.logger.name == "api.utils.models" diff --git a/tests/unit/server/api/utils/test_utils_oci.py b/tests/unit/server/api/utils/test_utils_oci.py new file mode 100644 index 00000000..8c2157be --- /dev/null +++ b/tests/unit/server/api/utils/test_utils_oci.py @@ -0,0 +1,115 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import patch, MagicMock + +import pytest +import oci + +from server.api.utils import oci as oci_utils +from server.api.utils.oci import OciException +from common.schema import OracleCloudSettings + + +class TestOciException: + """Test custom OCI exception class""" + + def test_oci_exception_initialization(self): + """Test OciException initialization""" + exc = OciException(status_code=400, detail="Invalid configuration") + assert exc.status_code == 400 + assert exc.detail == "Invalid configuration" + assert str(exc) == "Invalid configuration" + + +class TestOciUtils: + """Test OCI utility functions""" + + def setup_method(self): + """Setup test data""" + self.sample_oci_config = OracleCloudSettings( + auth_profile="DEFAULT", + compartment_id="ocid1.compartment.oc1..test", + genai_region="us-ashburn-1", + user="ocid1.user.oc1..testuser", + fingerprint="test-fingerprint", + tenancy="ocid1.tenancy.oc1..testtenant", + key_file="/path/to/key.pem", + ) + + def test_init_genai_client(self): + """Test GenAI client initialization""" + with patch.object(oci_utils, "init_client") as mock_init_client: + mock_client = MagicMock() + mock_init_client.return_value = mock_client + + result = oci_utils.init_genai_client(self.sample_oci_config) + + assert result == mock_client + mock_init_client.assert_called_once_with( + oci.generative_ai_inference.GenerativeAiInferenceClient, self.sample_oci_config + ) + + @patch.object(oci_utils, "init_client") + def test_get_namespace_success(self, mock_init_client): + """Test successful namespace retrieval""" + mock_client = MagicMock() + mock_client.get_namespace.return_value.data = "test-namespace" + mock_init_client.return_value = mock_client + + result = oci_utils.get_namespace(self.sample_oci_config) + + assert result == "test-namespace" + assert self.sample_oci_config.namespace == "test-namespace" + + @patch.object(oci_utils, "init_client") + def test_get_namespace_invalid_config(self, mock_init_client): + """Test namespace retrieval with invalid config""" + mock_client = MagicMock() + mock_client.get_namespace.side_effect = oci.exceptions.InvalidConfig("Invalid config") + mock_init_client.return_value = mock_client + + with pytest.raises(OciException) as exc_info: + oci_utils.get_namespace(self.sample_oci_config) + + assert exc_info.value.status_code == 400 + assert "Invalid Config" in str(exc_info.value) + + @patch.object(oci_utils, "init_client") + def test_get_namespace_file_not_found(self, mock_init_client): + """Test namespace retrieval with file not found error""" + mock_init_client.side_effect = FileNotFoundError("Key file not found") + + with pytest.raises(OciException) as exc_info: + oci_utils.get_namespace(self.sample_oci_config) + + assert exc_info.value.status_code == 400 + assert "Invalid Key Path" in str(exc_info.value) + + @patch.object(oci_utils, "init_client") + def test_get_regions_success(self, mock_init_client): + """Test successful regions retrieval""" + mock_client = MagicMock() + mock_region = MagicMock() + mock_region.is_home_region = True + mock_region.region_key = "IAD" + mock_region.region_name = "us-ashburn-1" + mock_region.status = "READY" + mock_client.list_region_subscriptions.return_value.data = [mock_region] + mock_init_client.return_value = mock_client + + result = oci_utils.get_regions(self.sample_oci_config) + + assert len(result) == 1 + assert result[0]["is_home_region"] is True + assert result[0]["region_key"] == "IAD" + assert result[0]["region_name"] == "us-ashburn-1" + assert result[0]["status"] == "READY" + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(oci_utils, "logger") + assert oci_utils.logger.name == "api.utils.oci" diff --git a/tests/unit/server/api/utils/test_utils_testbed.py b/tests/unit/server/api/utils/test_utils_testbed.py new file mode 100644 index 00000000..828ebf8c --- /dev/null +++ b/tests/unit/server/api/utils/test_utils_testbed.py @@ -0,0 +1,93 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import patch, MagicMock +import json + +import pytest +from oracledb import Connection + +from server.api.utils import testbed + + +class TestTestbedUtils: + """Test testbed utility functions""" + + def setup_method(self): + """Setup test data""" + self.mock_connection = MagicMock(spec=Connection) + self.sample_qa_data = { + "question": "What is the capital of France?", + "answer": "Paris", + "context": "France is a country in Europe.", + } + + def test_jsonl_to_json_content_single_json(self): + """Test converting single JSON object to JSON content""" + content = '{"key": "value"}' + result = testbed.jsonl_to_json_content(content) + expected = json.dumps({"key": "value"}) + assert result == expected + + def test_jsonl_to_json_content_jsonl_multiple_lines(self): + """Test converting JSONL with multiple lines to JSON content""" + content = '{"line": 1}\n{"line": 2}\n{"line": 3}' + result = testbed.jsonl_to_json_content(content) + expected = json.dumps([{"line": 1}, {"line": 2}, {"line": 3}]) + assert result == expected + + def test_jsonl_to_json_content_jsonl_single_line(self): + """Test converting JSONL with single line to JSON content""" + content = '{"single": "line"}' + result = testbed.jsonl_to_json_content(content) + expected = json.dumps({"single": "line"}) + assert result == expected + + def test_jsonl_to_json_content_bytes_input(self): + """Test converting bytes JSONL content to JSON""" + content = b'{"bytes": "content"}' + result = testbed.jsonl_to_json_content(content) + expected = json.dumps({"bytes": "content"}) + assert result == expected + + def test_jsonl_to_json_content_invalid_json(self): + """Test handling invalid JSON content""" + content = '{"invalid": json}' + with pytest.raises(ValueError, match="Invalid JSONL content"): + testbed.jsonl_to_json_content(content) + + def test_jsonl_to_json_content_empty_content(self): + """Test handling empty content""" + content = "" + with pytest.raises(ValueError, match="Invalid JSONL content"): + testbed.jsonl_to_json_content(content) + + def test_jsonl_to_json_content_whitespace_content(self): + """Test handling whitespace-only content""" + content = " \n \n " + with pytest.raises(ValueError, match="Invalid JSONL content"): + testbed.jsonl_to_json_content(content) + + @patch("server.api.utils.databases.execute_sql") + def test_create_testset_objects(self, mock_execute_sql): + """Test creating testset database objects""" + mock_execute_sql.return_value = [] + + testbed.create_testset_objects(self.mock_connection) + + # Should execute 3 SQL statements (testsets, testset_qa, evaluations tables) + assert mock_execute_sql.call_count == 3 + + # Verify table creation statements + call_args_list = mock_execute_sql.call_args_list + assert "oai_testsets" in call_args_list[0][0][1] + assert "oai_testset_qa" in call_args_list[1][0][1] + assert "oai_evaluations" in call_args_list[2][0][1] + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(testbed, "logger") + assert testbed.logger.name == "api.utils.testbed"