Skip to content

Commit a897177

Browse files
emopti-jruferilayaperumalg
authored andcommitted
Fix converse streaming issues:
- Correct finish reason when stop reason is not tool_use - Populate finish reason for non-tool_use cases - Ensure multiple tool calls are output in ChatResponse Closes gh-4374, gh-4126, gh-3251 Added missing requestMetadata mapping for ConverseStreamRequest. Switch tests to use cross-region inference for model access reliability. Signed-off-by: Jared Rufer <jrufer@emopti.com>
1 parent 967d00c commit a897177

File tree

10 files changed

+344
-498
lines changed

10 files changed

+344
-498
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 10 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
import org.slf4j.Logger;
3535
import org.slf4j.LoggerFactory;
3636
import reactor.core.publisher.Flux;
37-
import reactor.core.publisher.Sinks;
38-
import reactor.core.publisher.Sinks.EmitFailureHandler;
3937
import reactor.core.scheduler.Schedulers;
4038
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
4139
import software.amazon.awssdk.core.SdkBytes;
@@ -51,9 +49,7 @@
5149
import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics;
5250
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
5351
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
54-
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
5552
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
56-
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
5753
import software.amazon.awssdk.services.bedrockruntime.model.DocumentBlock;
5854
import software.amazon.awssdk.services.bedrockruntime.model.DocumentSource;
5955
import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock;
@@ -76,6 +72,7 @@
7672

7773
import org.springframework.ai.bedrock.converse.api.BedrockMediaFormat;
7874
import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
75+
import org.springframework.ai.bedrock.converse.api.ConverseChatResponseStream;
7976
import org.springframework.ai.bedrock.converse.api.URLValidator;
8077
import org.springframework.ai.chat.messages.AssistantMessage;
8178
import org.springframework.ai.chat.messages.MessageType;
@@ -84,6 +81,7 @@
8481
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
8582
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
8683
import org.springframework.ai.chat.metadata.DefaultUsage;
84+
import org.springframework.ai.chat.metadata.Usage;
8785
import org.springframework.ai.chat.model.ChatModel;
8886
import org.springframework.ai.chat.model.ChatResponse;
8987
import org.springframework.ai.chat.model.Generation;
@@ -682,11 +680,17 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
682680
.system(converseRequest.system())
683681
.additionalModelRequestFields(converseRequest.additionalModelRequestFields())
684682
.toolConfig(converseRequest.toolConfig())
683+
.requestMetadata(converseRequest.requestMetadata())
685684
.build();
686685

687-
Flux<ConverseStreamOutput> response = converseStream(converseStreamRequest);
686+
Usage accumulatedUsage = null;
687+
if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null) {
688+
accumulatedUsage = perviousChatResponse.getMetadata().getUsage();
689+
}
688690

689-
Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response, perviousChatResponse);
691+
Flux<ChatResponse> chatResponses = new ConverseChatResponseStream(this.bedrockRuntimeAsyncClient,
692+
converseStreamRequest, accumulatedUsage)
693+
.stream();
690694

691695
Flux<ChatResponse> chatResponseFlux = chatResponses.switchMap(chatResponse -> {
692696

@@ -733,48 +737,6 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
733737
});
734738
}
735739

736-
public static final EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = EmitFailureHandler
737-
.busyLooping(Duration.ofSeconds(10));
738-
739-
/**
740-
* Invoke the model and return the response stream.
741-
*
742-
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
743-
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
744-
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream
745-
* @param converseStreamRequest Model invocation request.
746-
* @return The model invocation response stream.
747-
*/
748-
public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseStreamRequest) {
749-
Assert.notNull(converseStreamRequest, "'converseStreamRequest' must not be null");
750-
751-
Sinks.Many<ConverseStreamOutput> eventSink = Sinks.many().multicast().onBackpressureBuffer();
752-
753-
ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder()
754-
.onDefault(output -> {
755-
logger.debug("Received converse stream output:{}", output);
756-
eventSink.emitNext(output, DEFAULT_EMIT_FAILURE_HANDLER);
757-
})
758-
.build();
759-
760-
ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder()
761-
.onEventStream(stream -> stream.subscribe(e -> e.accept(visitor)))
762-
.onComplete(() -> {
763-
eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER);
764-
logger.info("Completed streaming response.");
765-
})
766-
.onError(error -> {
767-
logger.error("Error handling Bedrock converse stream response", error);
768-
eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER);
769-
})
770-
.build();
771-
772-
this.bedrockRuntimeAsyncClient.converseStream(converseStreamRequest, responseHandler);
773-
774-
return eventSink.asFlux();
775-
776-
}
777-
778740
/**
779741
* Use the provided convention for reporting observation data
780742
* @param observationConvention The provided convention

0 commit comments

Comments
 (0)