diff --git a/plugins/plugins.go b/plugins/plugins.go index 5a755ee263..a29c1e642f 100644 --- a/plugins/plugins.go +++ b/plugins/plugins.go @@ -160,6 +160,8 @@ type Manager struct { registeredCacheTriggers []func(*cache.Config) logger logging.Logger consoleLogger logging.Logger + serverInitialized chan struct{} + serverInitializedOnce sync.Once } type managerContextKey string @@ -286,6 +288,7 @@ func New(raw []byte, id string, store storage.Store, opts ...func(*Manager)) (*M pluginStatusListeners: map[string]StatusListener{}, maxErrors: -1, interQueryBuiltinCacheConfig: interQueryBuiltinCacheConfig, + serverInitialized: make(chan struct{}), } if m.logger == nil { @@ -759,6 +762,22 @@ func (m *Manager) ConsoleLogger() logging.Logger { return m.consoleLogger } +// ServerInitialized signals a channel indicating that the OPA +// server has finished initialization. +func (m *Manager) ServerInitialized() { + m.serverInitializedOnce.Do(func() { close(m.serverInitialized) }) +} + +// ServerInitializedChannel returns a receive-only channel that +// is closed when the OPA server has finished initialization. +// Be aware that the socket of the server listener may not be +// open by the time this channel is closed. There is a very +// small window where the socket may still be closed, due to +// a race condition. +func (m *Manager) ServerInitializedChannel() <-chan struct{} { + return m.serverInitialized +} + // RegisterCacheTrigger accepts a func that receives new inter-query cache config generated by // a reconfigure of the plugin manager, so that it can be propagated to existing inter-query caches. func (m *Manager) RegisterCacheTrigger(trigger func(*cache.Config)) { diff --git a/plugins/plugins_test.go b/plugins/plugins_test.go index 935eabf2d4..7824547515 100644 --- a/plugins/plugins_test.go +++ b/plugins/plugins_test.go @@ -340,6 +340,39 @@ func TestPluginManagerConsoleLogger(t *testing.T) { } } +func TestPluginManagerServerInitialized(t *testing.T) { + // Verify that ServerInitializedChannel is closed when + // ServerInitialized is called. + m1, err := New([]byte{}, "test1", inmem.New()) + if err != nil { + t.Fatal(err) + } + initChannel1 := m1.ServerInitializedChannel() + m1.ServerInitialized() + // Verify that ServerInitialized is idempotent and will not panic + m1.ServerInitialized() + select { + case <-initChannel1: + break + default: + t.Fatal("expected ServerInitializedChannel to be closed") + } + + // Verify that ServerInitializedChannel is open when + // ServerInitialized is not called. + m2, err := New([]byte{}, "test2", inmem.New()) + if err != nil { + t.Fatal(err) + } + initChannel2 := m2.ServerInitializedChannel() + select { + case <-initChannel2: + t.Fatal("expected ServerInitializedChannel to be open and have no messages") + default: + break + } +} + type myAuthPluginMock struct{} func (m *myAuthPluginMock) NewClient(c rest.Config) (*http.Client, error) { diff --git a/runtime/runtime.go b/runtime/runtime.go index 2b1fe9d2f1..65d0817d22 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -438,9 +438,13 @@ func (rt *Runtime) Serve(ctx context.Context) error { signalc := make(chan os.Signal, 1) signal.Notify(signalc, syscall.SIGINT, syscall.SIGTERM) + // Note that there is a small chance the socket of the server listener is still + // closed by the time this block is executed, due to the serverLoop above + // executing in a goroutine. rt.serverInitMtx.Lock() rt.serverInitialized = true rt.serverInitMtx.Unlock() + rt.Manager.ServerInitialized() logrus.Debug("Server initialized.") diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index 6c29ed3a9d..7f8dfde0b0 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -337,6 +337,36 @@ func TestCheckAuthIneffective(t *testing.T) { } +func TestServerInitialized(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Millisecond) + defer cancel() // NOTE(sr): The timeout will have been reached by the time `done` is closed. + var output bytes.Buffer + + params := NewParams() + params.Output = &output + params.Addrs = &[]string{":0"} + params.GracefulShutdownPeriod = 1 + rt, err := NewRuntime(ctx, params) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + logrus.SetOutput(rt.Params.Output) + + initChannel := rt.Manager.ServerInitializedChannel() + done := make(chan struct{}) + go func() { + rt.StartServer(ctx) + close(done) + }() + <-done + select { + case <-initChannel: + return + default: + t.Fatal("expected ServerInitializedChannel to be closed") + } +} + func getTestServer(update interface{}, statusCode int) (baseURL string, teardownFn func()) { mux := http.NewServeMux() ts := httptest.NewServer(mux)