1616
1717package io .serverlessworkflow .impl .executors .ai ;
1818
19- import dev .langchain4j .data .message .AiMessage ;
20- import dev .langchain4j .data .message .ChatMessage ;
21- import dev .langchain4j .data .message .SystemMessage ;
22- import dev .langchain4j .data .message .UserMessage ;
23- import dev .langchain4j .model .chat .ChatModel ;
24- import dev .langchain4j .model .chat .response .ChatResponse ;
25- import dev .langchain4j .model .output .FinishReason ;
26- import dev .langchain4j .model .output .TokenUsage ;
2719import io .serverlessworkflow .ai .api .types .CallAILangChainChatModel ;
2820import io .serverlessworkflow .api .types .TaskBase ;
2921import io .serverlessworkflow .api .types .ai .CallAIChatModel ;
3426import io .serverlessworkflow .impl .WorkflowModelFactory ;
3527import io .serverlessworkflow .impl .executors .CallableTask ;
3628import io .serverlessworkflow .impl .resources .ResourceLoader ;
37- import io .serverlessworkflow .impl .services .ChatModelService ;
38- import java .util .ArrayList ;
39- import java .util .HashSet ;
40- import java .util .List ;
41- import java .util .Map ;
42- import java .util .ServiceLoader ;
43- import java .util .Set ;
4429import java .util .concurrent .CompletableFuture ;
45- import java .util .regex .Matcher ;
46- import java .util .regex .Pattern ;
4730
4831public class AIChatModelCallExecutor implements CallableTask <CallAIChatModel > {
4932
50- private static final Pattern VARIABLE_PATTERN = Pattern .compile ("\\ {\\ {\\ s*(.+?)\\ s*\\ }\\ }" );
51-
5233 @ Override
5334 public void init (CallAIChatModel task , WorkflowApplication application , ResourceLoader loader ) {}
5435
@@ -58,12 +39,13 @@ public CompletableFuture<WorkflowModel> apply(
5839 WorkflowModelFactory modelFactory = workflowContext .definition ().application ().modelFactory ();
5940 if (taskContext .task () instanceof CallAILangChainChatModel callAILangChainChatModel ) {
6041 return CompletableFuture .completedFuture (
61- modelFactory .fromAny (doCall ( callAILangChainChatModel , input . asJavaObject ())));
62- }
63-
64- if (taskContext .task () instanceof CallAIChatModel callAIChatModel ) {
42+ modelFactory .fromAny (
43+ new CallAILangChainChatModelExecutor ()
44+ . apply ( callAILangChainChatModel , input . asJavaObject ())));
45+ } else if (taskContext .task () instanceof CallAIChatModel callAIChatModel ) {
6546 return CompletableFuture .completedFuture (
66- modelFactory .fromAny (doCall (callAIChatModel , input .asJavaObject ())));
47+ modelFactory .fromAny (
48+ new CallAIChatModelExecutor ().apply (callAIChatModel , input .asJavaObject ())));
6749 }
6850 throw new IllegalArgumentException (
6951 "AIChatModelCallExecutor can only process CallAIChatModel tasks, but received: "
@@ -74,112 +56,4 @@ public CompletableFuture<WorkflowModel> apply(
7456 public boolean accept (Class <? extends TaskBase > clazz ) {
7557 return CallAIChatModel .class .isAssignableFrom (clazz );
7658 }
77-
78- private Object doCall (CallAILangChainChatModel callAIChatModel , Object javaObject ) {
79- ChatModel chatModel = callAIChatModel .getChatModel ();
80- Class <?> chatModelRequest = callAIChatModel .getChatModelRequest ();
81- }
82-
83- private Object doCall (CallAIChatModel callAIChatModel , Object javaObject ) {
84- validate (callAIChatModel , javaObject );
85- ChatModel chatModel = createChatModel (callAIChatModel );
86- Map <String , Object > substitutions = (Map <String , Object >) javaObject ;
87-
88- List <ChatMessage > messages = new ArrayList <>();
89-
90- if (callAIChatModel .getChatModelRequest ().getSystemMessages () != null ) {
91- for (String systemMessage : callAIChatModel .getChatModelRequest ().getSystemMessages ()) {
92- String fixedUserMessage = replaceVariables (systemMessage , substitutions );
93- messages .add (new SystemMessage (fixedUserMessage ));
94- }
95- }
96-
97- if (callAIChatModel .getChatModelRequest ().getUserMessages () != null ) {
98- for (String userMessage : callAIChatModel .getChatModelRequest ().getUserMessages ()) {
99- String fixedUserMessage = replaceVariables (userMessage , substitutions );
100- messages .add (new UserMessage (fixedUserMessage ));
101- }
102- }
103-
104- return prepareResponse (chatModel .chat (messages ), javaObject );
105- }
106-
107- private String replaceVariables (String template , Map <String , Object > substitutions ) {
108- Set <String > variables = extractVariables (template );
109- for (Map .Entry <String , Object > entry : substitutions .entrySet ()) {
110- String variable = entry .getKey ();
111- Object value = entry .getValue ();
112- if (value != null && variables .contains (variable )) {
113- template = template .replace ("{{" + variable + "}}" , value .toString ());
114- }
115- }
116- return template ;
117- }
118-
119- private void validate (CallAIChatModel callAIChatModel , Object javaObject ) {
120- // TODO
121- }
122-
123- private ChatModel createChatModel (CallAIChatModel callAIChatModel ) {
124- ChatModelService chatModelService = getAvailableModel ();
125- if (chatModelService != null ) {
126- return chatModelService .getChatModel (callAIChatModel .getPreferences ());
127- }
128- throw new IllegalStateException (
129- "No LLM models found. Please ensure that you have the required dependencies in your classpath." );
130- }
131-
132- private ChatModelService getAvailableModel () {
133- ServiceLoader <ChatModelService > loader = ServiceLoader .load (ChatModelService .class );
134-
135- for (ChatModelService service : loader ) {
136- return service ;
137- }
138-
139- throw new IllegalStateException (
140- "No LLM models found. Please ensure that you have the required dependencies in your classpath." );
141- }
142-
143- private Map <String , Object > prepareResponse (ChatResponse response , Object javaObject ) {
144-
145- String id = response .id ();
146- String modelName = response .modelName ();
147- TokenUsage tokenUsage = response .tokenUsage ();
148- FinishReason finishReason = response .finishReason ();
149- AiMessage aiMessage = response .aiMessage ();
150-
151- Map <String , Object > responseMap = (Map <String , Object >) javaObject ;
152- if (response .id () != null ) {
153- responseMap .put ("id" , id );
154- }
155-
156- if (modelName != null ) {
157- responseMap .put ("modelName" , modelName );
158- }
159-
160- if (tokenUsage != null ) {
161- responseMap .put ("tokenUsage.inputTokenCount" , tokenUsage .inputTokenCount ());
162- responseMap .put ("tokenUsage.outputTokenCount" , tokenUsage .outputTokenCount ());
163- responseMap .put ("tokenUsage.totalTokenCount" , tokenUsage .totalTokenCount ());
164- }
165-
166- if (finishReason != null ) {
167- responseMap .put ("finishReason" , finishReason .name ());
168- }
169-
170- if (aiMessage != null ) {
171- responseMap .put ("text" , aiMessage .text ());
172- }
173-
174- return responseMap ;
175- }
176-
177- private static Set <String > extractVariables (String template ) {
178- Set <String > variables = new HashSet <>();
179- Matcher matcher = VARIABLE_PATTERN .matcher (template );
180- while (matcher .find ()) {
181- variables .add (matcher .group (1 ));
182- }
183- return variables ;
184- }
18559}
0 commit comments