From bd8ca70b1deb18ed593a6353e48c7db88822161f Mon Sep 17 00:00:00 2001 From: alespour <42931850+alespour@users.noreply.github.com> Date: Fri, 12 Apr 2024 15:25:55 +0200 Subject: [PATCH] feat: add support for custom query/gRPC headers (#76) * fix: make sure Authorization header is set * feat: add WithHeader query option * feat: add custom headers to query/gRPC request * fix: database header not needed in query header * fix: optimized query metadata preparation * fix: remove unused methods * refactor: flight.Client is an interface * docs: update CHANGELOG --- CHANGELOG.md | 1 + influxdb3/client.go | 6 +-- influxdb3/example_query_test.go | 84 +++++++++++++++++++++++++++++ influxdb3/options.go | 16 ++++++ influxdb3/options_test.go | 20 +++++++ influxdb3/query.go | 23 ++++++-- influxdb3/query_test.go | 95 +++++++++++++++++++++++++++++++++ 7 files changed, 238 insertions(+), 7 deletions(-) create mode 100644 influxdb3/example_query_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 206c47a..0ac15a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ 1. [#68](https://github.com/InfluxCommunity/influxdb3-go/pull/68): Upgrade Go version to 1.22. 2. [#74](https://github.com/InfluxCommunity/influxdb3-go/pull/74): Use `log/slog` to print debug information instead of `fmt.Printf` +3. [#76](https://github.com/InfluxCommunity/influxdb3-go/pull/76): Add custom headers support for queries (gRPC requests) ## 0.6.0 [2024-03-01] diff --git a/influxdb3/client.go b/influxdb3/client.go index 2902aa0..e71a9e3 100644 --- a/influxdb3/client.go +++ b/influxdb3/client.go @@ -47,7 +47,7 @@ type Client struct { // Cached base server API URL. apiURL *url.URL // Flight client for executing queries - queryClient *flight.Client + queryClient flight.Client } // httpParams holds parameters for creating an HTTP request @@ -174,7 +174,7 @@ func (c *Client) makeAPICall(ctx context.Context, params httpParams) (*http.Resp } req.Header.Set("User-Agent", userAgent) if c.authorization != "" { - req.Header.Add("Authorization", c.authorization) + req.Header.Set("Authorization", c.authorization) } resp, err := c.config.HTTPClient.Do(req) @@ -238,6 +238,6 @@ func (c *Client) resolveHTTPError(r *http.Response) error { // Close closes all idle connections. func (c *Client) Close() error { c.config.HTTPClient.CloseIdleConnections() - err := (*c.queryClient).Close() + err := c.queryClient.Close() return err } diff --git a/influxdb3/example_query_test.go b/influxdb3/example_query_test.go new file mode 100644 index 0000000..14b3ebe --- /dev/null +++ b/influxdb3/example_query_test.go @@ -0,0 +1,84 @@ +/* + The MIT License + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. +*/ + +package influxdb3 + +import ( + "context" + "log" +) + +func ExampleClient_Query() { + client, err := NewFromEnv() + if err != nil { + log.Fatal(err) + } + defer client.Close() + + // query + iterator, err := client.Query(context.Background(), + "SELECT count(*) FROM weather WHERE time >= now() - interval '5 minutes'") + + for iterator.Next() { + // process the result + } + + // query with custom header + iterator, err = client.Query(context.Background(), + "SELECT count(*) FROM stat WHERE time >= now() - interval '5 minutes'", + WithHeader("X-trace-ID", "#0122")) + + for iterator.Next() { + // process the result + } +} + +func ExampleClient_QueryWithParameters() { + client, err := NewFromEnv() + if err != nil { + log.Fatal(err) + } + defer client.Close() + + // query + iterator, err := client.QueryWithParameters(context.Background(), + "SELECT count(*) FROM weather WHERE location = $location AND time >= now() - interval '5 minutes'", + QueryParameters{ + "location": "sun-valley-1", + }) + + for iterator.Next() { + // process the result + } + + // query with custom header + iterator, err = client.QueryWithParameters(context.Background(), + "SELECT count(*) FROM weather WHERE location = $location AND time >= now() - interval '5 minutes'", + QueryParameters{ + "location": "sun-valley-1", + }, + WithHeader("X-trace-ID", "#0122")) + + for iterator.Next() { + // process the result + } +} diff --git a/influxdb3/options.go b/influxdb3/options.go index 716c440..2149d38 100644 --- a/influxdb3/options.go +++ b/influxdb3/options.go @@ -23,6 +23,8 @@ package influxdb3 import ( + "net/http" + "github.com/influxdata/line-protocol/v2/lineprotocol" ) @@ -33,6 +35,9 @@ type QueryOptions struct { // Query type. QueryType QueryType + + // Headers to be included in requests. Use to add or override headers in `ClientConfig`. + Headers http.Header } // WriteOptions holds options for write @@ -110,6 +115,7 @@ type Option func(o *options) // Available options: // - WithDatabase // - WithQueryType +// - WithHeader type QueryOption = Option // WriteOption is a functional option type that can be passed to Client.Write methods. @@ -135,6 +141,16 @@ func WithQueryType(queryType QueryType) Option { } } +// WithHeader is used to add or override default header in Client.Query method. +func WithHeader(key, value string) Option { + return func(o *options) { + if o.Headers == nil { + o.Headers = make(http.Header, 0) + } + o.Headers[key] = []string{value} + } +} + // WithPrecision is used to override default precision in Client.Write methods. func WithPrecision(precision lineprotocol.Precision) Option { return func(o *options) { diff --git a/influxdb3/options_test.go b/influxdb3/options_test.go index c909625..f7b4c43 100644 --- a/influxdb3/options_test.go +++ b/influxdb3/options_test.go @@ -1,6 +1,7 @@ package influxdb3 import ( + "net/http" "testing" "github.com/google/go-cmp/cmp" @@ -40,6 +41,25 @@ func TestQueryOptions(t *testing.T) { QueryType: InfluxQL, }, }, + { + name: "add header", + opts: va(WithHeader("header-a", "value-a")), + want: &QueryOptions{ + Headers: http.Header{ + "header-a": {"value-a"}, + }, + }, + }, + { + name: "add headers", + opts: va(WithHeader("header-a", "value-a"), WithHeader("header-b", "value-b")), + want: &QueryOptions{ + Headers: http.Header{ + "header-a": {"value-a"}, + "header-b": {"value-b"}, + }, + }, + }, } for _, tc := range testCases { diff --git a/influxdb3/query.go b/influxdb3/query.go index 5b7404f..3de98d1 100644 --- a/influxdb3/query.go +++ b/influxdb3/query.go @@ -62,11 +62,15 @@ func (c *Client) initializeQueryClient() error { if err != nil { return fmt.Errorf("flight: %s", err) } - c.queryClient = &client + c.queryClient = client return nil } +func (c *Client) setQueryClient(flightClient flight.Client) { + c.queryClient = flightClient +} + // QueryParameters is a type for query parameters. type QueryParameters = map[string]any @@ -132,8 +136,19 @@ func (c *Client) query(ctx context.Context, query string, parameters QueryParame var queryType QueryType queryType = options.QueryType - ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+c.config.Token) - ctx = metadata.AppendToOutgoingContext(ctx, "database", database) + md := make(metadata.MD, 0) + for k, v := range c.config.Headers { + for _, value := range v { + md.Append(k, value) + } + } + for k, v := range options.Headers { + for _, value := range v { + md.Append(k, value) + } + } + md.Set("authorization", "Bearer "+c.config.Token) + ctx = metadata.NewOutgoingContext(ctx, md) ticketData := map[string]interface{}{ "database": database, @@ -151,7 +166,7 @@ func (c *Client) query(ctx context.Context, query string, parameters QueryParame } ticket := &flight.Ticket{Ticket: ticketJSON} - stream, err := (*c.queryClient).DoGet(ctx, ticket) + stream, err := c.queryClient.DoGet(ctx, ticket) if err != nil { return nil, fmt.Errorf("flight do get: %s", err) } diff --git a/influxdb3/query_test.go b/influxdb3/query_test.go index bbaf016..2d6d07a 100644 --- a/influxdb3/query_test.go +++ b/influxdb3/query_test.go @@ -24,10 +24,19 @@ package influxdb3 import ( "context" + "net/http" "testing" + "github.com/apache/arrow/go/v15/arrow" + "github.com/apache/arrow/go/v15/arrow/array" + "github.com/apache/arrow/go/v15/arrow/flight" + "github.com/apache/arrow/go/v15/arrow/ipc" + "github.com/apache/arrow/go/v15/arrow/memory" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" ) func TestQueryDatabaseNotSet(t *testing.T) { @@ -54,3 +63,89 @@ func TestQueryWithOptionsNotSet(t *testing.T) { assert.Error(t, err) assert.EqualError(t, err, "options not set") } + +func TestQueryWithCustomHeader(t *testing.T) { + s := flight.NewServerWithMiddleware(nil) + err := s.Init("localhost:18080") + require.NoError(t, err) + f := &flightServer{} + s.RegisterFlightService(f) + + go s.Serve() + defer s.Shutdown() + + middleware := &callHeadersMiddleware{} + fc, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, []flight.ClientMiddleware{ + flight.CreateClientMiddleware(middleware), + }, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer fc.Close() + + c, err := New(ClientConfig{ + Host: "http://localhost:80", + Token: "my-token", + Database: "my-database", + Headers: http.Header{ + "my-config-header": {"hdr-config-1"}, + }, + }) + require.NoError(t, err) + defer c.Close() + + c.setQueryClient(fc) + + _, err = c.Query(context.Background(), "SELECT * FROM nothing", WithHeader("my-call-header", "hdr-call-1")) + require.NoError(t, err, "DoGet success") + assert.True(t, middleware.outgoingMDOk, "context contains outgoing MD") + assert.NotNil(t, middleware.outgoingMD, "outgoing MD is not nil") + assert.Contains(t, middleware.outgoingMD, "authorization", "auth header present") + assert.Contains(t, middleware.outgoingMD, "my-config-header", "custom config header present") + assert.Equal(t, []string{"hdr-config-1"}, middleware.outgoingMD["my-config-header"], "custom config header value") + assert.Contains(t, middleware.outgoingMD, "my-call-header", "custom call header present") + assert.Equal(t, []string{"hdr-call-1"}, middleware.outgoingMD["my-call-header"],"custom call header value") +} + +// fake Flight server implementation + +type flightServer struct { + flight.BaseFlightServer +} + +func (f *flightServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "intField", Type: arrow.PrimitiveTypes.Int64, Nullable: false}, + {Name: "stringField", Type: arrow.BinaryTypes.String, Nullable: false}, + {Name: "floatField", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + }, nil) + builder := array.NewRecordBuilder(memory.DefaultAllocator, schema) + defer builder.Release() + builder.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3, 4, 5}, nil) + builder.Field(1).(*array.StringBuilder).AppendValues([]string{"a", "b", "c", "d", "e"}, nil) + builder.Field(2).(*array.Float64Builder).AppendValues([]float64{1, 0, 3, 0, 5}, []bool{true, false, true, false, true}) + rec0 := builder.NewRecord() + defer rec0.Release() + recs := []arrow.Record{rec0} + + w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema())) + for _, r := range recs { + w.Write(r) + } + + return nil +} + +type callHeadersMiddleware struct { + outgoingMDOk bool + outgoingMD metadata.MD +} + +func (c *callHeadersMiddleware) StartCall(ctx context.Context) context.Context { + c.outgoingMD, c.outgoingMDOk = metadata.FromOutgoingContext(ctx) + return ctx +} + +func (c *callHeadersMiddleware) CallCompleted(ctx context.Context, err error) { +} + +func (c *callHeadersMiddleware) HeadersReceived(ctx context.Context, md metadata.MD) { +}