Skip to content

Commit

Permalink
Merge pull request #46 from michalwarda/main
Browse files Browse the repository at this point in the history
Add possibility to use api_key per chat invocation.
  • Loading branch information
brainlid authored Dec 4, 2023
2 parents 749731a + 00704a1 commit 407449b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
14 changes: 9 additions & 5 deletions lib/chat_models/chat_open_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ defmodule LangChain.ChatModels.ChatOpenAI do
field :endpoint, :string, default: "https://api.openai.com/v1/chat/completions"
# field :model, :string, default: "gpt-4"
field :model, :string, default: "gpt-3.5-turbo"
# API key for OpenAI. If not set, will use global api key.
field :api_key, :string

# What sampling temperature to use, between 0 and 2. Higher values like 0.8
# will make the output more random, while lower values like 0.2 will make it
Expand Down Expand Up @@ -65,6 +67,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
:model,
:temperature,
:frequency_penalty,
:api_key,
:seed,
:n,
:stream,
Expand All @@ -73,10 +76,10 @@ defmodule LangChain.ChatModels.ChatOpenAI do
]
@required_fields [:model]

@spec get_api_key() :: String.t()
defp get_api_key() do
@spec get_api_key(t) :: String.t()
defp get_api_key(%ChatOpenAI{api_key: api_key}) do
# if no API key is set default to `""` which will raise a Stripe API error
Config.resolve(:openai_key, "")
api_key || Config.resolve(:openai_key, "")
end

@spec get_org_id() :: String.t() | nil
Expand Down Expand Up @@ -125,6 +128,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
def for_api(%ChatOpenAI{} = openai, messages, functions) do
%{
model: openai.model,
api_key: openai.api_key,
temperature: openai.temperature,
frequency_penalty: openai.frequency_penalty,
n: openai.n,
Expand Down Expand Up @@ -242,7 +246,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
Req.new(
url: openai.endpoint,
json: for_api(openai, messages, functions),
auth: {:bearer, get_api_key()},
auth: {:bearer, get_api_key(openai)},
receive_timeout: openai.receive_timeout,
retry: :transient,
max_retries: 3,
Expand Down Expand Up @@ -327,7 +331,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
Req.new(
url: openai.endpoint,
json: for_api(openai, messages, functions),
auth: {:bearer, get_api_key()},
auth: {:bearer, get_api_key(openai)},
receive_timeout: openai.receive_timeout,
finch_request: finch_fun
)
Expand Down
4 changes: 3 additions & 1 deletion test/chat_models/chat_open_ai_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
ChatOpenAI.new(%{
"model" => "gpt-3.5-turbo-0613",
"temperature" => 1,
"frequency_penalty" => 0.5
"frequency_penalty" => 0.5,
"api_key" => "api_key"
})

data = ChatOpenAI.for_api(openai, [], [])
assert data.model == "gpt-3.5-turbo-0613"
assert data.temperature == 1
assert data.frequency_penalty == 0.5
assert data.response_format == %{"type" => "text"}
assert data.api_key == "api_key"
end

test "generates a map for an API call with JSON response set to true" do
Expand Down

0 comments on commit 407449b

Please sign in to comment.