From 73d721b4f2b952dfcd057dbb22f1b938aaca50e8 Mon Sep 17 00:00:00 2001 From: Ignasi Date: Tue, 8 Apr 2025 17:48:25 +0200 Subject: [PATCH 1/4] Parse chunk messages --- commands/pull.go | 4 ++-- desktop/api.go | 6 ++++++ desktop/desktop.go | 27 ++++++++++++++++++++++----- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/commands/pull.go b/commands/pull.go index 8f37e386..f99a535b 100644 --- a/commands/pull.go +++ b/commands/pull.go @@ -35,6 +35,6 @@ func newPullCmd(desktopClient *desktop.Client) *cobra.Command { return c } -func TUIProgress(line string) { - fmt.Print("\r\033[K", line) +func TUIProgress(message string) { + fmt.Print("\r\033[K", message) } diff --git a/desktop/api.go b/desktop/api.go index b7064167..46f8de12 100644 --- a/desktop/api.go +++ b/desktop/api.go @@ -1,5 +1,11 @@ package desktop +// ProgressMessage represents a message sent during model pull operations +type ProgressMessage struct { + Type string `json:"type"` // "progress", "success", or "error" + Message string `json:"message"` // Human-readable message +} + type OpenAIChatMessage struct { Role string `json:"role"` Content string `json:"content"` diff --git a/desktop/desktop.go b/desktop/desktop.go index 148f7ba2..453ef9b3 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -104,14 +104,31 @@ func (c *Client) Pull(model string, progress func(string)) (string, error) { scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { progressLine := scanner.Text() - if progressLine != "" { - progress(progressLine) + if progressLine == "" { + continue + } + + // Parse the progress message + var progressMsg ProgressMessage + if err := json.Unmarshal([]byte(progressLine), &progressMsg); err != nil { + return "", fmt.Errorf("error parsing progress message: %w", err) } - } - fmt.Println() + // Handle different message types + switch progressMsg.Type { + case "progress": + progress(progressMsg.Message) + case "error": + return "", fmt.Errorf("error pulling model: %s", progressMsg.Message) + case "success": + return progressMsg.Message, nil + default: + return "", fmt.Errorf("unknown message type: %s", progressMsg.Type) + } + } - return fmt.Sprintf("Model %s pulled successfully", model), nil + // If we get here, something went wrong + return "", fmt.Errorf("unexpected end of stream while pulling model %s", model) } func (c *Client) List(jsonFormat, openai bool, model string) (string, error) { From 72b22901f8279db2c58d40f77defde44cf9477fe Mon Sep 17 00:00:00 2001 From: Ignasi Date: Wed, 9 Apr 2025 15:37:08 +0200 Subject: [PATCH 2/4] Unescape html + newline when progress has been shown --- commands/pull.go | 16 +++++++++++++++- desktop/desktop.go | 3 ++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/commands/pull.go b/commands/pull.go index f99a535b..2c1663b8 100644 --- a/commands/pull.go +++ b/commands/pull.go @@ -23,11 +23,25 @@ func newPullCmd(desktopClient *desktop.Client) *cobra.Command { }, RunE: func(cmd *cobra.Command, args []string) error { model := args[0] - response, err := desktopClient.Pull(model, TUIProgress) + + // Track if progress was shown + progressShown := false + progressTracker := func(message string) { + progressShown = true + TUIProgress(message) + } + + response, err := desktopClient.Pull(model, progressTracker) if err != nil { err = handleClientError(err, "Failed to pull model") return handleNotRunningError(err) } + + // Add a newline before the success message only if progress was shown + if progressShown { + fmt.Println() + } + cmd.Println(response) return nil }, diff --git a/desktop/desktop.go b/desktop/desktop.go index 453ef9b3..fb61cf59 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -5,6 +5,7 @@ import ( "bytes" "encoding/json" "fmt" + "html" "io" "net/http" "os" @@ -110,7 +111,7 @@ func (c *Client) Pull(model string, progress func(string)) (string, error) { // Parse the progress message var progressMsg ProgressMessage - if err := json.Unmarshal([]byte(progressLine), &progressMsg); err != nil { + if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { return "", fmt.Errorf("error parsing progress message: %w", err) } From 6a61c6c1815fe694ebdb3b4328c794f10db857e8 Mon Sep 17 00:00:00 2001 From: Ignasi Date: Wed, 9 Apr 2025 16:42:31 +0200 Subject: [PATCH 3/4] Show proper error --- commands/pull.go | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/commands/pull.go b/commands/pull.go index 2c1663b8..e151d921 100644 --- a/commands/pull.go +++ b/commands/pull.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/docker/model-cli/desktop" + "github.com/pkg/errors" "github.com/spf13/cobra" ) @@ -32,16 +33,27 @@ func newPullCmd(desktopClient *desktop.Client) *cobra.Command { } response, err := desktopClient.Pull(model, progressTracker) - if err != nil { - err = handleClientError(err, "Failed to pull model") - return handleNotRunningError(err) - } - // Add a newline before the success message only if progress was shown + // Add a newline before any output (success or error) if progress was shown if progressShown { fmt.Println() } + if err != nil { + err = handleClientError(err, "Failed to pull model") + + // Check if it's a "not running" error + if errors.Is(err, notRunningErr) { + // For "not running" errors, return the error to display the usage + return handleNotRunningError(err) + } + + // For other errors, print the error message and return nil + // to prevent Cobra from displaying the usage + fmt.Fprintln(cmd.ErrOrStderr(), err) + return nil + } + cmd.Println(response) return nil }, From 84db1567a2dd3b4088c93388100435bfa98fc93e Mon Sep 17 00:00:00 2001 From: Dorin Geman Date: Thu, 10 Apr 2025 13:30:05 +0300 Subject: [PATCH 4/4] pull: Return whether progress was shown Signed-off-by: Dorin Geman --- commands/compose.go | 2 +- commands/pull.go | 49 +++++++++++++++------------------------------ commands/run.go | 6 ++---- desktop/desktop.go | 21 ++++++++++--------- 4 files changed, 31 insertions(+), 47 deletions(-) diff --git a/commands/compose.go b/commands/compose.go index 4c5e45cf..efa78032 100644 --- a/commands/compose.go +++ b/commands/compose.go @@ -33,7 +33,7 @@ func newUpCommand(desktopClient *desktop.Client) *cobra.Command { return err } - _, err := desktopClient.Pull(model, func(s string) { + _, _, err := desktopClient.Pull(model, func(s string) { sendInfo(s) }) if err != nil { diff --git a/commands/pull.go b/commands/pull.go index e151d921..9f1da718 100644 --- a/commands/pull.go +++ b/commands/pull.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/docker/model-cli/desktop" - "github.com/pkg/errors" "github.com/spf13/cobra" ) @@ -23,42 +22,26 @@ func newPullCmd(desktopClient *desktop.Client) *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { - model := args[0] - - // Track if progress was shown - progressShown := false - progressTracker := func(message string) { - progressShown = true - TUIProgress(message) - } - - response, err := desktopClient.Pull(model, progressTracker) - - // Add a newline before any output (success or error) if progress was shown - if progressShown { - fmt.Println() - } - - if err != nil { - err = handleClientError(err, "Failed to pull model") + return pullModel(cmd, desktopClient, args[0]) + }, + } + return c +} - // Check if it's a "not running" error - if errors.Is(err, notRunningErr) { - // For "not running" errors, return the error to display the usage - return handleNotRunningError(err) - } +func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { + response, progressShown, err := desktopClient.Pull(model, TUIProgress) - // For other errors, print the error message and return nil - // to prevent Cobra from displaying the usage - fmt.Fprintln(cmd.ErrOrStderr(), err) - return nil - } + // Add a newline before any output (success or error) if progress was shown. + if progressShown { + cmd.Println() + } - cmd.Println(response) - return nil - }, + if err != nil { + return handleNotRunningError(handleClientError(err, "Failed to pull model")) } - return c + + cmd.Println(response) + return nil } func TUIProgress(message string) { diff --git a/commands/run.go b/commands/run.go index 2e0d03d4..ebe46099 100644 --- a/commands/run.go +++ b/commands/run.go @@ -37,11 +37,9 @@ func newRunCmd(desktopClient *desktop.Client) *cobra.Command { return handleNotRunningError(handleClientError(err, "Failed to list models")) } cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") - response, err := desktopClient.Pull(model, TUIProgress) - if err != nil { - return handleNotRunningError(handleClientError(err, "Failed to pull model")) + if err := pullModel(cmd, desktopClient, model); err != nil { + return err } - cmd.Println(response) } if prompt != "" { diff --git a/desktop/desktop.go b/desktop/desktop.go index fb61cf59..1477e9e1 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -80,10 +80,10 @@ func (c *Client) Status() Status { } } -func (c *Client) Pull(model string, progress func(string)) (string, error) { +func (c *Client) Pull(model string, progress func(string)) (string, bool, error) { jsonData, err := json.Marshal(models.ModelCreateRequest{From: model}) if err != nil { - return "", fmt.Errorf("error marshaling request: %w", err) + return "", false, fmt.Errorf("error marshaling request: %w", err) } createPath := inference.ModelsPrefix + "/create" @@ -93,15 +93,17 @@ func (c *Client) Pull(model string, progress func(string)) (string, error) { bytes.NewReader(jsonData), ) if err != nil { - return "", c.handleQueryError(err, createPath) + return "", false, c.handleQueryError(err, createPath) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body)) + return "", false, fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body)) } + progressShown := false + scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { progressLine := scanner.Text() @@ -112,24 +114,25 @@ func (c *Client) Pull(model string, progress func(string)) (string, error) { // Parse the progress message var progressMsg ProgressMessage if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { - return "", fmt.Errorf("error parsing progress message: %w", err) + return "", progressShown, fmt.Errorf("error parsing progress message: %w", err) } // Handle different message types switch progressMsg.Type { case "progress": progress(progressMsg.Message) + progressShown = true case "error": - return "", fmt.Errorf("error pulling model: %s", progressMsg.Message) + return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message) case "success": - return progressMsg.Message, nil + return progressMsg.Message, progressShown, nil default: - return "", fmt.Errorf("unknown message type: %s", progressMsg.Type) + return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type) } } // If we get here, something went wrong - return "", fmt.Errorf("unexpected end of stream while pulling model %s", model) + return "", progressShown, fmt.Errorf("unexpected end of stream while pulling model %s", model) } func (c *Client) List(jsonFormat, openai bool, model string) (string, error) {