diff --git a/lib/langchain.rb b/lib/langchain.rb index 485abfe1a..26526e1ff 100644 --- a/lib/langchain.rb +++ b/lib/langchain.rb @@ -130,4 +130,8 @@ module Prompt autoload :PromptTemplate, "langchain/prompt/prompt_template" autoload :FewShotPromptTemplate, "langchain/prompt/few_shot_prompt_template" end + + module Errors + class BaseError < StandardError; end + end end diff --git a/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb b/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb index 3240a0150..8f335e4e8 100644 --- a/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb +++ b/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb @@ -16,19 +16,21 @@ module Langchain::Agent # agent.run(question: "How many full soccer fields would be needed to cover the distance between NYC and DC in a straight line?") # #=> "Approximately 2,945 soccer fields would be needed to cover the distance between NYC and DC in a straight line." class ChainOfThoughtAgent < Base - attr_reader :llm, :tools + attr_reader :llm, :tools, :max_iterations # Initializes the Agent # # @param llm [Object] The LLM client to use # @param tools [Array] The tools to use + # @param max_iterations [Integer] The maximum number of iterations to run # @return [ChainOfThoughtAgent] The Agent::ChainOfThoughtAgent instance - def initialize(llm:, tools: []) + def initialize(llm:, tools: [], max_iterations: 10) Langchain::Tool::Base.validate_tools!(tools: tools) @tools = tools @llm = llm + @max_iterations = max_iterations end # Validate tools when they're re-assigned @@ -51,7 +53,8 @@ def run(question:) tools: tools ) - loop do + final_response = nil + max_iterations.times do Langchain.logger.info("[#{self.class.name}]".red + ": Sending the prompt to the #{llm.class} LLM") response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"]) @@ -81,9 +84,12 @@ def run(question:) end else # Return the final answer - break response.match(/Final Answer: (.*)/)&.send(:[], -1) + final_response = response.match(/Final Answer: (.*)/)&.send(:[], -1) + break end end + + final_response || raise(MaxIterationsReachedError.new(max_iterations)) end private @@ -114,5 +120,11 @@ def prompt_template file_path: Langchain.root.join("langchain/agent/chain_of_thought_agent/chain_of_thought_agent_prompt.json") ) end + + class MaxIterationsReachedError < Langchain::Errors::BaseError + def initialize(max_iterations) + super("Agent stopped after #{max_iterations} iterations") + end + end end end diff --git a/spec/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_spec.rb b/spec/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_spec.rb index 068446029..59daa47a6 100644 --- a/spec/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_spec.rb +++ b/spec/langchain/agent/chain_of_thought_agent/chain_of_thought_agent_spec.rb @@ -76,6 +76,12 @@ it "runs the agent" do expect(subject.run(question: question)).to eq(final_answer) end + + it "raises an error after max_iterations" do + allow(subject).to receive(:max_iterations).and_return(1) + + expect { subject.run(question: question) }.to raise_error(Langchain::Agent::ChainOfThoughtAgent::MaxIterationsReachedError) + end end describe "#create_prompt" do