Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 90 additions & 116 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,8 @@ private Single<Event> appendNewMessageToSession(
}

if (this.artifactService != null && saveInputBlobsAsArtifacts) {
// The runner directly saves the artifacts (if applicable) in the
// user message and replaces the artifact data with a file name
// placeholder.
// The runner directly saves the artifacts (if applicable) in the user message and replaces
// the artifact data with a file name placeholder.
for (int i = 0; i < newMessage.parts().get().size(); i++) {
Part part = newMessage.parts().get().get(i);
if (part.inlineData().isEmpty()) {
Expand Down Expand Up @@ -426,12 +425,10 @@ public Flowable<Event> runAsync(

// Create initial context
InvocationContext initialContext =
newInvocationContextBuilder(
session,
Optional.of(newMessage),
/* liveRequestQueue= */ Optional.empty(),
runConfig)
newInvocationContextBuilder(session)
.invocationId(invocationId)
.runConfig(runConfig)
.userContent(newMessage)
.build();

return Telemetry.traceFlowable(
Expand All @@ -455,6 +452,9 @@ public Flowable<Event> runAsync(
: Single.just(null))
.flatMapPublisher(
event -> {
if (event == null) {
return Flowable.empty();
}
// Get the updated session after the message and state delta are
// applied
return this.sessionService
Expand All @@ -464,80 +464,14 @@ public Flowable<Event> runAsync(
session.id(),
Optional.empty())
.flatMapPublisher(
updatedSession -> {
// Create context with updated session for
// beforeRunCallback
InvocationContext contextWithUpdatedSession =
newInvocationContextBuilder(
updatedSession,
event.content(),
/* liveRequestQueue= */ Optional.empty(),
runConfig)
.invocationId(invocationId)
.agent(
this.findAgentToRun(
updatedSession, rootAgent))
.build();

// Call beforeRunCallback with updated session
Maybe<Event> beforeRunEvent =
this.pluginManager
.beforeRunCallback(contextWithUpdatedSession)
.map(
content ->
Event.builder()
.id(Event.generateEventId())
.invocationId(
contextWithUpdatedSession
.invocationId())
.author("model")
.content(Optional.of(content))
.build());

// Agent execution
Flowable<Event> agentEvents =
contextWithUpdatedSession
.agent()
.runAsync(contextWithUpdatedSession)
.flatMap(
agentEvent ->
this.sessionService
.appendEvent(
updatedSession, agentEvent)
.flatMap(
registeredEvent -> {
// TODO: remove this hack
// after
// deprecating runAsync with
// Session.
copySessionStates(
updatedSession,
session);
return contextWithUpdatedSession
.combinedPlugin()
.onEventCallback(
contextWithUpdatedSession,
registeredEvent)
.defaultIfEmpty(
registeredEvent);
})
.toFlowable());

// If beforeRunCallback returns content, emit it and
// skip
// agent
return beforeRunEvent
.toFlowable()
.switchIfEmpty(agentEvents)
.concatWith(
Completable.defer(
() ->
pluginManager.runAfterRunCallback(
contextWithUpdatedSession)))
.concatWith(
Completable.defer(
() -> compactEvents(updatedSession)));
});
updatedSession ->
runAgentWithFreshSession(
session,
updatedSession,
event,
invocationId,
runConfig,
rootAgent));
}))
.doOnError(
throwable -> {
Expand All @@ -552,6 +486,64 @@ public Flowable<Event> runAsync(
}
}

private Flowable<Event> runAgentWithFreshSession(
Session session,
Session updatedSession,
Event event,
String invocationId,
RunConfig runConfig,
BaseAgent rootAgent) {
// Create context with updated session for beforeRunCallback
InvocationContext contextWithUpdatedSession =
newInvocationContextBuilder(updatedSession)
.invocationId(invocationId)
.agent(this.findAgentToRun(updatedSession, rootAgent))
.runConfig(runConfig)
.userContent(event.content().orElseGet(Content::fromParts))
.build();

// Call beforeRunCallback with updated session
Maybe<Event> beforeRunEvent =
this.pluginManager
.beforeRunCallback(contextWithUpdatedSession)
.map(
content ->
Event.builder()
.id(Event.generateEventId())
.invocationId(contextWithUpdatedSession.invocationId())
.author("model")
.content(Optional.of(content))
.build());

// Agent execution
Flowable<Event> agentEvents =
contextWithUpdatedSession
.agent()
.runAsync(contextWithUpdatedSession)
.flatMap(
agentEvent ->
this.sessionService
.appendEvent(updatedSession, agentEvent)
.flatMap(
registeredEvent -> {
// TODO: remove this hack after deprecating runAsync with Session.
copySessionStates(updatedSession, session);
return contextWithUpdatedSession
.combinedPlugin()
.onEventCallback(contextWithUpdatedSession, registeredEvent)
.defaultIfEmpty(registeredEvent);
})
.toFlowable());

// If beforeRunCallback returns content, emit it and skip agent
return beforeRunEvent
.toFlowable()
.switchIfEmpty(agentEvents)
.concatWith(
Completable.defer(() -> pluginManager.runAfterRunCallback(contextWithUpdatedSession)))
.concatWith(Completable.defer(() -> compactEvents(updatedSession)));
}

private Completable compactEvents(Session session) {
return Optional.ofNullable(eventsCompactionConfig)
.map(SlidingWindowEventCompactor::new)
Expand Down Expand Up @@ -590,43 +582,25 @@ private InvocationContext newInvocationContextForLive(
runConfigBuilder.setInputAudioTranscription(AudioTranscriptionConfig.builder().build());
}
}
return newInvocationContext(
session, /* newMessage= */ Optional.empty(), liveRequestQueue, runConfigBuilder.build());
InvocationContext.Builder builder =
newInvocationContextBuilder(session)
.runConfig(runConfigBuilder.build())
.userContent(Content.fromParts());
liveRequestQueue.ifPresent(builder::liveRequestQueue);
return builder.build();
}

/**
* Creates an {@link InvocationContext} for the given session, request queue, and config.
*
* @return a new {@link InvocationContext}.
*/
private InvocationContext newInvocationContext(
Session session,
Optional<Content> newMessage,
Optional<LiveRequestQueue> liveRequestQueue,
RunConfig runConfig) {
return newInvocationContextBuilder(session, newMessage, liveRequestQueue, runConfig).build();
}

private InvocationContext.Builder newInvocationContextBuilder(
Session session,
Optional<Content> newMessage,
Optional<LiveRequestQueue> liveRequestQueue,
RunConfig runConfig) {
private InvocationContext.Builder newInvocationContextBuilder(Session session) {
BaseAgent rootAgent = this.agent;
var invocationContextBuilder =
InvocationContext.builder()
.sessionService(this.sessionService)
.artifactService(this.artifactService)
.memoryService(this.memoryService)
.pluginManager(this.pluginManager)
.agent(rootAgent)
.session(session)
.userContent(newMessage.orElseGet(() -> Content.fromParts()))
.runConfig(runConfig)
.resumabilityConfig(this.resumabilityConfig)
.agent(this.findAgentToRun(session, rootAgent));
liveRequestQueue.ifPresent(invocationContextBuilder::liveRequestQueue);
return invocationContextBuilder;
return InvocationContext.builder()
.sessionService(this.sessionService)
.artifactService(this.artifactService)
.memoryService(this.memoryService)
.pluginManager(this.pluginManager)
.agent(rootAgent)
.session(session)
.resumabilityConfig(this.resumabilityConfig)
.agent(this.findAgentToRun(session, rootAgent));
}

/**
Expand Down