Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LangChain Integration #60

Merged
merged 12 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 64 additions & 15 deletions libs/manubot_ai_editor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import time
import json

import openai
from langchain_openai import OpenAI, ChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage

from manubot_ai_editor import env_vars

Expand Down Expand Up @@ -141,12 +142,13 @@ def __init__(
super().__init__()

# make sure the OpenAI API key is set
openai.api_key = openai_api_key
if openai_api_key is None:
# attempt to get the OpenAI API key from the environment, since one
# wasn't specified as an argument
openai_api_key = os.environ.get(env_vars.OPENAI_API_KEY, None)

if openai.api_key is None:
openai.api_key = os.environ.get(env_vars.OPENAI_API_KEY, None)

if openai.api_key is None or openai.api_key.strip() == "":
# if it's *still* not set, bail
if openai_api_key is None or openai_api_key.strip() == "":
raise ValueError(
f"OpenAI API key not found. Please provide it as parameter "
f"or set it as an the environment variable "
Expand Down Expand Up @@ -253,6 +255,22 @@ def __init__(

self.several_spaces_pattern = re.compile(r"\s+")

if self.endpoint == "edits":
falquaddoomi marked this conversation as resolved.
Show resolved Hide resolved
# FIXME: what's the "edits" equivalent in langchain?
falquaddoomi marked this conversation as resolved.
Show resolved Hide resolved
client_cls = OpenAI
elif self.endpoint == "chat":
client_cls = ChatOpenAI
else:
client_cls = OpenAI
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to take care of this anymore. Before, there were a "completion" and "edits" endpoints, but now we only have a "chat" endpoint I believe. Let's research a little bit, but I think we only need the ChatOpenAI class here.


# construct the OpenAI client after all the rest of
# the settings above have been processed
self.client = client_cls(
api_key=openai_api_key,
**self.model_parameters,
)


def get_prompt(
self, paragraph_text: str, section_name: str = None, resolved_prompt: str = None
) -> str | tuple[str, str]:
Expand Down Expand Up @@ -526,17 +544,48 @@ def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolv
flush=True,
)

if self.endpoint == "edits":
completions = openai.Edit.create(**params)
elif self.endpoint == "chat":
completions = openai.ChatCompletion.create(**params)
# FIXME: 'params' contains a lot of fields that we're not
# currently passing to the langchain client. i need to figure
# out where they're supposed to be given, e.g. in the client
# init or with each request.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are those fields in params?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at it again, "a lot" is an overstatement, sorry. On top of the model_parameters dict that gets merged into it and aside from prompt (or the other variants based on whether it's a "chat" or "edits" model) GPT3CompletionModel.get_params() introduces just:

  • n: I assume this is the number of responses you want the API to generate
    • it seems that it's always 1, and it LangChain's invoke() returns a single response anyway, so I assume we can ignore this one
  • stop: despite being None all the time and probably not necessary to include in invoke()
    • this one's easy to integrate, since invoke() takes stop as an argument; I'll just go ahead and add it
  • max_tokens: it seems this is taken at client initialization in LangChain
    • I'll see if there's a way to provide it for each invoke() call, or to change its value prior to the call

Correct me if I'm wrong, but since model_parameters is already used to initialize the client and since AFAICT it's not changed after that, I don't think we need to include its contents in invoke().

I'll go ahead and make the other changes, though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I didn't forget what the code does, the only field that should go in each request/invoke (instead of using them to initialize the client) is max_tokens, because for each paragraph we restrict the model to generate up to twice (or so) the number of tokens in the input paragraph. So that should go into each request, not the client (or update the client before each request?).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, after I made the comment above I discovered that invoke() does take max_tokens as well as stop; I've added it in my most recent commits. I assume we still don't need to change n from 1, which AFAICT is the default for invoke() as well, so I left that out of the call to invoke().


# map the prompt to langchain's prompt types, based on what
# kind of endpoint we're using
if "messages" in params:
falquaddoomi marked this conversation as resolved.
Show resolved Hide resolved
# map the messages to langchain's message types
# based on the 'role' field
prompts = [
HumanMessage(content=msg["content"])
if msg["role"] == "user" else
falquaddoomi marked this conversation as resolved.
Show resolved Hide resolved
SystemMessage(content=msg["content"])
for msg in params["messages"]
]
elif "instruction" in params:
# since we don't know how to use the edits endpoint, we'll just
# concatenate the instruction and input and use the regular
# completion endpoint
# FIXME: there's probably a langchain equivalent for
# "edits", so we should change this to use that
prompts = [
HumanMessage(content=params["instruction"]),
HumanMessage(content=params["input"]),
]
elif "prompt" in params:
prompts = [HumanMessage(content=params["prompt"])]

response = self.client.invoke(prompts)

if isinstance(response, BaseMessage):
message = response.content.strip()
else:
completions = openai.Completion.create(**params)
message = response.strip()

# FIXME: the prior code retrieved the first of the 'choices'
# response from the openai client. now, we only get one
# response from the langchain client, but i should check
# if that's really how langchain works or if there is a way
# to get multiple 'choices' back from the backend.

if self.endpoint == "chat":
message = completions.choices[0].message.content.strip()
else:
message = completions.choices[0].text.strip()
except Exception as e:
error_message = str(e)
print(f"Error: {error_message}")
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setuptools.setup(
name="manubot-ai-editor",
version="0.5.2",
version="0.5.3",
author="Milton Pividori",
author_email="miltondp@gmail.com",
description="A Manubot plugin to revise a manuscript using GPT-3",
Expand All @@ -25,7 +25,7 @@
],
python_requires=">=3.10",
install_requires=[
"openai==0.28",
"langchain-openai==0.2.0",
falquaddoomi marked this conversation as resolved.
Show resolved Hide resolved
"pyyaml",
],
classifiers=[
Expand Down
17 changes: 6 additions & 11 deletions tests/test_model_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest

from manubot_ai_editor.editor import ManuscriptEditor, env_vars
from manubot_ai_editor import models
from manubot_ai_editor.models import GPT3CompletionModel, RandomManuscriptRevisionModel

MANUSCRIPTS_DIR = Path(__file__).parent / "manuscripts"
Expand All @@ -32,12 +31,12 @@ def test_model_object_init_without_openai_api_key():

@mock.patch.dict("os.environ", {env_vars.OPENAI_API_KEY: "env_var_test_value"})
def test_model_object_init_with_openai_api_key_as_environment_variable():
GPT3CompletionModel(
model = GPT3CompletionModel(
title="Test title",
keywords=["test", "keywords"],
)

assert models.openai.api_key == "env_var_test_value"
assert model.client.openai_api_key.get_secret_value() == "env_var_test_value"


def test_model_object_init_with_openai_api_key_as_parameter():
Expand All @@ -46,30 +45,26 @@ def test_model_object_init_with_openai_api_key_as_parameter():
if env_vars.OPENAI_API_KEY in os.environ:
os.environ.pop(env_vars.OPENAI_API_KEY)

GPT3CompletionModel(
model = GPT3CompletionModel(
title="Test title",
keywords=["test", "keywords"],
openai_api_key="test_value",
)

from manubot_ai_editor import models

assert models.openai.api_key == "test_value"
assert model.client.openai_api_key.get_secret_value() == "test_value"
finally:
os.environ = _environ


@mock.patch.dict("os.environ", {env_vars.OPENAI_API_KEY: "env_var_test_value"})
def test_model_object_init_with_openai_api_key_as_parameter_has_higher_priority():
GPT3CompletionModel(
model = GPT3CompletionModel(
title="Test title",
keywords=["test", "keywords"],
openai_api_key="test_value",
)

from manubot_ai_editor import models

assert models.openai.api_key == "test_value"
assert model.client.openai_api_key.get_secret_value() == "test_value"


def test_model_object_init_default_language_model():
Expand Down