Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RunWithOptions method that supports returning file output as io.ReadCloser #77

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,79 @@ func TestAutomaticallyRetryPostRequests(t *testing.T) {
assert.ErrorContains(t, err, http.StatusText(http.StatusInternalServerError))
}

func TestRunWithOptions(t *testing.T) {
var mockServer *httptest.Server
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/predictions":
assert.Equal(t, http.MethodPost, r.Method)
prediction := replicate.Prediction{
ID: "gtsllfynndufawqhdngldkdrkq",
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
Status: replicate.Starting,
}
json.NewEncoder(w).Encode(prediction)
case "/predictions/gtsllfynndufawqhdngldkdrkq":
assert.Equal(t, http.MethodGet, r.Method)
prediction := replicate.Prediction{
ID: "gtsllfynndufawqhdngldkdrkq",
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
Status: replicate.Succeeded,
Output: map[string]interface{}{
"image": mockServer.URL + "/output.png",
"text": "Hello, world!",
},
}
json.NewEncoder(w).Encode(prediction)
case "/output.png":
w.Header().Set("Content-Type", "image/png")
w.Write([]byte("mock image data"))
default:
t.Fatalf("Unexpected request to %s", r.URL.Path)
}
}))
defer mockServer.Close()

client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx := context.Background()
input := replicate.PredictionInput{"prompt": "A test image"}

// Test with WithFileOutput option
output, err := client.RunWithOptions(ctx, "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil, replicate.WithFileOutput())

require.NoError(t, err)
assert.NotNil(t, output)

// Check if the image output is transformed to io.ReadCloser
imageOutput, ok := output.(map[string]interface{})["image"].(io.ReadCloser)
require.True(t, ok, "Expected image output to be io.ReadCloser")

imageData, err := io.ReadAll(imageOutput)
require.NoError(t, err)
assert.Equal(t, []byte("mock image data"), imageData)

// Check if the text output remains unchanged
textOutput, ok := output.(map[string]interface{})["text"].(string)
require.True(t, ok, "Expected text output to be string")
assert.Equal(t, "Hello, world!", textOutput)

// Test without WithFileOutput option
outputWithoutFileOption, err := client.RunWithOptions(ctx, "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil)

require.NoError(t, err)
assert.NotNil(t, outputWithoutFileOption)

// Check if the image output remains a URL string
imageOutputURL, ok := outputWithoutFileOption.(map[string]interface{})["image"].(string)
require.True(t, ok, "Expected image output to be string")
assert.Equal(t, mockServer.URL+"/output.png", imageOutputURL)
}

func TestStream(t *testing.T) {
tokens := []string{"Alpha", "Bravo", "Charlie", "Delta", "Echo"}

Expand Down
127 changes: 126 additions & 1 deletion run.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,45 @@
package replicate

import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)

func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error) {
// RunOption is a function that modifies RunOptions
type RunOption func(*runOptions)

// runOptions represents options for running a model
type runOptions struct {
useFileOutput bool
}

// FileOutput is a custom type that implements io.ReadCloser and includes a URL field
type FileOutput struct {
io.ReadCloser
URL string
}

// WithFileOutput sets the UseFileOutput option to true
func WithFileOutput() RunOption {
return func(o *runOptions) {
o.useFileOutput = true
}
}

// RunWithOptions runs a model with specified options
func (r *Client) RunWithOptions(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook, opts ...RunOption) (PredictionOutput, error) {
options := runOptions{}
for _, opt := range opts {
opt(&options)
}

id, err := ParseIdentifier(identifier)
if err != nil {
return nil, err
Expand All @@ -29,5 +63,96 @@ func (r *Client) Run(ctx context.Context, identifier string, input PredictionInp
return nil, &ModelError{Prediction: prediction}
}

if options.useFileOutput {
return transformOutput(ctx, prediction.Output, r)
}

return prediction.Output, nil
}

// Run runs a model and returns the output
func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error) {
return r.RunWithOptions(ctx, identifier, input, webhook)
}

func transformOutput(ctx context.Context, value interface{}, client *Client) (interface{}, error) {
var err error
switch v := value.(type) {
case map[string]interface{}:
for k, val := range v {
v[k], err = transformOutput(ctx, val, client)
if err != nil {
return nil, err
}
}
return v, nil
case []interface{}:
for i, val := range v {
v[i], err = transformOutput(ctx, val, client)
if err != nil {
return nil, err
}
}
return v, nil
case string:
if strings.HasPrefix(v, "data:") {
return readDataURI(v)
}
if strings.HasPrefix(v, "https:") || strings.HasPrefix(v, "http:") {
return readHTTP(ctx, v, client)
}
return v, nil
}
return value, nil
}

func readDataURI(uri string) (*FileOutput, error) {
u, err := url.Parse(uri)
if err != nil {
return nil, err
}
if u.Scheme != "data" {
return nil, errors.New("not a data URI")
}
mediatype, data, found := strings.Cut(u.Opaque, ",")
if !found {
return nil, errors.New("invalid data URI format")
}
var reader io.Reader
if strings.HasSuffix(mediatype, ";base64") {
decoded, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return nil, err
}
reader = bytes.NewReader(decoded)
} else {
reader = strings.NewReader(data)
}
return &FileOutput{
ReadCloser: io.NopCloser(reader),
URL: uri,
}, nil
}

func readHTTP(ctx context.Context, url string, client *Client) (*FileOutput, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
resp, err := client.c.Do(req)
if err != nil {
return nil, err
}
if resp == nil || resp.Body == nil {
return nil, errors.New("HTTP request failed to get a response")
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, fmt.Errorf("HTTP request failed with status code %d", resp.StatusCode)
}

return &FileOutput{
ReadCloser: resp.Body,
URL: url,
}, nil
}
Loading