Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ push-test-agent: buildx-create build-kagent-adk
echo "Building FROM DOCKER_REGISTRY=$(DOCKER_REGISTRY)/$(DOCKER_REPO)/kagent-adk:$(VERSION)"
$(DOCKER_BUILDER) build --push $(BUILD_ARGS) $(TOOLS_IMAGE_BUILD_ARGS) -t $(DOCKER_REGISTRY)/kebab:latest -f go/test/e2e/agents/kebab/Dockerfile ./go/test/e2e/agents/kebab
kubectl apply --namespace kagent --context kind-$(KIND_CLUSTER_NAME) -f go/test/e2e/agents/kebab/agent.yaml
$(DOCKER_BUILDER) build --push $(BUILD_ARGS) $(TOOLS_IMAGE_BUILD_ARGS) -t $(DOCKER_REGISTRY)/poem-flow:latest -f python/samples/crewai/poem_flow/Dockerfile ./python

.PHONY: create-kind-cluster
create-kind-cluster:
Expand Down
88 changes: 88 additions & 0 deletions go/internal/database/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package database

import (
"encoding/json"
"errors"
"fmt"
"slices"
"time"
Expand Down Expand Up @@ -57,6 +58,13 @@ type Client interface {
StoreCheckpointWrites(writes []*LangGraphCheckpointWrite) error
ListCheckpoints(userID, threadID, checkpointNS string, checkpointID *string, limit int) ([]*LangGraphCheckpointTuple, error)
DeleteCheckpoint(userID, threadID string) error

// CrewAI methods
StoreCrewAIMemory(memory *CrewAIAgentMemory) error
SearchCrewAIMemoryByTask(userID, threadID, taskDescription string, limit int) ([]*CrewAIAgentMemory, error)
ResetCrewAIMemory(userID, threadID string) error
StoreCrewAIFlowState(state *CrewAIFlowState) error
GetCrewAIFlowState(userID, threadID string) (*CrewAIFlowState, error)
}

type LangGraphCheckpointTuple struct {
Expand Down Expand Up @@ -578,3 +586,83 @@ func (c *clientImpl) DeleteCheckpoint(userID, threadID string) error {
})

}

// CrewAI methods

// StoreCrewAIMemory stores CrewAI agent memory
func (c *clientImpl) StoreCrewAIMemory(memory *CrewAIAgentMemory) error {
err := save(c.db, memory)
if err != nil {
return fmt.Errorf("failed to store CrewAI agent memory: %w", err)
}
return nil
}

// SearchCrewAIMemoryByTask searches CrewAI agent memory by task description across all agents for a session
func (c *clientImpl) SearchCrewAIMemoryByTask(userID, threadID, taskDescription string, limit int) ([]*CrewAIAgentMemory, error) {
var memories []*CrewAIAgentMemory

// Search for task_description within the JSON memory_data field
// Using JSON_EXTRACT or JSON_UNQUOTE for MySQL/PostgreSQL, or simple LIKE for SQLite
// Sort by created_at DESC, then by score ASC (if score exists in JSON)
query := c.db.Where(
"user_id = ? AND thread_id = ? AND (memory_data LIKE ? OR JSON_EXTRACT(memory_data, '$.task_description') LIKE ?)",
userID, threadID, "%"+taskDescription+"%", "%"+taskDescription+"%",
).Order("created_at DESC, JSON_EXTRACT(memory_data, '$.score') ASC")

// Apply limit
if limit > 0 {
query = query.Limit(limit)
}

err := query.Find(&memories).Error
if err != nil {
return nil, fmt.Errorf("failed to search CrewAI agent memory by task: %w", err)
}

return memories, nil
}

// ResetCrewAIMemory deletes all CrewAI agent memory for a session
func (c *clientImpl) ResetCrewAIMemory(userID, threadID string) error {
result := c.db.Where(
"user_id = ? AND thread_id = ?",
userID, threadID,
).Delete(&CrewAIAgentMemory{})

if result.Error != nil {
return fmt.Errorf("failed to reset CrewAI agent memory: %w", result.Error)
}

return nil
}

// StoreCrewAIFlowState stores CrewAI flow state
func (c *clientImpl) StoreCrewAIFlowState(state *CrewAIFlowState) error {
err := save(c.db, state)
if err != nil {
return fmt.Errorf("failed to store CrewAI flow state: %w", err)
}
return nil
}

// GetCrewAIFlowState retrieves the most recent CrewAI flow state
func (c *clientImpl) GetCrewAIFlowState(userID, threadID string) (*CrewAIFlowState, error) {
var state CrewAIFlowState

// Get the most recent state by ordering by created_at DESC
// Thread_id is equivalent to flow_uuid used by CrewAI because in each session there is only one flow
err := c.db.Where(
"user_id = ? AND thread_id = ?",
userID, threadID,
).Order("created_at DESC").First(&state).Error

if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil // Return nil for not found, as expected by the Python client
}
return nil, fmt.Errorf("failed to get CrewAI flow state: %w", err)
}

return &state, nil
}
145 changes: 145 additions & 0 deletions go/internal/database/fake/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"sort"
"strings"
"sync"

"github.com/kagent-dev/kagent/go/api/v1alpha2"
Expand All @@ -26,6 +27,8 @@ type InMemoryFakeClient struct {
pushNotifications map[string]*protocol.TaskPushNotificationConfig // key: taskID
checkpoints map[string]*database.LangGraphCheckpoint // key: user_id:thread_id:checkpoint_ns:checkpoint_id
checkpointWrites map[string][]*database.LangGraphCheckpointWrite // key: user_id:thread_id:checkpoint_ns:checkpoint_id
crewaiMemory map[string][]*database.CrewAIAgentMemory // key: user_id:thread_id:agent_id
crewaiFlowStates map[string]*database.CrewAIFlowState // key: user_id:thread_id
nextFeedbackID int
}

Expand All @@ -43,6 +46,8 @@ func NewClient() database.Client {
pushNotifications: make(map[string]*protocol.TaskPushNotificationConfig),
checkpoints: make(map[string]*database.LangGraphCheckpoint),
checkpointWrites: make(map[string][]*database.LangGraphCheckpointWrite),
crewaiMemory: make(map[string][]*database.CrewAIAgentMemory),
crewaiFlowStates: make(map[string]*database.CrewAIFlowState),
nextFeedbackID: 1,
}
}
Expand Down Expand Up @@ -724,3 +729,143 @@ func (c *InMemoryFakeClient) ListWrites(userID, threadID, checkpointNS, checkpoi

return writes[start:end], nil
}

// CrewAI methods

// StoreCrewAIMemory stores CrewAI agent memory
func (c *InMemoryFakeClient) StoreCrewAIMemory(memory *database.CrewAIAgentMemory) error {
c.mu.Lock()
defer c.mu.Unlock()

if c.crewaiMemory == nil {
c.crewaiMemory = make(map[string][]*database.CrewAIAgentMemory)
}

key := fmt.Sprintf("%s:%s", memory.UserID, memory.ThreadID)
c.crewaiMemory[key] = append(c.crewaiMemory[key], memory)

return nil
}

// SearchCrewAIMemoryByTask searches CrewAI agent memory by task description across all agents for a session
func (c *InMemoryFakeClient) SearchCrewAIMemoryByTask(userID, threadID, taskDescription string, limit int) ([]*database.CrewAIAgentMemory, error) {
c.mu.RLock()
defer c.mu.RUnlock()

if c.crewaiMemory == nil {
return []*database.CrewAIAgentMemory{}, nil
}

var allMemories []*database.CrewAIAgentMemory

// Search across all agents for this user/thread
for key, memories := range c.crewaiMemory {
// Key format is "user_id:thread_id"
if strings.HasPrefix(key, userID+":"+threadID) {
for _, memory := range memories {
// Parse the JSON memory data and search for task_description
var memoryData map[string]interface{}
if err := json.Unmarshal([]byte(memory.MemoryData), &memoryData); err == nil {
if taskDesc, ok := memoryData["task_description"].(string); ok {
if strings.Contains(strings.ToLower(taskDesc), strings.ToLower(taskDescription)) {
allMemories = append(allMemories, memory)
}
}
}
// Fallback to simple string search if JSON parsing fails
if len(allMemories) == 0 && strings.Contains(strings.ToLower(memory.MemoryData), strings.ToLower(taskDescription)) {
allMemories = append(allMemories, memory)
}
}
}
}

// Sort by created_at DESC, then by score ASC (if score exists in JSON)
sort.Slice(allMemories, func(i, j int) bool {
// First sort by created_at DESC (most recent first)
if !allMemories[i].CreatedAt.Equal(allMemories[j].CreatedAt) {
return allMemories[i].CreatedAt.After(allMemories[j].CreatedAt)
}

// If created_at is equal, sort by score ASC
var scoreI, scoreJ float64
var memoryDataI, memoryDataJ map[string]interface{}

if err := json.Unmarshal([]byte(allMemories[i].MemoryData), &memoryDataI); err == nil {
if score, ok := memoryDataI["score"].(float64); ok {
scoreI = score
}
}

if err := json.Unmarshal([]byte(allMemories[j].MemoryData), &memoryDataJ); err == nil {
if score, ok := memoryDataJ["score"].(float64); ok {
scoreJ = score
}
}

return scoreI < scoreJ
})

// Apply limit
if limit > 0 && len(allMemories) > limit {
allMemories = allMemories[:limit]
}

return allMemories, nil
}

// ResetCrewAIMemory deletes all CrewAI agent memory for a session
func (c *InMemoryFakeClient) ResetCrewAIMemory(userID, threadID string) error {
c.mu.Lock()
defer c.mu.Unlock()

if c.crewaiMemory == nil {
return nil
}

// Find and delete all memory entries for this user/thread combination
keysToDelete := make([]string, 0)
for key := range c.crewaiMemory {
// Key format is "user_id:thread_id"
if strings.HasPrefix(key, userID+":"+threadID) {
keysToDelete = append(keysToDelete, key)
}
}

// Delete the entries
for _, key := range keysToDelete {
delete(c.crewaiMemory, key)
}

return nil
}

// StoreCrewAIFlowState stores CrewAI flow state
func (c *InMemoryFakeClient) StoreCrewAIFlowState(state *database.CrewAIFlowState) error {
c.mu.Lock()
defer c.mu.Unlock()

if c.crewaiFlowStates == nil {
c.crewaiFlowStates = make(map[string]*database.CrewAIFlowState)
}

key := fmt.Sprintf("%s:%s", state.UserID, state.ThreadID)
c.crewaiFlowStates[key] = state

return nil
}

// GetCrewAIFlowState retrieves CrewAI flow state
func (c *InMemoryFakeClient) GetCrewAIFlowState(userID, threadID string) (*database.CrewAIFlowState, error) {
c.mu.RLock()
defer c.mu.RUnlock()

if c.crewaiFlowStates == nil {
return nil, nil
}

key := fmt.Sprintf("%s:%s", userID, threadID)
state := c.crewaiFlowStates[key]

return state, nil
}
4 changes: 4 additions & 0 deletions go/internal/database/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ func (m *Manager) Initialize() error {
&ToolServer{},
&LangGraphCheckpoint{},
&LangGraphCheckpointWrite{},
&CrewAIAgentMemory{},
&CrewAIFlowState{},
)

if err != nil {
Expand Down Expand Up @@ -130,6 +132,8 @@ func (m *Manager) Reset(recreateTables bool) error {
&ToolServer{},
&LangGraphCheckpoint{},
&LangGraphCheckpointWrite{},
&CrewAIAgentMemory{},
&CrewAIFlowState{},
)

if err != nil {
Expand Down
25 changes: 25 additions & 0 deletions go/internal/database/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,29 @@ type LangGraphCheckpointWrite struct {
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"`
}

// CrewAIAgentMemory represents long-term memory for CrewAI agents
type CrewAIAgentMemory struct {
UserID string `gorm:"primaryKey;not null" json:"user_id"`
ThreadID string `gorm:"primaryKey;not null" json:"thread_id"`
CreatedAt time.Time `gorm:"autoCreateTime;index:idx_crewai_memory_list" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"`
// MemoryData contains JSON serialized memory data including task_description, score, metadata, datetime
MemoryData string `gorm:"type:text;not null" json:"memory_data"`
}

// CrewAIFlowState represents flow state for CrewAI flows
type CrewAIFlowState struct {
UserID string `gorm:"primaryKey;not null" json:"user_id"`
ThreadID string `gorm:"primaryKey;not null" json:"thread_id"`
MethodName string `gorm:"primaryKey;not null" json:"method_name"`
CreatedAt time.Time `gorm:"autoCreateTime;index:idx_crewai_flow_state_list" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"`
// StateData contains JSON serialized flow state data
StateData string `gorm:"type:text;not null" json:"state_data"`
}

// TableName methods to match Python table names
func (Agent) TableName() string { return "agent" }
func (Event) TableName() string { return "event" }
Expand All @@ -190,3 +213,5 @@ func (Tool) TableName() string { return "tool" }
func (ToolServer) TableName() string { return "toolserver" }
func (LangGraphCheckpoint) TableName() string { return "lg_checkpoint" }
func (LangGraphCheckpointWrite) TableName() string { return "lg_checkpoint_write" }
func (CrewAIAgentMemory) TableName() string { return "crewai_agent_memory" }
func (CrewAIFlowState) TableName() string { return "crewai_flow_state" }
Loading
Loading