diff --git a/internal/core/dify_invocation/http_request.go b/internal/core/dify_invocation/http_request.go index b5853aa..a49e05e 100644 --- a/internal/core/dify_invocation/http_request.go +++ b/internal/core/dify_invocation/http_request.go @@ -1,6 +1,8 @@ package dify_invocation import ( + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities" "github.com/langgenius/dify-plugin-daemon/internal/utils/http_requests" "github.com/langgenius/dify-plugin-daemon/internal/utils/stream" ) @@ -29,14 +31,30 @@ func StreamResponse[T any](method string, path string, options ...http_requests. return http_requests.RequestAndParseStream[T](client, difyPath(path), method, options...) } -func InvokeModel(payload *InvokeModelRequest) (*stream.StreamResponse[InvokeModelResponseChunk], error) { - return StreamResponse[InvokeModelResponseChunk]("POST", "invoke/model", http_requests.HttpPayloadJson(payload)) +func InvokeLLM(payload *InvokeLLMRequest) (*stream.StreamResponse[model_entities.LLMResultChunk], error) { + return StreamResponse[model_entities.LLMResultChunk]("POST", "invoke/llm", http_requests.HttpPayloadJson(payload)) } -func InvokeTool(payload *InvokeToolRequest) (*stream.StreamResponse[InvokeToolResponseChunk], error) { - return StreamResponse[InvokeToolResponseChunk]("POST", "invoke/tool", http_requests.HttpPayloadJson(payload)) +func InvokeTextEmbedding(payload *InvokeTextEmbeddingRequest) (*model_entities.TextEmbeddingResult, error) { + return Request[model_entities.TextEmbeddingResult]("POST", "invoke/text_embedding", http_requests.HttpPayloadJson(payload)) } -func InvokeNode[T WorkflowNodeData](payload *InvokeNodeRequest[T]) (*InvokeNodeResponse, error) { - return Request[InvokeNodeResponse]("POST", "invoke/node", http_requests.HttpPayloadJson(payload)) +func InvokeRerank(payload *InvokeRerankRequest) (*model_entities.RerankResult, error) { + return Request[model_entities.RerankResult]("POST", "invoke/rerank", http_requests.HttpPayloadJson(payload)) +} + +func InvokeTTS(payload *InvokeTTSRequest) (*stream.StreamResponse[model_entities.TTSResult], error) { + return StreamResponse[model_entities.TTSResult]("POST", "invoke/tts", http_requests.HttpPayloadJson(payload)) +} + +func InvokeSpeech2Text(payload *InvokeSpeech2TextRequest) (*model_entities.Speech2TextResult, error) { + return Request[model_entities.Speech2TextResult]("POST", "invoke/speech2text", http_requests.HttpPayloadJson(payload)) +} + +func InvokeModeration(payload *InvokeModerationRequest) (*model_entities.ModerationResult, error) { + return Request[model_entities.ModerationResult]("POST", "invoke/moderation", http_requests.HttpPayloadJson(payload)) +} + +func InvokeTool(payload *InvokeToolRequest) (*stream.StreamResponse[tool_entities.ToolResponseChunk], error) { + return StreamResponse[tool_entities.ToolResponseChunk]("POST", "invoke/tool", http_requests.HttpPayloadJson(payload)) } diff --git a/internal/core/dify_invocation/types.go b/internal/core/dify_invocation/types.go index f6a8cea..a75e89b 100644 --- a/internal/core/dify_invocation/types.go +++ b/internal/core/dify_invocation/types.go @@ -1,9 +1,7 @@ package dify_invocation import ( - "encoding/json" - - "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests" ) type BaseInvokeDifyRequest struct { @@ -15,66 +13,74 @@ type BaseInvokeDifyRequest struct { type InvokeType string const ( - INVOKE_TYPE_MODEL InvokeType = "model" - INVOKE_TYPE_TOOL InvokeType = "tool" - INVOKE_TYPE_NODE InvokeType = "node" + INVOKE_TYPE_LLM InvokeType = "LLM" + INVOKE_TYPE_TEXT_EMBEDDING InvokeType = "text_embedding" + INVOKE_TYPE_RERANK InvokeType = "rerank" + INVOKE_TYPE_TTS InvokeType = "tts" + INVOKE_TYPE_SPEECH2TEXT InvokeType = "speech2text" + INVOKE_TYPE_MODERATION InvokeType = "moderation" + INVOKE_TYPE_TOOL InvokeType = "tool" + INVOKE_TYPE_NODE InvokeType = "node" ) -type InvokeModelRequest struct { +type InvokeLLMRequest struct { BaseInvokeDifyRequest - Provider string `json:"provider"` - Model string `json:"model"` - ModelType model_entities.ModelType `json:"model_type"` - Parameters map[string]any `json:"parameters"` -} - -func (r InvokeModelRequest) MarshalJSON() ([]byte, error) { - flattened := make(map[string]any) - flattened["tenant_id"] = r.TenantId - flattened["user_id"] = r.UserId - flattened["provider"] = r.Provider - flattened["model"] = r.Model - flattened["parameters"] = r.Parameters - return json.Marshal(flattened) + Data struct { + requests.BaseRequestInvokeModel + requests.InvokeLLMSchema + } `json:"data" validate:"required"` } -type InvokeModelResponseChunk struct { +type InvokeTextEmbeddingRequest struct { + BaseInvokeDifyRequest + Data struct { + requests.BaseRequestInvokeModel + requests.InvokeTextEmbeddingSchema + } `json:"data" validate:"required"` } -type InvokeToolRequest struct { +type InvokeRerankRequest struct { BaseInvokeDifyRequest - Provider string `json:"provider"` - Tool string `json:"tool"` - Parameters map[string]any `json:"parameters"` + Data struct { + requests.BaseRequestInvokeModel + requests.InvokeRerankSchema + } `json:"data" validate:"required"` } -func (r InvokeToolRequest) MarshalJSON() ([]byte, error) { - flattened := make(map[string]any) - flattened["tenant_id"] = r.TenantId - flattened["user_id"] = r.UserId - flattened["provider"] = r.Provider - flattened["tool"] = r.Tool - flattened["parameters"] = r.Parameters - return json.Marshal(flattened) +type InvokeTTSRequest struct { + BaseInvokeDifyRequest + Data struct { + requests.BaseRequestInvokeModel + requests.InvokeTTSSchema + } `json:"data" validate:"required"` } -type InvokeToolResponseChunk struct { +type InvokeSpeech2TextRequest struct { + BaseInvokeDifyRequest + Data struct { + requests.BaseRequestInvokeModel + requests.InvokeSpeech2TextSchema + } `json:"data" validate:"required"` } -type InvokeNodeRequest[T WorkflowNodeData] struct { +type InvokeModerationRequest struct { BaseInvokeDifyRequest - NodeType NodeType `json:"node_type"` - NodeData T `json:"node_data"` + Data struct { + requests.BaseRequestInvokeModel + requests.InvokeModerationSchema + } `json:"data" validate:"required"` } -func (r InvokeNodeRequest[T]) MarshalJSON() ([]byte, error) { - flattened := make(map[string]any) - flattened["tenant_id"] = r.TenantId - flattened["user_id"] = r.UserId - flattened["node_type"] = r.NodeType - flattened["node_data"] = r.NodeData - return json.Marshal(flattened) +type InvokeToolRequest struct { + BaseInvokeDifyRequest + Data struct { + requests.RequestInvokeTool + } `json:"data" validate:"required"` } type InvokeNodeResponse struct { + ProcessData map[string]any `json:"process_data"` + Output map[string]any `json:"output"` + Input map[string]any `json:"input"` + EdgeSourceHandle []string `json:"edge_source_handle"` } diff --git a/internal/core/plugin_daemon/invoke_dify.go b/internal/core/plugin_daemon/invoke_dify.go index 1a6cae6..b1cd613 100644 --- a/internal/core/plugin_daemon/invoke_dify.go +++ b/internal/core/plugin_daemon/invoke_dify.go @@ -7,9 +7,7 @@ import ( "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation" "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" "github.com/langgenius/dify-plugin-daemon/internal/types/entities" - "github.com/langgenius/dify-plugin-daemon/internal/utils/log" "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" - "github.com/langgenius/dify-plugin-daemon/internal/utils/routine" ) func invokeDify( @@ -69,60 +67,14 @@ func prepareDifyInvocationArguments(session *session_manager.Session, request ma func dispatchDifyInvocationTask(handle *backwards_invocation.BackwardsInvocation) { switch handle.Type() { case dify_invocation.INVOKE_TYPE_TOOL: - r, err := parser.MapToStruct[dify_invocation.InvokeToolRequest](handle.RequestData()) + _, err := parser.MapToStruct[dify_invocation.InvokeToolRequest](handle.RequestData()) if err != nil { handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error())) return } - submitToolTask(runtime, session, backwards_request_id, &r) - case dify_invocation.INVOKE_TYPE_MODEL: - r, err := parser.MapToStruct[dify_invocation.InvokeModelRequest](handle.RequestData()) - if err != nil { - handle.WriteError(fmt.Errorf("unmarshal invoke model request failed: %s", err.Error())) - return - } - - submitModelTask(runtime, session, backwards_request_id, &r) - case dify_invocation.INVOKE_TYPE_NODE: - node_type, ok := detailed_request["node_type"].(dify_invocation.NodeType) - if !ok { - return fmt.Errorf("invoke request missing node_type: %s", data) - } - node_data, ok := detailed_request["data"].(map[string]any) - if !ok { - return fmt.Errorf("invoke request missing data: %s", data) - } - switch node_type { - case dify_invocation.QUESTION_CLASSIFIER: - d := dify_invocation.InvokeNodeRequest[dify_invocation.QuestionClassifierNodeData]{ - NodeType: dify_invocation.QUESTION_CLASSIFIER, - } - if err := d.FromMap(node_data); err != nil { - return fmt.Errorf("unmarshal question classifier node data failed: %s", err.Error()) - } - submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d) - case dify_invocation.KNOWLEDGE_RETRIEVAL: - d := dify_invocation.InvokeNodeRequest[dify_invocation.KnowledgeRetrievalNodeData]{ - NodeType: dify_invocation.KNOWLEDGE_RETRIEVAL, - } - if err := d.FromMap(node_data); err != nil { - return fmt.Errorf("unmarshal knowledge retrieval node data failed: %s", err.Error()) - } - submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d) - case dify_invocation.PARAMETER_EXTRACTOR: - d := dify_invocation.InvokeNodeRequest[dify_invocation.ParameterExtractorNodeData]{ - NodeType: dify_invocation.PARAMETER_EXTRACTOR, - } - if err := d.FromMap(node_data); err != nil { - return fmt.Errorf("unmarshal parameter extractor node data failed: %s", err.Error()) - } - submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d) - default: - return fmt.Errorf("unknown node type: %s", node_type) - } default: - return fmt.Errorf("unknown invoke type: %s", typ) + handle.WriteError(fmt.Errorf("unsupported invoke type: %s", handle.Type())) } } @@ -130,65 +82,3 @@ func setTaskContext(session *session_manager.Session, r *dify_invocation.BaseInv r.TenantId = session.TenantID() r.UserId = session.UserID() } - -func submitModelTask( - runtime entities.PluginRuntimeInterface, - session *session_manager.Session, - request_id string, - t *dify_invocation.InvokeModelRequest, -) { - setTaskContext(session, &t.BaseInvokeDifyRequest) - routine.Submit(func() { - response, err := dify_invocation.InvokeModel(t) - if err != nil { - log.Error("invoke model failed: %s", err.Error()) - return - } - defer response.Close() - - for response.Next() { - chunk, _ := response.Read() - fmt.Println(chunk) - } - }) -} - -func submitToolTask( - runtime entities.PluginRuntimeInterface, - session *session_manager.Session, - request_id string, - t *dify_invocation.InvokeToolRequest, -) { - setTaskContext(session, &t.BaseInvokeDifyRequest) - routine.Submit(func() { - response, err := dify_invocation.InvokeTool(t) - if err != nil { - log.Error("invoke tool failed: %s", err.Error()) - return - } - defer response.Close() - - for response.Next() { - chunk, _ := response.Read() - fmt.Println(chunk) - } - }) -} - -func submitNodeInvocationRequestTask[W dify_invocation.WorkflowNodeData]( - runtime entities.PluginRuntimeInterface, - session *session_manager.Session, - request_id string, - t *dify_invocation.InvokeNodeRequest[W], -) { - setTaskContext(session, &t.BaseInvokeDifyRequest) - routine.Submit(func() { - response, err := dify_invocation.InvokeNode(t) - if err != nil { - log.Error("invoke node failed: %s", err.Error()) - return - } - - fmt.Println(response) - }) -} diff --git a/internal/types/entities/requests/model.go b/internal/types/entities/requests/model.go index f7d6cd5..758f447 100644 --- a/internal/types/entities/requests/model.go +++ b/internal/types/entities/requests/model.go @@ -4,16 +4,16 @@ import ( "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities" ) -type BaseRequestInvokeModel struct { - Provider string `json:"provider" validate:"required"` - Model string `json:"model" validate:"required"` +type Credentials struct { Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"` } -type RequestInvokeLLM struct { - BaseRequestInvokeModel +type BaseRequestInvokeModel struct { + Provider string `json:"provider" validate:"required"` + Model string `json:"model" validate:"required"` +} - ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type,eq=llm"` +type InvokeLLMSchema struct { ModelParameters map[string]any `json:"model_parameters" validate:"omitempty,dive,is_basic_type"` PromptMessages []model_entities.PromptMessage `json:"prompt_messages" validate:"omitempty,dive"` Tools []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"` @@ -21,43 +21,75 @@ type RequestInvokeLLM struct { Stream bool `json:"stream" ` } +type RequestInvokeLLM struct { + BaseRequestInvokeModel + Credentials + InvokeLLMSchema + + ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type,eq=llm"` +} + +type InvokeTextEmbeddingSchema struct { + Texts []string `json:"texts" validate:"required,dive"` +} + type RequestInvokeTextEmbedding struct { BaseRequestInvokeModel + Credentials + InvokeTextEmbeddingSchema ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type,eq=text-embedding"` - Texts []string `json:"texts" validate:"required,dive"` +} + +type InvokeRerankSchema struct { + Query string `json:"query" validate:"required"` + Docs []string `json:"docs" validate:"required,dive"` + ScoreThreshold float64 `json:"score_threshold" ` + TopN int `json:"top_n" ` } type RequestInvokeRerank struct { BaseRequestInvokeModel + Credentials + InvokeRerankSchema + + ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type,eq=rerank"` +} - ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type,eq=rerank"` - Query string `json:"query" validate:"required"` - Docs []string `json:"docs" validate:"required,dive"` - ScoreThreshold float64 `json:"score_threshold" ` - TopN int `json:"top_n" ` +type InvokeTTSSchema struct { + ContentText string `json:"content_text" validate:"required"` + Voice string `json:"voice" validate:"required"` } type RequestInvokeTTS struct { BaseRequestInvokeModel + Credentials + + ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type,eq=tts"` +} - ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type,eq=tts"` - ContentText string `json:"content_text" validate:"required"` - Voice string `json:"voice" validate:"required"` +type InvokeSpeech2TextSchema struct { + File string `json:"file" validate:"required"` // hexing encoded voice file } type RequestInvokeSpeech2Text struct { BaseRequestInvokeModel + Credentials + InvokeSpeech2TextSchema ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type,eq=speech2text"` - File string `json:"file" validate:"required"` // hexing encoded voice file +} + +type InvokeModerationSchema struct { + Text string `json:"text" validate:"required"` } type RequestInvokeModeration struct { BaseRequestInvokeModel + Credentials + InvokeModerationSchema ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type,eq=moderation"` - Text string `json:"text" validate:"required"` } type RequestValidateProviderCredentials struct { diff --git a/internal/core/dify_invocation/workflow_node_data.go b/internal/types/entities/requests/node.go similarity index 75% rename from internal/core/dify_invocation/workflow_node_data.go rename to internal/types/entities/requests/node.go index 11f078b..b5e35b7 100644 --- a/internal/core/dify_invocation/workflow_node_data.go +++ b/internal/types/entities/requests/node.go @@ -1,4 +1,4 @@ -package dify_invocation +package requests type WorkflowNodeData interface { KnowledgeRetrievalNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData @@ -10,7 +10,6 @@ const ( KNOWLEDGE_RETRIEVAL NodeType = "knowledge_retrieval" QUESTION_CLASSIFIER NodeType = "question_classifier" PARAMETER_EXTRACTOR NodeType = "parameter_extractor" - CODE NodeType = "code" ) type KnowledgeRetrievalNodeData struct { @@ -21,3 +20,8 @@ type QuestionClassifierNodeData struct { type ParameterExtractorNodeData struct { } + +type InvokeNodeRequest[T WorkflowNodeData] struct { + NodeType NodeType `json:"node_type"` + NodeData T `json:"node_data"` +} diff --git a/internal/types/entities/requests/tool.go b/internal/types/entities/requests/tool.go index 2982b76..d0cd0e3 100644 --- a/internal/types/entities/requests/tool.go +++ b/internal/types/entities/requests/tool.go @@ -1,10 +1,14 @@ package requests -type RequestInvokeTool struct { +type InvokeToolSchema struct { Provider string `json:"provider" validate:"required"` Tool string `json:"tool" validate:"required"` ToolParameters map[string]any `json:"tool_parameters" validate:"omitempty,dive,is_basic_type"` - Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"` +} + +type RequestInvokeTool struct { + InvokeToolSchema + Credentials } type RequestValidateToolCredentials struct {