Skip to content

Commit

Permalink
feat: invoke all models
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Jul 19, 2024
1 parent cbd3189 commit dfe1d2e
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 7 deletions.
68 changes: 68 additions & 0 deletions internal/types/entities/model_entities/rerank_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package model_entities

import (
"testing"

"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
)

func TestRerankFullFunction(t *testing.T) {
const (
rerank = `
{
"model": "rerank",
"docs": [
{
"index": 1,
"text": "text",
"score": 0.1
}
]
}`
)

_, err := parser.UnmarshalJsonBytes[RerankResult]([]byte(rerank))
if err != nil {
t.Error(err)
}
}

func TestRerankWrongDocs(t *testing.T) {
const (
rerank = `
{
"model": "rerank",
"docs": [
{
"index": 1,
"text": "text"
}
]
}`
)

_, err := parser.UnmarshalJsonBytes[RerankResult]([]byte(rerank))
if err == nil {
t.Error("should have error")
}
}

func TestRerankWrongDocIndex(t *testing.T) {
const (
rerank = `
{
"model": "rerank",
"docs": [
{
"text": "text",
"score": 0.1
}
]
}`
)

_, err := parser.UnmarshalJsonBytes[RerankResult]([]byte(rerank))
if err == nil {
t.Error("should have error")
}
}
18 changes: 18 additions & 0 deletions internal/types/entities/model_entities/text_embedding.go
Original file line number Diff line number Diff line change
@@ -1 +1,19 @@
package model_entities

import "github.com/shopspring/decimal"

type EmbeddingUsage struct {
Tokens *int `json:"tokens" validate:"required"`
TotalTokens *int `json:"total_tokens" validate:"required"`
UnitPrice decimal.Decimal `json:"unit_price" validate:"required"`
PriceUnit decimal.Decimal `json:"price_unit" validate:"required"`
TotalPrice decimal.Decimal `json:"total_price" validate:"required"`
Currency *string `json:"currency" validate:"required"`
Latency *float64 `json:"latency" validate:"required"`
}

type TextEmbeddingResult struct {
Model string `json:"model" validate:"required"`
Embeddings [][]float64 `json:"embeddings" validate:"required,dive"`
Usage EmbeddingUsage `json:"usage" validate:"required"`
}
84 changes: 84 additions & 0 deletions internal/types/entities/model_entities/text_embedding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package model_entities

import (
"testing"

"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
)

func TestTextEmbeddingFullFunction(t *testing.T) {
const (
text_embedding = `
{
"model": "text_embedding",
"embeddings": [[
0.1, 0.2, 0.3
]],
"usage" : {
"tokens": 3,
"total_tokens": 100,
"unit_price": 0.1,
"price_unit": 1,
"total_price": 10,
"currency": "usd",
"latency": 0.1
}
}`
)

_, err := parser.UnmarshalJsonBytes[TextEmbeddingResult]([]byte(text_embedding))
if err != nil {
t.Error(err)
}
}

func TestTextEmbeddingWrongUsage(t *testing.T) {
const (
text_embedding = `
{
"model": "text_embedding",
"embeddings": [[
0.1, 0.2, 0.3
]],
"usage" : {
"tokens": 3,
"total_tokens": 100,
"unit_price": 0.1,
"price_unit": 1,
"total_price": 10,
"currency": "usd"
}
}`
)

_, err := parser.UnmarshalJsonBytes[TextEmbeddingResult]([]byte(text_embedding))
if err == nil {
t.Error("should have error")
}
}

func TestTextEmbeddingWrongEmbeddings(t *testing.T) {
const (
text_embedding = `
{
"model": "text_embedding",
"embeddings": [
0.1, 0.2, 0.3
],
"usage" : {
"tokens": 3,
"total_tokens": 100,
"unit_price": 0.1,
"price_unit": 1,
"total_price": 10,
"currency": "usd",
"latency": 0.1
}
}`
)

_, err := parser.UnmarshalJsonBytes[TextEmbeddingResult]([]byte(text_embedding))
if err == nil {
t.Error("should have error")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type I18nObject struct {
PtBr string `json:"pt_BR" validate:"lt=1024"`
}

func isGenericType(fl validator.FieldLevel) bool {
func isBasicType(fl validator.FieldLevel) bool {
// allowed int, string, bool, float64
switch fl.Field().Kind() {
case reflect.Int, reflect.String, reflect.Bool, reflect.Float64:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,5 @@ func init() {

validators.GlobalEntitiesValidator.RegisterValidation("parameter_rule", isParameterRule)

validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType)
validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isBasicType)
}
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func init() {
},
)

validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType)
validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isBasicType)
}

func UnmarshalToolProviderConfiguration(data []byte) (*ToolProviderConfiguration, error) {
Expand Down
41 changes: 37 additions & 4 deletions internal/types/entities/requests/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,47 @@ import (
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
)

type BaseRequestInvokeModel struct {
Provider string `json:"provider"`
ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type"`
Model string `json:"model"`
Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
}

type RequestInvokeLLM struct {
Provider string `json:"provider"`
ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type"`
Model string `json:"model"`
BaseRequestInvokeModel

ModelParameters map[string]any `json:"model_parameters" validate:"omitempty,dive,is_basic_type"`
PromptMessages []model_entities.PromptMessage `json:"prompt_messages" validate:"omitempty,dive"`
Tools []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"`
Stop []string `json:"stop" validate:"omitempty"`
Stream bool `json:"stream"`
Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
}

type RequestInvokeTextEmbedding struct {
BaseRequestInvokeModel

Texts []string `json:"texts" validate:"required,dive"`
}

type RequestInvokeRerank struct {
BaseRequestInvokeModel

Query string `json:"query" validate:"required"`
Docs []string `json:"docs" validate:"required,dive"`
ScoreThreshold float64 `json:"score_threshold"`
TopN int `json:"top_n"`
}

type RequestInvokeTTS struct {
BaseRequestInvokeModel

ContentText string `json:"content_text" validate:"required"`
Voice string `json:"voice" validate:"required"`
}

type RequestInvokeSpeech2Text struct {
BaseRequestInvokeModel

File string `json:"file" validate:"required"` // base64 encoded voice file
}

0 comments on commit dfe1d2e

Please sign in to comment.