diff --git a/commands/completion/functions.go b/commands/completion/functions.go new file mode 100644 index 00000000..a5f66909 --- /dev/null +++ b/commands/completion/functions.go @@ -0,0 +1,34 @@ +package completion + +import ( + "encoding/json" + + "github.com/docker/model-cli/desktop" + "github.com/spf13/cobra" +) + +func NoComplete(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { + return nil, cobra.ShellCompDirectiveNoFileComp +} + +// ModelNames offers completion for models present within the local store. +func ModelNames(desktopClient *desktop.Client, limit int) cobra.CompletionFunc { + return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + if limit > 0 && len(args) >= limit { + return nil, cobra.ShellCompDirectiveNoFileComp + } + modelsString, err := desktopClient.List(true, false, false, "") + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + var models []desktop.Model + if err := json.Unmarshal([]byte(modelsString), &models); err != nil { + return nil, cobra.ShellCompDirectiveError + } + var names []string + for _, m := range models { + names = append(names, m.Tags...) + } + return names, cobra.ShellCompDirectiveNoFileComp + } +} diff --git a/commands/inspect.go b/commands/inspect.go index 8a13dc3a..d67342a7 100644 --- a/commands/inspect.go +++ b/commands/inspect.go @@ -3,6 +3,7 @@ package commands import ( "fmt" + "github.com/docker/model-cli/commands/completion" "github.com/docker/model-cli/desktop" "github.com/spf13/cobra" ) @@ -32,6 +33,7 @@ func newInspectCmd(desktopClient *desktop.Client) *cobra.Command { cmd.Println(model) return nil }, + ValidArgsFunction: completion.ModelNames(desktopClient, 1), } c.Flags().BoolVar(&openai, "openai", false, "List model in an OpenAI format") return c diff --git a/commands/list.go b/commands/list.go index fe5f4d50..81dd74ab 100644 --- a/commands/list.go +++ b/commands/list.go @@ -1,6 +1,7 @@ package commands import ( + "github.com/docker/model-cli/commands/completion" "github.com/docker/model-cli/desktop" "github.com/spf13/cobra" ) @@ -20,6 +21,7 @@ func newListCmd(desktopClient *desktop.Client) *cobra.Command { cmd.Print(models) return nil }, + ValidArgsFunction: completion.NoComplete, } c.Flags().BoolVar(&jsonFormat, "json", false, "List models in a JSON format") c.Flags().BoolVar(&openai, "openai", false, "List models in an OpenAI format") diff --git a/commands/pull.go b/commands/pull.go index 9f1da718..c594d60a 100644 --- a/commands/pull.go +++ b/commands/pull.go @@ -3,6 +3,7 @@ package commands import ( "fmt" + "github.com/docker/model-cli/commands/completion" "github.com/docker/model-cli/desktop" "github.com/spf13/cobra" ) @@ -24,6 +25,7 @@ func newPullCmd(desktopClient *desktop.Client) *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { return pullModel(cmd, desktopClient, args[0]) }, + ValidArgsFunction: completion.NoComplete, } return c } diff --git a/commands/rm.go b/commands/rm.go index ff9630e3..ac4b433b 100644 --- a/commands/rm.go +++ b/commands/rm.go @@ -3,6 +3,7 @@ package commands import ( "fmt" + "github.com/docker/model-cli/commands/completion" "github.com/docker/model-cli/desktop" "github.com/spf13/cobra" ) @@ -32,6 +33,7 @@ func newRemoveCmd(desktopClient *desktop.Client) *cobra.Command { } return nil }, + ValidArgsFunction: completion.ModelNames(desktopClient, -1), } return c } diff --git a/commands/run.go b/commands/run.go index 766daf7e..d1488344 100644 --- a/commands/run.go +++ b/commands/run.go @@ -7,6 +7,7 @@ import ( "os" "strings" + "github.com/docker/model-cli/commands/completion" "github.com/docker/model-cli/desktop" "github.com/spf13/cobra" ) @@ -81,6 +82,7 @@ func newRunCmd(desktopClient *desktop.Client) *cobra.Command { } return nil }, + ValidArgsFunction: completion.ModelNames(desktopClient, 1), } c.Args = func(cmd *cobra.Command, args []string) error { if len(args) < 1 { @@ -90,7 +92,7 @@ func newRunCmd(desktopClient *desktop.Client) *cobra.Command { "See 'docker model run --help' for more information", ) } - if len(args) > 2 { + if len(args) >= 2 { return fmt.Errorf("too many arguments, expected " + cmdArgs) } return nil diff --git a/commands/status.go b/commands/status.go index 17f7f7e0..ec2d0a78 100644 --- a/commands/status.go +++ b/commands/status.go @@ -4,6 +4,7 @@ import ( "os" "github.com/docker/cli/cli-plugins/hooks" + "github.com/docker/model-cli/commands/completion" "github.com/docker/model-cli/desktop" "github.com/spf13/cobra" ) @@ -27,6 +28,7 @@ func newStatusCmd(desktopClient *desktop.Client) *cobra.Command { return nil }, + ValidArgsFunction: completion.NoComplete, } return c } diff --git a/commands/version.go b/commands/version.go index 0dc42526..3448dde7 100644 --- a/commands/version.go +++ b/commands/version.go @@ -1,6 +1,9 @@ package commands -import "github.com/spf13/cobra" +import ( + "github.com/docker/model-cli/commands/completion" + "github.com/spf13/cobra" +) var Version = "dev" @@ -11,6 +14,7 @@ func newVersionCmd() *cobra.Command { Run: func(cmd *cobra.Command, args []string) { cmd.Printf("Docker Model Runner version %s\n", Version) }, + ValidArgsFunction: completion.NoComplete, } return c }