Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pre and post health check methods #147

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ type SQLDatasource struct {
CustomRoutes map[string]func(http.ResponseWriter, *http.Request)
metrics Metrics
EnableMultipleConnections bool
// PreCheckHealth (optional). Performs custom health check before the Connect method
PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
// PostCheckHealth (optional).Performs custom health check after the Connect method
PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
}

// NewDatasource creates a new `SQLDatasource`.
Expand Down Expand Up @@ -252,8 +256,10 @@ func (ds *SQLDatasource) CheckHealth(ctx context.Context, req *backend.CheckHeal
ctx, req = checkHealthMutator.MutateCheckHealth(ctx, req)
}
healthChecker := &HealthChecker{
Connector: ds.connector,
Metrics: ds.metrics.WithEndpoint(EndpointHealth),
Connector: ds.connector,
Metrics: ds.metrics.WithEndpoint(EndpointHealth),
PreCheckHealth: ds.PreCheckHealth,
PostCheckHealth: ds.PostCheckHealth,
}
return healthChecker.Check(ctx, req)
}
Expand Down
7 changes: 6 additions & 1 deletion driver-mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"os"
"path/filepath"
"strings"
Expand All @@ -17,7 +18,8 @@ import (

// SQLMock connects to a local folder with csv files
type SQLMock struct {
folder string
folder string
ShouldFailToConnect bool
}

func (h *SQLMock) Settings(_ context.Context, _ backend.DataSourceInstanceSettings) DriverSettings {
Expand All @@ -31,6 +33,9 @@ func (h *SQLMock) Settings(_ context.Context, _ backend.DataSourceInstanceSettin

// Connect opens a sql.DB connection using datasource settings
func (h *SQLMock) Connect(_ context.Context, _ backend.DataSourceInstanceSettings, msg json.RawMessage) (*sql.DB, error) {
if h.ShouldFailToConnect {
return nil, errors.New("failed to create mock")
}
backend.Logger.Debug("connecting to mock data")
folder := h.folder
if folder == "" {
Expand Down
33 changes: 19 additions & 14 deletions health.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,30 @@ import (
)

type HealthChecker struct {
Connector *Connector
Metrics Metrics
Connector *Connector
Metrics Metrics
PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
}

func (hc *HealthChecker) Check(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) {
start := time.Now()

_, err := hc.Connector.Connect(ctx, req.GetHTTPHeaders())
if err != nil {
if hc.PreCheckHealth != nil {
if res := hc.PreCheckHealth(ctx, req); res != nil && res.Status == backend.HealthStatusError {
hc.Metrics.CollectDuration(SourceDownstream, StatusError, time.Since(start).Seconds())
return res, nil
}
}
if _, err := hc.Connector.Connect(ctx, req.GetHTTPHeaders()); err != nil {
hc.Metrics.CollectDuration(SourceDownstream, StatusError, time.Since(start).Seconds())
return &backend.CheckHealthResult{
Status: backend.HealthStatusError,
Message: err.Error(),
}, DownstreamError(err)
return &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: err.Error()}, nil
}
if hc.PostCheckHealth != nil {
if res := hc.PostCheckHealth(ctx, req); res != nil && res.Status == backend.HealthStatusError {
hc.Metrics.CollectDuration(SourceDownstream, StatusError, time.Since(start).Seconds())
return res, nil
}
}
hc.Metrics.CollectDuration(SourceDownstream, StatusOK, time.Since(start).Seconds())

return &backend.CheckHealthResult{
Status: backend.HealthStatusOk,
Message: "Data source is working",
}, nil
return &backend.CheckHealthResult{Status: backend.HealthStatusOk, Message: "Data source is working"}, nil
}
109 changes: 109 additions & 0 deletions health_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package sqlds_test

import (
"context"
"testing"

"github.com/grafana/grafana-plugin-sdk-go/backend"
sqlds "github.com/grafana/sqlds/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func getFakeConnector(t *testing.T, shouldFail bool) *sqlds.Connector {
t.Helper()
c, _ := sqlds.NewConnector(context.TODO(), &sqlds.SQLMock{ShouldFailToConnect: shouldFail}, backend.DataSourceInstanceSettings{}, false)
return c
}

func TestHealthChecker_Check(t *testing.T) {
tests := []struct {
name string
Connector *sqlds.Connector
Metrics sqlds.Metrics
PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult
ctx context.Context
req *backend.CheckHealthRequest
want *backend.CheckHealthResult
wantErr error
}{
{
name: "default health check should return valid result",
Connector: getFakeConnector(t, false),
},
{
name: "should not error when pre check succeed",
Connector: getFakeConnector(t, false),
PreCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
return &backend.CheckHealthResult{Status: backend.HealthStatusOk}
},
},
{
name: "should error when pre check failed",
Connector: getFakeConnector(t, false),
PreCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
return &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"}
},
want: &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"},
},
{
name: "should return actual error when pre and post health check succeed but actual connect failed",
Connector: getFakeConnector(t, true),
PreCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
return &backend.CheckHealthResult{Status: backend.HealthStatusOk}
},
PostCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
return &backend.CheckHealthResult{Status: backend.HealthStatusOk}
},
want: &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unable to get default db connection"},
},
{
name: "should not error when post check succeed",
Connector: getFakeConnector(t, false),
PostCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
return &backend.CheckHealthResult{Status: backend.HealthStatusOk}
},
},
{
name: "should error when post check failed",
Connector: getFakeConnector(t, false),
PostCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult {
return &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"}
},
want: &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unknown error"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
connector := tt.Connector
if connector == nil {
connector = &sqlds.Connector{}
}
req := tt.req
if req == nil {
req = &backend.CheckHealthRequest{}
}
want := tt.want
if want == nil {
want = &backend.CheckHealthResult{Status: backend.HealthStatusOk, Message: "Data source is working"}
}
hc := &sqlds.HealthChecker{
Connector: connector,
Metrics: tt.Metrics,
PreCheckHealth: tt.PreCheckHealth,
PostCheckHealth: tt.PostCheckHealth,
}
got, err := hc.Check(tt.ctx, req)
if tt.wantErr != nil {
require.NotNil(t, err)
assert.Equal(t, tt.wantErr.Error(), err.Error())
return
}
require.Nil(t, err)
require.NotNil(t, got)
assert.Equal(t, want.Message, got.Message)
assert.Equal(t, want.Status, got.Status)
})
}
}
Loading