Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: GoogleGemini generation_config param #665

Merged
merged 4 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions lib/langchain/llm/google_gemini.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def initialize(api_key:, default_options: {})

chat_parameters.update(
model: {default: @defaults[:chat_completion_model_name]},
temperature: {default: @defaults[:temperature]}
temperature: {default: @defaults[:temperature]},
generation_config: {default: nil},
safety_settings: {default: nil}
)
chat_parameters.remap(
messages: :contents,
Expand All @@ -42,13 +44,25 @@ def chat(params = {})
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?

parameters = chat_parameters.to_params(params)
parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
parameters[:generation_config] ||= {}
parameters[:generation_config][:temperature] ||= parameters[:temperature] if parameters[:temperature]
parameters.delete(:temperature)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not change:

parameters[:generation_config][:temperature] ||= parameters[:temperature] if parameters[:temperature]
parameters.delete(:temperature)

to

parameters[:generation_config][:temperature] ||= parameters.delete(:temperature) if parameters[:temperature]

everywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we always need to delete the parameter even if it is already defined in the generation_config. Otherwise actual payload will have additional unneeded parameters. This case already covered in the specs.

parameters[:generation_config][:top_p] ||= parameters[:top_p] if parameters[:top_p]
parameters.delete(:top_p)
parameters[:generation_config][:top_k] ||= parameters[:top_k] if parameters[:top_k]
parameters.delete(:top_k)
parameters[:generation_config][:max_output_tokens] ||= parameters[:max_tokens] if parameters[:max_tokens]
parameters.delete(:max_tokens)
parameters[:generation_config][:response_mime_type] ||= parameters[:response_format] if parameters[:response_format]
parameters.delete(:response_format)
parameters[:generation_config][:stop_sequences] ||= parameters[:stop] if parameters[:stop]
parameters.delete(:stop)

uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")

request = Net::HTTP::Post.new(uri)
request.content_type = "application/json"
request.body = parameters.to_json
request.body = Langchain::Utils::HashTransformer.deep_transform_keys(parameters) { |key| Langchain::Utils::HashTransformer.camelize_lower(key.to_s).to_sym }.to_json

response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
http.request(request)
Expand Down
25 changes: 25 additions & 0 deletions lib/langchain/utils/hash_transformer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module Langchain
module Utils
class HashTransformer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should've just yanked out the code out of the activesupport lib? This is fine, but I was wondering if it could be more elegant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code here is 100% genuine, not copied from ActiveSupport in any way. In fact It is GPT generated.

# Converts a string to camelCase
def self.camelize_lower(str)
str.split("_").inject([]) { |buffer, e| buffer.push(buffer.empty? ? e : e.capitalize) }.join
end

# Recursively transforms the keys of a hash to camel case
def self.deep_transform_keys(hash, &block)
case hash
when Hash
hash.each_with_object({}) do |(key, value), result|
new_key = block.call(key)
result[new_key] = deep_transform_keys(value, &block)
end
when Array
hash.map { |item| deep_transform_keys(item, &block) }
else
hash
end
end
end
end
end
47 changes: 47 additions & 0 deletions spec/langchain/llm/google_gemini_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,22 @@
RSpec.describe Langchain::LLM::GoogleGemini do
let(:subject) { described_class.new(api_key: "123") }

describe "#initialize" do
it "initializes with default options" do
expect(subject.api_key).to eq("123")
expect(subject.defaults[:chat_completion_model_name]).to eq("gemini-1.5-pro-latest")
expect(subject.defaults[:embeddings_model_name]).to eq("text-embedding-004")
expect(subject.defaults[:temperature]).to eq(0.0)
end

it "merges default options with provided options" do
custom_options = {chat_completion_model_name: "custom-model", temperature: 2.0}
google_gemini_with_custom_options = described_class.new(api_key: "123", default_options: custom_options)
expect(google_gemini_with_custom_options.defaults[:chat_completion_model_name]).to eq("custom-model")
expect(google_gemini_with_custom_options.defaults[:temperature]).to eq(2.0)
end
end

describe "#embed" do
let(:embedding) { [0.013168523, -0.008711934, -0.046782676] }
let(:raw_embedding_response) { double(body: File.read("spec/fixtures/llm/google_gemini/embed.json")) }
Expand All @@ -23,11 +39,42 @@
describe "#chat" do
let(:messages) { [{role: "user", parts: [{text: "How high is the sky?"}]}] }
let(:raw_chat_completions_response) { double(body: File.read("spec/fixtures/llm/google_gemini/chat.json")) }
let(:params) { {messages: messages, model: "gemini-1.5-pro-latest", system: "system instruction", tool_choice: "AUTO", tools: [{name: "tool1"}], temperature: 1.1, response_format: "application/json", stop: ["A", "B"], generation_config: {temperature: 1.7, top_p: 1.3, response_schema: {"type" => "object", "description" => "sample schema"}}, safety_settings: [{category: "HARM_CATEGORY_UNSPECIFIED", threshold: "BLOCK_ONLY_HIGH"}]} }

before do
allow(Net::HTTP).to receive(:start).and_return(raw_chat_completions_response)
end

it "raises an error if messages are not provided" do
expect { subject.chat({}) }.to raise_error(ArgumentError, "messages argument is required")
end

it "correctly processes and sends parameters" do
expect(Net::HTTP::Post).to receive(:new) do |uri|
expect(uri.to_s).to include("https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest:generateContent?key=123")
end.and_call_original

allow_any_instance_of(Net::HTTP::Post).to receive(:body=) do |request, body|
parsed_body = JSON.parse(body)

expect(parsed_body["model"]).to eq("gemini-1.5-pro-latest")
expect(parsed_body["contents"]).to eq([{"parts" => [{"text" => "How high is the sky?"}], "role" => "user"}])
expect(parsed_body["systemInstruction"]).to eq({"parts" => [{"text" => "system instruction"}]})
expect(parsed_body["toolConfig"]).to eq({"functionCallingConfig" => {"mode" => "AUTO"}})
expect(parsed_body["tools"]).to eq({"functionDeclarations" => [{"name" => "tool1"}]})
expect(parsed_body["temperature"]).to eq(nil)
expect(parsed_body["generationConfig"]["temperature"]).to eq(1.7)
expect(parsed_body["topP"]).to eq(nil)
expect(parsed_body["generationConfig"]["topP"]).to eq(1.3)
expect(parsed_body["responseFormat"]).to eq(nil)
expect(parsed_body["generationConfig"]["responseMimeType"]).to eq("application/json")
expect(parsed_body["generationConfig"]["responseSchema"]).to eq({"type" => "object", "description" => "sample schema"})
expect(parsed_body["safetySettings"]).to eq([{"category" => "HARM_CATEGORY_UNSPECIFIED", "threshold" => "BLOCK_ONLY_HIGH"}])
end

subject.chat(params)
end

it "returns valid llm response object" do
response = subject.chat(messages: messages)

Expand Down
67 changes: 67 additions & 0 deletions spec/langchain/utils/hash_transformer_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
RSpec.describe Langchain::Utils::HashTransformer do
describe ".camelize_lower" do
it "converts snake_case to camelCase" do
expect(described_class.camelize_lower("example_key")).to eq("exampleKey")
expect(described_class.camelize_lower("nested_key_example")).to eq("nestedKeyExample")
end

it "handles strings without underscores" do
expect(described_class.camelize_lower("example")).to eq("example")
end

it "handles empty strings" do
expect(described_class.camelize_lower("")).to eq("")
end
end

describe ".deep_transform_keys" do
it "transforms keys of a simple hash" do
hash = {example_key: "value", another_key: "another_value"}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({exampleKey: "value", anotherKey: "another_value"})
end

it "transforms keys of a nested hash" do
hash = {example_key: {nested_key: "value"}}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({exampleKey: {nestedKey: "value"}})
end

it "transforms keys of an array of hashes" do
hash = {array_key: [{nested_key: "value"}, {another_key: "another_value"}]}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({arrayKey: [{nestedKey: "value"}, {anotherKey: "another_value"}]})
end

it "handles arrays of non-hash elements" do
hash = {array_key: ["string", 123, :symbol]}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({arrayKey: ["string", 123, :symbol]})
end

it "handles non-hash, non-array values" do
hash = {simple_key: "value"}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({simpleKey: "value"})
end

it "handles empty hashes" do
hash = {}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({})
end

it "handles deeply nested structures" do
hash = {level_one: {level_two: {level_three: {nested_key: "value"}}}}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({levelOne: {levelTwo: {levelThree: {nestedKey: "value"}}}})
end
end
end