diff --git a/cmd/candy/cmd/launch_darwin.go b/cmd/candy/cmd/launch_darwin.go index 00b840e6..41a4425a 100644 --- a/cmd/candy/cmd/launch_darwin.go +++ b/cmd/candy/cmd/launch_darwin.go @@ -3,6 +3,7 @@ package cmd import ( + "context" "fmt" "os" "path/filepath" @@ -49,6 +50,7 @@ func launchRunE(c *cobra.Command, args []string) error { } for _, proxy := range proxies { + proxy := proxy g.Add(func() error { return proxy.Run() }, func(err error) { @@ -58,9 +60,11 @@ func launchRunE(c *cobra.Command, args []string) error { } { + ctx, cancel := context.WithCancel(context.Background()) g.Add(func() error { - return startServer(c) + return startServer(c, ctx) }, func(err error) { + cancel() }) } diff --git a/cmd/candy/cmd/root.go b/cmd/candy/cmd/root.go index c143ac98..302fea91 100644 --- a/cmd/candy/cmd/root.go +++ b/cmd/candy/cmd/root.go @@ -25,34 +25,33 @@ var ( ) func init() { - var err error - homeDir, err = userHomeDir() - if err != nil { - candy.Log().Fatal("error getting user home directory", zap.Error(err)) - } - - rootCmd.PersistentFlags().StringVar(&flagRootCfgFile, "config", filepath.Join(homeDir, ".candyconfig"), "Config file") + rootCmd.PersistentFlags().StringVar(&flagRootCfgFile, "config", filepath.Join(userHomeDir(), ".candyconfig"), "Config file") } -func userHomeDir() (string, error) { +func userHomeDir() string { + if homeDir != "" { + return homeDir + } + var ( sudo = os.Getenv("SUDO_USER") euid = os.Geteuid() + err error ) if sudo != "" && euid == 0 { u, err := user.Lookup(sudo) if err != nil { - return "", nil + candy.Log().Fatal("error looking up sudo user", zap.String("user", sudo), zap.Error(err)) } - return u.HomeDir, nil + return u.HomeDir } - homeDir, err := os.UserHomeDir() + homeDir, err = os.UserHomeDir() if err != nil { - return "", err + candy.Log().Fatal("error getting user home directory", zap.Error(err)) } - return homeDir, nil + return homeDir } diff --git a/cmd/candy/cmd/run.go b/cmd/candy/cmd/run.go index 75b91a33..9322972f 100644 --- a/cmd/candy/cmd/run.go +++ b/cmd/candy/cmd/run.go @@ -32,7 +32,7 @@ func init() { } func addServerFlags(cmd *cobra.Command) { - cmd.Flags().String("host-root", filepath.Join(homeDir, ".candy"), "Path to the directory containing applications that will be served by Candy") + cmd.Flags().String("host-root", filepath.Join(userHomeDir(), ".candy"), "Path to the directory containing applications that will be served by Candy") cmd.Flags().StringSlice("domain", defaultDomains, "The top-level domains for which Candy will respond to DNS queries") cmd.Flags().String("http-addr", ":28080", "The Proxy server HTTP address") cmd.Flags().String("https-addr", ":28443", "The Proxy server HTTPS address") @@ -42,10 +42,10 @@ func addServerFlags(cmd *cobra.Command) { } func runRunE(c *cobra.Command, args []string) error { - return startServer(c) + return startServer(c, context.Background()) } -func startServer(c *cobra.Command) error { +func startServer(c *cobra.Command, ctx context.Context) error { var cfg server.Config if err := candy.LoadConfig( flagRootCfgFile, @@ -66,10 +66,10 @@ func startServer(c *cobra.Command) error { candy.Log().Info("using config", zap.Any("cfg", cfg)) if err := os.MkdirAll(cfg.HostRoot, 0o0755); err != nil { - return fmt.Errorf("failed to create host directory: %w", err) + return fmt.Errorf("failed to create host directory %s: %w", cfg.HostRoot, err) } svr := server.New(cfg) - return svr.Run(context.Background()) + return svr.Run(ctx) }