|
21 | 21 | BinaryContent, |
22 | 22 | BuiltinToolCallPart, |
23 | 23 | BuiltinToolReturnPart, |
| 24 | + CachePoint, |
24 | 25 | DocumentUrl, |
25 | 26 | ImageUrl, |
26 | 27 | ModelMessage, |
@@ -296,9 +297,14 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes |
296 | 297 | tool_call_id=tool_use['toolUseId'], |
297 | 298 | ), |
298 | 299 | ) |
| 300 | + cache_read_tokens = response['usage'].get('cacheReadInputTokens', 0) |
| 301 | + cache_write_tokens = response['usage'].get('cacheWriteInputTokens', 0) |
| 302 | + input_tokens = response['usage']['inputTokens'] + cache_read_tokens + cache_write_tokens |
299 | 303 | u = usage.RequestUsage( |
300 | | - input_tokens=response['usage']['inputTokens'], |
| 304 | + input_tokens=input_tokens, |
301 | 305 | output_tokens=response['usage']['outputTokens'], |
| 306 | + cache_read_tokens=cache_read_tokens, |
| 307 | + cache_write_tokens=cache_write_tokens, |
302 | 308 | ) |
303 | 309 | response_id = response.get('ResponseMetadata', {}).get('RequestId', None) |
304 | 310 | return ModelResponse( |
@@ -346,7 +352,12 @@ async def _messages_create( |
346 | 352 | 'inferenceConfig': inference_config, |
347 | 353 | } |
348 | 354 |
|
349 | | - tool_config = self._map_tool_config(model_request_parameters) |
| 355 | + tool_config = self._map_tool_config( |
| 356 | + model_request_parameters, |
| 357 | + should_add_cache_point=( |
| 358 | + not system_prompt and BedrockModelProfile.from_profile(self.profile).bedrock_supports_prompt_caching |
| 359 | + ), |
| 360 | + ) |
350 | 361 | if tool_config: |
351 | 362 | params['toolConfig'] = tool_config |
352 | 363 |
|
@@ -395,11 +406,16 @@ def _map_inference_config( |
395 | 406 |
|
396 | 407 | return inference_config |
397 | 408 |
|
398 | | - def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> ToolConfigurationTypeDef | None: |
| 409 | + def _map_tool_config( |
| 410 | + self, model_request_parameters: ModelRequestParameters, should_add_cache_point: bool = False |
| 411 | + ) -> ToolConfigurationTypeDef | None: |
399 | 412 | tools = self._get_tools(model_request_parameters) |
400 | 413 | if not tools: |
401 | 414 | return None |
402 | 415 |
|
| 416 | + if should_add_cache_point: |
| 417 | + tools[-1]['cachePoint'] = {'type': 'default'} |
| 418 | + |
403 | 419 | tool_choice: ToolChoiceTypeDef |
404 | 420 | if not model_request_parameters.allow_text_output: |
405 | 421 | tool_choice = {'any': {}} |
@@ -429,7 +445,12 @@ async def _map_messages( # noqa: C901 |
429 | 445 | if isinstance(part, SystemPromptPart) and part.content: |
430 | 446 | system_prompt.append({'text': part.content}) |
431 | 447 | elif isinstance(part, UserPromptPart): |
432 | | - bedrock_messages.extend(await self._map_user_prompt(part, document_count)) |
| 448 | + has_leading_cache_point, user_messages = await self._map_user_prompt( |
| 449 | + part, document_count, profile.bedrock_supports_prompt_caching |
| 450 | + ) |
| 451 | + if has_leading_cache_point: |
| 452 | + system_prompt.append({'cachePoint': {'type': 'default'}}) |
| 453 | + bedrock_messages.extend(user_messages) |
433 | 454 | elif isinstance(part, ToolReturnPart): |
434 | 455 | assert part.tool_call_id is not None |
435 | 456 | bedrock_messages.append( |
@@ -522,13 +543,22 @@ async def _map_messages( # noqa: C901 |
522 | 543 |
|
523 | 544 | return system_prompt, processed_messages |
524 | 545 |
|
525 | | - @staticmethod |
526 | | - async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) -> list[MessageUnionTypeDef]: |
| 546 | + async def _map_user_prompt( # noqa: C901 |
| 547 | + self, part: UserPromptPart, document_count: Iterator[int], supports_caching: bool |
| 548 | + ) -> tuple[bool, list[MessageUnionTypeDef]]: |
527 | 549 | content: list[ContentBlockUnionTypeDef] = [] |
| 550 | + has_leading_cache_point = False |
| 551 | + |
528 | 552 | if isinstance(part.content, str): |
529 | 553 | content.append({'text': part.content}) |
530 | 554 | else: |
531 | | - for item in part.content: |
| 555 | + if part.content and isinstance(part.content[0], CachePoint): |
| 556 | + has_leading_cache_point = True |
| 557 | + items_to_process = part.content[1:] |
| 558 | + else: |
| 559 | + items_to_process = part.content |
| 560 | + |
| 561 | + for item in items_to_process: |
532 | 562 | if isinstance(item, str): |
533 | 563 | content.append({'text': item}) |
534 | 564 | elif isinstance(item, BinaryContent): |
@@ -578,11 +608,15 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) |
578 | 608 | ), f'Unsupported video format: {format}' |
579 | 609 | video: VideoBlockTypeDef = {'format': format, 'source': {'bytes': downloaded_item['data']}} |
580 | 610 | content.append({'video': video}) |
| 611 | + elif isinstance(item, CachePoint): |
| 612 | + if supports_caching: |
| 613 | + content.append({'cachePoint': {'type': 'default'}}) |
| 614 | + continue |
581 | 615 | elif isinstance(item, AudioUrl): # pragma: no cover |
582 | 616 | raise NotImplementedError('Audio is not supported yet.') |
583 | 617 | else: |
584 | 618 | assert_never(item) |
585 | | - return [{'role': 'user', 'content': content}] |
| 619 | + return has_leading_cache_point, [{'role': 'user', 'content': content}] |
586 | 620 |
|
587 | 621 | @staticmethod |
588 | 622 | def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef: |
@@ -674,9 +708,14 @@ def timestamp(self) -> datetime: |
674 | 708 | return self._timestamp |
675 | 709 |
|
676 | 710 | def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage: |
| 711 | + cache_read_tokens = metadata['usage'].get('cacheReadInputTokens', 0) |
| 712 | + cache_write_tokens = metadata['usage'].get('cacheWriteInputTokens', 0) |
| 713 | + input_tokens = metadata['usage']['inputTokens'] + cache_read_tokens + cache_write_tokens |
677 | 714 | return usage.RequestUsage( |
678 | | - input_tokens=metadata['usage']['inputTokens'], |
| 715 | + input_tokens=input_tokens, |
679 | 716 | output_tokens=metadata['usage']['outputTokens'], |
| 717 | + cache_write_tokens=cache_write_tokens, |
| 718 | + cache_read_tokens=cache_read_tokens, |
680 | 719 | ) |
681 | 720 |
|
682 | 721 |
|
|
0 commit comments