Skip to content

[WIP] Add CRaC Support #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions pom.xml
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.1.3</version>
<version>3.2.0</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.apolloconfig.apollo.ai</groupId>
@@ -16,10 +16,13 @@
<description>a smart qa bot</description>
<properties>
<java.version>17</java.version>
<openai-gpt3-java.version>0.16.0</openai-gpt3-java.version>
<guava.version>32.1.2-jre</guava.version>
<openai-gpt3-java.version>0.18.2</openai-gpt3-java.version>
<guava.version>32.1.3-jre</guava.version>
<flexmark.version>0.64.8</flexmark.version>
<milvus.version>2.3.0</milvus.version>
<milvus.version>2.3.3</milvus.version>
<!-- There is a bug in 4.1.101.Final -->
<netty.codec.http2.version>4.1.100.Final</netty.codec.http2.version>
<crac.version>1.4.0</crac.version>
</properties>

<dependencyManagement>
@@ -44,6 +47,16 @@
<artifactId>milvus-sdk-java</artifactId>
<version>${milvus.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-codec-http2</artifactId>
<version>${netty.codec.http2.version}</version>
</dependency>
<dependency>
<groupId>org.crac</groupId>
<artifactId>crac</artifactId>
<version>${crac.version}</version>
</dependency>
</dependencies>
</dependencyManagement>

@@ -82,6 +95,10 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.crac</groupId>
<artifactId>crac</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
Original file line number Diff line number Diff line change
@@ -4,8 +4,10 @@
import io.milvus.client.MilvusServiceClient;
import io.milvus.param.ConnectParam;
import java.util.Map;
import org.crac.Context;
import org.crac.Resource;

public class MilvusClientFactory {
public class MilvusClientFactory implements Resource {

private static final MilvusClientFactory INSTANCE = new MilvusClientFactory();
private static final Map<String, MilvusServiceClient> clients = Maps.newConcurrentMap();
@@ -43,4 +45,14 @@ private MilvusServiceClient createClient(String host, int port) {
private MilvusServiceClient createCloudClient(String uri, String token) {
return new MilvusServiceClient(ConnectParam.newBuilder().withUri(uri).withToken(token).build());
}

@Override
public void beforeCheckpoint(Context<? extends Resource> context) throws Exception {
clients.clear();
}

@Override
public void afterRestore(Context<? extends Resource> context) throws Exception {

}
}
Original file line number Diff line number Diff line change
@@ -34,20 +34,30 @@
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import org.crac.Context;
import org.crac.Resource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;

@Profile("milvus")
@Service
class MilvusService implements VectorDBService {
class MilvusService implements VectorDBService, Resource {

private final MilvusServiceClient milvusServiceClient;
private static final Logger LOGGER = LoggerFactory.getLogger(MilvusService.class);

private MilvusServiceClient milvusServiceClient;
private final MilvusConfig milvusConfig;
private final List<Float> dummyEmbeddings = Lists.newArrayList();

public MilvusService(MilvusConfig milvusConfig) {
this.milvusConfig = milvusConfig;
this.init();
}

private void init() {
if (milvusConfig.isUseZillzCloud()) {
this.milvusServiceClient = MilvusClientFactory.getCloudClient(
milvusConfig.getZillizCloudUri(),
@@ -413,4 +423,16 @@ private void ensureFileCollection() {
);
}

@Override
public void beforeCheckpoint(Context<? extends Resource> context) throws Exception {
LOGGER.info("beforeCheckpoint");
this.milvusServiceClient = null;
}

@Override
public void afterRestore(Context<? extends Resource> context) throws Exception {
LOGGER.info("afterRestore");
this.init();
LOGGER.info("afterRestore done");
}
}
Original file line number Diff line number Diff line change
@@ -10,19 +10,25 @@
import com.theokanning.openai.embedding.EmbeddingRequest;
import io.reactivex.Flowable;
import java.util.List;
import org.crac.Context;
import org.crac.Resource;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;

@Profile("openai")
@Component
class OpenAiService implements AiService {
class OpenAiService implements AiService, Resource {

private static final String DEFAULT_MODEL = "gpt-3.5-turbo";
private static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002";

private final com.theokanning.openai.service.OpenAiService service;
private com.theokanning.openai.service.OpenAiService service;

public OpenAiService() {
init();
}

private void init() {
service = OpenAiServiceFactory.getService(System.getenv("OPENAI_API_KEY"));
}

@@ -60,4 +66,14 @@ public List<Embedding> getEmbeddings(List<String> chunks) {

return service.createEmbeddings(embeddingRequest).getData();
}

@Override
public void beforeCheckpoint(Context<? extends Resource> context) throws Exception {
this.service = null;
}

@Override
public void afterRestore(Context<? extends Resource> context) throws Exception {
this.init();
}
}
Original file line number Diff line number Diff line change
@@ -23,9 +23,11 @@
import okhttp3.Authenticator;
import okhttp3.Credentials;
import okhttp3.OkHttpClient;
import org.crac.Context;
import org.crac.Resource;
import retrofit2.Retrofit;

public class OpenAiServiceFactory {
public class OpenAiServiceFactory implements Resource {

private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60);

@@ -85,6 +87,16 @@ private OkHttpClient client(String apiKey) {
.build();
}

@Override
public void beforeCheckpoint(Context<? extends Resource> context) throws Exception {
SERVICES.clear();
}

@Override
public void afterRestore(Context<? extends Resource> context) throws Exception {

}

private static class DelegatingSocketFactory extends SocketFactory {

private final SocketFactory delegate;
14 changes: 14 additions & 0 deletions src/main/scripts/startup.sh
Original file line number Diff line number Diff line change
@@ -19,13 +19,27 @@ SERVICE_NAME=qa-bot
LOG_DIR=/opt/logs
## Adjust server port if necessary
SERVER_PORT=${SERVER_PORT:=9090}
## Adjust crac files dir if necessary
CRAC_FILES_DIR=/opt/crac

## Create log directory if not existed because JDK 8+ won't do that
mkdir -p $LOG_DIR

mkdir -p $CRAC_FILES_DIR

## Adjust memory settings if necessary
#export JAVA_OPTS="-Xms2560m -Xmx2560m -Xss256k -XX:MetaspaceSize=128m -XX:MaxMetaspaceSize=384m -XX:NewSize=1536m -XX:MaxNewSize=1536m -XX:SurvivorRatio=8"

# Check for 'checkpoint' argument
if [ "$1" = "checkpoint" ]; then
export JAVA_OPTS="$JAVA_OPTS -Dspring.context.checkpoint=onRefresh -XX:CRaCCheckpointTo=$CRAC_FILES_DIR"
fi

# Check for 'restore' argument
if [ "$1" = "restore" ]; then
export JAVA_OPTS="$JAVA_OPTS -XX:CRaCRestoreFrom=$CRAC_FILES_DIR"
fi

## Only uncomment the following when you are using server jvm
export JAVA_OPTS="$JAVA_OPTS -server -XX:-ReduceInitialCardMarks"