Skip to content

Commit

Permalink
add perplexity prompter
Browse files Browse the repository at this point in the history
  • Loading branch information
Southclaws committed Dec 23, 2024
1 parent eb54157 commit c77df16
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 0 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ require (
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/r3labs/sse/v2 v2.10.0 // indirect
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
Expand Down Expand Up @@ -164,6 +165,7 @@ require (
golang.org/x/tools v0.25.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241113202542-65e8d215514f // indirect
google.golang.org/grpc v1.67.1 // indirect
gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
modernc.org/libc v1.61.0 // indirect
modernc.org/mathutil v1.6.0 // indirect
Expand Down
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:Om
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4=
github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/r3labs/sse/v2 v2.10.0 h1:hFEkLLFY4LDifoHdiCN/LlGBAdVJYsANaLqNYa1l/v0=
github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEmkNJ7I=
github.com/rabbitmq/amqp091-go v1.10.0 h1:STpn5XsHlHGcecLmMFCtg7mqq0RnD+zFr4uzukfVhBw=
github.com/rabbitmq/amqp091-go v1.10.0/go.mod h1:Hy4jKW5kQART1u+JkDTF9YYOQUHXqMuhrgxOEeS7G4o=
github.com/redis/rueidis v1.0.49 h1:uhjMcQ663R8st3saoo85VV9Ce37zfvRXiveZcBrS3YQ=
Expand Down Expand Up @@ -595,6 +597,7 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
Expand Down Expand Up @@ -742,6 +745,8 @@ google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFyt
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io=
google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y=
gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
3 changes: 3 additions & 0 deletions internal/infrastructure/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ func New(cfg config.Config) (Prompter, error) {
case "openai":
return newOpenAI(cfg)

case "perplexity":
return newPerplexity(cfg)

case "mock":
return newMock()

Expand Down
181 changes: 181 additions & 0 deletions internal/infrastructure/ai/perplexity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package ai

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"

"github.com/Southclaws/fault"
"github.com/Southclaws/fault/fctx"
"github.com/Southclaws/storyden/internal/config"
"github.com/r3labs/sse/v2"
)

const (
DefaultEndpoint = "https://api.perplexity.ai/chat/completions"
DefautTimeout = 10 * time.Second
)

const (
Llama_3_1SonarSmall_128kChat = "llama-3.1-sonar-small-128k-chat"
Llama_3_1SonarLarge_128kChat = "llama-3.1-sonar-large-128k-chat"
Llama_3_1SonarSmall_128kOnline = "llama-3.1-sonar-small-128k-online"
Llama_3_1SonarLarge_128kOnline = "llama-3.1-sonar-large-128k-online"
Llama_3_1_8bInstruct = "llama-3.1-8b-instruct"
Llama_3_1_70bInstruct = "llama-3.1-70b-instruct"
)

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type CompletionRequest struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
}

type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

type Choice struct {
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
Message Message `json:"message"`
Delta Message `json:"delta"`
}

type CompletionResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Created int `json:"created"`
Usage Usage `json:"usage"`
Citations []string `json:"citations"`
Object string `json:"object"`
Choices []Choice `json:"choices"`
}

type Perplexity struct {
endpoint string
apiKey string
model string
httpClient *http.Client
httpTimeout time.Duration
}

func newPerplexity(cfg config.Config) (*Perplexity, error) {
s := &Perplexity{
apiKey: cfg.OpenAIKey,
endpoint: DefaultEndpoint,
model: Llama_3_1SonarSmall_128kChat,
httpClient: &http.Client{},
httpTimeout: DefautTimeout,
}
return s, nil
}

func (s *Perplexity) Prompt(ctx context.Context, input string) (*Result, error) {
r := &CompletionResponse{}

reqBody := CompletionRequest{
Messages: []Message{{Role: "user", Content: input}},
Model: s.model,
}

requestBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}

ctx, cancel := context.WithDeadline(ctx, time.Now().Add(s.httpTimeout))
defer cancel()

req, err := http.NewRequestWithContext(ctx, "POST", s.endpoint, bytes.NewBuffer(requestBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Authorization", "Bearer "+s.apiKey)
req.Header.Set("Content-Type", "application/json")

resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusUnauthorized {
return nil, fmt.Errorf("unauthorized: check your API key")
}
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}

err = json.Unmarshal(body, r)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal response body: %w - body response=%s", err, string(body))
}

return &Result{
Answer: r.Choices[0].Message.Content,
}, nil
}

func (s *Perplexity) PromptStream(ctx context.Context, input string) (chan string, chan error) {
outch := make(chan string)
errch := make(chan error)

client := sse.NewClient(DefaultEndpoint)

eventch := make(chan *sse.Event)

go func() {
err := client.SubscribeChan("completions", eventch)
if err != nil {
errch <- fault.Wrap(err, fctx.With(ctx))
return
}

client.Unsubscribe(eventch)

for e := range eventch {
var cr CompletionResponse

if err := json.Unmarshal(e.Data, &cr); err != nil {
errch <- fault.Wrap(err, fctx.With(ctx))
return
}

fmt.Println(cr.Citations)

outch <- cr.Choices[0].Delta.Content

if cr.Choices[0].FinishReason == "stop" {
client.Unsubscribe(eventch)
break
}
}

close(outch)
close(errch)
close(eventch)
}()

return outch, errch
}

func (o *Perplexity) EmbeddingFunc() func(ctx context.Context, text string) ([]float32, error) {
panic("not implemented")
}

0 comments on commit c77df16

Please sign in to comment.