Skip to content

Commit

Permalink
Merge pull request #2 from muhrifqii/dev/chat-memory
Browse files Browse the repository at this point in the history
Chat Memory and MessageStore Usecase
  • Loading branch information
muhrifqii authored Sep 4, 2024
2 parents b147c35 + 781a359 commit 0605dba
Show file tree
Hide file tree
Showing 25 changed files with 434 additions and 37 deletions.
12 changes: 12 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for more information:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
# https://containers.dev/guide/dependabot

version: 2
updates:
- package-ecosystem: "devcontainers"
directory: "/"
schedule:
interval: weekly
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,3 @@ out/
/nbdist/
/.nb-gradle/

### VS Code ###
.vscode/
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"editor.formatOnSave": true,
"java.configuration.updateBuildConfiguration": "automatic",
"java.compile.nullAnalysis.mode": "automatic",
"java.autobuild.enabled": false,
}
15 changes: 15 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ lombok = "1.18.34"
reactor-core = "3.6.9"
graalvm = "0.10.2"
flyway = "10.12.0"
micrometer-tracing = "1.3.3"
micrometer-registry = "1.13.3"
otel = "1.37.0"
test-containers = "1.19.8"
jackson = "2.17.2"

[libraries]
lombok = { module = "org.projectlombok:lombok", version.ref = "lombok" }
Expand All @@ -15,6 +20,7 @@ spring-boot-webflux = { module = "org.springframework.boot:spring-boot-starter-w
spring-boot-r2dbc = { module = "org.springframework.boot:spring-boot-starter-data-r2dbc", version.ref = "spring-boot" }
spring-boot-devtools = { module = "org.springframework.boot:spring-boot-devtools", version.ref = "spring-boot" }
spring-boot-test = { module = "org.springframework.boot:spring-boot-starter-test", version.ref = "spring-boot" }
spring-boot-testcontainers = { module = "org.springframework.boot:spring-boot-testcontainers", version.ref = "spring-boot" }

spring-ai-bom = { group = "org.springframework.ai", name = "spring-ai-bom", version.ref = "spring-ai" }
spring-ai-ollama = { group = "org.springframework.ai", name = "spring-ai-ollama-spring-boot-starter" }
Expand All @@ -27,6 +33,15 @@ flyway-postgresql = { module = "org.flywaydb:flyway-database-postgresql", versio
postgresql = { module = "org.postgresql:postgresql", version = "42.7.3" }
postgresql-r2dbc = { module = "org.postgresql:r2dbc-postgresql", version = "1.0.5.RELEASE" }

testcontainer-junit = { module = "org.testcontainers:junit-jupiter", version.ref = "test-containers" }

micrometer-registry = { module = "io.micrometer:micrometer-registry-otlp", version.ref = "micrometer-registry" }
micrometer-tracing-otel = { module = "io.micrometer:micrometer-tracing-bridge-otel", version.ref = "micrometer-tracing" }
otel-exporter = { module = "io.opentelemetry:opentelemetry-exporter-otlp", version.ref = "otel" }

jackson-jsr310 = { module = "com.fasterxml.jackson.datatype:jackson-datatype-jsr310", version.ref = "jackson" }
uuid = { module = "com.fasterxml.uuid:java-uuid-generator", version = "5.1.0" }

[plugins]
spring-boot = { id = "org.springframework.boot", version.ref = "spring-boot" }
spring-dependencies = { id = "io.spring.dependency-management", version = "1.1.6" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ public record Message(
String id,
String conversationId,
String content,
String messageType,
String createdAt) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.muhrifqii.llm.api.traits;

import com.muhrifqii.llm.api.usecases.PromptModelUsecase;
import com.muhrifqii.llm.api.usecases.SummarizerUsecase;

public interface ChatServiceTrait extends PromptModelUsecase, SummarizerUsecase {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.muhrifqii.llm.api.traits;

import com.muhrifqii.llm.api.usecases.ConversationStoreUsecase;
import com.muhrifqii.llm.api.usecases.MessageStoreUsecase;

public interface ConversationServiceTrait extends ConversationStoreUsecase, MessageStoreUsecase {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.muhrifqii.llm.api.usecases;

import com.muhrifqii.llm.api.datamodels.conversations.Conversation;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public interface ConversationStoreUsecase {

Flux<Conversation> getConversations(int page, int size, String orderBy);

Mono<Conversation> getConversation(String id);

Mono<Conversation> getOrCreateConversation(String id);

Mono<Conversation> updateConversation(String id, Conversation latest);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.muhrifqii.llm.api.usecases;

import com.muhrifqii.llm.api.datamodels.conversations.Message;
import com.muhrifqii.llm.api.datamodels.conversations.UserMessage;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public interface MessageStoreUsecase {
Flux<Message> getMessages(String conversationId, String cursor);

Mono<Message> saveUserMessage(UserMessage userMessage);

Mono<Message> saveAssistantMessage(Message message);
}
16 changes: 12 additions & 4 deletions llm/api/src/main/java/com/muhrifqii/llm/api/utils/DateUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,31 @@
import lombok.experimental.UtilityClass;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Optional;
import java.time.ZoneOffset;

@UtilityClass
public class DateUtils {

public static LocalDateTime now() {
return LocalDateTime.now();
}

public static String nowIsoString() {
return LocalDateTime.now()
.format(DateTimeFormatter.ISO_DATE_TIME);
return DateTimeFormatter.ISO_DATE_TIME
.format(now());
}

public static long nowMillis() {
return LocalDateTime.now()
return now()
.toInstant(ZoneOffset.UTC)
.toEpochMilli();
}

public static String toIsoString(LocalDateTime date) {
return date.format(DateTimeFormatter.ISO_DATE_TIME);

return Optional.ofNullable(date)
.map(DateTimeFormatter.ISO_DATE_TIME::format)
.orElse("");
}
}
7 changes: 7 additions & 0 deletions llm/ollama-provider/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ dependencies {
implementation(platform(libs.spring.ai.bom))
implementation(libs.spring.ai.ollama)

implementation(libs.micrometer.tracing.otel)
implementation(libs.otel.exporter)
implementation(libs.micrometer.registry)

implementation(libs.uuid)
implementation(libs.jackson.jsr310)

testAndDevelopmentOnly(libs.spring.boot.devtools)

testImplementation(libs.spring.boot.test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@

@UtilityClass
public class Constants {

public static final String EMPTY_SLUG = "-";

public static final int PAGE_SIZE = 30;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.muhrifqii.llm.annotations;

import java.lang.annotation.ElementType;
import java.lang.annotation.Target;
import org.springframework.beans.factory.annotation.Qualifier;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;

@Qualifier
@Target({ ElementType.TYPE, ElementType.PARAMETER, ElementType.FIELD })
@Retention(RetentionPolicy.RUNTIME)
public @interface MemCachedChatMemory {
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
package com.muhrifqii.llm.configurations;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import com.muhrifqii.llm.annotations.MemCachedChatMemory;

@Configuration
public class ChatModelConfig {
@Bean
ChatClient chatClient(ChatClient.Builder builder) {
ChatClient chatClient(
ChatClient.Builder builder,
@MemCachedChatMemory ChatMemory chatMemory) {
return builder
.defaultSystem(
"You are a Pokédex AI named Slaking.AI, a highly advanced AI that specializes in providing detailed and accurate information about Pokémon. You have access to all known data about Pokémon species, including their types, abilities, evolutions, habitat, and more. Your responses should be concise, factual, and directly related to the Pokémon in question. Ensure to offer relevant insights based on the user's query, and avoid speculation. Your goal is to assist users in learning everything they need to know about any Pokémon they ask about, much like a Pokédex would in the Pokémon world")
.defaultAdvisors(
new SimpleLoggerAdvisor(),
new MessageChatMemoryAdvisor(chatMemory))
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import com.muhrifqii.llm.api.datamodels.conversations.ConversationRequest;
import com.muhrifqii.llm.api.datamodels.conversations.Message;
import com.muhrifqii.llm.api.datamodels.conversations.UserMessage;
import com.muhrifqii.llm.services.ChatService;
import com.muhrifqii.llm.api.usecases.PromptModelUsecase;

import reactor.core.publisher.Mono;
import reactor.core.publisher.Flux;
Expand All @@ -23,20 +23,20 @@
@RequiredArgsConstructor
public class AskOnceChatController {

private final ChatService chatService;
private final PromptModelUsecase promptService;

@GetMapping("/health")
public String healthCheck() {
return "Up And Running";
return "Up Up and Away!";
}

@PostMapping("/simple")
public Mono<Message> chat(@RequestBody ConversationRequest input) {
return chatService.chat(null, new UserMessage(input.message(), null));
return promptService.chat(null, new UserMessage(input.message(), null));
}

@PostMapping(value = "/stream", produces = MediaType.APPLICATION_NDJSON_VALUE)
public Flux<Message> streamChat(@RequestBody ConversationRequest input) {
return chatService.streamChat(null, new UserMessage(input.message(), null));
return promptService.streamChat(null, new UserMessage(input.message(), null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@
import com.muhrifqii.llm.api.datamodels.conversations.ConversationRequest;
import com.muhrifqii.llm.api.datamodels.conversations.Message;
import com.muhrifqii.llm.api.datamodels.conversations.UserMessage;
import com.muhrifqii.llm.services.ChatService;
import com.muhrifqii.llm.services.ConversationService;
import com.muhrifqii.llm.api.traits.ChatServiceTrait;
import com.muhrifqii.llm.api.traits.ConversationServiceTrait;

@RestController
@RequestMapping("/ai/conversations")
@RequiredArgsConstructor
@Slf4j
public class ConversationalChatController {

private final ChatService chatService;
private final ConversationService conversationService;
private final ChatServiceTrait chatService;
private final ConversationServiceTrait conversationService;

@GetMapping()
public Flux<Conversation> getConversations(
Expand All @@ -52,18 +52,20 @@ public Mono<Conversation> getConversation(@PathVariable String id) {

@PostMapping("/{id}/chat-and-wait")
public Mono<Message> chatAndWait(@PathVariable String id, @RequestBody ConversationRequest input) {
final var userMessage = new UserMessage(input.message(), null);
return conversationService.getOrCreateConversation(id)
.doOnNext(conversation -> makeTitle(input, conversation))
.flatMap(conversation -> chatService
.chat(conversation.id(), new UserMessage(input.message(), null)));
.chat(conversation.id(), userMessage));
}

@PostMapping(value = "/{id}/chat-stream", produces = MediaType.APPLICATION_NDJSON_VALUE)
public Flux<Message> streamChat(@PathVariable String id, @RequestBody ConversationRequest input) {
final var userMessage = new UserMessage(input.message(), null);
return conversationService.getOrCreateConversation(id)
.doOnNext(conversation -> makeTitle(input, conversation))
.flatMapMany(conversation -> chatService
.streamChat(conversation.id(), new UserMessage(input.message(), null)));
.streamChat(conversation.id(), userMessage));
}

private void makeTitle(ConversationRequest input, Conversation conversation) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.muhrifqii.llm.repositories;

import java.time.LocalDateTime;
import java.util.Map;

import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.data.annotation.CreatedDate;
import org.springframework.data.annotation.Id;
import org.springframework.data.domain.Persistable;
import org.springframework.data.relational.core.mapping.Table;
import org.springframework.lang.Nullable;

import lombok.Builder;

@Table("ai_messages")
@Builder
public record MessageEntity(
@Id String id,
String coversationId,
String content,
String messageType,
@CreatedDate LocalDateTime createdAt)
implements Persistable<String>, Message {

@Override
@Nullable
public String getId() {
return id();
}

@Override
public boolean isNew() {
return createdAt == null;
}

@Override
public String getContent() {
return content();
}

@Override
public Map<String, Object> getMetadata() {
// todo: implement
return Map.of();
}

@Override
public MessageType getMessageType() {
return MessageType.fromValue(messageType);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.muhrifqii.llm.repositories;

import org.springframework.data.repository.reactive.ReactiveCrudRepository;
import org.springframework.stereotype.Repository;

import reactor.core.publisher.Flux;
import org.springframework.data.r2dbc.repository.Query;

@Repository
public interface MessageRepository extends ReactiveCrudRepository<MessageEntity, String> {

@Query("""
SELECT * FROM ai_messages
WHERE coversation_id = :conversationId
AND id < :cursor
ORDER BY id DESC
LIMIT :limit
""")
Flux<MessageEntity> findMessagesBeforeCursor(
String conversationId,
String cursor,
int limit);

@Query("""
SELECT * FROM ai_messages
WHERE coversation_id = :conversationId
ORDER BY id DESC
LIMIT :limit
""")
Flux<MessageEntity> findLatestMessages(
String conversationId,
int limit);
}
Loading

0 comments on commit 0605dba

Please sign in to comment.