Skip to content

Commit 6f7cd57

Browse files
committed
LlamaIndex: Make model configurable
1 parent 08e0167 commit 6f7cd57

File tree

3 files changed

+27
-14
lines changed

3 files changed

+27
-14
lines changed
Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
import os
2+
from typing import Tuple
23

34
import openai
45
from langchain_openai import AzureOpenAIEmbeddings
56
from langchain_openai import OpenAIEmbeddings
6-
from llama_index.core import Settings
7+
from llama_index.core.base.embeddings.base import BaseEmbedding
8+
from llama_index.core.llms import LLM
79
from llama_index.llms.azure_openai import AzureOpenAI
810
from llama_index.llms.openai import OpenAI
911
from llama_index.embeddings.langchain import LangchainEmbedding
1012

1113

12-
def configure_llm():
14+
MODEL_NAME = "gpt-4o"
15+
16+
17+
def configure_llm() -> Tuple[LLM, BaseEmbedding]:
1318
"""
1419
Configure LLM. Use either vanilla Open AI, or Azure Open AI.
1520
"""
@@ -21,27 +26,32 @@ def configure_llm():
2126

2227
if openai.api_type == "openai":
2328
llm = OpenAI(
29+
model=MODEL_NAME,
30+
temperature=0.0,
2431
api_key=os.getenv("OPENAI_API_KEY"),
25-
temperature=0.0
2632
)
2733
elif openai.api_type == "azure":
2834
llm = AzureOpenAI(
35+
model=MODEL_NAME,
36+
temperature=0.0,
2937
engine=os.getenv("LLM_INSTANCE"),
3038
azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT"),
3139
api_key = os.getenv("OPENAI_API_KEY"),
3240
api_version = os.getenv("OPENAI_AZURE_API_VERSION"),
33-
temperature=0.0
3441
)
3542
else:
3643
raise ValueError(f"Open AI API type not defined or invalid: {openai.api_type}")
3744

38-
Settings.llm = llm
3945
if openai.api_type == "openai":
40-
Settings.embed_model = LangchainEmbedding(OpenAIEmbeddings())
46+
embed_model = LangchainEmbedding(OpenAIEmbeddings(model=MODEL_NAME))
4147
elif openai.api_type == "azure":
42-
Settings.embed_model = LangchainEmbedding(
48+
embed_model = LangchainEmbedding(
4349
AzureOpenAIEmbeddings(
4450
azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT"),
4551
model=os.getenv("EMBEDDING_MODEL_INSTANCE")
4652
)
4753
)
54+
else:
55+
embed_model = None
56+
57+
return llm, embed_model

topic/machine-learning/llama-index/demo_mcp.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,17 @@
2424

2525
from dotenv import load_dotenv
2626
from llama_index.core.agent.workflow import FunctionAgent
27-
from llama_index.llms.openai import OpenAI
27+
from llama_index.core.llms import LLM
2828
from llama_index.tools.mcp import BasicMCPClient, McpToolSpec
2929

3030
from boot import configure_llm
3131

3232

3333
class Agent:
3434

35+
def __init__(self, llm: LLM):
36+
self.llm = llm
37+
3538
async def get_tools(self):
3639
# Connect to the CrateDB MCP server using `streamable-http` transport.
3740
mcp_url = os.getenv("CRATEDB_MCP_URL", "http://127.0.0.1:8000/mcp/")
@@ -49,7 +52,7 @@ async def get_agent(self):
4952
return FunctionAgent(
5053
name="Agent",
5154
description="CrateDB text-to-SQL agent",
52-
llm=OpenAI(model="gpt-4o"),
55+
llm=self.llm,
5356
tools=await self.get_tools(),
5457
system_prompt=Instructions.full(),
5558
)
@@ -69,10 +72,10 @@ def main():
6972

7073
# Configure application.
7174
load_dotenv()
72-
configure_llm()
75+
llm, embed_model = configure_llm()
7376

7477
# Use an agent that uses the CrateDB MCP server.
75-
agent = Agent()
78+
agent = Agent(llm)
7679

7780
# Invoke an inquiry.
7881
print("Running query")

topic/machine-learning/llama-index/demo_nlsql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from dotenv import load_dotenv
1010
from llama_index.core.utilities.sql_wrapper import SQLDatabase
1111
from llama_index.core.query_engine import NLSQLTableQueryEngine
12-
from llama_index.core import Settings
1312

1413
from boot import configure_llm
1514

@@ -21,7 +20,7 @@ def main():
2120

2221
# Configure application.
2322
load_dotenv()
24-
configure_llm()
23+
llm, embed_model = configure_llm()
2524

2625
# Configure database connection and query engine.
2726
print("Connecting to CrateDB")
@@ -33,7 +32,8 @@ def main():
3332
query_engine = NLSQLTableQueryEngine(
3433
sql_database=sql_database,
3534
tables=[os.getenv("CRATEDB_TABLE_NAME")],
36-
llm=Settings.llm
35+
llm=llm,
36+
embed_model=embed_model,
3737
)
3838

3939
# Invoke an inquiry.

0 commit comments

Comments
 (0)