diff --git a/.gitignore b/.gitignore index 9401776..7b5cbe4 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,6 @@ nebula-httpd .idea .vscode/ vendor/ - # Dependency directories (remove the comment below to include it) tmp/ diff --git a/common/common.go b/common/common.go index 4638a78..af7b746 100644 --- a/common/common.go +++ b/common/common.go @@ -1,3 +1,6 @@ package common type Any interface{} + +type ParameterList []string +type ParameterMap map[string]interface{} diff --git a/controllers/db.go b/controllers/db.go index 3140419..4cfd806 100644 --- a/controllers/db.go +++ b/controllers/db.go @@ -26,7 +26,8 @@ type Request struct { } type ExecuteRequest struct { - Gql string `json:"gql"` + Gql string `json:"gql"` + ParamList common.ParameterList `json:"paramList"` } type Data map[string]interface{} @@ -84,7 +85,7 @@ func (this *DatabaseController) Execute() { res.Message = "connection refused for lack of session" } else { json.Unmarshal(this.Ctx.Input.RequestBody, ¶ms) - result, err := dao.Execute(nsid.(string), params.Gql) + result, err := dao.Execute(nsid.(string), params.Gql, params.ParamList) if err == nil { res.Code = 0 res.Data = &result diff --git a/service/dao/dao.go b/service/dao/dao.go index d311edf..11f122b 100644 --- a/service/dao/dao.go +++ b/service/dao/dao.go @@ -8,13 +8,14 @@ import ( "github.com/vesoft-inc/nebula-http-gateway/service/pool" nebula "github.com/vesoft-inc/nebula-go/v2" - nebulaType "github.com/vesoft-inc/nebula-go/v2/nebula" + nebulatype "github.com/vesoft-inc/nebula-go/v2/nebula" ) type ExecuteResult struct { - Headers []string `json:"headers"` - Tables []map[string]common.Any `json:"tables"` - TimeCost int64 `json:"timeCost"` + Headers []string `json:"headers"` + Tables []map[string]common.Any `json:"tables"` + TimeCost int64 `json:"timeCost"` + LocalParams common.ParameterMap `json:"localParams"` } type list []common.Any @@ -44,21 +45,21 @@ func getBasicValue(valWarp *nebula.ValueWrapper) (common.Any, error) { if valType == "null" { value, err := valWarp.AsNull() switch value { - case nebulaType.NullType___NULL__: + case nebulatype.NullType___NULL__: return "NULL", err - case nebulaType.NullType_NaN: + case nebulatype.NullType_NaN: return "NaN", err - case nebulaType.NullType_BAD_DATA: + case nebulatype.NullType_BAD_DATA: return "BAD_DATA", err - case nebulaType.NullType_BAD_TYPE: + case nebulatype.NullType_BAD_TYPE: return "BAD_TYPE", err - case nebulaType.NullType_OUT_OF_RANGE: + case nebulatype.NullType_OUT_OF_RANGE: return "OUT_OF_RANGE", err - case nebulaType.NullType_DIV_BY_ZERO: + case nebulatype.NullType_DIV_BY_ZERO: return "DIV_BY_ZERO", err - case nebulaType.NullType_UNKNOWN_PROP: + case nebulatype.NullType_UNKNOWN_PROP: return "UNKNOWN_PROP", err - case nebulaType.NullType_ERR_OVERFLOW: + case nebulatype.NullType_ERR_OVERFLOW: return "ERR_OVERFLOW", err } return "NULL", err @@ -287,26 +288,34 @@ func Disconnect(nsid string) { pool.Disconnect(nsid) } -func Execute(nsid string, gql string) (result ExecuteResult, err error) { +func Execute(nsid string, gql string, paramList common.ParameterList) (result ExecuteResult, err error) { result = ExecuteResult{ - Headers: make([]string, 0), - Tables: make([]map[string]common.Any, 0), + Headers: make([]string, 0), + Tables: make([]map[string]common.Any, 0), + LocalParams: nil, } connection, err := pool.GetConnection(nsid) if err != nil { return result, err } - responseChannel := make(chan pool.ChannelResponse) connection.RequestChannel <- pool.ChannelRequest{ Gql: gql, ResponseChannel: responseChannel, + ParamList: paramList, } response := <-responseChannel + paramsMap := response.Params + if len(paramsMap) > 0 { + result.LocalParams = paramsMap + } if response.Error != nil { return result, response.Error } resp := response.Result + if response.Result == nil { + return result, nil + } if resp.IsSetPlanDesc() { format := string(resp.GetPlanDesc().GetFormat()) if format == "row" { @@ -334,7 +343,6 @@ func Execute(nsid string, gql string) (result ExecuteResult, err error) { return result, err } } - if !resp.IsSucceed() { logs.Info("ErrorCode: %v, ErrorMsg: %s", resp.GetErrorCode(), resp.GetErrorMsg()) return result, errors.New(string(resp.GetErrorMsg())) diff --git a/service/pool/pool.go b/service/pool/pool.go index bb8bc5c..92ac19c 100644 --- a/service/pool/pool.go +++ b/service/pool/pool.go @@ -1,7 +1,10 @@ package pool import ( + "encoding/json" "errors" + "fmt" + "regexp" "strings" "sync" "time" @@ -12,11 +15,20 @@ import ( uuid "github.com/satori/go.uuid" nebula "github.com/vesoft-inc/nebula-go/v2" + nebulatype "github.com/vesoft-inc/nebula-go/v2/nebula" ) var ( ConnectionClosedError = errors.New("an existing connection was forcibly closed, please check your network") SessionLostError = errors.New("the connection session was lost, please connect again") + InterruptError = errors.New("Other statements was not executed due to this error.") +) + +// Console side commands +const ( + Unknown = -1 + Param = 1 + Params = 2 ) type Account struct { @@ -26,18 +38,21 @@ type Account struct { type ChannelResponse struct { Result *nebula.ResultSet + Params common.ParameterMap Error error } type ChannelRequest struct { Gql string ResponseChannel chan ChannelResponse + ParamList common.ParameterList } type Connection struct { RequestChannel chan ChannelRequest CloseChannel chan bool updateTime int64 + parameterMap common.ParameterMap account *Account session *nebula.Session } @@ -76,6 +91,190 @@ func isThriftTransportError(err error) bool { return false } +// construct Slice to nebula.NList +func Slice2Nlist(list []interface{}) (*nebulatype.NList, error) { + sv := []*nebulatype.Value{} + var ret nebulatype.NList + for _, item := range list { + nv, er := Base2Value(item) + if er != nil { + return nil, er + } + sv = append(sv, nv) + } + ret.Values = sv + return &ret, nil +} + +// construct map to nebula.NMap +func Map2Nmap(m map[string]interface{}) (*nebulatype.NMap, error) { + var ret nebulatype.NMap + kvs := map[string]*nebulatype.Value{} + for k, v := range m { + nv, err := Base2Value(v) + if err != nil { + return nil, err + } + kvs[k] = nv + } + ret.Kvs = kvs + return &ret, nil +} + +// construct go-type to nebula.Value +func Base2Value(any interface{}) (value *nebulatype.Value, err error) { + value = nebulatype.NewValue() + if v, ok := any.(bool); ok { + value.BVal = &v + } else if v, ok := any.(int); ok { + ival := int64(v) + value.IVal = &ival + } else if v, ok := any.(float64); ok { + if v == float64(int64(v)) { + iv := int64(v) + value.IVal = &iv + } else { + value.FVal = &v + } + } else if v, ok := any.(float32); ok { + if v == float32(int64(v)) { + iv := int64(v) + value.IVal = &iv + } else { + fval := float64(v) + value.FVal = &fval + } + } else if v, ok := any.(string); ok { + value.SVal = []byte(v) + } else if any == nil { + nval := nebulatype.NullType___NULL__ + value.NVal = &nval + } else if v, ok := any.([]interface{}); ok { + nv, er := Slice2Nlist([]interface{}(v)) + if er != nil { + err = er + } + value.LVal = nv + } else if v, ok := any.(map[string]interface{}); ok { + nv, er := Map2Nmap(map[string]interface{}(v)) + if er != nil { + err = er + } + value.MVal = nv + } else { + // unsupport other Value type, use this function carefully + err = fmt.Errorf("Only support convert boolean/float/int/string/map/list to nebula.Value but %T", any) + } + return +} + +func isCmd(query string) (isLocal bool, localCmd int, args []string) { + isLocal = false + localCmd = Unknown + plain := strings.TrimSpace(query) + if len(plain) < 1 || plain[0] != ':' { + return + } + isLocal = true + words := strings.Fields(plain[1:]) + localCmdName := words[0] + switch strings.ToLower(localCmdName) { + case "param": + localCmd = Param + args = []string{plain} + case "params": + localCmd = Params + args = []string{plain} + } + return +} + +func executeCmd(parameterList common.ParameterList, parameterMap *common.ParameterMap) (showMap common.ParameterMap, err error) { + tempMap := make(common.ParameterMap) + for _, v := range parameterList { + // convert interface{} to nebula.Value + if isLocal, cmd, args := isCmd(v); isLocal { + switch cmd { + case Param: + if len(args) == 1 { + err = defineParams(args[0], parameterMap) + } + if err != nil { + return nil, err + } + case Params: + if len(args) == 1 { + err = ListParams(args[0], &tempMap, parameterMap) + } + if err != nil { + return nil, err + } + } + } + } + return tempMap, nil +} + +func defineParams(args string, parameterMap *common.ParameterMap) (err error) { + argsRewritten := strings.Replace(args, "'", "\"", -1) + reg := regexp.MustCompile(`(?i)^\s*:param\s+(\S+)\s*=>(.*)$`) + matchResult := reg.FindAllStringSubmatch(argsRewritten, -1) + if len(matchResult) != 1 || len(matchResult[0]) != 3 { + err = errors.New("Set params failed. Wrong local command format (" + reg.String() + ") ") + return + } + /* + * :param p1=> -> [":param p1=>",":p1",""] + * :param p2=>3 -> [":param p2=>3",":p2","3"] + */ + paramKey := matchResult[0][1] + paramValue := matchResult[0][2] + if len(paramValue) == 0 { + delete((*parameterMap), paramKey) + } else { + paramsWithGoType := make(common.ParameterMap) + param := "{\"" + paramKey + "\"" + ":" + paramValue + "}" + err = json.Unmarshal([]byte(param), ¶msWithGoType) + if err != nil { + return + } + for k, v := range paramsWithGoType { + (*parameterMap)[k] = v + } + } + return nil +} + +func ListParams(args string, tmpParameter *common.ParameterMap, sessionMap *common.ParameterMap) (err error) { + reg := regexp.MustCompile(`(?i)^\s*:params\s*(\S*)\s*$`) + matchResult := reg.FindAllStringSubmatch(args, -1) + if len(matchResult) != 1 { + err = errors.New("Set params failed. Wrong local command format " + reg.String() + ") ") + return + } + res := matchResult[0] + /* + * :params -> [":params",""] + * :params p1 -> ["params","p1"] + */ + if len(res) != 2 { + return + } + paramKey := matchResult[0][1] + if len(paramKey) == 0 { + for k, v := range *sessionMap { + (*tmpParameter)[k] = v + } + } else { + if paramValue, ok := (*sessionMap)[paramKey]; ok { + (*tmpParameter)[paramKey] = paramValue + } else { + err = errors.New("Unknown parameter: " + paramKey) + } + } + return nil +} + func NewConnection(address string, port int, username string, password string) (nsid string, err error) { connectLock.Lock() defer connectLock.Unlock() @@ -103,6 +302,7 @@ func NewConnection(address string, port int, username string, password string) ( CloseChannel: make(chan bool), updateTime: time.Now().Unix(), session: session, + parameterMap: make(common.ParameterMap), account: &Account{ username: username, password: password, @@ -126,13 +326,46 @@ func NewConnection(address string, port int, username string, password string) ( } } }() - response, err := connection.session.Execute(request.Gql) - if err != nil && (isThriftProtoError(err) || isThriftTransportError(err)) { - err = ConnectionClosedError + showMap := make(common.ParameterMap) + if len(request.ParamList) > 0 { + showMap, err = executeCmd(request.ParamList, &connection.parameterMap) + if err != nil { + if len(request.Gql) > 0 { + err = errors.New(err.Error() + InterruptError.Error()) + } + request.ResponseChannel <- ChannelResponse{ + Result: nil, + Params: showMap, + Error: err, + } + return + } } - request.ResponseChannel <- ChannelResponse{ - Result: response, - Error: err, + + if len(request.Gql) > 0 { + params := make(map[string]*nebulatype.Value) + for k, v := range connection.parameterMap { + value, paramError := Base2Value(v) + if paramError != nil { + err = paramError + } + params[k] = value + } + response, err := connection.session.ExecuteWithParameter(request.Gql, params) + if err != nil && (isThriftProtoError(err) || isThriftTransportError(err)) { + err = ConnectionClosedError + } + request.ResponseChannel <- ChannelResponse{ + Result: response, + Params: showMap, + Error: err, + } + } else { + request.ResponseChannel <- ChannelResponse{ + Result: nil, + Params: showMap, + Error: nil, + } } }() case <-connection.CloseChannel: