From 362d65f95243297833c428980bbc8ef928d79ad3 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 20 Aug 2024 22:26:08 -0400 Subject: [PATCH] test(go/adbc/driver/flightsql): more testing of header parameter --- .../driver/flightsql/flightsql_adbc_test.go | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go index 5f34d3410d..927c67b039 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go @@ -36,6 +36,7 @@ import ( "os" "path/filepath" "runtime" + "slices" "strings" "testing" "time" @@ -306,6 +307,78 @@ func TestADBCFlightSQL(t *testing.T) { suite.Run(t, &DomainSocketTests{db: db}) } +// Run the test suite, but validating that a header set on the database is ALWAYS passed + +type FlightSQLWithHeaderQuirks struct { + FlightSQLQuirks +} + +func (s *FlightSQLWithHeaderQuirks) SetupDriver(t *testing.T) adbc.Driver { + var err error + s.mem = memory.NewCheckedAllocator(memory.DefaultAllocator) + // Enforce that a particular header is present on ALL requests + s.s = flight.NewServerWithMiddleware([]flight.ServerMiddleware{ + { + Unary: func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if md, ok := metadata.FromIncomingContext(ctx); ok { + vals := md.Get("x-expected") + if slices.Contains(vals, "open sesame") { + return handler(ctx, req) + } + } + return nil, fmt.Errorf("missing expected header") + }, + Stream: func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + ctx := stream.Context() + if md, ok := metadata.FromIncomingContext(ctx); ok { + vals := md.Get("x-expected") + if slices.Contains(vals, "open sesame") { + return handler(srv, stream) + } + } + return fmt.Errorf("missing expected header") + }, + }, + }, s.opts...) + require.NoError(t, err) + s.srv, err = example.NewSQLiteFlightSQLServer(s.db) + require.NoError(t, err) + s.srv.Alloc = s.mem + + s.s.RegisterFlightService(flightsql.NewFlightServer(s.srv)) + require.NoError(t, s.s.Init("localhost:0")) + s.s.SetShutdownOnSignals(os.Interrupt, os.Kill) + s.done = make(chan bool) + go func() { + defer close(s.done) + _ = s.s.Serve() + }() + + return driver.NewDriver(s.mem) +} + +func (s *FlightSQLWithHeaderQuirks) DatabaseOptions() map[string]string { + return map[string]string{ + adbc.OptionKeyURI: "grpc+tcp://" + s.s.Addr().String(), + driver.OptionRPCCallHeaderPrefix + "x-expected": "open sesame", + } +} + +func TestADBCFlightSQLWithHeader(t *testing.T) { + // XXX: arrow-go uses a shared DB so CreateDB can't be called more than once in a process + db, err := sql.Open("sqlite", "file:adbcwithheader?mode=memory&cache=private") + require.NoError(t, err) + defer db.Close() + + q := &FlightSQLWithHeaderQuirks{FlightSQLQuirks{db: db}} + suite.Run(t, &validation.DatabaseTests{Quirks: q}) + suite.Run(t, &validation.ConnectionTests{Quirks: q}) + suite.Run(t, &validation.StatementTests{Quirks: q}) + suite.Run(t, &OptionTests{Quirks: q}) + suite.Run(t, &PartitionTests{Quirks: q}) + suite.Run(t, &StatementTests{Quirks: q}) +} + // Driver-specific tests type DefaultDialOptionsTests struct {