Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Aug 30, 2024
1 parent c8d93f1 commit 0c7a0a6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 15 deletions.
56 changes: 41 additions & 15 deletions integrations/anthropic/example/prompt_caching.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# To run this example, you will need to set a `ANTHROPIC_API_KEY` environment variable.

import time

from haystack import Pipeline
from haystack.components.builders import ChatPromptBuilder
from haystack.components.converters import HTMLToDocument
from haystack.components.fetchers import LinkContentFetcher
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils import Secret

from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator

enable_prompt_caching = True

msg = ChatMessage.from_system(
"You are a prompt expert who answers questions based on the given documents.\n"
"Here are the documents:\n"
Expand All @@ -18,6 +21,19 @@
"{% endfor %}"
)


def measure_and_print_streaming_chunk():
first_token_time = None

def stream_callback(chunk: StreamingChunk) -> None:
nonlocal first_token_time
if first_token_time is None:
first_token_time = time.time()
print(chunk.content, flush=True, end="")

return stream_callback, lambda: first_token_time


fetch_pipeline = Pipeline()
fetch_pipeline.add_component("fetcher", LinkContentFetcher())
fetch_pipeline.add_component("converter", HTMLToDocument())
Expand All @@ -36,28 +52,32 @@
final_prompt_msg = result["prompt_builder"]["prompt"][0]

# We add a cache control header to the prompt message
final_prompt_msg.meta["cache_control"] = {"type": "ephemeral"}
if enable_prompt_caching:
final_prompt_msg.meta["cache_control"] = {"type": "ephemeral"}

generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if enable_prompt_caching else {}
claude_llm = AnthropicChatGenerator(
api_key=Secret.from_env_var("ANTHROPIC_API_KEY"),
generation_kwargs=generation_kwargs,
)

# Build QA pipeline
qa_pipeline = Pipeline()
qa_pipeline.add_component(
"llm",
AnthropicChatGenerator(
api_key=Secret.from_env_var("ANTHROPIC_API_KEY"),
streaming_callback=print_streaming_chunk,
generation_kwargs={"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}},
),
)
qa_pipeline.add_component("llm", claude_llm)

questions = [
"Why is Monte-Carlo Tree Search used in LATS",
"Summarize LATS selection, expansion, evaluation, simulation, backpropagation, and reflection",
"What's this paper about?",
"What's the main contribution of this paper?",
"How can findings from this paper be applied to real-world problems?",
]

# Answer the questions using prompt caching (i.e. the entire document is cached, we run the question against it)
for question in questions:
print("Question: " + question)
start_time = time.time()
streaming_callback, get_first_token_time = measure_and_print_streaming_chunk()
claude_llm.streaming_callback = streaming_callback

result = qa_pipeline.run(
data={
"llm": {
Expand All @@ -69,5 +89,11 @@
}
)

print("\n\nChecking cache usage:", result["llm"]["replies"][0].meta.get("usage"))
print("\n")
end_time = time.time()
total_time = end_time - start_time
time_to_first_token = get_first_token_time() - start_time

print(f"\nTotal time: {total_time:.2f} seconds")
print(f"Time to first token: {time_to_first_token:.2f} seconds")
print(f"Cache usage: {result['llm']['replies'][0].meta.get('usage')}")
print("\n" + "=" * 50)
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
self._convert_to_anthropic_format(non_system_messages) if non_system_messages else []
)

extra_headers = filtered_generation_kwargs.get("extra_headers", {})
prompt_caching_on = "anthropic-beta" in extra_headers and "prompt-caching" in extra_headers["anthropic-beta"]
has_cached_messages = any("cache_control" in m for m in system_messages_formatted) or any(
"cache_control" in m for m in messages_formatted
)
if has_cached_messages and not prompt_caching_on:
logger.warn(
"Prompt caching is not enabled but you requested individual messages to be cached. "
"Messages will be sent to the API without prompt caching."
)
for m in system_messages_formatted:
m.pop("cache_control", None)
for m in messages_formatted:
m.pop("cache_control", None)

response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create(
max_tokens=filtered_generation_kwargs.pop("max_tokens", 512),
system=system_messages_formatted or filtered_generation_kwargs.pop("system", ""),
Expand Down

0 comments on commit 0c7a0a6

Please sign in to comment.