Skip to content

Commit

Permalink
refactor: types
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Jul 23, 2024
1 parent f62f3af commit 82df72e
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 184 deletions.
30 changes: 24 additions & 6 deletions internal/core/dify_invocation/http_request.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dify_invocation

import (
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
"github.com/langgenius/dify-plugin-daemon/internal/utils/http_requests"
"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
)
Expand Down Expand Up @@ -29,14 +31,30 @@ func StreamResponse[T any](method string, path string, options ...http_requests.
return http_requests.RequestAndParseStream[T](client, difyPath(path), method, options...)
}

func InvokeModel(payload *InvokeModelRequest) (*stream.StreamResponse[InvokeModelResponseChunk], error) {
return StreamResponse[InvokeModelResponseChunk]("POST", "invoke/model", http_requests.HttpPayloadJson(payload))
func InvokeLLM(payload *InvokeLLMRequest) (*stream.StreamResponse[model_entities.LLMResultChunk], error) {
return StreamResponse[model_entities.LLMResultChunk]("POST", "invoke/llm", http_requests.HttpPayloadJson(payload))
}

func InvokeTool(payload *InvokeToolRequest) (*stream.StreamResponse[InvokeToolResponseChunk], error) {
return StreamResponse[InvokeToolResponseChunk]("POST", "invoke/tool", http_requests.HttpPayloadJson(payload))
func InvokeTextEmbedding(payload *InvokeTextEmbeddingRequest) (*model_entities.TextEmbeddingResult, error) {
return Request[model_entities.TextEmbeddingResult]("POST", "invoke/text_embedding", http_requests.HttpPayloadJson(payload))
}

func InvokeNode[T WorkflowNodeData](payload *InvokeNodeRequest[T]) (*InvokeNodeResponse, error) {
return Request[InvokeNodeResponse]("POST", "invoke/node", http_requests.HttpPayloadJson(payload))
func InvokeRerank(payload *InvokeRerankRequest) (*model_entities.RerankResult, error) {
return Request[model_entities.RerankResult]("POST", "invoke/rerank", http_requests.HttpPayloadJson(payload))
}

func InvokeTTS(payload *InvokeTTSRequest) (*stream.StreamResponse[model_entities.TTSResult], error) {
return StreamResponse[model_entities.TTSResult]("POST", "invoke/tts", http_requests.HttpPayloadJson(payload))
}

func InvokeSpeech2Text(payload *InvokeSpeech2TextRequest) (*model_entities.Speech2TextResult, error) {
return Request[model_entities.Speech2TextResult]("POST", "invoke/speech2text", http_requests.HttpPayloadJson(payload))
}

func InvokeModeration(payload *InvokeModerationRequest) (*model_entities.ModerationResult, error) {
return Request[model_entities.ModerationResult]("POST", "invoke/moderation", http_requests.HttpPayloadJson(payload))
}

func InvokeTool(payload *InvokeToolRequest) (*stream.StreamResponse[tool_entities.ToolResponseChunk], error) {
return StreamResponse[tool_entities.ToolResponseChunk]("POST", "invoke/tool", http_requests.HttpPayloadJson(payload))
}
96 changes: 51 additions & 45 deletions internal/core/dify_invocation/types.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package dify_invocation

import (
"encoding/json"

"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
)

type BaseInvokeDifyRequest struct {
Expand All @@ -15,66 +13,74 @@ type BaseInvokeDifyRequest struct {
type InvokeType string

const (
INVOKE_TYPE_MODEL InvokeType = "model"
INVOKE_TYPE_TOOL InvokeType = "tool"
INVOKE_TYPE_NODE InvokeType = "node"
INVOKE_TYPE_LLM InvokeType = "LLM"
INVOKE_TYPE_TEXT_EMBEDDING InvokeType = "text_embedding"
INVOKE_TYPE_RERANK InvokeType = "rerank"
INVOKE_TYPE_TTS InvokeType = "tts"
INVOKE_TYPE_SPEECH2TEXT InvokeType = "speech2text"
INVOKE_TYPE_MODERATION InvokeType = "moderation"
INVOKE_TYPE_TOOL InvokeType = "tool"
INVOKE_TYPE_NODE InvokeType = "node"
)

type InvokeModelRequest struct {
type InvokeLLMRequest struct {
BaseInvokeDifyRequest
Provider string `json:"provider"`
Model string `json:"model"`
ModelType model_entities.ModelType `json:"model_type"`
Parameters map[string]any `json:"parameters"`
}

func (r InvokeModelRequest) MarshalJSON() ([]byte, error) {
flattened := make(map[string]any)
flattened["tenant_id"] = r.TenantId
flattened["user_id"] = r.UserId
flattened["provider"] = r.Provider
flattened["model"] = r.Model
flattened["parameters"] = r.Parameters
return json.Marshal(flattened)
Data struct {
requests.BaseRequestInvokeModel
requests.InvokeLLMSchema
} `json:"data" validate:"required"`
}

type InvokeModelResponseChunk struct {
type InvokeTextEmbeddingRequest struct {
BaseInvokeDifyRequest
Data struct {
requests.BaseRequestInvokeModel
requests.InvokeTextEmbeddingSchema
} `json:"data" validate:"required"`
}

type InvokeToolRequest struct {
type InvokeRerankRequest struct {
BaseInvokeDifyRequest
Provider string `json:"provider"`
Tool string `json:"tool"`
Parameters map[string]any `json:"parameters"`
Data struct {
requests.BaseRequestInvokeModel
requests.InvokeRerankSchema
} `json:"data" validate:"required"`
}

func (r InvokeToolRequest) MarshalJSON() ([]byte, error) {
flattened := make(map[string]any)
flattened["tenant_id"] = r.TenantId
flattened["user_id"] = r.UserId
flattened["provider"] = r.Provider
flattened["tool"] = r.Tool
flattened["parameters"] = r.Parameters
return json.Marshal(flattened)
type InvokeTTSRequest struct {
BaseInvokeDifyRequest
Data struct {
requests.BaseRequestInvokeModel
requests.InvokeTTSSchema
} `json:"data" validate:"required"`
}

type InvokeToolResponseChunk struct {
type InvokeSpeech2TextRequest struct {
BaseInvokeDifyRequest
Data struct {
requests.BaseRequestInvokeModel
requests.InvokeSpeech2TextSchema
} `json:"data" validate:"required"`
}

type InvokeNodeRequest[T WorkflowNodeData] struct {
type InvokeModerationRequest struct {
BaseInvokeDifyRequest
NodeType NodeType `json:"node_type"`
NodeData T `json:"node_data"`
Data struct {
requests.BaseRequestInvokeModel
requests.InvokeModerationSchema
} `json:"data" validate:"required"`
}

func (r InvokeNodeRequest[T]) MarshalJSON() ([]byte, error) {
flattened := make(map[string]any)
flattened["tenant_id"] = r.TenantId
flattened["user_id"] = r.UserId
flattened["node_type"] = r.NodeType
flattened["node_data"] = r.NodeData
return json.Marshal(flattened)
type InvokeToolRequest struct {
BaseInvokeDifyRequest
Data struct {
requests.RequestInvokeTool
} `json:"data" validate:"required"`
}

type InvokeNodeResponse struct {
ProcessData map[string]any `json:"process_data"`
Output map[string]any `json:"output"`
Input map[string]any `json:"input"`
EdgeSourceHandle []string `json:"edge_source_handle"`
}
114 changes: 2 additions & 112 deletions internal/core/plugin_daemon/invoke_dify.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import (
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
)

func invokeDify(
Expand Down Expand Up @@ -69,126 +67,18 @@ func prepareDifyInvocationArguments(session *session_manager.Session, request ma
func dispatchDifyInvocationTask(handle *backwards_invocation.BackwardsInvocation) {
switch handle.Type() {
case dify_invocation.INVOKE_TYPE_TOOL:
r, err := parser.MapToStruct[dify_invocation.InvokeToolRequest](handle.RequestData())
_, err := parser.MapToStruct[dify_invocation.InvokeToolRequest](handle.RequestData())
if err != nil {
handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error()))
return
}

submitToolTask(runtime, session, backwards_request_id, &r)
case dify_invocation.INVOKE_TYPE_MODEL:
r, err := parser.MapToStruct[dify_invocation.InvokeModelRequest](handle.RequestData())
if err != nil {
handle.WriteError(fmt.Errorf("unmarshal invoke model request failed: %s", err.Error()))
return
}

submitModelTask(runtime, session, backwards_request_id, &r)
case dify_invocation.INVOKE_TYPE_NODE:
node_type, ok := detailed_request["node_type"].(dify_invocation.NodeType)
if !ok {
return fmt.Errorf("invoke request missing node_type: %s", data)
}
node_data, ok := detailed_request["data"].(map[string]any)
if !ok {
return fmt.Errorf("invoke request missing data: %s", data)
}
switch node_type {
case dify_invocation.QUESTION_CLASSIFIER:
d := dify_invocation.InvokeNodeRequest[dify_invocation.QuestionClassifierNodeData]{
NodeType: dify_invocation.QUESTION_CLASSIFIER,
}
if err := d.FromMap(node_data); err != nil {
return fmt.Errorf("unmarshal question classifier node data failed: %s", err.Error())
}
submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d)
case dify_invocation.KNOWLEDGE_RETRIEVAL:
d := dify_invocation.InvokeNodeRequest[dify_invocation.KnowledgeRetrievalNodeData]{
NodeType: dify_invocation.KNOWLEDGE_RETRIEVAL,
}
if err := d.FromMap(node_data); err != nil {
return fmt.Errorf("unmarshal knowledge retrieval node data failed: %s", err.Error())
}
submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d)
case dify_invocation.PARAMETER_EXTRACTOR:
d := dify_invocation.InvokeNodeRequest[dify_invocation.ParameterExtractorNodeData]{
NodeType: dify_invocation.PARAMETER_EXTRACTOR,
}
if err := d.FromMap(node_data); err != nil {
return fmt.Errorf("unmarshal parameter extractor node data failed: %s", err.Error())
}
submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d)
default:
return fmt.Errorf("unknown node type: %s", node_type)
}
default:
return fmt.Errorf("unknown invoke type: %s", typ)
handle.WriteError(fmt.Errorf("unsupported invoke type: %s", handle.Type()))
}
}

func setTaskContext(session *session_manager.Session, r *dify_invocation.BaseInvokeDifyRequest) {
r.TenantId = session.TenantID()
r.UserId = session.UserID()
}

func submitModelTask(
runtime entities.PluginRuntimeInterface,
session *session_manager.Session,
request_id string,
t *dify_invocation.InvokeModelRequest,
) {
setTaskContext(session, &t.BaseInvokeDifyRequest)
routine.Submit(func() {
response, err := dify_invocation.InvokeModel(t)
if err != nil {
log.Error("invoke model failed: %s", err.Error())
return
}
defer response.Close()

for response.Next() {
chunk, _ := response.Read()
fmt.Println(chunk)
}
})
}

func submitToolTask(
runtime entities.PluginRuntimeInterface,
session *session_manager.Session,
request_id string,
t *dify_invocation.InvokeToolRequest,
) {
setTaskContext(session, &t.BaseInvokeDifyRequest)
routine.Submit(func() {
response, err := dify_invocation.InvokeTool(t)
if err != nil {
log.Error("invoke tool failed: %s", err.Error())
return
}
defer response.Close()

for response.Next() {
chunk, _ := response.Read()
fmt.Println(chunk)
}
})
}

func submitNodeInvocationRequestTask[W dify_invocation.WorkflowNodeData](
runtime entities.PluginRuntimeInterface,
session *session_manager.Session,
request_id string,
t *dify_invocation.InvokeNodeRequest[W],
) {
setTaskContext(session, &t.BaseInvokeDifyRequest)
routine.Submit(func() {
response, err := dify_invocation.InvokeNode(t)
if err != nil {
log.Error("invoke node failed: %s", err.Error())
return
}

fmt.Println(response)
})
}
Loading

0 comments on commit 82df72e

Please sign in to comment.