From de8a4dba1dd22c38c186963f103f1872ecca4597 Mon Sep 17 00:00:00 2001 From: Fred Park Date: Fri, 13 Dec 2024 09:19:13 -0800 Subject: [PATCH] addressing comments, added more cases --- README.md | 17 +++-- configuration.go | 3 +- errors/errors.go | 21 ++++-- integration/integration_test.go | 10 +-- integration/tls/tls_test.go | 2 +- main.go | 33 ++++++---- main_test.go | 80 +++++++++++++++-------- timestream/client.go | 1 - timestream/client_test.go | 112 ++++++++++++++++++++++++++++++++ 9 files changed, 219 insertions(+), 60 deletions(-) diff --git a/README.md b/README.md index 5b8d3ff..78392ce 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,12 @@ docker load < timestream-prometheus-connector-docker-image-.tar.gz --default-table=prometheusMetricsTable ``` +If you have `docker compose` installed, you can bring up the containers with: + + ```shell + docker compose up -d + ``` + It is recommended to secure the Prometheus requests with HTTPS with TLS encryption. To enable TLS encryption: 1. Mount the volume containing the server certificate and the server private key to a volume on the Docker container, then specify the path to the certificate and the key through the `tls-certificate` and `tls-key` configuration options. Note that the path specified must be with respect to the Docker container. @@ -571,7 +577,8 @@ The Prometheus Connector exposes the query SDK's retry configurations for users. | Standalone OptionOption | Lambda Option | Description | Is Required | Default Value | |--------|-------------|------------|---------|---------| -| `max-retries` | `max_retries` | The maximum number of times the read request will be retried for failures. | No | 3 | +| `max-read-retries` | `max_read_retries` | The maximum number of times the read request will be retried for failures. | No | 3 | +| `max-write-retries` | `max_write_retries` | The maximum number of times the write request will be retried for failures. | No | 10 | #### Configuration Examples @@ -579,8 +586,8 @@ Configure the Prometheus Connector to retry up to 10 times upon recoverable erro | Runtime | Command | | -------------------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Precompiled Binaries | `./bootstrap --default-database=PrometheusDatabase --default-table=PrometheusMetricsTable --max-retries=10` | -| AWS Lambda Function | `aws lambda update-function-configuration --function-name PrometheusConnector --environment "Variables={default_database=prometheusDatabase,default_table=prometheusMetricsTable,max_retries=10}"` | +| Precompiled Binaries | `./bootstrap --default-database=PrometheusDatabase --default-table=PrometheusMetricsTable --max-read-retries=10 --max-write-retries=10` | +| AWS Lambda Function | `aws lambda update-function-configuration --function-name PrometheusConnector --environment "Variables={default_database=prometheusDatabase,default_table=prometheusMetricsTable,max_read_retries=10,max_write_retries=10}"` | ### Logger Configuration Options @@ -925,11 +932,11 @@ All connector-specific errors can be found in [`errors/errors.go`](./errors/erro 12. **Error**: `ParseRetriesError` - **Description**: This error will occur when the `max-retries` option has an invalid value. + **Description**: This error will occur when the `max-read-retries` or `max-write-retries` option has an invalid value. **Solution** - See the [Retry Configuration Options](#retry-configuration-options) section for acceptable formats for the `max-retries` option. + See the [Retry Configuration Options](#retry-configuration-options) section for acceptable formats for the `max-read-retries` or `max-write-retries` option. 13. **Error**: `UnknownMatcherError` diff --git a/configuration.go b/configuration.go index e18f2a1..297fba5 100644 --- a/configuration.go +++ b/configuration.go @@ -30,7 +30,8 @@ type configuration struct { var ( enableLogConfig = &configuration{flag: "enable-logging", envFlag: "enable_logging", defaultValue: "true"} regionConfig = &configuration{flag: "region", envFlag: "region", defaultValue: "us-east-1"} - maxRetriesConfig = &configuration{flag: "max-retries", envFlag: "max_retries", defaultValue: strconv.Itoa(retry.DefaultMaxAttempts)} + maxReadRetriesConfig = &configuration{flag: "max-read-retries", envFlag: "max_read_retries", defaultValue: strconv.Itoa(retry.DefaultMaxAttempts)} + maxWriteRetriesConfig = &configuration{flag: "max-write-retries", envFlag: "max_write_retries", defaultValue: strconv.Itoa(10)} defaultDatabaseConfig = &configuration{flag: "default-database", envFlag: "default_database", defaultValue: ""} defaultTableConfig = &configuration{flag: "default-table", envFlag: "default_table", defaultValue: ""} enableSigV4AuthConfig = &configuration{flag: "enable-sigv4-auth", envFlag: "enable_sigv4_auth", defaultValue: "true"} diff --git a/errors/errors.go b/errors/errors.go index 37d2f2d..95d584d 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -96,13 +96,20 @@ type ParseRetriesError struct { baseConnectorError } -func NewParseRetriesError(retries string) error { - return &ParseRetriesError{baseConnectorError: baseConnectorError{ - statusCode: http.StatusBadRequest, - errorMsg: fmt.Sprintf("error occurred while parsing max-retries, expected an integer, but received '%s'", retries), - message: "The value specified in the max-retries option is not one of the accepted values. " + - acceptedValueErrorMessage, - }} +func NewParseRetriesError(retries string, operation string) error { + return &ParseRetriesError{ + baseConnectorError: baseConnectorError{ + statusCode: http.StatusBadRequest, + errorMsg: fmt.Sprintf( + "error occurred while parsing max-%s-retries, expected an integer, but received '%s'", + operation, retries, + ), + message: fmt.Sprintf( + "The value specified in the max-%s-retries option is not one of the accepted values. %s", + operation, acceptedValueErrorMessage, + ), + }, + } } type ParseBasicAuthHeaderError struct { diff --git a/integration/integration_test.go b/integration/integration_test.go index f7b67d8..e7e3eb8 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -134,15 +134,15 @@ func TestWriteClient(t *testing.T) { {"write request with invalid AWS credentials", reqBatchFail, invalidCredentials, true}, } - for _, tc := range invalidTestCases { - t.Run(tc.name, func(t *testing.T) { + for _, test := range invalidTestCases { + t.Run(test.name, func(t *testing.T) { var client *timestream.Client - if tc.allowLongLabel { - client = createClient(t, logger, database, table, tc.creds, true, false) + if test.allowLongLabel { + client = createClient(t, logger, database, table, test.creds, true, false) } else { client = clientEnableFailOnLongLabelName } - err := client.WriteClient().Write(ctx, tc.request, invalidCredentials) + err := client.WriteClient().Write(ctx, test.request, invalidCredentials) assert.NotNil(t, err) }) } diff --git a/integration/tls/tls_test.go b/integration/tls/tls_test.go index 53f621e..b1e59b5 100644 --- a/integration/tls/tls_test.go +++ b/integration/tls/tls_test.go @@ -229,7 +229,7 @@ func getDatabaseRowCount(t *testing.T, database string, table string) int { querySvc := timestreamquery.NewFromConfig(cfg) queryInput := ×treamquery.QueryInput{ - QueryString: aws.String(fmt.Sprintf("SELECT count(*) from %s.%s", database, table)), + QueryString: aws.String(fmt.Sprintf("SELECT count(*) from \"%s\".\"%s\"", database, table)), } out, err := querySvc.Query(ctx, queryInput) diff --git a/main.go b/main.go index 3fff5c0..cc98f88 100644 --- a/main.go +++ b/main.go @@ -52,10 +52,9 @@ import ( ) const ( - readHeader = "x-prometheus-remote-read-version" - writeHeader = "x-prometheus-remote-write-version" - basicAuthHeader = "authorization" - writeClientMaxRetries = 10 + readHeader = "x-prometheus-remote-read-version" + writeHeader = "x-prometheus-remote-write-version" + basicAuthHeader = "authorization" ) var ( @@ -101,7 +100,8 @@ type connectionConfig struct { listenAddr string promlogConfig promlog.Config telemetryPath string - maxRetries int + maxReadRetries int + maxWriteRetries int certificate string key string } @@ -121,13 +121,13 @@ func main() { logger := cfg.createLogger() ctx := context.Background() - awsQueryConfigs, err := cfg.buildAWSConfig(ctx, cfg.maxRetries) + awsQueryConfigs, err := cfg.buildAWSConfig(ctx, cfg.maxReadRetries) if err != nil { timestream.LogError(logger, "Failed to build AWS configuration for query", err) os.Exit(1) } - awsWriteConfigs, err := cfg.buildAWSConfig(ctx, writeClientMaxRetries) + awsWriteConfigs, err := cfg.buildAWSConfig(ctx, cfg.maxWriteRetries) if err != nil { timestream.LogError(logger, "Failed to build AWS configuration for write", err) os.Exit(1) @@ -184,12 +184,12 @@ func lambdaHandler(req events.APIGatewayProxyRequest) (events.APIGatewayProxyRes return createErrorResponse(errors.NewParseBasicAuthHeaderError().(*errors.ParseBasicAuthHeaderError).Message()) } } - awsQueryConfigs, err := cfg.buildAWSConfig(ctx, cfg.maxRetries) + awsQueryConfigs, err := cfg.buildAWSConfig(ctx, cfg.maxReadRetries) if err != nil { timestream.LogError(logger, "Failed to build AWS configuration for query", err) os.Exit(1) } - awsWriteConfigs, err := cfg.buildAWSConfig(ctx, writeClientMaxRetries) + awsWriteConfigs, err := cfg.buildAWSConfig(ctx, cfg.maxWriteRetries) if err != nil { timestream.LogError(logger, "Failed to build AWS configuration for write", err) os.Exit(1) @@ -381,10 +381,16 @@ func parseEnvironmentVariables() (*connectionConfig, error) { return nil, err } - retries := getOrDefault(maxRetriesConfig) - cfg.maxRetries, err = strconv.Atoi(retries) + readRetries := getOrDefault(maxReadRetriesConfig) + cfg.maxReadRetries, err = strconv.Atoi(readRetries) if err != nil { - return nil, errors.NewParseRetriesError(retries) + return nil, errors.NewParseRetriesError(readRetries, "read") + } + + writeRetries := getOrDefault(maxWriteRetriesConfig) + cfg.maxWriteRetries, err = strconv.Atoi(writeRetries) + if err != nil { + return nil, errors.NewParseRetriesError(writeRetries, "write") } cfg.promlogConfig = promlog.Config{Level: &promlog.AllowedLevel{}, Format: &promlog.AllowedFormat{}} @@ -411,7 +417,8 @@ func parseFlags() *connectionConfig { a.Flag(enableLogConfig.flag, "Enables or disables logging in the connector. Default to 'true'.").Default(enableLogConfig.defaultValue).StringVar(&enableLogging) a.Flag(regionConfig.flag, "The signing region for the Timestream service. Default to 'us-east-1'.").Default(regionConfig.defaultValue).StringVar(&cfg.clientConfig.region) - a.Flag(maxRetriesConfig.flag, "The maximum number of times the read request will be retried for failures. Default to 3.").Default(maxRetriesConfig.defaultValue).IntVar(&cfg.maxRetries) + a.Flag(maxReadRetriesConfig.flag, "The maximum number of times the read request will be retried for failures. Default to 3.").Default(maxReadRetriesConfig.defaultValue).IntVar(&cfg.maxReadRetries) + a.Flag(maxWriteRetriesConfig.flag, "The maximum number of times the write request will be retried for failures. Default to 10.").Default(maxWriteRetriesConfig.defaultValue).IntVar(&cfg.maxWriteRetries) a.Flag(defaultDatabaseConfig.flag, "The Prometheus label containing the database name for data ingestion.").Default(defaultDatabaseConfig.defaultValue).StringVar(&cfg.defaultDatabase) a.Flag(defaultTableConfig.flag, "The Prometheus label containing the table name for data ingestion.").Default(defaultTableConfig.defaultValue).StringVar(&cfg.defaultTable) a.Flag(listenAddrConfig.flag, "Address to listen on for web endpoints.").Default(listenAddrConfig.defaultValue).StringVar(&cfg.listenAddr) diff --git a/main_test.go b/main_test.go index 4d4736c..ebda8d6 100644 --- a/main_test.go +++ b/main_test.go @@ -178,7 +178,8 @@ func setUp() ([]string, *connectionConfig) { enableLogging: true, enableSigV4Auth: true, listenAddr: ":9201", - maxRetries: 3, + maxReadRetries: 3, + maxWriteRetries: 10, telemetryPath: "/metrics", } } @@ -654,35 +655,53 @@ func TestCreateLogger(t *testing.T) { assert.NotEqual(t, nopLogger, logger, "Actual logger must not equal to log.NewNopLogger.") }) } -func TestBuildAWSConfig(t *testing.T) { - t.Run("success", func(t *testing.T) { - expectedRegion := "region" - expectedMaxRetries := 3 - input := &connectionConfig{ - clientConfig: &clientConfig{ - region: expectedRegion, - }, - maxRetries: expectedMaxRetries, - } +func TestBuildAWSConfig(t *testing.T) { + testCases := []struct { + name string + maxRetries int + expectedMaxAttempts int + }{ + { + name: "read config", + maxRetries: 10, + expectedMaxAttempts: 10, + }, + { + name: "write config", + maxRetries: 3, + expectedMaxAttempts: 3, + }, + } - actualOutput, err := input.buildAWSConfig(context.Background(), expectedMaxRetries) + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + expectedRegion := "region" + input := &connectionConfig{ + clientConfig: &clientConfig{ + region: expectedRegion, + }, + maxReadRetries: test.expectedMaxAttempts, + maxWriteRetries: test.expectedMaxAttempts, + } - assert.Nil(t, err) - assert.NotNil(t, actualOutput) + actualConfig, err := input.buildAWSConfig(context.Background(), test.maxRetries) - assert.Equal(t, expectedRegion, actualOutput.Region) + assert.Nil(t, err) + assert.NotNil(t, actualConfig) + assert.Equal(t, expectedRegion, actualConfig.Region) - retryer := actualOutput.Retryer() - assert.NotNil(t, retryer) + retryer := actualConfig.Retryer() + assert.NotNil(t, retryer) - standardRetryer, ok := retryer.(*retry.Standard) - assert.True(t, ok, "expected retryer to be of type *retry.Standard") + standardRetryer, ok := retryer.(*retry.Standard) + assert.True(t, ok, "expected retryer to be of type *retry.Standard") - if ok { - assert.Equal(t, expectedMaxRetries, standardRetryer.MaxAttempts()) - } - }) + if ok { + assert.Equal(t, test.expectedMaxAttempts, standardRetryer.MaxAttempts()) + } + }) + } } func TestParseEnvironmentVariables(t *testing.T) { @@ -704,7 +723,8 @@ func TestParseEnvironmentVariables(t *testing.T) { enableSigV4Auth: true, failOnInvalidSample: false, failOnLongMetricLabelName: false, - maxRetries: 3, + maxReadRetries: 3, + maxWriteRetries: 10, }, expectedError: nil, }, @@ -727,10 +747,16 @@ func TestParseEnvironmentVariables(t *testing.T) { expectedError: errors.NewParseSampleOptionError("foo"), }, { - name: "error invalid max_retries option", - lambdaOptions: []lambdaEnvOptions{{key: maxRetriesConfig.envFlag, value: "foo"}}, + name: "error invalid max_read_retries option", + lambdaOptions: []lambdaEnvOptions{{key: maxReadRetriesConfig.envFlag, value: "foo"}}, + expectedConfig: nil, + expectedError: errors.NewParseRetriesError("foo", "read"), + }, + { + name: "error invalid max_write_retries option", + lambdaOptions: []lambdaEnvOptions{{key: maxWriteRetriesConfig.envFlag, value: "foo"}}, expectedConfig: nil, - expectedError: errors.NewParseRetriesError("foo"), + expectedError: errors.NewParseRetriesError("foo", "write"), }, } diff --git a/timestream/client.go b/timestream/client.go index f12a8c8..dc02073 100644 --- a/timestream/client.go +++ b/timestream/client.go @@ -229,7 +229,6 @@ func (wc *WriteClient) Write(ctx context.Context, req *prompb.WriteRequest, cred LogError(wc.logger, "Unable to construct a new session with the given credentials.", err) return err } - LogInfo(wc.logger, fmt.Sprintf("%d records requested for ingestion from Prometheus.", len(req.Timeseries))) recordMap := make(recordDestinationMap) diff --git a/timestream/client_test.go b/timestream/client_test.go index bc82dac..7157689 100644 --- a/timestream/client_test.go +++ b/timestream/client_test.go @@ -19,12 +19,15 @@ import ( goErrors "errors" "fmt" "math" + "net/http" "reflect" "sort" "strconv" "testing" "time" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/timestreamquery" @@ -1145,6 +1148,115 @@ func TestWriteClientWrite(t *testing.T) { mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 1) }) + + t.Run("invalid credentials provider", func(t *testing.T) { + mockTimestreamWriteClient := new(mockTimestreamWriteClient) + + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { + return nil, goErrors.New("invalid credentials") + } + + c := &Client{ + queryClient: nil, + defaultDataBase: mockDatabaseName, + defaultTable: mockTableName, + } + c.writeClient = createNewWriteClientTemplate(c) + + err := c.WriteClient().Write(context.Background(), createNewRequestTemplate(), nil) + assert.Error(t, err) + assert.Equal(t, "invalid credentials", err.Error()) + + mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 0) + }) + + t.Run("handle 4xx SDK error", func(t *testing.T) { + mockTimestreamWriteClient := new(mockTimestreamWriteClient) + expectedInput := createNewWriteRecordsInputTemplate() + + responseError := &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{ + &http.Response{ + StatusCode: 400, + Header: http.Header{}, + }, + }, + Err: goErrors.New("InvalidParameterException"), + } + + mockTimestreamWriteClient.On( + "WriteRecords", + mock.Anything, + mock.MatchedBy(func(writeInput *timestreamwrite.WriteRecordsInput) bool { + sortRecords(writeInput) + sortRecords(expectedInput) + return reflect.DeepEqual(writeInput, expectedInput) + }), + mock.Anything, + ).Return(×treamwrite.WriteRecordsOutput{}, responseError) + + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { + return mockTimestreamWriteClient, nil + } + + c := &Client{ + queryClient: nil, + defaultDataBase: mockDatabaseName, + defaultTable: mockTableName, + } + c.writeClient = createNewWriteClientTemplate(c) + + req := createNewRequestTemplate() + err := c.writeClient.Write(context.Background(), req, mockCredentials) + assert.Equal(t, responseError, err) + + mockTimestreamWriteClient.AssertCalled(t, "WriteRecords", mock.Anything, expectedInput, mock.Anything) + mockTimestreamWriteClient.AssertExpectations(t) + }) + + t.Run("handle 5xx SDK error", func(t *testing.T) { + mockTimestreamWriteClient := new(mockTimestreamWriteClient) + expectedInput := createNewWriteRecordsInputTemplate() + + responseError := &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{ + &http.Response{ + StatusCode: 500, + Header: http.Header{}, + }, + }, + Err: goErrors.New("InternalServerError"), + } + + mockTimestreamWriteClient.On( + "WriteRecords", + mock.Anything, + mock.MatchedBy(func(writeInput *timestreamwrite.WriteRecordsInput) bool { + sortRecords(writeInput) + sortRecords(expectedInput) + return reflect.DeepEqual(writeInput, expectedInput) + }), + mock.Anything, + ).Return(×treamwrite.WriteRecordsOutput{}, responseError) + + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { + return mockTimestreamWriteClient, nil + } + + c := &Client{ + queryClient: nil, + defaultDataBase: mockDatabaseName, + defaultTable: mockTableName, + } + c.writeClient = createNewWriteClientTemplate(c) + + req := createNewRequestTemplate() + err := c.writeClient.Write(context.Background(), req, mockCredentials) + assert.Equal(t, responseError, err) + + mockTimestreamWriteClient.AssertCalled(t, "WriteRecords", mock.Anything, expectedInput, mock.Anything) + mockTimestreamWriteClient.AssertExpectations(t) + }) } // sortRecords sorts the slice of Record in the WriteRecordsInput by time, and sorts the slice of Dimension by dimension names.