diff --git a/cli/Makefile b/cli/Makefile new file mode 100644 index 000000000..4ed21a6db --- /dev/null +++ b/cli/Makefile @@ -0,0 +1,30 @@ +VERSION ?= dev +GIT_COMMIT := $(shell git rev-parse --short HEAD || echo "unknown") +BUILD_DATE := $(shell date -u '+%Y-%m-%d') + +LDFLAGS := -X github.com/kagent-dev/kagent/cli/internal/cli.Version=$(VERSION) \ + -X github.com/kagent-dev/kagent/cli/internal/cli.GitCommit=$(GIT_COMMIT) \ + -X github.com/kagent-dev/kagent/cli/internal/cli.BuildDate=$(BUILD_DATE) + +.PHONY: build +build: + go build -ldflags "$(LDFLAGS)" -o bin/kagent ./cmd/kagent + +.PHONY: install +install: + go install -ldflags "$(LDFLAGS)" ./cmd/kagent + +.PHONY: clean +clean: + rm -rf bin/ + +.PHONY: test +test: + go test ./... + +.PHONY: deps +deps: + go mod download + go mod tidy + +.DEFAULT_GOAL := build \ No newline at end of file diff --git a/cli/bin/kagent b/cli/bin/kagent new file mode 100755 index 000000000..139977ba2 Binary files /dev/null and b/cli/bin/kagent differ diff --git a/cli/cmd/kagent/main.go b/cli/cmd/kagent/main.go new file mode 100644 index 000000000..3305b0fb9 --- /dev/null +++ b/cli/cmd/kagent/main.go @@ -0,0 +1,23 @@ +package main + +import ( + "fmt" + "os" + + "github.com/kagent-dev/kagent/cli/internal/cli" + "github.com/kagent-dev/kagent/cli/internal/config" +) + +func main() { + // Initialize config + if err := config.Init(); err != nil { + fmt.Fprintf(os.Stderr, "Error initializing config: %v\n", err) + os.Exit(1) + } + + // Create and execute root command + rootCmd := cli.NewRootCmd() + if err := rootCmd.Execute(); err != nil { + os.Exit(1) + } +} diff --git a/cli/go.mod b/cli/go.mod index c2fbe13c4..b9f7f9db7 100644 --- a/cli/go.mod +++ b/cli/go.mod @@ -1,3 +1,32 @@ module github.com/kagent-dev/kagent/cli go 1.23.5 + +require ( + github.com/gorilla/websocket v1.5.3 + github.com/spf13/cobra v1.8.1 + github.com/spf13/viper v1.19.0 + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 +) + +require ( + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/cli/go.sum b/cli/go.sum new file mode 100644 index 000000000..ffaabb32b --- /dev/null +++ b/cli/go.sum @@ -0,0 +1,79 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= +github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/cli/internal/api/client.go b/cli/internal/api/client.go new file mode 100644 index 000000000..f32d8bcfc --- /dev/null +++ b/cli/internal/api/client.go @@ -0,0 +1,116 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +type Client struct { + BaseURL string + WSURL string + HTTPClient *http.Client +} + +func (c *Client) GetVersion() (string, error) { + var result struct { + Status bool `json:"status"` + Message string `json:"message"` + Data struct { + Version string `json:"version"` + } `json:"data"` + } + + err := c.doRequest("GET", "/version", nil, &result) + if err != nil { + return "", err + } + + if !result.Status { + return "", fmt.Errorf("api error: %s", result.Message) + } + + return result.Data.Version, nil +} + +func NewClient(baseURL, wsURL string) *Client { + // Ensure baseURL doesn't end with a slash + baseURL = strings.TrimRight(baseURL, "/") + + return &Client{ + BaseURL: baseURL, + WSURL: wsURL, + HTTPClient: &http.Client{ + Timeout: time.Second * 30, + }, + } +} + +func (c *Client) doRequest(method, path string, body interface{}, result interface{}) error { + var bodyReader *bytes.Reader + if body != nil { + bodyBytes, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("error marshaling request body: %w", err) + } + bodyReader = bytes.NewReader(bodyBytes) + } + + // Ensure path starts with a slash + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + + url := c.BaseURL + path + + var req *http.Request + var err error + if bodyReader != nil { + req, err = http.NewRequest(method, url, bodyReader) + } else { + req, err = http.NewRequest(method, url, nil) + } + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return fmt.Errorf("error making request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + return fmt.Errorf("request failed with status: %s", resp.Status) + } + + // Decode into APIResponse first + var apiResp APIResponse + if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil { + return fmt.Errorf("error decoding response: %w", err) + } + + // Check response status + if !apiResp.Status { + return fmt.Errorf("api error: %s", apiResp.Message) + } + + // If caller wants the result, marshal the Data field into their result type + if result != nil { + dataBytes, err := json.Marshal(apiResp.Data) + if err != nil { + return fmt.Errorf("error re-marshaling data: %w", err) + } + + if err := json.Unmarshal(dataBytes, result); err != nil { + return fmt.Errorf("error unmarshaling into result: %w", err) + } + } + + return nil +} diff --git a/cli/internal/api/run.go b/cli/internal/api/run.go new file mode 100644 index 000000000..369952203 --- /dev/null +++ b/cli/internal/api/run.go @@ -0,0 +1,37 @@ +package api + +import "fmt" + +func (c *Client) CreateRun(req *CreateRunRequest) (*CreateRunResult, error) { + var run CreateRunResult + err := c.doRequest("POST", "/runs", req, &run) + return &run, err +} + +func (c *Client) GetRun(runID string) (*Run, error) { + var run Run + err := c.doRequest("GET", fmt.Sprintf("/runs/%s", runID), nil, &run) + return &run, err +} + +func (c *Client) ListRuns(userID string) ([]Run, error) { + // Go through all sessions and then retrieve all runs for each session + var sessions []Session + err := c.doRequest("GET", fmt.Sprintf("/sessions/?user_id=%s", userID), nil, &sessions) + if err != nil { + return nil, err + } + + // For each session, get the run information + var runs []Run + for _, session := range sessions { + var sessionRuns SessionRuns + err := c.doRequest("GET", fmt.Sprintf("/sessions/%d/runs/?user_id=%s", session.ID, userID), nil, &sessionRuns) + if err != nil { + fmt.Println("Error getting runs for session") + return nil, err + } + runs = append(runs, sessionRuns.Runs...) + } + return runs, nil +} diff --git a/cli/internal/api/session.go b/cli/internal/api/session.go new file mode 100644 index 000000000..d19eabcb0 --- /dev/null +++ b/cli/internal/api/session.go @@ -0,0 +1,25 @@ +package api + +import "fmt" + +func (c *Client) ListSessions(userID string) ([]Session, error) { + var sessions []Session + err := c.doRequest("GET", fmt.Sprintf("/sessions/?user_id=%s", userID), nil, &sessions) + return sessions, err +} + +func (c *Client) CreateSession(session *CreateSession) (*Session, error) { + var result Session + err := c.doRequest("POST", "/sessions/", session, &result) + return &result, err +} + +func (c *Client) GetSession(sessionID int, userID string) (*Session, error) { + var session Session + err := c.doRequest("GET", fmt.Sprintf("/sessions/%d?user_id=%s", sessionID, userID), nil, &session) + return &session, err +} + +func (c *Client) DeleteSession(sessionID int, userID string) error { + return c.doRequest("DELETE", fmt.Sprintf("/sessions/%d?user_id=%s", sessionID, userID), nil, nil) +} diff --git a/cli/internal/api/team.go b/cli/internal/api/team.go new file mode 100644 index 000000000..d2fa06cc9 --- /dev/null +++ b/cli/internal/api/team.go @@ -0,0 +1,32 @@ +package api + +import "fmt" + +func (c *Client) ListTeams(userID string) ([]TeamResponse, error) { + var teams []TeamResponse + err := c.doRequest("GET", fmt.Sprintf("/teams/?user_id=%s", userID), nil, &teams) + return teams, err +} + +func (c *Client) CreateTeam(team *TeamResponse) error { + return c.doRequest("POST", "/teams/", team, team) +} + +func (c *Client) GetTeam(teamLabel string, userID string) (*TeamResponse, error) { + allTeams, err := c.ListTeams(userID) + if err != nil { + return nil, err + } + + for _, team := range allTeams { + if team.Component.Label == teamLabel { + return &team, nil + } + } + + return nil, nil +} + +func (c *Client) DeleteTeam(teamID int, userID string) error { + return c.doRequest("DELETE", fmt.Sprintf("/teams/%d?user_id=%s", teamID, userID), nil, nil) +} diff --git a/cli/internal/api/types.go b/cli/internal/api/types.go new file mode 100644 index 000000000..99bb59b10 --- /dev/null +++ b/cli/internal/api/types.go @@ -0,0 +1,316 @@ +package api + +// APIResponse is the common response wrapper for all API responses +type APIResponse struct { + Status bool `json:"status"` + Message string `json:"message"` + Data interface{} `json:"data"` +} + +type Session struct { + ID int `json:"id"` + UserID string `json:"user_id"` + Version string `json:"version"` + TeamID int `json:"team_id"` + Name string `json:"name"` +} + +type CreateSession struct { + UserID string `json:"user_id"` + TeamID int `json:"team_id"` + Name string `json:"name"` +} + +// BaseComponent represents the common fields in all components +type BaseComponent struct { + Provider string `json:"provider"` + ComponentType string `json:"component_type"` + Version int `json:"version"` + ComponentVersion int `json:"component_version"` + Description *string `json:"description"` + Config interface{} `json:"config"` + Label *string `json:"label,omitempty"` +} + +// TeamResponseConfig represents the team component configuration +type TeamResponseConfig struct { + Participants []BaseComponent `json:"participants"` + TerminationCondition *BaseComponent `json:"termination_condition,omitempty"` +} + +// TeamComponent represents the component field in the Team response +type TeamComponent struct { + Provider string `json:"provider"` + ComponentType string `json:"component_type"` + Version int `json:"version"` + ComponentVersion int `json:"component_version"` + Description *string `json:"description"` + Component TeamResponseConfig `json:"component"` + Label string `json:"label"` +} + +// TeamResponse represents the full team response structure +type TeamResponse struct { + ID int `json:"id"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + UserID string `json:"user_id"` + Version string `json:"version"` + Component TeamComponent `json:"component"` +} + +// AgentConfig represents the configuration for an agent +type AgentResponseConfig struct { + Name string `json:"name"` + ModelClient *BaseComponent `json:"model_client"` + Tools []BaseComponent `json:"tools,omitempty"` + ModelContext *BaseComponent `json:"model_context,omitempty"` + Description string `json:"description"` + SystemMessage string `json:"system_message"` + ReflectOnToolUse bool `json:"reflect_on_tool_use"` + ToolCallSummaryFormat string `json:"tool_call_summary_format"` +} + +// ModelResponseConfig represents the configuration for a model +type ModelResponseConfig struct { + Model string `json:"model"` +} + +// TerminationResponseConfig represents the configuration for termination conditions +type TerminationResponseConfig struct { + MaxMessages int `json:"max_messages"` +} + +// HTTPToolConfig represents the configuration for HTTP tools +type HTTPToolConfig struct { + Name string `json:"name"` + Description string `json:"description"` + Scheme string `json:"scheme"` + Host string `json:"host"` + Port int `json:"port"` + Path string `json:"path"` + Method string `json:"method"` + Headers map[string]string `json:"headers"` + JSONSchema map[string]interface{} `json:"json_schema"` +} + +// BuiltInToolConfig represents the configuration for built-in tools +type BuiltInToolConfig struct { + FnName string `json:"fn_name"` +} + +// TeamConfig represents either a SelectorGroupChatConfig or RoundRobinGroupChatConfig +type TeamConfig struct { + // Shared fields between both configs + Participants []AgentComponent `json:"participants"` + TerminationCondition *TerminationComponent `json:"termination_condition,omitempty"` + MaxTurns *int `json:"max_turns,omitempty"` + + // SelectorGroupChat specific fields + ModelClient *ModelComponent `json:"model_client,omitempty"` + SelectorPrompt string `json:"selector_prompt,omitempty"` + AllowRepeatedSpeaker bool `json:"allow_repeated_speaker,omitempty"` +} + +// Component types +type AgentComponent struct { + Provider string `json:"provider"` + ComponentType string `json:"component_type"` + Version *int `json:"version,omitempty"` + Description *string `json:"description,omitempty"` + Component AgentConfig `json:"component"` + Label *string `json:"label,omitempty"` +} + +type ModelComponent struct { + Provider string `json:"provider"` + ComponentType string `json:"component_type"` + Version *int `json:"version,omitempty"` + Description *string `json:"description,omitempty"` + Component ModelConfig `json:"component"` + Label *string `json:"label,omitempty"` +} + +type TerminationComponent struct { + Provider string `json:"provider"` + ComponentType string `json:"component_type"` + Version *int `json:"version,omitempty"` + Description *string `json:"description,omitempty"` + Component TerminationConfig `json:"component"` + Label *string `json:"label,omitempty"` +} + +// Agent Configurations +type AgentConfig struct { + // MultimodalWebSurferConfig fields + Name string `json:"name"` + ModelClient *ModelComponent `json:"model_client,omitempty"` + DownloadsFolder *string `json:"downloads_folder,omitempty"` + Description string `json:"description"` + DebugDir *string `json:"debug_dir,omitempty"` + Headless *bool `json:"headless,omitempty"` + StartPage *string `json:"start_page,omitempty"` + AnimateActions *bool `json:"animate_actions,omitempty"` + ToSaveScreenshots *bool `json:"to_save_screenshots,omitempty"` + UseOCR *bool `json:"use_ocr,omitempty"` + BrowserChannel *string `json:"browser_channel,omitempty"` + BrowserDataDir *string `json:"browser_data_dir,omitempty"` + ToResizeViewport *bool `json:"to_resize_viewport,omitempty"` + + // AssistantAgentConfig fields + Tools []ToolComponent `json:"tools,omitempty"` + ModelContext *ChatCompletionContextComponent `json:"model_context,omitempty"` + SystemMessage *string `json:"system_message,omitempty"` + ReflectOnToolUse bool `json:"reflect_on_tool_use,omitempty"` + ToolCallSummaryFormat string `json:"tool_call_summary_format,omitempty"` +} + +// Model Configurations +type ModelInfo struct { + Vision bool `json:"vision"` + FunctionCalling bool `json:"function_calling"` + JSONOutput bool `json:"json_output"` + Family string `json:"family"` +} + +type CreateArgumentsConfig struct { + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]float64 `json:"logit_bias,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + N *int `json:"n,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop interface{} `json:"stop,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + User *string `json:"user,omitempty"` +} + +type ModelConfig struct { + // Base OpenAI fields + Model string `json:"model"` + APIKey *string `json:"api_key,omitempty"` + Timeout *int `json:"timeout,omitempty"` + MaxRetries *int `json:"max_retries,omitempty"` + ModelCapabilities interface{} `json:"model_capabilities,omitempty"` + ModelInfo *ModelInfo `json:"model_info,omitempty"` + CreateArgumentsConfig + + // OpenAIClientConfig specific fields + Organization *string `json:"organization,omitempty"` + BaseURL *string `json:"base_url,omitempty"` + + // AzureOpenAIClientConfig specific fields + AzureEndpoint *string `json:"azure_endpoint,omitempty"` + AzureDeployment *string `json:"azure_deployment,omitempty"` + APIVersion *string `json:"api_version,omitempty"` + AzureADToken *string `json:"azure_ad_token,omitempty"` + AzureADTokenProvider interface{} `json:"azure_ad_token_provider,omitempty"` +} + +// Tool Configuration +type ToolComponent struct { + Provider string `json:"provider"` + ComponentType string `json:"component_type"` + Version *int `json:"version,omitempty"` + Description *string `json:"description,omitempty"` + Component ToolConfig `json:"component"` + Label *string `json:"label,omitempty"` +} + +type ToolConfig struct { + SourceCode string `json:"source_code"` + Name string `json:"name"` + Description string `json:"description"` + GlobalImports []interface{} `json:"global_imports"` + HasCancellationSupport bool `json:"has_cancellation_support"` +} + +// ChatCompletionContext Configuration +type ChatCompletionContextComponent struct { + Provider string `json:"provider"` + ComponentType string `json:"component_type"` + Version *int `json:"version,omitempty"` + Description *string `json:"description,omitempty"` + Component ChatCompletionContextConfig `json:"component"` + Label *string `json:"label,omitempty"` +} + +type ChatCompletionContextConfig struct { + // Empty as per the TypeScript definition +} + +// Termination Configurations +type TerminationConfig struct { + // OrTerminationConfig + Conditions []TerminationComponent `json:"conditions,omitempty"` + + // MaxMessageTerminationConfig + MaxMessages *int `json:"max_messages,omitempty"` + + // TextMentionTerminationConfig + Text *string `json:"text,omitempty"` +} +type ModelsUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` +} + +type TaskMessage struct { + Source string `json:"source"` + ModelsUsage *ModelsUsage `json:"models_usage"` + Content string `json:"content"` + Type string `json:"type"` +} + +type RunMessage struct { + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + Version string `json:"version"` + SessionID int `json:"session_id"` + MessageMeta map[string]interface{} `json:"message_meta"` + ID int `json:"id"` + UserID *string `json:"user_id"` + Component TaskMessage `json:"component"` + RunID string `json:"run_id"` +} + +type CreateRunRequest struct { + SessionID int `json:"session_id"` + UserID string `json:"user_id"` +} + +type CreateRunResult struct { + ID string `json:"run_id"` +} + +type SessionRuns struct { + Runs []Run `json:"runs"` +} + +type Run struct { + ID string `json:"id"` + CreatedAt string `json:"created_at"` + Status string `json:"status"` + Task Task `json:"task"` + TeamResult TeamResult `json:"team_result"` + Messages []RunMessage `json:"messages"` +} + +type Task struct { + Source string `json:"source"` + Content string `json:"content"` + MessageType string `json:"message_type"` +} + +type TeamResult struct { + TaskResult TaskResult `json:"task_result"` + Usage string `json:"usage"` + Duration float64 `json:"duration"` +} + +type TaskResult struct { + Messages []TaskMessage `json:"messages"` + StopReason string `json:"stop_reason"` +} diff --git a/cli/internal/cli/config.go b/cli/internal/cli/config.go new file mode 100644 index 000000000..b96040e9c --- /dev/null +++ b/cli/internal/cli/config.go @@ -0,0 +1,51 @@ +package cli + +import ( + "encoding/json" + "fmt" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +func newConfigCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "config", + Short: "Manage CLI configuration", + } + + cmd.AddCommand( + &cobra.Command{ + Use: "view", + Short: "View current configuration", + RunE: runConfigView, + }, + &cobra.Command{ + Use: "set-api-url URL", + Short: "Set API URL", + Args: cobra.ExactArgs(1), + RunE: runConfigSetBackendURL, + }, + ) + + return cmd +} + +func runConfigView(cmd *cobra.Command, args []string) error { + settings := viper.AllSettings() + output, err := json.MarshalIndent(settings, "", " ") + if err != nil { + return err + } + fmt.Println(string(output)) + return nil +} + +func runConfigSetBackendURL(cmd *cobra.Command, args []string) error { + viper.Set("api_url", args[0]) + if err := viper.WriteConfig(); err != nil { + return fmt.Errorf("error saving config: %w", err) + } + fmt.Printf("API URL set to: %s\n", args[0]) + return nil +} diff --git a/cli/internal/cli/format.go b/cli/internal/cli/format.go new file mode 100644 index 000000000..81488fb60 --- /dev/null +++ b/cli/internal/cli/format.go @@ -0,0 +1,75 @@ +package cli + +import ( + "encoding/json" + "fmt" + + "github.com/spf13/viper" +) + +type OutputFormat string + +const ( + OutputFormatJSON OutputFormat = "json" + OutputFormatTable OutputFormat = "table" +) + +// PrintOutput handles the output formatting based on the configured output format +func PrintOutput(data interface{}, tableHeaders []string, tableRows [][]string) error { + format := OutputFormat(viper.GetString("output_format")) + + switch format { + case OutputFormatJSON: + return printJSON(data) + case OutputFormatTable: + return printTable(tableHeaders, tableRows) + default: + return fmt.Errorf("unknown output format: %s", format) + } +} + +func printJSON(data interface{}) error { + output, err := json.MarshalIndent(data, "", " ") + if err != nil { + return fmt.Errorf("error formatting JSON: %w", err) + } + fmt.Println(string(output)) + return nil +} + +func printTable(headers []string, rows [][]string) error { + if len(rows) == 0 { + fmt.Println("No data found") + return nil + } + + // Calculate column widths + widths := make([]int, len(headers)) + for i, h := range headers { + widths[i] = len(h) + } + + for _, row := range rows { + for i, cell := range row { + if len(cell) > widths[i] { + widths[i] = len(cell) + } + } + } + + // Print headers + for i, h := range headers { + fmt.Printf("%-*s", widths[i]+2, h) + } + fmt.Println() + + // Print rows + for _, row := range rows { + for i, cell := range row { + fmt.Printf("%-*s", widths[i]+2, cell) + } + fmt.Println() + } + + return nil +} diff --git a/cli/internal/cli/install.go b/cli/internal/cli/install.go new file mode 100644 index 000000000..2a81448c2 --- /dev/null +++ b/cli/internal/cli/install.go @@ -0,0 +1,19 @@ +package cli + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func newInstallCmd() *cobra.Command { + return &cobra.Command{ + Use: "install", + Short: "Install kagent", + RunE: runInstall, + } +} + +func runInstall(cmd *cobra.Command, args []string) error { + return fmt.Errorf("not implemented") +} diff --git a/cli/internal/cli/root.go b/cli/internal/cli/root.go new file mode 100644 index 000000000..ecf311f1a --- /dev/null +++ b/cli/internal/cli/root.go @@ -0,0 +1,39 @@ +package cli + +import ( + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +func NewRootCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "kagent", + Short: "kagent CLI", + Long: `A CLI tool to interact with kagent`, + } + + // Add global flags + cmd.PersistentFlags().String("api-url", "http://localhost:8081/api", "Backend API URL") + cmd.PersistentFlags().String("ws-url", "", "WebSocket URL (optional, derived from backend URL if not set)") + cmd.PersistentFlags().String("output", "table", "Output format (json or table)") + + // Bind flags to viper + viper.BindPFlag("api_url", cmd.PersistentFlags().Lookup("api-url")) + viper.BindPFlag("ws_url", cmd.PersistentFlags().Lookup("ws-url")) + viper.BindPFlag("output_format", cmd.PersistentFlags().Lookup("output")) + + // Set default values + viper.SetDefault("output_format", "table") + + // Add commands + cmd.AddCommand( + newInstallCmd(), + newConfigCmd(), + newSessionCmd(), + newTeamCmd(), + newRunCmd(), + newVersionCmd(), + ) + + return cmd +} diff --git a/cli/internal/cli/run.go b/cli/internal/cli/run.go new file mode 100644 index 000000000..6500593f7 --- /dev/null +++ b/cli/internal/cli/run.go @@ -0,0 +1,172 @@ +package cli + +import ( + "encoding/json" + "fmt" + "strconv" + + "github.com/kagent-dev/kagent/cli/internal/api" + "github.com/kagent-dev/kagent/cli/internal/config" + "github.com/kagent-dev/kagent/cli/internal/ws" + "github.com/spf13/cobra" + "golang.org/x/exp/rand" +) + +func newRunCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "run", + Short: "Manage runs", + } + + createCmd := &cobra.Command{ + Use: "create [team-name] [task]", + Short: "Create a new run", + Args: cobra.ExactArgs(2), + RunE: runRunCreate, + } + + listAllCmd := &cobra.Command{ + Use: "list", + Short: "Lists all runs", + RunE: runListAll, + } + + cmd.AddCommand( + createCmd, + listAllCmd, + &cobra.Command{ + Use: "get [run-id]", + Short: "Get run details", + Args: cobra.ExactArgs(1), + RunE: runRunGet, + }, + ) + + return cmd +} + +func runListAll(cmd *cobra.Command, args []string) error { + cfg, err := config.Get() + if err != nil { + return err + } + + client := api.NewClient(cfg.APIURL, cfg.WSURL) + runs, err := client.ListRuns(cfg.UserID) + if err != nil { + return fmt.Errorf("error listing runs: %w", err) + } + + if len(runs) == 0 { + fmt.Println("No runs found") + return nil + } + + headers := []string{"ID", "CONTENT", "MESSAGES", "STATUS", "CREATED"} + rows := make([][]string, len(runs)) + for i, run := range runs { + // Truncate task content to first 10 characters if possible + content := run.Task.Content + if len(content) > 10 { + content = content[:10] + "..." + } + + rows[i] = []string{ + run.ID, + content, + strconv.Itoa(len(run.Messages)), + run.Status, + run.CreatedAt, + } + } + + return PrintOutput(runs, headers, rows) + +} + +func generateRandomString(prefix string, length int) (string, error) { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, length) + + if _, err := rand.Read(b); err != nil { + return "", err + } + + for i := range b { + b[i] = charset[int(b[i])%len(charset)] + } + + return prefix + string(b), nil +} + +func runRunCreate(cmd *cobra.Command, args []string) error { + cfg, err := config.Get() + if err != nil { + return err + } + + client := api.NewClient(cfg.APIURL, cfg.WSURL) + // Get the team based on the input + userID + team, err := client.GetTeam(args[0], cfg.UserID) + if err != nil { + return err + } + fmt.Printf("Retrieved team %s with ID %d\n", team.Component.Label, team.ID) + + // Create a random session name + sessionName, err := generateRandomString("session-", 5) + if err != nil { + return err + } + + session, err := client.CreateSession(&api.CreateSession{ + UserID: cfg.UserID, + // This will probably be created on the apiserver side in the future + Name: sessionName, + TeamID: team.ID, + }) + if err != nil { + fmt.Printf("Failed to create session: %v\n", err) + return err + } + + fmt.Printf("Created session %s with ID %d\n", session.Name, session.ID) + + run, err := client.CreateRun(&api.CreateRunRequest{ + SessionID: session.ID, + UserID: session.UserID, + }) + if err != nil { + return err + } + + wsConfig := ws.DefaultConfig() + wsClient, err := ws.NewClient(cfg.WSURL, run.ID, wsConfig) + if err != nil { + return fmt.Errorf("failed to create WebSocket client: %v", err) + } + + // Starting interactive mode by default + return wsClient.StartInteractive(team.Component, args[1]) +} + +func runRunGet(cmd *cobra.Command, args []string) error { + cfg, err := config.Get() + if err != nil { + return err + } + + client := api.NewClient(cfg.APIURL, cfg.WSURL) + run, err := client.GetRun(args[0]) + if err != nil { + return err + } + + output, err := json.MarshalIndent(run, "", " ") + if err != nil { + return err + } + + fmt.Println(string(output)) + return nil +} diff --git a/cli/internal/cli/session.go b/cli/internal/cli/session.go new file mode 100644 index 000000000..b85f5432d --- /dev/null +++ b/cli/internal/cli/session.go @@ -0,0 +1,125 @@ +package cli + +import ( + "encoding/json" + "fmt" + "strconv" + + "github.com/kagent-dev/kagent/cli/internal/api" + "github.com/kagent-dev/kagent/cli/internal/config" + "github.com/spf13/cobra" +) + +func newSessionCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "session", + Short: "Manage sessions", + } + + cmd.AddCommand( + &cobra.Command{ + Use: "list", + Short: "List all sessions", + RunE: runSessionList, + }, + &cobra.Command{ + Use: "create", + Short: "Create a new session", + RunE: runSessionCreate, + }, + &cobra.Command{ + Use: "get [session-id]", + Short: "Get session details", + Args: cobra.ExactArgs(1), + RunE: runSessionGet, + }, + &cobra.Command{ + Use: "delete [session-id]", + Short: "Delete a session", + Args: cobra.ExactArgs(1), + RunE: runSessionDelete, + }, + ) + + return cmd +} + +func runSessionList(cmd *cobra.Command, args []string) error { + cfg, err := config.Get() + if err != nil { + return err + } + + client := api.NewClient(cfg.APIURL, cfg.WSURL) + sessions, err := client.ListSessions(cfg.UserID) + if err != nil { + return err + } + + if len(sessions) == 0 { + fmt.Println("No sessions found") + return nil + } + + headers := []string{"ID", "NAME", "TEAM"} + rows := make([][]string, len(sessions)) + for i, session := range sessions { + rows[i] = []string{ + strconv.Itoa(session.ID), + session.Name, + strconv.Itoa(session.TeamID), + } + } + + return PrintOutput(sessions, headers, rows) +} + +func runSessionCreate(cmd *cobra.Command, args []string) error { + return nil +} + +func runSessionGet(cmd *cobra.Command, args []string) error { + cfg, err := config.Get() + if err != nil { + return err + } + + sessionID, err := strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("invalid session ID: %s", args[0]) + } + + client := api.NewClient(cfg.APIURL, cfg.WSURL) + session, err := client.GetSession(sessionID, cfg.UserID) + if err != nil { + return err + } + + output, err := json.MarshalIndent(session, "", " ") + if err != nil { + return err + } + + fmt.Println(string(output)) + return nil +} + +func runSessionDelete(cmd *cobra.Command, args []string) error { + cfg, err := config.Get() + if err != nil { + return err + } + + sessionID, err := strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("invalid session ID: %s", args[0]) + } + + client := api.NewClient(cfg.APIURL, cfg.WSURL) + if err := client.DeleteSession(sessionID, cfg.UserID); err != nil { + return err + } + + fmt.Printf("Session %d deleted\n", sessionID) + return nil +} diff --git a/cli/internal/cli/team.go b/cli/internal/cli/team.go new file mode 100644 index 000000000..808bd3fcb --- /dev/null +++ b/cli/internal/cli/team.go @@ -0,0 +1,125 @@ +package cli + +import ( + "encoding/json" + "fmt" + "strconv" + + "github.com/kagent-dev/kagent/cli/internal/api" + "github.com/kagent-dev/kagent/cli/internal/config" + "github.com/spf13/cobra" +) + +func newTeamCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "team", + Short: "Manage teams", + } + + cmd.AddCommand( + &cobra.Command{ + Use: "list", + Short: "List all teams", + RunE: runTeamList, + }, + &cobra.Command{ + Use: "create", + Short: "Create a new team", + RunE: runTeamCreate, + }, + &cobra.Command{ + Use: "get [name]", + Short: "Get team details", + Args: cobra.ExactArgs(1), + RunE: runTeamGet, + }, + &cobra.Command{ + Use: "delete [team-id]", + Short: "Delete a team", + Args: cobra.ExactArgs(1), + RunE: runTeamDelete, + }, + ) + + return cmd +} + +func runTeamList(cmd *cobra.Command, args []string) error { + cfg, err := config.Get() + if err != nil { + return err + } + + client := api.NewClient(cfg.APIURL, cfg.WSURL) + teams, err := client.ListTeams(cfg.UserID) + if err != nil { + return err + } + if len(teams) == 0 { + fmt.Println("No teams found") + return nil + } + + // Prepare table data + headers := []string{"ID", "NAME", "CREATED"} + rows := make([][]string, len(teams)) + for i, team := range teams { + rows[i] = []string{ + fmt.Sprintf("%d", team.ID), + team.Component.Label, + team.CreatedAt, + } + } + + return PrintOutput(teams, headers, rows) +} + +func runTeamCreate(cmd *cobra.Command, args []string) error { + return fmt.Errorf("not implemented") +} + +func runTeamGet(cmd *cobra.Command, args []string) error { + cfg, err := config.Get() + if err != nil { + return err + } + + client := api.NewClient(cfg.APIURL, cfg.WSURL) + team, err := client.GetTeam(args[0], cfg.UserID) + if err != nil { + return err + } + + if team == nil { + fmt.Println("Team not found") + return nil + } + + output, err := json.MarshalIndent(team, "", " ") + if err != nil { + return err + } + + fmt.Println(string(output)) + return nil +} + +func runTeamDelete(cmd *cobra.Command, args []string) error { + cfg, err := config.Get() + if err != nil { + return err + } + + teamID, err := strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("invalid team ID: %s", args[0]) + } + + client := api.NewClient(cfg.APIURL, cfg.WSURL) + if err := client.DeleteTeam(teamID, cfg.UserID); err != nil { + return err + } + + fmt.Printf("Team %d deleted\n", teamID) + return nil +} diff --git a/cli/internal/cli/version.go b/cli/internal/cli/version.go new file mode 100644 index 000000000..5bdc9bb7d --- /dev/null +++ b/cli/internal/cli/version.go @@ -0,0 +1,46 @@ +package cli + +import ( + "fmt" + + "github.com/kagent-dev/kagent/cli/internal/api" + "github.com/kagent-dev/kagent/cli/internal/config" + "github.com/spf13/cobra" +) + +var ( + // These variables should be set during build time using -ldflags + Version = "dev" + GitCommit = "none" + BuildDate = "unknown" +) + +func newVersionCmd() *cobra.Command { + return &cobra.Command{ + Use: "version", + Short: "Show version information", + RunE: runVersion, + } +} + +func runVersion(cmd *cobra.Command, args []string) error { + fmt.Printf("kagent version %s\n", Version) + fmt.Printf("git commit: %s\n", GitCommit) + fmt.Printf("build date: %s\n", BuildDate) + + // Get backend version + cfg, err := config.Get() + if err != nil { + return err + } + + client := api.NewClient(cfg.APIURL, cfg.WSURL) + version, err := client.GetVersion() + if err != nil { + fmt.Println("Warning: Could not fetch backend version") + } else { + fmt.Printf("backend version: %s\n", version) + } + + return nil +} diff --git a/cli/internal/config/config.go b/cli/internal/config/config.go new file mode 100644 index 000000000..c600d56aa --- /dev/null +++ b/cli/internal/config/config.go @@ -0,0 +1,57 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/spf13/viper" +) + +type Config struct { + APIURL string `mapstructure:"api_url"` + WSURL string `mapstructure:"ws_url"` + UserID string `mapstructure:"user_id"` +} + +func Init() error { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("error getting user home directory: %w", err) + } + + configDir := filepath.Join(home, ".kagent") + if err := os.MkdirAll(configDir, 0755); err != nil { + return fmt.Errorf("error creating config directory: %w", err) + } + + configFile := filepath.Join(configDir, "config.yaml") + + viper.SetConfigFile(configFile) + viper.SetConfigType("yaml") + + // Set default values + viper.SetDefault("api_url", "http://localhost:8081/api") + viper.SetDefault("ws_url", "ws://localhost:8081/api/ws") + viper.SetDefault("user_id", "guestuser@gmail.com") + + if err := viper.ReadInConfig(); err != nil { + // If config file doesn't exist, create it with defaults + if _, ok := err.(viper.ConfigFileNotFoundError); ok || os.IsNotExist(err) { + if err := viper.WriteConfigAs(configFile); err != nil { + return fmt.Errorf("error creating default config file: %w", err) + } + } else { + return fmt.Errorf("error reading config file: %w", err) + } + } + return nil +} + +func Get() (*Config, error) { + var config Config + if err := viper.Unmarshal(&config); err != nil { + return nil, fmt.Errorf("error unmarshaling config: %w", err) + } + return &config, nil +} diff --git a/cli/internal/ws/client.go b/cli/internal/ws/client.go new file mode 100644 index 000000000..1a451deda --- /dev/null +++ b/cli/internal/ws/client.go @@ -0,0 +1,185 @@ +package ws + +import ( + "bufio" + "encoding/json" + "fmt" + "net/http" + "os" + "os/signal" + "time" + + "github.com/gorilla/websocket" + "github.com/kagent-dev/kagent/cli/internal/api" +) + +// Config holds the WebSocket client configuration +type Config struct { + Origin string // WebSocket origin header +} + +// DefaultConfig returns the default configuration +func DefaultConfig() Config { + return Config{ + Origin: "http://localhost:8000", + } +} + +// Client handles the WebSocket connection and message processing +type Client struct { + conn *websocket.Conn + done chan struct{} + config Config +} + +// NewClient creates a new WebSocket client and establishes connection +func NewClient(wsURL string, runID string, config Config) (*Client, error) { + // Set the required headers for the WebSocket connection + headers := http.Header{} + headers.Add("Origin", config.Origin) + + // Create dialer with debug logging + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + EnableCompression: true, + } + + conn, _, err := dialer.Dial(wsURL+"/runs/"+runID, headers) + if err != nil { + return nil, fmt.Errorf("websocket connection failed: %v", err) + } + + return &Client{ + conn: conn, + done: make(chan struct{}), + config: config, + }, nil +} + +// StartInteractive initiates the interactive session with the server +func (c *Client) StartInteractive(teamConfig api.TeamComponent, task string) error { + defer c.conn.Close() + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + inputTimeout := make(chan struct{}) + + // Send initial start message + startMsg := StartMessage{ + Type: MessageTypeStart, + Task: task, + TeamConfig: teamConfig, + } + + // type: "start", + // task: query, + // team_config: teamConfig, + + if err := c.conn.WriteJSON(startMsg); err != nil { + return fmt.Errorf("failed to send start message: %v", err) + } + + go c.handleMessages(inputTimeout) + + select { + case <-interrupt: + fmt.Println("\nReceived interrupt signal. Closing connection...") + stopMsg := StopMessage{ + Type: MessageTypeStop, + Reason: "Cancelled by user", + } + if err := c.conn.WriteJSON(stopMsg); err != nil { + fmt.Fprintf(os.Stderr, "Error sending stop message: %v\n", err) + } + select { + case <-c.done: + case <-time.After(time.Second): + } + return nil + + case <-inputTimeout: + fmt.Println("\nInput timeout exceeded. Stopping task...") + return fmt.Errorf("input timeout exceeded") + + case <-c.done: + return nil + } +} + +func (c *Client) handleMessages(inputTimeout chan struct{}) { + defer close(c.done) + + for { + var msg WebSocketMessage + err := c.conn.ReadJSON(&msg) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading message: %v\n", err) + return + } + + switch msg.Type { + case MessageTypeError: + fmt.Fprintf(os.Stderr, "Error: %s\n", msg.Error) + return + + case MessageTypeMessage: + var taskMessage api.TaskMessage + if err := json.Unmarshal(msg.Data, &taskMessage); err != nil { + fmt.Fprintf(os.Stderr, "Error parsing message data: %v\n", err) + continue + } + fmt.Printf("%s: %s\n", taskMessage.Source, taskMessage.Content) + + case MessageTypeInputRequest: + go c.handleInputTimeout(inputTimeout) + if err := c.handleUserInput(); err != nil { + fmt.Fprintf(os.Stderr, "Error handling input: %v\n", err) + return + } + + case MessageTypeResult, MessageTypeCompletion: + if msg.Status == "complete" { + if msg.Result != nil { + // Handle any specific TeamResult processing if needed + fmt.Printf("\nTask completed! Duration: %.2f seconds\n", msg.Result.Duration) + } else { + fmt.Println("\nTask completed successfully!") + } + return + } else if msg.Status == "error" { + fmt.Fprintf(os.Stderr, "\nTask failed: %s\n", msg.Error) + return + } + } + } +} + +func (c *Client) handleInputTimeout(inputTimeout chan struct{}) { + timer := time.NewTimer(InputTimeoutDuration) + select { + case <-timer.C: + close(inputTimeout) + stopMsg := StopMessage{ + Type: MessageTypeStop, + Reason: "Input timeout", + Code: "TIMEOUT", + } + c.conn.WriteJSON(stopMsg) + case <-c.done: + timer.Stop() + } +} + +func (c *Client) handleUserInput() error { + fmt.Print("\nInput required > ") + scanner := bufio.NewScanner(os.Stdin) + if scanner.Scan() { + response := scanner.Text() + inputMsg := InputResponseMessage{ + Type: MessageTypeInputResponse, + Response: response, + } + return c.conn.WriteJSON(inputMsg) + } + return fmt.Errorf("failed to read user input") +} diff --git a/cli/internal/ws/types.go b/cli/internal/ws/types.go new file mode 100644 index 000000000..ffeb531f9 --- /dev/null +++ b/cli/internal/ws/types.go @@ -0,0 +1,57 @@ +package ws + +import ( + "encoding/json" + "time" + + "github.com/kagent-dev/kagent/cli/internal/api" +) + +const ( + // InputTimeoutDuration defines how long to wait for user input + InputTimeoutDuration = 5 * time.Minute +) + +// MessageType represents the type of WebSocket message +type MessageType string + +const ( + MessageTypeStart MessageType = "start" + MessageTypeMessage MessageType = "message" + MessageTypeInputRequest MessageType = "input_request" + MessageTypeResult MessageType = "result" + MessageTypeCompletion MessageType = "completion" + MessageTypeError MessageType = "error" + MessageTypeStop MessageType = "stop" + MessageTypeInputResponse MessageType = "input_response" +) + +// WebSocketMessage represents the structure of messages received from the server +type WebSocketMessage struct { + Type MessageType `json:"type"` + Data json.RawMessage `json:"data,omitempty"` + Status string `json:"status,omitempty"` + Error string `json:"error,omitempty"` + Message string `json:"message,omitempty"` + Result *api.TeamResult `json:"result,omitempty"` +} + +// StartMessage represents the initial message sent to start a task +type StartMessage struct { + Type MessageType `json:"type"` + Task string `json:"task"` + TeamConfig api.TeamComponent `json:"team_config"` +} + +// StopMessage represents the message sent to stop a task +type StopMessage struct { + Type MessageType `json:"type"` + Reason string `json:"reason"` + Code string `json:"code,omitempty"` +} + +// InputResponseMessage represents the message sent in response to an input request +type InputResponseMessage struct { + Type MessageType `json:"type"` + Response string `json:"response"` +} diff --git a/python/pyproject.toml b/python/pyproject.toml index 6bebb802b..8e56e8e29 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -28,9 +28,9 @@ jupyter-executor = [ ] [tool.uv.sources] -autogenstudio = { git = "https://github.com/EItanya/autogen.git", subdirectory = "python/packages/autogen-studio", rev = "c573244206768908f415ea06b415c8af6d1e8d5b" } -autogen-ext = { git = "https://github.com/EItanya/autogen.git", subdirectory = "python/packages/autogen-ext", rev = "c573244206768908f415ea06b415c8af6d1e8d5b" } -autogen-core = { git = "https://github.com/EItanya/autogen.git", subdirectory = "python/packages/autogen-core", rev = "c573244206768908f415ea06b415c8af6d1e8d5b" } +autogenstudio = { git = "https://github.com/EItanya/autogen.git", subdirectory = "python/packages/autogen-studio", rev = "0d4e9bb02d7356feda4d0b383df4085ed6995166" } +autogen-ext = { git = "https://github.com/EItanya/autogen.git", subdirectory = "python/packages/autogen-ext", rev = "0d4e9bb02d7356feda4d0b383df4085ed6995166" } +autogen-core = { git = "https://github.com/EItanya/autogen.git", subdirectory = "python/packages/autogen-core", rev = "0d4e9bb02d7356feda4d0b383df4085ed6995166" } kagent = { workspace = true } [tool.ruff] diff --git a/python/uv.lock b/python/uv.lock index 9e1256250..e27b195dc 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -124,7 +124,7 @@ wheels = [ [[package]] name = "autogen-agentchat" version = "0.4.3" -source = { git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-agentchat&rev=c573244206768908f415ea06b415c8af6d1e8d5b#c573244206768908f415ea06b415c8af6d1e8d5b" } +source = { git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-agentchat&rev=0d4e9bb02d7356feda4d0b383df4085ed6995166#0d4e9bb02d7356feda4d0b383df4085ed6995166" } dependencies = [ { name = "autogen-core" }, ] @@ -132,7 +132,7 @@ dependencies = [ [[package]] name = "autogen-core" version = "0.4.3" -source = { git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-core&rev=c573244206768908f415ea06b415c8af6d1e8d5b#c573244206768908f415ea06b415c8af6d1e8d5b" } +source = { git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-core&rev=0d4e9bb02d7356feda4d0b383df4085ed6995166#0d4e9bb02d7356feda4d0b383df4085ed6995166" } dependencies = [ { name = "jsonref" }, { name = "opentelemetry-api" }, @@ -145,7 +145,7 @@ dependencies = [ [[package]] name = "autogen-ext" version = "0.4.3" -source = { git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-ext&rev=c573244206768908f415ea06b415c8af6d1e8d5b#c573244206768908f415ea06b415c8af6d1e8d5b" } +source = { git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-ext&rev=0d4e9bb02d7356feda4d0b383df4085ed6995166#0d4e9bb02d7356feda4d0b383df4085ed6995166" } dependencies = [ { name = "autogen-core" }, ] @@ -175,7 +175,7 @@ openai = [ [[package]] name = "autogenstudio" version = "0.4.0" -source = { git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-studio&rev=c573244206768908f415ea06b415c8af6d1e8d5b#c573244206768908f415ea06b415c8af6d1e8d5b" } +source = { git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-studio&rev=0d4e9bb02d7356feda4d0b383df4085ed6995166#0d4e9bb02d7356feda4d0b383df4085ed6995166" } dependencies = [ { name = "aiofiles" }, { name = "alembic" }, @@ -921,9 +921,9 @@ jupyter-executor = [ [package.metadata] requires-dist = [ { name = "autogen-agentchat", specifier = "==0.4.3" }, - { name = "autogen-core", git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-core&rev=c573244206768908f415ea06b415c8af6d1e8d5b" }, - { name = "autogen-ext", extras = ["http"], git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-ext&rev=c573244206768908f415ea06b415c8af6d1e8d5b" }, - { name = "autogenstudio", git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-studio&rev=c573244206768908f415ea06b415c8af6d1e8d5b" }, + { name = "autogen-core", git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-core&rev=0d4e9bb02d7356feda4d0b383df4085ed6995166" }, + { name = "autogen-ext", extras = ["http"], git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-ext&rev=0d4e9bb02d7356feda4d0b383df4085ed6995166" }, + { name = "autogenstudio", git = "https://github.com/EItanya/autogen.git?subdirectory=python%2Fpackages%2Fautogen-studio&rev=0d4e9bb02d7356feda4d0b383df4085ed6995166" }, { name = "ipykernel", marker = "extra == 'jupyter-executor'", specifier = ">=6.29.5" }, { name = "mcp", specifier = ">=1.2.0" }, { name = "nbclient", marker = "extra == 'jupyter-executor'", specifier = ">=0.10.2" },