diff --git a/go.mod b/go.mod index 59b75a8..ca828f2 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,6 @@ require ( github.com/c-bata/go-prompt v0.2.6 github.com/jedib0t/go-pretty/v6 v6.0.5 github.com/jievince/liner v1.2.3 - github.com/vesoft-inc/nebula-go/v2 v2.5.2-0.20211210024917-9461e07cdca2 + github.com/vesoft-inc/nebula-go/v2 v2.5.2-0.20211221081231-40030d441885 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 // indirect ) diff --git a/go.sum b/go.sum index 0fd0a38..498d81a 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/vesoft-inc/nebula-go/v2 v2.5.2-0.20211210024917-9461e07cdca2 h1:ivN82oSpM/kNCZVdzoCSj1IT4ZDW7Xi2juhAZvh7eF8= -github.com/vesoft-inc/nebula-go/v2 v2.5.2-0.20211210024917-9461e07cdca2/go.mod h1:fehDUs97/mpmxXi9WezhznX0Dg7hmQRUoOWgDZv9zG0= +github.com/vesoft-inc/nebula-go/v2 v2.5.2-0.20211221081231-40030d441885 h1:mf9MRXEKOmUQVffO9UcYXDMtVtXfeTK1kPxie/2fK5s= +github.com/vesoft-inc/nebula-go/v2 v2.5.2-0.20211221081231-40030d441885/go.mod h1:fehDUs97/mpmxXi9WezhznX0Dg7hmQRUoOWgDZv9zG0= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180816055513-1c9583448a9c/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/main.go b/main.go index 6d1ac95..9773942 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ package main import ( "crypto/tls" "crypto/x509" + "encoding/json" "flag" "fmt" "io/ioutil" @@ -15,6 +16,7 @@ import ( "os" "path" "path/filepath" + "regexp" "strconv" "strings" "time" @@ -22,7 +24,8 @@ import ( "github.com/vesoft-inc/nebula-console/box" "github.com/vesoft-inc/nebula-console/cli" "github.com/vesoft-inc/nebula-console/printer" - nebula "github.com/vesoft-inc/nebula-go/v2" + nebulago "github.com/vesoft-inc/nebula-go/v2" + nebula "github.com/vesoft-inc/nebula-go/v2/nebula" ) // Console side commands @@ -34,8 +37,14 @@ const ( ExportCsv = 3 ExportDot = 4 Repeat = 5 + Param = 6 + Params = 7 ) +type ParameterMap map[string]interface{} + +var parameterMap ParameterMap + var dataSetPrinter = printer.NewDataSetPrinter() var planDescPrinter = printer.NewPlanDescPrinter() @@ -106,7 +115,7 @@ func isConsoleCmd(cmd string) (isLocal bool, localCmd int, args []string) { return } - plain := strings.TrimSpace(strings.ToLower(cmd)) + plain := strings.TrimSpace(cmd) if len(plain) < 1 || plain[0] != ':' { return } @@ -117,9 +126,11 @@ func isConsoleCmd(cmd string) (isLocal bool, localCmd int, args []string) { } words := strings.Fields(plain[1:]) localCmdName := words[0] - switch localCmdName { + switch strings.ToLower(localCmdName) { case "exit", "quit": - localCmd = Quit + { + localCmd = Quit + } case "sleep": { localCmd = Sleep @@ -145,6 +156,16 @@ func isConsoleCmd(cmd string) (isLocal bool, localCmd int, args []string) { localCmd = ExportDot args = []string{words[1]} } + case "param": + { + localCmd = Param + args = []string{plain} + } + case "params": + { + localCmd = Params + args = []string{plain} + } } return } @@ -177,12 +198,41 @@ func executeConsoleCmd(c cli.Cli, cmd int, args []string) { printConsoleResp("Error: invald integer, repeats should be greater than 1") } g_repeats = i + case Param: + reg := regexp.MustCompile(`^\s*:param\s+(.+?)\s*=>\s*(.+?)\s*$`) + if reg == nil { + fmt.Println("regexp err") + return + } + if len(args) != 1 { + return + } + res := reg.FindAllStringSubmatch(args[0], -1) + if len(res) != 1 || len(res[0]) != 3 { + return + } + + tmp := make(ParameterMap) + param := "{\"" + res[0][1] + "\"" + ":" + res[0][2] + "}" + err := json.Unmarshal([]byte(param), &tmp) + if err != nil { + printConsoleResp("Error: parameter parsing failed") + return + } + for k, v := range tmp { + parameterMap[k] = v + } + + case Params: + for k := range parameterMap { + fmt.Println(k, " => ", parameterMap[k]) + } default: printConsoleResp("Error: this local command not exists!") } } -func printResultSet(res *nebula.ResultSet, startTime time.Time) (duration time.Duration) { +func printResultSet(res *nebulago.ResultSet, startTime time.Time) (duration time.Duration) { if !res.IsSucceed() && !res.IsPartialSucceed() { fmt.Printf("[ERROR (%d)]: %s", res.GetErrorCode(), res.GetErrorMsg()) fmt.Println() @@ -254,7 +304,18 @@ func loop(c cli.Cli) error { var t2 int64 = 0 for i := 0; i < g_repeats; i++ { start := time.Now() - res, err := session.Execute(line) + // convert interface{} to nebula.Value + params := make(map[string]*nebula.Value) + for k, v := range parameterMap { + value, err := Base2Value(v) + if err != nil { + printConsoleResp(err.Error()) + return err + } + params[k] = value + } + + res, err := session.ExecuteWithParameter(line, params) if err != nil { return err } @@ -388,12 +449,87 @@ func validateFlags() { } } -var pool *nebula.ConnectionPool +// construct Slice to nebula.NList +func Slice2Nlist(list []interface{}) (*nebula.NList, error) { + sv := []*nebula.Value{} + var ret nebula.NList + for _, item := range list { + nv, er := Base2Value(item) + if er != nil { + return nil, er + } + sv = append(sv, nv) + } + ret.Values = sv + return &ret, nil +} + +// construct map to nebula.NMap +func Map2Nmap(m map[string]interface{}) (*nebula.NMap, error) { + var ret nebula.NMap + kvs := map[string]*nebula.Value{} + for k, v := range m { + nv, err := Base2Value(v) + if err != nil { + return nil, err + } + kvs[k] = nv + } + ret.Kvs = kvs + return &ret, nil +} + +// construct go-type to nebula.Value +func Base2Value(any interface{}) (value *nebula.Value, err error) { + value = nebula.NewValue() + if v, ok := any.(bool); ok { + value.BVal = &v + } else if v, ok := any.(int); ok { + ival := int64(v) + value.IVal = &ival + } else if v, ok := any.(float64); ok { + if v == float64(int64(v)) { + iv := int64(v) + value.IVal = &iv + } else { + value.FVal = &v + } + } else if v, ok := any.(float32); ok { + if v == float32(int64(v)) { + iv := int64(v) + value.IVal = &iv + } else { + fval := float64(v) + value.FVal = &fval + } + } else if v, ok := any.(string); ok { + value.SVal = []byte(v) + } else if v, ok := any.([]interface{}); ok { + nv, er := Slice2Nlist([]interface{}(v)) + if er != nil { + err = er + } + value.LVal = nv + } else if v, ok := any.(map[string]interface{}); ok { + nv, er := Map2Nmap(map[string]interface{}(v)) + if er != nil { + err = er + } + value.MVal = nv + } else { + // unsupport other Value type, use this function carefully + err = fmt.Errorf("Only support convert boolean/float/int/string/map/list to nebula.Value but %T", any) + } + return +} + +var pool *nebulago.ConnectionPool -var session *nebula.Session +var session *nebulago.Session func main() { flag.Parse() + parameterMap = make(ParameterMap) if flag.NFlag() == 1 && *version { fmt.Printf("nebula-console version Git: %s, Build Time: %s\n", gitCommit, buildDate) @@ -414,9 +550,9 @@ func main() { historyHome = filepath.Dir(ex) // Set to executable folder } - hostAddress := nebula.HostAddress{Host: *address, Port: *port} - hostList := []nebula.HostAddress{hostAddress} - poolConfig := nebula.PoolConfig{ + hostAddress := nebulago.HostAddress{Host: *address, Port: *port} + hostList := []nebulago.HostAddress{hostAddress} + poolConfig := nebulago.PoolConfig{ TimeOut: time.Duration(*timeout) * time.Millisecond, IdleTime: 0 * time.Millisecond, MaxConnPoolSize: 2, @@ -428,9 +564,9 @@ func main() { if err2 != nil { log.Panicf(fmt.Sprintf("Fail to generate the ssl config, ssl_root_ca_path: %s, ssl_cert_path: %s, ssl_private_key_path: %s, %s", *sslRootCAPath, *sslCertPath, *sslPrivateKeyPath, err2.Error())) } - pool, err = nebula.NewSslConnectionPool(hostList, poolConfig, sslConfig, nebula.DefaultLogger{}) + pool, err = nebulago.NewSslConnectionPool(hostList, poolConfig, sslConfig, nebulago.DefaultLogger{}) } else { - pool, err = nebula.NewConnectionPool(hostList, poolConfig, nebula.DefaultLogger{}) + pool, err = nebulago.NewConnectionPool(hostList, poolConfig, nebulago.DefaultLogger{}) } if err != nil { log.Panicf(fmt.Sprintf("Fail to initialize the connection pool, host: %s, port: %d, %s", *address, *port, err.Error()))