Skip to content

Commit

Permalink
fix(go/adbc/driver/flightsql): propagate headers in GetObjects
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Jul 9, 2024
1 parent 66ecd33 commit e2f3675
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 3 deletions.
173 changes: 172 additions & 1 deletion go/adbc/driver/flightsql/cmd/testserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,36 @@ import (
"github.com/apache/arrow/go/v17/arrow/array"
"github.com/apache/arrow/go/v17/arrow/flight"
"github.com/apache/arrow/go/v17/arrow/flight/flightsql"
"github.com/apache/arrow/go/v17/arrow/flight/flightsql/schema_ref"
"github.com/apache/arrow/go/v17/arrow/memory"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb"
)

type RecordedHeader struct {
method string
header string
value string
}

type ExampleServer struct {
flightsql.BaseServer

mu sync.Mutex
pollingStatus map[string]int
headers []RecordedHeader
}

var recordedHeadersSchema = arrow.NewSchema([]arrow.Field{
{Name: "method", Type: arrow.BinaryTypes.String, Nullable: false},
{Name: "header", Type: arrow.BinaryTypes.String, Nullable: false},
{Name: "value", Type: arrow.BinaryTypes.String, Nullable: false},
}, nil)

func StatusWithDetail(code codes.Code, message string, details ...proto.Message) error {
p := status.New(code, message).Proto()
// Have to do this by hand because gRPC uses deprecated proto import
Expand All @@ -64,11 +79,31 @@ func StatusWithDetail(code codes.Code, message string, details ...proto.Message)
return status.FromProto(p).Err()
}

func (srv *ExampleServer) recordHeaders(ctx context.Context, method string) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
panic("Misuse of recordHeaders")
}

srv.mu.Lock()
defer srv.mu.Unlock()
for k, vv := range md {
for _, v := range vv {
log.Printf("Header: %s: %s = %s\n", method, k, v)
srv.headers = append(srv.headers, RecordedHeader{
method: method, header: k, value: v,
})
}
}
}

func (srv *ExampleServer) ClosePreparedStatement(ctx context.Context, request flightsql.ActionClosePreparedStatementRequest) error {
srv.recordHeaders(ctx, "ClosePreparedStatement")
return nil
}

func (srv *ExampleServer) CreatePreparedStatement(ctx context.Context, req flightsql.ActionCreatePreparedStatementRequest) (result flightsql.ActionCreatePreparedStatementResult, err error) {
srv.recordHeaders(ctx, "CreatePreparedStatement")
switch req.GetQuery() {
case "error_create_prepared_statement":
err = status.Error(codes.InvalidArgument, "expected error (DoAction)")
Expand All @@ -83,7 +118,8 @@ func (srv *ExampleServer) CreatePreparedStatement(ctx context.Context, req fligh
return
}

func (srv *ExampleServer) GetFlightInfoPreparedStatement(_ context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
func (srv *ExampleServer) GetFlightInfoPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
srv.recordHeaders(ctx, "GetFlightInfoPreparedStatement")
switch string(cmd.GetPreparedStatementHandle()) {
case "error_do_get", "error_do_get_stream", "error_do_get_detail", "error_do_get_stream_detail", "forever":
schema := arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
Expand Down Expand Up @@ -111,6 +147,7 @@ func (srv *ExampleServer) GetFlightInfoPreparedStatement(_ context.Context, cmd
}

func (srv *ExampleServer) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
srv.recordHeaders(ctx, "GetFlightInfoStatement")
ticket, err := flightsql.CreateStatementQueryTicket(desc.Cmd)
if err != nil {
return nil, err
Expand Down Expand Up @@ -239,6 +276,7 @@ func (srv *ExampleServer) PollFlightInfoPreparedStatement(ctx context.Context, q
}

func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) {
srv.recordHeaders(ctx, "DoGetPreparedStatement")
log.Printf("DoGetPreparedStatement: %v", cmd.GetPreparedStatementHandle())
switch string(cmd.GetPreparedStatementHandle()) {
case "error_do_get":
Expand Down Expand Up @@ -271,6 +309,45 @@ func (srv *ExampleServer) DoGetPreparedStatement(ctx context.Context, cmd flight
case "stateless_prepared_statement":
err = status.Error(codes.InvalidArgument, "client didn't use the updated handle")
return
case "recorded_headers":
schema = recordedHeadersSchema
ch := make(chan flight.StreamChunk)

methods := array.NewStringBuilder(srv.Alloc)
headers := array.NewStringBuilder(srv.Alloc)
values := array.NewStringBuilder(srv.Alloc)
defer methods.Release()
defer headers.Release()
defer values.Release()

srv.mu.Lock()
defer srv.mu.Unlock()

count := int64(0)
for _, recorded := range srv.headers {
count++
methods.AppendString(recorded.method)
headers.AppendString(recorded.header)
values.AppendString(recorded.value)
}
srv.headers = make([]RecordedHeader, 0)

rec := array.NewRecord(recordedHeadersSchema, []arrow.Array{
methods.NewArray(),
headers.NewArray(),
values.NewArray(),
}, count)

go func() {
defer close(ch)
ch <- flight.StreamChunk{
Data: rec,
Desc: nil,
Err: nil,
}
}()
out = ch
return
}

schema = arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
Expand Down Expand Up @@ -341,6 +418,100 @@ func (srv *ExampleServer) DoPutPreparedStatementUpdate(context.Context, flightsq
return 0, status.Error(codes.Unimplemented, "DoPutPreparedStatementUpdate not implemented")
}

func (srv *ExampleServer) GetFlightInfoCatalogs(ctx context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
srv.recordHeaders(ctx, "GetFlightInfoCatalogs")
return &flight.FlightInfo{
Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}},
FlightDescriptor: desc,
Schema: flight.SerializeSchema(schema_ref.Catalogs, srv.Alloc),
TotalRecords: -1,
TotalBytes: -1,
}, nil
}

func (srv *ExampleServer) DoGetCatalogs(ctx context.Context) (*arrow.Schema, <-chan flight.StreamChunk, error) {
srv.recordHeaders(ctx, "DoGetCatalogs")

// Just return some dummy data
schema := schema_ref.Catalogs
ch := make(chan flight.StreamChunk, 1)
catalogs, _, err := array.FromJSON(srv.Alloc, arrow.BinaryTypes.String, strings.NewReader(`["catalog"]`))
if err != nil {
return nil, nil, err
}
defer catalogs.Release()

batch := array.NewRecord(schema, []arrow.Array{catalogs}, 1)
ch <- flight.StreamChunk{Data: batch}
close(ch)
return schema, ch, nil
}

func (srv *ExampleServer) GetFlightInfoSchemas(ctx context.Context, req flightsql.GetDBSchemas, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
srv.recordHeaders(ctx, "GetFlightInfoDBSchemas")
return &flight.FlightInfo{
Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}},
FlightDescriptor: desc,
Schema: flight.SerializeSchema(schema_ref.DBSchemas, srv.Alloc),
TotalRecords: -1,
TotalBytes: -1,
}, nil
}

func (srv *ExampleServer) DoGetDBSchemas(ctx context.Context, req flightsql.GetDBSchemas) (*arrow.Schema, <-chan flight.StreamChunk, error) {
srv.recordHeaders(ctx, "DoGetDBSchemas")

// Just return some dummy data
schema := schema_ref.DBSchemas
ch := make(chan flight.StreamChunk, 1)
// Not really a proper match, but good enough
if req.GetDBSchemaFilterPattern() == nil || *req.GetDBSchemaFilterPattern() == "" || *req.GetDBSchemaFilterPattern() == "main" {
catalogs, _, err := array.FromJSON(srv.Alloc, arrow.BinaryTypes.String, strings.NewReader(`["main"]`))
if err != nil {
return nil, nil, err
}
defer catalogs.Release()

dbSchemas, _, err := array.FromJSON(srv.Alloc, arrow.BinaryTypes.String, strings.NewReader(`[""]`))
if err != nil {
return nil, nil, err
}
defer dbSchemas.Release()

batch := array.NewRecord(schema, []arrow.Array{catalogs, dbSchemas}, 1)
ch <- flight.StreamChunk{Data: batch}
}
close(ch)
return schema, ch, nil
}

func (srv *ExampleServer) GetFlightInfoTables(ctx context.Context, req flightsql.GetTables, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
srv.recordHeaders(ctx, "GetFlightInfoTables")
schema := schema_ref.Tables
if req.GetIncludeSchema() {
schema = schema_ref.TablesWithIncludedSchema
}
return &flight.FlightInfo{
Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: desc.Cmd}}},
FlightDescriptor: desc,
Schema: flight.SerializeSchema(schema, srv.Alloc),
TotalRecords: -1,
TotalBytes: -1,
}, nil
}

func (srv *ExampleServer) DoGetTables(ctx context.Context, req flightsql.GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) {
srv.recordHeaders(ctx, "DoGetTables")
// Just return some dummy data
schema := schema_ref.Tables
if req.GetIncludeSchema() {
schema = schema_ref.TablesWithIncludedSchema
}
ch := make(chan flight.StreamChunk, 1)
close(ch)
return schema, ch, nil
}

func main() {
var (
host = flag.String("host", "localhost", "hostname to bind to")
Expand Down
5 changes: 4 additions & 1 deletion go/adbc/driver/flightsql/flightsql_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func (c *connectionImpl) setSessionOptions(ctx context.Context, key string, val
var header, trailer metadata.MD
errors, err := c.cl.SetSessionOptions(ctx, &req, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts)
if err != nil {
return adbcFromFlightStatusWithDetails(err, header, trailer, "GetSessionOptions")
return adbcFromFlightStatusWithDetails(err, header, trailer, "SetSessionOptions")
}
if len(errors.Errors) > 0 {
msg := strings.Builder{}
Expand Down Expand Up @@ -635,6 +635,7 @@ func (c *connectionImpl) GetObjectsCatalogs(ctx context.Context, catalog *string
header, trailer metadata.MD
numCatalogs int64
)
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
// To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response.
info, err := c.cl.GetCatalogs(ctx, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts)
if err != nil {
Expand Down Expand Up @@ -675,6 +676,7 @@ func (c *connectionImpl) GetObjectsDbSchemas(ctx context.Context, depth adbc.Obj
if depth == adbc.ObjectDepthCatalogs {
return
}
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
result = make(map[string][]string)
var header, trailer metadata.MD
// Pre-populate the map of which schemas are in which catalogs
Expand Down Expand Up @@ -716,6 +718,7 @@ func (c *connectionImpl) GetObjectsTables(ctx context.Context, depth adbc.Object
if depth == adbc.ObjectDepthCatalogs || depth == adbc.ObjectDepthDBSchemas {
return
}
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
result = make(map[internal.CatalogAndSchema][]internal.TableInfo)

// Pre-populate the map of which schemas are in which catalogs
Expand Down
6 changes: 5 additions & 1 deletion python/adbc_driver_flightsql/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def dremio_dbapi(dremio_uri, dremio_user, dremio_pass):
adbc_driver_manager.DatabaseOptions.USERNAME.value: dremio_user,
adbc_driver_manager.DatabaseOptions.PASSWORD.value: dremio_pass,
},
autocommit=True,
) as conn:
yield conn

Expand All @@ -79,5 +80,8 @@ def test_dbapi():
if not uri:
pytest.skip("Set ADBC_TEST_FLIGHTSQL_URI to run tests")

with adbc_driver_flightsql.dbapi.connect(uri) as conn:
with adbc_driver_flightsql.dbapi.connect(
uri,
autocommit=True,
) as conn:
yield conn
29 changes: 29 additions & 0 deletions python/adbc_driver_flightsql/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import re
import secrets
import threading
import time

Expand Down Expand Up @@ -190,3 +191,31 @@ def test_stateless_prepared_statement(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_prepare("stateless_prepared_statement")
cur.execute("stateless_prepared_statement", parameters=[(1,)])


def test_header_propagation(test_dbapi) -> None:
header = "x-trace"
value = secrets.token_hex(16)
option = f"adbc.flight.sql.rpc.call_header.{header}"
test_dbapi.adbc_connection.set_options(**{option: value})

with test_dbapi.cursor() as cur:
cur.execute("recorded_headers")
headers = cur.fetchall()

with test_dbapi.adbc_get_objects():
pass

with test_dbapi.cursor() as cur:
cur.execute("recorded_headers")
headers = [x for x in cur.fetchall() if x[1] == header]

for method in [
"GetFlightInfoCatalogs",
"DoGetCatalogs",
"GetFlightInfoDBSchemas",
"DoGetDBSchemas",
"GetFlightInfoTables",
"DoGetTables",
]:
assert (method, header, value) in headers

0 comments on commit e2f3675

Please sign in to comment.