Skip to content

Commit

Permalink
fix: GoogleGemini generation_config param (#665)
Browse files Browse the repository at this point in the history
* fix generation_config param

* sync all parameters to match gemini API

* Migrate from active_support to in-house utils

---------

Co-authored-by: Andrei Bondarev <andrei@sourcelabs.io>
  • Loading branch information
mazenkhalil and andreibondarev authored Jun 22, 2024
1 parent da045bb commit 4c2dad0
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 3 deletions.
20 changes: 17 additions & 3 deletions lib/langchain/llm/google_gemini.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def initialize(api_key:, default_options: {})

chat_parameters.update(
model: {default: @defaults[:chat_completion_model_name]},
temperature: {default: @defaults[:temperature]}
temperature: {default: @defaults[:temperature]},
generation_config: {default: nil},
safety_settings: {default: nil}
)
chat_parameters.remap(
messages: :contents,
Expand All @@ -42,13 +44,25 @@ def chat(params = {})
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?

parameters = chat_parameters.to_params(params)
parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
parameters[:generation_config] ||= {}
parameters[:generation_config][:temperature] ||= parameters[:temperature] if parameters[:temperature]
parameters.delete(:temperature)
parameters[:generation_config][:top_p] ||= parameters[:top_p] if parameters[:top_p]
parameters.delete(:top_p)
parameters[:generation_config][:top_k] ||= parameters[:top_k] if parameters[:top_k]
parameters.delete(:top_k)
parameters[:generation_config][:max_output_tokens] ||= parameters[:max_tokens] if parameters[:max_tokens]
parameters.delete(:max_tokens)
parameters[:generation_config][:response_mime_type] ||= parameters[:response_format] if parameters[:response_format]
parameters.delete(:response_format)
parameters[:generation_config][:stop_sequences] ||= parameters[:stop] if parameters[:stop]
parameters.delete(:stop)

uri = URI("https://generativelanguage.googleapis.com/v1beta/models/#{parameters[:model]}:generateContent?key=#{api_key}")

request = Net::HTTP::Post.new(uri)
request.content_type = "application/json"
request.body = parameters.to_json
request.body = Langchain::Utils::HashTransformer.deep_transform_keys(parameters) { |key| Langchain::Utils::HashTransformer.camelize_lower(key.to_s).to_sym }.to_json

response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
http.request(request)
Expand Down
25 changes: 25 additions & 0 deletions lib/langchain/utils/hash_transformer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module Langchain
module Utils
class HashTransformer
# Converts a string to camelCase
def self.camelize_lower(str)
str.split("_").inject([]) { |buffer, e| buffer.push(buffer.empty? ? e : e.capitalize) }.join
end

# Recursively transforms the keys of a hash to camel case
def self.deep_transform_keys(hash, &block)
case hash
when Hash
hash.each_with_object({}) do |(key, value), result|
new_key = block.call(key)
result[new_key] = deep_transform_keys(value, &block)
end
when Array
hash.map { |item| deep_transform_keys(item, &block) }
else
hash
end
end
end
end
end
47 changes: 47 additions & 0 deletions spec/langchain/llm/google_gemini_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,22 @@
RSpec.describe Langchain::LLM::GoogleGemini do
let(:subject) { described_class.new(api_key: "123") }

describe "#initialize" do
it "initializes with default options" do
expect(subject.api_key).to eq("123")
expect(subject.defaults[:chat_completion_model_name]).to eq("gemini-1.5-pro-latest")
expect(subject.defaults[:embeddings_model_name]).to eq("text-embedding-004")
expect(subject.defaults[:temperature]).to eq(0.0)
end

it "merges default options with provided options" do
custom_options = {chat_completion_model_name: "custom-model", temperature: 2.0}
google_gemini_with_custom_options = described_class.new(api_key: "123", default_options: custom_options)
expect(google_gemini_with_custom_options.defaults[:chat_completion_model_name]).to eq("custom-model")
expect(google_gemini_with_custom_options.defaults[:temperature]).to eq(2.0)
end
end

describe "#embed" do
let(:embedding) { [0.013168523, -0.008711934, -0.046782676] }
let(:raw_embedding_response) { double(body: File.read("spec/fixtures/llm/google_gemini/embed.json")) }
Expand All @@ -23,11 +39,42 @@
describe "#chat" do
let(:messages) { [{role: "user", parts: [{text: "How high is the sky?"}]}] }
let(:raw_chat_completions_response) { double(body: File.read("spec/fixtures/llm/google_gemini/chat.json")) }
let(:params) { {messages: messages, model: "gemini-1.5-pro-latest", system: "system instruction", tool_choice: "AUTO", tools: [{name: "tool1"}], temperature: 1.1, response_format: "application/json", stop: ["A", "B"], generation_config: {temperature: 1.7, top_p: 1.3, response_schema: {"type" => "object", "description" => "sample schema"}}, safety_settings: [{category: "HARM_CATEGORY_UNSPECIFIED", threshold: "BLOCK_ONLY_HIGH"}]} }

before do
allow(Net::HTTP).to receive(:start).and_return(raw_chat_completions_response)
end

it "raises an error if messages are not provided" do
expect { subject.chat({}) }.to raise_error(ArgumentError, "messages argument is required")
end

it "correctly processes and sends parameters" do
expect(Net::HTTP::Post).to receive(:new) do |uri|
expect(uri.to_s).to include("https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-latest:generateContent?key=123")
end.and_call_original

allow_any_instance_of(Net::HTTP::Post).to receive(:body=) do |request, body|
parsed_body = JSON.parse(body)

expect(parsed_body["model"]).to eq("gemini-1.5-pro-latest")
expect(parsed_body["contents"]).to eq([{"parts" => [{"text" => "How high is the sky?"}], "role" => "user"}])
expect(parsed_body["systemInstruction"]).to eq({"parts" => [{"text" => "system instruction"}]})
expect(parsed_body["toolConfig"]).to eq({"functionCallingConfig" => {"mode" => "AUTO"}})
expect(parsed_body["tools"]).to eq({"functionDeclarations" => [{"name" => "tool1"}]})
expect(parsed_body["temperature"]).to eq(nil)
expect(parsed_body["generationConfig"]["temperature"]).to eq(1.7)
expect(parsed_body["topP"]).to eq(nil)
expect(parsed_body["generationConfig"]["topP"]).to eq(1.3)
expect(parsed_body["responseFormat"]).to eq(nil)
expect(parsed_body["generationConfig"]["responseMimeType"]).to eq("application/json")
expect(parsed_body["generationConfig"]["responseSchema"]).to eq({"type" => "object", "description" => "sample schema"})
expect(parsed_body["safetySettings"]).to eq([{"category" => "HARM_CATEGORY_UNSPECIFIED", "threshold" => "BLOCK_ONLY_HIGH"}])
end

subject.chat(params)
end

it "returns valid llm response object" do
response = subject.chat(messages: messages)

Expand Down
67 changes: 67 additions & 0 deletions spec/langchain/utils/hash_transformer_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
RSpec.describe Langchain::Utils::HashTransformer do
describe ".camelize_lower" do
it "converts snake_case to camelCase" do
expect(described_class.camelize_lower("example_key")).to eq("exampleKey")
expect(described_class.camelize_lower("nested_key_example")).to eq("nestedKeyExample")
end

it "handles strings without underscores" do
expect(described_class.camelize_lower("example")).to eq("example")
end

it "handles empty strings" do
expect(described_class.camelize_lower("")).to eq("")
end
end

describe ".deep_transform_keys" do
it "transforms keys of a simple hash" do
hash = {example_key: "value", another_key: "another_value"}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({exampleKey: "value", anotherKey: "another_value"})
end

it "transforms keys of a nested hash" do
hash = {example_key: {nested_key: "value"}}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({exampleKey: {nestedKey: "value"}})
end

it "transforms keys of an array of hashes" do
hash = {array_key: [{nested_key: "value"}, {another_key: "another_value"}]}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({arrayKey: [{nestedKey: "value"}, {anotherKey: "another_value"}]})
end

it "handles arrays of non-hash elements" do
hash = {array_key: ["string", 123, :symbol]}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({arrayKey: ["string", 123, :symbol]})
end

it "handles non-hash, non-array values" do
hash = {simple_key: "value"}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({simpleKey: "value"})
end

it "handles empty hashes" do
hash = {}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({})
end

it "handles deeply nested structures" do
hash = {level_one: {level_two: {level_three: {nested_key: "value"}}}}
result = described_class.deep_transform_keys(hash) { |key| described_class.camelize_lower(key.to_s).to_sym }

expect(result).to eq({levelOne: {levelTwo: {levelThree: {nestedKey: "value"}}}})
end
end
end

0 comments on commit 4c2dad0

Please sign in to comment.