Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to support OpenAI Python v1 #28

Merged
merged 1 commit into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 31 additions & 34 deletions developergpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

import click
import inquirer
import openai
from openai import error
from openai import OpenAI
from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.shortcuts import CompleteStyle
Expand All @@ -21,26 +20,6 @@
session: PromptSession = PromptSession()


def handle_api_error(f):
"""Handle API errors gracefully"""

@wraps(f)
def internal(*args, **kwargs):
try:
return f(*args, **kwargs)
except error.RateLimitError:
console.print("[bold red] Rate limit exceeded. Try again later.[/bold red]")
sys.exit(-1)
except error.ServiceUnavailableError:
console.print("[bold red] Service Unavailable. Try again later.[/bold red]")
sys.exit(-1)
except error.InvalidRequestError as e:
console.log(f"[bold red] Invalid Request: {e}[/bold red]")
sys.exit(-1)

return internal


@click.group()
@click.option(
"--temperature",
Expand Down Expand Up @@ -70,8 +49,9 @@ def main(ctx, temperature, model):

ctx.ensure_object(dict)
if model in config.OPENAI_MODEL_MAP:
openai.api_key = config.get_environ_key(config.OPEN_AI_API_KEY, console)
openai_adapter.check_open_ai_key(console)
client = OpenAI(api_key=config.get_environ_key(config.OPEN_AI_API_KEY, console))
openai_adapter.check_open_ai_key(console, client)
ctx.obj["client"] = client
elif model in config.HF_MODEL_MAP:
ctx.obj["api_key"] = config.get_environ_key_optional(
config.HUGGING_FACE_API_KEY, console
Expand All @@ -88,7 +68,6 @@ def main(ctx, temperature, model):
@main.command(help="Chat with DeveloperGPT")
@click.pass_context
@click.argument("user_input", nargs=-1)
@handle_api_error
def chat(ctx, user_input):
if user_input:
user_input = str(" ".join(user_input))
Expand All @@ -100,7 +79,6 @@ def chat(ctx, user_input):
input_messages = [openai_adapter.INITIAL_CHAT_SYSTEM_MSG]
elif model in config.HF_MODEL_MAP:
input_messages = huggingface_adapter.BASE_INPUT_CHAT_MSGS
api_token = ctx.obj.get("api_key", None)
else:
return

Expand All @@ -115,12 +93,23 @@ def chat(ctx, user_input):
continue

if model in config.OPENAI_MODEL_MAP:
client = ctx.obj["client"]
input_messages = openai_adapter.get_model_chat_response(
user_input, console, input_messages, ctx.obj["temperature"], model
user_input=user_input,
console=console,
input_messages=input_messages,
temperature=ctx.obj["temperature"],
model=model,
client=client,
)
elif model in config.HF_MODEL_MAP:
api_token = ctx.obj.get("api_key", None)
input_messages = huggingface_adapter.get_model_chat_response(
user_input, console, input_messages, api_token, model
user_input=user_input,
console=console,
input_messages=input_messages,
api_token=api_token,
model=model,
)

user_input = None
Expand All @@ -135,7 +124,6 @@ def chat(ctx, user_input):
help="Get commands without command or argument explanations (less accurate)",
)
@click.pass_context
@handle_api_error
def cmd(ctx, user_input, fast):
input_request = "\nDesired Command Request: "

Expand All @@ -149,8 +137,6 @@ def cmd(ctx, user_input, fast):
session.history.append_string(user_input)

model = ctx.obj["model"]
if model in config.HF_MODEL_MAP:
api_token = ctx.obj.get("api_key", None)

if not user_input:
console.print("[gray]Type 'quit' to exit[/gray]")
Expand All @@ -169,13 +155,24 @@ def cmd(ctx, user_input, fast):
if not user_input:
continue

model_output = None
if model in config.OPENAI_MODEL_MAP:
client = ctx.obj["client"]
model_output = openai_adapter.model_command(
user_input, console, fast, model
user_input=user_input,
console=console,
fast_mode=fast,
model=model,
client=client,
)
elif model in config.HF_MODEL_MAP:
api_token = ctx.obj.get("api_key", None)
model_output = huggingface_adapter.model_command(
user_input, console, api_token, fast, model
user_input=user_input,
console=console,
api_token=api_token,
fast_mode=fast,
model=model,
)
user_input = None # clear input for next iteration

Expand All @@ -193,7 +190,7 @@ def cmd(ctx, user_input, fast):
questions = [
inquirer.List("Next", message="What would you like to do?", choices=options)
]
selected_option = inquirer.prompt(questions)["Next"]
selected_option = inquirer.prompt(questions)["Next"] # type: ignore

if selected_option == "Revise Query":
input_request = "Revised Command Request: "
Expand Down
2 changes: 2 additions & 0 deletions developergpt/huggingface_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def format_assistant_output(output: str) -> str:


def model_command(
*,
user_input: str,
console: Console,
api_token: Optional[str],
Expand Down Expand Up @@ -163,6 +164,7 @@ def format_bloom_chat_input(messages: list) -> str:


def get_model_chat_response(
*,
user_input: str,
console: Console,
input_messages: list,
Expand Down
68 changes: 40 additions & 28 deletions developergpt/openai_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
"""
import sys
from datetime import datetime
from typing import Optional

import openai
from openai import error
from openai import OpenAI
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
Expand Down Expand Up @@ -155,11 +156,13 @@ def format_assistant_response(assistant_response: str) -> dict:


def get_model_chat_response(
*,
user_input: str,
console: Console,
input_messages: list,
temperature: float,
model: str,
client: "OpenAI",
) -> list:
MAX_TOKENS = 4000
RESERVED_OUTPUT_TOKENS = 1024
Expand All @@ -173,7 +176,7 @@ def get_model_chat_response(
n_output_tokens = max(RESERVED_OUTPUT_TOKENS, MAX_TOKENS - n_input_tokens)

"""Get the response from the model."""
response = openai.ChatCompletion.create(
response = client.chat.completions.create(
model=model_name,
messages=input_messages,
max_tokens=n_output_tokens,
Expand All @@ -189,26 +192,30 @@ def get_model_chat_response(
title_align="left",
width=panel_width,
)

with Live(output_panel, refresh_per_second=4):
for chunk in response:
msg = chunk["choices"][0]["delta"].get("content", "")
collected_messages.append(msg)
output_panel.renderable = Markdown(
"".join(collected_messages), inline_code_theme="monokai"
)
try:
with Live(output_panel, refresh_per_second=4):
for chunk in response:
msg = chunk.choices[0].delta.content
if msg:
collected_messages.append(msg)
output_panel.renderable = Markdown(
"".join(collected_messages), inline_code_theme="monokai"
)
except openai.RateLimitError:
console.print("[bold red] Rate limit exceeded. Try again later.[/bold red]")
sys.exit(-1)
except openai.BadRequestError as e:
console.log(f"[bold red] Bad Request: {e}[/bold red]")
sys.exit(-1)

full_response = "".join(collected_messages)
input_messages.append(format_assistant_response(full_response))
return input_messages


def model_command(
user_input: str,
console: Console,
fast_mode: bool,
model: str,
) -> list:
*, user_input: str, console: Console, fast_mode: bool, model: str, client: "OpenAI"
) -> Optional[str]:
MAX_TOKENS = 4000
RESERVED_OUTPUT_TOKENS = 1024
MAX_INPUT_TOKENS = MAX_TOKENS - RESERVED_OUTPUT_TOKENS
Expand All @@ -234,25 +241,30 @@ def model_command(
)

n_output_tokens = max(RESERVED_OUTPUT_TOKENS, MAX_TOKENS - n_input_tokens)
try:
with console.status("[bold blue]Decoding request") as _:
response = client.chat.completions.create(
model=model_name,
messages=input_messages,
max_tokens=n_output_tokens,
temperature=TEMP,
)
except openai.RateLimitError:
console.print("[bold red] Rate limit exceeded. Try again later.[/bold red]")
sys.exit(-1)
except openai.BadRequestError as e:
console.log(f"[bold red] Bad Request: {e}[/bold red]")
sys.exit(-1)

with console.status("[bold blue]Decoding request") as _:
response = openai.ChatCompletion.create(
model=model_name,
messages=input_messages,
max_tokens=n_output_tokens,
temperature=TEMP,
)

model_output = response["choices"][0]["message"]["content"].strip()

model_output = response.choices[0].message.content
return model_output


def check_open_ai_key(console: "Console") -> None:
def check_open_ai_key(console: "Console", client: "OpenAI") -> None:
"""Check if the OpenAI API key is valid."""
try:
_ = openai.Model.list()
except error.AuthenticationError:
_ = client.models.list()
except openai.AuthenticationError:
console.print(
f"[bold red]Error: Invalid OpenAI API key. Check your {config.OPEN_AI_API_KEY} environment variable.[/bold red]"
)
Expand Down
3 changes: 2 additions & 1 deletion developergpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import os
import sys
from typing import Optional

import pyperclip
import requests
Expand Down Expand Up @@ -39,7 +40,7 @@ def pretty_print_commands(commands: list, console: Console, panel_width: int) ->


def print_command_response(
model_output: str, console: Console, fast_mode: bool, model: str
model_output: Optional[str], console: Console, fast_mode: bool, model: str
) -> list:
if not model_output:
return []
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
openai <= 0.27.8
pydantic <= 1.10.9
openai >= 1.1.0
pydantic < 2.0.0
text_generation >= 0.6.0
click
tiktoken
rich
inquirer
prompt_toolkit
text_generation >= 0.6.0
pyperclip
requests