From 362d65f95243297833c428980bbc8ef928d79ad3 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 20 Aug 2024 22:26:08 -0400
Subject: [PATCH] test(go/adbc/driver/flightsql): more testing of header
parameter
---
.../driver/flightsql/flightsql_adbc_test.go | 73 +++++++++++++++++++
1 file changed, 73 insertions(+)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 5f34d3410d..927c67b039 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -36,6 +36,7 @@ import (
"os"
"path/filepath"
"runtime"
+ "slices"
"strings"
"testing"
"time"
@@ -306,6 +307,78 @@ func TestADBCFlightSQL(t *testing.T) {
suite.Run(t, &DomainSocketTests{db: db})
}
+// Run the test suite, but validating that a header set on the database is ALWAYS passed
+
+type FlightSQLWithHeaderQuirks struct {
+ FlightSQLQuirks
+}
+
+func (s *FlightSQLWithHeaderQuirks) SetupDriver(t *testing.T) adbc.Driver {
+ var err error
+ s.mem = memory.NewCheckedAllocator(memory.DefaultAllocator)
+ // Enforce that a particular header is present on ALL requests
+ s.s = flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+ {
+ Unary: func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+ if md, ok := metadata.FromIncomingContext(ctx); ok {
+ vals := md.Get("x-expected")
+ if slices.Contains(vals, "open sesame") {
+ return handler(ctx, req)
+ }
+ }
+ return nil, fmt.Errorf("missing expected header")
+ },
+ Stream: func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ ctx := stream.Context()
+ if md, ok := metadata.FromIncomingContext(ctx); ok {
+ vals := md.Get("x-expected")
+ if slices.Contains(vals, "open sesame") {
+ return handler(srv, stream)
+ }
+ }
+ return fmt.Errorf("missing expected header")
+ },
+ },
+ }, s.opts...)
+ require.NoError(t, err)
+ s.srv, err = example.NewSQLiteFlightSQLServer(s.db)
+ require.NoError(t, err)
+ s.srv.Alloc = s.mem
+
+ s.s.RegisterFlightService(flightsql.NewFlightServer(s.srv))
+ require.NoError(t, s.s.Init("localhost:0"))
+ s.s.SetShutdownOnSignals(os.Interrupt, os.Kill)
+ s.done = make(chan bool)
+ go func() {
+ defer close(s.done)
+ _ = s.s.Serve()
+ }()
+
+ return driver.NewDriver(s.mem)
+}
+
+func (s *FlightSQLWithHeaderQuirks) DatabaseOptions() map[string]string {
+ return map[string]string{
+ adbc.OptionKeyURI: "grpc+tcp://" + s.s.Addr().String(),
+ driver.OptionRPCCallHeaderPrefix + "x-expected": "open sesame",
+ }
+}
+
+func TestADBCFlightSQLWithHeader(t *testing.T) {
+ // XXX: arrow-go uses a shared DB so CreateDB can't be called more than once in a process
+ db, err := sql.Open("sqlite", "file:adbcwithheader?mode=memory&cache=private")
+ require.NoError(t, err)
+ defer db.Close()
+
+ q := &FlightSQLWithHeaderQuirks{FlightSQLQuirks{db: db}}
+ suite.Run(t, &validation.DatabaseTests{Quirks: q})
+ suite.Run(t, &validation.ConnectionTests{Quirks: q})
+ suite.Run(t, &validation.StatementTests{Quirks: q})
+ suite.Run(t, &OptionTests{Quirks: q})
+ suite.Run(t, &PartitionTests{Quirks: q})
+ suite.Run(t, &StatementTests{Quirks: q})
+}
+
// Driver-specific tests
type DefaultDialOptionsTests struct {