diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 00000000..25729df3 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,129 @@ +env: + REGISTRY: ghcr.io/anthropics/anthropic-quickstarts +name: build +on: + pull_request: + paths: + - computer-use-demo/** + push: + branches: + - main + paths: + - computer-use-demo/** +jobs: + build: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + strategy: + fail-fast: true + matrix: + platform: + - amd64 + - arm64 + steps: + - uses: actions/checkout@v4 + - name: Login to ghcr.io + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{github.actor}} + password: ${{secrets.GITHUB_TOKEN}} + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Set image tag + run: | + short_sha=$(git rev-parse --short ${{ github.sha }}) + echo "TAG=${REGISTRY}:computer-use-demo-${short_sha}" >> "$GITHUB_ENV" + - name: Build Docker image + uses: docker/build-push-action@v5 + with: + platforms: linux/${{ matrix.platform }} + context: computer-use-demo + push: false + tags: ${{ env.TAG }} + cache-from: type=gha,scope=computer-use-${{ matrix.platform }} + cache-to: type=gha,mode=max,scope=computer-use-${{ matrix.platform }} + load: true + - name: Run container + run: docker run -d -p 8051:8051 ${{ env.TAG }} + - name: Check streamlit + run: | + timeout=60 + start_time=$(date +%s) + docker_id=$(docker ps --filter "ancestor=${{ env.TAG }}" --format "{{.ID}}") + echo "docker_id=$docker_id" >> "$GITHUB_ENV" + while true; do + current_time=$(date +%s) + elapsed=$((current_time - start_time)) + if [ $elapsed -ge $timeout ]; then + echo "Timeout reached. Container did not respond within $timeout seconds." + exit 1 + fi + response=$(docker exec $docker_id curl -s -o /dev/null -w "%{http_code}" http://127.0.0.1:8501 || echo "000") + if [ "$response" = "200" ]; then + echo "Container responded with 200 OK" + exit 0 + fi + done + - name: Check VNC + run: docker exec $docker_id nc localhost 5900 -z + - name: Check noVNC + run: docker exec $docker_id curl -s -o /dev/null -w "%{http_code}" http://localhost:6080 | grep -q 200 || exit 1 + - name: Check landing page + run: docker exec $docker_id curl -s -o /dev/null -w "%{http_code}" http://localhost:8080 | grep -q 200 || exit 1 + - name: Determine push tags + run: | + if [ "${{ github.event_name }}" == "pull_request" ]; then + echo "PUSH_TAGS=${TAG}-${{ matrix.platform }}" >> "$GITHUB_ENV" + else + echo "PUSH_TAGS=${TAG}-${{ matrix.platform }},${REGISTRY}:computer-use-demo-latest-${{ matrix.platform }}" >> "$GITHUB_ENV" + fi + - name: Push Docker image + uses: docker/build-push-action@v5 + with: + platforms: linux/${{ matrix.platform }} + context: . + push: true + tags: ${{ env.PUSH_TAGS }} + cache-from: type=gha,scope=computer-use-${{ matrix.platform }} + cache-to: type=gha,mode=max,scope=computer-use-${{ matrix.platform }} + merge: + runs-on: ubuntu-latest + needs: + - build + permissions: + contents: read + packages: write + steps: + - uses: actions/checkout@v4 + - name: Login to ghcr.io + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{github.actor}} + password: ${{secrets.GITHUB_TOKEN}} + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Set image tag + run: | + echo "SHORT_SHA=$(git rev-parse --short ${{ github.sha }})" >> "$GITHUB_ENV" + - name: Create SHA manifest and push + run: | + docker buildx imagetools create -t \ + ${REGISTRY}:computer-use-demo-${SHORT_SHA} \ + ${REGISTRY}:computer-use-demo-${SHORT_SHA}-amd64 \ + ${REGISTRY}:computer-use-demo-${SHORT_SHA}-arm64 + + - name: Create latest manifest and push + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: | + docker buildx imagetools create -t \ + ${REGISTRY}:computer-use-demo-latest \ + ${REGISTRY}:computer-use-demo-latest-amd64 \ + ${REGISTRY}:computer-use-demo-latest-arm64 diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 00000000..c9d1b481 --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,49 @@ +name: tests +on: + pull_request: {} + push: + branches: + - main +jobs: + ruff: + runs-on: ubuntu-latest + defaults: + run: + working-directory: computer-use-demo + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v1 + pyright: + runs-on: ubuntu-latest + defaults: + run: + working-directory: computer-use-demo + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + cache: "pip" + python-version: "3.11.6" + - run: | + python -m venv .venv + source .venv/bin/activate + pip install -r dev-requirements.txt + - run: echo "$PWD/.venv/bin" >> $GITHUB_PATH + - uses: jakebailey/pyright-action@v1 + pytest: + runs-on: ubuntu-latest + defaults: + run: + working-directory: computer-use-demo + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + cache: "pip" + python-version: "3.11.6" + - run: | + python -m venv .venv + source .venv/bin/activate + pip install -r dev-requirements.txt + - run: echo "$PWD/.venv/bin" >> $GITHUB_PATH + - run: pytest tests --junitxml=junit/test-results.xml diff --git a/README.md b/README.md index e0e469ff..1e8c7fdb 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Anthropic Quickstarts -Anthropic Quickstarts is a collection of projects designed to help developers quickly get started with building deployable applications using the Anthropic API. Each quickstart provides a foundation that you can easily build upon and customize for your specific needs. +Anthropic Quickstarts is a collection of projects designed to help developers quickly get started with building applications using the Anthropic API. Each quickstart provides a foundation that you can easily build upon and customize for your specific needs. ## Getting Started @@ -8,18 +8,24 @@ To use these quickstarts, you'll need an Anthropic API key. If you don't have on ## Available Quickstarts -### 1. Customer Support Agent +### Customer Support Agent -Our first quickstart project is a customer support agent powered by Claude. This project demonstrates how to leverage Claude's natural language understanding and generation capabilities to create an AI-assisted customer support system with access to a knowledge base. +A customer support agent powered by Claude. This project demonstrates how to leverage Claude's natural language understanding and generation capabilities to create an AI-assisted customer support system with access to a knowledge base. [Go to Customer Support Agent Quickstart](./customer-support-agent) -### 2. Financial Data Analyst +### Financial Data Analyst -Our second quickstart is a financial data analyst powered by Claude. This project demonstrates how to leverage Claude's capabilities with interactive data visualization to analyze financial data via chat. +A financial data analyst powered by Claude. This project demonstrates how to leverage Claude's capabilities with interactive data visualization to analyze financial data via chat. [Go to Financial Data Analyst Quickstart](./financial-data-analyst) +### Computer Use Demo + +An environment and tools that Claude can use to control a desktop computer. This project demonstrates how to leverage the computer use capabilities of the the new Claude 3.5 Sonnet model. + +[Go to Computer Use Demo Quickstart](./computer-use-demo) + ## General Usage Each quickstart project comes with its own README and setup instructions. Generally, you'll follow these steps: diff --git a/computer-use-demo/.gitignore b/computer-use-demo/.gitignore new file mode 100644 index 00000000..b4035dd3 --- /dev/null +++ b/computer-use-demo/.gitignore @@ -0,0 +1,4 @@ +.venv +.ruff_cache +__pycache__ +.pytest_cache diff --git a/computer-use-demo/.zed/settings.json b/computer-use-demo/.zed/settings.json new file mode 100644 index 00000000..460ba02b --- /dev/null +++ b/computer-use-demo/.zed/settings.json @@ -0,0 +1,12 @@ +{ + "preferred_line_length": 88, + "languages": { + "Python": { + "language_servers": ["pyright", "ruff"] + } + }, + "telemetry": { + "diagnostics": false, + "metrics": false + } +} diff --git a/computer-use-demo/Dockerfile b/computer-use-demo/Dockerfile new file mode 100644 index 00000000..f3b00255 --- /dev/null +++ b/computer-use-demo/Dockerfile @@ -0,0 +1,105 @@ +FROM docker.io/ubuntu:22.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV DEBIAN_PRIORITY=high + +RUN apt-get update && \ + apt-get -y upgrade && \ + apt-get -y install \ + build-essential \ + # UI Requirements + xvfb \ + xterm \ + xdotool \ + scrot \ + imagemagick \ + sudo \ + mutter \ + x11vnc \ + # Python/pyenv reqs + build-essential \ + libssl-dev \ + zlib1g-dev \ + libbz2-dev \ + libreadline-dev \ + libsqlite3-dev \ + curl \ + git \ + libncursesw5-dev \ + xz-utils \ + tk-dev \ + libxml2-dev \ + libxmlsec1-dev \ + libffi-dev \ + liblzma-dev \ + # Network tools + net-tools \ + netcat \ + # PPA req + software-properties-common && \ + # Userland apps + sudo add-apt-repository ppa:mozillateam/ppa && \ + sudo apt-get install -y --no-install-recommends \ + libreoffice \ + firefox-esr \ + x11-apps \ + xpdf \ + gedit \ + xpaint \ + tint2 \ + galculator \ + pcmanfm \ + unzip && \ + apt-get clean + +# Install noVNC +RUN git clone --branch v1.5.0 https://github.com/novnc/noVNC.git /opt/noVNC && \ + git clone --branch v0.12.0 https://github.com/novnc/websockify /opt/noVNC/utils/websockify && \ + ln -s /opt/noVNC/vnc.html /opt/noVNC/index.html + +# setup user +ENV USERNAME=computeruse +ENV HOME=/home/$USERNAME +RUN useradd -m -s /bin/bash -d $HOME $USERNAME +RUN echo "${USERNAME} ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers +USER computeruse +WORKDIR $HOME + +# setup python +RUN git clone https://github.com/pyenv/pyenv.git ~/.pyenv && \ + cd ~/.pyenv && src/configure && make -C src && cd .. && \ + echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc && \ + echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc && \ + echo 'eval "$(pyenv init -)"' >> ~/.bashrc +ENV PYENV_ROOT="$HOME/.pyenv" +ENV PATH="$PYENV_ROOT/bin:$PATH" +ENV PYENV_VERSION_MAJOR=3 +ENV PYENV_VERSION_MINOR=11 +ENV PYENV_VERSION_PATCH=6 +ENV PYENV_VERSION=$PYENV_VERSION_MAJOR.$PYENV_VERSION_MINOR.$PYENV_VERSION_PATCH +RUN eval "$(pyenv init -)" && \ + pyenv install $PYENV_VERSION && \ + pyenv global $PYENV_VERSION && \ + pyenv rehash + +ENV PATH="$HOME/.pyenv/shims:$HOME/.pyenv/bin:$PATH" + +RUN python -m pip install --upgrade pip==23.1.2 setuptools==58.0.4 wheel==0.40.0 && \ + python -m pip config set global.disable-pip-version-check true + +# only reinstall if requirements.txt changes +COPY --chown=$USERNAME:$USERNAME computer_use_demo/requirements.txt $HOME/computer_use_demo/requirements.txt +RUN python -m pip install -r $HOME/computer_use_demo/requirements.txt + +# setup desktop env & app +COPY --chown=$USERNAME:$USERNAME image/ $HOME +COPY --chown=$USERNAME:$USERNAME computer_use_demo/ $HOME/computer_use_demo/ + +ARG DISPLAY_NUM=1 +ARG HEIGHT=768 +ARG WIDTH=1024 +ENV DISPLAY_NUM=$DISPLAY_NUM +ENV HEIGHT=$HEIGHT +ENV WIDTH=$WIDTH + +ENTRYPOINT [ "./entrypoint.sh" ] diff --git a/computer-use-demo/LICENSE b/computer-use-demo/LICENSE new file mode 100644 index 00000000..a2981d1c --- /dev/null +++ b/computer-use-demo/LICENSE @@ -0,0 +1,7 @@ +Copyright 2024 Anthropic, PBC. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/computer-use-demo/README.md b/computer-use-demo/README.md new file mode 100644 index 00000000..c31739ea --- /dev/null +++ b/computer-use-demo/README.md @@ -0,0 +1,154 @@ +# Anthropic Computer Use Demo + +> [!CAUTION] +> Computer use is a beta feature. Please be aware that computer use poses unique risks that are distinct from standard API features or chat interfaces. These risks are heightened when using computer use to interact with the internet. To minimize risks, consider taking precautions such as: +> +> 1. Use a dedicated virtual machine or container with minimal privileges to prevent direct system attacks or accidents. +> 2. Avoid giving the model access to sensitive data, such as account login information, to prevent information theft. +> 3. Limit internet access to an allowlist of domains to reduce exposure to malicious content. +> 4. Ask a human to confirm decisions that may result in meaningful real-world consequences as well as any tasks requiring affirmative consent, such as accepting cookies, executing financial transactions, or agreeing to terms of service. +> +> In some circumstances, Claude will follow commands found in content even if it conflicts with the user's instructions. For example, Claude instructions on webpages or contained in images may override instructions or cause Claude to make mistakes. We suggest taking precautions to isolate Claude from sensitive data and actions to avoid risks related to prompt injection. +> +> Finally, please inform end users of relevant risks and obtain their consent prior to enabling computer use in your own products. + +This repository helps you get started with computer use on Claude, with reference implementations of: + +* Build files to create a Docker container with all nescessary dependencies +* A computer use agent loop using the Anthropic API, Bedrock, or Vertex to access the updated Claude 3.5 Sonnet model +* Anthropic-defined computer use tools +* A streamlit app for interacting with the agent loop + +Please use [this form](https://forms.gle/BT1hpBrqDPDUrCqo7) to provide feedback on the quality of the model responses, the API itself, or the quality of the documentation - we cannot wait to hear from you! + +> [!IMPORTANT] +> The Beta API used in this reference implementation is subject to change. Please refer to the [API release notes](https://docs.anthropic.com/en/release-notes/api) for the most up-to-date information. + +> [!IMPORTANT] +> The components are weakly separated: the agent loop runs in the container being controlled by Claude, can only be used by one session at a time, and must be restarted or reset between sessions if necessary. + +## Quickstart: running the Docker container + +### Anthropic API + +```bash +export ANTHROPIC_API_KEY=%your_api_key% +docker run \ + -e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY \ + -v $HOME/.anthropic:/home/computeruse/.anthropic \ + -p 5900:5900 \ + -p 8501:8501 \ + -p 6080:6080 \ + -p 8080:8080 \ + -it ghcr.io/anthropics/anthropic-quickstarts:computer-use-demo-latest +``` + +Once the container is running, open your browser to [http://localhost:8080](http://localhost:8080) to access the combined interface that includes both the agent chat and desktop view. + +The container stores settings like API key and custom system prompt in `~/.anthropic/`. Mount this directory to persist these settings between container runs. + +Alternative access points: +- Streamlit interface only: [http://localhost:8501](http://localhost:8501) +- Desktop view only: [http://localhost:6080/vnc.html](http://localhost:6080/vnc.html) +- Direct VNC connection: `vnc://localhost:5900` (for VNC clients) + +### Bedrock + +You'll need to pass in AWS credentials with appropriate permissions to use Claude on Bedrock. + +You have a few options for authenticating with Bedrock. See the [boto3 documentation](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#environment-variables) for more details and options. + +#### Option 1: (suggested) Use the host's AWS credentials file and AWS profile + +```bash +export AWS_PROFILE= +docker run \ + -e API_PROVIDER=bedrock \ + -e AWS_PROFILE=$AWS_PROFILE \ + -v $HOME/.aws/credentials:/home/computeruse/.aws/credentials \ + -v $HOME/.anthropic:/home/computeruse/.anthropic \ + -p 5900:5900 \ + -p 8501:8501 \ + -p 6080:6080 \ + -p 8080:8080 \ + -it ghcr.io/anthropics/anthropic-quickstarts:computer-use-demo-latest +``` + +#### Option 2: Use an access key and secret + +```bash +export AWS_ACCESS_KEY_ID=%your_aws_access_key% +export AWS_SECRET_ACCESS_KEY=%your_aws_secret_access_key% +export AWS_SESSION_TOKEN=%your_aws_session_token% +docker run \ + -e API_PROVIDER=bedrock \ + -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ + -e AWS_SESSION_TOKEN=$AWS_SESSION_TOKEN \ + -v $HOME/.anthropic:/home/computeruse/.anthropic \ + -p 5900:5900 \ + -p 8501:8501 \ + -p 6080:6080 \ + -p 8080:8080 \ + -it ghcr.io/anthropics/anthropic-quickstarts:computer-use-demo-latest +``` + +### Vertex +You'll need to pass in Google Cloud credentials with appropriate permissions to use Claude on Vertex. + +```bash +docker build . -t computer-use-demo +gcloud auth application-default login +export VERTEX_REGION=%your_vertex_region% +export VERTEX_PROJECT_ID=%your_vertex_project_id% +docker run \ + -e API_PROVIDER=vertex \ + -e CLOUD_ML_REGION=$VERTEX_REGION \ + -e ANTHROPIC_VERTEX_PROJECT_ID=$VERTEX_PROJECT_ID \ + -v $HOME/.config/gcloud/application_default_credentials.json:/home/computeruse/.config/gcloud/application_default_credentials.json \ + -p 5900:5900 \ + -p 8501:8501 \ + -p 6080:6080 \ + -p 8080:8080 \ + -it computer-use-demo +``` +This example shows how to use the Google Cloud Application Default Credentials to authenticate with Vertex. + +You can also set `GOOGLE_APPLICATION_CREDENTIALS` to use an arbitrary credential file, see the [Google Cloud Authentication documentation](https://cloud.google.com/docs/authentication/application-default-credentials#GAC) for more details. + +## Screen size +Environment variables `WIDTH` and `HEIGHT` can be used to set the screen size. For example: + +```bash +docker run \ + -e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY \ + -v $(pwd)/computer_use_demo:/home/computeruse/computer_use_demo/ \ + -v $HOME/.anthropic:/home/computeruse/.anthropic \ + -p 5900:5900 \ + -p 8501:8501 \ + -p 6080:6080 \ + -p 8080:8080 \ + -e WIDTH=1920 \ + -e HEIGHT=1080 \ + -it ghcr.io/anthropics/anthropic-quickstarts:computer-use-demo-latest +``` + +We do not recommend sending screenshots in resolutions above [XGA/WXGA](https://en.wikipedia.org/wiki/Display_resolution_standards#XGA) to avoid issues related to [image resizing](https://docs.anthropic.com/en/docs/build-with-claude/vision#evaluate-image-size). +Relying on the image resizing behavior in the API will result in lower model accuracy and slower performance than implementing scaling in your tools directly. The `computer` tool implementation in this project demonstrates how to scale both images and coordinates from higher resolutions to the suggested resolutions. + +## Development +```bash +./setup.sh # configure venv, install development dependencies, and install pre-commit hooks +docker build . -t computer-use-demo:local # manually build the docker image (optional) +export ANTHROPIC_API_KEY=%your_api_key% +docker run \ + -e ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY \ + -v $(pwd)/computer_use_demo:/home/computeruse/computer_use_demo/ `# mount local python module for development` \ + -v $HOME/.anthropic:/home/computeruse/.anthropic \ + -p 5900:5900 \ + -p 8501:8501 \ + -p 6080:6080 \ + -p 8080:8080 \ + -it computer-use-demo:local # can also use ghcr.io/anthropics/anthropic-quickstarts:computer-use-demo-latest +``` +The docker run command above mounts the repo inside the docker image, such that you can edit files from the host. Streamlit is already configured with auto reloading. diff --git a/computer-use-demo/computer_use_demo/__init__.py b/computer-use-demo/computer_use_demo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/computer-use-demo/computer_use_demo/loop.py b/computer-use-demo/computer_use_demo/loop.py new file mode 100644 index 00000000..317a70ca --- /dev/null +++ b/computer-use-demo/computer_use_demo/loop.py @@ -0,0 +1,250 @@ +""" +Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools. +""" + +import platform +from collections.abc import Callable +from datetime import datetime +from enum import StrEnum +from typing import Any, cast + +from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse +from anthropic.types import ( + MessageParam, + ToolParam, + ToolResultBlockParam, +) +from anthropic.types.beta import ( + BetaContentBlock, + BetaContentBlockParam, + BetaImageBlockParam, + BetaMessage, + BetaMessageParam, + BetaTextBlockParam, + BetaToolParam, + BetaToolResultBlockParam, +) + +from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult + +BETA_FLAG = "computer-use-2024-10-22" + + +class APIProvider(StrEnum): + ANTHROPIC = "anthropic" + BEDROCK = "bedrock" + VERTEX = "vertex" + + +PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = { + APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022", + APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0", + APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022", +} + + +# This system prompt is optimized for the Docker environment in this repository and +# specific tool combinations enabled. +# We encourage modifying this system prompt to ensure the model has context for the +# environment it is running in, and to provide any additional information that may be +# helpful for the task at hand. +SYSTEM_PROMPT = f""" +* You are utilising an Ubuntu virtual machine using {platform.machine()} architecture with internet access. +* You can feel free to install Ubuntu applications with your bash tool. Use curl instead of wget. +* To open firefox, please just click on the firefox icon. Note, firefox-esr is what is installed on your system. +* Using bash tool you can start GUI applications, but you need to set export DISPLAY=:1 and use a subshell. For example "(DISPLAY=:1 xterm &)". GUI apps run with bash tool will appear within your desktop environment, but they may take some time to appear. Take a screenshot to confirm it did. +* When using your bash tool with commands that are expected to output very large quantities of text, redirect into a tmp file and use str_replace_editor or `grep -n -B -A ` to confirm output. +* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. +* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. +* The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. + + + +* When using Firefox, if a startup wizard appears, IGNORE IT. Do not even click "skip this step". Instead, click on the address bar where it says "Search or enter address", and enter the appropriate search term or URL there. +* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. +* When viewing a webpage, first use your computer tool to view it and explore it. But, if there is a lot of text on that page, instead curl the html of that page to a file on disk and then using your StrReplaceEditTool to view the contents in plain text. +""" + + +async def sampling_loop( + *, + model: str, + provider: APIProvider, + system_prompt_suffix: str, + messages: list[BetaMessageParam], + output_callback: Callable[[BetaContentBlock], None], + tool_output_callback: Callable[[ToolResult, str], None], + api_response_callback: Callable[[APIResponse[BetaMessage]], None], + api_key: str, + only_n_most_recent_images: int | None = None, + max_tokens: int = 4096, +): + """ + Agentic sampling loop for the assistant/tool interaction of computer use. + """ + tool_collection = ToolCollection( + ComputerTool(), + BashTool(), + EditTool(), + ) + system = ( + f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}" + ) + + while True: + if only_n_most_recent_images: + _maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images) + + # Call the API + # we use raw_response to provide debug information to streamlit. Your + # implementation may be able call the SDK directly with: + # `response = client.messages.create(...)` instead. + if provider == APIProvider.ANTHROPIC: + raw_response = Anthropic( + api_key=api_key + ).beta.messages.with_raw_response.create( + max_tokens=max_tokens, + messages=messages, + model=model, + system=system, + tools=cast(list[BetaToolParam], tool_collection.to_params()), + extra_headers={"anthropic-beta": BETA_FLAG}, + ) + elif provider == APIProvider.VERTEX: + raw_response = AnthropicVertex().messages.with_raw_response.create( + max_tokens=max_tokens, + messages=cast(list[MessageParam], messages), + model=model, + system=system, + tools=cast(list[ToolParam], tool_collection.to_params()), + extra_headers={"anthropic-beta": BETA_FLAG}, + ) + elif provider == APIProvider.BEDROCK: + raw_response = AnthropicBedrock().messages.with_raw_response.create( + max_tokens=max_tokens, + messages=cast(list[MessageParam], messages), + model=model, + system=system, + tools=cast(list[ToolParam], tool_collection.to_params()), + extra_body={"anthropic_beta": [BETA_FLAG]}, + ) + + api_response_callback(cast(APIResponse[BetaMessage], raw_response)) + + response = raw_response.parse() + + messages.append( + { + "role": "assistant", + "content": cast(list[BetaContentBlockParam], response.content), + } + ) + + tool_result_content: list[BetaToolResultBlockParam] = [] + for content_block in cast(list[BetaContentBlock], response.content): + output_callback(content_block) + if content_block.type == "tool_use": + result = await tool_collection.run( + name=content_block.name, + tool_input=cast(dict[str, Any], content_block.input), + ) + tool_result_content.append( + _make_api_tool_result(result, content_block.id) + ) + tool_output_callback(result, content_block.id) + + if not tool_result_content: + return messages + + messages.append({"content": tool_result_content, "role": "user"}) + + +def _maybe_filter_to_n_most_recent_images( + messages: list[BetaMessageParam], + images_to_keep: int, + min_removal_threshold: int = 10, +): + """ + With the assumption that images are screenshots that are of diminishing value as + the conversation progresses, remove all but the final `images_to_keep` tool_result + images in place, with a chunk of min_removal_threshold to reduce the amount we + break the implicit prompt cache. + """ + if images_to_keep is None: + return messages + + tool_result_blocks = cast( + list[ToolResultBlockParam], + [ + item + for message in messages + for item in ( + message["content"] if isinstance(message["content"], list) else [] + ) + if isinstance(item, dict) and item.get("type") == "tool_result" + ], + ) + + total_images = sum( + 1 + for tool_result in tool_result_blocks + for content in tool_result.get("content", []) + if isinstance(content, dict) and content.get("type") == "image" + ) + + images_to_remove = total_images - images_to_keep + # for better cache behavior, we want to remove in chunks + images_to_remove -= images_to_remove % min_removal_threshold + + for tool_result in tool_result_blocks: + if isinstance(tool_result.get("content"), list): + new_content = [] + for content in tool_result.get("content", []): + if isinstance(content, dict) and content.get("type") == "image": + if images_to_remove > 0: + images_to_remove -= 1 + continue + new_content.append(content) + tool_result["content"] = new_content + + +def _make_api_tool_result( + result: ToolResult, tool_use_id: str +) -> BetaToolResultBlockParam: + """Convert an agent ToolResult to an API ToolResultBlockParam.""" + tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = [] + is_error = False + if result.error: + is_error = True + tool_result_content = _maybe_prepend_system_tool_result(result, result.error) + else: + if result.output: + tool_result_content.append( + { + "type": "text", + "text": _maybe_prepend_system_tool_result(result, result.output), + } + ) + if result.base64_image: + tool_result_content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": result.base64_image, + }, + } + ) + return { + "type": "tool_result", + "content": tool_result_content, + "tool_use_id": tool_use_id, + "is_error": is_error, + } + + +def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str): + if result.system: + result_text = f"{result.system}\n{result_text}" + return result_text diff --git a/computer-use-demo/computer_use_demo/requirements.txt b/computer-use-demo/computer_use_demo/requirements.txt new file mode 100644 index 00000000..8b3760d1 --- /dev/null +++ b/computer-use-demo/computer_use_demo/requirements.txt @@ -0,0 +1,5 @@ +streamlit>=1.38.0 +anthropic[bedrock,vertex]>=0.36.2 +jsonschema==4.22.0 +boto3>=1.28.57 +google-auth<3,>=2 diff --git a/computer-use-demo/computer_use_demo/streamlit.py b/computer-use-demo/computer_use_demo/streamlit.py new file mode 100644 index 00000000..ac99b5c6 --- /dev/null +++ b/computer-use-demo/computer_use_demo/streamlit.py @@ -0,0 +1,364 @@ +""" +Entrypoint for streamlit, see https://docs.streamlit.io/ +""" + +import asyncio +import base64 +import os +import subprocess +from datetime import datetime +from enum import StrEnum +from functools import partial +from pathlib import PosixPath +from typing import cast + +import streamlit as st +from anthropic import APIResponse +from anthropic.types import ( + TextBlock, +) +from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock +from anthropic.types.tool_use_block import ToolUseBlock +from streamlit.delta_generator import DeltaGenerator + +from computer_use_demo.loop import ( + PROVIDER_TO_DEFAULT_MODEL_NAME, + APIProvider, + sampling_loop, +) +from computer_use_demo.tools import ToolResult + +CONFIG_DIR = PosixPath("~/.anthropic").expanduser() +API_KEY_FILE = CONFIG_DIR / "api_key" +STREAMLIT_STYLE = """ + +""" + +WARNING_TEXT = "⚠️ Security Alert: Never provide access to sensitive accounts or data, as malicious web content can hijack Claude's behavior" + + +class Sender(StrEnum): + USER = "user" + BOT = "assistant" + TOOL = "tool" + + +def setup_state(): + if "messages" not in st.session_state: + st.session_state.messages = [] + if "api_key" not in st.session_state: + # Try to load API key from file first, then environment + st.session_state.api_key = load_from_storage("api_key") or os.getenv( + "ANTHROPIC_API_KEY", "" + ) + if "api_key_input" not in st.session_state: + st.session_state.api_key_input = st.session_state.api_key + if "provider" not in st.session_state: + st.session_state.provider = ( + os.getenv("API_PROVIDER", "anthropic") or APIProvider.ANTHROPIC + ) + if "provider_radio" not in st.session_state: + st.session_state.provider_radio = st.session_state.provider + if "model" not in st.session_state: + _reset_model() + if "auth_validated" not in st.session_state: + st.session_state.auth_validated = False + if "responses" not in st.session_state: + st.session_state.responses = {} + if "tools" not in st.session_state: + st.session_state.tools = {} + if "only_n_most_recent_images" not in st.session_state: + st.session_state.only_n_most_recent_images = 10 + if "custom_system_prompt" not in st.session_state: + st.session_state.custom_system_prompt = load_from_storage("system_prompt") or "" + if "hide_images" not in st.session_state: + st.session_state.hide_images = False + + +def _reset_model(): + st.session_state.model = PROVIDER_TO_DEFAULT_MODEL_NAME[ + cast(APIProvider, st.session_state.provider) + ] + + +async def main(): + """Render loop for streamlit""" + setup_state() + + st.markdown(STREAMLIT_STYLE, unsafe_allow_html=True) + + st.title("Claude Computer Use Demo") + + if not os.getenv("HIDE_WARNING", False): + st.warning(WARNING_TEXT) + + with st.sidebar: + + def _reset_api_provider(): + if st.session_state.provider_radio != st.session_state.provider: + _reset_model() + st.session_state.provider = st.session_state.provider_radio + st.session_state.auth_validated = False + + provider_options = [option.value for option in APIProvider] + st.radio( + "API Provider", + options=provider_options, + index=provider_options.index(st.session_state.provider), + key="provider_radio", + format_func=lambda x: x.title(), + on_change=_reset_api_provider, + ) + + st.text_input("Model", key="model") + + if st.session_state.provider == APIProvider.ANTHROPIC: + st.text_input( + "Anthropic API Key", + value=st.session_state.api_key, + type="password", + key="api_key_input", + on_change=lambda: save_to_storage( + "api_key", st.session_state.api_key_input + ), + ) + st.session_state.api_key = st.session_state.api_key_input + + st.number_input( + "Only send N most recent images", + min_value=0, + key="only_n_most_recent_images", + help="To decrease the total tokens sent, remove older screenshots from the conversation", + ) + st.text_area( + "Custom System Prompt Suffix", + key="custom_system_prompt", + help="Additional instructions to append to the system prompt. see computer_use_demo/loop.py for the base system prompt.", + on_change=lambda: save_to_storage( + "system_prompt", st.session_state.custom_system_prompt + ), + ) + st.checkbox("Hide screenshots", key="hide_images") + + if st.button("Reset", type="primary"): + with st.spinner("Resetting..."): + st.session_state.clear() + setup_state() + + subprocess.run("pkill Xvfb; pkill tint2", shell=True) # noqa: ASYNC221 + await asyncio.sleep(1) + subprocess.run("./start_all.sh", shell=True) # noqa: ASYNC221 + + if not st.session_state.auth_validated: + if auth_error := validate_auth( + st.session_state.provider, st.session_state.api_key + ): + st.warning(f"Please resolve the following auth issue:\n\n{auth_error}") + return + else: + st.session_state.auth_validated = True + + chat, http_logs = st.tabs(["Chat", "HTTP Exchange Logs"]) + new_message = st.chat_input( + "Type a message to send to Claude to control the computer..." + ) + + with chat: + # render past chats + for message in st.session_state.messages: + if isinstance(message["content"], str): + _render_message(message["role"], message["content"]) + elif isinstance(message["content"], list): + for block in message["content"]: + # the tool result we send back to the Anthropic API isn't sufficient to render all details, + # so we store the tool use responses + if isinstance(block, dict) and block["type"] == "tool_result": + _render_message( + Sender.TOOL, st.session_state.tools[block["tool_use_id"]] + ) + else: + _render_message( + message["role"], + cast(BetaTextBlock | BetaToolUseBlock, block), + ) + + # render past http exchanges + for identity, response in st.session_state.responses.items(): + _render_api_response(response, identity, http_logs) + + # render past chats + if new_message: + st.session_state.messages.append( + { + "role": Sender.USER, + "content": [TextBlock(type="text", text=new_message)], + } + ) + _render_message(Sender.USER, new_message) + + try: + most_recent_message = st.session_state["messages"][-1] + except IndexError: + return + + if most_recent_message["role"] is not Sender.USER: + # we don't have a user message to respond to, exit early + return + + with st.spinner("Running Agent..."): + # run the agent sampling loop with the newest message + st.session_state.messages = await sampling_loop( + system_prompt_suffix=st.session_state.custom_system_prompt, + model=st.session_state.model, + provider=st.session_state.provider, + messages=st.session_state.messages, + output_callback=partial(_render_message, Sender.BOT), + tool_output_callback=partial( + _tool_output_callback, tool_state=st.session_state.tools + ), + api_response_callback=partial( + _api_response_callback, + tab=http_logs, + response_state=st.session_state.responses, + ), + api_key=st.session_state.api_key, + only_n_most_recent_images=st.session_state.only_n_most_recent_images, + ) + + +def validate_auth(provider: APIProvider, api_key: str | None): + if provider == APIProvider.ANTHROPIC: + if not api_key: + return "Enter your Anthropic API key in the sidebar to continue." + if provider == APIProvider.BEDROCK: + import boto3 + + if not boto3.Session().get_credentials(): + return "You must have AWS credentials set up to use the Bedrock API." + if provider == APIProvider.VERTEX: + import google.auth + from google.auth.exceptions import DefaultCredentialsError + + if not os.environ.get("CLOUD_ML_REGION"): + return "Set the CLOUD_ML_REGION environment variable to use the Vertex API." + try: + google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + except DefaultCredentialsError: + return "Your google cloud credentials are not set up correctly." + + +def load_from_storage(filename: str) -> str | None: + """Load data from a file in the storage directory.""" + try: + file_path = CONFIG_DIR / filename + if file_path.exists(): + data = file_path.read_text().strip() + if data: + return data + except Exception as e: + st.write(f"Debug: Error loading {filename}: {e}") + return None + + +def save_to_storage(filename: str, data: str) -> None: + """Save data to a file in the storage directory.""" + try: + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + file_path = CONFIG_DIR / filename + file_path.write_text(data) + # Ensure only user can read/write the file + file_path.chmod(0o600) + except Exception as e: + st.write(f"Debug: Error saving {filename}: {e}") + + +def _api_response_callback( + response: APIResponse[BetaMessage], + tab: DeltaGenerator, + response_state: dict[str, APIResponse[BetaMessage]], +): + """ + Handle an API response by storing it to state and rendering it. + """ + response_id = datetime.now().isoformat() + response_state[response_id] = response + _render_api_response(response, response_id, tab) + + +def _tool_output_callback( + tool_output: ToolResult, tool_id: str, tool_state: dict[str, ToolResult] +): + """Handle a tool output by storing it to state and rendering it.""" + tool_state[tool_id] = tool_output + _render_message(Sender.TOOL, tool_output) + + +def _render_api_response( + response: APIResponse[BetaMessage], response_id: str, tab: DeltaGenerator +): + """Render an API response to a streamlit tab""" + with tab: + with st.expander(f"Request/Response ({response_id})"): + newline = "\n\n" + st.markdown( + f"`{response.http_request.method} {response.http_request.url}`{newline}{newline.join(f'`{k}: {v}`' for k, v in response.http_request.headers.items())}" + ) + st.json(response.http_request.read().decode()) + st.markdown( + f"`{response.http_response.status_code}`{newline}{newline.join(f'`{k}: {v}`' for k, v in response.headers.items())}" + ) + st.json(response.http_response.text) + + +def _render_message( + sender: Sender, + message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, +): + """Convert input from the user or output from the agent to a streamlit message.""" + # streamlit's hotreloading breaks isinstance checks, so we need to check for class names + is_tool_result = not isinstance(message, str) and ( + isinstance(message, ToolResult) + or message.__class__.__name__ == "ToolResult" + or message.__class__.__name__ == "CLIResult" + ) + if not message or ( + is_tool_result + and st.session_state.hide_images + and not hasattr(message, "error") + and not hasattr(message, "output") + ): + return + with st.chat_message(sender): + if is_tool_result: + message = cast(ToolResult, message) + if message.output: + if message.__class__.__name__ == "CLIResult": + st.code(message.output) + else: + st.markdown(message.output) + if message.error: + st.error(message.error) + if message.base64_image and not st.session_state.hide_images: + st.image(base64.b64decode(message.base64_image)) + elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock): + st.write(message.text) + elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock): + st.code(f"Tool Use: {message.name}\nInput: {message.input}") + else: + st.markdown(message) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/computer-use-demo/computer_use_demo/tools/__init__.py b/computer-use-demo/computer_use_demo/tools/__init__.py new file mode 100644 index 00000000..1fd037f1 --- /dev/null +++ b/computer-use-demo/computer_use_demo/tools/__init__.py @@ -0,0 +1,14 @@ +from .base import CLIResult, ToolResult +from .bash import BashTool +from .collection import ToolCollection +from .computer import ComputerTool +from .edit import EditTool + +__ALL__ = [ + BashTool, + CLIResult, + ComputerTool, + EditTool, + ToolCollection, + ToolResult, +] diff --git a/computer-use-demo/computer_use_demo/tools/base.py b/computer-use-demo/computer_use_demo/tools/base.py new file mode 100644 index 00000000..37eafbf0 --- /dev/null +++ b/computer-use-demo/computer_use_demo/tools/base.py @@ -0,0 +1,94 @@ +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, fields, replace +from typing import Any, ClassVar, Literal, Optional, Required, TypedDict + +APIToolType = Literal["computer_20241022", "text_editor_20241022", "bash_20241022"] +APIToolName = Literal["computer", "str_replace_editor", "bash"] + + +class AnthropicAPIToolParam(TypedDict): + """API shape for Anthropic-defined tools.""" + + name: Required[APIToolName] + type: Required[APIToolType] + + +class ComputerToolOptions(TypedDict): + display_height_px: Required[int] + display_width_px: Required[int] + display_number: Optional[int] + + +class BaseAnthropicTool(metaclass=ABCMeta): + """Abstract base class for Anthropic-defined tools.""" + + name: ClassVar[APIToolName] + api_type: ClassVar[APIToolType] + + @property + def options(self) -> ComputerToolOptions | None: + return None + + @abstractmethod + def __call__(self, **kwargs) -> Any: + """Executes the tool with the given arguments.""" + ... + + def to_params( + self, + ) -> dict: # -> AnthropicToolParam & Optional[ComputerToolOptions] + """Creates the shape necessary to this tool to the Anthropic API.""" + return { + "name": self.name, + "type": self.api_type, + **(self.options or {}), + } + + +@dataclass(kw_only=True, frozen=True) +class ToolResult: + """Represents the result of a tool execution.""" + + output: str | None = None + error: str | None = None + base64_image: str | None = None + system: str | None = None + + def __bool__(self): + return any(getattr(self, field.name) for field in fields(self)) + + def __add__(self, other: "ToolResult"): + def combine_fields( + field: str | None, other_field: str | None, concatenate: bool = True + ): + if field and other_field: + if concatenate: + return field + other_field + raise ValueError("Cannot combine tool results") + return field or other_field + + return ToolResult( + output=combine_fields(self.output, other.output), + error=combine_fields(self.error, other.error), + base64_image=combine_fields(self.base64_image, other.base64_image, False), + system=combine_fields(self.system, other.system), + ) + + def replace(self, **kwargs): + """Returns a new ToolResult with the given fields replaced.""" + return replace(self, **kwargs) + + +class CLIResult(ToolResult): + """A ToolResult that can be rendered as a CLI output.""" + + +class ToolFailure(ToolResult): + """A ToolResult that represents a failure.""" + + +class ToolError(Exception): + """Raised when a tool encounters an error.""" + + def __init__(self, message): + self.message = message diff --git a/computer-use-demo/computer_use_demo/tools/bash.py b/computer-use-demo/computer_use_demo/tools/bash.py new file mode 100644 index 00000000..e5d3df50 --- /dev/null +++ b/computer-use-demo/computer_use_demo/tools/bash.py @@ -0,0 +1,135 @@ +import asyncio +import os + +from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult + + +class _BashSession: + """A session of a bash shell.""" + + _started: bool + _process: asyncio.subprocess.Process + + command: str = "/bin/bash" + _output_delay: float = 0.2 # seconds + _timeout: float = 120.0 # seconds + _sentinel: str = "<>" + + def __init__(self): + self._started = False + self._timed_out = False + + async def start(self): + if self._started: + return + + self._process = await asyncio.create_subprocess_shell( + self.command, + preexec_fn=os.setsid, + shell=True, + bufsize=0, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + self._started = True + + def stop(self): + """Terminate the bash shell.""" + if not self._started: + raise ToolError("Session has not started.") + if self._process.returncode is not None: + return + self._process.terminate() + + async def run(self, command: str): + """Execute a command in the bash shell.""" + if not self._started: + raise ToolError("Session has not started.") + if self._process.returncode is not None: + return ToolResult( + system="tool must be restarted", + error=f"bash has exited with returncode {self._process.returncode}", + ) + if self._timed_out: + raise ToolError( + f"timed out: bash has not returned in {self._timeout} seconds and must be restarted", + ) + + # we know these are not None because we created the process with PIPEs + assert self._process.stdin + assert self._process.stdout + assert self._process.stderr + + # send command to the process + self._process.stdin.write( + command.encode() + f"; echo '{self._sentinel}'\n".encode() + ) + await self._process.stdin.drain() + + # read output from the process, until the sentinel is found + try: + async with asyncio.timeout(self._timeout): + while True: + await asyncio.sleep(self._output_delay) + # if we read directly from stdout/stderr, it will wait forever for + # EOF. use the StreamReader buffer directly instead. + output = self._process.stdout._buffer.decode() # pyright: ignore[reportAttributeAccessIssue] + if self._sentinel in output: + # strip the sentinel and break + output = output[: output.index(self._sentinel)] + break + except asyncio.TimeoutError: + self._timed_out = True + raise ToolError( + f"timed out: bash has not returned in {self._timeout} seconds and must be restarted", + ) from None + + if output.endswith("\n"): + output = output[:-1] + + error = self._process.stderr._buffer.decode() # pyright: ignore[reportAttributeAccessIssue] + if error.endswith("\n"): + error = error[:-1] + + # clear the buffers so that the next output can be read correctly + self._process.stdout._buffer.clear() # pyright: ignore[reportAttributeAccessIssue] + self._process.stderr._buffer.clear() # pyright: ignore[reportAttributeAccessIssue] + + return CLIResult(output=output, error=error) + + +class BashTool(BaseAnthropicTool): + """ + A tool that allows the agent to run bash commands. + The tool parameters are defined by Anthropic and are not editable. + """ + + _session: _BashSession | None + name = "bash" + api_type = "bash_20241022" + + def __init__(self): + self._session = None + super().__init__() + + async def __call__( + self, command: str | None = None, restart: bool = False, **kwargs + ): + if restart: + if self._session: + self._session.stop() + self._session = _BashSession() + await self._session.start() + + return ToolResult(system="tool has been restarted.") + + if self._session is None: + self._session = _BashSession() + await self._session.start() + + if command is not None: + return await self._session.run(command) + + raise ToolError("no command provided.") diff --git a/computer-use-demo/computer_use_demo/tools/collection.py b/computer-use-demo/computer_use_demo/tools/collection.py new file mode 100644 index 00000000..12e52123 --- /dev/null +++ b/computer-use-demo/computer_use_demo/tools/collection.py @@ -0,0 +1,32 @@ +"""Collection classes for managing multiple tools.""" + +from typing import Any + +from .base import ( + BaseAnthropicTool, + ToolError, + ToolFailure, + ToolResult, +) + + +class ToolCollection: + """A collection of anthropic-defined tools.""" + + def __init__(self, *tools: BaseAnthropicTool): + self.tools = tools + self.tool_map = {tool.name: tool for tool in tools} + + def to_params( + self, + ) -> list[dict]: # -> List[AnthropicToolParam & Optional[ComputerToolOptions]] + return [tool.to_params() for tool in self.tools] + + async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult: + tool = self.tool_map.get(name) + if not tool: + return ToolFailure(error=f"Tool {name} is invalid") + try: + return await tool(**tool_input) + except ToolError as e: + return ToolFailure(error=e.message) diff --git a/computer-use-demo/computer_use_demo/tools/computer.py b/computer-use-demo/computer_use_demo/tools/computer.py new file mode 100644 index 00000000..7e48c981 --- /dev/null +++ b/computer-use-demo/computer_use_demo/tools/computer.py @@ -0,0 +1,249 @@ +import asyncio +import base64 +import os +import shlex +import shutil +from enum import StrEnum +from pathlib import Path +from typing import Literal, TypedDict +from uuid import uuid4 + +from .base import BaseAnthropicTool, ComputerToolOptions, ToolError, ToolResult +from .run import run + +OUTPUT_DIR = "/tmp/outputs" + +TYPING_DELAY_MS = 12 +TYPING_GROUP_SIZE = 50 + +Action = Literal[ + "key", + "type", + "mouse_move", + "left_click", + "left_click_drag", + "right_click", + "middle_click", + "double_click", + "screenshot", + "cursor_position", +] + + +class Resolution(TypedDict): + width: int + height: int + + +# sizes above XGA/WXGA are not recommended (see README.md) +# scale down to one of these targets if ComputerTool._scaling_enabled is set +MAX_SCALING_TARGETS: dict[str, Resolution] = { + "XGA": Resolution(width=1024, height=768), # 4:3 + "WXGA": Resolution(width=1280, height=800), # 16:10 + "FWXGA": Resolution(width=1366, height=768), # ~16:9 +} + + +class ScalingSource(StrEnum): + COMPUTER = "computer" + API = "api" + + +def chunks(s: str, chunk_size: int) -> list[str]: + return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)] + + +class ComputerTool(BaseAnthropicTool): + """ + A tool that allows the agent to interact with the screen, keyboard, and mouse of the current computer. + The tool parameters are defined by Anthropic and are not editable. + """ + + name = "computer" + api_type = "computer_20241022" + width: int + height: int + display_num: int | None + + _screenshot_delay = 2.0 + _scaling_enabled = True + + @property + def options(self) -> ComputerToolOptions: + width, height = self.scale_coordinates( + ScalingSource.COMPUTER, self.width, self.height + ) + return { + "display_width_px": width, + "display_height_px": height, + "display_number": self.display_num, + } + + def __init__(self): + super().__init__() + + self.width = int(os.getenv("WIDTH") or 0) + self.height = int(os.getenv("HEIGHT") or 0) + assert self.width and self.height, "WIDTH, HEIGHT must be set" + if (display_num := os.getenv("DISPLAY_NUM")) is not None: + self.display_num = int(display_num) + self._display_prefix = f"DISPLAY=:{self.display_num} " + else: + self.display_num = None + self._display_prefix = "" + + self.xdotool = f"{self._display_prefix}xdotool" + + async def __call__( + self, + *, + action: Action, + text: str | None = None, + coordinate: tuple[int, int] | None = None, + **kwargs, + ): + if action in ("mouse_move", "left_click_drag"): + if coordinate is None: + raise ToolError(f"coordinate is required for {action}") + if text is not None: + raise ToolError(f"text is not accepted for {action}") + if not isinstance(coordinate, list) or len(coordinate) != 2: + raise ToolError(f"{coordinate} must be a tuple of length 2") + if not all(isinstance(i, int) and i >= 0 for i in coordinate): + raise ToolError(f"{coordinate} must be a tuple of non-negative ints") + + x, y = self.scale_coordinates( + ScalingSource.API, coordinate[0], coordinate[1] + ) + + if action == "mouse_move": + return await self.shell(f"{self.xdotool} mousemove --sync {x} {y}") + elif action == "left_click_drag": + return await self.shell( + f"{self.xdotool} mousedown 1 mousemove --sync {x} {y} mouseup 1" + ) + + if action in ("key", "type"): + if text is None: + raise ToolError(f"text is required for {action}") + if coordinate is not None: + raise ToolError(f"coordinate is not accepted for {action}") + if not isinstance(text, str): + raise ToolError(output=f"{text} must be a string") + + if action == "key": + return await self.shell(f"{self.xdotool} key -- {text}") + elif action == "type": + results: list[ToolResult] = [] + for chunk in chunks(text, TYPING_GROUP_SIZE): + cmd = f"{self.xdotool} type --delay {TYPING_DELAY_MS} -- {shlex.quote(chunk)}" + results.append(await self.shell(cmd, take_screenshot=False)) + screenshot_base64 = (await self.screenshot()).base64_image + return ToolResult( + output="".join(result.output or "" for result in results), + error="".join(result.error or "" for result in results), + base64_image=screenshot_base64, + ) + + if action in ( + "left_click", + "right_click", + "double_click", + "middle_click", + "screenshot", + "cursor_position", + ): + if text is not None: + raise ToolError(f"text is not accepted for {action}") + if coordinate is not None: + raise ToolError(f"coordinate is not accepted for {action}") + + if action == "screenshot": + return await self.screenshot() + elif action == "cursor_position": + result = await self.shell( + f"{self.xdotool} getmouselocation --shell", + take_screenshot=False, + ) + output = result.output or "" + x, y = self.scale_coordinates( + ScalingSource.COMPUTER, + int(output.split("X=")[1].split("\n")[0]), + int(output.split("Y=")[1].split("\n")[0]), + ) + return result.replace(output=f"X={x},Y={y}") + else: + click_arg = { + "left_click": "1", + "right_click": "3", + "middle_click": "2", + "double_click": "--repeat 2 --delay 500 1", + }[action] + return await self.shell(f"{self.xdotool} click {click_arg}") + + raise ToolError(f"Invalid action: {action}") + + async def screenshot(self): + """Take a screenshot of the current screen and return the base64 encoded image.""" + output_dir = Path(OUTPUT_DIR) + output_dir.mkdir(parents=True, exist_ok=True) + path = output_dir / f"screenshot_{uuid4().hex}.png" + + # Try gnome-screenshot first + if shutil.which("gnome-screenshot"): + screenshot_cmd = f"{self._display_prefix}gnome-screenshot -f {path} -p" + else: + # Fall back to scrot if gnome-screenshot isn't available + screenshot_cmd = f"{self._display_prefix}scrot -p {path}" + + result = await self.shell(screenshot_cmd, take_screenshot=False) + if self._scaling_enabled: + x, y = self.scale_coordinates( + ScalingSource.COMPUTER, self.width, self.height + ) + await self.shell( + f"convert {path} -resize {x}x{y}! {path}", take_screenshot=False + ) + + if path.exists(): + return result.replace( + base64_image=base64.b64encode(path.read_bytes()).decode() + ) + raise ToolError(f"Failed to take screenshot: {result.error}") + + async def shell(self, command: str, take_screenshot=True) -> ToolResult: + """Run a shell command and return the output, error, and optionally a screenshot.""" + _, stdout, stderr = await run(command) + base64_image = None + + if take_screenshot: + # delay to let things settle before taking a screenshot + await asyncio.sleep(self._screenshot_delay) + base64_image = (await self.screenshot()).base64_image + + return ToolResult(output=stdout, error=stderr, base64_image=base64_image) + + def scale_coordinates(self, source: ScalingSource, x: int, y: int): + """Scale coordinates to a target maximum resolution.""" + if not self._scaling_enabled: + return x, y + ratio = self.width / self.height + target_dimension = None + for dimension in MAX_SCALING_TARGETS.values(): + # allow some error in the aspect ratio - not ratios are exactly 16:9 + if abs(dimension["width"] / dimension["height"] - ratio) < 0.02: + if dimension["width"] < self.width: + target_dimension = dimension + break + if target_dimension is None: + return x, y + # should be less than 1 + x_scaling_factor = target_dimension["width"] / self.width + y_scaling_factor = target_dimension["height"] / self.height + if source == ScalingSource.API: + if x > self.width or y > self.height: + raise ToolError(f"Coordinates {x}, {y} are out of bounds") + # scale up + return round(x / x_scaling_factor), round(y / y_scaling_factor) + # scale down + return round(x * x_scaling_factor), round(y * y_scaling_factor) diff --git a/computer-use-demo/computer_use_demo/tools/edit.py b/computer-use-demo/computer_use_demo/tools/edit.py new file mode 100644 index 00000000..d0609f9b --- /dev/null +++ b/computer-use-demo/computer_use_demo/tools/edit.py @@ -0,0 +1,282 @@ +from collections import defaultdict +from pathlib import Path +from typing import Literal, get_args + +from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult +from .run import maybe_truncate, run + +Command = Literal[ + "view", + "create", + "str_replace", + "insert", + "undo_edit", +] +SNIPPET_LINES: int = 4 + + +class EditTool(BaseAnthropicTool): + """ + An filesystem editor tool that allows the agent to view, create, and edit files. + The tool parameters are defined by Anthropic and are not editable. + """ + + api_type = "text_editor_20241022" + name = "str_replace_editor" + + _file_history: dict[Path, list[str]] + + def __init__(self): + self._file_history = defaultdict(list) + super().__init__() + + async def __call__( + self, + *, + command: Command, + path: str, + file_text: str | None = None, + view_range: list[int] | None = None, + old_str: str | None = None, + new_str: str | None = None, + insert_line: int | None = None, + **kwargs, + ): + _path = Path(path) + self.validate_path(command, _path) + if command == "view": + return await self.view(_path, view_range) + elif command == "create": + if not file_text: + raise ToolError("Parameter `file_text` is required for command: create") + self.write_file(_path, file_text) + self._file_history[_path].append(file_text) + return ToolResult(output=f"File created successfully at: {_path}") + elif command == "str_replace": + if not old_str: + raise ToolError( + "Parameter `old_str` is required for command: str_replace" + ) + return self.str_replace(_path, old_str, new_str) + elif command == "insert": + if insert_line is None: + raise ToolError( + "Parameter `insert_line` is required for command: insert" + ) + if not new_str: + raise ToolError("Parameter `new_str` is required for command: insert") + return self.insert(_path, insert_line, new_str) + elif command == "undo_edit": + return self.undo_edit(_path) + raise ToolError( + f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}' + ) + + def validate_path(self, command: str, path: Path): + """ + Check that the path/command combination is valid. + """ + # Check if its an absolute path + if not path.is_absolute(): + suggested_path = Path("") / path + raise ToolError( + f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?" + ) + # Check if path exists + if not path.exists() and command != "create": + raise ToolError( + f"The path {path} does not exist. Please provide a valid path." + ) + if path.exists() and command == "create": + raise ToolError( + f"File already exists at: {path}. Cannot overwrite files using command `create`." + ) + # Check if the path points to a directory + if path.is_dir(): + if command != "view": + raise ToolError( + f"The path {path} is a directory and only the `view` command can be used on directories" + ) + + async def view(self, path: Path, view_range: list[int] | None = None): + """Implement the view command""" + if path.is_dir(): + if view_range: + raise ToolError( + "The `view_range` parameter is not allowed when `path` points to a directory." + ) + + _, stdout, stderr = await run( + rf"find {path} -maxdepth 2 -not -path '*/\.*'" + ) + if not stderr: + stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n" + return CLIResult(output=stdout, error=stderr) + + file_content = self.read_file(path) + init_line = 1 + if view_range: + if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range): + raise ToolError( + "Invalid `view_range`. It should be a list of two integers." + ) + file_lines = file_content.split("\n") + n_lines_file = len(file_lines) + init_line, final_line = view_range + if init_line < 1 or init_line > n_lines_file: + raise ToolError( + f"Invalid `view_range`: {view_range}. It's first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}" + ) + if final_line > n_lines_file: + raise ToolError( + f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`" + ) + if final_line != -1 and final_line < init_line: + raise ToolError( + f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be larger or equal than its first `{init_line}`" + ) + + if final_line == -1: + file_content = "\n".join(file_lines[init_line - 1 :]) + else: + file_content = "\n".join(file_lines[init_line - 1 : final_line]) + + return CLIResult( + output=self._make_output(file_content, str(path), init_line=init_line) + ) + + def str_replace(self, path: Path, old_str: str, new_str: str | None): + """Implement the str_replace command, which replaces old_str with new_str in the file content""" + # Read the file content + file_content = self.read_file(path).expandtabs() + old_str = old_str.expandtabs() + new_str = new_str.expandtabs() if new_str is not None else "" + + # Check if old_str is unique in the file + occurrences = file_content.count(old_str) + if occurrences == 0: + raise ToolError( + f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}." + ) + elif occurrences > 1: + file_content_lines = file_content.split("\n") + lines = [ + idx + 1 + for idx, line in enumerate(file_content_lines) + if old_str in line + ] + raise ToolError( + f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique" + ) + + # Replace old_str with new_str + new_file_content = file_content.replace(old_str, new_str) + + # Write the new content to the file + self.write_file(path, new_file_content) + + # Save the content to history + self._file_history[path].append(file_content) + + # Create a snippet of the edited section + replacement_line = file_content.split(old_str)[0].count("\n") + start_line = max(0, replacement_line - SNIPPET_LINES) + end_line = replacement_line + SNIPPET_LINES + new_str.count("\n") + snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1]) + + # Prepare the success message + success_msg = f"The file {path} has been edited. " + success_msg += self._make_output( + snippet, f"a snippet of {path}", start_line + 1 + ) + success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary." + + return CLIResult(output=success_msg) + + def insert(self, path: Path, insert_line: int, new_str: str): + """Implement the insert command, which inserts new_str at the specified line in the file content.""" + file_text = self.read_file(path).expandtabs() + new_str = new_str.expandtabs() + file_text_lines = file_text.split("\n") + n_lines_file = len(file_text_lines) + + if insert_line < 0 or insert_line > n_lines_file: + raise ToolError( + f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}" + ) + + new_str_lines = new_str.split("\n") + new_file_text_lines = ( + file_text_lines[:insert_line] + + new_str_lines + + file_text_lines[insert_line:] + ) + snippet_lines = ( + file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line] + + new_str_lines + + file_text_lines[insert_line : insert_line + SNIPPET_LINES] + ) + + new_file_text = "\n".join(new_file_text_lines) + snippet = "\n".join(snippet_lines) + + self.write_file(path, new_file_text) + self._file_history[path].append(file_text) + + success_msg = f"The file {path} has been edited. " + success_msg += self._make_output( + snippet, + "a snippet of the edited file", + max(1, insert_line - SNIPPET_LINES + 1), + ) + success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary." + return CLIResult(output=success_msg) + + def undo_edit(self, path: Path): + """Implement the undo_edit command.""" + if not self._file_history[path]: + raise ToolError(f"No edit history found for {path}.") + + old_text = self._file_history[path].pop() + self.write_file(path, old_text) + + return CLIResult( + output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}" + ) + + def read_file(self, path: Path): + """Read the content of a file from a given path; raise a ToolError if an error occurs.""" + try: + return path.read_text() + except Exception as e: + raise ToolError(f"Ran into {e} while trying to read {path}") from None + + def write_file(self, path: Path, file: str): + """Write the content of a file to a given path; raise a ToolError if an error occurs.""" + try: + path.write_text(file) + except Exception as e: + raise ToolError(f"Ran into {e} while trying to write to {path}") from None + + def _make_output( + self, + file_content: str, + file_descriptor: str, + init_line: int = 1, + expand_tabs: bool = True, + ): + """Generate output for the CLI based on the content of a file.""" + file_content = maybe_truncate(file_content) + if expand_tabs: + file_content = file_content.expandtabs() + file_content = "\n".join( + [ + f"{i + init_line:6}\t{line}" + for i, line in enumerate(file_content.split("\n")) + ] + ) + return ( + f"Here's the result of running `cat -n` on {file_descriptor}:\n" + + file_content + + "\n" + ) diff --git a/computer-use-demo/computer_use_demo/tools/run.py b/computer-use-demo/computer_use_demo/tools/run.py new file mode 100644 index 00000000..89db980a --- /dev/null +++ b/computer-use-demo/computer_use_demo/tools/run.py @@ -0,0 +1,42 @@ +"""Utility to run shell commands asynchronously with a timeout.""" + +import asyncio + +TRUNCATED_MESSAGE: str = "To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for." +MAX_RESPONSE_LEN: int = 16000 + + +def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN): + """Truncate content and append a notice if content exceeds the specified length.""" + return ( + content + if not truncate_after or len(content) <= truncate_after + else content[:truncate_after] + TRUNCATED_MESSAGE + ) + + +async def run( + cmd: str, + timeout: float | None = 120.0, # seconds + truncate_after: int | None = MAX_RESPONSE_LEN, +): + """Run a shell command asynchronously with a timeout.""" + process = await asyncio.create_subprocess_shell( + cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + + try: + stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout) + return ( + process.returncode or 0, + maybe_truncate(stdout.decode(), truncate_after=truncate_after), + maybe_truncate(stderr.decode(), truncate_after=truncate_after), + ) + except asyncio.TimeoutError as exc: + try: + process.kill() + except ProcessLookupError: + pass + raise TimeoutError( + f"Command '{cmd}' timed out after {timeout} seconds" + ) from exc diff --git a/computer-use-demo/dev-requirements.txt b/computer-use-demo/dev-requirements.txt new file mode 100644 index 00000000..daee0ee2 --- /dev/null +++ b/computer-use-demo/dev-requirements.txt @@ -0,0 +1,5 @@ +-r computer_use_demo/requirements.txt +ruff==0.6.7 +pre-commit==3.8.0 +pytest==8.3.3 +pytest-asyncio==0.23.6 diff --git a/computer-use-demo/image/.config/tint2/applications/firefox-custom.desktop b/computer-use-demo/image/.config/tint2/applications/firefox-custom.desktop new file mode 100755 index 00000000..94802126 --- /dev/null +++ b/computer-use-demo/image/.config/tint2/applications/firefox-custom.desktop @@ -0,0 +1,8 @@ +[Desktop Entry] +Name=Firefox Custom +Comment=Open Firefox with custom URL +Exec=firefox-esr -new-window +Icon=firefox-esr +Terminal=false +Type=Application +Categories=Network;WebBrowser; diff --git a/computer-use-demo/image/.config/tint2/applications/gedit.desktop b/computer-use-demo/image/.config/tint2/applications/gedit.desktop new file mode 100755 index 00000000..d5af03f4 --- /dev/null +++ b/computer-use-demo/image/.config/tint2/applications/gedit.desktop @@ -0,0 +1,8 @@ +[Desktop Entry] +Name=Gedit +Comment=Open gedit +Exec=gedit +Icon=text-editor-symbolic +Terminal=false +Type=Application +Categories=TextEditor; diff --git a/computer-use-demo/image/.config/tint2/applications/terminal.desktop b/computer-use-demo/image/.config/tint2/applications/terminal.desktop new file mode 100644 index 00000000..0c2d45d4 --- /dev/null +++ b/computer-use-demo/image/.config/tint2/applications/terminal.desktop @@ -0,0 +1,8 @@ +[Desktop Entry] +Name=Terminal +Comment=Open Terminal +Exec=xterm +Icon=utilities-terminal +Terminal=false +Type=Application +Categories=System;TerminalEmulator; diff --git a/computer-use-demo/image/.config/tint2/tint2rc b/computer-use-demo/image/.config/tint2/tint2rc new file mode 100644 index 00000000..5db6d312 --- /dev/null +++ b/computer-use-demo/image/.config/tint2/tint2rc @@ -0,0 +1,100 @@ +#------------------------------------- +# Panel +panel_items = TL +panel_size = 100% 60 +panel_margin = 0 0 +panel_padding = 2 0 2 +panel_background_id = 1 +wm_menu = 0 +panel_dock = 0 +panel_position = bottom center horizontal +panel_layer = top +panel_monitor = all +panel_shrink = 0 +autohide = 0 +autohide_show_timeout = 0 +autohide_hide_timeout = 0.5 +autohide_height = 2 +strut_policy = follow_size +panel_window_name = tint2 +disable_transparency = 1 +mouse_effects = 1 +font_shadow = 0 +mouse_hover_icon_asb = 100 0 10 +mouse_pressed_icon_asb = 100 0 0 +scale_relative_to_dpi = 0 +scale_relative_to_screen_height = 0 + +#------------------------------------- +# Taskbar +taskbar_mode = single_desktop +taskbar_hide_if_empty = 0 +taskbar_padding = 0 0 2 +taskbar_background_id = 0 +taskbar_active_background_id = 0 +taskbar_name = 1 +taskbar_hide_inactive_tasks = 0 +taskbar_hide_different_monitor = 0 +taskbar_hide_different_desktop = 0 +taskbar_always_show_all_desktop_tasks = 0 +taskbar_name_padding = 4 2 +taskbar_name_background_id = 0 +taskbar_name_active_background_id = 0 +taskbar_name_font_color = #e3e3e3 100 +taskbar_name_active_font_color = #ffffff 100 +taskbar_distribute_size = 0 +taskbar_sort_order = none +task_align = left + +#------------------------------------- +# Launcher +launcher_padding = 4 8 4 +launcher_background_id = 0 +launcher_icon_background_id = 0 +launcher_icon_size = 48 +launcher_icon_asb = 100 0 0 +launcher_icon_theme_override = 0 +startup_notifications = 1 +launcher_tooltip = 1 + +#------------------------------------- +# Launcher icon +launcher_item_app = /usr/share/applications/libreoffice-calc.desktop +launcher_item_app = /home/computeruse/.config/tint2/applications/terminal.desktop +launcher_item_app = /home/computeruse/.config/tint2/applications/firefox-custom.desktop +launcher_item_app = /usr/share/applications/xpaint.desktop +launcher_item_app = /usr/share/applications/xpdf.desktop +launcher_item_app = /home/computeruse/.config/tint2/applications/gedit.desktop +launcher_item_app = /usr/share/applications/galculator.desktop + +#------------------------------------- +# Background definitions +# ID 1 +rounded = 0 +border_width = 0 +background_color = #000000 60 +border_color = #000000 30 + +# ID 2 +rounded = 4 +border_width = 1 +background_color = #777777 20 +border_color = #777777 30 + +# ID 3 +rounded = 4 +border_width = 1 +background_color = #777777 20 +border_color = #ffffff 40 + +# ID 4 +rounded = 4 +border_width = 1 +background_color = #aa4400 100 +border_color = #aa7733 100 + +# ID 5 +rounded = 4 +border_width = 1 +background_color = #aaaa00 100 +border_color = #aaaa00 100 diff --git a/computer-use-demo/image/.streamlit/config.toml b/computer-use-demo/image/.streamlit/config.toml new file mode 100644 index 00000000..544f575c --- /dev/null +++ b/computer-use-demo/image/.streamlit/config.toml @@ -0,0 +1,6 @@ +[server] +fileWatcherType = "auto" +runOnSave = true + +[browser] +gatherUsageStats = false diff --git a/computer-use-demo/image/entrypoint.sh b/computer-use-demo/image/entrypoint.sh new file mode 100755 index 00000000..c0a5e677 --- /dev/null +++ b/computer-use-demo/image/entrypoint.sh @@ -0,0 +1,15 @@ +#!/bin/bash +set -e + +./start_all.sh +./novnc_startup.sh + +python http_server.py > /tmp/server_logs.txt 2>&1 & + +STREAMLIT_SERVER_PORT=8501 python -m streamlit run computer_use_demo/streamlit.py > /tmp/streamlit_stdout.log & + +echo "✨ Computer Use Demo is ready!" +echo "➡️ Open http://localhost:8080 in your browser to begin" + +# Keep the container running +tail -f /dev/null diff --git a/computer-use-demo/image/http_server.py b/computer-use-demo/image/http_server.py new file mode 100644 index 00000000..082ff4de --- /dev/null +++ b/computer-use-demo/image/http_server.py @@ -0,0 +1,19 @@ +import os +import socket +from http.server import HTTPServer, SimpleHTTPRequestHandler + + +class HTTPServerV6(HTTPServer): + address_family = socket.AF_INET6 + + +def run_server(): + os.chdir(os.path.dirname(__file__) + "/static_content") + server_address = ("::", 8080) + httpd = HTTPServerV6(server_address, SimpleHTTPRequestHandler) + print("Starting HTTP server on port 8080...") # noqa: T201 + httpd.serve_forever() + + +if __name__ == "__main__": + run_server() diff --git a/computer-use-demo/image/index.html b/computer-use-demo/image/index.html new file mode 100644 index 00000000..e6336dbd --- /dev/null +++ b/computer-use-demo/image/index.html @@ -0,0 +1,43 @@ + + + + Computer Use Demo + + + + +
+ + +
+ + diff --git a/computer-use-demo/image/mutter_startup.sh b/computer-use-demo/image/mutter_startup.sh new file mode 100755 index 00000000..5f714f74 --- /dev/null +++ b/computer-use-demo/image/mutter_startup.sh @@ -0,0 +1,20 @@ +echo "starting mutter" +XDG_SESSION_TYPE=x11 mutter --replace --sm-disable 2>/tmp/mutter_stderr.log & + +# Wait for tint2 window properties to appear +timeout=30 +while [ $timeout -gt 0 ]; do + if xdotool search --class "mutter" >/dev/null 2>&1; then + break + fi + sleep 1 + ((timeout--)) +done + +if [ $timeout -eq 0 ]; then + echo "mutter stderr output:" >&2 + cat /tmp/mutter_stderr.log >&2 + exit 1 +fi + +rm /tmp/mutter_stderr.log diff --git a/computer-use-demo/image/novnc_startup.sh b/computer-use-demo/image/novnc_startup.sh new file mode 100755 index 00000000..da56816c --- /dev/null +++ b/computer-use-demo/image/novnc_startup.sh @@ -0,0 +1,21 @@ +#!/bin/bash +echo "starting noVNC" + +# Start noVNC with explicit websocket settings +/opt/noVNC/utils/novnc_proxy \ + --vnc localhost:5900 \ + --listen 6080 \ + --web /opt/noVNC \ + > /tmp/novnc.log 2>&1 & + +# Wait for noVNC to start +timeout=10 +while [ $timeout -gt 0 ]; do + if netstat -tuln | grep -q ":6080 "; then + break + fi + sleep 1 + ((timeout--)) +done + +echo "noVNC started successfully" diff --git a/computer-use-demo/image/start_all.sh b/computer-use-demo/image/start_all.sh new file mode 100755 index 00000000..31224dc6 --- /dev/null +++ b/computer-use-demo/image/start_all.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +set -e + +export DISPLAY=:${DISPLAY_NUM} +./xvfb_startup.sh +./tint2_startup.sh +./mutter_startup.sh +./x11vnc_startup.sh diff --git a/computer-use-demo/image/static_content/index.html b/computer-use-demo/image/static_content/index.html new file mode 100644 index 00000000..cfd2ae13 --- /dev/null +++ b/computer-use-demo/image/static_content/index.html @@ -0,0 +1,72 @@ + + + + Computer Use Demo + + + + +
+ + + + +
+ + diff --git a/computer-use-demo/image/tint2_startup.sh b/computer-use-demo/image/tint2_startup.sh new file mode 100755 index 00000000..34f39a18 --- /dev/null +++ b/computer-use-demo/image/tint2_startup.sh @@ -0,0 +1,24 @@ +#!/bin/bash +echo "starting tint2 on display :$DISPLAY_NUM ..." + +# Start tint2 and capture its stderr +tint2 -c $HOME/.config/tint2/tint2rc 2>/tmp/tint2_stderr.log & + +# Wait for tint2 window properties to appear +timeout=30 +while [ $timeout -gt 0 ]; do + if xdotool search --class "tint2" >/dev/null 2>&1; then + break + fi + sleep 1 + ((timeout--)) +done + +if [ $timeout -eq 0 ]; then + echo "tint2 stderr output:" >&2 + cat /tmp/tint2_stderr.log >&2 + exit 1 +fi + +# Remove the temporary stderr log file +rm /tmp/tint2_stderr.log diff --git a/computer-use-demo/image/x11vnc_startup.sh b/computer-use-demo/image/x11vnc_startup.sh new file mode 100755 index 00000000..3ada566d --- /dev/null +++ b/computer-use-demo/image/x11vnc_startup.sh @@ -0,0 +1,31 @@ +#!/bin/bash +echo "starting vnc" + +(x11vnc -display $DISPLAY \ + -forever \ + -shared \ + -wait 50 \ + -timeout 60 \ + -noxrecord \ + -noxfixes \ + -noxdamage \ + -rfbport 5900 \ + 2>/tmp/x11vnc_stderr.log) & + +# Wait for x11vnc to start +timeout=10 +while [ $timeout -gt 0 ]; do + if netstat -tuln | grep -q ":5900 "; then + break + fi + sleep 1 + ((timeout--)) +done + +if [ $timeout -eq 0 ]; then + echo "x11vnc stderr output:" >&2 + cat /tmp/x11vnc_stderr.log >&2 + exit 1 +fi + +rm /tmp/x11vnc_stderr.log diff --git a/computer-use-demo/image/xvfb_startup.sh b/computer-use-demo/image/xvfb_startup.sh new file mode 100755 index 00000000..9b9ae585 --- /dev/null +++ b/computer-use-demo/image/xvfb_startup.sh @@ -0,0 +1,48 @@ +#!/bin/bash +set -e # Exit on error + +DPI=96 +RES_AND_DEPTH=${WIDTH}x${HEIGHT}x24 + +# Function to check if Xvfb is already running +check_xvfb_running() { + if [ -e /tmp/.X${DISPLAY_NUM}-lock ]; then + return 0 # Xvfb is already running + else + return 1 # Xvfb is not running + fi +} + +# Function to check if Xvfb is ready +wait_for_xvfb() { + local timeout=10 + local start_time=$(date +%s) + while ! xdpyinfo >/dev/null 2>&1; do + if [ $(($(date +%s) - start_time)) -gt $timeout ]; then + echo "Xvfb failed to start within $timeout seconds" >&2 + return 1 + fi + sleep 0.1 + done + return 0 +} + +# Check if Xvfb is already running +if check_xvfb_running; then + echo "Xvfb is already running on display ${DISPLAY}" + exit 0 +fi + +# Start Xvfb +Xvfb $DISPLAY -ac -screen 0 $RES_AND_DEPTH -retro -dpi $DPI -nolisten tcp -nolisten unix & +XVFB_PID=$! + +# Wait for Xvfb to start +if wait_for_xvfb; then + echo "Xvfb started successfully on display ${DISPLAY}" + echo "Xvfb PID: $XVFB_PID" +else + echo "Xvfb failed to start" + kill $XVFB_PID + exit 1 +fi diff --git a/computer-use-demo/pyproject.toml b/computer-use-demo/pyproject.toml new file mode 100644 index 00000000..26850f01 --- /dev/null +++ b/computer-use-demo/pyproject.toml @@ -0,0 +1,8 @@ +[tool.pyright] +venvPath = "." +venv = ".venv" +useLibraryCodeForTypes = false + +[tool.pytest.ini_options] +pythonpath = "." +asyncio_mode = "auto" diff --git a/computer-use-demo/ruff.toml b/computer-use-demo/ruff.toml new file mode 100644 index 00000000..18d30ac3 --- /dev/null +++ b/computer-use-demo/ruff.toml @@ -0,0 +1,24 @@ +extend-exclude = [".venv"] + +[format] +docstring-code-format = true + +[lint] +select = [ + "A", + "ASYNC", + "B", + "E", + "F", + "I", + "PIE", + "RUF200", + "T20", + "UP", + "W", +] + +ignore = ["E501", "ASYNC230"] + +[lint.isort] +combine-as-imports = true diff --git a/computer-use-demo/setup.sh b/computer-use-demo/setup.sh new file mode 100755 index 00000000..f8b992dd --- /dev/null +++ b/computer-use-demo/setup.sh @@ -0,0 +1,6 @@ +#!/bin/bash +python3 -m venv .venv +source .venv/bin/activate +pip install --upgrade pip +pip install -r dev-requirements.txt +pre-commit install diff --git a/computer-use-demo/tests/conftest.py b/computer-use-demo/tests/conftest.py new file mode 100644 index 00000000..4b87e76c --- /dev/null +++ b/computer-use-demo/tests/conftest.py @@ -0,0 +1,12 @@ +import os +from unittest import mock + +import pytest + + +@pytest.fixture(autouse=True) +def mock_screen_dimensions(): + with mock.patch.dict( + os.environ, {"HEIGHT": "768", "WIDTH": "1024", "DISPLAY_NUM": "1"} + ): + yield diff --git a/computer-use-demo/tests/loop_test.py b/computer-use-demo/tests/loop_test.py new file mode 100644 index 00000000..4985dbee --- /dev/null +++ b/computer-use-demo/tests/loop_test.py @@ -0,0 +1,64 @@ +from unittest import mock + +from anthropic.types import TextBlock, ToolUseBlock +from anthropic.types.beta import BetaMessage, BetaMessageParam + +from computer_use_demo.loop import APIProvider, sampling_loop + + +async def test_loop(): + client = mock.Mock() + client.beta.messages.with_raw_response.create.return_value = mock.Mock() + client.beta.messages.with_raw_response.create.return_value.parse.side_effect = [ + mock.Mock( + spec=BetaMessage, + content=[ + TextBlock(type="text", text="Hello"), + ToolUseBlock( + type="tool_use", id="1", name="computer", input={"action": "test"} + ), + ], + ), + mock.Mock(spec=BetaMessage, content=[TextBlock(type="text", text="Done!")]), + ] + + tool_collection = mock.AsyncMock() + tool_collection.run.return_value = mock.Mock( + output="Tool output", error=None, base64_image=None + ) + + output_callback = mock.Mock() + tool_output_callback = mock.Mock() + api_response_callback = mock.Mock() + + with mock.patch( + "computer_use_demo.loop.Anthropic", return_value=client + ), mock.patch( + "computer_use_demo.loop.ToolCollection", return_value=tool_collection + ): + messages: list[BetaMessageParam] = [{"role": "user", "content": "Test message"}] + result = await sampling_loop( + model="test-model", + provider=APIProvider.ANTHROPIC, + system_prompt_suffix="", + messages=messages, + output_callback=output_callback, + tool_output_callback=tool_output_callback, + api_response_callback=api_response_callback, + api_key="test-key", + ) + + assert len(result) == 4 + assert result[0] == {"role": "user", "content": "Test message"} + assert result[1]["role"] == "assistant" + assert result[2]["role"] == "user" + assert result[3]["role"] == "assistant" + + assert client.beta.messages.with_raw_response.create.call_count == 2 + tool_collection.run.assert_called_once_with( + name="computer", tool_input={"action": "test"} + ) + output_callback.assert_called_with(TextBlock(text="Done!", type="text")) + assert output_callback.call_count == 3 + assert tool_output_callback.call_count == 1 + assert api_response_callback.call_count == 2 diff --git a/computer-use-demo/tests/streamlit_test.py b/computer-use-demo/tests/streamlit_test.py new file mode 100644 index 00000000..25cd586b --- /dev/null +++ b/computer-use-demo/tests/streamlit_test.py @@ -0,0 +1,23 @@ +from unittest import mock + +import pytest +from streamlit.testing.v1 import AppTest + +from computer_use_demo.streamlit import Sender, TextBlock + + +@pytest.fixture +def streamlit_app(): + return AppTest.from_file("computer_use_demo/streamlit.py") + + +def test_streamlit(streamlit_app: AppTest): + streamlit_app.run() + streamlit_app.text_input[1].set_value("sk-ant-0000000000000").run() + with mock.patch("computer_use_demo.loop.sampling_loop") as patch: + streamlit_app.chat_input[0].set_value("Hello").run() + assert patch.called + assert patch.call_args.kwargs["messages"] == [ + {"role": Sender.USER, "content": [TextBlock(text="Hello", type="text")]} + ] + assert not streamlit_app.exception diff --git a/computer-use-demo/tests/tools/bash_test.py b/computer-use-demo/tests/tools/bash_test.py new file mode 100644 index 00000000..7a69a68e --- /dev/null +++ b/computer-use-demo/tests/tools/bash_test.py @@ -0,0 +1,71 @@ +import pytest + +from computer_use_demo.tools.bash import BashTool, ToolError + + +@pytest.fixture +def bash_tool(): + return BashTool() + + +@pytest.mark.asyncio +async def test_bash_tool_restart(bash_tool): + result = await bash_tool(restart=True) + assert result.system == "tool has been restarted." + + # Verify the tool can be used after restart + result = await bash_tool(command="echo 'Hello after restart'") + assert "Hello after restart" in result.output + + +@pytest.mark.asyncio +async def test_bash_tool_run_command(bash_tool): + result = await bash_tool(command="echo 'Hello, World!'") + assert result.output.strip() == "Hello, World!" + assert result.error == "" + + +@pytest.mark.asyncio +async def test_bash_tool_no_command(bash_tool): + with pytest.raises(ToolError, match="no command provided."): + await bash_tool() + + +@pytest.mark.asyncio +async def test_bash_tool_session_creation(bash_tool): + result = await bash_tool(command="echo 'Session created'") + assert bash_tool._session is not None + assert "Session created" in result.output + + +@pytest.mark.asyncio +async def test_bash_tool_session_reuse(bash_tool): + result1 = await bash_tool(command="echo 'First command'") + result2 = await bash_tool(command="echo 'Second command'") + + assert "First command" in result1.output + assert "Second command" in result2.output + + +@pytest.mark.asyncio +async def test_bash_tool_session_error(bash_tool): + result = await bash_tool(command="invalid_command_that_does_not_exist") + assert "command not found" in result.error + + +@pytest.mark.asyncio +async def test_bash_tool_non_zero_exit(bash_tool): + result = await bash_tool(command="bash -c 'exit 1'") + assert result.error.strip() == "" + assert result.output.strip() == "" + + +@pytest.mark.asyncio +async def test_bash_tool_timeout(bash_tool): + await bash_tool(command="echo 'Hello, World!'") + bash_tool._session._timeout = 0.1 # Set a very short timeout for testing + with pytest.raises( + ToolError, + match="timed out: bash has not returned in 0.1 seconds and must be restarted", + ): + await bash_tool(command="sleep 1") diff --git a/computer-use-demo/tests/tools/computer_test.py b/computer-use-demo/tests/tools/computer_test.py new file mode 100644 index 00000000..37a9c2f7 --- /dev/null +++ b/computer-use-demo/tests/tools/computer_test.py @@ -0,0 +1,139 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +from computer_use_demo.tools.computer import ( + ComputerTool, + ScalingSource, + ToolError, + ToolResult, +) + + +@pytest.fixture +def computer_tool(): + return ComputerTool() + + +@pytest.mark.asyncio +async def test_computer_tool_mouse_move(computer_tool): + with patch.object(computer_tool, "shell", new_callable=AsyncMock) as mock_shell: + mock_shell.return_value = ToolResult(output="Mouse moved") + result = await computer_tool(action="mouse_move", coordinate=[100, 200]) + mock_shell.assert_called_once_with( + f"{computer_tool.xdotool} mousemove --sync 100 200" + ) + assert result.output == "Mouse moved" + + +@pytest.mark.asyncio +async def test_computer_tool_type(computer_tool): + with ( + patch.object(computer_tool, "shell", new_callable=AsyncMock) as mock_shell, + patch.object( + computer_tool, "screenshot", new_callable=AsyncMock + ) as mock_screenshot, + ): + mock_shell.return_value = ToolResult(output="Text typed") + mock_screenshot.return_value = ToolResult(base64_image="base64_screenshot") + result = await computer_tool(action="type", text="Hello, World!") + assert mock_shell.call_count == 1 + assert "type --delay 12 -- 'Hello, World!'" in mock_shell.call_args[0][0] + assert result.output == "Text typed" + assert result.base64_image == "base64_screenshot" + + +@pytest.mark.asyncio +async def test_computer_tool_screenshot(computer_tool): + with patch.object( + computer_tool, "screenshot", new_callable=AsyncMock + ) as mock_screenshot: + mock_screenshot.return_value = ToolResult(base64_image="base64_screenshot") + result = await computer_tool(action="screenshot") + mock_screenshot.assert_called_once() + assert result.base64_image == "base64_screenshot" + + +@pytest.mark.asyncio +async def test_computer_tool_scaling(computer_tool): + computer_tool._scaling_enabled = True + computer_tool.width = 1920 + computer_tool.height = 1080 + + # Test scaling from API to computer + x, y = computer_tool.scale_coordinates(ScalingSource.API, 1366, 768) + assert x == 1920 + assert y == 1080 + + # Test scaling from computer to API + x, y = computer_tool.scale_coordinates(ScalingSource.COMPUTER, 1920, 1080) + assert x == 1366 + assert y == 768 + + # Test no scaling when disabled + computer_tool._scaling_enabled = False + x, y = computer_tool.scale_coordinates(ScalingSource.API, 1366, 768) + assert x == 1366 + assert y == 768 + + +@pytest.mark.asyncio +async def test_computer_tool_scaling_with_different_aspect_ratio(computer_tool): + computer_tool._scaling_enabled = True + computer_tool.width = 1920 + computer_tool.height = 1200 # 16:10 aspect ratio + + # Test scaling from API to computer + x, y = computer_tool.scale_coordinates(ScalingSource.API, 1280, 800) + assert x == 1920 + assert y == 1200 + + # Test scaling from computer to API + x, y = computer_tool.scale_coordinates(ScalingSource.COMPUTER, 1920, 1200) + assert x == 1280 + assert y == 800 + + +@pytest.mark.asyncio +async def test_computer_tool_no_scaling_for_unsupported_resolution(computer_tool): + computer_tool._scaling_enabled = True + computer_tool.width = 4096 + computer_tool.height = 2160 + + # Test no scaling for unsupported resolution + x, y = computer_tool.scale_coordinates(ScalingSource.API, 4096, 2160) + assert x == 4096 + assert y == 2160 + + x, y = computer_tool.scale_coordinates(ScalingSource.COMPUTER, 4096, 2160) + assert x == 4096 + assert y == 2160 + + +@pytest.mark.asyncio +async def test_computer_tool_scaling_out_of_bounds(computer_tool): + computer_tool._scaling_enabled = True + computer_tool.width = 1920 + computer_tool.height = 1080 + + # Test scaling from API with out of bounds coordinates + with pytest.raises(ToolError, match="Coordinates .*, .* are out of bounds"): + x, y = computer_tool.scale_coordinates(ScalingSource.API, 2000, 1500) + + +@pytest.mark.asyncio +async def test_computer_tool_invalid_action(computer_tool): + with pytest.raises(ToolError, match="Invalid action: invalid_action"): + await computer_tool(action="invalid_action") + + +@pytest.mark.asyncio +async def test_computer_tool_missing_coordinate(computer_tool): + with pytest.raises(ToolError, match="coordinate is required for mouse_move"): + await computer_tool(action="mouse_move") + + +@pytest.mark.asyncio +async def test_computer_tool_missing_text(computer_tool): + with pytest.raises(ToolError, match="text is required for type"): + await computer_tool(action="type") diff --git a/computer-use-demo/tests/tools/edit_test.py b/computer-use-demo/tests/tools/edit_test.py new file mode 100644 index 00000000..c6484152 --- /dev/null +++ b/computer-use-demo/tests/tools/edit_test.py @@ -0,0 +1,330 @@ +from pathlib import Path +from unittest.mock import patch + +import pytest + +from computer_use_demo.tools.base import CLIResult, ToolError, ToolResult +from computer_use_demo.tools.edit import EditTool + + +@pytest.mark.asyncio +async def test_view_command(): + edit_tool = EditTool() + + # Test viewing a file that exists + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text: + mock_read_text.return_value = "File content" + result = await edit_tool(command="view", path="/test/file.txt") + assert isinstance(result, CLIResult) + assert result.output + assert "File content" in result.output + + # Test viewing a directory + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=True + ), patch("computer_use_demo.tools.edit.run") as mock_run: + mock_run.return_value = (None, "file1.txt\nfile2.txt", None) + result = await edit_tool(command="view", path="/test/dir") + assert isinstance(result, CLIResult) + assert result.output + assert "file1.txt" in result.output + assert "file2.txt" in result.output + + # Test viewing a file with a specific range + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text: + mock_read_text.return_value = "Line 1\nLine 2\nLine 3\nLine 4" + result = await edit_tool( + command="view", path="/test/file.txt", view_range=[2, 3] + ) + assert isinstance(result, CLIResult) + assert result.output + assert "\n 2\tLine 2\n 3\tLine 3\n" in result.output + + # Test viewing a file with an invalid range + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text: + mock_read_text.return_value = "Line 1\nLine 2\nLine 3\nLine 4" + with pytest.raises(ToolError, match="Invalid `view_range`"): + await edit_tool(command="view", path="/test/file.txt", view_range=[3, 2]) + + # Test viewing a non-existent file + with patch("pathlib.Path.exists", return_value=False): + with pytest.raises(ToolError, match="does not exist"): + await edit_tool(command="view", path="/nonexistent/file.txt") + + # Test viewing a directory with a view_range + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=True + ): + with pytest.raises(ToolError, match="view_range` parameter is not allowed"): + await edit_tool(command="view", path="/test/dir", view_range=[1, 2]) + + +@pytest.mark.asyncio +async def test_create_command(): + edit_tool = EditTool() + + # Test creating a new file with content + with patch("pathlib.Path.exists", return_value=False), patch( + "pathlib.Path.write_text" + ) as mock_write_text: + result = await edit_tool( + command="create", path="/test/newfile.txt", file_text="New file content" + ) + assert isinstance(result, ToolResult) + assert result.output + assert "File created successfully" in result.output + mock_write_text.assert_called_once_with("New file content") + + # Test attempting to create a file without content + with patch("pathlib.Path.exists", return_value=False): + with pytest.raises(ToolError, match="Parameter `file_text` is required"): + await edit_tool(command="create", path="/test/newfile.txt") + + # Test attempting to create a file that already exists + with patch("pathlib.Path.exists", return_value=True): + with pytest.raises(ToolError, match="File already exists"): + await edit_tool( + command="create", path="/test/existingfile.txt", file_text="Content" + ) + + +@pytest.mark.asyncio +async def test_str_replace_command(): + edit_tool = EditTool() + + # Test replacing a unique string in a file + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text, patch( + "pathlib.Path.write_text" + ) as mock_write_text: + mock_read_text.return_value = "Original content" + result = await edit_tool( + command="str_replace", + path="/test/file.txt", + old_str="Original", + new_str="New", + ) + assert isinstance(result, CLIResult) + assert result.output + assert "has been edited" in result.output + mock_write_text.assert_called_once_with("New content") + + # Test attempting to replace a non-existent string + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text: + mock_read_text.return_value = "Original content" + with pytest.raises(ToolError, match="did not appear verbatim"): + await edit_tool( + command="str_replace", + path="/test/file.txt", + old_str="Nonexistent", + new_str="New", + ) + + # Test attempting to replace a string that appears multiple times + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text: + mock_read_text.return_value = "Test test test" + with pytest.raises(ToolError, match="Multiple occurrences"): + await edit_tool( + command="str_replace", + path="/test/file.txt", + old_str="test", + new_str="example", + ) + + edit_tool._file_history.clear() + # Verify that the file history is updated after replacement + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text, patch( + "pathlib.Path.write_text" + ): + mock_read_text.return_value = "Original content" + await edit_tool( + command="str_replace", + path="/test/file.txt", + old_str="Original", + new_str="New", + ) + assert edit_tool._file_history[Path("/test/file.txt")] == ["Original content"] + + +@pytest.mark.asyncio +async def test_insert_command(): + edit_tool = EditTool() + + # Test inserting a string at a valid line number + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text, patch( + "pathlib.Path.write_text" + ) as mock_write_text: + mock_read_text.return_value = "Line 1\nLine 2\nLine 3" + result = await edit_tool( + command="insert", path="/test/file.txt", insert_line=2, new_str="New Line" + ) + assert isinstance(result, CLIResult) + assert result.output + assert "has been edited" in result.output + mock_write_text.assert_called_once_with("Line 1\nLine 2\nNew Line\nLine 3") + + # Test inserting a string at the beginning of the file (line 0) + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text, patch( + "pathlib.Path.write_text" + ) as mock_write_text: + mock_read_text.return_value = "Line 1\nLine 2" + result = await edit_tool( + command="insert", + path="/test/file.txt", + insert_line=0, + new_str="New First Line", + ) + assert isinstance(result, CLIResult) + assert result.output + assert "has been edited" in result.output + mock_write_text.assert_called_once_with("New First Line\nLine 1\nLine 2") + + # Test inserting a string at the end of the file + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text, patch( + "pathlib.Path.write_text" + ) as mock_write_text: + mock_read_text.return_value = "Line 1\nLine 2" + result = await edit_tool( + command="insert", + path="/test/file.txt", + insert_line=2, + new_str="New Last Line", + ) + assert isinstance(result, CLIResult) + assert result.output + assert "has been edited" in result.output + mock_write_text.assert_called_once_with("Line 1\nLine 2\nNew Last Line") + + # Test attempting to insert at an invalid line number + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text: + mock_read_text.return_value = "Line 1\nLine 2" + with pytest.raises(ToolError, match="Invalid `insert_line` parameter"): + await edit_tool( + command="insert", + path="/test/file.txt", + insert_line=5, + new_str="Invalid Line", + ) + + # Verify that the file history is updated after insertion + edit_tool._file_history.clear() + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text, patch( + "pathlib.Path.write_text" + ): + mock_read_text.return_value = "Original content" + await edit_tool( + command="insert", path="/test/file.txt", insert_line=1, new_str="New Line" + ) + assert edit_tool._file_history[Path("/test/file.txt")] == ["Original content"] + + +@pytest.mark.asyncio +async def test_undo_edit_command(): + edit_tool = EditTool() + + # Test undoing a str_replace operation + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text, patch( + "pathlib.Path.write_text" + ) as mock_write_text: + mock_read_text.return_value = "Original content" + await edit_tool( + command="str_replace", + path="/test/file.txt", + old_str="Original", + new_str="New", + ) + mock_read_text.return_value = "New content" + result = await edit_tool(command="undo_edit", path="/test/file.txt") + assert isinstance(result, CLIResult) + assert result.output + assert "Last edit to /test/file.txt undone successfully" in result.output + mock_write_text.assert_called_with("Original content") + + # Test undoing an insert operation + edit_tool._file_history.clear() + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ), patch("pathlib.Path.read_text") as mock_read_text, patch( + "pathlib.Path.write_text" + ) as mock_write_text: + mock_read_text.return_value = "Line 1\nLine 2" + await edit_tool( + command="insert", path="/test/file.txt", insert_line=1, new_str="New Line" + ) + mock_read_text.return_value = "Line 1\nNew Line\nLine 2" + result = await edit_tool(command="undo_edit", path="/test/file.txt") + assert isinstance(result, CLIResult) + assert result.output + assert "Last edit to /test/file.txt undone successfully" in result.output + mock_write_text.assert_called_with("Line 1\nLine 2") + + # Test attempting to undo when there's no history + edit_tool._file_history.clear() + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ): + with pytest.raises(ToolError, match="No edit history found"): + await edit_tool(command="undo_edit", path="/test/file.txt") + + +@pytest.mark.asyncio +async def test_validate_path(): + edit_tool = EditTool() + + # Test with valid absolute paths + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=False + ): + edit_tool.validate_path("view", Path("/valid/path.txt")) + + # Test with relative paths (should raise an error) + with pytest.raises(ToolError, match="not an absolute path"): + edit_tool.validate_path("view", Path("relative/path.txt")) + + # Test with non-existent paths for non-create commands (should raise an error) + with patch("pathlib.Path.exists", return_value=False): + with pytest.raises(ToolError, match="does not exist"): + edit_tool.validate_path("view", Path("/nonexistent/file.txt")) + + # Test with existing paths for create command (should raise an error) + with patch("pathlib.Path.exists", return_value=True): + with pytest.raises(ToolError, match="File already exists"): + edit_tool.validate_path("create", Path("/existing/file.txt")) + + # Test with directory paths for non-view commands (should raise an error) + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=True + ): + with pytest.raises(ToolError, match="is a directory"): + edit_tool.validate_path("str_replace", Path("/directory/path")) + + # Test with directory path for view command (should not raise an error) + with patch("pathlib.Path.exists", return_value=True), patch( + "pathlib.Path.is_dir", return_value=True + ): + edit_tool.validate_path("view", Path("/directory/path"))