From 6dff62f42b4cb32e7a0b0c350ce950bf4af11f0c Mon Sep 17 00:00:00 2001 From: ShohamBit Date: Mon, 6 Jan 2025 19:36:39 +0000 Subject: [PATCH] added error handling to server --- pkg/cmd/flags/server/server.go | 51 ++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/pkg/cmd/flags/server/server.go b/pkg/cmd/flags/server/server.go index 84b8e87b131a..76dcfa667f37 100644 --- a/pkg/cmd/flags/server/server.go +++ b/pkg/cmd/flags/server/server.go @@ -1,7 +1,10 @@ package server import ( + "errors" "fmt" + "net" + "net/url" "os" "strings" @@ -44,7 +47,7 @@ func PrepareServer(serverSlice []string) (*Server, error) { // split flag http.address or grpc.address for example serverParts := strings.SplitN(endpoint, ".", 2) if len(serverParts) < 2 { - return nil, fmt.Errorf("cannot process http or grpc alone") + return nil, fmt.Errorf("cannot process the flag: try grpc.Xxx or http.Xxx instead") } switch serverParts[0] { //flag http.Xxx @@ -52,7 +55,11 @@ func PrepareServer(serverSlice []string) (*Server, error) { httpParts := strings.SplitN(serverParts[1], "=", 2) switch httpParts[0] { case ListenEndpointFlag: - server.HTTPServer = http.New(httpParts[1]) + if isValidAddr(httpParts[1]) { + server.HTTPServer = http.New(httpParts[1]) + } else { + return nil, errors.New("invalid http address") + } case MetricsEndpointFlag: if strings.Compare(httpParts[1], "true") == 0 { enableMetrics = true @@ -69,10 +76,12 @@ func PrepareServer(serverSlice []string) (*Server, error) { if strings.Compare(httpParts[1], "true") == 0 { enablePyroscope = true } + default: + return nil, errors.New("invalid http flag, consider using one of the following commands: address, metrics, healthz, pprof, pyroscope") } //flag grpc.Xxx case GRPCServer: - grpcParts := strings.SplitN(serverParts[1], "=", 1) + grpcParts := strings.SplitN(serverParts[1], "=", 2) switch grpcParts[0] { case ListenEndpointFlag: addressParts := strings.SplitN(grpcParts[1], ":", 2) @@ -92,19 +101,15 @@ func PrepareServer(serverSlice []string) (*Server, error) { addressParts = append(addressParts, "4466") } default: - return nil, errfmt.Errorf("grpc supported protocols are tcp or unix. eg: tcp:4466, unix:/tmp/tracee.sock") + return nil, errfmt.Errorf("grpc supported protocols are tcp or unix. eg: tcp:4466, unix:/var/run/tracee.sock") } server.GRPCServer = grpc.New(addressParts[0], addressParts[1]) default: - if _, err = os.Stat("/var/run/tracee.sock"); err == nil { - err = os.Remove("/var/run/tracee.sock") - if err != nil { - return nil, errfmt.Errorf("failed to cleanup gRPC listening address (%s): %v", "/var/run/tracee.sock", err) - } - } - server.GRPCServer = grpc.New("unix", "/var/run/tracee.sock") + return nil, errors.New("invalid grpc flag, consider using address") } + default: + return nil, fmt.Errorf("cannot process the flag: try grpc.Xxx or http.Xxx instead") } } @@ -131,3 +136,27 @@ func PrepareServer(serverSlice []string) (*Server, error) { return &server, nil } + +func isValidAddr(addr string) bool { + _, err := url.ParseRequestURI("http://" + addr) + if err != nil { + return false + } + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return false + } + + ip := net.ParseIP(host) + if ip == nil && host != "localhost" && host != "0.0.0.0" { + return false + } + + _, err = net.LookupPort("tcp", port) + if err != nil { + return false + } + + return true +}