diff --git a/main.go b/main.go index 8d581e1..0bfc435 100644 --- a/main.go +++ b/main.go @@ -66,11 +66,9 @@ type Config struct { } func main() { - // ******************************************************************************** - // setup context to catch signals - // ******************************************************************************** - ctx, cancel := notifyContext() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // ******************************************************************************** // setup logging // ******************************************************************************** @@ -85,7 +83,9 @@ func main() { if err := debug.Self(); err != nil { log.FromContext(ctx).Infof("%s", err) } + starttime := time.Now() + // enumerating phases log.FromContext(ctx).Infof("there are 5 phases which will be executed followed by a success message:") log.FromContext(ctx).Infof("the phases include:") @@ -95,10 +95,12 @@ func main() { log.FromContext(ctx).Infof("4: create network service client") log.FromContext(ctx).Infof("5: connect to all passed services") log.FromContext(ctx).Infof("a final success message with start time duration") + // ******************************************************************************** log.FromContext(ctx).Infof("executing phase 1: get config from environment (time since start: %s)", time.Since(starttime)) // ******************************************************************************** now := time.Now() + config := &Config{} if err := envconfig.Usage("nsm", config); err != nil { logrus.Fatal(err) @@ -107,18 +109,29 @@ func main() { logrus.Fatalf("error processing config from env: %+v", err) } log.FromContext(ctx).Infof("Config: %#v", config) + log.FromContext(ctx).WithField("duration", time.Since(now)).Infof("completed phase 1: get config from environment") + // ******************************************************************************** log.FromContext(ctx).Infof("executing phase 2: run vpp and get a connection to it (time since start: %s)", time.Since(starttime)) // ******************************************************************************** now = time.Now() + vppConn, vppErrCh := vpphelper.StartAndDialContext(ctx) exitOnErrCh(ctx, cancel, vppErrCh) + + defer func() { + cancel() + <-vppErrCh + }() + log.FromContext(ctx).WithField("duration", time.Since(now)).Info("completed phase 2: run vpp and get a connection to it") + // ******************************************************************************** log.FromContext(ctx).Infof("executing phase 3: retrieving svid, check spire agent logs if this is the last line you see (time since start: %s)", time.Since(starttime)) // ******************************************************************************** now = time.Now() + source, err := workloadapi.NewX509Source(ctx) if err != nil { logrus.Fatalf("error getting x509 source: %+v", err) @@ -128,6 +141,7 @@ func main() { logrus.Fatalf("error getting x509 svid: %+v", err) } logrus.Infof("SVID: %q", svid.ID) + log.FromContext(ctx).WithField("duration", time.Since(now)).Info("completed phase 3: retrieving svid") // ******************************************************************************** @@ -170,6 +184,9 @@ func main() { log.FromContext(ctx).Infof("executing phase 5: connect to all passed services (time since start: %s)", time.Since(starttime)) // ******************************************************************************** + signalCtx, cancelSignalCtx := notifyContext(ctx) + defer cancelSignalCtx() + for i := 0; i < len(config.NetworkServices); i++ { u := nsurl.NSURL(config.NetworkServices[i]) mech := u.Mechanism() @@ -184,8 +201,8 @@ func main() { }, } - requestCtx, cancel := context.WithTimeout(ctx, config.RequestTimeout) - defer cancel() + requestCtx, cancelRequest := context.WithTimeout(signalCtx, config.RequestTimeout) + defer cancelRequest() resp, err := c.Request(requestCtx, request) if err != nil { @@ -193,15 +210,13 @@ func main() { } defer func() { - closeCtx, cancel := context.WithTimeout(context.Background(), config.RequestTimeout) - closeCtx = log.WithFields(closeCtx, log.Fields(ctx)) - defer cancel() + closeCtx, cancelClose := context.WithTimeout(ctx, config.RequestTimeout) + defer cancelClose() _, _ = c.Close(closeCtx, resp) }() } - <-ctx.Done() - <-vppErrCh + <-signalCtx.Done() } func exitOnErrCh(ctx context.Context, cancel context.CancelFunc, errCh <-chan error) { @@ -219,9 +234,9 @@ func exitOnErrCh(ctx context.Context, cancel context.CancelFunc, errCh <-chan er }(ctx, errCh) } -func notifyContext() (context.Context, context.CancelFunc) { +func notifyContext(ctx context.Context) (context.Context, context.CancelFunc) { return signal.NotifyContext( - context.Background(), + ctx, os.Interrupt, // More Linux signals here syscall.SIGHUP,