diff --git a/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/_prompter.py b/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/_prompter.py index 2602feeb2c61..71bb4e7a5d44 100644 --- a/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/_prompter.py +++ b/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/_prompter.py @@ -184,7 +184,7 @@ async def find_index_topics(self, input_string: str) -> List[str]: return topic_list - async def generalize_task(self, task_description: str) -> str: + async def generalize_task(self, task_description: str, revise: bool | None = True) -> str: """ Attempts to rewrite a task description in a more general form. """ @@ -198,29 +198,31 @@ async def generalize_task(self, task_description: str) -> str: user_message.append(task_description) self._clear_history() - await self.call_model( + generalized_task = await self.call_model( summary="Ask the model to rephrase the task in a list of important points", system_message_content=sys_message, user_content=user_message, ) - user_message = [ - "Do you see any parts of this list that are irrelevant to actually solving the task? If so, explain which items are irrelevant." - ] - await self.call_model( - summary="Ask the model to identify irrelevant points", - system_message_content=sys_message, - user_content=user_message, - ) + if revise: + user_message = [ + "Do you see any parts of this list that are irrelevant to actually solving the task? If so, explain which items are irrelevant." + ] + await self.call_model( + summary="Ask the model to identify irrelevant points", + system_message_content=sys_message, + user_content=user_message, + ) + + user_message = [ + "Revise your original list to include only the most general terms, those that are critical to solving the task, removing any themes or descriptions that are not essential to the solution. Your final list may be shorter, but do not leave out any part of the task that is needed for solving the task. Do not add any additional commentary either before or after the list." + ] + generalized_task = await self.call_model( + summary="Ask the model to make a final list of general terms", + system_message_content=sys_message, + user_content=user_message, + ) - user_message = [ - "Revise your original list to include only the most general terms, those that are critical to solving the task, removing any themes or descriptions that are not essential to the solution. Your final list may be shorter, but do not leave out any part of the task that is needed for solving the task. Do not add any additional commentary either before or after the list." - ] - generalized_task = await self.call_model( - summary="Ask the model to make a final list of general terms", - system_message_content=sys_message, - user_content=user_message, - ) return generalized_task async def validate_insight(self, insight: str, task_description: str) -> bool: diff --git a/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/memory_controller.py b/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/memory_controller.py index 3a25f6ea18f9..acf5a649d72f 100644 --- a/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/memory_controller.py +++ b/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/memory_controller.py @@ -16,6 +16,11 @@ # Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating # the settings that change frequently, as when loading many settings from a single YAML file. class MemoryControllerConfig(TypedDict, total=False): + generalize_task: bool + revise_generalized_task: bool + generate_topics: bool + validate_memos: bool + max_memos_to_retrieve: int max_train_trials: int max_test_trials: int MemoryBank: "MemoryBankConfig" @@ -33,6 +38,11 @@ class MemoryController: task_assignment_callback: An optional callback used to assign a task to any agent managed by the caller. config: An optional dict that can be used to override the following values: + - generalize_task: Whether to rewrite tasks in more general terms. + - revise_generalized_task: Whether to critique then rewrite the generalized task. + - generate_topics: Whether to base retrieval directly on tasks, or on topics extracted from tasks. + - validate_memos: Whether to apply a final validation stage to retrieved memos. + - max_memos_to_retrieve: The maximum number of memos to return from retrieve_relevant_memos(). - max_train_trials: The maximum number of learning iterations to attempt when training on a task. - max_test_trials: The total number of attempts made when testing for failure on a task. - MemoryBank: A config dict passed to MemoryBank. @@ -91,10 +101,20 @@ def __init__( self.logger.enter_function() # Apply default settings and any config overrides. + self.generalize_task = True + self.revise_generalized_task = True + self.generate_topics = True + self.validate_memos = True + self.max_memos_to_retrieve = 10 self.max_train_trials = 10 self.max_test_trials = 3 memory_bank_config = None if config is not None: + self.generalize_task = config.get("generalize_task", self.generalize_task) + self.revise_generalized_task = config.get("revise_generalized_task", self.revise_generalized_task) + self.generate_topics = config.get("generate_topics", self.generate_topics) + self.validate_memos = config.get("validate_memos", self.validate_memos) + self.max_memos_to_retrieve = config.get("max_memos_to_retrieve", self.max_memos_to_retrieve) self.max_train_trials = config.get("max_train_trials", self.max_train_trials) self.max_test_trials = config.get("max_test_trials", self.max_test_trials) memory_bank_config = config.get("MemoryBank", memory_bank_config) @@ -178,8 +198,10 @@ async def add_memo(self, insight: str, task: None | str = None, index_on_both: b if task is not None: self.logger.info("\nGIVEN TASK:") self.logger.info(task) - # Generalize the task. - generalized_task = await self.prompter.generalize_task(task) + if self.generalize_task: + generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task) + else: + generalized_task = task self.logger.info("\nGIVEN INSIGHT:") self.logger.info(insight) @@ -196,7 +218,10 @@ async def add_memo(self, insight: str, task: None | str = None, index_on_both: b text_to_index = task self.logger.info("\nTOPICS EXTRACTED FROM TASK:") - topics = await self.prompter.find_index_topics(text_to_index) + if self.generate_topics: + topics = await self.prompter.find_index_topics(text_to_index) + else: + topics = [text_to_index] self.logger.info("\n".join(topics)) self.logger.info("") @@ -218,7 +243,10 @@ async def add_task_solution_pair_to_memory(self, task: str, solution: str) -> No self.logger.info(solution) # Get a list of topics from the task. - topics = await self.prompter.find_index_topics(task.strip()) + if self.generate_topics: + topics = await self.prompter.find_index_topics(task.strip()) + else: + topics = [task.strip()] self.logger.info("\nTOPICS EXTRACTED FROM TASK:") self.logger.info("\n".join(topics)) self.logger.info("") @@ -238,8 +266,14 @@ async def retrieve_relevant_memos(self, task: str) -> List[Memo]: self.logger.info(task) # Get a list of topics from the generalized task. - generalized_task = await self.prompter.generalize_task(task) - task_topics = await self.prompter.find_index_topics(generalized_task) + if self.generalize_task: + generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task) + else: + generalized_task = task + if self.generate_topics: + task_topics = await self.prompter.find_index_topics(generalized_task) + else: + task_topics = [generalized_task] self.logger.info("\nTOPICS EXTRACTED FROM TASK:") self.logger.info("\n".join(task_topics)) self.logger.info("") @@ -250,7 +284,9 @@ async def retrieve_relevant_memos(self, task: str) -> List[Memo]: # Apply a final validation stage to keep only the memos that the LLM concludes are sufficiently relevant. validated_memos: List[Memo] = [] for memo in memo_list: - if await self.prompter.validate_insight(memo.insight, task): + if len(validated_memos) >= self.max_memos_to_retrieve: + break + if (not self.validate_memos) or await self.prompter.validate_insight(memo.insight, task): validated_memos.append(memo) self.logger.info("\n{} VALIDATED MEMOS".format(len(validated_memos))) diff --git a/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/utils/chat_completion_client_recorder.py b/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/utils/chat_completion_client_recorder.py index 16124db1f3c5..d9cb84a87c5d 100644 --- a/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/utils/chat_completion_client_recorder.py +++ b/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/utils/chat_completion_client_recorder.py @@ -41,10 +41,9 @@ class ChatCompletionClientRecorder(ChatCompletionClient): create calls) or a "stream" (a list of streamed outputs for create_stream calls). ReplayChatCompletionClient and ChatCompletionCache do similar things, but with significant differences: - - ReplayChatCompletionClient replays pre-defined responses in a specified order - without recording anything or checking the messages sent to the client. - - ChatCompletionCache caches responses and replays them for messages that have been seen before, - regardless of order, and calls the base client for any uncached messages. + + - ReplayChatCompletionClient replays pre-defined responses in a specified order without recording anything or checking the messages sent to the client. + - ChatCompletionCache caches responses and replays them for messages that have been seen before, regardless of order, and calls the base client for any uncached messages. """ def __init__( diff --git a/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/utils/teachability.py b/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/utils/teachability.py index f8a09ee40e34..d9f511b93201 100644 --- a/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/utils/teachability.py +++ b/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/utils/teachability.py @@ -14,10 +14,11 @@ class Teachability(Memory): Gives an AssistantAgent the ability to learn quickly from user teachings, hints, and advice. Steps for usage: - 1. Instantiate MemoryController. - 2. Instantiate Teachability, passing the memory controller as a parameter. - 3. Instantiate an AssistantAgent, passing the teachability instance (wrapped in a list) as the memory parameter. - 4. Use the AssistantAgent as usual, such as for chatting with the user. + + 1. Instantiate MemoryController. + 2. Instantiate Teachability, passing the memory controller as a parameter. + 3. Instantiate an AssistantAgent, passing the teachability instance (wrapped in a list) as the memory parameter. + 4. Use the AssistantAgent as usual, such as for chatting with the user. """ def __init__(self, memory_controller: "MemoryController", name: str | None = None) -> None: diff --git a/python/samples/task_centric_memory/README.md b/python/samples/task_centric_memory/README.md index 2ae16228efc7..f78146ec5c84 100644 --- a/python/samples/task_centric_memory/README.md +++ b/python/samples/task_centric_memory/README.md @@ -41,7 +41,7 @@ or else modify `utils/client.py` as appropriate for the model you choose. ## Running the Samples The following samples are listed in order of increasing complexity. -Execute the corresponding commands from this (autogen_ext/task_centric_memory) directory. +Execute the corresponding commands from the `python/samples/task_centric_memory` directory. ### Making AssistantAgent Teachable diff --git a/python/samples/task_centric_memory/configs/self_teaching.yaml b/python/samples/task_centric_memory/configs/self_teaching.yaml index a6e28eaa6245..7007d3c9cb51 100644 --- a/python/samples/task_centric_memory/configs/self_teaching.yaml +++ b/python/samples/task_centric_memory/configs/self_teaching.yaml @@ -15,10 +15,10 @@ client: Apprentice: name_of_agent_or_team: AssistantAgent # AssistantAgent or MagenticOneGroupChat disable_prefix_caching: 1 # If true, prepends a small random string to the context, to decorrelate repeated runs. - TaskCentricMemoryController: + MemoryController: max_train_trials: 10 max_test_trials: 3 - TaskCentricMemoryBank: + MemoryBank: path: ./memory_bank/self_teaching relevance_conversion_threshold: 1.7 n_results: 25