This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add OpenAI Chat Completion Agent (GPT-3.5/GPT-4) (#5061)
* Add OpenAI Chat Completion Agent * Change model list task to NA * Update OpenAI Chat Completion README.md limitations
- Loading branch information
Showing
6 changed files
with
288 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# OpenAI Chat Completion API | ||
This is an agent that interfaces with [OpenAI's chat completion api](https://platform.openai.com/docs/api-reference/chat/create) (/v1/chat/completions). | ||
|
||
The chat completion endpoint supports the most advanced large language models as of June 2023 such as | ||
* [GPT-3.5](https://platform.openai.com/docs/models/gpt-3-5) | ||
* [GPT-4](https://platform.openai.com/docs/models/gpt-4) | ||
|
||
## Setup | ||
```bash | ||
pip install openai | ||
``` | ||
|
||
More info on setup is outlined in the official docs [here](https://platform.openai.com/docs/api-reference/introduction). | ||
|
||
Once the openai Python package is installed, you can start using the endpoint as long as you have a valid OpenAI API key generated and ready-to-use. | ||
|
||
## Interactive example | ||
|
||
``` | ||
parlai self_chat -m openai_chat_completions --num-self-chats 1 --selfchat-max-turns 6 --openai-api-key <insert your api key> --max-tokens 40 --model-name gpt-3.5-turbo --init-prompt "You are chatting with a fellow Wizard of the grand frost order. You are defending the village from sword wielding golden retrievers raiding the countryside." --name wizard1 --role user --counterparty-role user --counterparty-name moose --partner-model-file zoo:blender/blender_90M/model | ||
``` | ||
|
||
### Output | ||
``` | ||
[context]: Hi! | ||
[OpenaiChatCompletionsAgent_1]: Hello fellow wizard! I see you're also busy defending the village from these sword-wielding golden retrievers. These creatures are quite strong and ferocious, aren't they? | ||
[TransformerGenerator_2]: i ' m not a wizard , but i ' ve always wanted to be one . what do you do for a living ? | ||
[OpenaiChatCompletionsAgent_1]: Oh, I'm not sure if I would call it a living, per se. As a wizard of the grand frost order, my duty is to protect the village and its people from all sorts of | ||
[TransformerGenerator_2]: that ' s very noble of you . i ' d love to be a grand wizard one day . | ||
[OpenaiChatCompletionsAgent_1]: It takes years of study and practice to become a grand wizard, but if you have the determination and passion for it, you can certainly achieve it. As a grand wizard, you won't just be | ||
[TransformerGenerator_2]: do you have any other hobbies besides being a wizard ? i like to think of it as a hobby . | ||
[OpenaiChatCompletionsAgent_1]: As a wizard, my primary focus is on studying and practicing magic to improve my abilities and protect the village. However, when I have some free time, I enjoy reading books on history, mythology, | ||
[TransformerGenerator_2]: what kind of books do you like to read ? i have a lot of free time as well . | ||
[OpenaiChatCompletionsAgent_1]: I enjoy reading books on history, folklore, and mythology. I find these topics fascinating, and they often give me inspiration for spells and incantations. I also enjoy reading fictional works, such as | ||
[TransformerGenerator_2]: i like fantasy books too . i like the ones that focus on the real world and not just fantasy . | ||
``` | ||
|
||
## Self chat example | ||
``` | ||
parlai interactive -m openai_chat_completions --openai-api-key <insert your api key> --max-tokens 40 --model-name gpt-4 | ||
``` | ||
|
||
### Output | ||
``` | ||
Enter Your Message: Can you describe a pack of golden retriever knights roaming the countryside? | ||
[OpenaiChatCompletionsAgent]: In the enchanting countryside, a majestic sight awaited anyone who happened to stumble upon it. A pack of Golden Retriever knights, glorious canines draped in gleaming armor, solemnly roamed | ||
Enter Your Message: Can you write a sentence describing how they roam the land for name brand kibble? | ||
[OpenaiChatCompletionsAgent]: the vast landscapes in pursuit of the fabled name-brand kibble, rumored to grant strength and power to those valiant enough to consume its heavenly morsels. | ||
``` | ||
|
||
## Limitations | ||
This API wrapper has three major limitations | ||
1. Cost - Repeatedly prompting the API can be expensive. | ||
2. Rate limiting - API queries can run into rate limiting issues which will cause the conversation to error out. [Official docs](https://platform.openai.com/docs/guides/rate-limits) offers more insight on dealing with this issue. | ||
3. Token Limit - A combination of prompt and response can usually only be up to 8k tokens and may be smaller depending on the model requested for chat completions [official docs](https://openai.com/pricing). This limits the size of both the initial prompt as well as the length of conversation that we can feed back into the model. Exceeding this limit will cause the conversation to error out. | ||
4. Self Chat - A self chat conducted between two OpenAI completion agents will not properly use the name and role arguments (as well as the counterparty versions). When this occurs, the turn history is not accurate because both agent-1 and agent-2 believe that their utterances are attached to `name` and `role` and that the other speaker is attributed to `counterparty-name` and `counterparty-role`. Ideally, agent-2 identifies its utterances to match `counterparty-name` and `counterparty-role`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
196 changes: 196 additions & 0 deletions
196
parlai/agents/openai_chat_completions/openai_chat_completions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Optional | ||
from parlai.core.params import ParlaiParser | ||
from parlai.core.opt import Opt | ||
from parlai.core.agents import Agent | ||
from parlai.core.message import Message | ||
|
||
try: | ||
import openai | ||
except ImportError: | ||
raise ImportError('Please run `pip install openai`.') | ||
|
||
|
||
class OpenaiChatCompletionsAgent(Agent): | ||
@classmethod | ||
def add_cmdline_args( | ||
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None | ||
) -> ParlaiParser: | ||
group = parser.add_argument_group('Chat Completion Arguments') | ||
group.add_argument( | ||
'--openai-api-key', | ||
type=str, | ||
required=True, | ||
help='Add your OpenAI api key', | ||
) | ||
group.add_argument( | ||
'--model-name', | ||
type=str, | ||
required=True, | ||
help="""Choose model name like GPT-4 or GPT 3.5""", | ||
) | ||
group.add_argument( | ||
'--init-prompt', | ||
type=str, | ||
default='', | ||
help="""Initial prompt that starts the conversation. Turns of conversation are appended to subsequent OpenAI | ||
completion queries.""", | ||
) | ||
group.add_argument( | ||
'--role', | ||
type=str, | ||
default='assistant', | ||
choices=['user', 'system', 'assistant'], | ||
help='Role of the author of message', | ||
) | ||
group.add_argument( | ||
'--counterparty-role', | ||
type=str, | ||
default='assistant', | ||
choices=['user', 'system', 'assistant'], | ||
help='Role of the other speaker', | ||
) | ||
group.add_argument( | ||
'--name', | ||
type=str, | ||
help='Name of the author of the message. Alphanumeric (with underscores) strings up to 64 chars are allowed', | ||
) | ||
group.add_argument( | ||
'--counterparty-name', | ||
type=str, | ||
help='Name of the other speaker. Alphanumeric (with underscores) strings up to 64 chars are allowed', | ||
) | ||
group.add_argument( | ||
'--max-tokens', | ||
type=int, | ||
required=True, | ||
help='The max number of tokens generated as a single conversation turn', | ||
) | ||
group.add_argument( | ||
'--temperature', | ||
type=float, | ||
default=1.0, | ||
help="""Temperature ranges between 0-2 such that higher temperature will make outputs more random while lower | ||
values make the output more deterministic""", | ||
) | ||
group.add_argument( | ||
'--top-p', | ||
type=float, | ||
default=1.0, | ||
help='Determines nucleus sampling rate', | ||
) | ||
group.add_argument( | ||
'--stop-sequence', | ||
type=str, | ||
help='Stop sequence is a string that will stop further generation of tokens', | ||
) | ||
group.add_argument( | ||
'--presence-penalty', | ||
type=float, | ||
default=0.0, | ||
help="""Presence penalty ranges between -2.0 to 2.0 such that more positive values will reduce chance of generating | ||
tokens that already appeared in text""", | ||
) | ||
group.add_argument( | ||
'--frequency-penalty', | ||
type=float, | ||
default=0.0, | ||
help="""Frequency penalty ranges between -2.0 to 2.0 such that more positive values will reduce chance of | ||
generating tokens that appear frequently""", | ||
) | ||
return parser | ||
|
||
def __init__(self, opt, shared=None): | ||
super().__init__(opt) | ||
self.id = 'OpenaiChatCompletionsAgent' | ||
self.turns = [] | ||
self.history = FakeHistory(self) | ||
self.model_name = opt.get('model_name') | ||
self.init_prompt = opt.get('init_prompt') | ||
self.role = opt.get('role') | ||
self.counterparty_role = opt.get('counterparty_role') | ||
self.name = opt.get('name') | ||
self.counterparty_name = opt.get('counterparty_name') | ||
self.max_tokens = opt.get('max_tokens') | ||
self.temperature = opt.get('temperature') | ||
self.top_p = opt.get('top_p') | ||
self.stop_sequence = opt.get('stop_sequence') | ||
self.presence_penalty = opt.get('presence_penalty') | ||
self.frequency_penalty = opt.get('frequency_penalty') | ||
|
||
# check that string self.init_prompt is not empty nor None | ||
if self.init_prompt: | ||
self.turns.append({'role': 'system', 'content': self.init_prompt}) | ||
|
||
openai.api_key = opt.get('openai_api_key') | ||
|
||
def reset(self): | ||
""" | ||
Reset the agent, clearing its observation. | ||
Many subclasses implement additional reset logic. | ||
""" | ||
self.observation = None | ||
self.turns = [] | ||
|
||
def observe(self, observation): | ||
""" | ||
Receive an observation/action dict. | ||
""" | ||
self.observation = observation | ||
|
||
if self.observation['id'] == 'context': | ||
msg = {'role': 'system', 'content': observation['text']} | ||
else: | ||
msg = {'role': self.counterparty_role, 'content': observation['text']} | ||
if self.counterparty_name: | ||
msg['name'] = self.counterparty_name | ||
self.turns.append(msg) | ||
|
||
return observation | ||
|
||
def act(self): | ||
""" | ||
Generate response to last seen observation. | ||
""" | ||
resp = self.query_chat_completion_api() | ||
resp_txt = resp['choices'][0]['message']['content'] | ||
|
||
msg = {'role': self.role, 'content': resp_txt} | ||
if self.name: | ||
msg['name'] = self.name | ||
self.turns.append(msg) | ||
|
||
return Message({'id': self.getID(), 'text': resp_txt, 'episode_done': False}) | ||
|
||
def query_chat_completion_api(self): | ||
response = openai.ChatCompletion.create( | ||
model=self.model_name, | ||
messages=self.turns, | ||
temperature=self.temperature, | ||
top_p=self.top_p, | ||
stop=self.stop_sequence, | ||
max_tokens=self.max_tokens, | ||
presence_penalty=self.presence_penalty, | ||
frequency_penalty=self.frequency_penalty, | ||
) | ||
return response | ||
|
||
|
||
class FakeHistory: | ||
def __init__(self, agent): | ||
self.agent = agent | ||
|
||
def add_reply(self, text, role='assistant', name=None): | ||
msg = {'role': role, 'content': text} | ||
if name: | ||
msg['name'] = name | ||
self.agent.turns.append() | ||
|
||
def get_history_str(self): | ||
return str(self.agent.turns) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#!/usr/bin/env python3 | ||
|
||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
def download(datapath): | ||
pass |