Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix:1559 修复自定义AI不能使用的问题。 #1599

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ public SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseE
case CHAT2DBAI:
return chatWithChat2dbAi(queryRequest, sseEmitter, uid);
case RESTAI :
return chatWithRestAi(queryRequest, sseEmitter);
return chatWithRestAi(queryRequest, sseEmitter, uid);
case FASTCHATAI:
return chatWithFastChatAi(queryRequest, sseEmitter, uid);
case AZUREAI :
Expand All @@ -261,9 +261,15 @@ public SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseE
* @param sseEmitter
* @return
*/
private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter) {
RestAIEventSourceListener eventSourceListener = new RestAIEventSourceListener(sseEmitter);
RestAIClient.getInstance().restCompletions(buildPrompt(prompt), eventSourceListener);
private SseEmitter chatWithRestAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
List<FastChatMessage> messages = getFastChatMessage(uid, prompt);

buildSseEmitter(sseEmitter, uid);

RestAIEventSourceListener restAIEventSourceListener = new RestAIEventSourceListener(sseEmitter);
RestAIClient.getInstance().streamCompletions(messages, restAIEventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ai.chat2db.server.web.api.util.ApplicationContextUtil;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

/**
* @author moji
Expand All @@ -19,6 +20,11 @@ public class RestAIClient {
*/
public static final String AI_SQL_SOURCE = "ai.sql.source";

/**
* Customized AI interface KEY
*/
public static final String REST_AI_API_KEY = "rest.ai.apiKey";

/**
* Customized AI interface address
*/
Expand All @@ -29,17 +35,24 @@ public class RestAIClient {
*/
public static final String REST_AI_STREAM_OUT = "rest.ai.stream";

private static RestAiStreamClient REST_AI_STREAM_CLIENT;
/**
* Custom AI interface model
*/
public static final String REST_AI_MODEL = "rest.ai.model";

public static RestAiStreamClient getInstance() {


private static RestAIStreamClient REST_AI_STREAM_CLIENT;

public static RestAIStreamClient getInstance() {
if (REST_AI_STREAM_CLIENT != null) {
return REST_AI_STREAM_CLIENT;
} else {
return singleton();
}
}

private static RestAiStreamClient singleton() {
private static RestAIStreamClient singleton() {
if (REST_AI_STREAM_CLIENT == null) {
synchronized (RestAIClient.class) {
if (REST_AI_STREAM_CLIENT == null) {
Expand All @@ -55,17 +68,23 @@ private static RestAiStreamClient singleton() {
*/
public static void refresh() {
String apiUrl = "";
Boolean stream = Boolean.TRUE;
String apiKey = "";
String model = "";
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config apiHostConfig = configService.find(REST_AI_URL).getData();
if (apiHostConfig != null) {
apiUrl = apiHostConfig.getContent();
}
Config config = configService.find(REST_AI_STREAM_OUT).getData();
Config config = configService.find(REST_AI_API_KEY).getData();
if (config != null) {
stream = Boolean.valueOf(config.getContent());
apiKey = config.getContent();
}
Config deployConfig = configService.find(REST_AI_MODEL).getData();
if (deployConfig != null && StringUtils.isNotBlank(deployConfig.getContent())) {
model = deployConfig.getContent();
}
REST_AI_STREAM_CLIENT = new RestAiStreamClient(apiUrl, stream);
REST_AI_STREAM_CLIENT = RestAIStreamClient.builder().apiKey(apiKey).apiHost(apiUrl).model(model)
.build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package ai.chat2db.server.web.api.controller.ai.rest.client;

import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.controller.ai.fastchat.interceptor.FastChatHeaderAuthorizationInterceptor;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsOptions;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.apache.commons.collections4.CollectionUtils;
import org.jetbrains.annotations.NotNull;

import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

/**
* Custom AI interface client
* @author moji
*/
@Slf4j
public class RestAIStreamClient {
/**
* apikey
*/
@Getter
@NotNull
private String apiKey;

/**
* apiHost
*/
@Getter
@NotNull
private String apiHost;

/**
* model
*/
@Getter
private String model;
/**
* okHttpClient
*/
@Getter
private OkHttpClient okHttpClient;

/**
* Construct instance object
*
* @param builder
*/
public RestAIStreamClient(Builder builder) {
this.apiKey = builder.apiKey;
this.apiHost = builder.apiHost;
this.model = builder.model;
this.okHttpClient = new OkHttpClient
.Builder()
.addInterceptor(new FastChatHeaderAuthorizationInterceptor(this.apiKey))
.connectTimeout(10, TimeUnit.SECONDS)
.writeTimeout(50, TimeUnit.SECONDS)
.readTimeout(50, TimeUnit.SECONDS)
.build();
}

/**
* structure
*
* @return
*/
public static RestAIStreamClient.Builder builder() {
return new RestAIStreamClient.Builder();
}

/**
* builder
*/
public static final class Builder {
private String apiKey;

private String apiHost;

private String model;


/**
* OkhttpClient
*/
private OkHttpClient okHttpClient;

public Builder() {
}

public RestAIStreamClient.Builder apiKey(String apiKeyValue) {
this.apiKey = apiKeyValue;
return this;
}

/**
* @param apiHostValue
* @return
*/
public RestAIStreamClient.Builder apiHost(String apiHostValue) {
this.apiHost = apiHostValue;
return this;
}

/**
* @param modelValue
* @return
*/
public RestAIStreamClient.Builder model(String modelValue) {
this.model = modelValue;
return this;
}


public RestAIStreamClient.Builder okHttpClient(OkHttpClient val) {
this.okHttpClient = val;
return this;
}

public RestAIStreamClient build() {
return new RestAIStreamClient(this);
}

}


/**
* Q&A interface stream form
*
* @param chatMessages
* @param eventSourceListener
*/
public void streamCompletions(List<FastChatMessage> chatMessages, EventSourceListener eventSourceListener) {
if (CollectionUtils.isEmpty(chatMessages)) {
log.error("param error:Rest AI Prompt cannot be empty");
throw new ParamBusinessException("prompt");
}
if (Objects.isNull(eventSourceListener)) {
log.error("param error:RestAIEventSourceListener cannot be empty");
throw new ParamBusinessException();
}
log.info("Rest AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent());
try {

FastChatCompletionsOptions chatCompletionsOptions = new FastChatCompletionsOptions(chatMessages);
chatCompletionsOptions.setStream(true);
chatCompletionsOptions.setModel(this.model);

EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
String requestBody = mapper.writeValueAsString(chatCompletionsOptions);
Request request = new Request.Builder()
.url(apiHost)
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
.build();
//Create event
EventSource eventSource = factory.newEventSource(request, eventSourceListener);
log.info("finish invoking rest ai");
} catch (Exception e) {
log.error("rest ai error", e);
eventSourceListener.onFailure(null, e, null);
throw new ParamBusinessException();
}
}


}
Loading