Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gen #172

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 74 additions & 15 deletions magefile.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ import (

var Default = RpcCaller

var (
projectRoot string
//packageRegex = regexp.MustCompile(`^\s*package\s+([a-zA-Z0-9_.]+);`)
importRegex = regexp.MustCompile(`^\s*import\s+"([a-zA-Z0-9_./]+)";`)
goPackageRegex = regexp.MustCompile(`^\s*option\s+go_package\s*=\s*"([^"]+)";`)
serviceRegex = regexp.MustCompile(`^\s*service\s+([A-Za-z0-9_]+)\s*{`)
rpcRegex = regexp.MustCompile(`^\s*rpc\s+([A-Za-z0-9_]+)\s*\(\s*([A-Za-z0-9_]+\.?[A-Za-z0-9_]*)\s*\)\s+returns\s*\(\s*([A-Za-z0-9_]+\.?[A-Za-z0-9_]*)\s*\);`)
)

func RpcCaller() {
if err := Generate(); err != nil {
fmt.Println(err)
Expand All @@ -39,13 +48,14 @@ type Service struct {
ServiceName string
GoPackage string
Methods []ServiceMethod
Imports []string
}

// Generate is the Mage target to generate Go code from proto files
func Generate() error {
fmt.Println("Generating rpc_caller...")

projectRoot, err := os.Getwd()
var err error
projectRoot, err = os.Getwd()
if err != nil {
return errors.New("get root directory failed")
}
Expand Down Expand Up @@ -119,24 +129,22 @@ func Generate() error {
func parseProtoFile(protoFilePath string) (Service, error) {
file, err := os.Open(protoFilePath)
if err != nil {
return Service{}, fmt.Errorf("打开 proto 文件失败: %v", err)
return Service{}, fmt.Errorf("open proto file failed: %v", err)
}
defer file.Close()

filePath := filepath.Dir(protoFilePath)
filePath_ := filepath.Dir(protoFilePath)
fileName := strings.TrimSuffix(filepath.Base(protoFilePath), filepath.Ext(protoFilePath))
scanner := bufio.NewScanner(file)
var (
goPackage string
serviceName string
methods []ServiceMethod
goPackage string
serviceName string
methods []ServiceMethod
imports []string
alreadyImports = make(map[string]struct{})
allImports = make(map[string]string)
)

//packageRegex := regexp.MustCompile(`^\s*package\s+([a-zA-Z0-9_.]+);`)
goPackageRegex := regexp.MustCompile(`^\s*option\s+go_package\s*=\s*"([^"]+)";`)
serviceRegex := regexp.MustCompile(`^\s*service\s+([A-Za-z0-9_]+)\s*{`)
rpcRegex := regexp.MustCompile(`^\s*rpc\s+([A-Za-z0-9_]+)\s*\(\s*([A-Za-z0-9_]+)\s*\)\s+returns\s*\(\s*([A-Za-z0-9_]+)\s*\);`)

inService := false

for scanner.Scan() {
Expand All @@ -147,6 +155,12 @@ func parseProtoFile(protoFilePath string) (Service, error) {
// continue
//}

if matches := importRegex.FindStringSubmatch(line); matches != nil {
pkg := strings.TrimSuffix(filepath.Base(matches[1]), filepath.Ext(matches[1]))
allImports[pkg] = matches[1]
continue
}

if matches := goPackageRegex.FindStringSubmatch(line); matches != nil {
goPackage = matches[1]

Expand All @@ -172,6 +186,21 @@ func parseProtoFile(protoFilePath string) (Service, error) {
FullMethodName: fullMethodName,
}
methods = append(methods, method)

if strings.Contains(requestType, ".") {
imp := strings.Split(requestType, ".")[0]
if f, ok := allImports[imp]; ok {
if _, ok = alreadyImports[imp]; !ok {
alreadyImports[imp] = struct{}{}
gopkg, err := getGoPackage(filepath.Join(projectRoot, f))
if err != nil {
fmt.Printf("get go package failed: %v", err)
} else {
imports = append(imports, gopkg)
}
}
}
}
continue
}

Expand All @@ -195,11 +224,12 @@ func parseProtoFile(protoFilePath string) (Service, error) {

sp := strings.Split(goPackage, "/")
service := Service{
FilePath: filePath,
FilePath: filePath_,
FileName: fileName,
ServiceName: serviceName,
GoPackage: sp[len(sp)-1],
Methods: methods,
Imports: imports,
}

return service, nil
Expand All @@ -213,6 +243,9 @@ func generateGoFile(service Service) error {
tmpl := `package {{.GoPackage}}

import (
{{- range .Imports }}
"{{ . }}"
{{- end }}
"github.com/openimsdk/protocol/rpccall"
"google.golang.org/grpc"
)
Expand All @@ -234,10 +267,12 @@ var (
ServiceNameCamel string
GoPackage string
Methods []ServiceMethod
Imports []string
}{
ServiceNameCamel: toCamelCase(service.ServiceName),
GoPackage: service.GoPackage,
Methods: service.Methods,
Imports: service.Imports,
}

t, err := template.New("goFile").Parse(tmpl)
Expand All @@ -261,14 +296,38 @@ var (
return nil
}

func getGoPackage(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("open proto file failed: %v", err)
}
defer file.Close()

scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if matches := goPackageRegex.FindStringSubmatch(line); matches != nil {
return matches[1], nil
}
}
return "", errors.New("goPackage not found")
}

func toCamelCase(s string) string {
parts := strings.Split(s, "_")
pkg := ""
ss := s
if strings.Contains(ss, ".") {
sl := strings.Split(ss, ".")
ss = sl[1]
pkg = sl[0] + "."
}
parts := strings.Split(ss, "_")
for i, p := range parts {
if len(p) > 0 {
parts[i] = strings.ToUpper(p[:1]) + p[1:]
}
}
return strings.Join(parts, "")
return pkg + strings.Join(parts, "")
}

func formatFile(filePath string) {
Expand Down
5 changes: 5 additions & 0 deletions msg/msg_caller.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ package msg

import (
"github.com/openimsdk/protocol/rpccall"
"github.com/openimsdk/protocol/sdkws"
"google.golang.org/grpc"
)

func InitMsg(conn *grpc.ClientConn) {
GetMaxSeqCaller.SetConn(conn)
GetMaxSeqsCaller.SetConn(conn)
GetHasReadSeqsCaller.SetConn(conn)
GetMsgByConversationIDsCaller.SetConn(conn)
GetConversationMaxSeqCaller.SetConn(conn)
PullMessageBySeqsCaller.SetConn(conn)
GetSeqMessageCaller.SetConn(conn)
SearchMessageCaller.SetConn(conn)
SendMsgCaller.SetConn(conn)
Expand Down Expand Up @@ -39,10 +42,12 @@ func InitMsg(conn *grpc.ClientConn) {
}

var (
GetMaxSeqCaller = rpccall.NewRpcCaller[sdkws.GetMaxSeqReq, sdkws.GetMaxSeqResp](Msg_GetMaxSeq_FullMethodName)
GetMaxSeqsCaller = rpccall.NewRpcCaller[GetMaxSeqsReq, SeqsInfoResp](Msg_GetMaxSeqs_FullMethodName)
GetHasReadSeqsCaller = rpccall.NewRpcCaller[GetHasReadSeqsReq, SeqsInfoResp](Msg_GetHasReadSeqs_FullMethodName)
GetMsgByConversationIDsCaller = rpccall.NewRpcCaller[GetMsgByConversationIDsReq, GetMsgByConversationIDsResp](Msg_GetMsgByConversationIDs_FullMethodName)
GetConversationMaxSeqCaller = rpccall.NewRpcCaller[GetConversationMaxSeqReq, GetConversationMaxSeqResp](Msg_GetConversationMaxSeq_FullMethodName)
PullMessageBySeqsCaller = rpccall.NewRpcCaller[sdkws.PullMessageBySeqsReq, sdkws.PullMessageBySeqsResp](Msg_PullMessageBySeqs_FullMethodName)
GetSeqMessageCaller = rpccall.NewRpcCaller[GetSeqMessageReq, GetSeqMessageResp](Msg_GetSeqMessage_FullMethodName)
SearchMessageCaller = rpccall.NewRpcCaller[SearchMessageReq, SearchMessageResp](Msg_SearchMessage_FullMethodName)
SendMsgCaller = rpccall.NewRpcCaller[SendMsgReq, SendMsgResp](Msg_SendMsg_FullMethodName)
Expand Down