|
4 | 4 | """ |
5 | 5 |
|
6 | 6 | import os |
7 | | -import openai |
8 | 7 | import sqlalchemy as sa |
9 | 8 |
|
10 | 9 | from dotenv import load_dotenv |
11 | | -from langchain_openai import AzureOpenAIEmbeddings |
12 | | -from langchain_openai import OpenAIEmbeddings |
13 | | -from llama_index.llms.azure_openai import AzureOpenAI |
14 | | -from llama_index.llms.openai import OpenAI |
15 | | -from llama_index.embeddings.langchain import LangchainEmbedding |
16 | 10 | from llama_index.core.utilities.sql_wrapper import SQLDatabase |
17 | 11 | from llama_index.core.query_engine import NLSQLTableQueryEngine |
18 | 12 | from llama_index.core import Settings |
19 | 13 |
|
20 | | - |
21 | | -def configure_llm(): |
22 | | - """ |
23 | | - Configure LLM. Use either vanilla Open AI, or Azure Open AI. |
24 | | - """ |
25 | | - |
26 | | - openai.api_type = os.getenv("OPENAI_API_TYPE") |
27 | | - openai.azure_endpoint = os.getenv("OPENAI_AZURE_ENDPOINT") |
28 | | - openai.api_version = os.getenv("OPENAI_AZURE_API_VERSION") |
29 | | - openai.api_key = os.getenv("OPENAI_API_KEY") |
30 | | - |
31 | | - if openai.api_type == "openai": |
32 | | - llm = OpenAI( |
33 | | - api_key=os.getenv("OPENAI_API_KEY"), |
34 | | - temperature=0.0 |
35 | | - ) |
36 | | - elif openai.api_type == "azure": |
37 | | - llm = AzureOpenAI( |
38 | | - engine=os.getenv("LLM_INSTANCE"), |
39 | | - azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT"), |
40 | | - api_key = os.getenv("OPENAI_API_KEY"), |
41 | | - api_version = os.getenv("OPENAI_AZURE_API_VERSION"), |
42 | | - temperature=0.0 |
43 | | - ) |
44 | | - else: |
45 | | - raise ValueError(f"Open AI API type not defined or invalid: {openai.api_type}") |
46 | | - |
47 | | - Settings.llm = llm |
48 | | - if openai.api_type == "openai": |
49 | | - Settings.embed_model = LangchainEmbedding(OpenAIEmbeddings()) |
50 | | - elif openai.api_type == "azure": |
51 | | - Settings.embed_model = LangchainEmbedding( |
52 | | - AzureOpenAIEmbeddings( |
53 | | - azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT"), |
54 | | - model=os.getenv("EMBEDDING_MODEL_INSTANCE") |
55 | | - ) |
56 | | - ) |
| 14 | +from boot import configure_llm |
57 | 15 |
|
58 | 16 |
|
59 | 17 | def main(): |
|
0 commit comments