Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

COT agent: add max_iterations option #162

Merged
merged 7 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/langchain.rb
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,8 @@ module Prompt
autoload :PromptTemplate, "langchain/prompt/prompt_template"
autoload :FewShotPromptTemplate, "langchain/prompt/few_shot_prompt_template"
end

module Errors
class BaseError < StandardError; end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@ module Langchain::Agent
# agent.run(question: "How many full soccer fields would be needed to cover the distance between NYC and DC in a straight line?")
# #=> "Approximately 2,945 soccer fields would be needed to cover the distance between NYC and DC in a straight line."
class ChainOfThoughtAgent < Base
attr_reader :llm, :tools
attr_reader :llm, :tools, :max_iterations

# Initializes the Agent
#
# @param llm [Object] The LLM client to use
# @param tools [Array] The tools to use
# @param max_iterations [Integer] The maximum number of iterations to run
# @return [ChainOfThoughtAgent] The Agent::ChainOfThoughtAgent instance
bborn marked this conversation as resolved.
Show resolved Hide resolved
def initialize(llm:, tools: [])
def initialize(llm:, tools: [], max_iterations: 10)
Langchain::Tool::Base.validate_tools!(tools: tools)

@tools = tools

@llm = llm
@max_iterations = max_iterations
end

# Validate tools when they're re-assigned
Expand All @@ -51,7 +53,8 @@ def run(question:)
tools: tools
)

loop do
final_response = nil
max_iterations.times do
Langchain.logger.info("[#{self.class.name}]".red + ": Sending the prompt to the #{llm.class} LLM")

response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"])
Expand Down Expand Up @@ -81,9 +84,12 @@ def run(question:)
end
else
# Return the final answer
break response.match(/Final Answer: (.*)/)&.send(:[], -1)
final_response = response.match(/Final Answer: (.*)/)&.send(:[], -1)
break
end
end

final_response || raise(MaxIterationsReachedError.new(max_iterations))
end

private
Expand Down Expand Up @@ -114,5 +120,11 @@ def prompt_template
file_path: Langchain.root.join("langchain/agent/chain_of_thought_agent/chain_of_thought_agent_prompt.json")
)
end

class MaxIterationsReachedError < Langchain::Errors::BaseError
def initialize(max_iterations)
super("Agent stopped after #{max_iterations} iterations")
end
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@
it "runs the agent" do
expect(subject.run(question: question)).to eq(final_answer)
end

it "raises an error after max_iterations" do
allow(subject).to receive(:max_iterations).and_return(1)

expect { subject.run(question: question) }.to raise_error(Langchain::Agent::ChainOfThoughtAgent::MaxIterationsReachedError)
end
end

describe "#create_prompt" do
Expand Down