From 33a2428a7e1b4e5f74af204085ac09094c74df56 Mon Sep 17 00:00:00 2001 From: Fred Park Date: Fri, 6 Dec 2024 13:15:54 -0800 Subject: [PATCH] go sdk upgrade to v2 --- .gitignore | 5 + configuration.go | 5 +- correctness/correctness_test.go | 5 +- docker-compose.yml | 13 + errors/errors.go | 3 +- go.mod | 20 +- go.sum | 39 +- integration/integration_test.go | 129 +++--- integration/integration_test_framework.go | 24 +- integration/tls/tls_test.go | 46 +- main.go | 199 ++++++--- main_test.go | 151 ++++--- timestream/client.go | 270 +++++++----- timestream/client_test.go | 510 ++++++++++++++-------- 14 files changed, 895 insertions(+), 524 deletions(-) create mode 100644 docker-compose.yml diff --git a/.gitignore b/.gitignore index 22df5f7..ab61a2d 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,8 @@ integration/tls/cert resources timestream-prometheus-connector + +# Build artifacts +darwin +linux +windows diff --git a/configuration.go b/configuration.go index 3e49131..e18f2a1 100644 --- a/configuration.go +++ b/configuration.go @@ -16,8 +16,9 @@ and limitations under the License. package main import ( - awsClient "github.com/aws/aws-sdk-go/aws/client" "strconv" + + "github.com/aws/aws-sdk-go-v2/aws/retry" ) type configuration struct { @@ -29,7 +30,7 @@ type configuration struct { var ( enableLogConfig = &configuration{flag: "enable-logging", envFlag: "enable_logging", defaultValue: "true"} regionConfig = &configuration{flag: "region", envFlag: "region", defaultValue: "us-east-1"} - maxRetriesConfig = &configuration{flag: "max-retries", envFlag: "max_retries", defaultValue: strconv.Itoa(awsClient.DefaultRetryerMaxNumRetries)} + maxRetriesConfig = &configuration{flag: "max-retries", envFlag: "max_retries", defaultValue: strconv.Itoa(retry.DefaultMaxAttempts)} defaultDatabaseConfig = &configuration{flag: "default-database", envFlag: "default_database", defaultValue: ""} defaultTableConfig = &configuration{flag: "default-table", envFlag: "default_table", defaultValue: ""} enableSigV4AuthConfig = &configuration{flag: "enable-sigv4-auth", envFlag: "enable_sigv4_auth", defaultValue: "true"} diff --git a/correctness/correctness_test.go b/correctness/correctness_test.go index 2c46d64..849991a 100644 --- a/correctness/correctness_test.go +++ b/correctness/correctness_test.go @@ -23,8 +23,6 @@ import ( "bufio" "context" "encoding/csv" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "io" "os" "path/filepath" @@ -33,6 +31,9 @@ import ( "timestream-prometheus-connector/integration" "timestream-prometheus-connector/timestream" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/docker/docker/client" ) diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..bdae10b --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,13 @@ +services: + timestream-prometheus-connector: + container_name: connector + build: . + ports: + - "9201:9201" + volumes: + - .:/home + command: + - --default-database=PrometheusDatabase + - --default-table=PrometheusMetricsTable + - --region=us-east-1 + - --log.level=debug diff --git a/errors/errors.go b/errors/errors.go index 459b50c..37d2f2d 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -16,8 +16,9 @@ package errors import ( "fmt" - "github.com/prometheus/prometheus/prompb" "net/http" + + "github.com/prometheus/prometheus/prompb" ) type baseConnectorError struct { diff --git a/go.mod b/go.mod index 1f2285c..c5468bd 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,12 @@ go 1.22.3 require ( github.com/alecthomas/kingpin/v2 v2.4.0 github.com/aws/aws-lambda-go v1.46.0 - github.com/aws/aws-sdk-go v1.52.5 + github.com/aws/aws-sdk-go-v2 v1.32.6 + github.com/aws/aws-sdk-go-v2/config v1.28.6 + github.com/aws/aws-sdk-go-v2/credentials v1.17.47 + github.com/aws/aws-sdk-go-v2/service/timestreamquery v1.29.1 + github.com/aws/aws-sdk-go-v2/service/timestreamwrite v1.29.8 + github.com/aws/smithy-go v1.22.1 github.com/docker/docker v25.0.6+incompatible github.com/docker/go-connections v0.4.0 github.com/go-kit/log v0.2.1 @@ -17,12 +22,21 @@ require ( github.com/prometheus/common v0.48.0 github.com/prometheus/prometheus v2.5.0+incompatible github.com/stretchr/testify v1.9.0 - golang.org/x/net v0.26.0 ) require ( github.com/Microsoft/go-winio v0.6.1 // indirect github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/log v0.1.0 // indirect @@ -35,7 +49,6 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect @@ -52,6 +65,7 @@ require ( go.opentelemetry.io/otel/sdk v1.28.0 // indirect go.opentelemetry.io/otel/trace v1.28.0 // indirect golang.org/x/mod v0.17.0 // indirect + golang.org/x/net v0.26.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect golang.org/x/text v0.16.0 // indirect diff --git a/go.sum b/go.sum index e2a36da..81bda88 100644 --- a/go.sum +++ b/go.sum @@ -12,8 +12,38 @@ github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8V github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/aws/aws-lambda-go v1.46.0 h1:UWVnvh2h2gecOlFhHQfIPQcD8pL/f7pVCutmFl+oXU8= github.com/aws/aws-lambda-go v1.46.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= -github.com/aws/aws-sdk-go v1.52.5 h1:m2lty5v9sHm1J3lhA43hJql+yKZudF09qzab0Ag9chM= -github.com/aws/aws-sdk-go v1.52.5/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4= +github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2/config v1.28.6 h1:D89IKtGrs/I3QXOLNTH93NJYtDhm8SYa9Q5CsPShmyo= +github.com/aws/aws-sdk-go-v2/config v1.28.6/go.mod h1:GDzxJ5wyyFSCoLkS+UhGB0dArhb9mI+Co4dHtoTxbko= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47 h1:48bA+3/fCdi2yAwVt+3COvmatZ6jUDNkDTIsqDiMUdw= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47/go.mod h1:+KdckOejLW3Ks3b0E3b5rHsr2f9yuORBum0WPnE5o5w= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 h1:s/fF4+yDQDoElYhfIVvSNyeCydfbuTKzhxSXDXCPasU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25/go.mod h1:IgPfDv5jqFIzQSNbUEMoitNooSMXjRSDkhXv8jiROvU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 h1:ZntTCl5EsYnhN/IygQEUugpdwbhdkom9uHcbCftiGgA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25/go.mod h1:DBdPrgeocww+CSl1C8cEV8PN1mHMBhuCDLpXezyvWkE= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.6 h1:nbmKXZzXPJn41CcD4HsHsGWqvKjLKz9kWu6XxvLmf1s= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.6/go.mod h1:SJhcisfKfAawsdNQoZMBEjg+vyN2lH6rO6fP+T94z5Y= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 h1:50+XsN70RS7dwJ2CkVNXzj7U2L1HKP8nqTd3XWEXBN4= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6/go.mod h1:WqgLmwY7so32kG01zD8CPTJWVWM+TzJoOVHwTg4aPug= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 h1:rLnYAfXQ3YAccocshIH5mzNNwZBkBo+bP6EhIxak6Hw= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7/go.mod h1:ZHtuQJ6t9A/+YDuxOLnbryAmITtr8UysSny3qcyvJTc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 h1:JnhTZR3PiYDNKlXy50/pNeix9aGMo6lLpXwJ1mw8MD4= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6/go.mod h1:URronUEGfXZN1VpdktPSD1EkAL9mfrV+2F4sjH38qOY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 h1:s4074ZO1Hk8qv65GqNXqDjmkf4HSQqJukaLuuW0TpDA= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2/go.mod h1:mVggCnIWoM09jP71Wh+ea7+5gAp53q+49wDFs1SW5z8= +github.com/aws/aws-sdk-go-v2/service/timestreamquery v1.29.1 h1:nfS8q82YuHG8pks28bGSAqy9R44XBLM72jcKDqRG7ak= +github.com/aws/aws-sdk-go-v2/service/timestreamquery v1.29.1/go.mod h1:PJ9MdxcmYoM5bLKzp92fdGooNWHTDMhuC4TGJ3peY7c= +github.com/aws/aws-sdk-go-v2/service/timestreamwrite v1.29.8 h1:chzp64fl/hknlRR9jlstQDB4bYaf848v7KmzUB13omA= +github.com/aws/aws-sdk-go-v2/service/timestreamwrite v1.29.8/go.mod h1:6r72p62vXJL+0VTgk9rVV7i9+C0qTcx+HuL56XT9Pus= +github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= +github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= @@ -74,10 +104,6 @@ github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4 github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= -github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= -github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -223,7 +249,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/integration/integration_test.go b/integration/integration_test.go index 93fbcda..f7b67d8 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -17,21 +17,23 @@ and limitations under the License. package integration import ( - "github.com/aws/aws-sdk-go/aws" - awsClient "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/timestreamwrite" - "github.com/go-kit/log" - "github.com/google/go-cmp/cmp" - "github.com/prometheus/common/model" - "github.com/prometheus/prometheus/prompb" - "github.com/stretchr/testify/assert" + "context" "math/rand" "os" "testing" "time" + "timestream-prometheus-connector/timestream" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/timestreamwrite" + "github.com/go-kit/log" + "github.com/google/go-cmp/cmp" + "github.com/prometheus/common/model" + "github.com/prometheus/prometheus/prompb" + "github.com/stretchr/testify/assert" ) var ( @@ -39,24 +41,33 @@ var ( nowUnix = time.Now().UnixNano() / (int64(time.Millisecond) / int64(time.Nanosecond)) endUnix = nowUnix + 30000 destinations = map[string][]string{database: {table}, database2: {table2}} - writeClient = timestreamwrite.New(session.Must(session.NewSession()), aws.NewConfig().WithRegion(region)) - awsCredentials = writeClient.Config.Credentials - emptyCredentials = credentials.NewStaticCredentials("", "", "") - invalidCredentials = credentials.NewStaticCredentials("accessKey", "secretKey", "") + writeClient *timestreamwrite.Client + awsCredentials aws.CredentialsProvider + emptyCredentials aws.CredentialsProvider = credentials.NewStaticCredentialsProvider("", "", "") + invalidCredentials aws.CredentialsProvider = credentials.NewStaticCredentialsProvider("accessKey", "secretKey", "") ) func TestMain(m *testing.M) { - if err := Setup(writeClient, destinations); err != nil { + ctx := context.Background() + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + panic(err) + } + awsCredentials = cfg.Credentials + + writeClient = timestreamwrite.NewFromConfig(cfg) + if err := Setup(ctx, writeClient, destinations); err != nil { panic(err) } code := m.Run() - if err := Shutdown(writeClient, destinations); err != nil { + if err := Shutdown(ctx, writeClient, destinations); err != nil { panic(err) } os.Exit(code) } func TestWriteClient(t *testing.T) { + ctx := context.Background() req := &prompb.WriteRequest{Timeseries: []*prompb.TimeSeries{ createTimeSeriesTemplate(), }} @@ -70,7 +81,7 @@ func TestWriteClient(t *testing.T) { tsLongLabel := createTimeSeriesTemplate() tsLongLabel.Labels[1].Name = "a_very_long_long_long_long_long_label_name_that_will_be_over_sixty_bytes" reqLongLabel := &prompb.WriteRequest{Timeseries: []*prompb.TimeSeries{ - tsLongMetric, + tsLongLabel, }} var timeSeriesBatch []*prompb.TimeSeries @@ -90,45 +101,56 @@ func TestWriteClient(t *testing.T) { timeSeriesBatchFail = append(timeSeriesBatchFail, createTimeSeriesTemplate()) reqBatchFail := &prompb.WriteRequest{Timeseries: timeSeriesBatchFail} - awsConfigs := &aws.Config{Region: aws.String(region)} - clientEnableFailOnLongLabelName := createClient(t, logger, database, table, awsConfigs, true, false) - clientDisableFailOnLongLabelName := createClient(t, logger, database, table, awsConfigs, false, false) + clientEnableFailOnLongLabelName := createClient(t, logger, database, table, awsCredentials, true, false) + clientDisableFailOnLongLabelName := createClient(t, logger, database, table, awsCredentials, false, false) type testCase []struct { - testName string - request *prompb.WriteRequest - credentials *credentials.Credentials + testName string + request *prompb.WriteRequest } successTestCase := testCase{ - {"write normal request", req, awsCredentials}, - {"write request with long metric name", reqLongMetric, awsCredentials}, - {"write request with long label value", reqLongLabel, awsCredentials}, - {"write request with 100 samples per request", reqBatch, awsCredentials}, - {"write request with more than 100 samples per request", largeReqBatch, awsCredentials}, + {"write normal request", req}, + {"write request with long metric name", reqLongMetric}, + {"write request with long label value", reqLongLabel}, + {"write request with 100 samples per request", reqBatch}, + {"write request with more than 100 samples per request", largeReqBatch}, } for _, test := range successTestCase { t.Run(test.testName, func(t *testing.T) { - err := clientDisableFailOnLongLabelName.WriteClient().Write(test.request, test.credentials) + err := clientDisableFailOnLongLabelName.WriteClient().Write(ctx, test.request, awsCredentials) assert.Nil(t, err) }) } - - invalidTestCase := testCase{ - {"write request with failing long metric name", reqLongMetric, awsCredentials}, - {"write request with failing long label value", reqLongLabel, awsCredentials}, - {"write request with no AWS credentials", reqBatchFail, emptyCredentials}, - {"write request with invalid AWS credentials", reqBatchFail, invalidCredentials}, + invalidTestCases := []struct { + name string + request *prompb.WriteRequest + creds aws.CredentialsProvider + allowLongLabel bool + }{ + {"write request with failing long metric name", reqLongMetric, invalidCredentials, false}, + {"write request with failing long label value", reqLongLabel, invalidCredentials, false}, + {"write request with no AWS credentials", reqBatchFail, emptyCredentials, true}, + {"write request with invalid AWS credentials", reqBatchFail, invalidCredentials, true}, } - for _, test := range invalidTestCase { - t.Run(test.testName, func(t *testing.T) { - err := clientEnableFailOnLongLabelName.WriteClient().Write(test.request, test.credentials) + + for _, tc := range invalidTestCases { + t.Run(tc.name, func(t *testing.T) { + var client *timestream.Client + if tc.allowLongLabel { + client = createClient(t, logger, database, table, tc.creds, true, false) + } else { + client = clientEnableFailOnLongLabelName + } + err := client.WriteClient().Write(ctx, tc.request, invalidCredentials) assert.NotNil(t, err) }) } + } func TestQueryClient(t *testing.T) { + ctx := context.Background() writeReq := createWriteRequest() request, expectedResponse := createValidReadRequest() @@ -159,16 +181,15 @@ func TestQueryClient(t *testing.T) { }, } - awsConfigs := &aws.Config{Region: aws.String(region)} - clientDisableFailOnLongLabelName := createClient(t, logger, database, table, awsConfigs, false, false) + clientDisableFailOnLongLabelName := createClient(t, logger, database, table, awsCredentials, false, false) - err := clientDisableFailOnLongLabelName.WriteClient().Write(writeReq, awsCredentials) + err := clientDisableFailOnLongLabelName.WriteClient().Write(ctx, writeReq, awsCredentials) assert.Nil(t, err) invalidTestCase := []struct { - testName string - request *prompb.ReadRequest - credentials *credentials.Credentials + testName string + request *prompb.ReadRequest + credentialsProvider aws.CredentialsProvider }{ {"read with invalid regex", requestWithInvalidRegex, awsCredentials}, {"read with invalid matcher", requestWithInvalidMatcher, awsCredentials}, @@ -178,14 +199,14 @@ func TestQueryClient(t *testing.T) { for _, test := range invalidTestCase { t.Run(test.testName, func(t *testing.T) { - response, err := clientDisableFailOnLongLabelName.QueryClient().Read(test.request, test.credentials) + response, err := clientDisableFailOnLongLabelName.QueryClient().Read(context.Background(), test.request, test.credentialsProvider) assert.NotNil(t, err) assert.Nil(t, response) }) } t.Run("read normal request", func(t *testing.T) { - response, err := clientDisableFailOnLongLabelName.QueryClient().Read(request, awsCredentials) + response, err := clientDisableFailOnLongLabelName.QueryClient().Read(ctx, request, awsCredentials) assert.Nil(t, err) assert.NotNil(t, response) assert.True(t, cmp.Equal(expectedResponse, response), "Actual response does not match expected response.") @@ -247,12 +268,18 @@ func createReadHints() *prompb.ReadHints { } // createClient creates a new Timestream client containing a Timestream query client and a Timestream write client. -func createClient(t *testing.T, logger log.Logger, database, table string, configs *aws.Config, failOnLongMetricLabelName bool, failOnInvalidSample bool) *timestream.Client { - client := timestream.NewBaseClient(database, table) - client.NewQueryClient(logger, configs) +func createClient(t *testing.T, logger log.Logger, database, table string, credentials aws.CredentialsProvider, failOnLongMetricLabelName bool, failOnInvalidSample bool) *timestream.Client { + cfg, err := config.LoadDefaultConfig(context.TODO(), + config.WithRegion(region), + config.WithCredentialsProvider(credentials), + ) + if err != nil { + t.Fatalf("failed to load AWS config: %v", err) + } - configs.MaxRetries = aws.Int(awsClient.DefaultRetryerMaxNumRetries) - client.NewWriteClient(logger, configs, failOnLongMetricLabelName, failOnInvalidSample) + client := timestream.NewBaseClient(database, table) + client.NewQueryClient(logger, cfg) + client.NewWriteClient(logger, cfg, failOnLongMetricLabelName, failOnInvalidSample) return client } diff --git a/integration/integration_test_framework.go b/integration/integration_test_framework.go index b070977..cad7e74 100644 --- a/integration/integration_test_framework.go +++ b/integration/integration_test_framework.go @@ -24,8 +24,8 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/timestreamwrite" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/timestreamwrite" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" @@ -129,28 +129,28 @@ func StopContainer(t *testing.T, cli *client.Client, ctx context.Context, contai } // Setup creates new databases and tables for integration tests. -func Setup(writeClient *timestreamwrite.TimestreamWrite, destinations map[string][]string) error { +func Setup(ctx context.Context, writeClient *timestreamwrite.Client, destinations map[string][]string) error { for database, tables := range destinations { databaseName := aws.String(database) for _, table := range tables { tableName := aws.String(table) - if _, err := writeClient.DescribeTable(×treamwrite.DescribeTableInput{DatabaseName: databaseName, TableName: tableName}); err == nil { - if _, err = writeClient.DeleteTable(×treamwrite.DeleteTableInput{DatabaseName: databaseName, TableName: tableName}); err != nil { + if _, err := writeClient.DescribeTable(ctx, ×treamwrite.DescribeTableInput{DatabaseName: databaseName, TableName: tableName}); err == nil { + if _, err = writeClient.DeleteTable(ctx, ×treamwrite.DeleteTableInput{DatabaseName: databaseName, TableName: tableName}); err != nil { return err } } } - if _, err := writeClient.DescribeDatabase(×treamwrite.DescribeDatabaseInput{DatabaseName: databaseName}); err == nil { - if _, err = writeClient.DeleteDatabase(×treamwrite.DeleteDatabaseInput{DatabaseName: databaseName}); err != nil { + if _, err := writeClient.DescribeDatabase(ctx, ×treamwrite.DescribeDatabaseInput{DatabaseName: databaseName}); err == nil { + if _, err = writeClient.DeleteDatabase(ctx, ×treamwrite.DeleteDatabaseInput{DatabaseName: databaseName}); err != nil { return err } } - if _, err := writeClient.CreateDatabase(×treamwrite.CreateDatabaseInput{DatabaseName: databaseName}); err != nil { + if _, err := writeClient.CreateDatabase(ctx, ×treamwrite.CreateDatabaseInput{DatabaseName: databaseName}); err != nil { return err } for _, table := range tables { - if _, err := writeClient.CreateTable(×treamwrite.CreateTableInput{DatabaseName: databaseName, TableName: aws.String(table)}); err != nil { + if _, err := writeClient.CreateTable(ctx, ×treamwrite.CreateTableInput{DatabaseName: databaseName, TableName: aws.String(table)}); err != nil { return err } } @@ -159,15 +159,15 @@ func Setup(writeClient *timestreamwrite.TimestreamWrite, destinations map[string } // Shutdown removes the databases and tables created for integration tests. -func Shutdown(writeClient *timestreamwrite.TimestreamWrite, destinations map[string][]string) error { +func Shutdown(ctx context.Context, writeClient *timestreamwrite.Client, destinations map[string][]string) error { for database, tables := range destinations { databaseName := aws.String(database) for _, table := range tables { - if _, err := writeClient.DeleteTable(×treamwrite.DeleteTableInput{DatabaseName: databaseName, TableName: aws.String(table)}); err != nil { + if _, err := writeClient.DeleteTable(ctx, ×treamwrite.DeleteTableInput{DatabaseName: databaseName, TableName: aws.String(table)}); err != nil { return err } } - if _, err := writeClient.DeleteDatabase(×treamwrite.DeleteDatabaseInput{DatabaseName: databaseName}); err != nil { + if _, err := writeClient.DeleteDatabase(ctx, ×treamwrite.DeleteDatabaseInput{DatabaseName: databaseName}); err != nil { return err } } diff --git a/integration/tls/tls_test.go b/integration/tls/tls_test.go index 1d5ad94..53f621e 100644 --- a/integration/tls/tls_test.go +++ b/integration/tls/tls_test.go @@ -17,16 +17,8 @@ and limitations under the License. package tls import ( + "context" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/timestreamquery" - "github.com/aws/aws-sdk-go/service/timestreamwrite" - "github.com/docker/docker/api/types" - "github.com/docker/docker/client" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/net/context" "os" "path/filepath" "strconv" @@ -34,6 +26,15 @@ import ( "time" "timestream-prometheus-connector/integration" "timestream-prometheus-connector/timestream" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/timestreamquery" + "github.com/aws/aws-sdk-go-v2/service/timestreamwrite" + "github.com/docker/docker/api/types" + "github.com/docker/docker/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -68,13 +69,18 @@ var ( ) func TestMain(m *testing.M) { - testSession := session.Must(session.NewSession()) - writeClient := timestreamwrite.New(testSession, aws.NewConfig().WithRegion(region)) - if err := integration.Setup(writeClient, destinations); err != nil { + ctx := context.Background() + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + panic(err) + } + + writeClient := timestreamwrite.NewFromConfig(cfg) + if err := integration.Setup(ctx, writeClient, destinations); err != nil { panic(err) } code := m.Run() - if err := integration.Shutdown(writeClient, destinations); err != nil { + if err := integration.Shutdown(ctx, writeClient, destinations); err != nil { panic(err) } os.Exit(code) @@ -217,12 +223,16 @@ func connectorStatusCheck(t *testing.T, dockerClient *client.Client, ctx context // getDatabaseRowCount gets the number of rows in a specific table. func getDatabaseRowCount(t *testing.T, database string, table string) int { - queryInput := ×treamquery.QueryInput{QueryString: aws.String(fmt.Sprintf("SELECT count(*) from %s.%s", database, table))} - - sess, err := session.NewSession(&aws.Config{Region: aws.String(region)}) + ctx := context.Background() + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) require.NoError(t, err) - querySvc := timestreamquery.New(sess) - out, err := querySvc.Query(queryInput) + + querySvc := timestreamquery.NewFromConfig(cfg) + queryInput := ×treamquery.QueryInput{ + QueryString: aws.String(fmt.Sprintf("SELECT count(*) from %s.%s", database, table)), + } + + out, err := querySvc.Query(ctx, queryInput) require.NoError(t, err) count, err := strconv.Atoi(*out.Rows[0].Data[0].ScalarValue) diff --git a/main.go b/main.go index a9b4116..5b3a279 100644 --- a/main.go +++ b/main.go @@ -17,23 +17,20 @@ and limitations under the License. package main import ( + "context" "encoding/base64" + goErrors "errors" "fmt" + "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/go-kit/log" - "github.com/gogo/protobuf/proto" - "github.com/golang/snappy" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/prometheus/common/promlog" - "github.com/prometheus/common/promlog/flag" - "github.com/prometheus/prometheus/prompb" - "github.com/alecthomas/kingpin/v2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + wtypes "github.com/aws/aws-sdk-go-v2/service/timestreamwrite/types" + "github.com/aws/smithy-go" + "io" "net/http" "os" @@ -42,6 +39,16 @@ import ( "strings" "timestream-prometheus-connector/errors" "timestream-prometheus-connector/timestream" + + "github.com/alecthomas/kingpin/v2" + "github.com/go-kit/log" + "github.com/gogo/protobuf/proto" + "github.com/golang/snappy" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/prometheus/common/promlog" + "github.com/prometheus/common/promlog/flag" + "github.com/prometheus/prometheus/prompb" ) const ( @@ -53,25 +60,29 @@ const ( var ( // Store the initialization function calls and client retrieval calls to allow unit tests to mock the creation of real clients. - createWriteClient = func(timestreamClient *timestream.Client, logger log.Logger, configs *aws.Config, failOnLongMetricLabelName bool, failOnInvalidSample bool) { - timestreamClient.NewWriteClient(logger, configs, failOnLongMetricLabelName, failOnInvalidSample) + createWriteClient = func(timestreamClient *timestream.Client, logger log.Logger, cfg aws.Config, failOnLongMetricLabelName bool, failOnInvalidSample bool) { + timestreamClient.NewWriteClient(logger, cfg, failOnLongMetricLabelName, failOnInvalidSample) + } + createQueryClient = func(timestreamClient *timestream.Client, logger log.Logger, cfg aws.Config) { + timestreamClient.NewQueryClient(logger, cfg) + } + + getWriteClient = func(timestreamClient *timestream.Client) writer { + return timestreamClient.WriteClient() } - createQueryClient = func(timestreamClient *timestream.Client, logger log.Logger, configs *aws.Config, maxRetries int) { - configs.MaxRetries = aws.Int(maxRetries) - timestreamClient.NewQueryClient(logger, configs) + getQueryClient = func(timestreamClient *timestream.Client) reader { + return timestreamClient.QueryClient() } - getWriteClient = func(timestreamClient *timestream.Client) writer { return timestreamClient.WriteClient() } - getQueryClient = func(timestreamClient *timestream.Client) reader { return timestreamClient.QueryClient() } - halt = os.Exit + halt = os.Exit ) type writer interface { - Write(req *prompb.WriteRequest, credentials *credentials.Credentials) error + Write(ctx context.Context, req *prompb.WriteRequest, credentialsProvider aws.CredentialsProvider) error Name() string } type reader interface { - Read(req *prompb.ReadRequest, credentials *credentials.Credentials) (*prompb.ReadResponse, error) + Read(ctx context.Context, req *prompb.ReadRequest, credentialsProvider aws.CredentialsProvider) (*prompb.ReadResponse, error) Name() string } @@ -108,15 +119,22 @@ func main() { http.Handle(cfg.telemetryPath, promhttp.Handler()) logger := cfg.createLogger() - awsQueryConfigs := cfg.buildAWSConfig() - awsWriteConfigs := cfg.buildAWSConfig() - timestreamClient := timestream.NewBaseClient(cfg.defaultDatabase, cfg.defaultTable) + ctx := context.Background() + awsQueryConfigs, err := cfg.buildAWSConfig(ctx, cfg.maxRetries) + if err != nil { + timestream.LogError(logger, "Failed to build AWS configuration for query", err) + os.Exit(1) + } - awsQueryConfigs.MaxRetries = aws.Int(cfg.maxRetries) - timestreamClient.NewQueryClient(logger, awsQueryConfigs) + awsWriteConfigs, err := cfg.buildAWSConfig(ctx, writeClientMaxRetries) + if err != nil { + timestream.LogError(logger, "Failed to build AWS configuration for write", err) + os.Exit(1) + } - awsWriteConfigs.MaxRetries = aws.Int(writeClientMaxRetries) + timestreamClient := timestream.NewBaseClient(cfg.defaultDatabase, cfg.defaultTable) + timestreamClient.NewQueryClient(logger, awsQueryConfigs) timestreamClient.NewWriteClient(logger, awsWriteConfigs, cfg.failOnLongMetricLabelName, cfg.failOnInvalidSample) timestream.LogInfo(logger, fmt.Sprintf("Timestream connection is initialized (Database: %s, Table: %s, Region: %s)", cfg.defaultDatabase, cfg.defaultTable, cfg.clientConfig.region)) @@ -136,7 +154,7 @@ func main() { // lambdaHandler receives Prometheus read or write requests sent by API Gateway. func lambdaHandler(req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { - if (len(os.Getenv(defaultDatabaseConfig.envFlag)) == 0 || len(os.Getenv(defaultTableConfig.envFlag)) == 0) { + if len(os.Getenv(defaultDatabaseConfig.envFlag)) == 0 || len(os.Getenv(defaultTableConfig.envFlag)) == 0 { return createErrorResponse(errors.NewMissingDestinationError().(*errors.MissingDestinationError).Message()) } @@ -147,23 +165,41 @@ func lambdaHandler(req events.APIGatewayProxyRequest) (events.APIGatewayProxyRes logger := cfg.createLogger() - var awsCredentials *credentials.Credentials + ctx := context.Background() + var awsCredentials aws.CredentialsProvider var ok bool // If SigV4 authentication has been enabled, such as when write requests originate // from the OpenTelemetry collector, credentials will be taken from the local environment. // Otherwise, basic auth is used for AWS credentials if cfg.enableSigV4Auth { - sess := session.Must(session.NewSession()) - awsCredentials = sess.Config.Credentials + awsConfig, err := config.LoadDefaultConfig(ctx) + if err != nil { + return createErrorResponse("Error loading AWS config: " + err.Error()) + } + awsCredentials = awsConfig.Credentials } else { awsCredentials, ok = parseBasicAuth(req.Headers[basicAuthHeader]) if !ok { return createErrorResponse(errors.NewParseBasicAuthHeaderError().(*errors.ParseBasicAuthHeaderError).Message()) } } + awsQueryConfigs, err := cfg.buildAWSConfig(ctx, cfg.maxRetries) + if err != nil { + timestream.LogError(logger, "Failed to build AWS configuration for query", err) + os.Exit(1) + } - awsConfigs := cfg.buildAWSConfig() + awsWriteConfigs, err := cfg.buildAWSConfig(ctx, writeClientMaxRetries) + if err != nil { + timestream.LogError(logger, "Failed to build AWS configuration for write", err) + os.Exit(1) + } + + if err != nil { + timestream.LogError(logger, "Failed to build AWS configuration", err) + os.Exit(1) + } timestreamClient := timestream.NewBaseClient(cfg.defaultDatabase, cfg.defaultTable) requestBody, err := base64.StdEncoding.DecodeString(req.Body) @@ -177,16 +213,16 @@ func lambdaHandler(req events.APIGatewayProxyRequest) (events.APIGatewayProxyRes } if len(req.Headers[writeHeader]) != 0 { - return handleWriteRequest(reqBuf, timestreamClient, awsConfigs, cfg, logger, awsCredentials) + return handleWriteRequest(reqBuf, timestreamClient, awsWriteConfigs, cfg, logger, awsCredentials) } else if len(req.Headers[readHeader]) != 0 { - return handleReadRequest(reqBuf, timestreamClient, awsConfigs, cfg, logger, awsCredentials) + return handleReadRequest(reqBuf, timestreamClient, awsQueryConfigs, cfg, logger, awsCredentials) } return createErrorResponse(errors.NewMissingHeaderError(readHeader, writeHeader).(*errors.MissingHeaderError).Message()) } // handleWriteRequest handles a Prometheus write request. -func handleWriteRequest(reqBuf []byte, timestreamClient *timestream.Client, awsConfigs *aws.Config, cfg *connectionConfig, logger log.Logger, credentials *credentials.Credentials) (events.APIGatewayProxyResponse, error) { +func handleWriteRequest(reqBuf []byte, timestreamClient *timestream.Client, awsConfigs aws.Config, cfg *connectionConfig, logger log.Logger, credentialsProvider aws.CredentialsProvider) (events.APIGatewayProxyResponse, error) { var writeRequest prompb.WriteRequest if err := proto.Unmarshal(reqBuf, &writeRequest); err != nil { return events.APIGatewayProxyResponse{ @@ -198,13 +234,8 @@ func handleWriteRequest(reqBuf []byte, timestreamClient *timestream.Client, awsC createWriteClient(timestreamClient, logger, awsConfigs, cfg.failOnLongMetricLabelName, cfg.failOnInvalidSample) timestream.LogInfo(logger, fmt.Sprintf("Timestream write connection is initialized (Database: %s, Table: %s, Region: %s)", cfg.defaultDatabase, cfg.defaultTable, cfg.clientConfig.region)) - if err := getWriteClient(timestreamClient).Write(&writeRequest, credentials); err != nil { + if err := getWriteClient(timestreamClient).Write(context.Background(), &writeRequest, credentialsProvider); err != nil { errorCode := http.StatusBadRequest - - if requestError, ok := err.(awserr.RequestFailure); ok { - errorCode = requestError.StatusCode() - } - return events.APIGatewayProxyResponse{ StatusCode: errorCode, Body: err.Error(), @@ -217,27 +248,20 @@ func handleWriteRequest(reqBuf []byte, timestreamClient *timestream.Client, awsC } // handleReadRequest handles a Prometheus read request. -func handleReadRequest(reqBuf []byte, timestreamClient *timestream.Client, awsConfigs *aws.Config, cfg *connectionConfig, logger log.Logger, credentials *credentials.Credentials) (events.APIGatewayProxyResponse, error) { +func handleReadRequest(reqBuf []byte, timestreamClient *timestream.Client, awsConfigs aws.Config, cfg *connectionConfig, logger log.Logger, credentialsProvider aws.CredentialsProvider) (events.APIGatewayProxyResponse, error) { var readRequest prompb.ReadRequest if err := proto.Unmarshal(reqBuf, &readRequest); err != nil { timestream.LogError(logger, "Error occurred while unmarshalling the decoded read request from Prometheus.", err) return createErrorResponse(err.Error()) } - createQueryClient(timestreamClient, logger, awsConfigs, cfg.maxRetries) + createQueryClient(timestreamClient, logger, awsConfigs) timestream.LogInfo(logger, fmt.Sprintf("Timestream query connection is initialized (Database: %s, Table: %s, Region: %s)", cfg.defaultDatabase, cfg.defaultTable, cfg.clientConfig.region)) - response, err := getQueryClient(timestreamClient).Read(&readRequest, credentials) + response, err := getQueryClient(timestreamClient).Read(context.Background(), &readRequest, credentialsProvider) if err != nil { timestream.LogError(logger, "Error occurred while reading the data back from Timestream.", err) - if requestError, ok := err.(awserr.RequestFailure); ok { - return events.APIGatewayProxyResponse{ - StatusCode: requestError.StatusCode(), - Body: err.Error(), - }, nil - } - return createErrorResponse(err.Error()) } @@ -263,7 +287,7 @@ func handleReadRequest(reqBuf []byte, timestreamClient *timestream.Client, awsCo } // parseBasicAuth parses the encoded HTTP Basic Authentication Header. -func parseBasicAuth(encoded string) (awsCredentials *credentials.Credentials, ok bool) { +func parseBasicAuth(encoded string) (aws.CredentialsProvider, bool) { auth := strings.SplitN(encoded, " ", 2) if len(auth) != 2 || auth[0] != "Basic" { return nil, false @@ -277,7 +301,16 @@ func parseBasicAuth(encoded string) (awsCredentials *credentials.Credentials, ok if len(credentialsSlice) != 2 { return nil, false } - return credentials.NewStaticCredentials(credentialsSlice[0], credentialsSlice[1], ""), true + staticCredentials := aws.NewCredentialsCache( + credentials.StaticCredentialsProvider{ + Value: aws.Credentials{ + AccessKeyID: credentialsSlice[0], + SecretAccessKey: credentialsSlice[1], + Source: "BasicAuthHeader", + }, + }, + ) + return staticCredentials, true } // createLogger creates a new logger for the clients. @@ -421,12 +454,19 @@ func parseFlags() *connectionConfig { } // buildAWSConfig builds a aws.Config and return the pointer of the config. -func (cfg *connectionConfig) buildAWSConfig() *aws.Config { - clientConfig := cfg.clientConfig - awsConfig := &aws.Config{ - Region: aws.String(clientConfig.region), +func (cfg *connectionConfig) buildAWSConfig(ctx context.Context, maxRetries int) (aws.Config, error) { + awsConfig, err := config.LoadDefaultConfig(ctx, + config.WithRegion(cfg.clientConfig.region), + config.WithRetryer(func() aws.Retryer { + return retry.NewStandard(func(o *retry.StandardOptions) { + o.MaxAttempts = maxRetries + }) + }), + ) + if err != nil { + return aws.Config{}, fmt.Errorf("failed to build AWS config: %w", err) } - return awsConfig + return awsConfig, nil } // serve listens for requests and remote writes and reads to Timestream. @@ -476,11 +516,17 @@ func createWriteHandler(logger log.Logger, writers []writer) func(w http.Respons http.Error(w, err.Error(), http.StatusBadRequest) return } - - if err := writers[0].Write(&req, awsCredentials); err != nil { + if err := writers[0].Write(context.Background(), &req, awsCredentials); err != nil { switch err := err.(type) { - case awserr.RequestFailure: - http.Error(w, err.Error(), err.StatusCode()) + case *wtypes.RejectedRecordsException: + http.Error(w, err.Error(), http.StatusConflict) + case *smithy.OperationError: + var apiError *smithy.GenericAPIError + if goErrors.As(err, &apiError) { + http.Error(w, apiError.ErrorMessage(), getHTTPStatusFromSmithyError(apiError)) + return + } + http.Error(w, "An unknown service error occurred", http.StatusInternalServerError) case *errors.SDKNonRequestError: http.Error(w, err.Error(), http.StatusBadRequest) case *errors.MissingDatabaseWithWriteError: @@ -488,10 +534,23 @@ func createWriteHandler(logger log.Logger, writers []writer) func(w http.Respons case *errors.MissingTableWithWriteError: http.Error(w, err.Error(), http.StatusBadRequest) default: - // Others will halt the program. halt(1) } } + + } +} + +func getHTTPStatusFromSmithyError(err *smithy.GenericAPIError) int { + switch err.ErrorCode() { + case "ThrottlingException": + return http.StatusTooManyRequests + case "ResourceNotFoundException": + return http.StatusNotFound + case "AccessDeniedException": + return http.StatusForbidden + default: + return http.StatusInternalServerError } } @@ -507,7 +566,6 @@ func createReadHandler(logger log.Logger, readers []reader) func(w http.Response } compressed, err := io.ReadAll(r.Body) - if err != nil { timestream.LogError(logger, "Error occurred while reading the read request sent by Prometheus.", err) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -527,15 +585,14 @@ func createReadHandler(logger log.Logger, readers []reader) func(w http.Response http.Error(w, err.Error(), http.StatusBadRequest) return } - - response, err := readers[0].Read(&req, awsCredentials) + response, err := readers[0].Read(context.Background(), &req, awsCredentials) if err != nil { timestream.LogError(logger, "Error occurred while reading the data back from Timestream.", err) - if requestError, ok := err.(awserr.RequestFailure); ok { - http.Error(w, err.Error(), requestError.StatusCode()) + var rejectedRecordsErr *wtypes.RejectedRecordsException + if goErrors.As(err, &rejectedRecordsErr) { + http.Error(w, err.Error(), http.StatusBadRequest) return } - http.Error(w, err.Error(), http.StatusBadRequest) return } diff --git a/main_test.go b/main_test.go index 86f4c28..94f5bf9 100644 --- a/main_test.go +++ b/main_test.go @@ -14,15 +14,25 @@ and limitations under the License. package main import ( + "context" "encoding/base64" goErrors "errors" "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "strings" + "testing" + "time" + "timestream-prometheus-connector/errors" + "timestream-prometheus-connector/timestream" + "github.com/aws/aws-lambda-go/events" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/private/protocol" - "github.com/aws/aws-sdk-go/service/timestreamquery" - "github.com/aws/aws-sdk-go/service/timestreamwrite" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" + wtypes "github.com/aws/aws-sdk-go-v2/service/timestreamwrite/types" "github.com/go-kit/log" "github.com/gogo/protobuf/proto" "github.com/golang/snappy" @@ -33,16 +43,6 @@ import ( "github.com/prometheus/prometheus/prompb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "io" - "net/http" - "net/http/httptest" - "os" - "os/exec" - "strings" - "testing" - "time" - "timestream-prometheus-connector/errors" - "timestream-prometheus-connector/timestream" ) const ( @@ -55,7 +55,7 @@ const ( assertResponseMessage = "Error must not occur while reading the response body from the test output." writeRequestType = "*prompb.WriteRequest" readRequestType = "*prompb.ReadRequest" - awsCredentialsType = "*credentials.Credentials" + awsCredentialsType = "*aws.CredentialsCache" ) var ( @@ -147,8 +147,8 @@ type requestTestCase struct { expectedStatusCode int } -func (m *mockWriter) Write(req *prompb.WriteRequest, credentials *credentials.Credentials) error { - args := m.Called(req, credentials) +func (m *mockWriter) Write(ctx context.Context, req *prompb.WriteRequest, credentialsProvider aws.CredentialsProvider) error { + args := m.Called(ctx, req, credentialsProvider) return args.Error(0) } @@ -157,8 +157,8 @@ type mockReader struct { reader } -func (m *mockReader) Read(req *prompb.ReadRequest, credentials *credentials.Credentials) (*prompb.ReadResponse, error) { - args := m.Called(req, credentials) +func (m *mockReader) Read(ctx context.Context, req *prompb.ReadRequest, credentialsProvider aws.CredentialsProvider) (*prompb.ReadResponse, error) { + args := m.Called(ctx, req, credentialsProvider) return args.Get(0).(*prompb.ReadResponse), args.Error(1) } @@ -170,8 +170,8 @@ func setUp() ([]string, *connectionConfig) { promLogLevel.Set("info") return []string{"cmd", "--default-database=foo", "--default-table=bar"}, &connectionConfig{ - clientConfig: &clientConfig{region: "us-east-1"}, - promlogConfig: promlog.Config{Format: promLogFormat, Level: promLogLevel}, + clientConfig: &clientConfig{region: "us-east-1"}, + promlogConfig: promlog.Config{Format: promLogFormat, Level: promLogLevel}, defaultDatabase: "foo", defaultTable: "bar", enableLogging: true, @@ -252,19 +252,22 @@ func TestMainParseFlags(t *testing.T) { cleanUp() }) } - func TestParseBasicAuth(t *testing.T) { tests := []struct { name string encodedCreds string - expectedCredentials *credentials.Credentials + expectedCredentials *aws.Credentials expectedAuthOk bool }{ { - name: "valid basic auth header", - encodedCreds: encodedBasicAuth, - expectedCredentials: credentials.NewStaticCredentials("fakeUser", "fakePassword", ""), - expectedAuthOk: true, + name: "valid basic auth header", + encodedCreds: encodedBasicAuth, + expectedCredentials: &aws.Credentials{ + AccessKeyID: "fakeUser", + SecretAccessKey: "fakePassword", + Source: "BasicAuthHeader", + }, + expectedAuthOk: true, }, { name: "empty basic auth header", @@ -279,14 +282,22 @@ func TestParseBasicAuth(t *testing.T) { expectedAuthOk: false, }, } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - awsCredentials, authOk := parseBasicAuth(test.encodedCreds) + awsCredentialsProvider, authOk := parseBasicAuth(test.encodedCreds) + assert.Equal(t, test.expectedAuthOk, authOk) - assert.Equal(t, test.expectedCredentials, awsCredentials) + + if test.expectedCredentials == nil { + assert.Nil(t, awsCredentialsProvider) + } else { + creds, err := awsCredentialsProvider.Retrieve(context.Background()) + assert.NoError(t, err) + assert.Equal(t, test.expectedCredentials, &creds) + } }) } - } func TestLambdaHandlerPrepareRequest(t *testing.T) { @@ -471,9 +482,13 @@ func TestLambdaHandlerWriteRequest(t *testing.T) { {key: defaultTableConfig.envFlag, value: tableValue}, {key: defaultDatabaseConfig.envFlag, value: databaseValue}, }, - inputRequest: events.APIGatewayProxyRequest{IsBase64Encoded: true, Body: string(validWriteRequestBody), Headers: validWriteHeader}, - mockSDKError: ×treamwrite.RejectedRecordsException{}, - expectedStatusCode: (×treamwrite.RejectedRecordsException{}).StatusCode(), + inputRequest: events.APIGatewayProxyRequest{ + IsBase64Encoded: true, + Body: string(validWriteRequestBody), + Headers: validWriteHeader, + }, + mockSDKError: &wtypes.RejectedRecordsException{}, + expectedStatusCode: http.StatusBadRequest, }, { name: "Missing database name from write", @@ -502,8 +517,10 @@ func TestLambdaHandlerWriteRequest(t *testing.T) { mockTimestreamWriter := new(mockWriter) mockTimestreamWriter.On( "Write", + mock.Anything, mock.AnythingOfType(writeRequestType), - mock.AnythingOfType(awsCredentialsType)).Return(test.mockSDKError) + mock.AnythingOfType(awsCredentialsType), + ).Return(test.mockSDKError) getWriteClient = func(timestreamClient *timestream.Client) writer { return mockTimestreamWriter @@ -564,9 +581,13 @@ func TestLambdaHandlerReadRequest(t *testing.T) { {key: defaultTableConfig.envFlag, value: tableValue}, {key: defaultDatabaseConfig.envFlag, value: databaseValue}, }, - inputRequest: events.APIGatewayProxyRequest{IsBase64Encoded: true, Body: string(validReadRequestBody), Headers: validReadHeader}, - mockSDKError: ×treamquery.ValidationException{}, - expectedStatusCode: (×treamquery.ValidationException{}).StatusCode(), + inputRequest: events.APIGatewayProxyRequest{ + IsBase64Encoded: true, + Body: string(validReadRequestBody), + Headers: validReadHeader, + }, + mockSDKError: &wtypes.RejectedRecordsException{}, + expectedStatusCode: http.StatusBadRequest, }, { name: "Missing database name from read", @@ -595,6 +616,7 @@ func TestLambdaHandlerReadRequest(t *testing.T) { mockTimestreamReader := new(mockReader) mockTimestreamReader.On( "Read", + mock.Anything, mock.AnythingOfType(readRequestType), mock.AnythingOfType(awsCredentialsType)).Return(&prompb.ReadResponse{}, test.mockSDKError) @@ -631,17 +653,34 @@ func TestCreateLogger(t *testing.T) { assert.NotEqual(t, nopLogger, logger, "Actual logger must not equal to log.NewNopLogger.") }) } - func TestBuildAWSConfig(t *testing.T) { t.Run("success", func(t *testing.T) { - expectedAWSConfig := &aws.Config{ - Region: aws.String("region"), + expectedRegion := "region" + expectedMaxRetries := 3 + + input := &connectionConfig{ + clientConfig: &clientConfig{ + region: expectedRegion, + }, + maxRetries: expectedMaxRetries, } - input := &connectionConfig{clientConfig: &clientConfig{region: "region"}} - actualOutput := input.buildAWSConfig() + actualOutput, err := input.buildAWSConfig(context.Background(), expectedMaxRetries) + + assert.Nil(t, err) + assert.NotNil(t, actualOutput) + + assert.Equal(t, expectedRegion, actualOutput.Region) - assert.Equal(t, expectedAWSConfig, actualOutput) + retryer := actualOutput.Retryer() + assert.NotNil(t, retryer) + + standardRetryer, ok := retryer.(*retry.Standard) + assert.True(t, ok, "expected retryer to be of type *retry.Standard") + + if ok { + assert.Equal(t, expectedMaxRetries, standardRetryer.MaxAttempts()) + } }) } @@ -780,15 +819,13 @@ func TestWriteHandler(t *testing.T) { expectedStatusCode: http.StatusBadRequest, }, { - name: "SDK error from write", - request: validWriteRequest, - returnError: ×treamwrite.RejectedRecordsException{ - RespMetadata: protocol.ResponseMetadata{StatusCode: 419}, - }, + name: "SDK error from write", + request: validWriteRequest, + returnError: &wtypes.RejectedRecordsException{}, getWriteRequestReader: getReaderHelper, basicAuthHeader: basicAuthHeader, encodedBasicAuth: encodedBasicAuth, - expectedStatusCode: 419, + expectedStatusCode: http.StatusConflict, }, { name: "unknown SDK error from write", @@ -824,6 +861,7 @@ func TestWriteHandler(t *testing.T) { mockTimestreamWriter := new(mockWriter) mockTimestreamWriter.On( "Write", + mock.Anything, mock.AnythingOfType(writeRequestType), mock.AnythingOfType(awsCredentialsType)).Return(test.returnError) @@ -861,6 +899,7 @@ func TestWriteHandler(t *testing.T) { mockTimestreamWriter := new(mockWriter) mockTimestreamWriter.On( "Write", + mock.Anything, mock.AnythingOfType(writeRequestType), mock.AnythingOfType(awsCredentialsType)).Return(errors.NewLongLabelNameError("", 0)) getWriteRequestClient := func(t *testing.T) io.Reader { @@ -947,16 +986,14 @@ func TestReadHandler(t *testing.T) { expectedStatusCode: http.StatusBadRequest, }, { - name: "SDK error from read", - request: validReadRequest, - returnError: ×treamwrite.RejectedRecordsException{ - RespMetadata: protocol.ResponseMetadata{StatusCode: http.StatusConflict}, - }, + name: "SDK error from read", + request: validReadRequest, + returnError: &wtypes.RejectedRecordsException{}, returnResponse: nil, getReadRequestReader: getReaderHelper, basicAuthHeader: basicAuthHeader, encodedBasicAuth: encodedBasicAuth, - expectedStatusCode: http.StatusConflict, + expectedStatusCode: http.StatusBadRequest, }, { name: "error from read", @@ -995,6 +1032,7 @@ func TestReadHandler(t *testing.T) { mockTimestreamReader := new(mockReader) mockTimestreamReader.On( "Read", + mock.Anything, mock.AnythingOfType(readRequestType), mock.AnythingOfType(awsCredentialsType)).Return(test.returnResponse, test.returnError) @@ -1023,7 +1061,6 @@ func TestReadHandler(t *testing.T) { // Decode and unmarshall the returned response body. actualResponse, err := io.ReadAll(resp.Body) assert.Nil(t, err, assertResponseMessage) - reqBuf, err := snappy.Decode(nil, actualResponse) assert.Nil(t, err, assertResponseMessage) var req prompb.ReadResponse diff --git a/timestream/client.go b/timestream/client.go index bbed550..5a41ab1 100644 --- a/timestream/client.go +++ b/timestream/client.go @@ -18,68 +18,71 @@ and limitations under the License. package timestream import ( + "context" + goErrors "errors" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/timestreamquery" - "github.com/aws/aws-sdk-go/service/timestreamquery/timestreamqueryiface" - "github.com/aws/aws-sdk-go/service/timestreamwrite" - "github.com/aws/aws-sdk-go/service/timestreamwrite/timestreamwriteiface" - "github.com/go-kit/log" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/common/model" - "github.com/prometheus/prometheus/prompb" "math" "strconv" "strings" "time" - "timestream-prometheus-connector/errors" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/service/timestreamquery" + qtypes "github.com/aws/aws-sdk-go-v2/service/timestreamquery/types" + "github.com/aws/aws-sdk-go-v2/service/timestreamwrite" + wtypes "github.com/aws/aws-sdk-go-v2/service/timestreamwrite/types" + "github.com/aws/smithy-go" + "github.com/aws/smithy-go/transport/http" + + "github.com/go-kit/log" + "github.com/prometheus/client_golang/prometheus" prometheusClientModel "github.com/prometheus/client_model/go" + "github.com/prometheus/common/model" + "github.com/prometheus/prometheus/prompb" + + "timestream-prometheus-connector/errors" ) type labelOperation string type longMetricsOperation func(measureValueName string) (labelOperation, error) -var addUserAgent = request.NamedHandler { - Name: "UserAgentHandler", - Fn: request.MakeAddToUserAgentHandler("Prometheus Connector", Version), -} +var addUserAgentMiddleware = middleware.AddUserAgentKey("Prometheus Connector/" + Version) // Store the initialization function calls to allow unit tests to mock the creation of real clients. -var initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { - sess, err := session.NewSession(config) - if err != nil { - return nil, err - } - sess.Handlers.Build.PushFrontNamed(addUserAgent) - return timestreamwrite.New(sess), nil +var initWriteClient = func(cfg aws.Config) (TimestreamWriteClient, error) { + client := timestreamwrite.NewFromConfig(cfg, func(o *timestreamwrite.Options) { + o.APIOptions = append(o.APIOptions, addUserAgentMiddleware) + }) + return client, nil } -var initQueryClient = func(config *aws.Config) (timestreamqueryiface.TimestreamQueryAPI, error) { - sess, err := session.NewSession(config) - if err != nil { - return nil, err + +var initQueryClient = func(cfg aws.Config) (*timestreamquery.Client, error) { + client := timestreamquery.NewFromConfig(cfg, func(o *timestreamquery.Options) { + o.APIOptions = append(o.APIOptions, addUserAgentMiddleware) + }) + return client, nil +} + +var initPaginatorFactory = func(timestreamQuery *timestreamquery.Client, queryInput *timestreamquery.QueryInput) Paginator { + return &TimestreamPaginator{ + paginator: timestreamquery.NewQueryPaginator(timestreamQuery, queryInput), } - sess.Handlers.Build.PushFrontNamed(addUserAgent) - return timestreamquery.New(sess), nil } // recordDestinationMap is a nested map that stores slices of Records based on the ingestion destination. // Below is an example of the map structure: -// records := map[string]map[string][]*timestreamwrite.Record{ -// "database1": map[string][]*timestreamwrite.Record{ -// "table1":[]*timestreamwrite.Record{record1, record2}, -// "table2":[]*timestreamwrite.Record{record3}, -// }, -// "database2": map[string]string{ -// "table3":[]*timestreamwrite.Record{record4, record5}, -// "table4":[]*timestreamwrite.Record{record6}, -// }, -// } -type recordDestinationMap map[string]map[string][]*timestreamwrite.Record +// +// records := map[string]map[string][]wtypes.Record{ +// "database1": map[string][]wtypes.Record{ +// "table1":[]wtypes.Record{record1, record2}, +// "table2":[]wtypes.Record{record3}, +// }, +// "database2": map[string]string{ +// "table3":[]wtypes.Record{record4, record5}, +// "table4":[]wtypes.Record{record6}, +// }, +type recordDestinationMap map[string]map[string][]wtypes.Record const ( maxWriteBatchLength int = 100 @@ -97,22 +100,22 @@ const ( type QueryClient struct { client *Client - config *aws.Config + config aws.Config logger log.Logger readExecutionTime prometheus.Histogram readRequests prometheus.Counter - timestreamQuery timestreamqueryiface.TimestreamQueryAPI + timestreamQuery *timestreamquery.Client } type WriteClient struct { client *Client - config *aws.Config + config aws.Config logger log.Logger ignoredSamples prometheus.Counter receivedSamples prometheus.Counter writeRequests prometheus.Counter writeExecutionTime prometheus.Histogram - timestreamWrite timestreamwriteiface.TimestreamWriteAPI + timestreamWrite TimestreamWriteClient failOnLongMetricLabelName bool failOnInvalidSample bool } @@ -124,6 +127,31 @@ type Client struct { defaultTable string } +type TimestreamWriteClient interface { + WriteRecords(ctx context.Context, input *timestreamwrite.WriteRecordsInput, optFns ...func(*timestreamwrite.Options)) (*timestreamwrite.WriteRecordsOutput, error) +} + +// Paginator defines the interface for Timestream pagination +type Paginator interface { + HasMorePages() bool + NextPage(ctx context.Context) (*timestreamquery.QueryOutput, error) +} + +// TimestreamPaginator wraps the actual Timestream paginator to support mocking in unit tests +type TimestreamPaginator struct { + paginator *timestreamquery.QueryPaginator +} + +func (tp *TimestreamPaginator) HasMorePages() bool { + return tp.paginator.HasMorePages() +} + +func (tp *TimestreamPaginator) NextPage(ctx context.Context) (*timestreamquery.QueryOutput, error) { + return tp.paginator.NextPage(ctx) +} + +type PaginatorFactory func(queryInput *timestreamquery.QueryInput) Paginator + // NewBaseClient creates a Timestream Client object with the ingestion destination labels. func NewBaseClient(defaultDataBase, defaultTable string) *Client { client := &Client{ @@ -135,7 +163,7 @@ func NewBaseClient(defaultDataBase, defaultTable string) *Client { } // NewQueryClient creates a new Timestream query client with the given set of configuration. -func (c *Client) NewQueryClient(logger log.Logger, configs *aws.Config) { +func (c *Client) NewQueryClient(logger log.Logger, configs aws.Config) { c.queryClient = &QueryClient{ client: c, logger: logger, @@ -157,7 +185,7 @@ func (c *Client) NewQueryClient(logger log.Logger, configs *aws.Config) { } // NewWriteClient creates a new Timestream write client with a given set of configurations. -func (c *Client) NewWriteClient(logger log.Logger, configs *aws.Config, failOnLongMetricLabelName bool, failOnInvalidSample bool) { +func (c *Client) NewWriteClient(logger log.Logger, configs aws.Config, failOnLongMetricLabelName bool, failOnInvalidSample bool) { c.writeClient = &WriteClient{ client: c, logger: logger, @@ -191,17 +219,17 @@ func (c *Client) NewWriteClient(logger log.Logger, configs *aws.Config, failOnLo ), } } - -// Write sends the prompb.WriteRequest to timestreamwriteiface.TimestreamWriteAPI -func (wc *WriteClient) Write(req *prompb.WriteRequest, credentials *credentials.Credentials) error { - wc.config.Credentials = credentials +func (wc *WriteClient) Write(ctx context.Context, req *prompb.WriteRequest, credentialsProvider aws.CredentialsProvider) error { + wc.config.Credentials = credentialsProvider var err error wc.timestreamWrite, err = initWriteClient(wc.config) if err != nil { LogError(wc.logger, "Unable to construct a new session with the given credentials.", err) return err } + LogInfo(wc.logger, fmt.Sprintf("%d records requested for ingestion from Prometheus.", len(req.Timeseries))) + recordMap := make(recordDestinationMap) recordMap, err = wc.convertToRecords(req.Timeseries, recordMap) if err != nil { @@ -212,31 +240,38 @@ func (wc *WriteClient) Write(req *prompb.WriteRequest, credentials *credentials. var sdkErr error for database, tableMap := range recordMap { for table, records := range tableMap { + recordLen := len(records) // Timestream will return an error if more than 100 records are sent in a batch. // Therefore, records should be chunked if there are more than 100 of them - var chunkEndIndex int - for chunkStartIndex := 0; chunkStartIndex < len(records); chunkStartIndex += maxWriteBatchLength { - chunkEndIndex += maxWriteBatchLength - if chunkEndIndex > len(records) { - chunkEndIndex = len(records) + for chunkStartIndex := 0; chunkStartIndex < recordLen; chunkStartIndex += maxWriteBatchLength { + chunkEndIndex := chunkStartIndex + maxWriteBatchLength + if chunkEndIndex > recordLen { + chunkEndIndex = recordLen } + + currentChunkSize := chunkEndIndex - chunkStartIndex + writeRecordsInput := ×treamwrite.WriteRecordsInput{ DatabaseName: aws.String(database), TableName: aws.String(table), Records: records[chunkStartIndex:chunkEndIndex], } + begin := time.Now() - _, err = wc.timestreamWrite.WriteRecords(writeRecordsInput) + _, err = wc.timestreamWrite.WriteRecords(ctx, writeRecordsInput) duration := time.Since(begin).Seconds() + if err != nil { sdkErr = wc.handleSDKErr(req, err, sdkErr) } else { - LogInfo(wc.logger, fmt.Sprintf("Successfully wrote %d records to database: %s table: %s", len(writeRecordsInput.Records), database, table)) + LogInfo(wc.logger, fmt.Sprintf("Successfully wrote %d records to Database: %s, Table: %s", currentChunkSize, database, table)) + recordsIgnored := getCounterValue(wc.ignoredSamples) - if (recordsIgnored > 0) { - LogInfo(wc.logger, fmt.Sprintf("%d number of records were rejected for ingestion to Timestream. See Troubleshooting in the README for why these may be rejected, or turn on debug logging for additional info.", recordsIgnored)) + if recordsIgnored > 0 { + LogInfo(wc.logger, fmt.Sprintf("%d records were rejected for ingestion to Timestream. See Troubleshooting in the README for possible reasons, or enable debug logging for more details.", recordsIgnored)) } } + wc.writeExecutionTime.Observe(duration) wc.writeRequests.Inc() } @@ -248,15 +283,18 @@ func (wc *WriteClient) Write(req *prompb.WriteRequest, credentials *credentials. // Read converts the Prometheus prompb.ReadRequest into Timestream queries and return // the result set as Prometheus prompb.ReadResponse. -func (qc *QueryClient) Read(req *prompb.ReadRequest, credentials *credentials.Credentials) (*prompb.ReadResponse, error) { - qc.config.Credentials = credentials +func (qc *QueryClient) Read( + ctx context.Context, + req *prompb.ReadRequest, + credentialsProvider aws.CredentialsProvider, +) (*prompb.ReadResponse, error) { + qc.config.Credentials = credentialsProvider var err error qc.timestreamQuery, err = initQueryClient(qc.config) if err != nil { LogError(qc.logger, "Unable to construct a new session with the given credentials", err) return nil, err } - queryInputs, isRelatedToRegex, err := qc.buildCommands(req.Queries) if err != nil { LogError(qc.logger, "Error occurred while translating Prometheus query.", err) @@ -268,33 +306,38 @@ func (qc *QueryClient) Read(req *prompb.ReadRequest, credentials *credentials.Cr begin := time.Now() var queryPageError error + for _, queryInput := range queryInputs { - queryPageError = qc.timestreamQuery.QueryPages(queryInput, - func(page *timestreamquery.QueryOutput, lastPage bool) bool { - var convertError error - resultSet, convertError = qc.convertToResult(resultSet, page) - qc.readRequests.Inc() - if convertError != nil { - LogError(qc.logger, "Error occurred while converting the Timestream query results to Prometheus QueryResults", err) - return false - } - LogInfo(qc.logger, fmt.Sprintf("Successfully read %d records from database: %s table: %s", len(page.Rows), qc.client.defaultDataBase, qc.client.defaultTable)) - return true - }) - if queryPageError != nil { - if requestError, ok := queryPageError.(awserr.RequestFailure); ok && (requestError.StatusCode()/100 == 4) { - LogDebug(qc.logger, "The read request failed while retrieving data back from Timestream.", "request", req) + paginator := initPaginatorFactory(qc.timestreamQuery, queryInput) + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + queryPageError = err + LogError(qc.logger, "Error occurred while fetching the next page of results.", err) + break + } + + resultSet, err = qc.convertToResult(resultSet, page) + qc.readRequests.Inc() + if err != nil { + LogError(qc.logger, "Error occurred while converting the Timestream query results to Prometheus QueryResults", err) + return nil, err } - if _, ok := queryPageError.(*timestreamquery.ValidationException); ok && isRelatedToRegex { + } + + if queryPageError != nil { + var apiError *smithy.GenericAPIError + if goErrors.As(queryPageError, &apiError) && apiError.Code == "ValidationException" && isRelatedToRegex { LogError(qc.logger, "Error occurred due to unsupported query. Please validate the regular expression used in the query. Check the documentation for unsupported RE2 syntax.", queryPageError) return nil, queryPageError } - LogError(qc.logger, "Error occurred while querying Timestream pages.", err) + LogError(qc.logger, "Error occurred while querying Timestream pages.", queryPageError) return nil, queryPageError } } + duration := time.Since(begin).Seconds() qc.readExecutionTime.Observe(duration) @@ -305,26 +348,31 @@ func (qc *QueryClient) Read(req *prompb.ReadRequest, credentials *credentials.Cr // handleSDKErr parses and logs the error from SDK (if any) func (wc *WriteClient) handleSDKErr(req *prompb.WriteRequest, currErr error, errToReturn error) error { - requestError, ok := currErr.(awserr.RequestFailure) - if !ok { + var responseError *http.ResponseError + if !goErrors.As(currErr, &responseError) { LogError(wc.logger, fmt.Sprintf("Error occurred while ingesting Timestream Records. %d records failed to be written", len(req.Timeseries)), currErr) - return errors.NewSDKNonRequestError(currErr) + return currErr } if errToReturn == nil { - errToReturn = requestError + errToReturn = currErr } - switch requestError.StatusCode() / 100 { + + statusCode := responseError.HTTPStatusCode() + switch statusCode / 100 { case 4: - LogDebug(wc.logger, "Error occurred while ingesting data due to invalid write request. Some Prometheus Samples were not ingested into Timestream, please review the write request and check the documentation for troubleshooting.", "request", req) + LogDebug(wc.logger, "Error occurred while ingesting data due to invalid write request. "+ + "Some Prometheus Samples were not ingested into Timestream, please review the write request and check the documentation for troubleshooting.", + "request", req) case 5: - errToReturn = requestError + errToReturn = currErr LogDebug(wc.logger, "Internal server error occurred. Samples will be retried by Prometheus", "request", req) } + return errToReturn } -// convertToRecords converts a slice of *prompb.TimeSeries to a slice of *timestreamwrite.Record +// convertToRecords converts a slice of *prompb.TimeSeries to a slice of wtypes.Record func (wc *WriteClient) convertToRecords(series []*prompb.TimeSeries, recordMap recordDestinationMap) (recordDestinationMap, error) { var operationOnLongMetrics longMetricsOperation if wc.failOnLongMetricLabelName { @@ -350,10 +398,10 @@ func (wc *WriteClient) convertToRecords(series []*prompb.TimeSeries, recordMap r return processTimeSeries(wc, operationOnLongMetrics, series, recordMap) } -// processTimeSeries processes a slice of *prompb.TimeSeries to a slice of *timestreamwrite.Record +// processTimeSeries processes a slice of *prompb.TimeSeries to a slice of wtypes.Record func processTimeSeries(wc *WriteClient, operationOnLongMetrics longMetricsOperation, series []*prompb.TimeSeries, recordMap recordDestinationMap) (recordDestinationMap, error) { for _, timeSeries := range series { - var dimensions []*timestreamwrite.Dimension + var dimensions []wtypes.Dimension var err error var operation labelOperation var databaseName string @@ -395,7 +443,7 @@ func processTimeSeries(wc *WriteClient, operationOnLongMetrics longMetricsOperat recordMap[databaseName] = getOrCreateRecordMapEntry(recordMap, databaseName) - var records []*timestreamwrite.Record + var records []wtypes.Record if recordMap[databaseName][tableName] != nil { records = recordMap[databaseName][tableName] @@ -417,13 +465,13 @@ func processTimeSeries(wc *WriteClient, operationOnLongMetrics longMetricsOperat return recordMap, nil } -// processMetricLabels processes metricLabels to a *timestreamwrite.Record -func processMetricLabels(metricLabels map[string]string, operationOnLongMetrics longMetricsOperation) ([]*timestreamwrite.Dimension, labelOperation, error) { +// processMetricLabels processes metricLabels to a wtypes.Record +func processMetricLabels(metricLabels map[string]string, operationOnLongMetrics longMetricsOperation) ([]wtypes.Dimension, labelOperation, error) { var operation labelOperation - var dimensions []*timestreamwrite.Dimension + var dimensions []wtypes.Dimension var err error for name, value := range metricLabels { - // Each label in the metricLabels map contains a characteristic/dimension of the metric, which maps to timestreamwrite.Dimension + // Each label in the metricLabels map contains a characteristic/dimension of the metric, which maps to wtypes.Dimension operation, err = operationOnLongMetrics(name) switch operation { case failed: @@ -431,7 +479,7 @@ func processMetricLabels(metricLabels map[string]string, operationOnLongMetrics case ignored: return nil, operation, nil default: - dimensions = append(dimensions, ×treamwrite.Dimension{ + dimensions = append(dimensions, wtypes.Dimension{ Name: aws.String(name), Value: aws.String(value), }) @@ -441,16 +489,16 @@ func processMetricLabels(metricLabels map[string]string, operationOnLongMetrics } // getOrCreateRecordMapEntry gets record map entry -func getOrCreateRecordMapEntry(recordMap recordDestinationMap, databaseName string) map[string][]*timestreamwrite.Record { +func getOrCreateRecordMapEntry(recordMap recordDestinationMap, databaseName string) map[string][]wtypes.Record { if recordMap[databaseName] == nil { - recordMap[databaseName] = make(map[string][]*timestreamwrite.Record) + recordMap[databaseName] = make(map[string][]wtypes.Record) } return recordMap[databaseName] } // convertToMap converts the slice of Labels to a Map and retrieves the measure value name. func convertToMap(labels []*prompb.Label) (map[string]string, string) { - // measureValueName is the Prometheus metric name that maps to MeasureName of a timestreamwrite.Record + // measureValueName is the Prometheus metric name that maps to MeasureName of a wtypes.Record var measureValueName string metric := make(map[string]string, len(labels)) @@ -464,7 +512,7 @@ func convertToMap(labels []*prompb.Label) (map[string]string, string) { } // appendRecords converts each valid Prometheus Sample to a Timestream Record and append the Record to the given slice of records. -func (wc *WriteClient) appendRecords(records []*timestreamwrite.Record, timeSeries *prompb.TimeSeries, dimensions []*timestreamwrite.Dimension, measureValueName string) ([]*timestreamwrite.Record, error) { +func (wc *WriteClient) appendRecords(records []wtypes.Record, timeSeries *prompb.TimeSeries, dimensions []wtypes.Dimension, measureValueName string) ([]wtypes.Record, error) { var operationOnInvalidSample func(timeSeriesValue float64) (labelOperation, error) if wc.failOnInvalidSample { operationOnInvalidSample = func(timeSeriesValue float64) (labelOperation, error) { @@ -489,7 +537,7 @@ func (wc *WriteClient) appendRecords(records []*timestreamwrite.Record, timeSeri } for _, sample := range timeSeries.Samples { - // sample.Value is the measured value of a metric which maps to the MeasureValue in timestreamwrite.Record + // sample.Value is the measured value of a metric which maps to the MeasureValue in wtypes.Record timeSeriesValue := sample.Value operation, err := operationOnInvalidSample(timeSeriesValue) @@ -501,13 +549,13 @@ func (wc *WriteClient) appendRecords(records []*timestreamwrite.Record, timeSeri default: } - records = append(records, ×treamwrite.Record{ + records = append(records, wtypes.Record{ Dimensions: dimensions, MeasureName: aws.String(measureValueName), MeasureValue: aws.String(strconv.FormatFloat(timeSeriesValue, 'f', 6, 64)), - MeasureValueType: aws.String(timestreamwrite.MeasureValueTypeDouble), + MeasureValueType: wtypes.MeasureValueTypeDouble, Time: aws.String(strconv.FormatInt(sample.Timestamp, 10)), - TimeUnit: aws.String(timestreamwrite.TimeUnitMilliseconds), + TimeUnit: wtypes.TimeUnitMilliseconds, }) } @@ -584,6 +632,7 @@ func (qc *QueryClient) convertToResult(results *prompb.QueryResult, page *timest } for _, row := range rows { + labels, samples, err := qc.constructLabels(row.Data, page.ColumnInfo) if err != nil { LogDebug(qc.logger, "Error occurred when constructing Prometheus Labels from Timestream QueryOutput with Row", "row", row) @@ -597,34 +646,39 @@ func (qc *QueryClient) convertToResult(results *prompb.QueryResult, page *timest } // constructLabels converts the given row to the corresponding Prometheus Label and Sample. -func (qc *QueryClient) constructLabels(row []*timestreamquery.Datum, metadata []*timestreamquery.ColumnInfo) ([]*prompb.Label, prompb.Sample, error) { +func (qc *QueryClient) constructLabels(row []qtypes.Datum, metadata []qtypes.ColumnInfo) ([]*prompb.Label, prompb.Sample, error) { var labels []*prompb.Label var sample prompb.Sample + for i, datum := range row { + if datum.NullValue == nil { column := metadata[i] switch *column.Name { case timeColumnName: timestamp, err := time.Parse(timestampLayout, *datum.ScalarValue) if err != nil { - err := fmt.Errorf("error occured while parsing '%d' as a timestamp", datum.ScalarValue) + err := fmt.Errorf("error occurred while parsing '%s' as a timestamp", *datum.ScalarValue) LogError(qc.logger, "Invalid datum type retrieved from Timestream", err) return labels, sample, err } sample.Timestamp = timestamp.UnixNano() / nanosToMillisConversionRate + case measureValueColumnName: val, err := strconv.ParseFloat(*datum.ScalarValue, 64) if err != nil { - err := fmt.Errorf("error occured while parsing '%d' as a float", datum.ScalarValue) + err := fmt.Errorf("error occurred while parsing '%s' as a float", *datum.ScalarValue) LogError(qc.logger, "Invalid datum type retrieved from Timestream", err) return labels, sample, err } sample.Value = val + case measureNameColumnName: labels = append(labels, &prompb.Label{ Name: model.MetricNameLabel, Value: *datum.ScalarValue, }) + default: labels = append(labels, &prompb.Label{ Name: *column.Name, diff --git a/timestream/client_test.go b/timestream/client_test.go index 5a44149..b4c836b 100644 --- a/timestream/client_test.go +++ b/timestream/client_test.go @@ -15,15 +15,22 @@ and limitations under the License. package timestream import ( + "context" goErrors "errors" "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/private/protocol" - "github.com/aws/aws-sdk-go/service/timestreamquery" - "github.com/aws/aws-sdk-go/service/timestreamquery/timestreamqueryiface" - "github.com/aws/aws-sdk-go/service/timestreamwrite" - "github.com/aws/aws-sdk-go/service/timestreamwrite/timestreamwriteiface" + "math" + "reflect" + "sort" + "strconv" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/timestreamquery" + qtypes "github.com/aws/aws-sdk-go-v2/service/timestreamquery/types" + "github.com/aws/aws-sdk-go-v2/service/timestreamwrite" + wtypes "github.com/aws/aws-sdk-go-v2/service/timestreamwrite/types" "github.com/go-kit/log" "github.com/google/go-cmp/cmp" "github.com/prometheus/client_golang/prometheus" @@ -31,74 +38,97 @@ import ( "github.com/prometheus/prometheus/prompb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "math" - "reflect" - "sort" - "strconv" - "testing" - "time" + "timestream-prometheus-connector/errors" ) var ( - mockLogger = log.NewNopLogger() - mockUnixTime = time.Now().UnixNano() / (int64(time.Millisecond) / int64(time.Nanosecond)) - mockCounter = prometheus.NewCounter(prometheus.CounterOpts{}) - mockHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{}) - mockEndUnixTime = mockUnixTime + 30000 - mockAwsConfigs = &aws.Config{} - mockCredentials = credentials.AnonymousCredentials + mockLogger = log.NewNopLogger() + mockUnixTime = time.Now().UnixNano() / (int64(time.Millisecond) / int64(time.Nanosecond)) + mockCounter = prometheus.NewCounter(prometheus.CounterOpts{}) + mockHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{}) + mockEndUnixTime = mockUnixTime + 30000 + mockCredentials = aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider("mockAccessKey", "mockSecretKey", "mockSessionToken")) + mockAwsConfigs = aws.Config{ + Credentials: mockCredentials, + Region: "us-west-2", + } startUnixInSeconds = mockUnixTime / millisToSecConversionRate endUnixInSeconds = mockEndUnixTime / millisToSecConversionRate ) const ( - mockTableName = "prom" - mockDatabaseName = "promDB" - mockRegion = "us-east-1" - mockLongMetric = "prometheus_remote_storage_queue_highest_sent_timestamp_seconds" - instance = "localhost:9090" - metricName = "go_gc_duration_seconds" - job = "prometheus" - measureValueStr = "0.001995" - invalidValue = "invalidValue" - invalidTime = "invalidTime" - timestamp1 = "2020-10-01 15:02:02.000000000" - timestamp2 = "2020-10-01 20:00:00.000000000" - quantile = "0.5" - instanceRegex = "9090*" - jobRegex = "pro*" - invalidRegex = "(?P\\w+)" - unixTime1 = 1601564522000 - unixTime2 = 1601582400000 - measureValue = 0.001995 - invalidMatcher = 10 - functionType = "func(*timestreamquery.QueryOutput, bool) bool" + mockTableName = "prom" + mockDatabaseName = "promDB" + mockRegion = "us-east-1" + mockLongMetric = "prometheus_remote_storage_queue_highest_sent_timestamp_seconds" + instance = "localhost:9090" + metricName = "go_gc_duration_seconds" + job = "prometheus" + measureValueStr = "0.001995" + invalidValue = "invalidValue" + invalidTime = "invalidTime" + timestamp1 = "2020-10-01 15:02:02.000000000" + timestamp2 = "2020-10-01 20:00:00.000000000" + quantile = "0.5" + instanceRegex = "9090*" + jobRegex = "pro*" + invalidRegex = "(?P\\w+)" + unixTime1 = 1601564522000 + unixTime2 = 1601582400000 + measureValue = 0.001995 + invalidMatcher = 10 + functionType = "func(*timestreamquery.QueryOutput, bool) bool" ) -type mockTimestreamWriteClient struct { +type mockPaginator struct { mock.Mock - timestreamwriteiface.TimestreamWriteAPI } -func (m *mockTimestreamWriteClient) WriteRecords(input *timestreamwrite.WriteRecordsInput) (*timestreamwrite.WriteRecordsOutput, error) { - args := m.Called(input) - return args.Get(0).(*timestreamwrite.WriteRecordsOutput), args.Error(1) +func newMockPaginator(timestreamQuery *timestreamquery.Client, queryInput *timestreamquery.QueryInput) *mockPaginator { + return &mockPaginator{} } -type mockTimestreamQueryClient struct { +func (m *mockPaginator) HasMorePages() bool { + args := m.Called() + if result := args.Get(0); result != nil { + return result.(bool) + } + return false +} + +func (m *mockPaginator) NextPage(ctx context.Context) (*timestreamquery.QueryOutput, error) { + args := m.Called(ctx) + if result := args.Get(0); result != nil { + return result.(*timestreamquery.QueryOutput), args.Error(1) + } + return nil, args.Error(1) +} + +type mockTimestreamWriteClient struct { mock.Mock - timestreamqueryiface.TimestreamQueryAPI } -func (m *mockTimestreamQueryClient) QueryPages(input *timestreamquery.QueryInput, f func(page *timestreamquery.QueryOutput, lastPage bool) bool) error { - args := m.Called(input, f) - return args.Error(0) +func (m *mockTimestreamWriteClient) WriteRecords( + ctx context.Context, + input *timestreamwrite.WriteRecordsInput, + optFns ...func(*timestreamwrite.Options), +) (*timestreamwrite.WriteRecordsOutput, error) { + args := m.Called(ctx, input, optFns) + if result := args.Get(0); result != nil { + return result.(*timestreamwrite.WriteRecordsOutput), args.Error(1) + } + return nil, args.Error(1) } -func TestClientNewClient(t *testing.T) { +type mockTimestreamQueryClient struct { + mock.Mock + *timestreamquery.Client +} + +func TestClientNewWriteClient(t *testing.T) { client := NewBaseClient(mockDatabaseName, mockTableName) - client.NewWriteClient(mockLogger, &aws.Config{Region: aws.String(mockRegion)}, true, true) + client.NewWriteClient(mockLogger, aws.Config{Region: mockRegion}, true, true) assert.NotNil(t, client.writeClient) assert.Equal(t, mockLogger, client.writeClient.logger) @@ -108,14 +138,8 @@ func TestClientNewClient(t *testing.T) { } func TestClientNewQueryClient(t *testing.T) { - // Mock the instantiation of query client newClients does not create a real query client. - queryInput := ×treamquery.QueryInput{QueryString: aws.String("SELECT 1")} - mockTimestreamQueryClient := new(mockTimestreamQueryClient) - mockTimestreamQueryClient.On("QueryPages", queryInput, - mock.AnythingOfType(functionType)).Return(nil) - client := NewBaseClient(mockDatabaseName, mockTableName) - client.NewQueryClient(mockLogger, &aws.Config{Region: aws.String(mockRegion)}) + client.NewQueryClient(mockLogger, aws.Config{Region: mockRegion}) assert.NotNil(t, client.queryClient) assert.Equal(t, mockLogger, client.queryClient.logger) @@ -148,7 +172,7 @@ func TestQueryClientRead(t *testing.T) { ColumnInfo: createColumnInfo(), NextToken: aws.String("nextToken"), QueryId: aws.String("QueryID"), - Rows: []*timestreamquery.Row{ + Rows: []qtypes.Row{ { Data: createDatumWithInstance( true, @@ -174,7 +198,7 @@ func TestQueryClientRead(t *testing.T) { timestamp1), }, { - Data: []*timestreamquery.Datum{ + Data: []qtypes.Datum{ {ScalarValue: aws.String(instance)}, {ScalarValue: aws.String(job)}, {ScalarValue: aws.String(measureValueStr)}, @@ -187,7 +211,7 @@ func TestQueryClientRead(t *testing.T) { queryOutputWithInvalidMeasureValue := ×treamquery.QueryOutput{ ColumnInfo: createColumnInfo(), - Rows: []*timestreamquery.Row{ + Rows: []qtypes.Row{ { Data: createDatumWithInstance( true, @@ -201,7 +225,7 @@ func TestQueryClientRead(t *testing.T) { queryOutputWithInvalidTime := ×treamquery.QueryOutput{ ColumnInfo: createColumnInfo(), - Rows: []*timestreamquery.Row{ + Rows: []qtypes.Row{ { Data: createDatumWithInstance( true, @@ -270,18 +294,31 @@ func TestQueryClientRead(t *testing.T) { } queryInputWithInvalidRegex := ×treamquery.QueryInput{ - QueryString: aws.String(fmt.Sprintf("SELECT * FROM %s.%s WHERE %s = '%s' AND REGEXP_LIKE(job, '%s') AND %s BETWEEN FROM_UNIXTIME(%d) AND FROM_UNIXTIME(%d)", - mockDatabaseName, mockTableName, measureNameColumnName, metricName, invalidRegex, timeColumnName, startUnixInSeconds, endUnixInSeconds)), + QueryString: aws.String(fmt.Sprintf( + "SELECT * FROM %s.%s WHERE %s = '%s' AND REGEXP_LIKE(job, '%s') AND %s BETWEEN FROM_UNIXTIME(%d) AND FROM_UNIXTIME(%d)", + mockDatabaseName, + mockTableName, + measureNameColumnName, + metricName, + invalidRegex, + timeColumnName, + startUnixInSeconds, + endUnixInSeconds, + )), } t.Run("success", func(t *testing.T) { mockTimestreamQueryClient := new(mockTimestreamQueryClient) - mockTimestreamQueryClient.On("QueryPages", queryInput, - mock.AnythingOfType(functionType)).Return(nil) - initQueryClient = func(config *aws.Config) (timestreamqueryiface.TimestreamQueryAPI, error) { - return mockTimestreamQueryClient, nil + initQueryClient = func(config aws.Config) (*timestreamquery.Client, error) { + return mockTimestreamQueryClient.Client, nil } + mockPaginator := newMockPaginator(mockTimestreamQueryClient.Client, queryInput) + mockPaginator.On("HasMorePages").Return(false, nil) + mockPaginator.On("NextPage", mock.Anything).Return(nil, nil) + initPaginatorFactory = func(timestreamQuery *timestreamquery.Client, queryInput *timestreamquery.QueryInput) Paginator { + return mockPaginator + } c := &Client{ writeClient: nil, defaultDataBase: mockDatabaseName, @@ -289,7 +326,7 @@ func TestQueryClientRead(t *testing.T) { } c.queryClient = createNewQueryClientTemplate(c) - readResponse, err := c.queryClient.Read(request, mockCredentials) + readResponse, err := c.queryClient.Read(context.Background(), request, mockCredentials) assert.Nil(t, err) assert.Equal(t, response, readResponse) @@ -298,10 +335,14 @@ func TestQueryClientRead(t *testing.T) { t.Run("success without mapping", func(t *testing.T) { mockTimestreamQueryClient := new(mockTimestreamQueryClient) - mockTimestreamQueryClient.On("QueryPages", queryInput, - mock.AnythingOfType(functionType)).Return(nil) - initQueryClient = func(config *aws.Config) (timestreamqueryiface.TimestreamQueryAPI, error) { - return mockTimestreamQueryClient, nil + initQueryClient = func(config aws.Config) (*timestreamquery.Client, error) { + return mockTimestreamQueryClient.Client, nil + } + mockPaginator := newMockPaginator(mockTimestreamQueryClient.Client, queryInput) + mockPaginator.On("HasMorePages").Return(false, nil) + mockPaginator.On("NextPage", mock.Anything).Return(nil, nil) + initPaginatorFactory = func(timestreamQuery *timestreamquery.Client, queryInput *timestreamquery.QueryInput) Paginator { + return mockPaginator } c := &Client{ @@ -311,7 +352,7 @@ func TestQueryClientRead(t *testing.T) { } c.queryClient = createNewQueryClientTemplate(c) - readResponse, err := c.queryClient.Read(request, mockCredentials) + readResponse, err := c.queryClient.Read(context.Background(), request, mockCredentials) assert.Nil(t, err) assert.Equal(t, response, readResponse) @@ -319,42 +360,62 @@ func TestQueryClientRead(t *testing.T) { }) t.Run("error from buildCommands with missing database name in request", func(t *testing.T) { - initQueryClient = func(config *aws.Config) (timestreamqueryiface.TimestreamQueryAPI, error) { - return new(mockTimestreamQueryClient), nil + mockTimestreamQueryClient := new(mockTimestreamQueryClient) + initQueryClient = func(config aws.Config) (*timestreamquery.Client, error) { + return mockTimestreamQueryClient.Client, nil + } + + mockPaginator := newMockPaginator(mockTimestreamQueryClient.Client, queryInput) + mockPaginator.On("HasMorePages").Return(false, nil) + mockPaginator.On("NextPage", mock.Anything).Return(nil, nil) + initPaginatorFactory = func(timestreamQuery *timestreamquery.Client, queryInput *timestreamquery.QueryInput) Paginator { + return mockPaginator } c := &Client{ - writeClient: nil, + writeClient: nil, } c.queryClient = createNewQueryClientTemplate(c) - _, err := c.queryClient.Read(request, mockCredentials) + _, err := c.queryClient.Read(context.Background(), request, mockCredentials) assert.IsType(t, &errors.MissingDatabaseError{}, err) }) - t.Run("error from buildCommands with missing table name in request", func(t *testing.T) { - initQueryClient = func(config *aws.Config) (timestreamqueryiface.TimestreamQueryAPI, error) { - return new(mockTimestreamQueryClient), nil + t.Run("success from NextPage() using data helpers", func(t *testing.T) { + mockTimestreamQueryClient := new(mockTimestreamQueryClient) + initQueryClient = func(config aws.Config) (*timestreamquery.Client, error) { + return mockTimestreamQueryClient.Client, nil + } + + mockPaginator := newMockPaginator(mockTimestreamQueryClient.Client, queryInput) + mockPaginator.On("HasMorePages").Return(false, nil) + mockPaginator.On("NextPage", mock.Anything).Return(nil, nil) + initPaginatorFactory = func(timestreamQuery *timestreamquery.Client, queryInput *timestreamquery.QueryInput) Paginator { + return mockPaginator } c := &Client{ - writeClient: nil, + writeClient: nil, defaultDataBase: mockDatabaseName, + defaultTable: mockTableName, } c.queryClient = createNewQueryClientTemplate(c) - _, err := c.queryClient.Read(request, mockCredentials) - assert.IsType(t, &errors.MissingTableError{}, err) + readResponse, err := c.queryClient.Read(context.Background(), request, mockCredentials) + + assert.NoError(t, err) + assert.NotNil(t, readResponse) + mockTimestreamQueryClient.AssertExpectations(t) }) - t.Run("error from QueryPages()", func(t *testing.T) { - mockTimestreamQueryClient := new(mockTimestreamQueryClient) - serverError := ×treamquery.InternalServerException{} - mockTimestreamQueryClient.On("QueryPages", queryInput, - mock.AnythingOfType(functionType)).Return(serverError) + t.Run("error from NextPage()", func(t *testing.T) { + serverError := &qtypes.InternalServerException{Message: aws.String("Server error")} - initQueryClient = func(config *aws.Config) (timestreamqueryiface.TimestreamQueryAPI, error) { - return mockTimestreamQueryClient, nil + mockPaginator := new(mockPaginator) + mockPaginator.On("HasMorePages").Return(true, nil) + mockPaginator.On("NextPage", mock.Anything).Return(nil, serverError) + initPaginatorFactory = func(timestreamQuery *timestreamquery.Client, queryInput *timestreamquery.QueryInput) Paginator { + return mockPaginator } c := &Client{ @@ -364,9 +425,37 @@ func TestQueryClientRead(t *testing.T) { } c.queryClient = createNewQueryClientTemplate(c) - _, err := c.queryClient.Read(request, mockCredentials) + _, err := c.queryClient.Read(context.Background(), request, mockCredentials) assert.Equal(t, serverError, err) + mockPaginator.AssertExpectations(t) + }) + + t.Run("error from NextPage() with invalid regex", func(t *testing.T) { + validationError := &wtypes.ValidationException{Message: aws.String("Validation error occurred")} + mockTimestreamQueryClient := new(mockTimestreamQueryClient) + + mockPaginator := newMockPaginator(mockTimestreamQueryClient.Client, queryInputWithInvalidRegex) + mockPaginator.On("HasMorePages").Return(true, nil) + mockPaginator.On("NextPage", mock.Anything).Return(nil, validationError) + initPaginatorFactory = func(timestreamQuery *timestreamquery.Client, queryInput *timestreamquery.QueryInput) Paginator { + return mockPaginator + } + + initQueryClient = func(config aws.Config) (*timestreamquery.Client, error) { + return mockTimestreamQueryClient.Client, nil + } + + c := &Client{ + writeClient: nil, + defaultDataBase: mockDatabaseName, + defaultTable: mockTableName, + } + c.queryClient = createNewQueryClientTemplate(c) + + _, err := c.queryClient.Read(context.Background(), requestWithInvalidRegex, mockCredentials) + assert.Equal(t, validationError, err) + mockTimestreamQueryClient.AssertExpectations(t) }) @@ -439,6 +528,13 @@ func TestQueryClientRead(t *testing.T) { }) t.Run("error from buildCommand with unknown matcher type", func(t *testing.T) { + mockPaginator := new(mockPaginator) + mockPaginator.On("HasMorePages").Return(false, nil) + mockPaginator.On("NextPage", mock.Anything).Return(nil, nil) + initPaginatorFactory = func(timestreamQuery *timestreamquery.Client, queryInput *timestreamquery.QueryInput) Paginator { + return mockPaginator + } + c := &Client{ writeClient: nil, defaultDataBase: mockDatabaseName, @@ -446,51 +542,53 @@ func TestQueryClientRead(t *testing.T) { } c.queryClient = createNewQueryClientTemplate(c) - _, err := c.queryClient.Read(requestWithInvalidMatcher, mockCredentials) + _, err := c.queryClient.Read(context.Background(), requestWithInvalidMatcher, mockCredentials) assert.IsType(t, &errors.UnknownMatcherError{}, err) }) - t.Run("error from queryPages with invalid regex", func(t *testing.T) { - validationError := ×treamquery.ValidationException{ - RespMetadata: protocol.ResponseMetadata{StatusCode: 400}, - } + t.Run("error from buildCommands with missing table name in request", func(t *testing.T) { mockTimestreamQueryClient := new(mockTimestreamQueryClient) - mockTimestreamQueryClient.On("QueryPages", queryInputWithInvalidRegex, - mock.AnythingOfType(functionType)).Return(validationError) + initQueryClient = func(config aws.Config) (*timestreamquery.Client, error) { + return mockTimestreamQueryClient.Client, nil + } - initQueryClient = func(config *aws.Config) (timestreamqueryiface.TimestreamQueryAPI, error) { - return mockTimestreamQueryClient, nil + mockPaginator := newMockPaginator(mockTimestreamQueryClient.Client, queryInput) + mockPaginator.On("HasMorePages").Return(false, nil) + mockPaginator.On("NextPage", mock.Anything).Return(nil, nil) + initPaginatorFactory = func(timestreamQuery *timestreamquery.Client, queryInput *timestreamquery.QueryInput) Paginator { + return mockPaginator } c := &Client{ writeClient: nil, defaultDataBase: mockDatabaseName, - defaultTable: mockTableName, } c.queryClient = createNewQueryClientTemplate(c) - _, err := c.queryClient.Read(requestWithInvalidRegex, mockCredentials) - assert.Equal(t, validationError, err) - - mockTimestreamQueryClient.AssertExpectations(t) + _, err := c.queryClient.Read(context.Background(), request, mockCredentials) + assert.IsType(t, &errors.MissingTableError{}, err) }) } func TestWriteClientWrite(t *testing.T) { t.Run("success", func(t *testing.T) { + mockTimestreamWriteClient := new(mockTimestreamWriteClient) expectedInput := createNewWriteRecordsInputTemplate() + mockTimestreamWriteClient.On( "WriteRecords", + mock.Anything, mock.MatchedBy(func(writeInput *timestreamwrite.WriteRecordsInput) bool { // Sort the records in the WriteRecordsInput by their time, and sort the Dimension by dimension names. sortRecords(writeInput) sortRecords(expectedInput) - return reflect.DeepEqual(writeInput, expectedInput) - })).Return(×treamwrite.WriteRecordsOutput{}, nil) + }), + mock.Anything, + ).Return(×treamwrite.WriteRecordsOutput{}, nil) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -501,9 +599,10 @@ func TestWriteClientWrite(t *testing.T) { } c.writeClient = createNewWriteClientTemplate(c) - err := c.writeClient.Write(createNewRequestTemplate(), mockCredentials) + err := c.writeClient.Write(context.Background(), createNewRequestTemplate(), mockCredentials) assert.Nil(t, err) + mockTimestreamWriteClient.AssertCalled(t, "WriteRecords", mock.Anything, expectedInput, mock.Anything) mockTimestreamWriteClient.AssertExpectations(t) }) @@ -515,15 +614,16 @@ func TestWriteClientWrite(t *testing.T) { mockTimestreamWriteClient.On( "WriteRecords", + mock.Anything, mock.MatchedBy(func(writeInput *timestreamwrite.WriteRecordsInput) bool { - // Sort the records in the WriteRecordsInput by their time, and sort the Dimension by dimension names. sortRecords(writeInput) sortRecords(expectedInput) - return reflect.DeepEqual(writeInput, expectedInput) - })).Return(×treamwrite.WriteRecordsOutput{}, nil) + }), + mock.Anything, + ).Return(×treamwrite.WriteRecordsOutput{}, nil) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -540,7 +640,7 @@ func TestWriteClientWrite(t *testing.T) { Value: measureValue, }) - err := c.writeClient.Write(req, mockCredentials) + err := c.writeClient.Write(context.Background(), req, mockCredentials) assert.Nil(t, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 1) @@ -555,18 +655,18 @@ func TestWriteClientWrite(t *testing.T) { mockTimestreamWriteClient.On( "WriteRecords", + mock.Anything, mock.MatchedBy(func(writeInput *timestreamwrite.WriteRecordsInput) bool { - // Sort the records in the WriteRecordsInput by their time, and sort the Dimension by dimension names. sortRecords(writeInput) sortRecords(expectedInput) - return reflect.DeepEqual(writeInput, expectedInput) - })).Return(×treamwrite.WriteRecordsOutput{}, nil) + }), + mock.Anything, + ).Return(×treamwrite.WriteRecordsOutput{}, nil) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } - c := &Client{ queryClient: nil, defaultDataBase: mockDatabaseName, @@ -580,7 +680,7 @@ func TestWriteClientWrite(t *testing.T) { Value: measureValue, }) - errWm := c.writeClient.Write(reqWithoutMapping, mockCredentials) + errWm := c.writeClient.Write(context.Background(), reqWithoutMapping, mockCredentials) assert.Nil(t, errWm) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 1) @@ -595,13 +695,17 @@ func TestWriteClientWrite(t *testing.T) { mockTimestreamWriteClient.On( "WriteRecords", + mock.Anything, mock.MatchedBy(func(writeInput *timestreamwrite.WriteRecordsInput) bool { sortRecords(writeInput) sortRecords(expectedInput) + return reflect.DeepEqual(writeInput, expectedInput) - })).Return(×treamwrite.WriteRecordsOutput{}, nil) + }), + mock.Anything, + ).Return(×treamwrite.WriteRecordsOutput{}, nil) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -614,7 +718,7 @@ func TestWriteClientWrite(t *testing.T) { req := createNewRequestTemplate() req.Timeseries = append(req.Timeseries, createTimeSeriesTemplate()) - err := c.writeClient.Write(req, mockCredentials) + err := c.writeClient.Write(context.Background(), req, mockCredentials) assert.Nil(t, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 1) @@ -629,24 +733,28 @@ func TestWriteClientWrite(t *testing.T) { mockTimestreamWriteClient.On( "WriteRecords", + mock.Anything, mock.MatchedBy(func(writeInput *timestreamwrite.WriteRecordsInput) bool { sortRecords(writeInput) sortRecords(expectedInput) + return reflect.DeepEqual(writeInput, expectedInput) - })).Return(×treamwrite.WriteRecordsOutput{}, nil) + }), + mock.Anything, + ).Return(×treamwrite.WriteRecordsOutput{}, nil) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } c := &Client{ - queryClient: nil, + queryClient: nil, } c.writeClient = createNewWriteClientTemplate(c) req := createNewRequestTemplate() req.Timeseries = append(req.Timeseries, createTimeSeriesTemplate()) - err := c.writeClient.Write(req, mockCredentials) + err := c.writeClient.Write(context.Background(), req, mockCredentials) expectedErr := errors.NewMissingDatabaseWithWriteError("", createTimeSeriesTemplate()) assert.Equal(t, err, expectedErr) }) @@ -659,25 +767,29 @@ func TestWriteClientWrite(t *testing.T) { mockTimestreamWriteClient.On( "WriteRecords", + mock.Anything, mock.MatchedBy(func(writeInput *timestreamwrite.WriteRecordsInput) bool { sortRecords(writeInput) sortRecords(expectedInput) + return reflect.DeepEqual(writeInput, expectedInput) - })).Return(×treamwrite.WriteRecordsOutput{}, nil) + }), + mock.Anything, + ).Return(×treamwrite.WriteRecordsOutput{}, nil) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } c := &Client{ - queryClient: nil, + queryClient: nil, defaultDataBase: mockDatabaseName, } c.writeClient = createNewWriteClientTemplate(c) req := createNewRequestTemplate() req.Timeseries = append(req.Timeseries, createTimeSeriesTemplate()) - err := c.writeClient.Write(req, mockCredentials) + err := c.writeClient.Write(context.Background(), req, mockCredentials) expectedErr := errors.NewMissingTableWithWriteError("", createTimeSeriesTemplate()) assert.Equal(t, err, expectedErr) }) @@ -685,12 +797,12 @@ func TestWriteClientWrite(t *testing.T) { t.Run("error from convertToRecords due to missing ingestion database destination", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } c := &Client{ - queryClient: nil, + queryClient: nil, } c.writeClient = createNewWriteClientTemplate(c) @@ -706,7 +818,7 @@ func TestWriteClientWrite(t *testing.T) { }, } - err := c.WriteClient().Write(input, mockCredentials) + err := c.WriteClient().Write(context.Background(), input, mockCredentials) assert.IsType(t, &errors.MissingDatabaseWithWriteError{}, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 0) @@ -715,12 +827,12 @@ func TestWriteClientWrite(t *testing.T) { t.Run("error from convertToRecords due to missing ingestion table destination", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } c := &Client{ - queryClient: nil, + queryClient: nil, defaultDataBase: mockDatabaseName, } c.writeClient = createNewWriteClientTemplate(c) @@ -737,7 +849,7 @@ func TestWriteClientWrite(t *testing.T) { }, } - err := c.WriteClient().Write(input, mockCredentials) + err := c.WriteClient().Write(context.Background(), input, mockCredentials) assert.IsType(t, &errors.MissingTableWithWriteError{}, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 0) @@ -746,19 +858,22 @@ func TestWriteClientWrite(t *testing.T) { t.Run("error from WriteRecords()", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) expectedInput := createNewWriteRecordsInputTemplate() - requestError := ×treamwrite.ValidationException{ - RespMetadata: protocol.ResponseMetadata{StatusCode: 404}, + requestError := &wtypes.ValidationException{ + Message: aws.String("Validation error occurred"), } mockTimestreamWriteClient.On( "WriteRecords", + mock.Anything, mock.MatchedBy(func(writeInput *timestreamwrite.WriteRecordsInput) bool { sortRecords(writeInput) sortRecords(expectedInput) return reflect.DeepEqual(writeInput, expectedInput) - })).Return(×treamwrite.WriteRecordsOutput{}, requestError) + }), + mock.Anything, + ).Return(×treamwrite.WriteRecordsOutput{}, requestError) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -769,7 +884,7 @@ func TestWriteClientWrite(t *testing.T) { } c.writeClient = createNewWriteClientTemplate(c) - err := c.WriteClient().Write(createNewRequestTemplate(), mockCredentials) + err := c.WriteClient().Write(context.Background(), createNewRequestTemplate(), mockCredentials) assert.Equal(t, requestError, err) mockTimestreamWriteClient.AssertExpectations(t) @@ -780,13 +895,16 @@ func TestWriteClientWrite(t *testing.T) { expectedInput := createNewWriteRecordsInputTemplate() mockTimestreamWriteClient.On( "WriteRecords", + mock.Anything, mock.MatchedBy(func(writeInput *timestreamwrite.WriteRecordsInput) bool { sortRecords(writeInput) sortRecords(expectedInput) return reflect.DeepEqual(writeInput, expectedInput) - })).Return(×treamwrite.WriteRecordsOutput{}, nil) + }), + mock.Anything, + ).Return(×treamwrite.WriteRecordsOutput{}, nil) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -799,7 +917,7 @@ func TestWriteClientWrite(t *testing.T) { c.writeClient.failOnInvalidSample = true req := createNewRequestTemplate() - err := c.WriteClient().Write(req, mockCredentials) + err := c.WriteClient().Write(context.Background(), req, mockCredentials) assert.Nil(t, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 1) @@ -808,7 +926,7 @@ func TestWriteClientWrite(t *testing.T) { t.Run("NaN timeSeries with fail-fast enabled", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -822,7 +940,7 @@ func TestWriteClientWrite(t *testing.T) { req := createNewRequestTemplate() req.Timeseries[0].Samples[0].Value = math.NaN() - err := c.WriteClient().Write(req, mockCredentials) + err := c.WriteClient().Write(context.Background(), req, mockCredentials) assert.IsType(t, &errors.InvalidSampleValueError{}, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 0) @@ -831,7 +949,7 @@ func TestWriteClientWrite(t *testing.T) { t.Run("NaN timeSeries with fail-fast disabled", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -845,7 +963,7 @@ func TestWriteClientWrite(t *testing.T) { req := createNewRequestTemplate() req.Timeseries[0].Samples[0].Value = math.NaN() - err := c.WriteClient().Write(req, mockCredentials) + err := c.WriteClient().Write(context.Background(), req, mockCredentials) assert.Nil(t, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 0) @@ -854,9 +972,10 @@ func TestWriteClientWrite(t *testing.T) { t.Run("Inf timeSeries with fail-fast enabled", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } + ctx := context.Background() c := &Client{ queryClient: nil, @@ -868,11 +987,11 @@ func TestWriteClientWrite(t *testing.T) { req := createNewRequestTemplate() req.Timeseries[0].Samples[0].Value = math.Inf(1) - err := c.WriteClient().Write(req, mockCredentials) + err := c.WriteClient().Write(ctx, req, mockCredentials) assert.NotNil(t, err) req.Timeseries[0].Samples[0].Value = math.Inf(-1) - err = c.WriteClient().Write(req, mockCredentials) + err = c.WriteClient().Write(ctx, req, mockCredentials) assert.IsType(t, &errors.InvalidSampleValueError{}, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 0) @@ -881,9 +1000,10 @@ func TestWriteClientWrite(t *testing.T) { t.Run("Inf timeSeries with fail-fast disabled", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } + ctx := context.Background() c := &Client{ queryClient: nil, @@ -895,11 +1015,11 @@ func TestWriteClientWrite(t *testing.T) { req := createNewRequestTemplate() req.Timeseries[0].Samples[0].Value = math.Inf(1) - err := c.WriteClient().Write(req, mockCredentials) + err := c.WriteClient().Write(ctx, req, mockCredentials) assert.Nil(t, err) req.Timeseries[0].Samples[0].Value = math.Inf(-1) - err = c.WriteClient().Write(req, mockCredentials) + err = c.WriteClient().Write(ctx, req, mockCredentials) assert.Nil(t, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 0) @@ -908,7 +1028,7 @@ func TestWriteClientWrite(t *testing.T) { t.Run("long metric name with fail-fast enabled", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -922,7 +1042,7 @@ func TestWriteClientWrite(t *testing.T) { req := createNewRequestTemplate() req.Timeseries[0].Labels[0].Value = mockLongMetric - err := c.WriteClient().Write(req, mockCredentials) + err := c.WriteClient().Write(context.Background(), req, mockCredentials) assert.IsType(t, &errors.LongLabelNameError{}, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 0) @@ -931,7 +1051,7 @@ func TestWriteClientWrite(t *testing.T) { t.Run("long metric name with fail-fast disabled", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -945,7 +1065,7 @@ func TestWriteClientWrite(t *testing.T) { req := createNewRequestTemplate() req.Timeseries[0].Labels[0].Value = mockLongMetric - err := c.WriteClient().Write(req, mockCredentials) + err := c.WriteClient().Write(context.Background(), req, mockCredentials) assert.Nil(t, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 0) @@ -954,7 +1074,7 @@ func TestWriteClientWrite(t *testing.T) { t.Run("long dimension name with fail-fast enabled", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -968,7 +1088,7 @@ func TestWriteClientWrite(t *testing.T) { req := createNewRequestTemplate() req.Timeseries[0].Labels[1].Name = mockLongMetric - err := c.WriteClient().Write(req, mockCredentials) + err := c.WriteClient().Write(context.Background(), req, mockCredentials) assert.IsType(t, &errors.LongLabelNameError{}, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 0) @@ -977,7 +1097,7 @@ func TestWriteClientWrite(t *testing.T) { t.Run("long dimension name with fail-fast disabled", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -991,7 +1111,7 @@ func TestWriteClientWrite(t *testing.T) { req := createNewRequestTemplate() req.Timeseries[0].Labels[1].Name = mockLongMetric - err := c.WriteClient().Write(req, mockCredentials) + err := c.WriteClient().Write(context.Background(), req, mockCredentials) assert.Nil(t, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 0) @@ -1000,9 +1120,15 @@ func TestWriteClientWrite(t *testing.T) { t.Run("unknown SDK error", func(t *testing.T) { mockTimestreamWriteClient := new(mockTimestreamWriteClient) unknownSDKErr := errors.NewSDKNonRequestError(goErrors.New("")) - mockTimestreamWriteClient.On("WriteRecords", createNewWriteRecordsInputTemplate()).Return(×treamwrite.WriteRecordsOutput{}, unknownSDKErr) + mockTimestreamWriteClient.On( + "WriteRecords", + mock.Anything, + createNewWriteRecordsInputTemplate(), + mock.Anything, + ).Return(×treamwrite.WriteRecordsOutput{}, + unknownSDKErr) - initWriteClient = func(config *aws.Config) (timestreamwriteiface.TimestreamWriteAPI, error) { + initWriteClient = func(config aws.Config) (TimestreamWriteClient, error) { return mockTimestreamWriteClient, nil } @@ -1014,7 +1140,7 @@ func TestWriteClientWrite(t *testing.T) { c.writeClient = createNewWriteClientTemplate(c) req := createNewRequestTemplate() - err := c.WriteClient().Write(req, mockCredentials) + err := c.WriteClient().Write(context.Background(), req, mockCredentials) assert.Equal(t, unknownSDKErr, err) mockTimestreamWriteClient.AssertNumberOfCalls(t, "WriteRecords", 1) @@ -1071,18 +1197,18 @@ func createNewRequestTemplateWithoutMapping() *prompb.WriteRequest { } // createNewRecordTemplate creates a template of timestreamwrite.Record pointer for unit tests. -func createNewRecordTemplate() *timestreamwrite.Record { - return ×treamwrite.Record{ - Dimensions: []*timestreamwrite.Dimension{ - &(timestreamwrite.Dimension{ +func createNewRecordTemplate() wtypes.Record { + return wtypes.Record{ + Dimensions: []wtypes.Dimension{ + (wtypes.Dimension{ Name: aws.String("label_1"), Value: aws.String("value_1")}), }, MeasureName: aws.String(metricName), MeasureValue: aws.String(measureValueStr), - MeasureValueType: aws.String(timestreamquery.ScalarTypeDouble), + MeasureValueType: wtypes.MeasureValueTypeDouble, Time: aws.String(strconv.FormatInt(mockUnixTime, 10)), - TimeUnit: aws.String(timestreamwrite.TimeUnitMilliseconds), + TimeUnit: wtypes.TimeUnitMilliseconds, } } @@ -1091,7 +1217,7 @@ func createNewWriteRecordsInputTemplate() *timestreamwrite.WriteRecordsInput { input := ×treamwrite.WriteRecordsInput{ DatabaseName: aws.String(mockDatabaseName), TableName: aws.String(mockTableName), - Records: []*timestreamwrite.Record{createNewRecordTemplate()}, + Records: []wtypes.Record{createNewRecordTemplate()}, } return input } @@ -1121,44 +1247,44 @@ func createNewQueryClientTemplate(c *Client) *QueryClient { } // createColumnInfo creates a Timestream ColumnInfo for constructing QueryOutput. -func createColumnInfo() []*timestreamquery.ColumnInfo { - return []*timestreamquery.ColumnInfo{ +func createColumnInfo() []qtypes.ColumnInfo { + return []qtypes.ColumnInfo{ { Name: aws.String(model.InstanceLabel), - Type: ×treamquery.Type{ - ScalarType: aws.String(timestreamquery.ScalarTypeVarchar), + Type: &qtypes.Type{ + ScalarType: qtypes.ScalarTypeVarchar, }, }, { Name: aws.String(model.JobLabel), - Type: ×treamquery.Type{ - ScalarType: aws.String(timestreamquery.ScalarTypeVarchar), + Type: &qtypes.Type{ + ScalarType: qtypes.ScalarTypeVarchar, }, }, { Name: aws.String(measureValueColumnName), - Type: ×treamquery.Type{ - ScalarType: aws.String(timestreamquery.ScalarTypeDouble), + Type: &qtypes.Type{ + ScalarType: qtypes.ScalarTypeDouble, }, }, { Name: aws.String(measureNameColumnName), - Type: ×treamquery.Type{ - ScalarType: aws.String(timestreamquery.ScalarTypeVarchar), + Type: &qtypes.Type{ + ScalarType: qtypes.ScalarTypeVarchar, }, }, { Name: aws.String(timeColumnName), - Type: ×treamquery.Type{ - ScalarType: aws.String(timestreamquery.ScalarTypeTimestamp), + Type: &qtypes.Type{ + ScalarType: qtypes.ScalarTypeTimestamp, }, }, } } // createDatumWithInstance creates a Timestream Datum object with instance. -func createDatumWithInstance(isNullValue bool, instance string, measureValue string, measureName string, time string) []*timestreamquery.Datum { - return []*timestreamquery.Datum{ +func createDatumWithInstance(isNullValue bool, instance string, measureValue string, measureName string, time string) []qtypes.Datum { + return []qtypes.Datum{ {ScalarValue: aws.String(instance)}, {NullValue: aws.Bool(isNullValue)}, {ScalarValue: aws.String(measureValue)}, @@ -1168,8 +1294,8 @@ func createDatumWithInstance(isNullValue bool, instance string, measureValue str } // createDatumWithJob creates a Timestream Datum object with job. -func createDatumWithJob(isNullValue bool, job string, measureValue string, measureName string, time string) []*timestreamquery.Datum { - return []*timestreamquery.Datum{ +func createDatumWithJob(isNullValue bool, job string, measureValue string, measureName string, time string) []qtypes.Datum { + return []qtypes.Datum{ {NullValue: aws.Bool(isNullValue)}, {ScalarValue: aws.String(job)}, {ScalarValue: aws.String(measureValue)},