Skip to content

Commit

Permalink
Implement Function Calling using the FCC's Broadband measuring report…
Browse files Browse the repository at this point in the history
… data

Introduces the chatOptionsBuilder parameter to the AiAgent.createPrompt(..) method.
  • Loading branch information
michaelsembwever committed Apr 28, 2024
1 parent 7a96233 commit 78bdddd
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 22 deletions.
7 changes: 5 additions & 2 deletions src/main/java/com/datastax/ai/agent/AiApplication.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.UUID;

import com.datastax.ai.agent.base.AiAgent;
import com.datastax.ai.agent.broadbandStats.AiAgentFccBroadbandDataTool;
import com.datastax.ai.agent.history.AiAgentSession;
import com.datastax.ai.agent.llmCache.AiAgentSessionVector;
import com.datastax.ai.agent.reranking.AiAgentReranker;
Expand All @@ -42,6 +43,7 @@
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.vectorstore.CassandraVectorStore;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
Expand Down Expand Up @@ -71,7 +73,8 @@ public AiAgentSessionVector agent(
AiAgentReranker reranker = AiAgentReranker.create(session);
AiAgentVector vector = AiAgentVector.create(reranker, store);
AiAgentTavily tavily = AiAgentTavily.create(vector);
return AiAgentSessionVector.create(tavily, cqlSession, embeddingClient);
AiAgentFccBroadbandDataTool fcc = AiAgentFccBroadbandDataTool.create(tavily);
return AiAgentSessionVector.create(fcc, cqlSession, embeddingClient);
}

@Route("")
Expand All @@ -92,7 +95,7 @@ public AiChatUI(AiAgentSessionVector agent) {

UserMessage message = new UserMessage(question);
message.getProperties().put(AiAgentSession.SESSION_ID, sessionId);
Prompt prompt = agent.createPrompt(message, Map.of());
Prompt prompt = agent.createPrompt(message, Map.of(), OpenAiChatOptions.builder());

agent.send(prompt)
.subscribe((response) -> {
Expand Down
11 changes: 9 additions & 2 deletions src/main/java/com/datastax/ai/agent/base/AiAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,20 @@
import reactor.core.publisher.Flux;


public interface AiAgent {
public interface AiAgent<T extends Object/*ChatOptionsBuilder*/> {

Prompt createPrompt(UserMessage userMessage, Map<String,Object> promptProperties);
Prompt createPrompt(
UserMessage userMessage,
Map<String,Object> promptProperties,
T chatOptionsBuilder);

Flux<ChatResponse> send(Prompt prompt);

default Map<String,Object> promptProperties(Map<String,Object> promptProperties) {
return promptProperties;
}

default T chatOptionsBuilder(T chatOptionsBuilder) {
return chatOptionsBuilder;
}
}
12 changes: 9 additions & 3 deletions src/main/java/com/datastax/ai/agent/base/AiAgentBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;
Expand All @@ -34,7 +35,7 @@


@Configuration
public class AiAgentBase implements AiAgent {
public class AiAgentBase implements AiAgent<OpenAiChatOptions.Builder> {

@Value("classpath:/prompt-templates/system-prompt-qa.txt")
private Resource systemPrompt;
Expand All @@ -50,13 +51,18 @@ public static AiAgentBase create(StreamingChatClient chatClient) {
}

@Override
public Prompt createPrompt(UserMessage userMessage, Map<String,Object> promptProperties) {
public Prompt createPrompt(
UserMessage userMessage,
Map<String,Object> promptProperties,
OpenAiChatOptions.Builder chatOptionsBuilder) {

Message systemMessage
= new SystemPromptTemplate(this.systemPrompt)
.createMessage(promptProperties(promptProperties));

return new Prompt(List.of(systemMessage, userMessage));
return new Prompt(
List.of(systemMessage, userMessage),
chatOptionsBuilder(chatOptionsBuilder).build());
}

@Override
Expand Down
10 changes: 7 additions & 3 deletions src/main/java/com/datastax/ai/agent/base/AiAgentDelegator.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import reactor.core.publisher.Flux;


public abstract class AiAgentDelegator implements AiAgent {
public abstract class AiAgentDelegator<T extends Object/*ChatOptionsBuilder*/> implements AiAgent<T> {

private final AiAgent agent;

Expand All @@ -34,8 +34,12 @@ public AiAgentDelegator(AiAgent agent) {
}

@Override
public Prompt createPrompt(UserMessage userMessage, Map<String,Object> promptProperties) {
return agent.createPrompt(userMessage, promptProperties);
public Prompt createPrompt(
UserMessage userMessage,
Map<String,Object> promptProperties,
T chatOptionsBuilder) {

return agent.createPrompt(userMessage, promptProperties, chatOptionsBuilder(chatOptionsBuilder));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* See the NOTICE file distributed with this work for additional information
* regarding copyright ownership.
*/
package com.datastax.ai.agent.broadbandStats;

import java.util.Map;

import com.datastax.ai.agent.base.AiAgent;
import com.datastax.ai.agent.base.AiAgentDelegator;


import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatOptions;


/**
* You will need to
* `export TAVILY_API_KEY=…`
*
* Get a free API key at https://app.tavily.com/
*/
public class AiAgentFccBroadbandDataTool extends AiAgentDelegator<OpenAiChatOptions.Builder> {

public static AiAgentFccBroadbandDataTool create(AiAgent agent) {
return new AiAgentFccBroadbandDataTool(agent);
}

AiAgentFccBroadbandDataTool(AiAgent agent) {
super(agent);
}

@Override
public Prompt createPrompt(
UserMessage message,
Map<String,Object> promptProperties,
OpenAiChatOptions.Builder chatOptionsBuilder) {

return super.createPrompt(
message,
promptProperties(promptProperties),
chatOptionsBuilder(chatOptionsBuilder));
}

@Override
public OpenAiChatOptions.Builder chatOptionsBuilder(OpenAiChatOptions.Builder chatOptionsBuilder) {
return chatOptionsBuilder.withFunction("fccBroadbandDataService");
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* See the NOTICE file distributed with this work for additional information
* regarding copyright ownership.
*/
package com.datastax.ai.agent.broadbandStats;

import java.time.Instant;
import java.util.function.Function;

import com.datastax.ai.agent.broadbandStats.FccBroadbandDataService.Request;
import com.datastax.ai.agent.broadbandStats.FccBroadbandDataService.Response;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;

final class FccBroadbandDataService implements Function<Request, Response> {

public record TimeRange(Instant from, Instant to) {}

public record Request(int device_id, TimeRange range) {}

public record Response(
long sk_tx_bytes,
long sk_rx_bytes,
long cust_wired_tx_bytes,
long cust_wired_rx_bytes,
long cust_wifi_tx_bytes,
long cust_wifi_rx_bytes) {}

private final CqlSession cqlSession;

FccBroadbandDataService(CqlSession cqlSession) {
this.cqlSession = cqlSession;
}

@Override
public Response apply(Request t) {

SimpleStatement stmt = QueryBuilder.selectFrom("datastax_ai_agent", "network_traffic").all()
.whereColumn("unit_id").isEqualTo(QueryBuilder.literal(t.device_id))
.whereColumn("dtime").isGreaterThanOrEqualTo(QueryBuilder.literal(t.range.from))
.whereColumn("dtime").isLessThanOrEqualTo(QueryBuilder.literal(t.range.to))
.build();


Response sum = new Response(0, 0, 0, 0, 0, 0);
for (Row r : cqlSession.execute(stmt)) {
sum = new Response(
sum.sk_tx_bytes + r.getLong("sk_tx_bytes"),
sum.sk_rx_bytes + r.getLong("sk_rx_bytes"),
sum.cust_wired_tx_bytes + r.getLong("cust_wired_tx_bytes"),
sum.cust_wired_rx_bytes + r.getLong("cust_wired_rx_bytes"),
sum.cust_wifi_tx_bytes + r.getLong("cust_wifi_tx_bytes"),
sum.cust_wifi_rx_bytes + r.getLong("cust_wifi_rx_bytes")
);
}
return sum;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* See the NOTICE file distributed with this work for additional information
* regarding copyright ownership.
*/
package com.datastax.ai.agent.broadbandStats;

import java.util.function.Function;

import com.datastax.ai.agent.broadbandStats.FccBroadbandDataService.Request;
import com.datastax.ai.agent.broadbandStats.FccBroadbandDataService.Response;
import com.datastax.oss.driver.api.core.CqlSession;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Description;

@Configuration
public class FccBroadbandDataTool {

@Bean
@Description("Get device data usage over a period of time")
public Function<Request, Response> fccBroadbandDataService(CqlSession cqlSession) {
return new FccBroadbandDataService(cqlSession);
}
}
14 changes: 11 additions & 3 deletions src/main/java/com/datastax/ai/agent/history/AiAgentSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import reactor.core.publisher.Flux;


public final class AiAgentSession implements AiAgent {
public final class AiAgentSession implements AiAgent<Object> {

public static final String SESSION_ID = AiAgentSession.class.getSimpleName() + "_sessionId";
public static final String CONVERSATION_TS = AiAgentSession.class.getSimpleName() + "_message_timestamp";
Expand All @@ -57,7 +57,11 @@ public static AiAgentSession create(AiAgent agent, CqlSession cqlSession) {
}

@Override
public Prompt createPrompt(UserMessage message, Map<String,Object> promptProperties) {
public Prompt createPrompt(
UserMessage message,
Map<String,Object> promptProperties,
Object chatOptionsBuilder) {

String sessionId = message.getProperties().get(SESSION_ID).toString();
List<Message> history = chatHistory.get(sessionId, CHAT_HISTORY_WINDOW_SIZE);

Expand All @@ -68,7 +72,11 @@ public Prompt createPrompt(UserMessage message, Map<String,Object> promptPropert
promptProperties = new HashMap<>(promptProperties);
promptProperties.put("conversation", conversationStr);
message.getProperties().put(CONVERSATION_TS, Instant.now());
return agent.createPrompt(message, promptProperties(promptProperties));

return agent.createPrompt(
message,
promptProperties(promptProperties),
chatOptionsBuilder(chatOptionsBuilder));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

import static org.springframework.ai.vectorstore.CassandraVectorStore.SIMILARITY_FIELD_NAME;

public class AiAgentReranker extends AiAgentDelegator {
public class AiAgentReranker extends AiAgentDelegator<Object> {

private static final Logger logger = LoggerFactory.getLogger(AiAgentReranker.class);

Expand All @@ -51,7 +51,10 @@ public static AiAgentReranker create(AiAgent agent) {
}

@Override
public Prompt createPrompt(UserMessage message, Map<String, Object> promptProperties) {
public Prompt createPrompt(
UserMessage message,
Map<String, Object> promptProperties,
Object chatOptionsBuilder) {

List<Document> similarDocuments = (List<Document>) promptProperties.get("documents");
promptProperties = promptProperties(promptProperties);
Expand Down Expand Up @@ -81,7 +84,7 @@ public Prompt createPrompt(UserMessage message, Map<String, Object> promptProper
}

promptProperties.put("documents", similarDocuments);
return super.createPrompt(message, promptProperties);
return super.createPrompt(message, promptProperties, chatOptionsBuilder);
}


Expand Down
10 changes: 7 additions & 3 deletions src/main/java/com/datastax/ai/agent/tavily/AiAgentTavily.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
*
* Get a free API key at https://app.tavily.com/
*/
public class AiAgentTavily extends AiAgentDelegator {
public class AiAgentTavily extends AiAgentDelegator<Object> {

private static final String TAVILY_URL = "https://api.tavily.com/search";

Expand All @@ -59,7 +59,11 @@ public static AiAgentTavily create(AiAgent agent) {
}

@Override
public Prompt createPrompt(UserMessage message, Map<String,Object> promptProperties) {
public Prompt createPrompt(
UserMessage message,
Map<String,Object> promptProperties,
Object chatOptionsBuilder) {

promptProperties = new HashMap<>(promptProperties(promptProperties));
if ( message.getContent().length() >= 5 ) {
JSONObject post = new JSONObject();
Expand All @@ -78,6 +82,6 @@ public Prompt createPrompt(UserMessage message, Map<String,Object> promptPropert
} else {
promptProperties.put("search_results", "[]");
}
return super.createPrompt(message, promptProperties);
return super.createPrompt(message, promptProperties, chatOptionsBuilder);
}
}
Loading

0 comments on commit 78bdddd

Please sign in to comment.