Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
* This file is a part of the ModelEngine Project.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

package modelengine.fel.core.memory.support;

import modelengine.fel.core.chat.ChatMessage;
import modelengine.fel.core.memory.Memory;
import modelengine.fel.core.template.BulkStringTemplate;
import modelengine.fel.core.template.support.DefaultBulkStringTemplate;
import modelengine.fitframework.inspection.Validation;
import modelengine.fitframework.util.MapBuilder;

import java.util.List;
import java.util.Map;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Function;
import java.util.stream.Collectors;

import static modelengine.fitframework.inspection.Validation.notNull;

/**
* 表示使用最近一定次数历史记录的实现。
*
* @author 宋永坦
* @since 2025-07-04
*/
public class RecentMemory implements Memory {
private final LinkedBlockingQueue<ChatMessage> records;
private final BulkStringTemplate bulkTemplate;
private final Function<ChatMessage, Map<String, String>> extractor;

/**
* 指定最大保留历史记录数量的构造方法。
*
* @param maxCount 表示最大保留历史记录数量的 {@code int}。
* @throws IllegalArgumentException 当 {@code maxCount < 0} 时。
*/
public RecentMemory(int maxCount) {
this(maxCount,
new DefaultBulkStringTemplate("{{type}}:{{text}}", "\n"),
message -> MapBuilder.<String, String>get()
.put("type", message.type().getRole())
.put("text", message.text())
.build());
}

/**
* 指定最大保留历史记录数量、渲染模板、抽取方法的构造方法。
*
* @param maxCount 表示最大保留历史记录数量的 {@code int}。
* @param bulkTemplate 表示批量字符串模板的 {@link BulkStringTemplate}。
* @param extractor 表示将 {@link ChatMessage} 转换成
* {@link Map}{@code <}{@link String}, {@link String}{@code >} 的处理函数。
* @throws IllegalArgumentException 当 {@code maxCount < 0}、{@code bulkTemplate}、{@code extractor} 为 {@code null} 时。
*/
public RecentMemory(int maxCount, BulkStringTemplate bulkTemplate,
Function<ChatMessage, Map<String, String>> extractor) {
Validation.greaterThanOrEquals(maxCount, 0, "The max count should >= 0.");
this.records = new LinkedBlockingQueue<>(maxCount);
this.bulkTemplate = notNull(bulkTemplate, "The bulkTemplate cannot be null.");
this.extractor = notNull(extractor, "The extractor cannot be null.");
}

@Override
public void add(ChatMessage message) {
if (!this.records.offer(message)) {
this.records.poll();
this.records.offer(message);
}
}

@Override
public void set(List<ChatMessage> messages) {
messages.forEach(this::add);
}

@Override
public void clear() {
this.records.clear();
}

@Override
public List<ChatMessage> messages() {
return this.records.stream().toList();
}

@Override
public String text() {
return this.records.stream()
.map(this.extractor)
.collect(Collectors.collectingAndThen(Collectors.toList(), this.bulkTemplate::render));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
* This file is a part of the ModelEngine Project.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

package modelengine.fel.core.memory.support;

import modelengine.fel.core.chat.ChatMessage;
import modelengine.fel.core.chat.support.AiMessage;

import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

/**
* 表示 {@link RecentMemory} 的测试。
*
* @author 宋永坦
* @since 2025-07-04
*/
class RecentMemoryTest {
private final List<ChatMessage> inputChatMessages =
Arrays.asList(new AiMessage("1"), new AiMessage("2"), new AiMessage("3"));

@Test
void shouldKeepAllMessagesWhenAddGivenLessMessage() {
RecentMemory recentMemory = new RecentMemory(4);
this.inputChatMessages.forEach(recentMemory::add);
List<ChatMessage> messages = recentMemory.messages();

assertEquals(inputChatMessages.size(), messages.size());
for (int i = 0; i < inputChatMessages.size(); ++i) {
assertEquals(inputChatMessages.get(i).text(), messages.get(i).text());
}
}

@Test
void shouldKeepMaxCountMessagesWhenAddGivenOverMaxCountMessages() {
RecentMemory recentMemory = new RecentMemory(2);
this.inputChatMessages.forEach(recentMemory::add);
List<ChatMessage> messages = recentMemory.messages();

assertEquals(2, messages.size());
assertEquals(inputChatMessages.get(1).text(), messages.get(0).text());
assertEquals(inputChatMessages.get(2).text(), messages.get(1).text());
}

@Test
void shouldKeepMaxCountMessagesWhenSetGivenOverMaxCountMessages() {
RecentMemory recentMemory = new RecentMemory(2);
recentMemory.set(this.inputChatMessages);
List<ChatMessage> messages = recentMemory.messages();

assertEquals(2, messages.size());
assertEquals(inputChatMessages.get(1).text(), messages.get(0).text());
assertEquals(inputChatMessages.get(2).text(), messages.get(1).text());
}
}
5 changes: 5 additions & 0 deletions framework/fel/java/fel-flow/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,10 @@
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import modelengine.fel.core.chat.ChatMessage;
import modelengine.fel.core.chat.ChatOption;
import modelengine.fel.core.memory.Memory;
import modelengine.fel.core.memory.support.RecentMemory;
import modelengine.fel.engine.activities.AiStart;
import modelengine.fel.engine.activities.FlowCallBack;
import modelengine.fel.engine.operators.models.StreamingConsumer;
import modelengine.fel.engine.operators.sources.Source;
import modelengine.fel.engine.util.StateKey;
import modelengine.fit.waterflow.domain.context.FlowSession;
import modelengine.fit.waterflow.domain.stream.operators.Operators;
Expand All @@ -33,6 +33,8 @@
* @since 2024-04-28
*/
public class Conversation<D, R> {
private static final int DEFAULT_HISTORY_COUNT = 20;

private final AiProcessFlow<D, R> flow;
private final FlowSession session;
private final AtomicReference<ConverseListener<R>> converseListener = new AtomicReference<>(null);
Expand Down Expand Up @@ -66,6 +68,7 @@ public Conversation(AiProcessFlow<D, R> flow, FlowSession session) {
@SafeVarargs
public final ConverseLatch<R> offer(D... data) {
ConverseLatch<R> latch = setListener(this.flow);
this.initMemory();
FlowSession newSession = FlowSession.newRootSession(this.session, this.session.preserved());
newSession.getWindow().setFrom(null);
this.flow.start().offer(data, newSession);
Expand All @@ -85,6 +88,7 @@ public final ConverseLatch<R> offer(D... data) {
public ConverseLatch<R> offer(String nodeId, List<?> data) {
Validation.notBlank(nodeId, "invalid nodeId.");
ConverseLatch<R> latch = setListener(this.flow);
this.initMemory();
FlowSession newSession = new FlowSession(this.session);
newSession.getWindow().setFrom(null);
this.flow.origin().offer(nodeId, data.toArray(new Object[0]), newSession);
Expand Down Expand Up @@ -231,4 +235,10 @@ private FlowSession setConverseListener(FlowSession session) {
session.setInnerState(StateKey.CONVERSE_LISTENER, new AtomicReference<>(new ConcurrentHashMap<>()));
return session;
}

private void initMemory() {
if (this.session.getInnerState(StateKey.HISTORY) == null) {
this.session.setInnerState(StateKey.HISTORY, new RecentMemory(DEFAULT_HISTORY_COUNT));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@

import modelengine.fel.core.chat.ChatMessage;
import modelengine.fel.core.chat.Prompt;
import modelengine.fel.core.chat.support.HumanMessage;
import modelengine.fel.core.memory.Memory;
import modelengine.fel.engine.util.StateKey;
import modelengine.fit.waterflow.bridge.fitflow.FitBoundedEmitter;
import modelengine.fit.waterflow.domain.context.FlowSession;
import modelengine.fitframework.flowable.Publisher;
import modelengine.fitframework.inspection.Validation;
import modelengine.fitframework.util.ObjectUtils;
import modelengine.fitframework.util.StringUtils;

/**
* 流式模型发射器。
Expand All @@ -26,6 +29,8 @@ public class LlmEmitter<O extends ChatMessage> extends FitBoundedEmitter<O, Chat

private final ChatChunk chunkAcc = new ChatChunk();
private final StreamingConsumer<ChatMessage, ChatMessage> consumer;
private final Memory memory;
private final ChatMessage question;

/**
* 初始化 {@link LlmEmitter}。
Expand All @@ -38,6 +43,9 @@ public LlmEmitter(Publisher<O> publisher, Prompt prompt, FlowSession session) {
super(publisher, data -> data);
Validation.notNull(session, "The session cannot be null.");
this.consumer = ObjectUtils.nullIf(session.getInnerState(StateKey.STREAMING_CONSUMER), EMPTY_CONSUMER);
this.memory = session.getInnerState(StateKey.HISTORY);
this.question =
ObjectUtils.getIfNull(session.getInnerState(StateKey.HISTORY_INPUT), () -> getDefaultQuestion(prompt));
}

@Override
Expand All @@ -46,4 +54,21 @@ public void emit(ChatMessage data, FlowSession trans) {
this.chunkAcc.merge(data);
this.consumer.accept(this.chunkAcc, data);
}

@Override
public void complete() {
if (this.memory != null && this.chunkAcc.toolCalls().isEmpty()) {
this.memory.add(this.question);
this.memory.add(this.chunkAcc);
}
super.complete();
}

private static ChatMessage getDefaultQuestion(Prompt prompt) {
int size = prompt.messages().size();
if (size == 0) {
return new HumanMessage(StringUtils.EMPTY);
}
return prompt.messages().get(size - 1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
* This file is a part of the ModelEngine Project.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

package modelengine.fel.engine.operators.models;

import modelengine.fel.core.chat.ChatMessage;
import modelengine.fel.core.chat.Prompt;
import modelengine.fel.core.chat.support.AiMessage;
import modelengine.fel.core.chat.support.ChatMessages;
import modelengine.fel.core.memory.Memory;
import modelengine.fel.core.tool.ToolCall;
import modelengine.fel.engine.util.StateKey;
import modelengine.fit.waterflow.domain.context.FlowSession;
import modelengine.fitframework.flowable.Choir;
import modelengine.fitframework.util.StringUtils;

import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

/**
* 表示 {@link LlmEmitter} 的测试。
*
* @author 宋永坦
* @since 2025-07-05
*/
class LlmEmitterTest {
@Test
void shouldAddMemoryWhenCompleteGivenLlmOutput() {
String output = "data1";
Prompt prompt = ChatMessages.fromList(Collections.emptyList());
Choir<ChatMessage> dataSource = Choir.create(emitter -> {
emitter.emit(new AiMessage(output));
emitter.complete();
});
FlowSession flowSession = new FlowSession();
Memory mockMemory = Mockito.mock(Memory.class);
ArgumentCaptor<ChatMessage> captor = ArgumentCaptor.forClass(ChatMessage.class);
Mockito.doNothing().when(mockMemory).add(captor.capture());
flowSession.setInnerState(StateKey.HISTORY, mockMemory);

LlmEmitter<ChatMessage> llmEmitter = new LlmEmitter<>(dataSource, prompt, flowSession);
llmEmitter.start(flowSession);

List<ChatMessage> captured = captor.getAllValues();
assertEquals(2, captured.size());
assertEquals(StringUtils.EMPTY, captured.get(0).text());
assertEquals(output, captured.get(1).text());
}

@Test
void shouldNotAddMemoryWhenCompleteGivenLlmToolCallOutput() {
String output = "data1";
Prompt prompt = ChatMessages.fromList(Collections.emptyList());
Choir<ChatMessage> dataSource = Choir.create(emitter -> {
emitter.emit(new AiMessage(output, Arrays.asList(ToolCall.custom().id("id1").build())));
emitter.complete();
});
FlowSession flowSession = new FlowSession();
Memory mockMemory = Mockito.mock(Memory.class);
flowSession.setInnerState(StateKey.HISTORY, mockMemory);

LlmEmitter<ChatMessage> llmEmitter = new LlmEmitter<>(dataSource, prompt, flowSession);
llmEmitter.start(flowSession);

Mockito.verify(mockMemory, Mockito.times(0)).add(Mockito.any());
}
}