From 437ab7e6fb74722ac12c96070b94d2593d8bdf4d Mon Sep 17 00:00:00 2001 From: Atsushi Watanabe Date: Fri, 8 May 2020 16:24:23 +0900 Subject: [PATCH] tunnel: add server option to separate tunnel and API port (#88) --- tunnel/cmd/secure-tunnel-server/main.go | 49 ++++++++++++++++++++----- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/tunnel/cmd/secure-tunnel-server/main.go b/tunnel/cmd/secure-tunnel-server/main.go index e60dc009..61c8b816 100644 --- a/tunnel/cmd/secure-tunnel-server/main.go +++ b/tunnel/cmd/secure-tunnel-server/main.go @@ -7,6 +7,7 @@ import ( "log" "math/rand" "net/http" + "sync" "time" mqtt "github.com/at-wat/mqtt-go" @@ -17,6 +18,8 @@ import ( var ( mqttEndpoint = flag.String("mqtt-endpoint", "", "AWS IoT endpoint") + apiAddr = flag.String("api-addr", ":80", "Address and port of API endpoint") + tunnelAddr = flag.String("tunnel-addr", ":80", "Address and port of proxy WebSocket endpoint") ) func init() { @@ -57,20 +60,46 @@ func main() { tunnelHandler := server.NewTunnelHandler() apiHandler := server.NewAPIHandler(tunnelHandler, notifier) - mux := http.NewServeMux() - mux.Handle("/", apiHandler) - mux.Handle("/tunnel", tunnelHandler) - s := &http.Server{ - Addr: ":80", - Handler: mux, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, + servers := map[string]*http.Server{ + *apiAddr: { + Addr: *apiAddr, + Handler: http.NewServeMux(), + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + }, + } + if *apiAddr != *tunnelAddr { + servers[*tunnelAddr] = &http.Server{ + Addr: *tunnelAddr, + Handler: http.NewServeMux(), + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } } - switch err := s.ListenAndServe(); err { + servers[*apiAddr].Handler.(*http.ServeMux).Handle("/", apiHandler) + servers[*tunnelAddr].Handler.(*http.ServeMux).Handle("/tunnel", tunnelHandler) + + var wg sync.WaitGroup + chErr := make(chan error, len(servers)) + for _, s := range servers { + wg.Add(1) + go func(s *http.Server) { + chErr <- s.ListenAndServe() + wg.Done() + }(s) + } + + switch err := <-chErr; err { case http.ErrServerClosed, nil: default: - log.Fatal(err) + log.Print(err) + } + for _, s := range servers { + if err := s.Close(); err != nil { + log.Print(err) + } } + wg.Wait() }