Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for custom query/gRPC headers #76

Merged
merged 17 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

1. [#68](https://github.com/InfluxCommunity/influxdb3-go/pull/68): Upgrade Go version to 1.22.
2. [#74](https://github.com/InfluxCommunity/influxdb3-go/pull/74): Use `log/slog` to print debug information instead of `fmt.Printf`
3. [#76](https://github.com/InfluxCommunity/influxdb3-go/pull/76): Add custom headers support for queries (gRPC requests)

## 0.6.0 [2024-03-01]

Expand Down
6 changes: 3 additions & 3 deletions influxdb3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type Client struct {
// Cached base server API URL.
apiURL *url.URL
// Flight client for executing queries
queryClient *flight.Client
queryClient flight.Client
}

// httpParams holds parameters for creating an HTTP request
Expand Down Expand Up @@ -174,7 +174,7 @@ func (c *Client) makeAPICall(ctx context.Context, params httpParams) (*http.Resp
}
req.Header.Set("User-Agent", userAgent)
if c.authorization != "" {
req.Header.Add("Authorization", c.authorization)
req.Header.Set("Authorization", c.authorization)
}

resp, err := c.config.HTTPClient.Do(req)
Expand Down Expand Up @@ -238,6 +238,6 @@ func (c *Client) resolveHTTPError(r *http.Response) error {
// Close closes all idle connections.
func (c *Client) Close() error {
c.config.HTTPClient.CloseIdleConnections()
err := (*c.queryClient).Close()
err := c.queryClient.Close()
return err
}
84 changes: 84 additions & 0 deletions influxdb3/example_query_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
The MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/

package influxdb3

import (
"context"
"log"
)

func ExampleClient_Query() {
client, err := NewFromEnv()
if err != nil {
log.Fatal(err)
}
defer client.Close()

// query
iterator, err := client.Query(context.Background(),
"SELECT count(*) FROM weather WHERE time >= now() - interval '5 minutes'")

for iterator.Next() {
// process the result
}

// query with custom header
iterator, err = client.Query(context.Background(),
"SELECT count(*) FROM stat WHERE time >= now() - interval '5 minutes'",
WithHeader("X-trace-ID", "#0122"))

for iterator.Next() {
// process the result
}
}

func ExampleClient_QueryWithParameters() {
client, err := NewFromEnv()
if err != nil {
log.Fatal(err)
}
defer client.Close()

// query
iterator, err := client.QueryWithParameters(context.Background(),
"SELECT count(*) FROM weather WHERE location = $location AND time >= now() - interval '5 minutes'",
QueryParameters{
"location": "sun-valley-1",
})

for iterator.Next() {
// process the result
}

// query with custom header
iterator, err = client.QueryWithParameters(context.Background(),
"SELECT count(*) FROM weather WHERE location = $location AND time >= now() - interval '5 minutes'",
QueryParameters{
"location": "sun-valley-1",
},
WithHeader("X-trace-ID", "#0122"))

for iterator.Next() {
// process the result
}
}
16 changes: 16 additions & 0 deletions influxdb3/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
package influxdb3

import (
"net/http"

"github.com/influxdata/line-protocol/v2/lineprotocol"
)

Expand All @@ -33,6 +35,9 @@ type QueryOptions struct {

// Query type.
QueryType QueryType

// Headers to be included in requests. Use to add or override headers in `ClientConfig`.
Headers http.Header
}

// WriteOptions holds options for write
Expand Down Expand Up @@ -110,6 +115,7 @@ type Option func(o *options)
// Available options:
// - WithDatabase
// - WithQueryType
// - WithHeader
type QueryOption = Option

// WriteOption is a functional option type that can be passed to Client.Write methods.
Expand All @@ -135,6 +141,16 @@ func WithQueryType(queryType QueryType) Option {
}
}

// WithHeader is used to add or override default header in Client.Query method.
func WithHeader(key, value string) Option {
return func(o *options) {
if o.Headers == nil {
o.Headers = make(http.Header, 0)
}
o.Headers[key] = []string{value}
}
}

// WithPrecision is used to override default precision in Client.Write methods.
func WithPrecision(precision lineprotocol.Precision) Option {
return func(o *options) {
Expand Down
20 changes: 20 additions & 0 deletions influxdb3/options_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package influxdb3

import (
"net/http"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -40,6 +41,25 @@ func TestQueryOptions(t *testing.T) {
QueryType: InfluxQL,
},
},
{
name: "add header",
opts: va(WithHeader("header-a", "value-a")),
want: &QueryOptions{
Headers: http.Header{
"header-a": {"value-a"},
},
},
},
{
name: "add headers",
opts: va(WithHeader("header-a", "value-a"), WithHeader("header-b", "value-b")),
want: &QueryOptions{
Headers: http.Header{
"header-a": {"value-a"},
"header-b": {"value-b"},
},
},
},
}

for _, tc := range testCases {
Expand Down
23 changes: 19 additions & 4 deletions influxdb3/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,15 @@ func (c *Client) initializeQueryClient() error {
if err != nil {
return fmt.Errorf("flight: %s", err)
}
c.queryClient = &client
c.queryClient = client

return nil
}

func (c *Client) setQueryClient(flightClient flight.Client) {
c.queryClient = flightClient
}

// QueryParameters is a type for query parameters.
type QueryParameters = map[string]any

Expand Down Expand Up @@ -132,8 +136,19 @@ func (c *Client) query(ctx context.Context, query string, parameters QueryParame
var queryType QueryType
queryType = options.QueryType

ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+c.config.Token)
ctx = metadata.AppendToOutgoingContext(ctx, "database", database)
md := make(metadata.MD, 0)
for k, v := range c.config.Headers {
for _, value := range v {
md.Append(k, value)
}
}
for k, v := range options.Headers {
for _, value := range v {
md.Append(k, value)
}
}
md.Set("authorization", "Bearer "+c.config.Token)
ctx = metadata.NewOutgoingContext(ctx, md)

ticketData := map[string]interface{}{
"database": database,
Expand All @@ -151,7 +166,7 @@ func (c *Client) query(ctx context.Context, query string, parameters QueryParame
}

ticket := &flight.Ticket{Ticket: ticketJSON}
stream, err := (*c.queryClient).DoGet(ctx, ticket)
stream, err := c.queryClient.DoGet(ctx, ticket)
if err != nil {
return nil, fmt.Errorf("flight do get: %s", err)
}
Expand Down
95 changes: 95 additions & 0 deletions influxdb3/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,19 @@ package influxdb3

import (
"context"
"net/http"
"testing"

"github.com/apache/arrow/go/v15/arrow"
"github.com/apache/arrow/go/v15/arrow/array"
"github.com/apache/arrow/go/v15/arrow/flight"
"github.com/apache/arrow/go/v15/arrow/ipc"
"github.com/apache/arrow/go/v15/arrow/memory"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)

func TestQueryDatabaseNotSet(t *testing.T) {
Expand All @@ -54,3 +63,89 @@ func TestQueryWithOptionsNotSet(t *testing.T) {
assert.Error(t, err)
assert.EqualError(t, err, "options not set")
}

func TestQueryWithCustomHeader(t *testing.T) {
s := flight.NewServerWithMiddleware(nil)
err := s.Init("localhost:18080")
require.NoError(t, err)
f := &flightServer{}
s.RegisterFlightService(f)

go s.Serve()
defer s.Shutdown()

middleware := &callHeadersMiddleware{}
fc, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, []flight.ClientMiddleware{
flight.CreateClientMiddleware(middleware),
}, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer fc.Close()

c, err := New(ClientConfig{
Host: "http://localhost:80",
Token: "my-token",
Database: "my-database",
Headers: http.Header{
"my-config-header": {"hdr-config-1"},
},
})
require.NoError(t, err)
defer c.Close()

c.setQueryClient(fc)

_, err = c.Query(context.Background(), "SELECT * FROM nothing", WithHeader("my-call-header", "hdr-call-1"))
require.NoError(t, err, "DoGet success")
assert.True(t, middleware.outgoingMDOk, "context contains outgoing MD")
assert.NotNil(t, middleware.outgoingMD, "outgoing MD is not nil")
assert.Contains(t, middleware.outgoingMD, "authorization", "auth header present")
assert.Contains(t, middleware.outgoingMD, "my-config-header", "custom config header present")
assert.Equal(t, []string{"hdr-config-1"}, middleware.outgoingMD["my-config-header"], "custom config header value")
assert.Contains(t, middleware.outgoingMD, "my-call-header", "custom call header present")
assert.Equal(t, []string{"hdr-call-1"}, middleware.outgoingMD["my-call-header"],"custom call header value")
}

// fake Flight server implementation

type flightServer struct {
flight.BaseFlightServer
}

func (f *flightServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
schema := arrow.NewSchema([]arrow.Field{
{Name: "intField", Type: arrow.PrimitiveTypes.Int64, Nullable: false},
{Name: "stringField", Type: arrow.BinaryTypes.String, Nullable: false},
{Name: "floatField", Type: arrow.PrimitiveTypes.Float64, Nullable: true},
}, nil)
builder := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer builder.Release()
builder.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3, 4, 5}, nil)
builder.Field(1).(*array.StringBuilder).AppendValues([]string{"a", "b", "c", "d", "e"}, nil)
builder.Field(2).(*array.Float64Builder).AppendValues([]float64{1, 0, 3, 0, 5}, []bool{true, false, true, false, true})
rec0 := builder.NewRecord()
defer rec0.Release()
recs := []arrow.Record{rec0}

w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema()))
for _, r := range recs {
w.Write(r)
}

return nil
}

type callHeadersMiddleware struct {
outgoingMDOk bool
outgoingMD metadata.MD
}

func (c *callHeadersMiddleware) StartCall(ctx context.Context) context.Context {
c.outgoingMD, c.outgoingMDOk = metadata.FromOutgoingContext(ctx)
return ctx
}

func (c *callHeadersMiddleware) CallCompleted(ctx context.Context, err error) {
}

func (c *callHeadersMiddleware) HeadersReceived(ctx context.Context, md metadata.MD) {
}