From 67918d29328f04a6f71b4959b7f6dd6b1b01392c Mon Sep 17 00:00:00 2001 From: luo-anthony Date: Wed, 29 Nov 2023 20:02:25 -0600 Subject: [PATCH] Update to support OpenAI Python v1 --- developergpt/cli.py | 65 +++++++++++++-------------- developergpt/huggingface_adapter.py | 2 + developergpt/openai_adapter.py | 68 +++++++++++++++++------------ developergpt/utils.py | 3 +- requirements.txt | 6 +-- 5 files changed, 78 insertions(+), 66 deletions(-) diff --git a/developergpt/cli.py b/developergpt/cli.py index 29cac94..f7cf04a 100644 --- a/developergpt/cli.py +++ b/developergpt/cli.py @@ -8,8 +8,7 @@ import click import inquirer -import openai -from openai import error +from openai import OpenAI from prompt_toolkit import PromptSession from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from prompt_toolkit.shortcuts import CompleteStyle @@ -21,26 +20,6 @@ session: PromptSession = PromptSession() -def handle_api_error(f): - """Handle API errors gracefully""" - - @wraps(f) - def internal(*args, **kwargs): - try: - return f(*args, **kwargs) - except error.RateLimitError: - console.print("[bold red] Rate limit exceeded. Try again later.[/bold red]") - sys.exit(-1) - except error.ServiceUnavailableError: - console.print("[bold red] Service Unavailable. Try again later.[/bold red]") - sys.exit(-1) - except error.InvalidRequestError as e: - console.log(f"[bold red] Invalid Request: {e}[/bold red]") - sys.exit(-1) - - return internal - - @click.group() @click.option( "--temperature", @@ -70,8 +49,9 @@ def main(ctx, temperature, model): ctx.ensure_object(dict) if model in config.OPENAI_MODEL_MAP: - openai.api_key = config.get_environ_key(config.OPEN_AI_API_KEY, console) - openai_adapter.check_open_ai_key(console) + client = OpenAI(api_key=config.get_environ_key(config.OPEN_AI_API_KEY, console)) + openai_adapter.check_open_ai_key(console, client) + ctx.obj["client"] = client elif model in config.HF_MODEL_MAP: ctx.obj["api_key"] = config.get_environ_key_optional( config.HUGGING_FACE_API_KEY, console @@ -88,7 +68,6 @@ def main(ctx, temperature, model): @main.command(help="Chat with DeveloperGPT") @click.pass_context @click.argument("user_input", nargs=-1) -@handle_api_error def chat(ctx, user_input): if user_input: user_input = str(" ".join(user_input)) @@ -100,7 +79,6 @@ def chat(ctx, user_input): input_messages = [openai_adapter.INITIAL_CHAT_SYSTEM_MSG] elif model in config.HF_MODEL_MAP: input_messages = huggingface_adapter.BASE_INPUT_CHAT_MSGS - api_token = ctx.obj.get("api_key", None) else: return @@ -115,12 +93,23 @@ def chat(ctx, user_input): continue if model in config.OPENAI_MODEL_MAP: + client = ctx.obj["client"] input_messages = openai_adapter.get_model_chat_response( - user_input, console, input_messages, ctx.obj["temperature"], model + user_input=user_input, + console=console, + input_messages=input_messages, + temperature=ctx.obj["temperature"], + model=model, + client=client, ) elif model in config.HF_MODEL_MAP: + api_token = ctx.obj.get("api_key", None) input_messages = huggingface_adapter.get_model_chat_response( - user_input, console, input_messages, api_token, model + user_input=user_input, + console=console, + input_messages=input_messages, + api_token=api_token, + model=model, ) user_input = None @@ -135,7 +124,6 @@ def chat(ctx, user_input): help="Get commands without command or argument explanations (less accurate)", ) @click.pass_context -@handle_api_error def cmd(ctx, user_input, fast): input_request = "\nDesired Command Request: " @@ -149,8 +137,6 @@ def cmd(ctx, user_input, fast): session.history.append_string(user_input) model = ctx.obj["model"] - if model in config.HF_MODEL_MAP: - api_token = ctx.obj.get("api_key", None) if not user_input: console.print("[gray]Type 'quit' to exit[/gray]") @@ -169,13 +155,24 @@ def cmd(ctx, user_input, fast): if not user_input: continue + model_output = None if model in config.OPENAI_MODEL_MAP: + client = ctx.obj["client"] model_output = openai_adapter.model_command( - user_input, console, fast, model + user_input=user_input, + console=console, + fast_mode=fast, + model=model, + client=client, ) elif model in config.HF_MODEL_MAP: + api_token = ctx.obj.get("api_key", None) model_output = huggingface_adapter.model_command( - user_input, console, api_token, fast, model + user_input=user_input, + console=console, + api_token=api_token, + fast_mode=fast, + model=model, ) user_input = None # clear input for next iteration @@ -193,7 +190,7 @@ def cmd(ctx, user_input, fast): questions = [ inquirer.List("Next", message="What would you like to do?", choices=options) ] - selected_option = inquirer.prompt(questions)["Next"] + selected_option = inquirer.prompt(questions)["Next"] # type: ignore if selected_option == "Revise Query": input_request = "Revised Command Request: " diff --git a/developergpt/huggingface_adapter.py b/developergpt/huggingface_adapter.py index a81f9d1..2421b47 100644 --- a/developergpt/huggingface_adapter.py +++ b/developergpt/huggingface_adapter.py @@ -65,6 +65,7 @@ def format_assistant_output(output: str) -> str: def model_command( + *, user_input: str, console: Console, api_token: Optional[str], @@ -163,6 +164,7 @@ def format_bloom_chat_input(messages: list) -> str: def get_model_chat_response( + *, user_input: str, console: Console, input_messages: list, diff --git a/developergpt/openai_adapter.py b/developergpt/openai_adapter.py index 6ac4a4c..0e19ddb 100644 --- a/developergpt/openai_adapter.py +++ b/developergpt/openai_adapter.py @@ -3,9 +3,10 @@ """ import sys from datetime import datetime +from typing import Optional import openai -from openai import error +from openai import OpenAI from rich.console import Console from rich.live import Live from rich.markdown import Markdown @@ -155,11 +156,13 @@ def format_assistant_response(assistant_response: str) -> dict: def get_model_chat_response( + *, user_input: str, console: Console, input_messages: list, temperature: float, model: str, + client: "OpenAI", ) -> list: MAX_TOKENS = 4000 RESERVED_OUTPUT_TOKENS = 1024 @@ -173,7 +176,7 @@ def get_model_chat_response( n_output_tokens = max(RESERVED_OUTPUT_TOKENS, MAX_TOKENS - n_input_tokens) """Get the response from the model.""" - response = openai.ChatCompletion.create( + response = client.chat.completions.create( model=model_name, messages=input_messages, max_tokens=n_output_tokens, @@ -189,14 +192,21 @@ def get_model_chat_response( title_align="left", width=panel_width, ) - - with Live(output_panel, refresh_per_second=4): - for chunk in response: - msg = chunk["choices"][0]["delta"].get("content", "") - collected_messages.append(msg) - output_panel.renderable = Markdown( - "".join(collected_messages), inline_code_theme="monokai" - ) + try: + with Live(output_panel, refresh_per_second=4): + for chunk in response: + msg = chunk.choices[0].delta.content + if msg: + collected_messages.append(msg) + output_panel.renderable = Markdown( + "".join(collected_messages), inline_code_theme="monokai" + ) + except openai.RateLimitError: + console.print("[bold red] Rate limit exceeded. Try again later.[/bold red]") + sys.exit(-1) + except openai.BadRequestError as e: + console.log(f"[bold red] Bad Request: {e}[/bold red]") + sys.exit(-1) full_response = "".join(collected_messages) input_messages.append(format_assistant_response(full_response)) @@ -204,11 +214,8 @@ def get_model_chat_response( def model_command( - user_input: str, - console: Console, - fast_mode: bool, - model: str, -) -> list: + *, user_input: str, console: Console, fast_mode: bool, model: str, client: "OpenAI" +) -> Optional[str]: MAX_TOKENS = 4000 RESERVED_OUTPUT_TOKENS = 1024 MAX_INPUT_TOKENS = MAX_TOKENS - RESERVED_OUTPUT_TOKENS @@ -234,25 +241,30 @@ def model_command( ) n_output_tokens = max(RESERVED_OUTPUT_TOKENS, MAX_TOKENS - n_input_tokens) + try: + with console.status("[bold blue]Decoding request") as _: + response = client.chat.completions.create( + model=model_name, + messages=input_messages, + max_tokens=n_output_tokens, + temperature=TEMP, + ) + except openai.RateLimitError: + console.print("[bold red] Rate limit exceeded. Try again later.[/bold red]") + sys.exit(-1) + except openai.BadRequestError as e: + console.log(f"[bold red] Bad Request: {e}[/bold red]") + sys.exit(-1) - with console.status("[bold blue]Decoding request") as _: - response = openai.ChatCompletion.create( - model=model_name, - messages=input_messages, - max_tokens=n_output_tokens, - temperature=TEMP, - ) - - model_output = response["choices"][0]["message"]["content"].strip() - + model_output = response.choices[0].message.content return model_output -def check_open_ai_key(console: "Console") -> None: +def check_open_ai_key(console: "Console", client: "OpenAI") -> None: """Check if the OpenAI API key is valid.""" try: - _ = openai.Model.list() - except error.AuthenticationError: + _ = client.models.list() + except openai.AuthenticationError: console.print( f"[bold red]Error: Invalid OpenAI API key. Check your {config.OPEN_AI_API_KEY} environment variable.[/bold red]" ) diff --git a/developergpt/utils.py b/developergpt/utils.py index b41cf7f..6f58a72 100644 --- a/developergpt/utils.py +++ b/developergpt/utils.py @@ -6,6 +6,7 @@ import json import os import sys +from typing import Optional import pyperclip import requests @@ -39,7 +40,7 @@ def pretty_print_commands(commands: list, console: Console, panel_width: int) -> def print_command_response( - model_output: str, console: Console, fast_mode: bool, model: str + model_output: Optional[str], console: Console, fast_mode: bool, model: str ) -> list: if not model_output: return [] diff --git a/requirements.txt b/requirements.txt index 93d5f57..ada5687 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -openai <= 0.27.8 -pydantic <= 1.10.9 +openai >= 1.1.0 +pydantic < 2.0.0 +text_generation >= 0.6.0 click tiktoken rich inquirer prompt_toolkit -text_generation >= 0.6.0 pyperclip requests \ No newline at end of file