Skip to content

Commit

Permalink
feat(go/adbc/driver/flightsql): expose FlightInfo during polling (apa…
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm authored Mar 1, 2024
1 parent bcbc161 commit d962961
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 4 deletions.
25 changes: 21 additions & 4 deletions go/adbc/driver/flightsql/cmd/testserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,32 @@ func (srv *ExampleServer) PollFlightInfo(ctx context.Context, desc *flight.Fligh
return nil, err
}

srv.pollingStatus[val.Value]--
progress := srv.pollingStatus[val.Value]

ticket, err := flightsql.CreateStatementQueryTicket([]byte(val.Value))
if err != nil {
return nil, err
}

endpoints := make([]*flight.FlightEndpoint, 5-progress)
if val.Value == "forever" {
srv.pollingStatus[val.Value]++
return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc),
Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}},
FlightDescriptor: desc,
TotalRecords: -1,
TotalBytes: -1,
AppMetadata: []byte("app metadata"),
},
FlightDescriptor: desc,
Progress: proto.Float64(float64(srv.pollingStatus[val.Value]) / 100.0),
}, nil
}

srv.pollingStatus[val.Value]--
progress := srv.pollingStatus[val.Value]

numEndpoints := 5 - progress
endpoints := make([]*flight.FlightEndpoint, numEndpoints)
for i := range endpoints {
endpoints[i] = &flight.FlightEndpoint{Ticket: &flight.Ticket{Ticket: ticket}}
}
Expand Down
94 changes: 94 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,32 @@ func (srv *IncrementalPollTestServer) PollFlightInfo(ctx context.Context, desc *
return nil, status.Errorf(codes.NotFound, "Query ID not found")
}

if query.query == "infinite" {
query.nextIndex++

descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: queryId})
if err != nil {
return nil, err
}
return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: nil,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{
Ticket: []byte{},
},
}},
AppMetadata: []byte("app metadata"),
},
FlightDescriptor: &flight.FlightDescriptor{
Type: flight.DescriptorCMD,
Cmd: descriptor,
},
// always makes a bit of progress, never gets anywhere
Progress: proto.Float64(float64(query.nextIndex) / 100.0),
}, nil
}

testCase, ok := srv.testCases[query.query]
if !ok {
if query.query == "unavailable" {
Expand Down Expand Up @@ -581,6 +607,32 @@ func (srv *IncrementalPollTestServer) PollFlightInfoStatement(ctx context.Contex
}

return srv.MakePollInfo(&unavailableCase, srv.queries[queryId], queryId)
} else if query.GetQuery() == "infinite" {
srv.queries[queryId] = &IncrementalQuery{
query: query.GetQuery(),
nextIndex: 0,
}

descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: queryId})
if err != nil {
return nil, err
}
return &flight.PollInfo{
Info: &flight.FlightInfo{
Schema: nil,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{
Ticket: []byte{},
},
}},
AppMetadata: []byte("app metadata"),
},
FlightDescriptor: &flight.FlightDescriptor{
Type: flight.DescriptorCMD,
Cmd: descriptor,
},
Progress: proto.Float64(0),
}, nil
}

testCase, ok := srv.testCases[query.GetQuery()]
Expand Down Expand Up @@ -790,6 +842,48 @@ func (ts *IncrementalPollTests) TestOptionValue() {
ts.Equal(adbc.StatusInvalidArgument, adbcErr.Code)
}

func (ts *IncrementalPollTests) TestAppMetadata() {
ctx, cancel := context.WithCancel(context.Background())
stmt, err := ts.cnxn.NewStatement()
ts.NoError(err)
defer stmt.Close()

ts.NoError(stmt.SetOption(adbc.OptionKeyIncremental, adbc.OptionValueEnabled))

ts.NoError(stmt.SetSqlQuery("infinite"))
_, partitions, _, err := stmt.ExecutePartitions(ctx)
ts.NoError(err)
ts.Equalf(uint64(1), partitions.NumPartitions, "%#v", partitions)

progress := 0.0
go func() {
var err error
var info []byte
for {
// While the below is stuck, we should be able to get the app metadata and progress
progress, err = stmt.(adbc.GetSetOptions).GetOptionDouble(adbc.OptionKeyProgress)
ts.NoError(err)

info, err = stmt.(adbc.GetSetOptions).GetOptionBytes(driver.OptionLastFlightInfo)
ts.NoError(err)
var flightInfo flight.FlightInfo
ts.NoError(proto.Unmarshal(info, &flightInfo))
ts.Equal([]byte("app metadata"), flightInfo.AppMetadata)

if progress > 0.03 {
break
}
}
cancel()
}()

// will get stuck forever, but will "make progress"
_, _, _, err = stmt.ExecutePartitions(ctx)
var adbcErr adbc.Error
ts.ErrorAs(err, &adbcErr)
ts.Equal(adbc.StatusCancelled, adbcErr.Code)
}

func (ts *IncrementalPollTests) TestUnavailable() {
// An error from the server should not tear down all the state. We
// should be able to retry the request.
Expand Down
1 change: 1 addition & 0 deletions go/adbc/driver/flightsql/flightsql_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ const (
OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update"
OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header."
OptionCookieMiddleware = "adbc.flight.sql.rpc.with_cookie_middleware"
OptionLastFlightInfo = "adbc.flight.sql.statement.exec.last_flight_info"
infoDriverName = "ADBC Flight SQL Driver - Go"
)

Expand Down
21 changes: 21 additions & 0 deletions go/adbc/driver/flightsql/flightsql_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ type statement struct {
timeouts timeoutOption
incrementalState *incrementalState
progress float64
// may seem redundant, but incrementalState isn't locked
lastInfo atomic.Pointer[flight.FlightInfo]
}

func (s *statement) closePreparedStatement() error {
Expand All @@ -184,6 +186,7 @@ func (s *statement) clearIncrementalQuery() error {
}
}
s.incrementalState = &incrementalState{}
s.lastInfo.Store(nil)
}
return nil
}
Expand Down Expand Up @@ -249,6 +252,21 @@ func (s *statement) GetOption(key string) (string, error) {
}
}
func (s *statement) GetOptionBytes(key string) ([]byte, error) {
switch key {
case OptionLastFlightInfo:
info := s.lastInfo.Load()
if info == nil {
return []byte{}, nil
}
serialized, err := proto.Marshal(info)
if err != nil {
return nil, adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Could not serialize result for '%s': %s", key, err.Error()),
Code: adbc.StatusInternal,
}
}
return serialized, nil
}
return nil, adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", key),
Code: adbc.StatusNotFound,
Expand Down Expand Up @@ -594,6 +612,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
// Reset the statement for reuse
s.incrementalState = &incrementalState{}
atomicStoreFloat64(&s.progress, 0.0)
s.lastInfo.Store(nil)
return schema, adbc.Partitions{}, totalRecords, nil
}

Expand Down Expand Up @@ -628,6 +647,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
s.incrementalState.previousInfo = poll.GetInfo()
s.incrementalState.retryDescriptor = poll.GetFlightDescriptor()
atomicStoreFloat64(&s.progress, poll.GetProgress())
s.lastInfo.Store(poll.GetInfo())

if s.incrementalState.retryDescriptor == nil {
// Query is finished
Expand All @@ -651,6 +671,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
if s.incrementalState.complete && len(info.Endpoint) == 0 {
s.incrementalState = &incrementalState{}
atomicStoreFloat64(&s.progress, 0.0)
s.lastInfo.Store(nil)
}
} else if s.prepared != nil {
info, err = s.prepared.Execute(ctx, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ class ConnectionOptions(enum.Enum):
class StatementOptions(enum.Enum):
"""Statement options specific to the Flight SQL driver."""

#: The latest FlightInfo value.
#:
#: Thread-safe. Mostly useful when using incremental execution, where an
#: advanced client may want to inspect the latest FlightInfo from the
#: service, but without waiting for execute_partitions to return. (The
#: service may send an updated FlightInfo with progress/app_metadata
#: values, but execute_partitions will only return if there are new
#: endpoints.)
LAST_FLIGHT_INFO = "adbc.flight.sql.statement.exec.last_flight_info"
#: The number of batches to queue per partition. Defaults to 5.
#:
#: This controls how much we read ahead on result sets.
Expand Down
51 changes: 51 additions & 0 deletions python/adbc_driver_flightsql/tests/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
# under the License.

import re
import threading

import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
import pyarrow
import pyarrow.flight
import pytest

import adbc_driver_manager
from adbc_driver_flightsql import StatementOptions as FlightSqlStatementOptions
from adbc_driver_manager import StatementOptions

SCHEMA = pyarrow.schema([("ints", "int32")])
Expand Down Expand Up @@ -106,6 +109,54 @@ def test_incremental_error_poll(test_dbapi) -> None:
assert partitions == []


def test_incremental_cancel(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
assert (
cur.adbc_statement.get_option_bytes(
FlightSqlStatementOptions.LAST_FLIGHT_INFO.value
)
== b""
)

cur.adbc_statement.set_options(
**{
StatementOptions.INCREMENTAL.value: "true",
}
)
partitions, schema = cur.adbc_execute_partitions("forever")
assert len(partitions) == 1

passed = False

def _bg():
nonlocal passed
while True:
progress = cur.adbc_statement.get_option_float(
StatementOptions.PROGRESS.value
)
# XXX: upstream PyArrow never bothered exposing app_metadata
raw_info = cur.adbc_statement.get_option_bytes(
FlightSqlStatementOptions.LAST_FLIGHT_INFO.value
)

# check that it's a valid info
pyarrow.flight.FlightInfo.deserialize(raw_info)
passed = b"app metadata" in raw_info

if progress > 0.07:
break
cur.adbc_cancel()

t = threading.Thread(target=_bg, daemon=True)
t.start()

with pytest.raises(test_dbapi.OperationalError, match="(?i)cancelled"):
cur.adbc_execute_partitions("forever")

t.join()
assert passed


def test_incremental_immediately(test_dbapi) -> None:
with test_dbapi.cursor() as cur:
cur.adbc_statement.set_options(
Expand Down

0 comments on commit d962961

Please sign in to comment.