Skip to content

Commit

Permalink
refactor: validator
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Jul 19, 2024
1 parent be0a7e8 commit cbd3189
Show file tree
Hide file tree
Showing 16 changed files with 59 additions and 191 deletions.
9 changes: 1 addition & 8 deletions internal/core/dify_invocation/workflow_node_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package dify_invocation
type WorkflowNodeData interface {
FromMap(map[string]any) error

*KnowledgeRetrievalNodeData | *QuestionClassifierNodeData | *ParameterExtractorNodeData | *CodeNodeData
*KnowledgeRetrievalNodeData | *QuestionClassifierNodeData | *ParameterExtractorNodeData
}

type NodeType string
Expand Down Expand Up @@ -35,10 +35,3 @@ type ParameterExtractorNodeData struct {
func (r *ParameterExtractorNodeData) FromMap(data map[string]any) error {
return nil
}

type CodeNodeData struct {
}

func (r *CodeNodeData) FromMap(data map[string]any) error {
return nil
}
9 changes: 0 additions & 9 deletions internal/core/plugin_daemon/invoke_dify.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,6 @@ func invokeDify(
return fmt.Errorf("unmarshal parameter extractor node data failed: %s", err.Error())
}
submitNodeInvocationRequestTask(runtime, session, request_id, &d)
case dify_invocation.CODE:
d := dify_invocation.InvokeNodeRequest[*dify_invocation.CodeNodeData]{
NodeType: dify_invocation.CODE,
NodeData: &dify_invocation.CodeNodeData{},
}
if err := d.FromMap(node_data); err != nil {
return fmt.Errorf("unmarshal code node data failed: %s", err.Error())
}
submitNodeInvocationRequestTask(runtime, session, request_id, &d)
default:
return fmt.Errorf("unknown node type: %s", node_type)
}
Expand Down
59 changes: 0 additions & 59 deletions internal/types/entities/model_entities/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,6 @@ func (p *PromptMessage) UnmarshalJSON(data []byte) error {
}
}

// Validate the struct
if err := validators.GlobalEntitiesValidator.Struct(p); err != nil {
return err
}

// validate tool call id
if p.Role == PROMPT_MESSAGE_ROLE_TOOL && p.ToolCallId == "" {
return errors.New("tool call id is required")
Expand All @@ -175,49 +170,13 @@ type PromptMessageTool struct {
Parameters map[string]any `json:"parameters"`
}

func (p *PromptMessageTool) UnmarshalJSON(data []byte) error {
type Alias PromptMessageTool
aux := &struct {
*Alias
}{
Alias: (*Alias)(p),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}

if err := validators.GlobalEntitiesValidator.Struct(p); err != nil {
return err
}

return nil
}

type LLMResultChunk struct {
Model LLMModel `json:"model" validate:"required"`
PromptMessages []PromptMessage `json:"prompt_messages" validate:"required,dive"`
SystemFingerprint string `json:"system_fingerprint" validate:"omitempty"`
Delta LLMResultChunkDelta `json:"delta" validate:"required"`
}

func (l *LLMResultChunk) UnmarshalJSON(data []byte) error {
type Alias LLMResultChunk
aux := &struct {
*Alias
}{
Alias: (*Alias)(l),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}

if err := validators.GlobalEntitiesValidator.Struct(l); err != nil {
return err
}

return nil
}

type LLMUsage struct {
PromptTokens *int `json:"prompt_tokens" validate:"required"`
PromptUnitPrice decimal.Decimal `json:"prompt_unit_price" validate:"required"`
Expand All @@ -233,24 +192,6 @@ type LLMUsage struct {
Latency *float64 `json:"latency" validate:"required"`
}

func (l *LLMUsage) UnmarshalJSON(data []byte) error {
type Alias LLMUsage
aux := &struct {
*Alias
}{
Alias: (*Alias)(l),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}

if err := validators.GlobalEntitiesValidator.Struct(l); err != nil {
return err
}

return nil
}

type LLMResultChunkDelta struct {
Index *int `json:"index" validate:"required"`
Message PromptMessage `json:"message" validate:"required"`
Expand Down
43 changes: 14 additions & 29 deletions internal/types/entities/model_entities/llm_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package model_entities

import (
"encoding/json"
"testing"

"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
)

func TestFullFunctionPromptMessage(t *testing.T) {
Expand Down Expand Up @@ -42,33 +43,31 @@ func TestFullFunctionPromptMessage(t *testing.T) {
`
)

var prompt_message PromptMessage

err := json.Unmarshal([]byte(system_message), &prompt_message)
prompt_message, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(system_message))
if err != nil {
t.Error(err)
}
if prompt_message.Role != "system" {
t.Error("role is not system")
}

err = json.Unmarshal([]byte(user_message), &prompt_message)
prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(user_message))
if err != nil {
t.Error(err)
}
if prompt_message.Role != "user" {
t.Error("role is not user")
}

err = json.Unmarshal([]byte(assistant_message), &prompt_message)
prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(assistant_message))
if err != nil {
t.Error(err)
}
if prompt_message.Role != "assistant" {
t.Error("role is not assistant")
}

err = json.Unmarshal([]byte(image_message), &prompt_message)
prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(image_message))
if err != nil {
t.Error(err)
}
Expand All @@ -79,7 +78,7 @@ func TestFullFunctionPromptMessage(t *testing.T) {
t.Error("type is not image")
}

err = json.Unmarshal([]byte(tool_message), &prompt_message)
prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(tool_message))
if err != nil {
t.Error(err)
}
Expand All @@ -101,9 +100,7 @@ func TestWrongRole(t *testing.T) {
`
)

var prompt_message PromptMessage

err := json.Unmarshal([]byte(wrong_role), &prompt_message)
_, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_role))
if err == nil {
t.Error("error is nil")
}
Expand All @@ -119,9 +116,7 @@ func TestWrongContent(t *testing.T) {
`
)

var prompt_message PromptMessage

err := json.Unmarshal([]byte(wrong_content), &prompt_message)
_, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content))
if err == nil {
t.Error("error is nil")
}
Expand All @@ -142,9 +137,7 @@ func TestWrongContentArray(t *testing.T) {
`
)

var prompt_message PromptMessage

err := json.Unmarshal([]byte(wrong_content_array), &prompt_message)
_, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content_array))
if err == nil {
t.Error("error is nil")
}
Expand All @@ -164,9 +157,7 @@ func TestWrongContentArray2(t *testing.T) {
`
)

var prompt_message PromptMessage

err := json.Unmarshal([]byte(wrong_content_array2), &prompt_message)
_, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content_array2))
if err == nil {
t.Error("error is nil")
}
Expand All @@ -191,9 +182,7 @@ func TestWrongContentArray3(t *testing.T) {
`
)

var prompt_message PromptMessage

err := json.Unmarshal([]byte(wrong_content_array3), &prompt_message)
_, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content_array3))
if err == nil {
t.Error("error is nil")
}
Expand Down Expand Up @@ -241,9 +230,7 @@ func TestFullFunctionLLMResultChunk(t *testing.T) {
`
)

var c LLMResultChunk

err := json.Unmarshal([]byte(llm_result_chunk), &c)
_, err := parser.UnmarshalJsonBytes[LLMResultChunk]([]byte(llm_result_chunk))
if err != nil {
t.Error(err)
}
Expand All @@ -269,9 +256,7 @@ func TestZeroLLMUsage(t *testing.T) {
`
)

var u LLMUsage

err := json.Unmarshal([]byte(llm_usage), &u)
_, err := parser.UnmarshalJsonBytes[LLMUsage]([]byte(llm_usage))
if err != nil {
t.Error(err)
}
Expand Down
1 change: 1 addition & 0 deletions internal/types/entities/model_entities/moderation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package model_entities
12 changes: 12 additions & 0 deletions internal/types/entities/model_entities/rerank.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package model_entities

type RerankDocument struct {
Index *int `json:"index" validate:"required"`
Text *string `json:"text" validate:"required"`
Score *float64 `json:"score" validate:"required"`
}

type RerankResult struct {
Model string `json:"model" validate:"required"`
Docs []RerankDocument `json:"docs" validate:"required,dive"`
}
1 change: 1 addition & 0 deletions internal/types/entities/model_entities/speech2text.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package model_entities
1 change: 1 addition & 0 deletions internal/types/entities/model_entities/text_embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package model_entities
1 change: 1 addition & 0 deletions internal/types/entities/model_entities/tts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package model_entities
17 changes: 0 additions & 17 deletions internal/types/entities/plugin_entities/model_configuration.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package plugin_entities

import (
"encoding/json"

"github.com/go-playground/locales/en"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
Expand Down Expand Up @@ -278,18 +276,3 @@ func init() {

validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType)
}

func UnmarshalModelProviderConfiguration(data []byte) (*ModelProviderConfiguration, error) {
var modelProviderConfiguration ModelProviderConfiguration
err := json.Unmarshal(data, &modelProviderConfiguration)
if err != nil {
return nil, err
}

err = validators.GlobalEntitiesValidator.Struct(modelProviderConfiguration)
if err != nil {
return nil, err
}

return &modelProviderConfiguration, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"testing"

"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
"gopkg.in/yaml.v3"
)

Expand Down Expand Up @@ -156,7 +157,7 @@ func TestFullFunctionModelProvider_Validate(t *testing.T) {
t.Error(err)
}

_, err = UnmarshalModelProviderConfiguration(json_data)
_, err = parser.UnmarshalJsonBytes[ModelProviderConfiguration](json_data)
if err != nil {
t.Errorf("UnmarshalModelProviderConfiguration() error = %v", err)
}
Expand Down
19 changes: 0 additions & 19 deletions internal/types/entities/plugin_entities/tool_configuration.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package plugin_entities

import (
"encoding/json"
"fmt"

"github.com/go-playground/locales/en"
Expand Down Expand Up @@ -253,24 +252,6 @@ func init() {
validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType)
}

func (t *ToolProviderConfiguration) UnmarshalJSON(data []byte) error {
type Alias ToolProviderConfiguration
aux := &struct {
*Alias
}{
Alias: (*Alias)(t),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}

if err := validators.GlobalEntitiesValidator.Struct(t); err != nil {
return err
}

return nil
}

func UnmarshalToolProviderConfiguration(data []byte) (*ToolProviderConfiguration, error) {
obj, err := parser.UnmarshalJsonBytes[ToolProviderConfiguration](data)
if err != nil {
Expand Down
21 changes: 0 additions & 21 deletions internal/types/entities/requests/model.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package requests

import (
"encoding/json"

"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
)

type RequestInvokeLLM struct {
Expand All @@ -18,21 +15,3 @@ type RequestInvokeLLM struct {
Stream bool `json:"stream"`
Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
}

func (r *RequestInvokeLLM) UnmarshalJSON(data []byte) error {
type Alias RequestInvokeLLM
aux := &struct {
*Alias
}{
Alias: (*Alias)(r),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}

if err := validators.GlobalEntitiesValidator.Struct(r); err != nil {
return err
}

return nil
}
Loading

0 comments on commit cbd3189

Please sign in to comment.