diff --git a/main.py b/main.py index 59c6a29d..30609561 100644 --- a/main.py +++ b/main.py @@ -73,7 +73,7 @@ def main(): math_agent = agent_thread_pool.submit( agent_factory.run_agent, "MathAgent", - "Solve the problem that Albert is wondering how much pizza he can eat in one day. He buys 2 large pizzas and 2 small pizzas. A large pizza has 16 slices and a small pizza has 8 slices. If he eats it all, how many pieces does he eat that day?" + "A freelance graphic designer in Canada earns CAD 500 per project and is planning to work on projects from clients in both the UK and Canada this month. With an expected 3 projects from Canadian clients and 2 from UK clients (paying GBP 400 each), how much will the designer earn in total in CAD by the end of the month" ) narrative_agent = agent_thread_pool.submit( diff --git a/src/llm_kernel/llm_classes/base_llm.py b/src/llm_kernel/llm_classes/base_llm.py index a6c27700..503dcd41 100644 --- a/src/llm_kernel/llm_classes/base_llm.py +++ b/src/llm_kernel/llm_classes/base_llm.py @@ -76,10 +76,7 @@ def address_request_list(self, agent_process, temperature=0.0 ): - steps = agent_process.prompt.split(";") - for step in steps: - self.process - return + raise NotImplementedError @abstractmethod def process(self, diff --git a/src/llm_kernel/llm_classes/bed_rock.py b/src/llm_kernel/llm_classes/bed_rock.py index f6c9fcb5..80b8b939 100644 --- a/src/llm_kernel/llm_classes/bed_rock.py +++ b/src/llm_kernel/llm_classes/bed_rock.py @@ -2,6 +2,7 @@ from .base_llm import BaseLLMKernel import time +from ...utils.message import Response class BedrockLLM(BaseLLMKernel): @@ -32,7 +33,7 @@ def process(self, re.search(r'claude', self.model_name, re.IGNORECASE) agent_process.set_status("executing") agent_process.set_start_time(time.time()) - prompt = agent_process.prompt + prompt = agent_process.message.prompt from langchain_core.prompts import ChatPromptTemplate chat_template = ChatPromptTemplate.from_messages([ ("user", f"{prompt}") @@ -41,7 +42,11 @@ def process(self, self.model.model_kwargs['temperature'] = temperature try: response = self.model(messages) - agent_process.set_response(response.content) + agent_process.set_response( + Response( + response_message=response.content + ) + ) except IndexError: raise IndexError(f"{self.model_name} can not generate a valid result, please try again") agent_process.set_status("done") diff --git a/src/llm_kernel/llm_classes/gemini_llm.py b/src/llm_kernel/llm_classes/gemini_llm.py index 9f6548e2..0e9a7f34 100644 --- a/src/llm_kernel/llm_classes/gemini_llm.py +++ b/src/llm_kernel/llm_classes/gemini_llm.py @@ -3,6 +3,8 @@ from .base_llm import BaseLLMKernel import time from ...utils.utils import get_from_env + +from ...utils.message import Response class GeminiLLM(BaseLLMKernel): def __init__(self, llm_name: str, max_gpu_memory: dict = None, @@ -35,7 +37,8 @@ def process(self, assert re.search(r'gemini', self.model_name, re.IGNORECASE) agent_process.set_status("executing") agent_process.set_start_time(time.time()) - prompt = agent_process.prompt + prompt = agent_process.message.prompt + # TODO: add tool calling self.logger.log( f"{agent_process.agent_name} is switched to executing.\n", level = "executing" @@ -45,7 +48,11 @@ def process(self, ) try: result = outputs.candidates[0].content.parts[0].text - agent_process.set_response(result) + agent_process.set_response( + Response( + response_message = result + ) + ) except IndexError: raise IndexError(f"{self.model_name} can not generate a valid result, please try again") agent_process.set_status("done") diff --git a/src/llm_kernel/llm_classes/gpt_llm.py b/src/llm_kernel/llm_classes/gpt_llm.py index 868adf02..e9234ede 100644 --- a/src/llm_kernel/llm_classes/gpt_llm.py +++ b/src/llm_kernel/llm_classes/gpt_llm.py @@ -3,6 +3,8 @@ import time from openai import OpenAI +from ...utils.message import Response + class GPTLLM(BaseLLMKernel): def __init__(self, llm_name: str, @@ -27,7 +29,7 @@ def process(self, assert re.search(r'gpt', self.model_name, re.IGNORECASE) agent_process.set_status("executing") agent_process.set_start_time(time.time()) - prompt = agent_process.prompt + prompt = agent_process.message.prompt self.logger.log( f"{agent_process.agent_name} is switched to executing.\n", level = "executing" @@ -36,9 +38,18 @@ def process(self, model=self.model_name, messages=[ {"role": "user", "content": prompt} - ] + ], + tools = agent_process.message.tools, + tool_choice = "required" if agent_process.message.tools else None + ) + + # print(response.choices[0].message) + agent_process.set_response( + Response( + response_message = response.choices[0].message.content, + tool_calls = response.choices[0].message.tool_calls + ) ) - agent_process.set_response(response.choices[0].message.content) agent_process.set_status("done") agent_process.set_end_time(time.time()) return diff --git a/src/llm_kernel/llm_classes/ollama_llm.py b/src/llm_kernel/llm_classes/ollama_llm.py index 2a4cc5a1..dd2e9342 100644 --- a/src/llm_kernel/llm_classes/ollama_llm.py +++ b/src/llm_kernel/llm_classes/ollama_llm.py @@ -3,6 +3,7 @@ import time import ollama +from ...utils.message import Response class OllamaLLM(BaseLLMKernel): def __init__(self, llm_name: str, @@ -30,17 +31,23 @@ def process(self, assert re.search(r'ollama', self.mode, re.IGNORECASE) agent_process.set_status("executing") agent_process.set_start_time(time.time()) - prompt = agent_process.prompt + prompt = agent_process.message.prompt self.logger.log( f"{agent_process.agent_name} is switched to executing.\n", level = "executing" ) - response = ollama.chat(model=self.model_name, messages=[ + response = ollama.chat( + model=self.model_name, + messages=[ { "role": "user", "content": prompt - } - ]) - agent_process.set_response(response['message']['content']) + }], + ) + agent_process.set_response( + Response( + response_message = response['message']['content'] + ) + ) agent_process.set_status("done") agent_process.set_end_time(time.time()) return diff --git a/src/llm_kernel/llm_classes/open_llm.py b/src/llm_kernel/llm_classes/open_llm.py index 195f69fa..40754bb0 100644 --- a/src/llm_kernel/llm_classes/open_llm.py +++ b/src/llm_kernel/llm_classes/open_llm.py @@ -4,6 +4,8 @@ import time from transformers import AutoTokenizer +from ...utils.message import Response + class OpenLLM(BaseLLMKernel): def load_llm_and_tokenizer(self) -> None: @@ -49,7 +51,8 @@ def process(self, timestamp = agent_process.get_time_limit() ) else: - prompt = agent_process.prompt + # prompt = agent_process.prompt + prompt = agent_process.message.prompt input_ids = self.tokenizer.encode(prompt, return_tensors="pt") attention_masks = input_ids != self.tokenizer.pad_token_id input_ids = input_ids.to(self.eval_device) @@ -67,7 +70,7 @@ def process(self, output_ids = outputs["result"] - prompt = agent_process.prompt + prompt = agent_process.message.prompt result = self.tokenizer.decode(output_ids, skip_special_tokens=True) result = result[len(prompt)+1: ] @@ -78,7 +81,11 @@ def process(self, self.context_manager.clear_restoration( agent_process.get_pid() ) - agent_process.set_response(result) + agent_process.set_response( + Response( + response_message=result + ) + ) agent_process.set_status("done") else: @@ -95,7 +102,11 @@ def process(self, "beam_attention_masks": outputs["beam_attention_masks"] } ) - agent_process.set_response(result) + agent_process.set_response( + Response( + response_message = result + ) + ) agent_process.set_status("suspending") agent_process.set_end_time(time.time()) diff --git a/src/utils/message.py b/src/utils/message.py new file mode 100644 index 00000000..d148147f --- /dev/null +++ b/src/utils/message.py @@ -0,0 +1,18 @@ +class Message: + def __init__(self, + prompt, + context = None, + tools = None + ) -> None: + self.prompt = prompt + self.context = context + self.tools = tools + +class Response: + def __init__( + self, + response_message, + tool_calls = None + ) -> None: + self.response_message = response_message + self.tool_calls = tool_calls diff --git a/tests/test_llms.py b/tests/test_llms.py index 8d8637e9..3dfb819a 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -6,6 +6,8 @@ from openagi.src.agents.agent_process import AgentProcess from src.context.simple_context import SimpleContextManager +from src.utils.message import Message + # Load environment variables once for all tests load_dotenv(find_dotenv()) @@ -28,21 +30,25 @@ def test_closed_llm(): llm = LLMKernel(llm_type, max_new_tokens = 10) agent_process = AgentProcess( agent_name="Narrative Agent", - prompt="Craft a tale about a valiant warrior on a quest to uncover priceless treasures hidden within a mystical island." + message = Message( + prompt="Craft a tale about a valiant warrior on a quest to uncover priceless treasures hidden within a mystical island." + ) ) llm.address_request(agent_process) response = agent_process.get_response() - assert isinstance(response, str), "The response should be a string" + assert isinstance(response.response_message, str), "The response should be a string" def test_open_llm(llm_setup): llm = llm_setup agent_process = AgentProcess( agent_name="Narrative Agent", - prompt="Craft a tale about a valiant warrior on a quest to uncover priceless treasures hidden within a mystical island." + message = Message( + prompt="Craft a tale about a valiant warrior on a quest to uncover priceless treasures hidden within a mystical island." + ) ) llm.address_request(agent_process) response = agent_process.get_response() - assert isinstance(response, str), "The response should be a string" + assert isinstance(response.response_message, str), "The response should be a string" if torch.cuda.device_count() > 0: context_manager = SimpleContextManager() agent_process.set_pid(0)