diff --git a/neo/message/json.go b/neo/message/json.go index 4ada8f6cf9..199d2ffa8c 100644 --- a/neo/message/json.go +++ b/neo/message/json.go @@ -75,6 +75,66 @@ func (json *JSON) Text(text string) *JSON { return json } +// Map set from map +func (json *JSON) Map(msg map[string]interface{}) *JSON { + if msg == nil { + return json + } + + if text, ok := msg["text"].(string); ok { + json.Message.Text = text + } + + if done, ok := msg["done"].(bool); ok { + json.Message.Done = done + } + + if confirm, ok := msg["confirm"].(bool); ok { + json.Message.Confirm = confirm + } + + if command, ok := msg["command"].(map[string]interface{}); ok { + json.Message.Command = &Command{} + if id, ok := command["id"].(string); ok { + json.Message.Command.ID = id + } + if name, ok := command["name"].(string); ok { + json.Message.Command.Name = name + } + if request, ok := command["request"].(string); ok { + json.Message.Command.Reqeust = request + } + } + + if actions, ok := msg["actions"].([]interface{}); ok { + for _, action := range actions { + if v, ok := action.(map[string]interface{}); ok { + action := Action{} + if name, ok := v["name"].(string); ok { + action.Name = name + } + if t, ok := v["type"].(string); ok { + action.Type = t + } + if payload, ok := v["payload"].(map[string]interface{}); ok { + action.Payload = payload + } + + if next, ok := v["next"].(string); ok { + action.Next = next + } + json.Message.Actions = append(json.Message.Actions, action) + } + } + } + + if data, ok := msg["data"].(map[string]interface{}); ok { + json.Message.Data = data + } + + return json +} + // Done set the done func (json *JSON) Done() *JSON { json.Message.Done = true diff --git a/neo/message/types.go b/neo/message/types.go index 45d3b26a79..468339df92 100644 --- a/neo/message/types.go +++ b/neo/message/types.go @@ -2,13 +2,13 @@ package message // Message the message type Message struct { - Text string `json:"text,omitempty"` - Error string `json:"error,omitempty"` - Done bool `json:"done,omitempty"` - Confirm bool `json:"confirm,omitempty"` - Command *Command `json:"command,omitempty"` - Actions []Action `json:"actions,omitempty"` - Data map[string]interface{} + Text string `json:"text,omitempty"` + Error string `json:"error,omitempty"` + Done bool `json:"done,omitempty"` + Confirm bool `json:"confirm,omitempty"` + Command *Command `json:"command,omitempty"` + Actions []Action `json:"actions,omitempty"` + Data map[string]interface{} `json:"-,omitempty"` } // Action the action diff --git a/neo/neo.go b/neo/neo.go index f55b875624..eb28412297 100644 --- a/neo/neo.go +++ b/neo/neo.go @@ -180,6 +180,7 @@ func (neo *DSL) Answer(ctx command.Context, question string, answer Answer) erro return } + log.Trace("Command with AI: question: %s messages:%v", question, messages) err = req.Run(messages, func(msg *message.JSON) int { chanStream <- msg return 1 @@ -193,6 +194,7 @@ func (neo *DSL) Answer(ctx command.Context, question string, answer Answer) erro } // chat with AI + log.Trace("Chat with AI: question:%s messages:%v", question, messages) _, ex := neo.AI.ChatCompletionsWith(ctx, messages, neo.Option, func(data []byte) int { chanStream <- message.NewOpenAI(data) return 1 @@ -246,8 +248,13 @@ func (neo *DSL) Answer(ctx command.Context, question string, answer Answer) erro return true } - msg.Write(w) content = msg.Append(content) + err := neo.write(msg, w, ctx, messages, content) + if err != nil { + log.Warn("Neo write process msg: %v error: %s", msg, err.Error()) + msg.Write(w) + } + return !msg.IsDone() case <-ctx.Done(): @@ -302,6 +309,46 @@ func (neo *DSL) prompts() []map[string]interface{} { return prompts } +// after the after hook +func (neo *DSL) write(msg *message.JSON, w io.Writer, ctx command.Context, messages []map[string]interface{}, content []byte) error { + + if neo.Write == "" { + msg.Write(w) + return nil + } + + args := []interface{}{ctx, messages, msg} + if msg.IsDone() { + args = append(args, string(content)) + } + + p, err := process.Of(neo.Write, args...) + if err != nil { + return err + } + + res, err := p.WithSID(ctx.Sid).Exec() + if err != nil { + return err + } + + if res == nil { + return fmt.Errorf("Neo custom write return null") + } + + if messages, ok := res.([]interface{}); ok { + for _, new := range messages { + if v, ok := new.(map[string]interface{}); ok { + newMsg := message.New().Map(v) + newMsg.Write(w) + } + } + return nil + } + + return fmt.Errorf("Neo custom write return not map") +} + // prepare the messages func (neo *DSL) prepare(ctx command.Context, messages []map[string]interface{}) []map[string]interface{} { if neo.Prepare == "" { diff --git a/neo/types.go b/neo/types.go index df7be80651..37129e965b 100644 --- a/neo/types.go +++ b/neo/types.go @@ -18,6 +18,7 @@ type DSL struct { ConversationSetting conversation.Setting `json:"conversation" yaml:"conversation"` Option map[string]interface{} `json:"option"` Prepare string `json:"prepare,omitempty"` + Write string `json:"write,omitempty"` Prompts []aigc.Prompt `json:"prompts,omitempty"` Allows []string `json:"allows,omitempty"` Command Command `json:"command,omitempty"`