diff --git a/controllers/db.go b/controllers/db.go index 338e804..69f5843 100644 --- a/controllers/db.go +++ b/controllers/db.go @@ -13,10 +13,10 @@ type DatabaseController struct { } type Response struct { - Code int `json:"code"` - Data common.Any `json:"data"` - Message string `json:"message"` - CmdResult string `json:"cmdResult"` + Code int `json:"code"` + Data common.Any `json:"data"` + Message string `json:"message"` + Params common.ParameterMap `json:"params"` } type Request struct { @@ -86,7 +86,8 @@ func (this *DatabaseController) Execute() { res.Message = "connection refused for lack of session" } else { json.Unmarshal(this.Ctx.Input.RequestBody, ¶ms) - result, cmdResult, err := dao.Execute(nsid.(string), params.Gql, params.ParamList) + result, paramsMap, err := dao.Execute(nsid.(string), params.Gql, params.ParamList) + if err == nil { res.Code = 0 res.Data = &result @@ -94,8 +95,10 @@ func (this *DatabaseController) Execute() { res.Code = -1 res.Message = err.Error() } - if cmdResult != "" { - res.CmdResult = cmdResult + if len(paramsMap) == 0 { + res.Params = nil + } else { + res.Params = paramsMap } } this.Data["json"] = &res diff --git a/service/dao/dao.go b/service/dao/dao.go index 22ec900..e32e719 100644 --- a/service/dao/dao.go +++ b/service/dao/dao.go @@ -287,14 +287,14 @@ func Disconnect(nsid string) { pool.Disconnect(nsid) } -func Execute(nsid string, gql string, paramList common.ParameterList) (result ExecuteResult, cmdResult string, err error) { +func Execute(nsid string, gql string, paramList common.ParameterList) (result ExecuteResult, params common.ParameterMap, err error) { result = ExecuteResult{ Headers: make([]string, 0), Tables: make([]map[string]common.Any, 0), } connection, err := pool.GetConnection(nsid) if err != nil { - return result, "", err + return result, nil, err } responseChannel := make(chan pool.ChannelResponse) connection.RequestChannel <- pool.ChannelRequest{ @@ -303,12 +303,13 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex ParamList: paramList, } response := <-responseChannel + paramsMap := response.Params if response.Error != nil { - return result, response.CmdResult, response.Error + return result, paramsMap, response.Error } resp := response.Result if response.Result == nil { - return result, response.CmdResult, nil + return result, paramsMap, nil } if resp.IsSetPlanDesc() { format := string(resp.GetPlanDesc().GetFormat()) @@ -324,7 +325,7 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex rowValue["operator info"] = rows[i][4] result.Tables = append(result.Tables, rowValue) } - return result, response.CmdResult, err + return result, paramsMap, err } else { var rowValue = make(map[string]common.Any) result.Headers = append(result.Headers, "format") @@ -334,12 +335,12 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex rowValue["format"] = resp.MakeDotGraphByStruct() } result.Tables = append(result.Tables, rowValue) - return result, response.CmdResult, err + return result, paramsMap, err } } if !resp.IsSucceed() { logs.Info("ErrorCode: %v, ErrorMsg: %s", resp.GetErrorCode(), resp.GetErrorMsg()) - return result, response.CmdResult, errors.New(string(resp.GetErrorMsg())) + return result, paramsMap, errors.New(string(resp.GetErrorMsg())) } if !resp.IsEmpty() { rowSize := resp.GetRowSize() @@ -353,16 +354,16 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex var _edgesParsedList = make(list, 0) var _pathsParsedList = make(list, 0) if err != nil { - return result, response.CmdResult, err + return result, paramsMap, err } for j := 0; j < colSize; j++ { rowData, err := record.GetValueByIndex(j) if err != nil { - return result, response.CmdResult, err + return result, paramsMap, err } value, err := getValue(rowData) if err != nil { - return result, response.CmdResult, err + return result, paramsMap, err } rowValue[result.Headers[j]] = value valueType := rowData.GetType() @@ -398,12 +399,12 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex rowValue["_pathsParsedList"] = _pathsParsedList } if err != nil { - return result, response.CmdResult, err + return result, paramsMap, err } } result.Tables = append(result.Tables, rowValue) } } result.TimeCost = resp.GetLatency() - return result, response.CmdResult, nil + return result, paramsMap, nil } diff --git a/service/pool/pool.go b/service/pool/pool.go index 7f332a3..0bf0348 100644 --- a/service/pool/pool.go +++ b/service/pool/pool.go @@ -21,7 +21,14 @@ import ( var ( ConnectionClosedError = errors.New("an existing connection was forcibly closed, please check your network") SessionLostError = errors.New("the connection session was lost, please connect again") - SetParamsSuccess = "Set patameter successfully!" + InterruptError = errors.New("Other statements was not executed due to this error.") +) + +// Console side commands +const ( + Unknown = -1 + Param = 1 + Params = 2 ) type Account struct { @@ -30,9 +37,9 @@ type Account struct { } type ChannelResponse struct { - Result *nebula.ResultSet - CmdResult string - Error error + Result *nebula.ResultSet + Params common.ParameterMap + Error error } type ChannelRequest struct { @@ -50,8 +57,6 @@ type Connection struct { session *nebula.Session } -// var parameterMap common.ParameterMap - var connectionPool = make(map[string]*Connection) var currentConnectionNum = 0 var connectLock sync.Mutex @@ -86,17 +91,6 @@ func isThriftTransportError(err error) bool { return false } -func transferString(s string) string { - if len(s) < 2 { - return s - } - - if s[0] == '\'' && s[len(s)-1] == '\'' { - return "\"" + strings.ReplaceAll(s[1:len(s)-1], "\"", "\\\"") + "\"" - } - return s -} - // construct Slice to nebula.NList func Slice2Nlist(list []interface{}) (*nebulaType.NList, error) { sv := []*nebulaType.Value{} @@ -171,61 +165,124 @@ func Base2Value(any interface{}) (value *nebulaType.Value, err error) { return } -func executeCmd(parameterList common.ParameterList, parameterMap *common.ParameterMap) (err error) { - tmp := make(common.ParameterMap) +func isCmd(query string) (isLocal bool, localCmd int, args []string) { + isLocal = false + localCmd = Unknown + plain := strings.TrimSpace(query) + if len(plain) < 1 || plain[0] != ':' { + return + } + isLocal = true + words := strings.Fields(plain[1:]) + localCmdName := words[0] + switch strings.ToLower(localCmdName) { + case "param": + { + localCmd = Param + args = []string{plain} + } + case "params": + { + localCmd = Params + args = []string{plain} + } + } + return +} + +func executeCmd(parameterList common.ParameterList, parameterMap *common.ParameterMap) (showMap common.ParameterMap, err error) { for _, v := range parameterList { // convert interface{} to nebula.Value - plain := strings.TrimSpace(v) - if len(plain) < 1 || plain[0] != ':' { - return - } - words := strings.Fields(plain[1:]) - localCmdName := words[0] - switch strings.ToLower(localCmdName) { - case "param": - err = parseParam(v, &tmp, parameterMap) - return err - case "params": - err = parseParams(v, &tmp, parameterMap) - return err + if isLocal, cmd, args := isCmd(v); isLocal { + switch cmd { + case Param: + if len(args) != 1 { + return nil, nil + } + err = defineParams(args[0], parameterMap) + return nil, err + case Params: + if len(args) != 1 { + return nil, nil + } + showMap, err = ListParams(args[0], parameterMap) + return showMap, err + } } } - return nil + return nil, nil } -func parseParam(gql string, tmp *common.ParameterMap, parameterMap *common.ParameterMap) (err error) { - reg := regexp.MustCompile(`(?i)^\s*:param\s+(.+?)\s*=>\s*(.+?)\s*$`) - res := reg.FindAllStringSubmatch(gql, -1) - if len(res) != 1 || len(res[0]) != 3 { - return errors.New("Grammar reular check error: " + reg.String()) +func defineParams(args string, parameterMap *common.ParameterMap) (err error) { + argsRewritten := strings.Replace(args, "'", "\"", -1) + reg := regexp.MustCompile(`^\s*:param\s+(\S+)\s*=>(.*)$`) + if reg == nil { + err = errors.New("invalid regular expression") + return } - param := "{\"" + res[0][1] + "\"" + ":" + transferString(res[0][2]) + "}" - err = json.Unmarshal([]byte(param), &tmp) - if err != nil { - return err + matchResult := reg.FindAllStringSubmatch(argsRewritten, -1) + if len(matchResult) != 1 || len(matchResult[0]) != 3 { + err = errors.New("Set params failed. Wrong local command format " + reg.String()) + return } - for k, v := range *tmp { - (*parameterMap)[k] = v + /* + * :param p1=> -> [":param p1=>",":p1",""] + * :param p2=>3 -> [":param p2=>3",":p2","3"] + */ + paramKey := matchResult[0][1] + paramValue := matchResult[0][2] + if len(paramValue) == 0 { + delete((*parameterMap), paramKey) + } else { + paramsWithGoType := make(common.ParameterMap) + param := "{\"" + paramKey + "\"" + ":" + paramValue + "}" + err = json.Unmarshal([]byte(param), ¶msWithGoType) + if err != nil { + return + } + for k, v := range paramsWithGoType { + (*parameterMap)[k] = v + } } return nil } -func parseParams(gql string, tmp *common.ParameterMap, parameterMap *common.ParameterMap) (err error) { - reg := regexp.MustCompile(`(?i)^\s*:params\s+{\s*(\"[0-9a-z]+\"\s*:\s*.+)?\s*}\s*$`) - res := reg.FindAllStringSubmatch(gql, -1) - if len(res) != 1 || len(res[0]) != 2 { - return errors.New("Grammar reular check error: " + reg.String()) +func ListParams(args string, parameterMap *common.ParameterMap) (showMap common.ParameterMap, err error) { + reg := regexp.MustCompile(`^\s*:params\s*(\S*)\s*$`) + paramsWithGoType := make(common.ParameterMap) + if reg == nil { + err = errors.New("invalid regular expression") + return } - param := "{" + res[0][1] + "}" - err = json.Unmarshal([]byte(param), &tmp) - if err != nil { - return err + matchResult := reg.FindAllStringSubmatch(args, -1) + if len(matchResult) != 1 { + err = errors.New("Set params failed. Wrong local command format " + reg.String()) + return } - for k, v := range *tmp { - (*parameterMap)[k] = v + res := matchResult[0] + /* + * :params -> [":params",""] + * :params p1 -> ["params","p1"] + */ + if len(res) != 2 { + return + } else { + paramKey := matchResult[0][1] + if len(paramKey) == 0 { + for k, v := range *parameterMap { + paramsWithGoType[k] = v + } + } else { + if paramValue, ok := (*parameterMap)[paramKey]; ok { + paramsWithGoType[paramKey] = paramValue + } else { + err = errors.New("Unknown parameter: " + paramKey) + } + } } - return nil + return paramsWithGoType, nil } + func NewConnection(address string, port int, username string, password string) (nsid string, err error) { connectLock.Lock() defer connectLock.Unlock() @@ -277,13 +334,18 @@ func NewConnection(address string, port int, username string, password string) ( } } }() - cmdResult := "" + showMap := make(common.ParameterMap) if len(request.ParamList) > 0 { - err = executeCmd(request.ParamList, &connection.parameterMap) + showMap, err = executeCmd(request.ParamList, &connection.parameterMap) if err != nil { - cmdResult = err.Error() - } else { - cmdResult = SetParamsSuccess + if len(request.Gql) > 0 { + err = errors.New(err.Error() + InterruptError.Error()) + } + request.ResponseChannel <- ChannelResponse{ + Result: nil, + Params: showMap, + Error: err, + } } } @@ -301,15 +363,15 @@ func NewConnection(address string, port int, username string, password string) ( err = ConnectionClosedError } request.ResponseChannel <- ChannelResponse{ - Result: response, - CmdResult: cmdResult, - Error: err, + Result: response, + Params: showMap, + Error: err, } } else { request.ResponseChannel <- ChannelResponse{ - Result: nil, - CmdResult: cmdResult, - Error: nil, + Result: nil, + Params: showMap, + Error: nil, } } }()