|
34 | 34 | import org.slf4j.Logger; |
35 | 35 | import org.slf4j.LoggerFactory; |
36 | 36 | import reactor.core.publisher.Flux; |
37 | | -import reactor.core.publisher.Sinks; |
38 | | -import reactor.core.publisher.Sinks.EmitFailureHandler; |
39 | 37 | import reactor.core.scheduler.Schedulers; |
40 | 38 | import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; |
41 | 39 | import software.amazon.awssdk.core.SdkBytes; |
|
51 | 49 | import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics; |
52 | 50 | import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; |
53 | 51 | import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; |
54 | | -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; |
55 | 52 | import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; |
56 | | -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; |
57 | 53 | import software.amazon.awssdk.services.bedrockruntime.model.DocumentBlock; |
58 | 54 | import software.amazon.awssdk.services.bedrockruntime.model.DocumentSource; |
59 | 55 | import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock; |
|
76 | 72 |
|
77 | 73 | import org.springframework.ai.bedrock.converse.api.BedrockMediaFormat; |
78 | 74 | import org.springframework.ai.bedrock.converse.api.ConverseApiUtils; |
| 75 | +import org.springframework.ai.bedrock.converse.api.ConverseChatResponseStream; |
79 | 76 | import org.springframework.ai.bedrock.converse.api.URLValidator; |
80 | 77 | import org.springframework.ai.chat.messages.AssistantMessage; |
81 | 78 | import org.springframework.ai.chat.messages.MessageType; |
|
84 | 81 | import org.springframework.ai.chat.metadata.ChatGenerationMetadata; |
85 | 82 | import org.springframework.ai.chat.metadata.ChatResponseMetadata; |
86 | 83 | import org.springframework.ai.chat.metadata.DefaultUsage; |
| 84 | +import org.springframework.ai.chat.metadata.Usage; |
87 | 85 | import org.springframework.ai.chat.model.ChatModel; |
88 | 86 | import org.springframework.ai.chat.model.ChatResponse; |
89 | 87 | import org.springframework.ai.chat.model.Generation; |
@@ -682,11 +680,17 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh |
682 | 680 | .system(converseRequest.system()) |
683 | 681 | .additionalModelRequestFields(converseRequest.additionalModelRequestFields()) |
684 | 682 | .toolConfig(converseRequest.toolConfig()) |
| 683 | + .requestMetadata(converseRequest.requestMetadata()) |
685 | 684 | .build(); |
686 | 685 |
|
687 | | - Flux<ConverseStreamOutput> response = converseStream(converseStreamRequest); |
| 686 | + Usage accumulatedUsage = null; |
| 687 | + if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null) { |
| 688 | + accumulatedUsage = perviousChatResponse.getMetadata().getUsage(); |
| 689 | + } |
688 | 690 |
|
689 | | - Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response, perviousChatResponse); |
| 691 | + Flux<ChatResponse> chatResponses = new ConverseChatResponseStream(this.bedrockRuntimeAsyncClient, |
| 692 | + converseStreamRequest, accumulatedUsage) |
| 693 | + .stream(); |
690 | 694 |
|
691 | 695 | Flux<ChatResponse> chatResponseFlux = chatResponses.switchMap(chatResponse -> { |
692 | 696 |
|
@@ -733,48 +737,6 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh |
733 | 737 | }); |
734 | 738 | } |
735 | 739 |
|
736 | | - public static final EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = EmitFailureHandler |
737 | | - .busyLooping(Duration.ofSeconds(10)); |
738 | | - |
739 | | - /** |
740 | | - * Invoke the model and return the response stream. |
741 | | - * |
742 | | - * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html |
743 | | - * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html |
744 | | - * https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream |
745 | | - * @param converseStreamRequest Model invocation request. |
746 | | - * @return The model invocation response stream. |
747 | | - */ |
748 | | - public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseStreamRequest) { |
749 | | - Assert.notNull(converseStreamRequest, "'converseStreamRequest' must not be null"); |
750 | | - |
751 | | - Sinks.Many<ConverseStreamOutput> eventSink = Sinks.many().multicast().onBackpressureBuffer(); |
752 | | - |
753 | | - ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder() |
754 | | - .onDefault(output -> { |
755 | | - logger.debug("Received converse stream output:{}", output); |
756 | | - eventSink.emitNext(output, DEFAULT_EMIT_FAILURE_HANDLER); |
757 | | - }) |
758 | | - .build(); |
759 | | - |
760 | | - ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder() |
761 | | - .onEventStream(stream -> stream.subscribe(e -> e.accept(visitor))) |
762 | | - .onComplete(() -> { |
763 | | - eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER); |
764 | | - logger.info("Completed streaming response."); |
765 | | - }) |
766 | | - .onError(error -> { |
767 | | - logger.error("Error handling Bedrock converse stream response", error); |
768 | | - eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER); |
769 | | - }) |
770 | | - .build(); |
771 | | - |
772 | | - this.bedrockRuntimeAsyncClient.converseStream(converseStreamRequest, responseHandler); |
773 | | - |
774 | | - return eventSink.asFlux(); |
775 | | - |
776 | | - } |
777 | | - |
778 | 740 | /** |
779 | 741 | * Use the provided convention for reporting observation data |
780 | 742 | * @param observationConvention The provided convention |
|
0 commit comments