Skip to content

Commit

Permalink
feat: add support for custom query/gRPC headers (#76)
Browse files Browse the repository at this point in the history
* fix: make sure Authorization header is set

* feat: add WithHeader query option

* feat: add custom headers to query/gRPC request

* fix: database header not needed in query header

* fix: optimized query metadata preparation

* fix: remove unused methods

* refactor: flight.Client is an interface

* docs: update CHANGELOG
  • Loading branch information
alespour authored Apr 12, 2024
1 parent 567467e commit bd8ca70
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

1. [#68](https://github.com/InfluxCommunity/influxdb3-go/pull/68): Upgrade Go version to 1.22.
2. [#74](https://github.com/InfluxCommunity/influxdb3-go/pull/74): Use `log/slog` to print debug information instead of `fmt.Printf`
3. [#76](https://github.com/InfluxCommunity/influxdb3-go/pull/76): Add custom headers support for queries (gRPC requests)

## 0.6.0 [2024-03-01]

Expand Down
6 changes: 3 additions & 3 deletions influxdb3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type Client struct {
// Cached base server API URL.
apiURL *url.URL
// Flight client for executing queries
queryClient *flight.Client
queryClient flight.Client
}

// httpParams holds parameters for creating an HTTP request
Expand Down Expand Up @@ -174,7 +174,7 @@ func (c *Client) makeAPICall(ctx context.Context, params httpParams) (*http.Resp
}
req.Header.Set("User-Agent", userAgent)
if c.authorization != "" {
req.Header.Add("Authorization", c.authorization)
req.Header.Set("Authorization", c.authorization)
}

resp, err := c.config.HTTPClient.Do(req)
Expand Down Expand Up @@ -238,6 +238,6 @@ func (c *Client) resolveHTTPError(r *http.Response) error {
// Close closes all idle connections.
func (c *Client) Close() error {
c.config.HTTPClient.CloseIdleConnections()
err := (*c.queryClient).Close()
err := c.queryClient.Close()
return err
}
84 changes: 84 additions & 0 deletions influxdb3/example_query_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
The MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/

package influxdb3

import (
"context"
"log"
)

func ExampleClient_Query() {
client, err := NewFromEnv()
if err != nil {
log.Fatal(err)
}
defer client.Close()

// query
iterator, err := client.Query(context.Background(),
"SELECT count(*) FROM weather WHERE time >= now() - interval '5 minutes'")

for iterator.Next() {
// process the result
}

// query with custom header
iterator, err = client.Query(context.Background(),
"SELECT count(*) FROM stat WHERE time >= now() - interval '5 minutes'",
WithHeader("X-trace-ID", "#0122"))

for iterator.Next() {
// process the result
}
}

func ExampleClient_QueryWithParameters() {
client, err := NewFromEnv()
if err != nil {
log.Fatal(err)
}
defer client.Close()

// query
iterator, err := client.QueryWithParameters(context.Background(),
"SELECT count(*) FROM weather WHERE location = $location AND time >= now() - interval '5 minutes'",
QueryParameters{
"location": "sun-valley-1",
})

for iterator.Next() {
// process the result
}

// query with custom header
iterator, err = client.QueryWithParameters(context.Background(),
"SELECT count(*) FROM weather WHERE location = $location AND time >= now() - interval '5 minutes'",
QueryParameters{
"location": "sun-valley-1",
},
WithHeader("X-trace-ID", "#0122"))

for iterator.Next() {
// process the result
}
}
16 changes: 16 additions & 0 deletions influxdb3/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
package influxdb3

import (
"net/http"

"github.com/influxdata/line-protocol/v2/lineprotocol"
)

Expand All @@ -33,6 +35,9 @@ type QueryOptions struct {

// Query type.
QueryType QueryType

// Headers to be included in requests. Use to add or override headers in `ClientConfig`.
Headers http.Header
}

// WriteOptions holds options for write
Expand Down Expand Up @@ -110,6 +115,7 @@ type Option func(o *options)
// Available options:
// - WithDatabase
// - WithQueryType
// - WithHeader
type QueryOption = Option

// WriteOption is a functional option type that can be passed to Client.Write methods.
Expand All @@ -135,6 +141,16 @@ func WithQueryType(queryType QueryType) Option {
}
}

// WithHeader is used to add or override default header in Client.Query method.
func WithHeader(key, value string) Option {
return func(o *options) {
if o.Headers == nil {
o.Headers = make(http.Header, 0)
}
o.Headers[key] = []string{value}
}
}

// WithPrecision is used to override default precision in Client.Write methods.
func WithPrecision(precision lineprotocol.Precision) Option {
return func(o *options) {
Expand Down
20 changes: 20 additions & 0 deletions influxdb3/options_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package influxdb3

import (
"net/http"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -40,6 +41,25 @@ func TestQueryOptions(t *testing.T) {
QueryType: InfluxQL,
},
},
{
name: "add header",
opts: va(WithHeader("header-a", "value-a")),
want: &QueryOptions{
Headers: http.Header{
"header-a": {"value-a"},
},
},
},
{
name: "add headers",
opts: va(WithHeader("header-a", "value-a"), WithHeader("header-b", "value-b")),
want: &QueryOptions{
Headers: http.Header{
"header-a": {"value-a"},
"header-b": {"value-b"},
},
},
},
}

for _, tc := range testCases {
Expand Down
23 changes: 19 additions & 4 deletions influxdb3/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,15 @@ func (c *Client) initializeQueryClient() error {
if err != nil {
return fmt.Errorf("flight: %s", err)
}
c.queryClient = &client
c.queryClient = client

return nil
}

func (c *Client) setQueryClient(flightClient flight.Client) {
c.queryClient = flightClient
}

// QueryParameters is a type for query parameters.
type QueryParameters = map[string]any

Expand Down Expand Up @@ -132,8 +136,19 @@ func (c *Client) query(ctx context.Context, query string, parameters QueryParame
var queryType QueryType
queryType = options.QueryType

ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+c.config.Token)
ctx = metadata.AppendToOutgoingContext(ctx, "database", database)
md := make(metadata.MD, 0)
for k, v := range c.config.Headers {
for _, value := range v {
md.Append(k, value)
}
}
for k, v := range options.Headers {
for _, value := range v {
md.Append(k, value)
}
}
md.Set("authorization", "Bearer "+c.config.Token)
ctx = metadata.NewOutgoingContext(ctx, md)

ticketData := map[string]interface{}{
"database": database,
Expand All @@ -151,7 +166,7 @@ func (c *Client) query(ctx context.Context, query string, parameters QueryParame
}

ticket := &flight.Ticket{Ticket: ticketJSON}
stream, err := (*c.queryClient).DoGet(ctx, ticket)
stream, err := c.queryClient.DoGet(ctx, ticket)
if err != nil {
return nil, fmt.Errorf("flight do get: %s", err)
}
Expand Down
95 changes: 95 additions & 0 deletions influxdb3/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,19 @@ package influxdb3

import (
"context"
"net/http"
"testing"

"github.com/apache/arrow/go/v15/arrow"
"github.com/apache/arrow/go/v15/arrow/array"
"github.com/apache/arrow/go/v15/arrow/flight"
"github.com/apache/arrow/go/v15/arrow/ipc"
"github.com/apache/arrow/go/v15/arrow/memory"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)

func TestQueryDatabaseNotSet(t *testing.T) {
Expand All @@ -54,3 +63,89 @@ func TestQueryWithOptionsNotSet(t *testing.T) {
assert.Error(t, err)
assert.EqualError(t, err, "options not set")
}

func TestQueryWithCustomHeader(t *testing.T) {
s := flight.NewServerWithMiddleware(nil)
err := s.Init("localhost:18080")
require.NoError(t, err)
f := &flightServer{}
s.RegisterFlightService(f)

go s.Serve()
defer s.Shutdown()

middleware := &callHeadersMiddleware{}
fc, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, []flight.ClientMiddleware{
flight.CreateClientMiddleware(middleware),
}, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer fc.Close()

c, err := New(ClientConfig{
Host: "http://localhost:80",
Token: "my-token",
Database: "my-database",
Headers: http.Header{
"my-config-header": {"hdr-config-1"},
},
})
require.NoError(t, err)
defer c.Close()

c.setQueryClient(fc)

_, err = c.Query(context.Background(), "SELECT * FROM nothing", WithHeader("my-call-header", "hdr-call-1"))
require.NoError(t, err, "DoGet success")
assert.True(t, middleware.outgoingMDOk, "context contains outgoing MD")
assert.NotNil(t, middleware.outgoingMD, "outgoing MD is not nil")
assert.Contains(t, middleware.outgoingMD, "authorization", "auth header present")
assert.Contains(t, middleware.outgoingMD, "my-config-header", "custom config header present")
assert.Equal(t, []string{"hdr-config-1"}, middleware.outgoingMD["my-config-header"], "custom config header value")
assert.Contains(t, middleware.outgoingMD, "my-call-header", "custom call header present")
assert.Equal(t, []string{"hdr-call-1"}, middleware.outgoingMD["my-call-header"],"custom call header value")
}

// fake Flight server implementation

type flightServer struct {
flight.BaseFlightServer
}

func (f *flightServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
schema := arrow.NewSchema([]arrow.Field{
{Name: "intField", Type: arrow.PrimitiveTypes.Int64, Nullable: false},
{Name: "stringField", Type: arrow.BinaryTypes.String, Nullable: false},
{Name: "floatField", Type: arrow.PrimitiveTypes.Float64, Nullable: true},
}, nil)
builder := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer builder.Release()
builder.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3, 4, 5}, nil)
builder.Field(1).(*array.StringBuilder).AppendValues([]string{"a", "b", "c", "d", "e"}, nil)
builder.Field(2).(*array.Float64Builder).AppendValues([]float64{1, 0, 3, 0, 5}, []bool{true, false, true, false, true})
rec0 := builder.NewRecord()
defer rec0.Release()
recs := []arrow.Record{rec0}

w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema()))
for _, r := range recs {
w.Write(r)
}

return nil
}

type callHeadersMiddleware struct {
outgoingMDOk bool
outgoingMD metadata.MD
}

func (c *callHeadersMiddleware) StartCall(ctx context.Context) context.Context {
c.outgoingMD, c.outgoingMDOk = metadata.FromOutgoingContext(ctx)
return ctx
}

func (c *callHeadersMiddleware) CallCompleted(ctx context.Context, err error) {
}

func (c *callHeadersMiddleware) HeadersReceived(ctx context.Context, md metadata.MD) {
}

0 comments on commit bd8ca70

Please sign in to comment.