Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add native tool functionality (e.g. google_search for Gemini) #250

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/chains/llm_chain.ex
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ defmodule LangChain.Chains.LLMChain do
alias LangChain.Function
alias LangChain.LangChainError
alias LangChain.Utils
alias LangChain.NativeTool

@primary_key false
embedded_schema do
Expand Down Expand Up @@ -273,7 +274,7 @@ defmodule LangChain.Chains.LLMChain do
@doc """
Add a tool to an LLMChain.
"""
@spec add_tools(t(), Function.t() | [Function.t()]) :: t() | no_return()
@spec add_tools(t(), NativeTool.t() | Function.t() | [Function.t()]) :: t() | no_return()
def add_tools(%LLMChain{tools: existing} = chain, tools) do
updated = existing ++ List.wrap(tools)

Expand Down
49 changes: 37 additions & 12 deletions lib/chat_models/chat_google_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
alias LangChain.LangChainError
alias LangChain.Utils
alias LangChain.Callbacks
alias LangChain.NativeTool

@behaviour ChatModel

Expand Down Expand Up @@ -176,14 +177,25 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
|> LangChain.Utils.conditionally_add_to_map("safetySettings", google_ai.safety_settings)

if functions && not Enum.empty?(functions) do
req
|> Map.put("tools", [
%{
# Google AI functions use an OpenAI compatible format.
# See: https://ai.google.dev/docs/function_calling#how_it_works
"functionDeclarations" => Enum.map(functions, &for_api/1)
}
])
native_tools = Enum.filter(functions, &match?(%NativeTool{}, &1))
function_tools = Enum.filter(functions, &match?(%Function{}, &1))

tools_array = []
tools_array =
if function_tools != [] do
tools_array ++ [%{"functionDeclarations" => Enum.map(function_tools, &for_api/1)}]
else
tools_array
end

tools_array =
if native_tools != [] do
tools_array ++ Enum.map(native_tools, &for_api/1)
else
tools_array
end

Map.put(req, "tools", tools_array)
else
req
end
Expand All @@ -201,9 +213,6 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
end

def for_api(%Message{role: :tool} = message) do
# Function response is whacky. They don't explain why it has this extra nested structure.
#
# https://ai.google.dev/gemini-api/docs/function-calling#expandable-7
%{
"role" => map_role(:tool),
"parts" => Enum.map(message.tool_results, &for_api/1)
Expand All @@ -224,6 +233,13 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
}
end

def for_api(%Message{content: content} = message) when is_list(content) do
%{
"role" => message.role,
"parts" => Enum.map(content, &for_api/1)
}
end

def for_api(%ContentPart{type: :text} = part) do
%{"text" => part.content}
end
Expand Down Expand Up @@ -316,6 +332,14 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
end
end

def for_api(%NativeTool{name: name, configuration: %{}=config}) do
%{name => config}
end

def for_api(%NativeTool{name: name, configuration: nil}) do
name
end

@doc """
Calls the Google AI API passing the ChatGoogleAI struct with configuration, plus
either a simple message or the list of messages to act as the prompt.
Expand Down Expand Up @@ -527,7 +551,8 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
role: unmap_role(content_data["role"]),
content: text_part,
complete: true,
index: data["index"]
index: data["index"],
metadata: (if data["groundingMetadata"], do: data["groundingMetadata"], else: nil)
}
|> Utils.conditionally_add_to_map(:tool_calls, tool_calls_from_parts)
|> Utils.conditionally_add_to_map(:tool_results, tool_result_from_parts)
Expand Down
6 changes: 5 additions & 1 deletion lib/message.ex
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ defmodule LangChain.Message do
# A `:tool` role contains one or more `tool_results` from the system having
# used tools.
field :tool_results, :any, virtual: true

# Additional metadata about the message.
field :metadata, :map
end

@type t :: %Message{}
Expand All @@ -117,7 +120,8 @@ defmodule LangChain.Message do
:tool_calls,
:tool_results,
:index,
:name
:name,
:metadata
]
@create_fields @update_fields
@required_fields [:role]
Expand Down
51 changes: 51 additions & 0 deletions lib/native_tool.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
defmodule LangChain.NativeTool do
use Ecto.Schema
import Ecto.Changeset

alias __MODULE__
alias LangChain.LangChainError

embedded_schema do
field :name, :string
field :configuration, :map
end

@type t :: %NativeTool{}
@type configuration :: %{String.t() => any()}

@create_fields [
:name,
:configuration
]
@required_fields [:name]

@doc """
Build a new native tool.
"""
@spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()}
def new(attrs \\ %{}) do
%NativeTool{}
|> cast(attrs, @create_fields)
|> common_validation()
|> apply_action(:insert)
end

@doc """
Build a new native tool and return it or raise an error if invalid.
"""
@spec new!(attrs :: map()) :: t() | no_return()
def new!(attrs \\ %{}) do
case new(attrs) do
{:ok, native_tool} ->
native_tool

{:error, changeset} ->
raise LangChainError, changeset
end
end

defp common_validation(changeset) do
changeset
|> validate_required(@required_fields)
end
end
24 changes: 24 additions & 0 deletions test/chat_models/chat_google_ai_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,30 @@ defmodule ChatModels.ChatGoogleAITest do
end
end

describe "google_search native tool" do
@tag live_call: true, live_google_ai: true
test "should include grounding metadata in response" do
alias LangChain.Chains.LLMChain
alias LangChain.Message
alias LangChain.NativeTool

model = ChatGoogleAI.new!(%{temperature: 0, stream: false, model: "gemini-2.0-flash"})

{:ok, updated_chain} =
%{llm: model, verbose: false, stream: false}
|> LLMChain.new!()
|> LLMChain.add_message(
Message.new_user!("What is the current Google stock price?")
)
|> LLMChain.add_tools(NativeTool.new!(%{name: "google_search", configuration: %{}}))
|> LLMChain.run()

assert %Message{} = updated_chain.last_message
assert updated_chain.last_message.role == :assistant
assert Map.has_key?(updated_chain.last_message.metadata, "groundingChunks")
end
end

describe "calculator with GoogleAI model" do
@tag live_call: true, live_google_ai: true
test "should work" do
Expand Down