Skip to content

Commit

Permalink
Custom HTTP request to list openrouter models, so pricing information…
Browse files Browse the repository at this point in the history
… can be extracted from the API response

Part of #210
  • Loading branch information
ruiAzevedo19 authored and Munsio committed Jun 27, 2024
1 parent ffbc218 commit 20a75d3
Showing 1 changed file with 64 additions and 8 deletions.
72 changes: 64 additions & 8 deletions provider/openrouter/openrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package openrouter

import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"time"

Expand Down Expand Up @@ -51,18 +54,72 @@ func (p *Provider) ID() (id string) {
return "openrouter"
}

// ModelsList holds a list of models.
type ModelsList struct {
Models []Model `json:"data"`
}

// Model holds a model.
type Model struct {
// ID holds the model id.
ID string `json:"id"`
// Pricing holds the pricing information of a model.
Pricing Pricing `json:"pricing"`
}

// Pricing holds the pricing information of a model.
type Pricing struct {
// Prompt holds the price for a prompt in dollars per token.
Prompt string `json:"prompt"`
// Completion holds the price for a completion in dollars per token.
Completion string `json:"completion"`
// Request holds the price for a request in dollars per request.
Request string `json:"request"`
// Image holds the price for an image in dollars per token.
Image string `json:"image"`
}

// Models returns which models are available to be queried via this provider.
func (p *Provider) Models() (models []model.Model, err error) {
client := p.client()
responseModels, err := providerModels(p.baseURL + "/models")
if err != nil {
return nil, err
}

var responseModels openai.ModelsList
models = make([]model.Model, len(responseModels.Models))
for i, model := range responseModels.Models {
models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID)
}

return models, nil
}

// providerModels returns the provider's list of models given the URL to fetch the models.
func providerModels(url string) (models ModelsList, err error) {
request, err := http.NewRequest("GET", url, nil)
if err != nil {
return ModelsList{}, pkgerrors.WithStack(err)
}
request.Header.Set("Accept", "application/json")

client := &http.Client{}
var responseBody []byte
if err := retry.Do( // Query available models with a retry logic cause "openrouter.ai" has failed us in the past.
func() error {
ms, err := client.ListModels(context.Background())
response, err := client.Do(request)
if err != nil {
return pkgerrors.WithStack(err)
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return pkgerrors.Errorf("received status code %d when querying provider models", response.StatusCode)
}

responseBody, err = io.ReadAll(response.Body)
if err != nil {
return pkgerrors.WithStack(err)
}
responseModels = ms

return nil
},
Expand All @@ -71,12 +128,11 @@ func (p *Provider) Models() (models []model.Model, err error) {
retry.DelayType(retry.BackOffDelay),
retry.LastErrorOnly(true),
); err != nil {
return nil, err
return ModelsList{}, err
}

models = make([]model.Model, len(responseModels.Models))
for i, model := range responseModels.Models {
models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID)
if err = json.Unmarshal(responseBody, &models); err != nil {
return ModelsList{}, pkgerrors.WithStack(err)
}

return models, nil
Expand Down

0 comments on commit 20a75d3

Please sign in to comment.