Skip to content

Commit

Permalink
Add cohere embedding for ai-cache (#1572)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayanami-desu authored Dec 27, 2024
1 parent 6dc4d43 commit 2d74c48
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 41 deletions.
158 changes: 158 additions & 0 deletions plugins/wasm-go/extensions/ai-cache/embedding/cohere.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package embedding

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)

const (
COHERE_DOMAIN = "api.cohere.com"
COHERE_PORT = 443
COHERE_DEFAULT_MODEL_NAME = "embed-english-v2.0"
COHERE_ENDPOINT = "/v2/embed"
)

type cohereProviderInitializer struct {
}

var cohereConfig cohereProviderConfig

type cohereProviderConfig struct {
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
}

func (c *cohereProviderInitializer) InitConfig(json gjson.Result) {
cohereConfig.apiKey = json.Get("apiKey").String()
}
func (c *cohereProviderInitializer) ValidateConfig() error {
if cohereConfig.apiKey == "" {
return errors.New("[Cohere] apiKey is required")
}
return nil
}

func (t *cohereProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
if c.servicePort == 0 {
c.servicePort = COHERE_PORT
}
if c.serviceHost == "" {
c.serviceHost = COHERE_DOMAIN
}
return &CohereProvider{
config: c,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: c.serviceName,
Host: c.serviceHost,
Port: int64(c.servicePort),
}),
}, nil
}

type cohereResponse struct {
Embeddings cohereEmbeddings `json:"embeddings"`
}

type cohereEmbeddings struct {
FloatTypeEebedding [][]float64 `json:"float"`
}

type cohereEmbeddingRequest struct {
Texts []string `json:"texts"`
Model string `json:"model"`
InputType string `json:"input_type"`
EmbeddingTypes []string `json:"embedding_types"`
}

type CohereProvider struct {
config ProviderConfig
client wrapper.HttpClient
}

func (t *CohereProvider) GetProviderType() string {
return PROVIDER_TYPE_COHERE
}
func (t *CohereProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) {
model := t.config.model

if model == "" {
model = COHERE_DEFAULT_MODEL_NAME
}
data := cohereEmbeddingRequest{
Texts: texts,
Model: model,
InputType: "search_document",
EmbeddingTypes: []string{"float"},
}

requestBody, err := json.Marshal(data)
if err != nil {
log.Errorf("failed to marshal request data: %v", err)
return "", nil, nil, err
}

headers := [][2]string{
{"Authorization", fmt.Sprintf("BEARER %s", cohereConfig.apiKey)},
{"Content-Type", "application/json"},
}

return COHERE_ENDPOINT, headers, requestBody, nil
}

func (t *CohereProvider) parseTextEmbedding(responseBody []byte) (*cohereResponse, error) {
var resp cohereResponse
err := json.Unmarshal(responseBody, &resp)
if err != nil {
return nil, err
}
return &resp, nil
}

func (t *CohereProvider) GetEmbedding(
queryString string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(emb []float64, err error)) error {
embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log)
if err != nil {
log.Errorf("failed to construct parameters: %v", err)
return err
}

var resp *cohereResponse
err = t.client.Post(embUrl, embHeaders, embRequestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {

if statusCode != http.StatusOK {
err = errors.New("failed to get embedding due to status code: " + strconv.Itoa(statusCode))
callback(nil, err)
return
}

log.Debugf("get embedding response: %d, %s", statusCode, responseBody)

resp, err = t.parseTextEmbedding(responseBody)
if err != nil {
err = fmt.Errorf("failed to parse response: %v", err)
callback(nil, err)
return
}

if len(resp.Embeddings.FloatTypeEebedding) == 0 {
err = errors.New("no embedding found in response")
callback(nil, err)
return
}

callback(resp.Embeddings.FloatTypeEebedding[0], nil)

}, t.config.timeout)
return err
}
20 changes: 16 additions & 4 deletions plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strconv"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)

const (
Expand All @@ -17,11 +18,22 @@ const (
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
)

var dashScopeConfig dashScopeProviderConfig

type dashScopeProviderInitializer struct {
}
type dashScopeProviderConfig struct {
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
}

func (c *dashScopeProviderInitializer) InitConfig(json gjson.Result) {
dashScopeConfig.apiKey = json.Get("apiKey").String()
}

func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiKey == "" {
func (c *dashScopeProviderInitializer) ValidateConfig() error {
if dashScopeConfig.apiKey == "" {
return errors.New("[DashScope] apiKey is required")
}
return nil
Expand Down Expand Up @@ -114,14 +126,14 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin
return "", nil, nil, err
}

if d.config.apiKey == "" {
if dashScopeConfig.apiKey == "" {
err := errors.New("dashScopeKey is empty")
log.Errorf("failed to construct headers: %v", err)
return "", nil, nil, err
}

headers := [][2]string{
{"Authorization", "Bearer " + d.config.apiKey},
{"Authorization", "Bearer " + dashScopeConfig.apiKey},
{"Content-Type", "application/json"},
}

Expand Down
24 changes: 19 additions & 5 deletions plugins/wasm-go/extensions/ai-cache/embedding/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"net/http"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)

const (
Expand All @@ -18,9 +20,21 @@ const (
type openAIProviderInitializer struct {
}

func (t *openAIProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiKey == "" {
return errors.New("[OpenAI] embedding service ApiKey is required")
var openAIConfig openAIProviderConfig

type openAIProviderConfig struct {
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
}

func (c *openAIProviderInitializer) InitConfig(json gjson.Result) {
openAIConfig.apiKey = json.Get("apiKey").String()
}

func (c *openAIProviderInitializer) ValidateConfig() error {
if openAIConfig.apiKey == "" {
return errors.New("[openAI] apiKey is required")
}
return nil
}
Expand Down Expand Up @@ -97,7 +111,7 @@ func (t *OpenAIProvider) constructParameters(text string, log wrapper.Log) (stri
}

headers := [][2]string{
{"Authorization", fmt.Sprintf("Bearer %s", t.config.apiKey)},
{"Authorization", fmt.Sprintf("Bearer %s", openAIConfig.apiKey)},
{"Content-Type", "application/json"},
}

Expand Down
33 changes: 13 additions & 20 deletions plugins/wasm-go/extensions/ai-cache/embedding/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,21 @@ import (
const (
PROVIDER_TYPE_DASHSCOPE = "dashscope"
PROVIDER_TYPE_TEXTIN = "textin"
PROVIDER_TYPE_COHERE = "cohere"
PROVIDER_TYPE_OPENAI = "openai"
)

type providerInitializer interface {
ValidateConfig(ProviderConfig) error
InitConfig(json gjson.Result)
ValidateConfig() error
CreateProvider(ProviderConfig) (Provider, error)
}

var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
PROVIDER_TYPE_TEXTIN: &textInProviderInitializer{},
PROVIDER_TYPE_COHERE: &cohereProviderInitializer{},
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
}
)
Expand All @@ -39,35 +42,26 @@ type ProviderConfig struct {
// @Title zh-CN 文本特征提取服务端口
// @Description zh-CN 文本特征提取服务端口
servicePort int64
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
//@Title zh-CN TextIn x-ti-app-id
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinAppId string
//@Title zh-CN TextIn x-ti-secret-code
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinSecretCode string
//@Title zh-CN TextIn request matryoshka_dim
// @Description zh-CN 仅适用于 TextIn 服务, 指定返回的向量维度。参考 https://www.textin.com/document/acge_text_embedding
textinMatryoshkaDim int
// @Title zh-CN 文本特征提取服务超时时间
// @Description zh-CN 文本特征提取服务超时时间
timeout uint32
// @Title zh-CN 文本特征提取服务使用的模型
// @Description zh-CN 用于文本特征提取的模型名称, 在 DashScope 中默认为 "text-embedding-v1"
model string

initializer providerInitializer
}

func (c *ProviderConfig) FromJson(json gjson.Result) {
c.typ = json.Get("type").String()
i, has := providerInitializers[c.typ]
if has {
i.InitConfig(json)
c.initializer = i
}
c.serviceName = json.Get("serviceName").String()
c.serviceHost = json.Get("serviceHost").String()
c.servicePort = json.Get("servicePort").Int()
c.apiKey = json.Get("apiKey").String()
c.textinAppId = json.Get("textinAppId").String()
c.textinSecretCode = json.Get("textinSecretCode").String()
c.textinMatryoshkaDim = int(json.Get("textinMatryoshkaDim").Int())
c.timeout = uint32(json.Get("timeout").Int())
c.model = json.Get("model").String()
if c.timeout == 0 {
Expand All @@ -82,11 +76,10 @@ func (c *ProviderConfig) Validate() error {
if c.typ == "" {
return errors.New("embedding service type is required")
}
initializer, has := providerInitializers[c.typ]
if !has {
if c.initializer == nil {
return errors.New("unknown embedding service provider type: " + c.typ)
}
if err := initializer.ValidateConfig(*c); err != nil {
if err := c.initializer.ValidateConfig(); err != nil {
return err
}
return nil
Expand Down
Loading

0 comments on commit 2d74c48

Please sign in to comment.