Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

control filter #61

Merged
merged 7 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ You can utilize non-visual models (e.g., GPT-4) for each agent by configuring th

Optionally, you can set a backup language model (LLM) engine in the `BACKUP_AGENT` field to handle cases where the primary engines fail during inference. Ensure you configure these settings accurately to leverage non-visual models effectively.

####
#### NOTE
💡 UFO also supports other LLMs and advanced configurations, such as customize your own model, please check the [documents](./model_worker/readme.md) for more details. Because of the limitations of model input, a lite version of the prompt is provided to allow users to experience it, which is configured in `config_dev`.yaml.

### 📔 Step 3: Additional Setting for RAG (optional).
Expand Down
1 change: 0 additions & 1 deletion ufo/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def build_human_demonstration_retriever(self, db_path: str) -> None:
self.human_demonstration_retriever = self.retriever_factory.create_retriever("demonstration", db_path)



class HostAgent(BasicAgent):
"""
The HostAgent class the manager of AppAgents.
Expand Down
217 changes: 217 additions & 0 deletions ufo/automator/ui_control/control_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import heapq
import sentence_transformers
import re
import warnings

warnings.filterwarnings("ignore")

class ControlFilterFactory:
"""
Factory class to filter control items.
"""

@staticmethod
def create_control_filter(control_filter_type: str, *args, **kwargs):
"""
Create a control filter model based on the given type.
:param control_filter_type: The type of control filter model to create.
:return: The created retriever.
"""
if control_filter_type == "text":
return TextControlFilter(*args, **kwargs)
elif control_filter_type == "semantic":
return SemanticControlFilter(*args, **kwargs)
elif control_filter_type == "icon":
return IconControlFilter(*args, **kwargs)
else:
raise ValueError("Invalid retriever type: {}".format(control_filter_type))

@staticmethod
def plan_to_keywords(plan:str) -> list:
"""
Gets keywords from the plan.
We only consider the words in the plan that are alphabetic or Chinese characters.
Args:
plan (str): The plan to be parsed.
Returns:
list: A list of keywords extracted from the plan.
"""
plans = plan.split("\n")
keywords = []
for plan in plans:
words = plan.replace("'", "").strip(".").split()
words = [word for word in words if word.isalpha() or bool(re.fullmatch(r'[\u4e00-\u9fa5]+', word))]
keywords.extend(words)
return keywords

class ControlFilterModel:
"""
ControlFilterModel represents a model for filtering control items.
"""

_instances = {}

def __new__(cls, model_path):
"""
Creates a new instance of ControlFilterModel.
Args:
model_path (str): The path to the model.
Returns:
ControlFilterModel: The ControlFilterModel instance.
"""
if model_path not in cls._instances:
instance = super(ControlFilterModel, cls).__new__(cls)
instance.model = cls.load_model(model_path)
cls._instances[model_path] = instance
return cls._instances[model_path]

@staticmethod
def load_model(model_path):
"""
Loads the model from the given model path.
Args:
model_path (str): The path to the model.
Returns:
SentenceTransformer: The loaded SentenceTransformer model.
"""
return sentence_transformers.SentenceTransformer(model_path)

def get_embedding(self, content):
"""
Encodes the given object into an embedding.
Args:
content: The content to encode.
Returns:
The embedding of the object.
"""
return self.model.encode(content)

def control_filter(self, keywords, control_item):
"""
Calculates the cosine similarity between the embeddings of the given keywords and the control item.
Args:
keywords (str): The keywords to be used for calculating the similarity.
control_item (str): The control item to be compared with the keywords.
Returns:
float: The cosine similarity between the embeddings of the keywords and the control item.
"""
keywords_embedding = self.get_embedding(keywords)
control_item_embedding = self.get_embedding(control_item)
return self.cos_sim(keywords_embedding, control_item_embedding)

@staticmethod
def cos_sim(embedding1, embedding2):
"""
Computes the cosine similarity between two embeddings.
"""
return sentence_transformers.util.cos_sim(embedding1, embedding2)


class TextControlFilter:
"""
A class that provides methods for filtering control items based on keywords.
"""

@staticmethod
def control_filter(filtered_control_info, control_items, keywords):
"""
Filters control items based on keywords.
Args:
filtered_control_info (list): A list of control items that have already been filtered.
control_items (list): A list of control items to be filtered.
keywords (list): A list of keywords to filter the control items.
"""
for control_item in control_items:
if control_item not in filtered_control_info:
control_text = control_item['control_text'].lower()
if any(keyword in control_text or control_text in keyword for keyword in keywords):
filtered_control_info.append(control_item)

class SemanticControlFilter(ControlFilterModel):
"""
A class that represents a semantic model for control filtering.
"""

def control_filter_score(self, control_text, keywords):
"""
Calculates the score for a control item based on the similarity between its text and a set of keywords.
Args:
control_text (str): The text of the control item.
keywords (list): A list of keywords.
Returns:
float: The score indicating the similarity between the control text and the keywords.
"""
keywords_embedding = self.get_embedding(keywords)
control_text_embedding = self.get_embedding(control_text)
return max(self.cos_sim(control_text_embedding, keywords_embedding).tolist()[0])

def control_filter(self, filtered_control_info, control_items, keywords, top_k):
"""
Filters control items based on their similarity to a set of keywords.
Args:
filtered_control_info (list): A list of already filtered control items.
control_items (list): A list of control items to be filtered.
keywords (list): A list of keywords.
top_k (int): The number of top control items to be selected.
"""
scores = []
for control_item in control_items:
if control_item not in filtered_control_info:
control_text = control_item['control_text'].lower()
score = self.control_filter_score(control_text, keywords)
else:
score = -100.0
scores.append(score)
topk_items = heapq.nlargest(top_k, enumerate(scores), key=lambda x: x[1])
topk_indices = [item[0] for item in topk_items]

filtered_control_info.extend([control_items[i] for i in topk_indices])

class IconControlFilter(ControlFilterModel):
"""
Represents a model for filtering control icons based on keywords.
Attributes:
Inherits attributes from ControlFilterModel.
Methods:
control_filter_score(control_icon, keywords): Calculates the score of a control icon based on its similarity to the given keywords.
control_filter(filtered_control_info, control_items, cropped_icons_dict, keywords, top_k): Filters control items based on their scores and returns the top-k items.
"""

def control_filter_score(self, control_icon, keywords):
"""
Calculates the score of a control icon based on its similarity to the given keywords.
Args:
control_icon: The control icon image.
keywords: The keywords to compare the control icon against.
Returns:
The maximum similarity score between the control icon and the keywords.
"""
keywords_embedding = self.get_embedding(keywords)
control_icon_embedding = self.get_embedding(control_icon)
return max(self.cos_sim(control_icon_embedding, keywords_embedding).tolist()[0])

def control_filter(self, filtered_control_info, control_items, cropped_icons_dict, keywords, top_k):
"""
Filters control items based on their scores and returns the top-k items.
Args:
filtered_control_info: The list of already filtered control items.
control_items: The list of all control items.
cropped_icons: The dictionary of the cropped icons.
keywords: The keywords to compare the control icons against.
top_k: The number of top items to return.
Returns:
The list of top-k control items based on their scores.
"""
scores = []
for label, cropped_icon in cropped_icons_dict.items():
if label not in [info['label'] for info in filtered_control_info]:
score = self.control_filter_score(cropped_icon, keywords)
scores.append((score, label))
else:
scores.append((-100.0, label))
topk_items = heapq.nlargest(top_k, scores, key=lambda x: x[0])
topk_labels = [item[1] for item in topk_items]

filtered_control_info.extend([control_item for control_item in control_items if control_item['label'] in topk_labels ])
31 changes: 31 additions & 0 deletions ufo/automator/ui_control/screenshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,23 @@ def get_annotation_dict(self) -> Dict:
annotation_dict[label_text] = control
return annotation_dict

def get_cropped_icons_dict(self) -> Dict:
"""
Get the dictionary of the cropped icons.
:return: The dictionary of the cropped icons.
"""
cropped_icons_dict = {}
image = self.photographer.capture()
window_rect = self.photographer.control.rectangle()
for i, control in enumerate(self.sub_control_list):
if self.annotation_type == "number":
label_text = str(i+1)
elif self.annotation_type == "letter":
label_text = self.number_to_letter(i)
control_rect = control.rectangle()
cropped_icons_dict[label_text] = image.crop(self.coordinate_adjusted(window_rect, control_rect))
return cropped_icons_dict


def capture(self, save_path:Optional[str] = None):
"""
Expand Down Expand Up @@ -370,6 +387,20 @@ def get_annotation_dict(self, control, sub_control_list: List, annotation_type="
screenshot = self.screenshot_factory.create_screenshot("app_window", control)
screenshot = AnnotationDecorator(screenshot, sub_control_list, annotation_type)
return screenshot.get_annotation_dict()


def get_cropped_icons_dict(self, control, sub_control_list: List, annotation_type="number") -> Dict:
"""
Get the dictionary of the cropped icons.
:param control: The control item to capture.
:param sub_control_list: The list of the controls to annotate.
:param annotation_type: The type of the annotation.
:return: The dictionary of the cropped icons.
"""

screenshot = self.screenshot_factory.create_screenshot("app_window", control)
screenshot = AnnotationDecorator(screenshot, sub_control_list, annotation_type)
return screenshot.get_cropped_icons_dict()



Expand Down
8 changes: 8 additions & 0 deletions ufo/config/config_dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,11 @@ INPUT_TEXT_ENTER: True # whether to press enter after typing the text
## APIs related
USE_APIS: True # Whether to use the API
WORD_API_PROMPT: "ufo/prompts/base/{mode}/word_api.yaml" # The prompt for the word API

# For control filtering
#'TEXT' for only rich text filter, 'SEMANTIC' for only semantic similarity match, 'ICON' for only icon match
CONTROL_FILTER_TYPE: ["TEXT", "SEMANTIC", "ICON"] # The control filter type
CONTROL_FILTER_TOP_K_SEMANTIC: 15 # The control filter top k for semantic similarity
CONTROL_FILTER_TOP_K_ICON: 15 # The control filter top k for icon similarity
CONTROL_FILTER_MODEL_SEMANTIC_NAME: "all-MiniLM-L6-v2" # The control filter model name of semantic similarity
CONTROL_FILTER_MODEL_ICON_NAME: "clip-ViT-B-32" # The control filter model name of icon similarity
Loading