Skip to content

Commit

Permalink
test(go/adbc/driver/flightsql): more testing of header parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Aug 21, 2024
1 parent 1c84391 commit 362d65f
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -306,6 +307,78 @@ func TestADBCFlightSQL(t *testing.T) {
suite.Run(t, &DomainSocketTests{db: db})
}

// Run the test suite, but validating that a header set on the database is ALWAYS passed

type FlightSQLWithHeaderQuirks struct {
FlightSQLQuirks
}

func (s *FlightSQLWithHeaderQuirks) SetupDriver(t *testing.T) adbc.Driver {
var err error
s.mem = memory.NewCheckedAllocator(memory.DefaultAllocator)
// Enforce that a particular header is present on ALL requests
s.s = flight.NewServerWithMiddleware([]flight.ServerMiddleware{
{
Unary: func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if md, ok := metadata.FromIncomingContext(ctx); ok {
vals := md.Get("x-expected")
if slices.Contains(vals, "open sesame") {
return handler(ctx, req)
}
}
return nil, fmt.Errorf("missing expected header")
},
Stream: func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx := stream.Context()
if md, ok := metadata.FromIncomingContext(ctx); ok {
vals := md.Get("x-expected")
if slices.Contains(vals, "open sesame") {
return handler(srv, stream)
}
}
return fmt.Errorf("missing expected header")
},
},
}, s.opts...)
require.NoError(t, err)
s.srv, err = example.NewSQLiteFlightSQLServer(s.db)
require.NoError(t, err)
s.srv.Alloc = s.mem

s.s.RegisterFlightService(flightsql.NewFlightServer(s.srv))
require.NoError(t, s.s.Init("localhost:0"))
s.s.SetShutdownOnSignals(os.Interrupt, os.Kill)
s.done = make(chan bool)
go func() {
defer close(s.done)
_ = s.s.Serve()
}()

return driver.NewDriver(s.mem)
}

func (s *FlightSQLWithHeaderQuirks) DatabaseOptions() map[string]string {
return map[string]string{
adbc.OptionKeyURI: "grpc+tcp://" + s.s.Addr().String(),
driver.OptionRPCCallHeaderPrefix + "x-expected": "open sesame",
}
}

func TestADBCFlightSQLWithHeader(t *testing.T) {
// XXX: arrow-go uses a shared DB so CreateDB can't be called more than once in a process
db, err := sql.Open("sqlite", "file:adbcwithheader?mode=memory&cache=private")
require.NoError(t, err)
defer db.Close()

q := &FlightSQLWithHeaderQuirks{FlightSQLQuirks{db: db}}
suite.Run(t, &validation.DatabaseTests{Quirks: q})
suite.Run(t, &validation.ConnectionTests{Quirks: q})
suite.Run(t, &validation.StatementTests{Quirks: q})
suite.Run(t, &OptionTests{Quirks: q})
suite.Run(t, &PartitionTests{Quirks: q})
suite.Run(t, &StatementTests{Quirks: q})
}

// Driver-specific tests

type DefaultDialOptionsTests struct {
Expand Down

0 comments on commit 362d65f

Please sign in to comment.