diff --git a/README.md b/README.md index 5cdac289..0f9e761f 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ You can utilize non-visual models (e.g., GPT-4) for each agent by configuring th Optionally, you can set a backup language model (LLM) engine in the `BACKUP_AGENT` field to handle cases where the primary engines fail during inference. Ensure you configure these settings accurately to leverage non-visual models effectively. -#### +#### NOTE 💡 UFO also supports other LLMs and advanced configurations, such as customize your own model, please check the [documents](./model_worker/readme.md) for more details. Because of the limitations of model input, a lite version of the prompt is provided to allow users to experience it, which is configured in `config_dev`.yaml. ### 📔 Step 3: Additional Setting for RAG (optional). diff --git a/ufo/agent/agent.py b/ufo/agent/agent.py index 44c20bae..3dfdbf13 100644 --- a/ufo/agent/agent.py +++ b/ufo/agent/agent.py @@ -254,7 +254,6 @@ def build_human_demonstration_retriever(self, db_path: str) -> None: self.human_demonstration_retriever = self.retriever_factory.create_retriever("demonstration", db_path) - class HostAgent(BasicAgent): """ The HostAgent class the manager of AppAgents. diff --git a/ufo/automator/ui_control/control_filter.py b/ufo/automator/ui_control/control_filter.py new file mode 100644 index 00000000..6035d46e --- /dev/null +++ b/ufo/automator/ui_control/control_filter.py @@ -0,0 +1,217 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import heapq +import sentence_transformers +import re +import warnings + +warnings.filterwarnings("ignore") + +class ControlFilterFactory: + """ + Factory class to filter control items. + """ + + @staticmethod + def create_control_filter(control_filter_type: str, *args, **kwargs): + """ + Create a control filter model based on the given type. + :param control_filter_type: The type of control filter model to create. + :return: The created retriever. + """ + if control_filter_type == "text": + return TextControlFilter(*args, **kwargs) + elif control_filter_type == "semantic": + return SemanticControlFilter(*args, **kwargs) + elif control_filter_type == "icon": + return IconControlFilter(*args, **kwargs) + else: + raise ValueError("Invalid retriever type: {}".format(control_filter_type)) + + @staticmethod + def plan_to_keywords(plan:str) -> list: + """ + Gets keywords from the plan. + We only consider the words in the plan that are alphabetic or Chinese characters. + Args: + plan (str): The plan to be parsed. + Returns: + list: A list of keywords extracted from the plan. + """ + plans = plan.split("\n") + keywords = [] + for plan in plans: + words = plan.replace("'", "").strip(".").split() + words = [word for word in words if word.isalpha() or bool(re.fullmatch(r'[\u4e00-\u9fa5]+', word))] + keywords.extend(words) + return keywords + +class ControlFilterModel: + """ + ControlFilterModel represents a model for filtering control items. + """ + + _instances = {} + + def __new__(cls, model_path): + """ + Creates a new instance of ControlFilterModel. + Args: + model_path (str): The path to the model. + Returns: + ControlFilterModel: The ControlFilterModel instance. + """ + if model_path not in cls._instances: + instance = super(ControlFilterModel, cls).__new__(cls) + instance.model = cls.load_model(model_path) + cls._instances[model_path] = instance + return cls._instances[model_path] + + @staticmethod + def load_model(model_path): + """ + Loads the model from the given model path. + Args: + model_path (str): The path to the model. + Returns: + SentenceTransformer: The loaded SentenceTransformer model. + """ + return sentence_transformers.SentenceTransformer(model_path) + + def get_embedding(self, content): + """ + Encodes the given object into an embedding. + Args: + content: The content to encode. + Returns: + The embedding of the object. + """ + return self.model.encode(content) + + def control_filter(self, keywords, control_item): + """ + Calculates the cosine similarity between the embeddings of the given keywords and the control item. + Args: + keywords (str): The keywords to be used for calculating the similarity. + control_item (str): The control item to be compared with the keywords. + Returns: + float: The cosine similarity between the embeddings of the keywords and the control item. + """ + keywords_embedding = self.get_embedding(keywords) + control_item_embedding = self.get_embedding(control_item) + return self.cos_sim(keywords_embedding, control_item_embedding) + + @staticmethod + def cos_sim(embedding1, embedding2): + """ + Computes the cosine similarity between two embeddings. + """ + return sentence_transformers.util.cos_sim(embedding1, embedding2) + + +class TextControlFilter: + """ + A class that provides methods for filtering control items based on keywords. + """ + + @staticmethod + def control_filter(filtered_control_info, control_items, keywords): + """ + Filters control items based on keywords. + Args: + filtered_control_info (list): A list of control items that have already been filtered. + control_items (list): A list of control items to be filtered. + keywords (list): A list of keywords to filter the control items. + """ + for control_item in control_items: + if control_item not in filtered_control_info: + control_text = control_item['control_text'].lower() + if any(keyword in control_text or control_text in keyword for keyword in keywords): + filtered_control_info.append(control_item) + +class SemanticControlFilter(ControlFilterModel): + """ + A class that represents a semantic model for control filtering. + """ + + def control_filter_score(self, control_text, keywords): + """ + Calculates the score for a control item based on the similarity between its text and a set of keywords. + Args: + control_text (str): The text of the control item. + keywords (list): A list of keywords. + Returns: + float: The score indicating the similarity between the control text and the keywords. + """ + keywords_embedding = self.get_embedding(keywords) + control_text_embedding = self.get_embedding(control_text) + return max(self.cos_sim(control_text_embedding, keywords_embedding).tolist()[0]) + + def control_filter(self, filtered_control_info, control_items, keywords, top_k): + """ + Filters control items based on their similarity to a set of keywords. + Args: + filtered_control_info (list): A list of already filtered control items. + control_items (list): A list of control items to be filtered. + keywords (list): A list of keywords. + top_k (int): The number of top control items to be selected. + """ + scores = [] + for control_item in control_items: + if control_item not in filtered_control_info: + control_text = control_item['control_text'].lower() + score = self.control_filter_score(control_text, keywords) + else: + score = -100.0 + scores.append(score) + topk_items = heapq.nlargest(top_k, enumerate(scores), key=lambda x: x[1]) + topk_indices = [item[0] for item in topk_items] + + filtered_control_info.extend([control_items[i] for i in topk_indices]) + +class IconControlFilter(ControlFilterModel): + """ + Represents a model for filtering control icons based on keywords. + Attributes: + Inherits attributes from ControlFilterModel. + Methods: + control_filter_score(control_icon, keywords): Calculates the score of a control icon based on its similarity to the given keywords. + control_filter(filtered_control_info, control_items, cropped_icons_dict, keywords, top_k): Filters control items based on their scores and returns the top-k items. + """ + + def control_filter_score(self, control_icon, keywords): + """ + Calculates the score of a control icon based on its similarity to the given keywords. + Args: + control_icon: The control icon image. + keywords: The keywords to compare the control icon against. + Returns: + The maximum similarity score between the control icon and the keywords. + """ + keywords_embedding = self.get_embedding(keywords) + control_icon_embedding = self.get_embedding(control_icon) + return max(self.cos_sim(control_icon_embedding, keywords_embedding).tolist()[0]) + + def control_filter(self, filtered_control_info, control_items, cropped_icons_dict, keywords, top_k): + """ + Filters control items based on their scores and returns the top-k items. + Args: + filtered_control_info: The list of already filtered control items. + control_items: The list of all control items. + cropped_icons: The dictionary of the cropped icons. + keywords: The keywords to compare the control icons against. + top_k: The number of top items to return. + Returns: + The list of top-k control items based on their scores. + """ + scores = [] + for label, cropped_icon in cropped_icons_dict.items(): + if label not in [info['label'] for info in filtered_control_info]: + score = self.control_filter_score(cropped_icon, keywords) + scores.append((score, label)) + else: + scores.append((-100.0, label)) + topk_items = heapq.nlargest(top_k, scores, key=lambda x: x[0]) + topk_labels = [item[1] for item in topk_items] + + filtered_control_info.extend([control_item for control_item in control_items if control_item['label'] in topk_labels ]) \ No newline at end of file diff --git a/ufo/automator/ui_control/screenshot.py b/ufo/automator/ui_control/screenshot.py index 3290b530..ce9a9a0f 100644 --- a/ufo/automator/ui_control/screenshot.py +++ b/ufo/automator/ui_control/screenshot.py @@ -258,6 +258,23 @@ def get_annotation_dict(self) -> Dict: annotation_dict[label_text] = control return annotation_dict + def get_cropped_icons_dict(self) -> Dict: + """ + Get the dictionary of the cropped icons. + :return: The dictionary of the cropped icons. + """ + cropped_icons_dict = {} + image = self.photographer.capture() + window_rect = self.photographer.control.rectangle() + for i, control in enumerate(self.sub_control_list): + if self.annotation_type == "number": + label_text = str(i+1) + elif self.annotation_type == "letter": + label_text = self.number_to_letter(i) + control_rect = control.rectangle() + cropped_icons_dict[label_text] = image.crop(self.coordinate_adjusted(window_rect, control_rect)) + return cropped_icons_dict + def capture(self, save_path:Optional[str] = None): """ @@ -370,6 +387,20 @@ def get_annotation_dict(self, control, sub_control_list: List, annotation_type=" screenshot = self.screenshot_factory.create_screenshot("app_window", control) screenshot = AnnotationDecorator(screenshot, sub_control_list, annotation_type) return screenshot.get_annotation_dict() + + + def get_cropped_icons_dict(self, control, sub_control_list: List, annotation_type="number") -> Dict: + """ + Get the dictionary of the cropped icons. + :param control: The control item to capture. + :param sub_control_list: The list of the controls to annotate. + :param annotation_type: The type of the annotation. + :return: The dictionary of the cropped icons. + """ + + screenshot = self.screenshot_factory.create_screenshot("app_window", control) + screenshot = AnnotationDecorator(screenshot, sub_control_list, annotation_type) + return screenshot.get_cropped_icons_dict() diff --git a/ufo/config/config_dev.yaml b/ufo/config/config_dev.yaml index 214c5e46..636d2974 100644 --- a/ufo/config/config_dev.yaml +++ b/ufo/config/config_dev.yaml @@ -51,3 +51,11 @@ INPUT_TEXT_ENTER: True # whether to press enter after typing the text ## APIs related USE_APIS: True # Whether to use the API WORD_API_PROMPT: "ufo/prompts/base/{mode}/word_api.yaml" # The prompt for the word API + +# For control filtering +#'TEXT' for only rich text filter, 'SEMANTIC' for only semantic similarity match, 'ICON' for only icon match +CONTROL_FILTER_TYPE: ["TEXT", "SEMANTIC", "ICON"] # The control filter type +CONTROL_FILTER_TOP_K_SEMANTIC: 15 # The control filter top k for semantic similarity +CONTROL_FILTER_TOP_K_ICON: 15 # The control filter top k for icon similarity +CONTROL_FILTER_MODEL_SEMANTIC_NAME: "all-MiniLM-L6-v2" # The control filter model name of semantic similarity +CONTROL_FILTER_MODEL_ICON_NAME: "clip-ViT-B-32" # The control filter model name of icon similarity diff --git a/ufo/module/processor.py b/ufo/module/processor.py index 66f625ef..6460e5fe 100644 --- a/ufo/module/processor.py +++ b/ufo/module/processor.py @@ -18,6 +18,9 @@ from ..config.config import Config from . import interactor +# Lazy import the control_filter factory to aviod long loading time. +control_filter = utils.LazyImport("..automator.ui_control.control_filter") + configs = Config.get_instance().config_data BACKEND = configs["CONTROL_BACKEND"] @@ -46,10 +49,9 @@ def __init__(self, index: int, log_path: str, photographer: PhotographerFacade, self.request_logger = request_logger self.logger = logger self._app_window = app_window - + self.global_step = global_step self.round_step = round_step - self.prev_status = prev_status self.index = index @@ -59,7 +61,6 @@ def __init__(self, index: int, log_path: str, photographer: PhotographerFacade, self._response = None self._cost = 0 self._control_label = None - self._control_text = None self._response_json = None @@ -83,7 +84,7 @@ def process(self): self.print_step_info() self.capture_screenshot() self.get_control_info() - self.get_prompt_message() + self.get_prompt_message() self.get_response() if self.is_error(): @@ -386,14 +387,12 @@ def execute_action(self): """ # Get the application window - new_app_window = self._desktop_windows_dict.get(self.control_label, None) if new_app_window is None: return - # Get the application name - self.app_root = control.get_application_name(new_app_window) - + self.app_root = control.get_application_name(new_app_window) + try: new_app_window.is_normal() @@ -404,7 +403,7 @@ def execute_action(self): return self._status = "CONTINUE" - + if new_app_window is not self._app_window and self._app_window is not None: utils.print_with_color( "Switching to a new application...", "magenta") @@ -473,6 +472,7 @@ def create_app_agent(self): + class AppAgentProcessor(BaseProcessor): def __init__(self, index: int, log_path: str, photographer: PhotographerFacade, request: str, request_logger: Logger, logger: Logger, app_agent: AppAgent, round_step:int, global_step: int, @@ -506,8 +506,9 @@ def __init__(self, index: int, log_path: str, photographer: PhotographerFacade, self._args = None self._image_url = [] self._control_reannotate = None - - + self.cropped_icons_dict = {} + self.control_filter_factory = control_filter.ControlFilterFactory() + def print_step_info(self): """ Print the step information. @@ -530,6 +531,7 @@ def capture_screenshot(self): control_list = control.find_control_elements_in_descendants(BACKEND, self._app_window, control_type_list = configs["CONTROL_LIST"], class_name_list = configs["CONTROL_LIST"]) self._annotation_dict = self.photographer.get_annotation_dict(self._app_window, control_list, annotation_type="number") + self.cropped_icons_dict = self.photographer.get_cropped_icons_dict(self._app_window, control_list, annotation_type="number") self.photographer.capture_app_window_screenshot(self._app_window, save_path=screenshot_save_path) self.photographer.capture_app_window_screenshot_with_annotation(self._app_window, control_list, annotation_type="number", save_path=annotated_screenshot_save_path) @@ -550,14 +552,55 @@ def capture_screenshot(self): self._image_url += [screenshot_url, screenshot_annotated_url] - def get_control_info(self): """ Get the control information. """ self._control_info = control.get_control_info_dict(self._annotation_dict, ["control_text", "control_type" if BACKEND == "uia" else "control_class"]) + + def get_filtered_control_info(self, plan:str): + """ + Get the filtered control information. + + :param plan: The plan string. + + Return: + The filtered control information. + """ + + control_filter_type = configs["CONTROL_FILTER_TYPE"] + control_filter_type_lower = [control_filter_type_lower.lower() for control_filter_type_lower in control_filter_type] + is_text_required = 'text' in control_filter_type_lower + is_semantic_required = 'semantic' in control_filter_type_lower + is_icon_required = 'icon' in control_filter_type_lower + + if control_filter_type and not (is_text_required or is_semantic_required or is_icon_required): + + raise ValueError(f"Unsupported CONTROL_FILTER_TYPE: {control_filter_type}") + + elif not control_filter_type: + + return self._control_info + else: + filtered_control_info = [] + keywords = self.control_filter_factory.plan_to_keywords(plan) + + if is_text_required: + model_text = self.control_filter_factory.create_control_filter('text') + model_text.control_filter(filtered_control_info, self._control_info, keywords) + + if is_semantic_required: + model_semantic = self.control_filter_factory.create_control_filter('semantic', configs["CONTROL_FILTER_MODEL_SEMANTIC_NAME"]) + model_semantic.control_filter(filtered_control_info, self._control_info, keywords, configs["CONTROL_FILTER_TOP_K_SEMANTIC"]) + + if is_icon_required: + model_icon = self.control_filter_factory.create_control_filter('icon', configs["CONTROL_FILTER_MODEL_ICON_NAME"]) + model_icon.control_filter(filtered_control_info, self._control_info, self.cropped_icons_dict, keywords, configs["CONTROL_FILTER_TOP_K_ICON"]) + + + def get_prompt_message(self): """ Get the prompt message. @@ -590,10 +633,13 @@ def get_prompt_message(self): if agent_memory.length > 0: prev_plan = agent_memory.get_latest_item().to_dict()["Plan"].strip() + filtered_control_info = self.get_filtered_control_info(prev_plan) else: prev_plan = "" + filtered_control_info = self.get_filtered_control_info(HostAgent.memory.get_latest_item().to_dict()["Plan"]) + self._prompt_message = self.AppAgent.message_constructor(examples, tips, external_knowledge_prompt, self._image_url, request_history, action_history, - self._control_info, prev_plan, self.request, configs["INCLUDE_LAST_SCREENSHOT"]) + filtered_control_info, prev_plan, self.request, configs["INCLUDE_LAST_SCREENSHOT"]) self.request_logger.debug(json.dumps({"step": self.global_step, "prompt": self._prompt_message, "status": ""})) @@ -698,7 +744,6 @@ def update_memory(self): additional_memory = {"Step": self.global_step, "RoundStep": self.get_process_step(), "AgentStep": self.AppAgent.get_step(), "Round": self.index, "Action": self._action, "Request": self.request, "Agent": "ActAgent", "AgentName": self.AppAgent.name, "Application": app_root, "Cost": self._cost, "Results": self._results} - app_agent_step_memory.set_values_from_dict(self._response_json) app_agent_step_memory.set_values_from_dict(additional_memory) @@ -748,5 +793,4 @@ def get_control_reannotate(self): :return: The control to reannotate. """ - return self._control_reannotate - \ No newline at end of file + return self._control_reannotate \ No newline at end of file diff --git a/ufo/module/round.py b/ufo/module/round.py index dad23046..9d3b9e69 100644 --- a/ufo/module/round.py +++ b/ufo/module/round.py @@ -11,7 +11,6 @@ from . import processor configs = Config.get_instance().config_data -BACKEND = configs["CONTROL_BACKEND"] @@ -44,7 +43,7 @@ def __init__(self, task: str, logger: Logger, request_logger: Logger, photograph self.application = "" self.app_root = "" self.app_window = None - + self._cost = 0.0 self.control_reannotate = [] @@ -62,7 +61,7 @@ def process_application_selection(self) -> None: host_agent_processor = processor.HostAgentProcessor(index=self.index, log_path=self.log_path, photographer=self.photographer, request=self.request, round_step=self.get_step(), global_step=self.global_step, request_logger=self.request_logger, logger=self.logger, host_agent=self.HostAgent, prev_status=self.get_status(), app_window=self.app_window) - + host_agent_processor.process() self._status = host_agent_processor.get_process_status() @@ -78,13 +77,13 @@ def process_action_selection(self) -> None: """ Select an action with the application. """ - + app_agent_processor = processor.AppAgentProcessor(index=self.index, log_path=self.log_path, photographer=self.photographer, request=self.request, round_step=self.get_step(), global_step=self.global_step, process_name=self.application, request_logger=self.request_logger, logger=self.logger, app_agent=self.AppAgent, app_window=self.app_window, control_reannotate=self.control_reannotate, prev_status=self.get_status()) - + app_agent_processor.process() - + self._status = app_agent_processor.get_process_status() self._step += app_agent_processor.get_process_step() self.update_cost(app_agent_processor.get_process_cost()) @@ -98,16 +97,16 @@ def get_status(self) -> str: return: The status of the session. """ return self._status - - - + + + def get_step(self) -> int: """ Get the step of the session. return: The step of the session. """ return self._step - + def get_cost(self) -> float: """ @@ -115,8 +114,8 @@ def get_cost(self) -> float: return: The cost of the session. """ return self._cost - - + + def print_cost(self) -> None: # Print the total cost @@ -125,7 +124,7 @@ def print_cost(self) -> None: formatted_cost = '${:.2f}'.format(total_cost) utils.print_with_color(f"Request total cost for current round is {formatted_cost}", "yellow") - + def get_results(self) -> str: """ Get the results of the session. @@ -139,7 +138,7 @@ def get_results(self) -> str: else: result = "" return result - + def set_index(self, index: int) -> None: """ @@ -152,15 +151,15 @@ def set_global_step(self, global_step: int) -> None: Set the global step of the session. """ self.global_step = global_step - - + + def get_application_window(self) -> object: """ Get the application of the session. return: The application of the session. """ return self.app_window - + def update_cost(self, cost: float) -> None: """ @@ -169,4 +168,4 @@ def update_cost(self, cost: float) -> None: if isinstance(cost, float) and isinstance(self._cost, float): self._cost += cost else: - self._cost = None + self._cost = None \ No newline at end of file diff --git a/ufo/module/session.py b/ufo/module/session.py index 2ffa1cd3..15d56d57 100644 --- a/ufo/module/session.py +++ b/ufo/module/session.py @@ -21,13 +21,14 @@ class Session(object): """ A session for UFO. """ - + def __init__(self, task): """ Initialize a session. :param task: The name of current task. :param gpt_key: GPT key. """ + self.task = task self._step = 0 self._round = 0 @@ -56,7 +57,7 @@ def __init__(self, task): utils.print_with_color(interactor.WELCOME_TEXT, "cyan") self.request = interactor.first_request() - + self.round_list = [] self._current_round = self.create_round() @@ -73,8 +74,8 @@ def create_round(self) -> round.Round: self.round_list.append(new_round) return new_round - - + + def experience_saver(self) -> None: """ Save the current trajectory as agent experience. @@ -122,7 +123,7 @@ def round_hostagent_execution(self) -> None: current_round.set_global_step(self.get_step()) current_round.process_application_selection() - + self._status = current_round.get_status() self._step += 1 @@ -133,7 +134,7 @@ def round_appagent_execution(self) -> None: """ Execute the app agent in the current round. """ - + current_round = self.get_current_round() current_round.set_global_step(self.get_step()) @@ -141,8 +142,7 @@ def round_appagent_execution(self) -> None: self._status = current_round.get_status() self._step += 1 - - + def get_current_round(self) -> round.Round: """ @@ -152,7 +152,6 @@ def get_current_round(self) -> round.Round: return self._current_round - def get_round_num(self) -> int: """ Get the round of the session. @@ -196,7 +195,7 @@ def get_cost(self) -> float: def print_cost(self) -> None: """ Print the total cost. - """ + """ total_cost = self.get_cost() if isinstance(total_cost, float): @@ -219,6 +218,7 @@ def get_results(self) -> str: return result + def get_application_window(self) -> object: """ Get the application of the session. @@ -237,6 +237,7 @@ def update_cost(self, cost: float) -> None: self._cost = None + def set_state(self, state) -> None: """ Set the state of the session. @@ -267,6 +268,7 @@ def initialize_logger(log_path: str, log_filename: str) -> logging.Logger: # Remove existing handlers if PRINT_LOG is False logger.handlers = [] + log_file_path = os.path.join(log_path, log_filename) file_handler = logging.FileHandler(log_file_path, encoding="utf-8") formatter = logging.Formatter('%(message)s') diff --git a/ufo/module/state.py b/ufo/module/state.py index 932f743d..83e2027e 100644 --- a/ufo/module/state.py +++ b/ufo/module/state.py @@ -68,7 +68,6 @@ def handle(self, session): pass - class RoundFinishState(SessionState): """ The state when a single round is finished.