Skip to content

Commit

Permalink
Add method for determining prediction progress from logs (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt authored Aug 11, 2023
1 parent 73d97bb commit 4ca60d8
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
38 changes: 38 additions & 0 deletions prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package replicate
import (
"context"
"fmt"
"regexp"
"strings"
)

type Source string
Expand Down Expand Up @@ -32,6 +34,42 @@ type Prediction struct {
CompletedAt *string `json:"completed_at,omitempty"`
}

type PredictionProgress struct {
Percentage float64
Current int
Total int
}

func (p Prediction) Progress() *PredictionProgress {
if p.Logs == nil || *p.Logs == "" {
return nil
}

pattern := `^\s*(?P<percentage>\d+)%\s*\|.+?\|\s*(?P<current>\d+)\/(?P<total>\d+)`
re := regexp.MustCompile(pattern)

lines := strings.Split(*p.Logs, "\n")
for i := len(lines) - 1; i >= 0; i-- {
line := strings.TrimSpace(lines[i])
if re.MatchString(line) {
matches := re.FindStringSubmatch(lines[i])
if len(matches) == 4 {
var percentage, current, total int
fmt.Sscanf(matches[1], "%d", &percentage)
fmt.Sscanf(matches[2], "%d", &current)
fmt.Sscanf(matches[3], "%d", &total)
return &PredictionProgress{
Percentage: float64(percentage) / float64(100),
Current: current,
Total: total,
}
}
}
}

return nil
}

type PredictionInput map[string]interface{}
type PredictionOutput interface{}

Expand Down
74 changes: 74 additions & 0 deletions replicate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,80 @@ func TestCreatePrediction(t *testing.T) {
assert.Equal(t, "https://streaming.api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", prediction.URLs["stream"])
}

func TestPredictionProgress(t *testing.T) {
prediction := replicate.Prediction{
ID: "ufawqhfynnddngldkgtslldrkq",
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
Status: "starting",
Input: map[string]interface{}{"text": "Alice"},
Output: nil,
Error: nil,
Logs: nil,
Metrics: nil,
CreatedAt: "2022-04-26T22:13:06.224088Z",
URLs: map[string]string{
"get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
"cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel",
"stream": "https://streaming.api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
},
}

lines := []string{
"Using seed: 12345",
"0%| | 0/5 [00:00<?, ?it/s]",
"20%|██ | 1/5 [00:00<00:01, 21.38it/s]",
"40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]",
"60%|████▍ | 3/5 [00:01<00:01, 22.46it/s]",
"80%|████████ | 4/5 [00:01<00:00, 22.86it/s]",
"100%|██████████| 5/5 [00:02<00:00, 22.26it/s]",
}
logs := ""

for i, line := range lines {
logs = logs + "\n" + line
prediction.Logs = &logs

progress := prediction.Progress()

switch i {
case 0:
prediction.Status = replicate.Processing
assert.Nil(t, progress)
case 1:
assert.NotNil(t, progress)
assert.Equal(t, 0, progress.Current)
assert.Equal(t, 5, progress.Total)
assert.Equal(t, 0.0, progress.Percentage)
case 2:
assert.NotNil(t, progress)
assert.Equal(t, 1, progress.Current)
assert.Equal(t, 5, progress.Total)
assert.Equal(t, 0.2, progress.Percentage)
case 3:
assert.NotNil(t, progress)
assert.Equal(t, 2, progress.Current)
assert.Equal(t, 5, progress.Total)
assert.Equal(t, 0.4, progress.Percentage)
case 4:
assert.NotNil(t, progress)
assert.Equal(t, 3, progress.Current)
assert.Equal(t, 5, progress.Total)
assert.Equal(t, 0.6, progress.Percentage)
case 5:
assert.NotNil(t, progress)
assert.Equal(t, 4, progress.Current)
assert.Equal(t, 5, progress.Total)
assert.Equal(t, 0.8, progress.Percentage)
case 6:
assert.NotNil(t, progress)
prediction.Status = replicate.Succeeded
assert.Equal(t, 5, progress.Current)
assert.Equal(t, 5, progress.Total)
assert.Equal(t, 1.0, progress.Percentage)
}
}
}

func TestListPredictions(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/predictions", r.URL.Path)
Expand Down

0 comments on commit 4ca60d8

Please sign in to comment.