From 20aedbec6bf005e5948b8145abee50d417b2d26f Mon Sep 17 00:00:00 2001 From: yesoreyeram <153843+yesoreyeram@users.noreply.github.com> Date: Thu, 12 Dec 2024 00:13:08 +0000 Subject: [PATCH] added pre and post health check methods --- datasource.go | 10 ++++- driver-mock.go | 7 +++- health.go | 33 ++++++++------- health_test.go | 109 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 142 insertions(+), 17 deletions(-) create mode 100644 health_test.go diff --git a/datasource.go b/datasource.go index 1b33208..59306a2 100644 --- a/datasource.go +++ b/datasource.go @@ -50,6 +50,10 @@ type SQLDatasource struct { CustomRoutes map[string]func(http.ResponseWriter, *http.Request) metrics Metrics EnableMultipleConnections bool + // PreCheckHealth (optional). Performs custom health check before the Connect method + PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult + // PostCheckHealth (optional).Performs custom health check after the Connect method + PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult } // NewDatasource creates a new `SQLDatasource`. @@ -252,8 +256,10 @@ func (ds *SQLDatasource) CheckHealth(ctx context.Context, req *backend.CheckHeal ctx, req = checkHealthMutator.MutateCheckHealth(ctx, req) } healthChecker := &HealthChecker{ - Connector: ds.connector, - Metrics: ds.metrics.WithEndpoint(EndpointHealth), + Connector: ds.connector, + Metrics: ds.metrics.WithEndpoint(EndpointHealth), + PreCheckHealth: ds.PreCheckHealth, + PostCheckHealth: ds.PostCheckHealth, } return healthChecker.Check(ctx, req) } diff --git a/driver-mock.go b/driver-mock.go index 3fd4aa6..2719251 100644 --- a/driver-mock.go +++ b/driver-mock.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "os" "path/filepath" "strings" @@ -17,7 +18,8 @@ import ( // SQLMock connects to a local folder with csv files type SQLMock struct { - folder string + folder string + ShouldFailToConnect bool } func (h *SQLMock) Settings(_ context.Context, _ backend.DataSourceInstanceSettings) DriverSettings { @@ -31,6 +33,9 @@ func (h *SQLMock) Settings(_ context.Context, _ backend.DataSourceInstanceSettin // Connect opens a sql.DB connection using datasource settings func (h *SQLMock) Connect(_ context.Context, _ backend.DataSourceInstanceSettings, msg json.RawMessage) (*sql.DB, error) { + if h.ShouldFailToConnect { + return nil, errors.New("failed to create mock") + } backend.Logger.Debug("connecting to mock data") folder := h.folder if folder == "" { diff --git a/health.go b/health.go index 88771a7..a1a8a94 100644 --- a/health.go +++ b/health.go @@ -8,25 +8,30 @@ import ( ) type HealthChecker struct { - Connector *Connector - Metrics Metrics + Connector *Connector + Metrics Metrics + PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult + PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult } func (hc *HealthChecker) Check(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { start := time.Now() - - _, err := hc.Connector.Connect(ctx, req.GetHTTPHeaders()) - if err != nil { + if hc.PreCheckHealth != nil { + if res := hc.PreCheckHealth(ctx, req); res != nil && res.Status == backend.HealthStatusError { + hc.Metrics.CollectDuration(SourceDownstream, StatusError, time.Since(start).Seconds()) + return res, nil + } + } + if _, err := hc.Connector.Connect(ctx, req.GetHTTPHeaders()); err != nil { hc.Metrics.CollectDuration(SourceDownstream, StatusError, time.Since(start).Seconds()) - return &backend.CheckHealthResult{ - Status: backend.HealthStatusError, - Message: err.Error(), - }, DownstreamError(err) + return &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: err.Error()}, nil + } + if hc.PostCheckHealth != nil { + if res := hc.PostCheckHealth(ctx, req); res != nil && res.Status == backend.HealthStatusError { + hc.Metrics.CollectDuration(SourceDownstream, StatusError, time.Since(start).Seconds()) + return res, nil + } } hc.Metrics.CollectDuration(SourceDownstream, StatusOK, time.Since(start).Seconds()) - - return &backend.CheckHealthResult{ - Status: backend.HealthStatusOk, - Message: "Data source is working", - }, nil + return &backend.CheckHealthResult{Status: backend.HealthStatusOk, Message: "Data source is working"}, nil } diff --git a/health_test.go b/health_test.go new file mode 100644 index 0000000..5e2fa5b --- /dev/null +++ b/health_test.go @@ -0,0 +1,109 @@ +package sqlds_test + +import ( + "context" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + sqlds "github.com/grafana/sqlds/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func getFakeConnector(t *testing.T, shouldFail bool) *sqlds.Connector { + t.Helper() + c, _ := sqlds.NewConnector(context.TODO(), &sqlds.SQLMock{ShouldFailToConnect: shouldFail}, backend.DataSourceInstanceSettings{}, false) + return c +} + +func TestHealthChecker_Check(t *testing.T) { + tests := []struct { + name string + Connector *sqlds.Connector + Metrics sqlds.Metrics + PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult + PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult + ctx context.Context + req *backend.CheckHealthRequest + want *backend.CheckHealthResult + wantErr error + }{ + { + name: "default health check should return valid result", + Connector: getFakeConnector(t, false), + }, + { + name: "should not error when pre check succeed", + Connector: getFakeConnector(t, false), + PreCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult { + return &backend.CheckHealthResult{Status: backend.HealthStatusOk} + }, + }, + { + name: "should error when pre check failed", + Connector: getFakeConnector(t, false), + PreCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult { + return &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"} + }, + want: &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"}, + }, + { + name: "should return actual error when pre and post health check succeed but actual connect failed", + Connector: getFakeConnector(t, true), + PreCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult { + return &backend.CheckHealthResult{Status: backend.HealthStatusOk} + }, + PostCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult { + return &backend.CheckHealthResult{Status: backend.HealthStatusOk} + }, + want: &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unable to get default db connection"}, + }, + { + name: "should not error when post check succeed", + Connector: getFakeConnector(t, false), + PostCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult { + return &backend.CheckHealthResult{Status: backend.HealthStatusOk} + }, + }, + { + name: "should error when post check failed", + Connector: getFakeConnector(t, false), + PostCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult { + return &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"} + }, + want: &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + connector := tt.Connector + if connector == nil { + connector = &sqlds.Connector{} + } + req := tt.req + if req == nil { + req = &backend.CheckHealthRequest{} + } + want := tt.want + if want == nil { + want = &backend.CheckHealthResult{Status: backend.HealthStatusOk, Message: "Data source is working"} + } + hc := &sqlds.HealthChecker{ + Connector: connector, + Metrics: tt.Metrics, + PreCheckHealth: tt.PreCheckHealth, + PostCheckHealth: tt.PostCheckHealth, + } + got, err := hc.Check(tt.ctx, req) + if tt.wantErr != nil { + require.NotNil(t, err) + assert.Equal(t, tt.wantErr.Error(), err.Error()) + return + } + require.Nil(t, err) + require.NotNil(t, got) + assert.Equal(t, want.Message, got.Message) + assert.Equal(t, want.Status, got.Status) + }) + } +}