Skip to content

Commit

Permalink
quick tour is working
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrunel committed May 15, 2024
1 parent f4f8769 commit a21f169
Show file tree
Hide file tree
Showing 35 changed files with 4,703 additions and 545 deletions.
9 changes: 9 additions & 0 deletions examples/agent_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from lavague.core import WebAgent, WorldModel, ActionEngine
from lavague.drivers.selenium import SeleniumDriver

selenium_driver = SeleniumDriver()
world_model = WorldModel.from_hub("hf_example")
action_engine = ActionEngine(selenium_driver)
agent = WebAgent(action_engine, world_model)
agent.get("https://huggingface.co/docs")
agent.run("Go on the quicktour of PEFT")
Binary file added examples/screenshots/output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion lavague-cli/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ documentation = "https://docs.lavague.ai/en/latest/"
packages = [{include = "lavague/"}]

[tool.poetry.dependencies]
python = "^3.8.1"
python = "^3.10.0"
lavague-core = "^0.1.0"
click = "^8.1.7"
importlib = "^1.0.4"
Expand Down
9 changes: 5 additions & 4 deletions lavague-core/lavague/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from lavague.core.action_engine import ActionEngine
from lavague.core.action_context import ActionContext
from lavague.core.base_driver import BaseDriver
from lavague.core.extractors import BaseExtractor, PythonFromMarkdownExtractor
from lavague.core.context import Context, get_default_context
from lavague.core.extractors import PythonFromMarkdownExtractor
from lavague.core.prompt_templates import DefaultPromptTemplate
from lavague.core.retrievers import BaseHtmlRetriever, OpsmSplitRetriever
from lavague.core.retrievers import OpsmSplitRetriever
from lavague.core.world_model import WorldModel
from lavague.core.agents import WebAgent
26 changes: 0 additions & 26 deletions lavague-core/lavague/core/action_context.py

This file was deleted.

71 changes: 35 additions & 36 deletions lavague-core/lavague/core/action_engine.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from __future__ import annotations
from typing import Optional, Generator
from typing import Optional, Generator, List
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import get_response_synthesizer
from llama_index.core import PromptTemplate
from llama_index.core import get_response_synthesizer, PromptTemplate, QueryBundle
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.base.embeddings.base import BaseEmbedding
from lavague.core.extractors import BaseExtractor
from lavague.core.retrievers import BaseHtmlRetriever
from lavague.core.base_driver import BaseDriver
from lavague.core.action_context import ActionContext

from lavague.core.context import Context, get_default_context

class ActionEngine:
"""
Expand All @@ -18,45 +16,23 @@ class ActionEngine:
Args:
driver (`BaseDriver`):
The Web driver used to interact with the headless browser
llm (`BaseLLM`):
The llm that will be used the generate the python code
embedding: (`BaseEmbedding`)
The embedder used by the retriever
retriever (`BaseHtmlRetriever`)
The retriever used to extract context from the html page
prompt_template (`str`):
The prompt_template given to the llm, later completed by chunks of the html page and the query
cleaning_function (`Callable[[str], Optional[str]]`):
Function to extract the python code from the llm output
"""

def __init__(
self,
driver: BaseDriver,
llm: BaseLLM,
embedding: BaseEmbedding,
retriever: BaseHtmlRetriever,
prompt_template: PromptTemplate,
extractor: BaseExtractor,
context: Optional[Context] = None,
):
self.driver = driver
self.llm = llm
self.embedding = embedding
self.retriever = retriever
self.prompt_template = prompt_template.partial_format(
if context is None:
context = get_default_context()
self.driver: BaseDriver = driver
self.llm: BaseLLM = context.llm
self.embedding: BaseEmbedding = context.embedding
self.retriever: BaseHtmlRetriever = context.retriever
self.prompt_template: PromptTemplate = context.prompt_template.partial_format(
driver_capability=driver.get_capability()
)
self.extractor = extractor

def from_context(driver: BaseDriver, context: ActionContext) -> ActionEngine:
return ActionEngine(
driver,
context.llm,
context.embedding,
context.retriever,
context.prompt_template,
context.extractor,
)
self.extractor: BaseExtractor = context.extractor

def _get_query_engine(self, streaming: bool = True) -> RetrieverQueryEngine:
"""
Expand Down Expand Up @@ -86,6 +62,29 @@ def _get_query_engine(self, streaming: bool = True) -> RetrieverQueryEngine:

return query_engine

def get_nodes(self, query: str) -> List[str]:
"""
Get the nodes from the html page
Args:
html (`str`): The html page
Return:
`List[str]`: The nodes
"""
source_nodes = self.retriever.retrieve_html(self.driver, self.embedding, QueryBundle(query_str=query))
source_nodes = [node.text for node in source_nodes]
return source_nodes

def get_action_from_context(self, context: str, query: str) -> str:
"""
Generate the code from a query and a context
"""
prompt = self.prompt_template.format(context_str=context, query_str=query)
response = self.llm.complete(prompt).text
code = self.extractor.extract(response)
return code

def get_action(self, query: str) -> Optional[str]:
"""
Generate the code from a query
Expand Down
137 changes: 137 additions & 0 deletions lavague-core/lavague/core/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import uuid
from lavague.core.utilities.telemetry import send_telemetry, send_telemetry_scr
from lavague.core.utilities.web_utils import display_screenshot, encode_image
from lavague.core.utilities.format_utils import extract_instruction
from lavague.core import ActionEngine, WorldModel
from PIL import Image
import time

N_ATTEMPTS = 5
N_STEPS = 5


class WebAgent:
"""
Web agent class, for now only works with selenium.
"""
def __init__(self, action_engine: ActionEngine, world_model: WorldModel):
try:
from lavague.drivers.selenium import SeleniumDriver

except:
raise ImportError("Failed to import lavague-drivers-selenium, install with `pip install lavague-drivers-selenium`")

self.driver: SeleniumDriver = action_engine.driver
self.action_engine: ActionEngine = action_engine
self.world_model: WorldModel = world_model

def get(self, url):
self.driver.goto(url)

def run(self, objective, display=True):
from selenium.webdriver.remote.webdriver import WebDriver

driver: WebDriver = self.driver.get_driver()
action_engine: ActionEngine = self.action_engine
world_model: WorldModel = self.world_model
success = True
error = ""
url = ""
image = None
screenshot_after_action = None

for i in range(N_STEPS):
success = True
error = ""
bounding_box = {"": 0}
viewport_size = {"": 0}
driver.save_screenshot("screenshot_before_action.png")
screenshot_before_action = Image.open("screenshot_before_action.png")
if display:
display_screenshot(screenshot_before_action)

print("Computing an action plan...")

# We get the current screenshot into base64 before sending to our World Model
state = encode_image("screenshot_before_action.png")

# We get the instruction for the action engine using the world model
output = world_model.get_instruction(state, objective)
instruction = extract_instruction(output)
print(instruction)

print("Thoughts:", output)
if instruction != "STOP":
query = instruction
html = driver.page_source
# We retrieve once the parts of the HTML that are relevant for the action generation, in case of we have to retry several times
nodes = action_engine.get_nodes(query)
context = "\n".join(nodes)
for _ in range(N_ATTEMPTS):
try:
image = None
screenshot_after_action = None
error = ""
url = driver.current_url
success = True
action = action_engine.get_action_from_context(context, query)
outputs = self.driver.get_highlighted_element(action)
image = outputs[-1]["image"]
bounding_box = outputs[-1]["bounding_box"]
viewport_size = outputs[-1]["viewport_size"]

if display:
display_screenshot(image)

print("Showing the next element to interact with")
time.sleep(3)

local_scope = {"driver": driver}

code = f"""
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
{action}""".strip()

exec(code, globals(), local_scope)
time.sleep(3)
driver.save_screenshot("screenshot_after_action.png")
screenshot_after_action = Image.open(
"screenshot_after_action.png"
)
if display:
display_screenshot(screenshot_after_action)

break

except Exception as e:
success = False
print(f"Action execution failed with {e}.\n Retrying...")
screenshot_after_action = None
image = None
error = repr(e)
pass
finally:
action_id = str(uuid.uuid4())
send_telemetry(
action_engine.llm.metadata.model_name,
action,
html,
"",
url,
"Agent",
success,
False,
error,
context,
bounding_box,
viewport_size,
objective,
instruction,
output,
action_id
)
send_telemetry_scr(action_id, screenshot_before_action, image, screenshot_after_action)
else:
print("Objective reached")
break
8 changes: 7 additions & 1 deletion lavague-core/lavague/core/base_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def get_driver(self) -> Any:
"""Return the expected variable name and the driver object"""
pass

@abstractmethod
def resize_driver(driver, width, height):
"""
Resize the driver to a targeted height and width.
"""

@abstractmethod
def get_url(self) -> Optional[str]:
"""Get the url of the current page"""
Expand All @@ -61,7 +67,7 @@ def get_html(self, clean: bool = True) -> str:
pass

@abstractmethod
def get_screenshot(self, filename: str) -> None:
def save_screenshot(self, filename: str) -> None:
"""Save a screenshot to the file filename"""
pass

Expand Down
52 changes: 52 additions & 0 deletions lavague-core/lavague/core/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from llama_index.core.llms import LLM
from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core import PromptTemplate
from lavague.core.extractors import BaseExtractor
from lavague.core.retrievers import BaseHtmlRetriever

DEFAULT_MAX_TOKENS = 512
DEFAULT_TEMPERATURE = 0.0

class Context:
"""Set the context which will be used thourough the action generation pipeline."""

def __init__(
self,
llm: LLM,
mm_llm: MultiModalLLM,
embedding: BaseEmbedding,
retriever: BaseHtmlRetriever,
prompt_template: PromptTemplate,
extractor: BaseExtractor,
):
"""
llm (`LLM`):
The llm that will be used the generate the python code
mm_llm (`MultiModalLLM`):
The multimodal llm that will be used by the world model
embedding: (`BaseEmbedding`)
The embedder used by the retriever
retriever (`BaseHtmlRetriever`)
The retriever used to extract context from the html page
prompt_template (`str`):
The prompt_template given to the llm, later completed by chunks of the html page and the query
cleaning_function (`Callable[[str], Optional[str]]`):
Function to extract the python code from the llm output
"""
self.llm = llm
self.mm_llm = mm_llm
self.embedding = embedding
self.retriever = retriever
self.prompt_template = prompt_template
self.extractor = extractor

def get_default_context() -> Context:
try:
from lavague.contexts.openai import OpenaiContext
return OpenaiContext()
except ImportError:
raise ImportError(
"`lavague-contexts-openai` package not found, "
"please run `pip install lavague-contexts-openai`"
)
Loading

0 comments on commit a21f169

Please sign in to comment.