From 4ca60d8db8b84ca78b9111ca0aa67034305e10e7 Mon Sep 17 00:00:00 2001 From: Mattt Date: Fri, 11 Aug 2023 01:20:02 -0700 Subject: [PATCH] Add method for determining prediction progress from logs (#13) --- prediction.go | 38 ++++++++++++++++++++++++ replicate_test.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/prediction.go b/prediction.go index 064b09e..65cd082 100644 --- a/prediction.go +++ b/prediction.go @@ -3,6 +3,8 @@ package replicate import ( "context" "fmt" + "regexp" + "strings" ) type Source string @@ -32,6 +34,42 @@ type Prediction struct { CompletedAt *string `json:"completed_at,omitempty"` } +type PredictionProgress struct { + Percentage float64 + Current int + Total int +} + +func (p Prediction) Progress() *PredictionProgress { + if p.Logs == nil || *p.Logs == "" { + return nil + } + + pattern := `^\s*(?P\d+)%\s*\|.+?\|\s*(?P\d+)\/(?P\d+)` + re := regexp.MustCompile(pattern) + + lines := strings.Split(*p.Logs, "\n") + for i := len(lines) - 1; i >= 0; i-- { + line := strings.TrimSpace(lines[i]) + if re.MatchString(line) { + matches := re.FindStringSubmatch(lines[i]) + if len(matches) == 4 { + var percentage, current, total int + fmt.Sscanf(matches[1], "%d", &percentage) + fmt.Sscanf(matches[2], "%d", ¤t) + fmt.Sscanf(matches[3], "%d", &total) + return &PredictionProgress{ + Percentage: float64(percentage) / float64(100), + Current: current, + Total: total, + } + } + } + } + + return nil +} + type PredictionInput map[string]interface{} type PredictionOutput interface{} diff --git a/replicate_test.go b/replicate_test.go index 9b635ea..33cad85 100644 --- a/replicate_test.go +++ b/replicate_test.go @@ -315,6 +315,80 @@ func TestCreatePrediction(t *testing.T) { assert.Equal(t, "https://streaming.api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", prediction.URLs["stream"]) } +func TestPredictionProgress(t *testing.T) { + prediction := replicate.Prediction{ + ID: "ufawqhfynnddngldkgtslldrkq", + Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + Status: "starting", + Input: map[string]interface{}{"text": "Alice"}, + Output: nil, + Error: nil, + Logs: nil, + Metrics: nil, + CreatedAt: "2022-04-26T22:13:06.224088Z", + URLs: map[string]string{ + "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", + "cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + "stream": "https://streaming.api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", + }, + } + + lines := []string{ + "Using seed: 12345", + "0%| | 0/5 [00:00