-
Notifications
You must be signed in to change notification settings - Fork 529
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
35 changed files
with
4,703 additions
and
545 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from lavague.core import WebAgent, WorldModel, ActionEngine | ||
from lavague.drivers.selenium import SeleniumDriver | ||
|
||
selenium_driver = SeleniumDriver() | ||
world_model = WorldModel.from_hub("hf_example") | ||
action_engine = ActionEngine(selenium_driver) | ||
agent = WebAgent(action_engine, world_model) | ||
agent.get("https://huggingface.co/docs") | ||
agent.run("Go on the quicktour of PEFT") |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from lavague.core.action_engine import ActionEngine | ||
from lavague.core.action_context import ActionContext | ||
from lavague.core.base_driver import BaseDriver | ||
from lavague.core.extractors import BaseExtractor, PythonFromMarkdownExtractor | ||
from lavague.core.context import Context, get_default_context | ||
from lavague.core.extractors import PythonFromMarkdownExtractor | ||
from lavague.core.prompt_templates import DefaultPromptTemplate | ||
from lavague.core.retrievers import BaseHtmlRetriever, OpsmSplitRetriever | ||
from lavague.core.retrievers import OpsmSplitRetriever | ||
from lavague.core.world_model import WorldModel | ||
from lavague.core.agents import WebAgent |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import uuid | ||
from lavague.core.utilities.telemetry import send_telemetry, send_telemetry_scr | ||
from lavague.core.utilities.web_utils import display_screenshot, encode_image | ||
from lavague.core.utilities.format_utils import extract_instruction | ||
from lavague.core import ActionEngine, WorldModel | ||
from PIL import Image | ||
import time | ||
|
||
N_ATTEMPTS = 5 | ||
N_STEPS = 5 | ||
|
||
|
||
class WebAgent: | ||
""" | ||
Web agent class, for now only works with selenium. | ||
""" | ||
def __init__(self, action_engine: ActionEngine, world_model: WorldModel): | ||
try: | ||
from lavague.drivers.selenium import SeleniumDriver | ||
|
||
except: | ||
raise ImportError("Failed to import lavague-drivers-selenium, install with `pip install lavague-drivers-selenium`") | ||
|
||
self.driver: SeleniumDriver = action_engine.driver | ||
self.action_engine: ActionEngine = action_engine | ||
self.world_model: WorldModel = world_model | ||
|
||
def get(self, url): | ||
self.driver.goto(url) | ||
|
||
def run(self, objective, display=True): | ||
from selenium.webdriver.remote.webdriver import WebDriver | ||
|
||
driver: WebDriver = self.driver.get_driver() | ||
action_engine: ActionEngine = self.action_engine | ||
world_model: WorldModel = self.world_model | ||
success = True | ||
error = "" | ||
url = "" | ||
image = None | ||
screenshot_after_action = None | ||
|
||
for i in range(N_STEPS): | ||
success = True | ||
error = "" | ||
bounding_box = {"": 0} | ||
viewport_size = {"": 0} | ||
driver.save_screenshot("screenshot_before_action.png") | ||
screenshot_before_action = Image.open("screenshot_before_action.png") | ||
if display: | ||
display_screenshot(screenshot_before_action) | ||
|
||
print("Computing an action plan...") | ||
|
||
# We get the current screenshot into base64 before sending to our World Model | ||
state = encode_image("screenshot_before_action.png") | ||
|
||
# We get the instruction for the action engine using the world model | ||
output = world_model.get_instruction(state, objective) | ||
instruction = extract_instruction(output) | ||
print(instruction) | ||
|
||
print("Thoughts:", output) | ||
if instruction != "STOP": | ||
query = instruction | ||
html = driver.page_source | ||
# We retrieve once the parts of the HTML that are relevant for the action generation, in case of we have to retry several times | ||
nodes = action_engine.get_nodes(query) | ||
context = "\n".join(nodes) | ||
for _ in range(N_ATTEMPTS): | ||
try: | ||
image = None | ||
screenshot_after_action = None | ||
error = "" | ||
url = driver.current_url | ||
success = True | ||
action = action_engine.get_action_from_context(context, query) | ||
outputs = self.driver.get_highlighted_element(action) | ||
image = outputs[-1]["image"] | ||
bounding_box = outputs[-1]["bounding_box"] | ||
viewport_size = outputs[-1]["viewport_size"] | ||
|
||
if display: | ||
display_screenshot(image) | ||
|
||
print("Showing the next element to interact with") | ||
time.sleep(3) | ||
|
||
local_scope = {"driver": driver} | ||
|
||
code = f""" | ||
from selenium.webdriver.common.by import By | ||
from selenium.webdriver.common.keys import Keys | ||
{action}""".strip() | ||
|
||
exec(code, globals(), local_scope) | ||
time.sleep(3) | ||
driver.save_screenshot("screenshot_after_action.png") | ||
screenshot_after_action = Image.open( | ||
"screenshot_after_action.png" | ||
) | ||
if display: | ||
display_screenshot(screenshot_after_action) | ||
|
||
break | ||
|
||
except Exception as e: | ||
success = False | ||
print(f"Action execution failed with {e}.\n Retrying...") | ||
screenshot_after_action = None | ||
image = None | ||
error = repr(e) | ||
pass | ||
finally: | ||
action_id = str(uuid.uuid4()) | ||
send_telemetry( | ||
action_engine.llm.metadata.model_name, | ||
action, | ||
html, | ||
"", | ||
url, | ||
"Agent", | ||
success, | ||
False, | ||
error, | ||
context, | ||
bounding_box, | ||
viewport_size, | ||
objective, | ||
instruction, | ||
output, | ||
action_id | ||
) | ||
send_telemetry_scr(action_id, screenshot_before_action, image, screenshot_after_action) | ||
else: | ||
print("Objective reached") | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from llama_index.core.llms import LLM | ||
from llama_index.core.multi_modal_llms import MultiModalLLM | ||
from llama_index.core.embeddings import BaseEmbedding | ||
from llama_index.core import PromptTemplate | ||
from lavague.core.extractors import BaseExtractor | ||
from lavague.core.retrievers import BaseHtmlRetriever | ||
|
||
DEFAULT_MAX_TOKENS = 512 | ||
DEFAULT_TEMPERATURE = 0.0 | ||
|
||
class Context: | ||
"""Set the context which will be used thourough the action generation pipeline.""" | ||
|
||
def __init__( | ||
self, | ||
llm: LLM, | ||
mm_llm: MultiModalLLM, | ||
embedding: BaseEmbedding, | ||
retriever: BaseHtmlRetriever, | ||
prompt_template: PromptTemplate, | ||
extractor: BaseExtractor, | ||
): | ||
""" | ||
llm (`LLM`): | ||
The llm that will be used the generate the python code | ||
mm_llm (`MultiModalLLM`): | ||
The multimodal llm that will be used by the world model | ||
embedding: (`BaseEmbedding`) | ||
The embedder used by the retriever | ||
retriever (`BaseHtmlRetriever`) | ||
The retriever used to extract context from the html page | ||
prompt_template (`str`): | ||
The prompt_template given to the llm, later completed by chunks of the html page and the query | ||
cleaning_function (`Callable[[str], Optional[str]]`): | ||
Function to extract the python code from the llm output | ||
""" | ||
self.llm = llm | ||
self.mm_llm = mm_llm | ||
self.embedding = embedding | ||
self.retriever = retriever | ||
self.prompt_template = prompt_template | ||
self.extractor = extractor | ||
|
||
def get_default_context() -> Context: | ||
try: | ||
from lavague.contexts.openai import OpenaiContext | ||
return OpenaiContext() | ||
except ImportError: | ||
raise ImportError( | ||
"`lavague-contexts-openai` package not found, " | ||
"please run `pip install lavague-contexts-openai`" | ||
) |
Oops, something went wrong.