Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[computer-use-demo] Add prompt caching #64

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions computer-use-demo/computer_use_demo/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@
from enum import StrEnum
from typing import Any, cast

from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
from anthropic import (
Anthropic,
AnthropicBedrock,
AnthropicVertex,
APIResponse,
BaseModel,
)
from anthropic.types import (
ToolResultBlockParam,
)
from anthropic.types.beta import (
BetaCacheControlEphemeralParam,
BetaContentBlock,
BetaContentBlockParam,
BetaImageBlockParam,
Expand All @@ -24,8 +31,6 @@

from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult

BETA_FLAG = "computer-use-2024-10-22"


class APIProvider(StrEnum):
ANTHROPIC = "anthropic"
Expand Down Expand Up @@ -74,6 +79,7 @@ async def sampling_loop(
api_key: str,
only_n_most_recent_images: int | None = None,
max_tokens: int = 4096,
prompt_caching: bool = True,
):
"""
Agentic sampling loop for the assistant/tool interaction of computer use.
Expand All @@ -98,6 +104,11 @@ async def sampling_loop(
elif provider == APIProvider.BEDROCK:
client = AnthropicBedrock()

betas = ["computer-use-2024-10-22"]
if prompt_caching:
betas.append("prompt-caching-2024-07-31")
_add_prompt_caching_headers(messages)

# Call the API
# we use raw_response to provide debug information to streamlit. Your
# implementation may be able call the SDK directly with:
Expand All @@ -108,7 +119,7 @@ async def sampling_loop(
model=model,
system=system,
tools=tool_collection.to_params(),
betas=["computer-use-2024-10-22"],
betas=betas,
)

api_response_callback(cast(APIResponse[BetaMessage], raw_response))
Expand Down Expand Up @@ -230,3 +241,36 @@ def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str):
if result.system:
result_text = f"<system>{result.system}</system>\n{result_text}"
return result_text


MAX_PROMPT_CACHING_BREAKPOINTS = 4


def _add_prompt_caching_headers(
messages: list[BetaMessageParam],
):
prompt_caching_breakpoints = 0
for message in messages:
if isinstance(message["content"], str):
continue

params: list[BetaContentBlockParam] = []
for content_block in message["content"]:
if isinstance(content_block, BaseModel):
content_block_param = cast(
BetaContentBlockParam, content_block.to_dict()
)
else:
content_block_param = content_block
params.append(content_block_param)

if (
isinstance(content_block_param, dict)
and content_block_param.get("type") == "image"
and prompt_caching_breakpoints < MAX_PROMPT_CACHING_BREAKPOINTS
):
content_block_param["cache_control"] = BetaCacheControlEphemeralParam(
type="ephemeral"
)
prompt_caching_breakpoints += 1
message["content"] = params
2 changes: 1 addition & 1 deletion computer-use-demo/computer_use_demo/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _reset_api_provider():
st.session_state.messages.append(
{
"role": Sender.USER,
"content": [TextBlock(type="text", text=new_message)],
"content": [BetaTextBlock(type="text", text=new_message)],
}
)
_render_message(Sender.USER, new_message)
Expand Down
19 changes: 12 additions & 7 deletions computer-use-demo/tests/loop_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from unittest import mock

from anthropic.types import TextBlock, ToolUseBlock
from anthropic.types.beta import BetaMessage, BetaMessageParam
from anthropic.types.beta import (
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaToolUseBlock,
)

from computer_use_demo.loop import APIProvider, sampling_loop

Expand All @@ -13,13 +17,13 @@ async def test_loop():
mock.Mock(
spec=BetaMessage,
content=[
TextBlock(type="text", text="Hello"),
ToolUseBlock(
BetaTextBlock(type="text", text="Hello"),
BetaToolUseBlock(
type="tool_use", id="1", name="computer", input={"action": "test"}
),
],
),
mock.Mock(spec=BetaMessage, content=[TextBlock(type="text", text="Done!")]),
mock.Mock(spec=BetaMessage, content=[BetaTextBlock(type="text", text="Done!")]),
]

tool_collection = mock.AsyncMock()
Expand Down Expand Up @@ -49,7 +53,8 @@ async def test_loop():
)

assert len(result) == 4
assert result[0] == {"role": "user", "content": "Test message"}
assert result[0]["role"] == "user"
assert result[0]["content"] == "Test message"
assert result[1]["role"] == "assistant"
assert result[2]["role"] == "user"
assert result[3]["role"] == "assistant"
Expand All @@ -58,7 +63,7 @@ async def test_loop():
tool_collection.run.assert_called_once_with(
name="computer", tool_input={"action": "test"}
)
output_callback.assert_called_with(TextBlock(text="Done!", type="text"))
output_callback.assert_called_with(BetaTextBlock(text="Done!", type="text"))
assert output_callback.call_count == 3
assert tool_output_callback.call_count == 1
assert api_response_callback.call_count == 2
5 changes: 3 additions & 2 deletions computer-use-demo/tests/streamlit_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from unittest import mock

import pytest
from anthropic.types.beta import BetaTextBlock
from streamlit.testing.v1 import AppTest

from computer_use_demo.streamlit import Sender, TextBlock
from computer_use_demo.streamlit import Sender


@pytest.fixture
Expand All @@ -18,6 +19,6 @@ def test_streamlit(streamlit_app: AppTest):
streamlit_app.chat_input[0].set_value("Hello").run()
assert patch.called
assert patch.call_args.kwargs["messages"] == [
{"role": Sender.USER, "content": [TextBlock(text="Hello", type="text")]}
{"role": Sender.USER, "content": [BetaTextBlock(text="Hello", type="text")]}
]
assert not streamlit_app.exception
Loading