Skip to content

Commit

Permalink
feat: generic invocation
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Jul 19, 2024
1 parent dfe1d2e commit 5b96e61
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 118 deletions.
9 changes: 7 additions & 2 deletions internal/core/plugin_daemon/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@ const (
type PluginAccessAction string

const (
PLUGIN_ACCESS_ACTION_INVOKE_TOOL PluginAccessAction = "invoke_tool"
PLUGIN_ACCESS_ACTION_INVOKE_LLM PluginAccessAction = "invoke_llm"
PLUGIN_ACCESS_ACTION_INVOKE_TOOL PluginAccessAction = "invoke_tool"
PLUGIN_ACCESS_ACTION_INVOKE_LLM PluginAccessAction = "invoke_llm"
PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING PluginAccessAction = "invoke_text_embedding"
PLUGIN_ACCESS_ACTION_INVOKE_RERANK PluginAccessAction = "invoke_rerank"
PLUGIN_ACCESS_ACTION_INVOKE_TTS PluginAccessAction = "invoke_tts"
PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT PluginAccessAction = "invoke_speech2text"
PLUGIN_ACCESS_ACTION_INVOKE_MODERATION PluginAccessAction = "invoke_moderation"
)

const (
Expand Down
151 changes: 124 additions & 27 deletions internal/core/plugin_daemon/model_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,29 @@ import (

"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
)

func getInvokeModelMap(
func genericInvokePlugin[Req any, Rsp any](
session *session_manager.Session,
request *Req,
response_buffer_size int,
typ PluginAccessType,
action PluginAccessAction,
request *requests.RequestInvokeLLM,
) map[string]any {
req := getBasicPluginAccessMap(session.ID(), session.UserID(), PLUGIN_ACCESS_TYPE_MODEL, action)
data := req["data"].(map[string]any)

data["provider"] = request.Provider
data["model"] = request.Model
data["model_type"] = request.ModelType
data["model_parameters"] = request.ModelParameters
data["prompt_messages"] = request.PromptMessages
data["tools"] = request.Tools
data["stop"] = request.Stop
data["stream"] = request.Stream
data["credentials"] = request.Credentials

return req
}

func InvokeLLM(
session *session_manager.Session,
request *requests.RequestInvokeLLM,
) (
*stream.StreamResponse[plugin_entities.InvokeModelResponseChunk], error,
*stream.StreamResponse[Rsp], error,
) {
runtime := plugin_manager.Get(session.PluginIdentity())
if runtime == nil {
return nil, errors.New("plugin not found")
}

response := stream.NewStreamResponse[plugin_entities.InvokeModelResponseChunk](512)
response := stream.NewStreamResponse[Rsp](response_buffer_size)

listener := runtime.Listen(session.ID())
listener.AddListener(func(message []byte) {
Expand All @@ -56,7 +39,7 @@ func InvokeLLM(

switch chunk.Type {
case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
chunk, err := parser.UnmarshalJsonBytes[plugin_entities.InvokeModelResponseChunk](chunk.Data)
chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
if err != nil {
log.Error("unmarshal json failed: %s", err.Error())
return
Expand All @@ -66,8 +49,15 @@ func InvokeLLM(
invokeDify(runtime, session, chunk.Data)
case plugin_entities.SESSION_MESSAGE_TYPE_END:
response.Close()
case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
if err != nil {
break
}
response.WriteError(errors.New(e.Error))
response.Close()
default:
log.Error("unknown stream message type: %s", chunk.Type)
response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
response.Close()
}
})
Expand All @@ -79,10 +69,117 @@ func InvokeLLM(
runtime.Write(session.ID(), []byte(parser.MarshalJson(
getInvokeModelMap(
session,
PLUGIN_ACCESS_ACTION_INVOKE_LLM,
typ,
action,
request,
),
)))

return response, nil
}

func getInvokeModelMap(
session *session_manager.Session,
typ PluginAccessType,
action PluginAccessAction,
request any,
) map[string]any {
req := getBasicPluginAccessMap(session.ID(), session.UserID(), typ, action)
data := req["data"].(map[string]any)

for k, v := range parser.StructToMap(request) {
data[k] = v
}

return req
}

func InvokeLLM(
session *session_manager.Session,
request *requests.RequestInvokeLLM,
) (
*stream.StreamResponse[model_entities.LLMResultChunk], error,
) {
return genericInvokePlugin[requests.RequestInvokeLLM, model_entities.LLMResultChunk](
session,
request,
512,
PLUGIN_ACCESS_TYPE_MODEL,
PLUGIN_ACCESS_ACTION_INVOKE_LLM,
)
}

func InvokeTextEmbedding(
session *session_manager.Session,
request *requests.RequestInvokeTextEmbedding,
) (
*stream.StreamResponse[model_entities.TextEmbeddingResult], error,
) {
return genericInvokePlugin[requests.RequestInvokeTextEmbedding, model_entities.TextEmbeddingResult](
session,
request,
1,
PLUGIN_ACCESS_TYPE_MODEL,
PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
)
}

func InvokeRerank(
session *session_manager.Session,
request *requests.RequestInvokeRerank,
) (
*stream.StreamResponse[model_entities.RerankResult], error,
) {
return genericInvokePlugin[requests.RequestInvokeRerank, model_entities.RerankResult](
session,
request,
1,
PLUGIN_ACCESS_TYPE_MODEL,
PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
)
}

func InvokeTTS(
session *session_manager.Session,
request *requests.RequestInvokeTTS,
) (
*stream.StreamResponse[string], error,
) {
return genericInvokePlugin[requests.RequestInvokeTTS, string](
session,
request,
1,
PLUGIN_ACCESS_TYPE_MODEL,
PLUGIN_ACCESS_ACTION_INVOKE_TTS,
)
}

func InvokeSpeech2Text(
session *session_manager.Session,
request *requests.RequestInvokeSpeech2Text,
) (
*stream.StreamResponse[string], error,
) {
return genericInvokePlugin[requests.RequestInvokeSpeech2Text, string](
session,
request,
1,
PLUGIN_ACCESS_TYPE_MODEL,
PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
)
}

func InvokeModeration(
session *session_manager.Session,
request *requests.RequestInvokeModeration,
) (
*stream.StreamResponse[bool], error,
) {
return genericInvokePlugin[requests.RequestInvokeModeration, bool](
session,
request,
1,
PLUGIN_ACCESS_TYPE_MODEL,
PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
)
}
77 changes: 7 additions & 70 deletions internal/core/plugin_daemon/tool_service.go
Original file line number Diff line number Diff line change
@@ -1,86 +1,23 @@
package plugin_daemon

import (
"errors"

"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
)

func getInvokeToolMap(
session *session_manager.Session,
action PluginAccessAction,
request *requests.RequestInvokeTool,
) map[string]any {
req := getBasicPluginAccessMap(session.ID(), session.UserID(), PLUGIN_ACCESS_TYPE_TOOL, action)
data := req["data"].(map[string]any)

data["provider"] = request.Provider
data["tool"] = request.Tool
data["parameters"] = request.ToolParameters
data["credentials"] = request.Credentials

return req
}

func InvokeTool(
session *session_manager.Session,
request *requests.RequestInvokeTool,
) (
*stream.StreamResponse[plugin_entities.ToolResponseChunk], error,
) {
runtime := plugin_manager.Get(session.PluginIdentity())
if runtime == nil {
return nil, errors.New("plugin not found")
}

response := stream.NewStreamResponse[plugin_entities.ToolResponseChunk](512)

listener := runtime.Listen(session.ID())
listener.AddListener(func(message []byte) {
chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message)
if err != nil {
log.Error("unmarshal json failed: %s", err.Error())
return
}

switch chunk.Type {
case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
chunk, err := parser.UnmarshalJsonBytes[plugin_entities.ToolResponseChunk](chunk.Data)
if err != nil {
log.Error("unmarshal json failed: %s", err.Error())
return
}
response.Write(chunk)
case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
invokeDify(runtime, session, chunk.Data)
case plugin_entities.SESSION_MESSAGE_TYPE_END:
response.Close()
case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
if err != nil {
break
}
response.WriteError(errors.New(e.Error))
response.Close()
default:
response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
response.Close()
}
})

response.OnClose(func() {
listener.Close()
})

runtime.Write(session.ID(), []byte(parser.MarshalJson(
getInvokeToolMap(session, PLUGIN_ACCESS_ACTION_INVOKE_TOOL, request)),
))

return response, nil
return genericInvokePlugin[requests.RequestInvokeTool, plugin_entities.ToolResponseChunk](
session,
request,
128,
PLUGIN_ACCESS_TYPE_TOOL,
PLUGIN_ACCESS_ACTION_INVOKE_TOOL,
)
}
10 changes: 5 additions & 5 deletions internal/core/plugin_manager/stdio_holder/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func (s *stdioHolder) StartStdout() {
scanner := bufio.NewScanner(s.reader)
for scanner.Scan() {
data := scanner.Bytes()
if len(data) == 0 {
continue
}

event, err := parser.UnmarshalJsonBytes[plugin_entities.PluginUniversalEvent](data)
if err != nil {
// log.Error("unmarshal json failed: %s", err.Error())
Expand Down Expand Up @@ -101,11 +105,7 @@ func (s *stdioHolder) StartStdout() {
}
}
case plugin_entities.PLUGIN_EVENT_ERROR:
for listener_session_id, listener := range s.error_listener {
if listener_session_id == session_id {
listener(event.Data)
}
}
log.Error("plugin %s: %s", s.plugin_identity, event.Data)
case plugin_entities.PLUGIN_EVENT_HEARTBEAT:
s.last_active_at = time.Now()
}
Expand Down
3 changes: 2 additions & 1 deletion internal/service/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
Expand All @@ -25,7 +26,7 @@ func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM]
session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
defer session.Close()

baseSSEService(r, func() (*stream.StreamResponse[plugin_entities.InvokeModelResponseChunk], error) {
baseSSEService(r, func() (*stream.StreamResponse[model_entities.LLMResultChunk], error) {
return plugin_daemon.InvokeLLM(session, &r.Data)
}, ctx)
}
4 changes: 0 additions & 4 deletions internal/types/entities/plugin_entities/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package plugin_entities

import (
"encoding/json"

"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
)

type PluginUniversalEvent struct {
Expand Down Expand Up @@ -51,8 +49,6 @@ type PluginResponseChunk struct {
Data json.RawMessage `json:"data"`
}

type InvokeModelResponseChunk = model_entities.LLMResultChunk

type ErrorResponse struct {
Error string `json:"error"`
}
Loading

0 comments on commit 5b96e61

Please sign in to comment.