Skip to content

Commit c0a08df

Browse files
loveTsongsurpercodehang
authored andcommitted
[fel] add default memory (ModelEngine-Group#194)
* [fel] implement memory for recent N histories * [fel] record LLM results to memory * [fel] apply RecentMemory as default memory for conversation * [fel] add null check for incoming message (cherry picked from commit ae3beb0)
1 parent 4bda5c0 commit c0a08df

File tree

6 files changed

+279
-1
lines changed

6 files changed

+279
-1
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3+
* This file is a part of the ModelEngine Project.
4+
* Licensed under the MIT License. See License.txt in the project root for license information.
5+
*--------------------------------------------------------------------------------------------*/
6+
7+
package modelengine.fel.core.memory.support;
8+
9+
import modelengine.fel.core.chat.ChatMessage;
10+
import modelengine.fel.core.memory.Memory;
11+
import modelengine.fel.core.template.BulkStringTemplate;
12+
import modelengine.fel.core.template.support.DefaultBulkStringTemplate;
13+
import modelengine.fitframework.inspection.Validation;
14+
import modelengine.fitframework.util.MapBuilder;
15+
16+
import java.util.List;
17+
import java.util.Map;
18+
import java.util.Queue;
19+
import java.util.concurrent.ArrayBlockingQueue;
20+
import java.util.function.Function;
21+
import java.util.stream.Collectors;
22+
23+
import static modelengine.fitframework.inspection.Validation.notNull;
24+
25+
/**
26+
* 表示使用最近一定次数历史记录的实现。
27+
*
28+
* @author 宋永坦
29+
* @since 2025-07-04
30+
*/
31+
public class RecentMemory implements Memory {
32+
private final Queue<ChatMessage> records;
33+
private final BulkStringTemplate bulkTemplate;
34+
private final Function<ChatMessage, Map<String, String>> extractor;
35+
36+
/**
37+
* 指定最大保留历史记录数量的构造方法。
38+
*
39+
* @param maxCount 表示最大保留历史记录数量的 {@code int}。
40+
* @throws IllegalArgumentException 当 {@code maxCount < 0} 时。
41+
*/
42+
public RecentMemory(int maxCount) {
43+
this(maxCount,
44+
new DefaultBulkStringTemplate("{{type}}:{{text}}", "\n"),
45+
message -> MapBuilder.<String, String>get()
46+
.put("type", message.type().getRole())
47+
.put("text", message.text())
48+
.build());
49+
}
50+
51+
/**
52+
* 指定最大保留历史记录数量、渲染模板、抽取方法的构造方法。
53+
*
54+
* @param maxCount 表示最大保留历史记录数量的 {@code int}。
55+
* @param bulkTemplate 表示批量字符串模板的 {@link BulkStringTemplate}。
56+
* @param extractor 表示将 {@link ChatMessage} 转换成
57+
* {@link Map}{@code <}{@link String}, {@link String}{@code >} 的处理函数。
58+
* @throws IllegalArgumentException 当 {@code maxCount < 0}、{@code bulkTemplate}、{@code extractor} 为 {@code null} 时。
59+
*/
60+
public RecentMemory(int maxCount, BulkStringTemplate bulkTemplate,
61+
Function<ChatMessage, Map<String, String>> extractor) {
62+
Validation.greaterThanOrEquals(maxCount, 0, "The max count should >= 0.");
63+
this.records = new ArrayBlockingQueue<>(maxCount);
64+
this.bulkTemplate = notNull(bulkTemplate, "The bulkTemplate cannot be null.");
65+
this.extractor = notNull(extractor, "The extractor cannot be null.");
66+
}
67+
68+
@Override
69+
public void add(ChatMessage message) {
70+
notNull(message, "The message cannot be null.");
71+
if (!this.records.offer(message)) {
72+
this.records.poll();
73+
this.records.offer(message);
74+
}
75+
}
76+
77+
@Override
78+
public void set(List<ChatMessage> messages) {
79+
notNull(messages, "The messages cannot be null.");
80+
messages.forEach(this::add);
81+
}
82+
83+
@Override
84+
public void clear() {
85+
this.records.clear();
86+
}
87+
88+
@Override
89+
public List<ChatMessage> messages() {
90+
return this.records.stream().toList();
91+
}
92+
93+
@Override
94+
public String text() {
95+
return this.records.stream()
96+
.map(this.extractor)
97+
.collect(Collectors.collectingAndThen(Collectors.toList(), this.bulkTemplate::render));
98+
}
99+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3+
* This file is a part of the ModelEngine Project.
4+
* Licensed under the MIT License. See License.txt in the project root for license information.
5+
*--------------------------------------------------------------------------------------------*/
6+
7+
package modelengine.fel.core.memory.support;
8+
9+
import modelengine.fel.core.chat.ChatMessage;
10+
import modelengine.fel.core.chat.support.AiMessage;
11+
12+
import org.junit.jupiter.api.Test;
13+
14+
import java.util.Arrays;
15+
import java.util.List;
16+
17+
import static org.junit.jupiter.api.Assertions.*;
18+
19+
/**
20+
* 表示 {@link RecentMemory} 的测试。
21+
*
22+
* @author 宋永坦
23+
* @since 2025-07-04
24+
*/
25+
class RecentMemoryTest {
26+
private final List<ChatMessage> inputChatMessages =
27+
Arrays.asList(new AiMessage("1"), new AiMessage("2"), new AiMessage("3"));
28+
29+
@Test
30+
void shouldKeepAllMessagesWhenAddGivenLessMessage() {
31+
RecentMemory recentMemory = new RecentMemory(4);
32+
this.inputChatMessages.forEach(recentMemory::add);
33+
List<ChatMessage> messages = recentMemory.messages();
34+
35+
assertEquals(inputChatMessages.size(), messages.size());
36+
for (int i = 0; i < inputChatMessages.size(); ++i) {
37+
assertEquals(inputChatMessages.get(i).text(), messages.get(i).text());
38+
}
39+
}
40+
41+
@Test
42+
void shouldKeepMaxCountMessagesWhenAddGivenOverMaxCountMessages() {
43+
RecentMemory recentMemory = new RecentMemory(2);
44+
this.inputChatMessages.forEach(recentMemory::add);
45+
List<ChatMessage> messages = recentMemory.messages();
46+
47+
assertEquals(2, messages.size());
48+
assertEquals(inputChatMessages.get(1).text(), messages.get(0).text());
49+
assertEquals(inputChatMessages.get(2).text(), messages.get(1).text());
50+
}
51+
52+
@Test
53+
void shouldKeepMaxCountMessagesWhenSetGivenOverMaxCountMessages() {
54+
RecentMemory recentMemory = new RecentMemory(2);
55+
recentMemory.set(this.inputChatMessages);
56+
List<ChatMessage> messages = recentMemory.messages();
57+
58+
assertEquals(2, messages.size());
59+
assertEquals(inputChatMessages.get(1).text(), messages.get(0).text());
60+
assertEquals(inputChatMessages.get(2).text(), messages.get(1).text());
61+
}
62+
}

framework/fel/java/fel-flow/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,10 @@
7474
<artifactId>assertj-core</artifactId>
7575
<scope>test</scope>
7676
</dependency>
77+
<dependency>
78+
<groupId>org.mockito</groupId>
79+
<artifactId>mockito-core</artifactId>
80+
<scope>test</scope>
81+
</dependency>
7782
</dependencies>
7883
</project>

framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/flows/Conversation.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import modelengine.fel.core.chat.ChatMessage;
1010
import modelengine.fel.core.chat.ChatOption;
1111
import modelengine.fel.core.memory.Memory;
12+
import modelengine.fel.core.memory.support.RecentMemory;
1213
import modelengine.fel.engine.activities.AiStart;
1314
import modelengine.fel.engine.activities.FlowCallBack;
1415
import modelengine.fel.engine.operators.models.StreamingConsumer;
15-
import modelengine.fel.engine.operators.sources.Source;
1616
import modelengine.fel.engine.util.StateKey;
1717
import modelengine.fit.waterflow.domain.context.FlowSession;
1818
import modelengine.fit.waterflow.domain.stream.operators.Operators;
@@ -33,6 +33,8 @@
3333
* @since 2024-04-28
3434
*/
3535
public class Conversation<D, R> {
36+
private static final int DEFAULT_HISTORY_COUNT = 20;
37+
3638
private final AiProcessFlow<D, R> flow;
3739
private final FlowSession session;
3840
private final AtomicReference<ConverseListener<R>> converseListener = new AtomicReference<>(null);
@@ -66,6 +68,7 @@ public Conversation(AiProcessFlow<D, R> flow, FlowSession session) {
6668
@SafeVarargs
6769
public final ConverseLatch<R> offer(D... data) {
6870
ConverseLatch<R> latch = setListener(this.flow);
71+
this.initMemory();
6972
FlowSession newSession = FlowSession.newRootSession(this.session, this.session.preserved());
7073
newSession.getWindow().setFrom(null);
7174
this.flow.start().offer(data, newSession);
@@ -85,6 +88,7 @@ public final ConverseLatch<R> offer(D... data) {
8588
public ConverseLatch<R> offer(String nodeId, List<?> data) {
8689
Validation.notBlank(nodeId, "invalid nodeId.");
8790
ConverseLatch<R> latch = setListener(this.flow);
91+
this.initMemory();
8892
FlowSession newSession = new FlowSession(this.session);
8993
newSession.getWindow().setFrom(null);
9094
this.flow.origin().offer(nodeId, data.toArray(new Object[0]), newSession);
@@ -231,4 +235,10 @@ private FlowSession setConverseListener(FlowSession session) {
231235
session.setInnerState(StateKey.CONVERSE_LISTENER, new AtomicReference<>(new ConcurrentHashMap<>()));
232236
return session;
233237
}
238+
239+
private void initMemory() {
240+
if (this.session.getInnerState(StateKey.HISTORY) == null) {
241+
this.session.setInnerState(StateKey.HISTORY, new RecentMemory(DEFAULT_HISTORY_COUNT));
242+
}
243+
}
234244
}

framework/fel/java/fel-flow/src/main/java/modelengine/fel/engine/operators/models/LlmEmitter.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88

99
import modelengine.fel.core.chat.ChatMessage;
1010
import modelengine.fel.core.chat.Prompt;
11+
import modelengine.fel.core.chat.support.HumanMessage;
12+
import modelengine.fel.core.memory.Memory;
1113
import modelengine.fel.engine.util.StateKey;
1214
import modelengine.fit.waterflow.bridge.fitflow.FitBoundedEmitter;
1315
import modelengine.fit.waterflow.domain.context.FlowSession;
1416
import modelengine.fitframework.flowable.Publisher;
1517
import modelengine.fitframework.inspection.Validation;
1618
import modelengine.fitframework.util.ObjectUtils;
19+
import modelengine.fitframework.util.StringUtils;
1720

1821
/**
1922
* 流式模型发射器。
@@ -26,6 +29,8 @@ public class LlmEmitter<O extends ChatMessage> extends FitBoundedEmitter<O, Chat
2629

2730
private final ChatChunk chunkAcc = new ChatChunk();
2831
private final StreamingConsumer<ChatMessage, ChatMessage> consumer;
32+
private final Memory memory;
33+
private final ChatMessage question;
2934

3035
/**
3136
* 初始化 {@link LlmEmitter}。
@@ -38,6 +43,9 @@ public LlmEmitter(Publisher<O> publisher, Prompt prompt, FlowSession session) {
3843
super(publisher, data -> data);
3944
Validation.notNull(session, "The session cannot be null.");
4045
this.consumer = ObjectUtils.nullIf(session.getInnerState(StateKey.STREAMING_CONSUMER), EMPTY_CONSUMER);
46+
this.memory = session.getInnerState(StateKey.HISTORY);
47+
this.question =
48+
ObjectUtils.getIfNull(session.getInnerState(StateKey.HISTORY_INPUT), () -> getDefaultQuestion(prompt));
4149
}
4250

4351
@Override
@@ -46,4 +54,21 @@ public void emit(ChatMessage data, FlowSession trans) {
4654
this.chunkAcc.merge(data);
4755
this.consumer.accept(this.chunkAcc, data);
4856
}
57+
58+
@Override
59+
public void complete() {
60+
if (this.memory != null && this.chunkAcc.toolCalls().isEmpty()) {
61+
this.memory.add(this.question);
62+
this.memory.add(this.chunkAcc);
63+
}
64+
super.complete();
65+
}
66+
67+
private static ChatMessage getDefaultQuestion(Prompt prompt) {
68+
int size = prompt.messages().size();
69+
if (size == 0) {
70+
return new HumanMessage(StringUtils.EMPTY);
71+
}
72+
return prompt.messages().get(size - 1);
73+
}
4974
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3+
* This file is a part of the ModelEngine Project.
4+
* Licensed under the MIT License. See License.txt in the project root for license information.
5+
*--------------------------------------------------------------------------------------------*/
6+
7+
package modelengine.fel.engine.operators.models;
8+
9+
import modelengine.fel.core.chat.ChatMessage;
10+
import modelengine.fel.core.chat.Prompt;
11+
import modelengine.fel.core.chat.support.AiMessage;
12+
import modelengine.fel.core.chat.support.ChatMessages;
13+
import modelengine.fel.core.memory.Memory;
14+
import modelengine.fel.core.tool.ToolCall;
15+
import modelengine.fel.engine.util.StateKey;
16+
import modelengine.fit.waterflow.domain.context.FlowSession;
17+
import modelengine.fitframework.flowable.Choir;
18+
import modelengine.fitframework.util.StringUtils;
19+
20+
import org.junit.jupiter.api.Test;
21+
import org.mockito.ArgumentCaptor;
22+
import org.mockito.Mockito;
23+
24+
import java.util.Arrays;
25+
import java.util.Collections;
26+
import java.util.List;
27+
28+
import static org.junit.jupiter.api.Assertions.*;
29+
30+
/**
31+
* 表示 {@link LlmEmitter} 的测试。
32+
*
33+
* @author 宋永坦
34+
* @since 2025-07-05
35+
*/
36+
class LlmEmitterTest {
37+
@Test
38+
void shouldAddMemoryWhenCompleteGivenLlmOutput() {
39+
String output = "data1";
40+
Prompt prompt = ChatMessages.fromList(Collections.emptyList());
41+
Choir<ChatMessage> dataSource = Choir.create(emitter -> {
42+
emitter.emit(new AiMessage(output));
43+
emitter.complete();
44+
});
45+
FlowSession flowSession = new FlowSession();
46+
Memory mockMemory = Mockito.mock(Memory.class);
47+
ArgumentCaptor<ChatMessage> captor = ArgumentCaptor.forClass(ChatMessage.class);
48+
Mockito.doNothing().when(mockMemory).add(captor.capture());
49+
flowSession.setInnerState(StateKey.HISTORY, mockMemory);
50+
51+
LlmEmitter<ChatMessage> llmEmitter = new LlmEmitter<>(dataSource, prompt, flowSession);
52+
llmEmitter.start(flowSession);
53+
54+
List<ChatMessage> captured = captor.getAllValues();
55+
assertEquals(2, captured.size());
56+
assertEquals(StringUtils.EMPTY, captured.get(0).text());
57+
assertEquals(output, captured.get(1).text());
58+
}
59+
60+
@Test
61+
void shouldNotAddMemoryWhenCompleteGivenLlmToolCallOutput() {
62+
String output = "data1";
63+
Prompt prompt = ChatMessages.fromList(Collections.emptyList());
64+
Choir<ChatMessage> dataSource = Choir.create(emitter -> {
65+
emitter.emit(new AiMessage(output, Arrays.asList(ToolCall.custom().id("id1").build())));
66+
emitter.complete();
67+
});
68+
FlowSession flowSession = new FlowSession();
69+
Memory mockMemory = Mockito.mock(Memory.class);
70+
flowSession.setInnerState(StateKey.HISTORY, mockMemory);
71+
72+
LlmEmitter<ChatMessage> llmEmitter = new LlmEmitter<>(dataSource, prompt, flowSession);
73+
llmEmitter.start(flowSession);
74+
75+
Mockito.verify(mockMemory, Mockito.times(0)).add(Mockito.any());
76+
}
77+
}

0 commit comments

Comments
 (0)