Skip to content

Commit

Permalink
Create parent interaction ID in MLAgentExecutor
Browse files Browse the repository at this point in the history
Signed-off-by: Arjun kumar Giri <arjung@amazon.com>
  • Loading branch information
arjunkumargiri committed Dec 1, 2023
1 parent b2ae0ba commit 03f469f
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import org.opensearch.ml.engine.Executable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;

import com.google.gson.Gson;

Expand All @@ -58,6 +60,7 @@ public class MLAgentExecutor implements Executable {

public static final String MEMORY_ID = "memory_id";
public static final String QUESTION = "question";
public static final String PARENT_INTERACTION_ID = "parent_interaction_id";

private Client client;
private Settings settings;
Expand Down Expand Up @@ -108,94 +111,49 @@ public void execute(Input input, ActionListener<Output> listener) {
MLAgent mlAgent = MLAgent.parse(parser);
String memoryType = mlAgent.getMemory().getType();
String memoryId = inputDataSet.getParameters().get(MEMORY_ID);
String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID);
String appType = mlAgent.getAppType();
String title = inputDataSet.getParameters().get(QUESTION);

ConversationIndexMemory.Factory conversationIndexMemoryFactory =
(ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.wrap(m -> {
inputDataSet.getParameters().put(MEMORY_ID, m.getConversationId());
ActionListener<Object> agentActionListener = ActionListener.wrap(output -> {
if (output != null) {
Gson gson = new Gson();
if (output instanceof ModelTensorOutput) {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) output;
modelTensorOutput.getMlModelOutputs().forEach(outs -> {
for (ModelTensor mlModelTensor : outs.getMlModelTensors()) {
modelTensors.add(mlModelTensor);
}
});
} else if (output instanceof ModelTensor) {
modelTensors.add((ModelTensor) output);
} else if (output instanceof List) {
if (((List) output).get(0) instanceof ModelTensor) {
((List<ModelTensor>) output).forEach(mlModelTensor -> modelTensors.add(mlModelTensor));
} else if (((List) output).get(0) instanceof ModelTensors) {
((List<ModelTensors>) output).forEach(outs -> {
for (ModelTensor mlModelTensor : outs.getMlModelTensors()) {
modelTensors.add(mlModelTensor);
}
});
} else {
Object finalOutput = output;
String result = output instanceof String
? (String) output
: AccessController
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(finalOutput));
modelTensors.add(ModelTensor.builder().name("response").result(result).build());
}
} else {
Object finalOutput = output;
String result = output instanceof String
? (String) output
: AccessController
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(finalOutput));
modelTensors.add(ModelTensor.builder().name("response").result(result).build());
}
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(outputs).build());
} else {
listener.onResponse(null);
}
String question = inputDataSet.getParameters().get(QUESTION);

if (memoryType != null
&& memoryFactoryMap.containsKey(memoryType)
&& (memoryId == null || parentInteractionId == null)) {
ConversationIndexMemory.Factory conversationIndexMemoryFactory =
(ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType);
conversationIndexMemoryFactory.create(question, memoryId, appType, ActionListener.wrap(memory -> {
inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId());
ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory;
// Create root interaction ID
ConversationIndexMessage msg = ConversationIndexMessage
.conversationIndexMessageBuilder()
.type(appType)
.question(question)
.response("")
.finalAnswer(true)
.sessionId(memory.getConversationId())
.build();
conversationIndexMemory
.save(msg, null, null, null, ActionListener.<CreateInteractionResponse>wrap(interaction -> {
log.info("Created parent interaction ID: " + interaction.getId());
inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interaction.getId());
ActionListener<Object> agentActionListener = createAgentActionListener(
listener,
outputs,
modelTensors
);
executeAgent(inputDataSet, mlAgent, agentActionListener);
}, ex -> {
log.error("Failed to create parent interaction", ex);
listener.onFailure(ex);
}));
}, ex -> {
log.error("Failed to run flow agent", ex);
log.error("Failed to read conversation memory", ex);
listener.onFailure(ex);
});

if ("flow".equals(mlAgent.getType())) {
MLFlowAgentRunner flowAgentExecutor = new MLFlowAgentRunner(
client,
settings,
clusterService,
xContentRegistry,
toolFactories,
memoryFactoryMap
);
flowAgentExecutor.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
} else if ("cot".equals(mlAgent.getType())) {
MLReActAgentRunner reactAgentExecutor = new MLReActAgentRunner(
client,
settings,
clusterService,
xContentRegistry,
toolFactories,
memoryFactoryMap
);
reactAgentExecutor.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
} else if ("conversational".equals(mlAgent.getType())) {
MLChatAgentRunner chatAgentRunner = new MLChatAgentRunner(
client,
settings,
clusterService,
xContentRegistry,
toolFactories,
memoryFactoryMap
);
chatAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
}
}, ex -> {
log.error("Failed to read conversation memory", ex);
listener.onFailure(ex);
}));
}));
} else {
ActionListener<Object> agentActionListener = createAgentActionListener(listener, outputs, modelTensors);
executeAgent(inputDataSet, mlAgent, agentActionListener);
}
}
} else {
listener.onFailure(new ResourceNotFoundException("Agent not found"));
Expand All @@ -209,6 +167,90 @@ public void execute(Input input, ActionListener<Output> listener) {

}

private void executeAgent(RemoteInferenceInputDataSet inputDataSet, MLAgent mlAgent, ActionListener<Object> agentActionListener) {
if ("flow".equals(mlAgent.getType())) {
MLFlowAgentRunner flowAgentExecutor = new MLFlowAgentRunner(
client,
settings,
clusterService,
xContentRegistry,
toolFactories,
memoryFactoryMap
);
flowAgentExecutor.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
} else if ("cot".equals(mlAgent.getType())) {
MLReActAgentRunner reactAgentExecutor = new MLReActAgentRunner(
client,
settings,
clusterService,
xContentRegistry,
toolFactories,
memoryFactoryMap
);
reactAgentExecutor.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
} else if ("conversational".equals(mlAgent.getType())) {
MLChatAgentRunner chatAgentRunner = new MLChatAgentRunner(
client,
settings,
clusterService,
xContentRegistry,
toolFactories,
memoryFactoryMap
);
chatAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
}
}

private ActionListener<Object> createAgentActionListener(
ActionListener<Output> listener,
List<ModelTensors> outputs,
List<ModelTensor> modelTensors
) {
return ActionListener.wrap(output -> {
if (output != null) {
Gson gson = new Gson();
if (output instanceof ModelTensorOutput) {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) output;
modelTensorOutput.getMlModelOutputs().forEach(outs -> {
for (ModelTensor mlModelTensor : outs.getMlModelTensors()) {
modelTensors.add(mlModelTensor);
}
});
} else if (output instanceof ModelTensor) {
modelTensors.add((ModelTensor) output);
} else if (output instanceof List) {
if (((List) output).get(0) instanceof ModelTensor) {
((List<ModelTensor>) output).forEach(mlModelTensor -> modelTensors.add(mlModelTensor));
} else if (((List) output).get(0) instanceof ModelTensors) {
((List<ModelTensors>) output).forEach(outs -> {
for (ModelTensor mlModelTensor : outs.getMlModelTensors()) {
modelTensors.add(mlModelTensor);
}
});
} else {
Object finalOutput = output;
String result = output instanceof String
? (String) output
: AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(finalOutput));
modelTensors.add(ModelTensor.builder().name("response").result(result).build());
}
} else {
Object finalOutput = output;
String result = output instanceof String
? (String) output
: AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(finalOutput));
modelTensors.add(ModelTensor.builder().name("response").result(result).build());
}
listener.onResponse(ModelTensorOutput.builder().mlModelOutputs(outputs).build());
} else {
listener.onResponse(null);
}
}, ex -> {
log.error("Failed to run flow agent", ex);
listener.onFailure(ex);
});
}

public XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference)
throws IOException {
return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON);
Expand Down
Loading

0 comments on commit 03f469f

Please sign in to comment.