diff --git a/doccano_mini/components.py b/doccano_mini/components.py index 3bbb601..06d63a0 100644 --- a/doccano_mini/components.py +++ b/doccano_mini/components.py @@ -1,4 +1,8 @@ +import os + import streamlit as st +from langchain.llms import OpenAI +from langchain.schema import BaseLanguageModel CODE = """from langchain.chains import load_chain @@ -19,3 +23,19 @@ def display_download_button(): def display_usage(): st.header("Usage") st.code(CODE) + + +def openai_model_form() -> BaseLanguageModel: + # https://platform.openai.com/docs/models/gpt-3-5 + AVAILABLE_MODELS = ( + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", + "text-davinci-003", + "text-davinci-002", + "code-davinci-002", + ) + api_key = st.text_input("API key", value=os.environ.get("OPENAI_API_KEY", ""), type="password") + model_name = st.selectbox("Model", AVAILABLE_MODELS, index=2) + temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.7, step=0.01) + top_p = st.slider("Top-p", min_value=0.0, max_value=1.0, value=1.0, step=0.01) + return OpenAI(model_name=model_name, temperature=temperature, top_p=top_p, openai_api_key=api_key) # type:ignore diff --git a/doccano_mini/models.py b/doccano_mini/models.py deleted file mode 100644 index 806b101..0000000 --- a/doccano_mini/models.py +++ /dev/null @@ -1,8 +0,0 @@ -# https://platform.openai.com/docs/models/gpt-3-5 -AVAILABLE_MODELS = ( - "gpt-3.5-turbo", - "gpt-3.5-turbo-0301", - "text-davinci-003", - "text-davinci-002", - "code-davinci-002", -) diff --git a/doccano_mini/pages/01_Text_Classification.py b/doccano_mini/pages/01_Text_Classification.py index 71285f2..ad09d76 100644 --- a/doccano_mini/pages/01_Text_Classification.py +++ b/doccano_mini/pages/01_Text_Classification.py @@ -1,12 +1,12 @@ -import os - import streamlit as st from langchain.chains import LLMChain -from langchain.llms import OpenAI -from doccano_mini.components import display_download_button, display_usage +from doccano_mini.components import ( + display_download_button, + display_usage, + openai_model_form, +) from doccano_mini.examples import make_classification_example -from doccano_mini.models import AVAILABLE_MODELS from doccano_mini.prompts import make_classification_prompt st.title("Text Classification") @@ -24,20 +24,16 @@ prompt.prefix = instruction st.header("Test") -api_key = st.text_input("Enter API key", value=os.environ.get("OPENAI_API_KEY", ""), type="password") col1, col2 = st.columns([3, 1]) text = col1.text_area(label="Please enter your text.", value="", height=300) -# Use text-davinci-003 by default. -model_name = col2.selectbox("Model", AVAILABLE_MODELS, index=2) -temperature = col2.slider("Temperature", min_value=0.0, max_value=1.0, value=0.7, step=0.01) -top_p = col2.slider("Top-p", min_value=0.0, max_value=1.0, value=1.0, step=0.01) +with col2: + llm = openai_model_form() with st.expander("See your prompt"): st.markdown(f"```\n{prompt.format(input=text)}\n```") if st.button("Predict"): - llm = OpenAI(model_name=model_name, temperature=temperature, top_p=top_p, openai_api_key=api_key) # type:ignore chain = LLMChain(llm=llm, prompt=prompt) response = chain.run(text) label = response.split(":")[1] diff --git a/doccano_mini/pages/09_Task_Free.py b/doccano_mini/pages/09_Task_Free.py index 2b9c34c..aafcbb8 100644 --- a/doccano_mini/pages/09_Task_Free.py +++ b/doccano_mini/pages/09_Task_Free.py @@ -1,12 +1,12 @@ -import os - import streamlit as st from langchain.chains import LLMChain -from langchain.llms import OpenAI -from doccano_mini.components import display_download_button, display_usage +from doccano_mini.components import ( + display_download_button, + display_usage, + openai_model_form, +) from doccano_mini.examples import make_task_free_example -from doccano_mini.models import AVAILABLE_MODELS from doccano_mini.prompts import make_task_free_prompt st.title("Task Free") @@ -31,13 +31,8 @@ st.markdown(f"Your prompt\n```\n{prompt.format(**inputs)}\n```") -# Use text-davinci-003 by default. -api_key = st.text_input("Enter API key", value=os.environ.get("OPENAI_API_KEY", ""), type="password") -model_name = st.selectbox("Model", AVAILABLE_MODELS, index=2) -temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.7, step=0.01) -top_p = st.slider("Top-p", min_value=0.0, max_value=1.0, value=1.0, step=0.01) +llm = openai_model_form() if st.button("Predict"): - llm = OpenAI(model_name=model_name, temperature=temperature, top_p=top_p, openai_api_key=api_key) # type:ignore chain = LLMChain(llm=llm, prompt=prompt) response = chain.run(**inputs) st.text(response)