Skip to content

Commit

Permalink
Merge pull request #172 from icey-yu/fix-gen
Browse files Browse the repository at this point in the history
Fix gen
  • Loading branch information
icey-yu authored Dec 16, 2024
2 parents cc43ff1 + 2d9768a commit d68f86c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 15 deletions.
89 changes: 74 additions & 15 deletions magefile.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ import (

var Default = RpcCaller

var (
projectRoot string
//packageRegex = regexp.MustCompile(`^\s*package\s+([a-zA-Z0-9_.]+);`)
importRegex = regexp.MustCompile(`^\s*import\s+"([a-zA-Z0-9_./]+)";`)
goPackageRegex = regexp.MustCompile(`^\s*option\s+go_package\s*=\s*"([^"]+)";`)
serviceRegex = regexp.MustCompile(`^\s*service\s+([A-Za-z0-9_]+)\s*{`)
rpcRegex = regexp.MustCompile(`^\s*rpc\s+([A-Za-z0-9_]+)\s*\(\s*([A-Za-z0-9_]+\.?[A-Za-z0-9_]*)\s*\)\s+returns\s*\(\s*([A-Za-z0-9_]+\.?[A-Za-z0-9_]*)\s*\);`)
)

func RpcCaller() {
if err := Generate(); err != nil {
fmt.Println(err)
Expand All @@ -39,13 +48,14 @@ type Service struct {
ServiceName string
GoPackage string
Methods []ServiceMethod
Imports []string
}

// Generate is the Mage target to generate Go code from proto files
func Generate() error {
fmt.Println("Generating rpc_caller...")

projectRoot, err := os.Getwd()
var err error
projectRoot, err = os.Getwd()
if err != nil {
return errors.New("get root directory failed")
}
Expand Down Expand Up @@ -119,24 +129,22 @@ func Generate() error {
func parseProtoFile(protoFilePath string) (Service, error) {
file, err := os.Open(protoFilePath)
if err != nil {
return Service{}, fmt.Errorf("打开 proto 文件失败: %v", err)
return Service{}, fmt.Errorf("open proto file failed: %v", err)
}
defer file.Close()

filePath := filepath.Dir(protoFilePath)
filePath_ := filepath.Dir(protoFilePath)
fileName := strings.TrimSuffix(filepath.Base(protoFilePath), filepath.Ext(protoFilePath))
scanner := bufio.NewScanner(file)
var (
goPackage string
serviceName string
methods []ServiceMethod
goPackage string
serviceName string
methods []ServiceMethod
imports []string
alreadyImports = make(map[string]struct{})
allImports = make(map[string]string)
)

//packageRegex := regexp.MustCompile(`^\s*package\s+([a-zA-Z0-9_.]+);`)
goPackageRegex := regexp.MustCompile(`^\s*option\s+go_package\s*=\s*"([^"]+)";`)
serviceRegex := regexp.MustCompile(`^\s*service\s+([A-Za-z0-9_]+)\s*{`)
rpcRegex := regexp.MustCompile(`^\s*rpc\s+([A-Za-z0-9_]+)\s*\(\s*([A-Za-z0-9_]+)\s*\)\s+returns\s*\(\s*([A-Za-z0-9_]+)\s*\);`)

inService := false

for scanner.Scan() {
Expand All @@ -147,6 +155,12 @@ func parseProtoFile(protoFilePath string) (Service, error) {
// continue
//}

if matches := importRegex.FindStringSubmatch(line); matches != nil {
pkg := strings.TrimSuffix(filepath.Base(matches[1]), filepath.Ext(matches[1]))
allImports[pkg] = matches[1]
continue
}

if matches := goPackageRegex.FindStringSubmatch(line); matches != nil {
goPackage = matches[1]

Expand All @@ -172,6 +186,21 @@ func parseProtoFile(protoFilePath string) (Service, error) {
FullMethodName: fullMethodName,
}
methods = append(methods, method)

if strings.Contains(requestType, ".") {
imp := strings.Split(requestType, ".")[0]
if f, ok := allImports[imp]; ok {
if _, ok = alreadyImports[imp]; !ok {
alreadyImports[imp] = struct{}{}
gopkg, err := getGoPackage(filepath.Join(projectRoot, f))
if err != nil {
fmt.Printf("get go package failed: %v", err)
} else {
imports = append(imports, gopkg)
}
}
}
}
continue
}

Expand All @@ -195,11 +224,12 @@ func parseProtoFile(protoFilePath string) (Service, error) {

sp := strings.Split(goPackage, "/")
service := Service{
FilePath: filePath,
FilePath: filePath_,
FileName: fileName,
ServiceName: serviceName,
GoPackage: sp[len(sp)-1],
Methods: methods,
Imports: imports,
}

return service, nil
Expand All @@ -213,6 +243,9 @@ func generateGoFile(service Service) error {
tmpl := `package {{.GoPackage}}
import (
{{- range .Imports }}
"{{ . }}"
{{- end }}
"github.com/openimsdk/protocol/rpccall"
"google.golang.org/grpc"
)
Expand All @@ -234,10 +267,12 @@ var (
ServiceNameCamel string
GoPackage string
Methods []ServiceMethod
Imports []string
}{
ServiceNameCamel: toCamelCase(service.ServiceName),
GoPackage: service.GoPackage,
Methods: service.Methods,
Imports: service.Imports,
}

t, err := template.New("goFile").Parse(tmpl)
Expand All @@ -261,14 +296,38 @@ var (
return nil
}

func getGoPackage(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("open proto file failed: %v", err)
}
defer file.Close()

scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if matches := goPackageRegex.FindStringSubmatch(line); matches != nil {
return matches[1], nil
}
}
return "", errors.New("goPackage not found")
}

func toCamelCase(s string) string {
parts := strings.Split(s, "_")
pkg := ""
ss := s
if strings.Contains(ss, ".") {
sl := strings.Split(ss, ".")
ss = sl[1]
pkg = sl[0] + "."
}
parts := strings.Split(ss, "_")
for i, p := range parts {
if len(p) > 0 {
parts[i] = strings.ToUpper(p[:1]) + p[1:]
}
}
return strings.Join(parts, "")
return pkg + strings.Join(parts, "")
}

func formatFile(filePath string) {
Expand Down
5 changes: 5 additions & 0 deletions msg/msg_caller.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ package msg

import (
"github.com/openimsdk/protocol/rpccall"
"github.com/openimsdk/protocol/sdkws"
"google.golang.org/grpc"
)

func InitMsg(conn *grpc.ClientConn) {
GetMaxSeqCaller.SetConn(conn)
GetMaxSeqsCaller.SetConn(conn)
GetHasReadSeqsCaller.SetConn(conn)
GetMsgByConversationIDsCaller.SetConn(conn)
GetConversationMaxSeqCaller.SetConn(conn)
PullMessageBySeqsCaller.SetConn(conn)
GetSeqMessageCaller.SetConn(conn)
SearchMessageCaller.SetConn(conn)
SendMsgCaller.SetConn(conn)
Expand Down Expand Up @@ -39,10 +42,12 @@ func InitMsg(conn *grpc.ClientConn) {
}

var (
GetMaxSeqCaller = rpccall.NewRpcCaller[sdkws.GetMaxSeqReq, sdkws.GetMaxSeqResp](Msg_GetMaxSeq_FullMethodName)
GetMaxSeqsCaller = rpccall.NewRpcCaller[GetMaxSeqsReq, SeqsInfoResp](Msg_GetMaxSeqs_FullMethodName)
GetHasReadSeqsCaller = rpccall.NewRpcCaller[GetHasReadSeqsReq, SeqsInfoResp](Msg_GetHasReadSeqs_FullMethodName)
GetMsgByConversationIDsCaller = rpccall.NewRpcCaller[GetMsgByConversationIDsReq, GetMsgByConversationIDsResp](Msg_GetMsgByConversationIDs_FullMethodName)
GetConversationMaxSeqCaller = rpccall.NewRpcCaller[GetConversationMaxSeqReq, GetConversationMaxSeqResp](Msg_GetConversationMaxSeq_FullMethodName)
PullMessageBySeqsCaller = rpccall.NewRpcCaller[sdkws.PullMessageBySeqsReq, sdkws.PullMessageBySeqsResp](Msg_PullMessageBySeqs_FullMethodName)
GetSeqMessageCaller = rpccall.NewRpcCaller[GetSeqMessageReq, GetSeqMessageResp](Msg_GetSeqMessage_FullMethodName)
SearchMessageCaller = rpccall.NewRpcCaller[SearchMessageReq, SearchMessageResp](Msg_SearchMessage_FullMethodName)
SendMsgCaller = rpccall.NewRpcCaller[SendMsgReq, SendMsgResp](Msg_SendMsg_FullMethodName)
Expand Down

0 comments on commit d68f86c

Please sign in to comment.