Skip to content

Commit

Permalink
added unit tests for instrument package and addressed review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
unmarshall committed Feb 16, 2024
1 parent e5cb64a commit b850d3a
Show file tree
Hide file tree
Showing 16 changed files with 912 additions and 22 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/gardener/machine-controller-manager v0.52.0
github.com/onsi/ginkgo/v2 v2.13.0
github.com/onsi/gomega v1.29.0
github.com/prometheus/client_golang v1.16.0 // indirect
github.com/prometheus/client_golang v1.16.0
github.com/spf13/pflag v1.0.5
golang.org/x/crypto v0.14.0
golang.org/x/exp v0.0.0-20230905200255-921286631fa9
Expand Down
2 changes: 1 addition & 1 deletion pkg/azure/access/helpers/disk.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (
// DeleteDisk deletes disk for passed in resourceGroup and diskName.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func DeleteDisk(ctx context.Context, client *armcompute.DisksClient, resourceGroup, diskName string) (err error) {
defer instrument.APIMetricRecorderFn(diskDeleteServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(diskDeleteServiceLabel, &err)()
var poller *runtime.Poller[armcompute.DisksClientDeleteResponse]
poller, err = client.BeginDelete(ctx, resourceGroup, diskName, nil)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/azure/access/helpers/marketplaceagreement.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const (
// GetAgreementTerms fetches the agreement terms for the purchase plan.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func GetAgreementTerms(ctx context.Context, mktPlaceAgreementAccess *armmarketplaceordering.MarketplaceAgreementsClient, purchasePlan armcompute.PurchasePlan) (agreementTerms *armmarketplaceordering.AgreementTerms, err error) {
defer instrument.APIMetricRecorderFn(mktPlaceAgreementGetServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(mktPlaceAgreementGetServiceLabel, &err)()
resp, err := mktPlaceAgreementAccess.Get(ctx, armmarketplaceordering.OfferTypeVirtualmachine, *purchasePlan.Publisher, *purchasePlan.Product, *purchasePlan.Name, nil)
if err != nil {
errors.LogAzAPIError(err, "Failed to get marketplace agreement for PurchasePlan: %+v", purchasePlan)
Expand All @@ -45,7 +45,7 @@ func GetAgreementTerms(ctx context.Context, mktPlaceAgreementAccess *armmarketpl
// AcceptAgreement updates the agreementTerms as accepted.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func AcceptAgreement(ctx context.Context, mktPlaceAgreementAccess *armmarketplaceordering.MarketplaceAgreementsClient, purchasePlan armcompute.PurchasePlan, existingAgreement armmarketplaceordering.AgreementTerms) (err error) {
defer instrument.APIMetricRecorderFn(mktPlaceAgreementCreateServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(mktPlaceAgreementCreateServiceLabel, &err)()
updatedAgreement := existingAgreement
updatedAgreement.Properties.Accepted = to.Ptr(true)
_, err = mktPlaceAgreementAccess.Create(ctx, armmarketplaceordering.OfferTypeVirtualmachine, *purchasePlan.Publisher, *purchasePlan.Product, *purchasePlan.Name, updatedAgreement, nil)
Expand Down
6 changes: 3 additions & 3 deletions pkg/azure/access/helpers/nic.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ const (
// DeleteNIC deletes the NIC identified by a resourceGroup and nicName.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func DeleteNIC(ctx context.Context, client *armnetwork.InterfacesClient, resourceGroup, nicName string) (err error) {
defer instrument.APIMetricRecorderFn(nicDeleteServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(nicDeleteServiceLabel, &err)()

var poller *runtime.Poller[armnetwork.InterfacesClientDeleteResponse]
delCtx, cancelFn := context.WithTimeout(ctx, defaultDeleteNICTimeout)
Expand All @@ -61,7 +61,7 @@ func DeleteNIC(ctx context.Context, client *armnetwork.InterfacesClient, resourc
// GetNIC fetches a NIC identified by resourceGroup and nic name.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func GetNIC(ctx context.Context, client *armnetwork.InterfacesClient, resourceGroup, nicName string) (nic *armnetwork.Interface, err error) {
defer instrument.APIMetricRecorderFn(nicGetServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(nicGetServiceLabel, &err)()

resp, err := client.Get(ctx, resourceGroup, nicName, nil)
if err != nil {
Expand All @@ -77,7 +77,7 @@ func GetNIC(ctx context.Context, client *armnetwork.InterfacesClient, resourceGr
// CreateNIC creates a NIC given the resourceGroup, nic name and NIC creation parameters.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func CreateNIC(ctx context.Context, nicAccess *armnetwork.InterfacesClient, resourceGroup string, nicParams armnetwork.Interface, nicName string) (nic *armnetwork.Interface, err error) {
defer instrument.APIMetricRecorderFn(nicCreateServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(nicCreateServiceLabel, &err)()

var (
poller *runtime.Poller[armnetwork.InterfacesClientCreateOrUpdateResponse]
Expand Down
2 changes: 1 addition & 1 deletion pkg/azure/access/helpers/resourcegraph.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type MapperFn[T any] func(map[string]interface{}) *T
// The result of the query are then mapped using a mapperFn and the result or an error is returned.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func QueryAndMap[T any](ctx context.Context, client *armresourcegraph.Client, subscriptionID string, mapperFn MapperFn[T], queryTemplate string, templateArgs ...any) (results []T, err error) {
defer instrument.APIMetricRecorderFn(resourceGraphQueryServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(resourceGraphQueryServiceLabel, &err)()

query := fmt.Sprintf(queryTemplate, templateArgs...)
resources, err := client.Resources(ctx,
Expand Down
2 changes: 1 addition & 1 deletion pkg/azure/access/helpers/resourcegroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const (
// ResourceGroupExists checks if the given resourceGroup exists.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func ResourceGroupExists(ctx context.Context, client *armresources.ResourceGroupsClient, resourceGroup string) (exists bool, err error) {
defer instrument.APIMetricRecorderFn(resourceGroupExistsServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(resourceGroupExistsServiceLabel, &err)()

resp, err := client.CheckExistence(ctx, resourceGroup, nil)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/azure/access/helpers/subnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const subnetGetServiceLabel = "subnet_get"
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func GetSubnet(ctx context.Context, subnetAccess *armnetwork.SubnetsClient, resourceGroup, virtualNetworkName, subnetName string) (subnet *armnetwork.Subnet, err error) {
var subnetResp armnetwork.SubnetsClientGetResponse
defer instrument.APIMetricRecorderFn(subnetGetServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(subnetGetServiceLabel, &err)()

subnetResp, err = subnetAccess.Get(ctx, resourceGroup, virtualNetworkName, subnetName, nil)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions pkg/azure/access/helpers/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const (
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func GetVirtualMachine(ctx context.Context, vmClient *armcompute.VirtualMachinesClient, resourceGroup, vmName string) (vm *armcompute.VirtualMachine, err error) {
var getResp armcompute.VirtualMachinesClientGetResponse
defer instrument.APIMetricRecorderFn(vmGetServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(vmGetServiceLabel, &err)()

getResp, err = vmClient.Get(ctx, resourceGroup, vmName, nil)
if err != nil {
Expand All @@ -62,7 +62,7 @@ func GetVirtualMachine(ctx context.Context, vmClient *armcompute.VirtualMachines
// If cascade delete is set for associated NICs and Disks then these resources will also be deleted along with the VM.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func DeleteVirtualMachine(ctx context.Context, vmAccess *armcompute.VirtualMachinesClient, resourceGroup, vmName string) (err error) {
defer instrument.APIMetricRecorderFn(vmDeleteServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(vmDeleteServiceLabel, &err)()

delCtx, cancelFn := context.WithTimeout(ctx, defaultDeleteVMTimeout)
defer cancelFn()
Expand All @@ -82,7 +82,7 @@ func DeleteVirtualMachine(ctx context.Context, vmAccess *armcompute.VirtualMachi
// CreateVirtualMachine creates a Virtual Machine given a resourceGroup and virtual machine creation parameters.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func CreateVirtualMachine(ctx context.Context, vmAccess *armcompute.VirtualMachinesClient, resourceGroup string, vmCreationParams armcompute.VirtualMachine) (vm *armcompute.VirtualMachine, err error) {
defer instrument.APIMetricRecorderFn(vmCreateServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(vmCreateServiceLabel, &err)()

createCtx, cancelFn := context.WithTimeout(ctx, defaultCreateVMTimeout)
defer cancelFn()
Expand All @@ -104,7 +104,7 @@ func CreateVirtualMachine(ctx context.Context, vmAccess *armcompute.VirtualMachi
// SetCascadeDeleteForNICsAndDisks sets cascade deletion for NICs and Disks (OSDisk and DataDisks) associated to passed virtual machine.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func SetCascadeDeleteForNICsAndDisks(ctx context.Context, vmClient *armcompute.VirtualMachinesClient, resourceGroup string, vmName string, vmUpdateParams *armcompute.VirtualMachineUpdate) (err error) {
defer instrument.APIMetricRecorderFn(vmUpdateServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(vmUpdateServiceLabel, &err)()

updCtx, cancelFn := context.WithTimeout(ctx, defaultUpdateVMTimeout)
defer cancelFn()
Expand Down
2 changes: 1 addition & 1 deletion pkg/azure/access/helpers/vmimage.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const vmImageGetServiceLabel = "virtual_machine_image_get"
// GetVMImage fetches the VM Image given a location and image reference.
// NOTE: All calls to this Azure API are instrumented as prometheus metric.
func GetVMImage(ctx context.Context, vmImagesAccess *armcompute.VirtualMachineImagesClient, location string, imageRef armcompute.ImageReference) (vmImage *armcompute.VirtualMachineImage, err error) {
defer instrument.APIMetricRecorderFn(vmImageGetServiceLabel, &err)
defer instrument.AZAPIMetricRecorderFn(vmImageGetServiceLabel, &err)()

resp, err := vmImagesAccess.Get(ctx, location, *imageRef.Publisher, *imageRef.Offer, *imageRef.SKU, *imageRef.Version, nil)
if err != nil {
Expand Down
7 changes: 5 additions & 2 deletions pkg/azure/instrument/instrument.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"strconv"
"time"

"github.com/gardener/machine-controller-manager/pkg/util/provider/machinecodes/codes"
"github.com/gardener/machine-controller-manager/pkg/util/provider/machinecodes/status"
"github.com/gardener/machine-controller-manager/pkg/util/provider/metrics"
)
Expand Down Expand Up @@ -67,6 +68,8 @@ func RecordDriverAPIMetric(err error, operation string, invocationTime time.Time
)
if errors.As(err, &statusErr) {
labels = append(labels, strconv.Itoa(int(statusErr.Code())))
} else {
labels = append(labels, strconv.Itoa(int(codes.Internal)))
}
metrics.DriverFailedAPIRequests.
WithLabelValues(labels...).
Expand All @@ -81,9 +84,9 @@ func RecordDriverAPIMetric(err error, operation string, invocationTime time.Time
).Observe(elapsed.Seconds())
}

// APIMetricRecorderFn returns a function that can be used to record a prometheus metric for Azure API calls.
// AZAPIMetricRecorderFn returns a function that can be used to record a prometheus metric for Azure API calls.
// NOTE: a pointer to an error (which itself is a fat interface pointer) is necessary to enable the callers of this function to enclose this call into a `defer` statement.
func APIMetricRecorderFn(azServiceName string, err *error) func() {
func AZAPIMetricRecorderFn(azServiceName string, err *error) func() {
invocationTime := time.Now()
return func() {
RecordAzAPIMetric(*err, azServiceName, invocationTime)
Expand Down
111 changes: 111 additions & 0 deletions pkg/azure/instrument/instrument_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package instrument

import (
"errors"
"strconv"
"testing"

"github.com/gardener/machine-controller-manager/pkg/util/provider/machinecodes/codes"
"github.com/gardener/machine-controller-manager/pkg/util/provider/machinecodes/status"
"github.com/gardener/machine-controller-manager/pkg/util/provider/metrics"
. "github.com/onsi/gomega"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
)

var (
testErr = errors.New("test-error")
defaultErrorCode = strconv.Itoa(int(codes.Internal))
testStatusErr = status.New(codes.InvalidArgument, "test-status-error")
)

const serviceName = "test-service"

func TestAPIMetricRecorderFn(t *testing.T) {
testCases := []struct {
name string
err error
}{
{"assert that function captures failed API request count when the error is not nil", testErr},
{"assert that function captures successful API request count when the error is nil", nil},
}
g := NewWithT(t)
reg := prometheus.NewRegistry()
g.Expect(reg.Register(metrics.APIRequestCount)).To(Succeed())
g.Expect(reg.Register(metrics.APIFailedRequestCount)).To(Succeed())
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
defer metrics.APIRequestCount.Reset()
defer metrics.APIFailedRequestCount.Reset()
_ = deferredMetricsRecorderInvoker(tc.err != nil, false, AZAPIMetricRecorderFn)
if tc.err != nil {
g.Expect(testutil.CollectAndCount(metrics.APIFailedRequestCount)).To(Equal(1))
g.Expect(testutil.ToFloat64(metrics.APIFailedRequestCount.WithLabelValues(prometheusProviderLabelValue, serviceName))).To(Equal(float64(1)))
} else {
g.Expect(testutil.CollectAndCount(metrics.APIRequestCount)).To(Equal(1))
g.Expect(testutil.ToFloat64(metrics.APIRequestCount.WithLabelValues(prometheusProviderLabelValue, serviceName))).To(Equal(float64(1)))
}
})
}
}

func TestDriverAPIMetricRecorderFn(t *testing.T) {
testCases := []struct {
name string
err error
}{
{"assert that function captures failed driver API request with default error code for internal error when there is an error", testErr},
{"assert that function captures failed driver API request with error code from status.Status on error", testStatusErr},
{"assert that function captures successful driver API request count when the error is nil", nil},
}
g := NewWithT(t)
reg := prometheus.NewRegistry()
g.Expect(reg.Register(metrics.DriverFailedAPIRequests)).To(Succeed())
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
defer metrics.DriverFailedAPIRequests.Reset()
_ = deferredMetricsRecorderInvoker(tc.err != nil, isStatusErr(tc.err), DriverAPIMetricRecorderFn)
if tc.err != nil {
expectedErrCode := getExpectedErrorCode(tc.err)
g.Expect(testutil.CollectAndCount(metrics.DriverFailedAPIRequests)).To(Equal(1))
g.Expect(testutil.ToFloat64(metrics.DriverFailedAPIRequests.WithLabelValues(prometheusProviderLabelValue, serviceName, expectedErrCode))).To(Equal(float64(1)))
} else {
g.Expect(testutil.CollectAndCount(metrics.DriverFailedAPIRequests)).To(Equal(0))
}
})
}
}

func isStatusErr(err error) bool {
if err == nil {
return false
}
var statusErr *status.Status
return errors.As(err, &statusErr)
}

func getExpectedErrorCode(err error) string {
if err == nil {
return ""
}
var statusErr *status.Status
if errors.As(err, &statusErr) {
return strconv.Itoa(int(statusErr.Code()))
} else {
return defaultErrorCode
}
}

type recorderFn func(string, *error) func()

func deferredMetricsRecorderInvoker(shouldReturnErr bool, isStatusErr bool, fn recorderFn) (err error) {
defer fn(serviceName, &err)()
if shouldReturnErr {
if isStatusErr {
err = testStatusErr
} else {
err = testErr
}
}
return
}
10 changes: 5 additions & 5 deletions pkg/azure/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func NewDefaultDriver(accessFactory access.Factory) driver.Driver {
}

func (d defaultDriver) ListMachines(ctx context.Context, req *driver.ListMachinesRequest) (resp *driver.ListMachinesResponse, err error) {
defer instrument.DriverAPIMetricRecorderFn(listMachinesOperationLabel, &err)
defer instrument.DriverAPIMetricRecorderFn(listMachinesOperationLabel, &err)()
providerSpec, connectConfig, err := helpers.ExtractProviderSpecAndConnectConfig(req.MachineClass, req.Secret)
if err != nil {
return
Expand All @@ -66,7 +66,7 @@ func (d defaultDriver) ListMachines(ctx context.Context, req *driver.ListMachine
}

func (d defaultDriver) CreateMachine(ctx context.Context, req *driver.CreateMachineRequest) (resp *driver.CreateMachineResponse, err error) {
defer instrument.DriverAPIMetricRecorderFn(createMachineOperationLabel, &err)
defer instrument.DriverAPIMetricRecorderFn(createMachineOperationLabel, &err)()

providerSpec, connectConfig, err := helpers.ExtractProviderSpecAndConnectConfig(req.MachineClass, req.Secret)
if err != nil {
Expand Down Expand Up @@ -99,7 +99,7 @@ func (d defaultDriver) CreateMachine(ctx context.Context, req *driver.CreateMach
}

func (d defaultDriver) DeleteMachine(ctx context.Context, req *driver.DeleteMachineRequest) (resp *driver.DeleteMachineResponse, err error) {
defer instrument.DriverAPIMetricRecorderFn(deleteMachineOperationLabel, &err)
defer instrument.DriverAPIMetricRecorderFn(deleteMachineOperationLabel, &err)()

providerSpec, connectConfig, err := helpers.ExtractProviderSpecAndConnectConfig(req.MachineClass, req.Secret)
if err != nil {
Expand Down Expand Up @@ -165,7 +165,7 @@ func (d defaultDriver) DeleteMachine(ctx context.Context, req *driver.DeleteMach
}

func (d defaultDriver) GetMachineStatus(ctx context.Context, req *driver.GetMachineStatusRequest) (resp *driver.GetMachineStatusResponse, err error) {
defer instrument.DriverAPIMetricRecorderFn(getMachineStatusOperationLabel, &err)
defer instrument.DriverAPIMetricRecorderFn(getMachineStatusOperationLabel, &err)()

providerSpec, connectConfig, err := helpers.ExtractProviderSpecAndConnectConfig(req.MachineClass, req.Secret)
if err != nil {
Expand Down Expand Up @@ -197,7 +197,7 @@ func (d defaultDriver) GetMachineStatus(ctx context.Context, req *driver.GetMach
}

func (d defaultDriver) GetVolumeIDs(_ context.Context, request *driver.GetVolumeIDsRequest) (resp *driver.GetVolumeIDsResponse, err error) {
defer instrument.DriverAPIMetricRecorderFn(getVolumeIDsOperationLabel, &err)
defer instrument.DriverAPIMetricRecorderFn(getVolumeIDsOperationLabel, &err)()

var volumeIDs []string

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit b850d3a

Please sign in to comment.