Skip to content

Commit

Permalink
PECO-1054 Expose Arrow batches to users (#160)
Browse files Browse the repository at this point in the history
Step one of exposing arrow batches directly to users. 
Moved the logic for iterating over the pages in a result set into
ResultPageIterator.
Rows now composes in ResultPageIterator. 
Introduced Delimeter type. Delimeter tracks a start/end point and
provides functions for determining if a point is within the delimiter
range and the direction of the point if it is outside the delimeter
range.
Updated sparkArrowBatch, arrowRowScanner, columnRows, rows to use
Delimiter.
Updated the Fetch logic for cloudURL and localBatch so that the
concurrentFetcher doesn't need to hold or pass through a Config
instance.
  • Loading branch information
rcypher-databricks authored Sep 19, 2023
2 parents 73073d2 + 69bfdef commit f7c0286
Show file tree
Hide file tree
Showing 14 changed files with 735 additions and 446 deletions.
18 changes: 11 additions & 7 deletions internal/fetcher/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ package fetcher

import (
"context"
"github.com/databricks/databricks-sql-go/internal/config"
"sync"

"github.com/databricks/databricks-sql-go/driverctx"
dbsqllog "github.com/databricks/databricks-sql-go/logger"
)

type FetchableItems[OutputType any] interface {
Fetch(ctx context.Context, cfg *config.Config) ([]OutputType, error)
Fetch(ctx context.Context) ([]OutputType, error)
}

type Fetcher[OutputType any] interface {
Expand All @@ -24,7 +23,6 @@ type concurrentFetcher[I FetchableItems[O], O any] struct {
outChan chan O
err error
nWorkers int
cfg *config.Config
mu sync.Mutex
start sync.Once
ctx context.Context
Expand Down Expand Up @@ -100,10 +98,17 @@ func (f *concurrentFetcher[I, O]) logger() *dbsqllog.DBSQLLogger {
return f.DBSQLLogger
}

func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers int, cfg *config.Config, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) {
func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers, maxItemsInMemory int, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) {
if nWorkers < 1 {
nWorkers = 1
}
if maxItemsInMemory < 1 {
maxItemsInMemory = 1
}

// channel for loaded items
// TODO: pass buffer size
outputChannel := make(chan O, 100)
outputChannel := make(chan O, maxItemsInMemory)

// channel to signal a cancel
stopChannel := make(chan bool)
Expand All @@ -118,7 +123,6 @@ func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWork
cancelChan: stopChannel,
ctx: ctx,
nWorkers: nWorkers,
cfg: cfg,
}

return fetcher, nil
Expand All @@ -139,7 +143,7 @@ func work[I FetchableItems[O], O any](f *concurrentFetcher[I, O], workerIndex in
case input, ok := <-f.inputChan:
if ok {
f.logger().Debug().Msgf("concurrent fetcher worker %d loading item", workerIndex)
result, err := input.Fetch(f.ctx, f.cfg)
result, err := input.Fetch(f.ctx)
if err != nil {
f.logger().Debug().Msgf("concurrent fetcher worker %d received error", workerIndex)
f.setErr(err)
Expand Down
12 changes: 6 additions & 6 deletions internal/fetcher/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package fetcher

import (
"context"
"github.com/databricks/databricks-sql-go/internal/config"
"github.com/pkg/errors"
"math"
"testing"
"time"

"github.com/pkg/errors"
)

// Create a mock struct for FetchableItems
Expand All @@ -20,7 +20,7 @@ type mockOutput struct {
}

// Implement the Fetch method
func (m *mockFetchableItem) Fetch(ctx context.Context, cfg *config.Config) ([]*mockOutput, error) {
func (m *mockFetchableItem) Fetch(ctx context.Context) ([]*mockOutput, error) {
time.Sleep(m.wait)
outputs := make([]*mockOutput, 5)
for i := range outputs {
Expand All @@ -35,7 +35,7 @@ var _ FetchableItems[*mockOutput] = (*mockFetchableItem)(nil)
func TestConcurrentFetcher(t *testing.T) {
t.Run("Comprehensively tests the concurrent fetcher", func(t *testing.T) {
ctx := context.Background()
cfg := &config.Config{}

inputChan := make(chan FetchableItems[*mockOutput], 10)
for i := 0; i < 10; i++ {
item := mockFetchableItem{item: i, wait: 1 * time.Second}
Expand All @@ -44,7 +44,7 @@ func TestConcurrentFetcher(t *testing.T) {
close(inputChan)

// Create a fetcher
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, cfg, inputChan)
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, 3, inputChan)
if err != nil {
t.Fatalf("Error creating fetcher: %v", err)
}
Expand Down Expand Up @@ -95,7 +95,7 @@ func TestConcurrentFetcher(t *testing.T) {
close(inputChan)

// Create a new fetcher
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, &config.Config{}, inputChan)
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan)
if err != nil {
t.Fatalf("Error creating fetcher: %v", err)
}
Expand Down
49 changes: 8 additions & 41 deletions internal/rows/arrowbased/arrowRows.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,9 @@ type valueContainerMaker interface {
}

type sparkArrowBatch struct {
rowCount, startRow, endRow int64
arrowRecordBytes []byte
hasSchema bool
}

func (sab *sparkArrowBatch) contains(rowIndex int64) bool {
return sab != nil && sab.startRow <= rowIndex && sab.endRow >= rowIndex
rowscanner.Delimiter
arrowRecordBytes []byte
hasSchema bool
}

type timeStampFn func(arrow.Timestamp) time.Time
Expand All @@ -46,6 +42,7 @@ type colInfo struct {

// arrowRowScanner handles extracting values from arrow records
type arrowRowScanner struct {
rowscanner.Delimiter
recordReader
valueContainerMaker

Expand All @@ -61,9 +58,6 @@ type arrowRowScanner struct {
// database types for the columns
colInfo []colInfo

// number of rows in the current TRowSet
nRows int64

// a TRowSet contains multiple arrow batches
currentBatch *sparkArrowBatch

Expand Down Expand Up @@ -140,12 +134,12 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp
}

rs := &arrowRowScanner{
Delimiter: rowscanner.NewDelimiter(rowSet.StartRowOffset, rowscanner.CountRows(rowSet)),
recordReader: sparkRecordReader{
ctx: ctx,
},
valueContainerMaker: &arrowValueContainerMaker{},
ArrowConfig: arrowConfig,
nRows: countRows(rowSet),
arrowSchemaBytes: schemaBytes,
arrowSchema: arrowSchema,
toTimestampFn: ttsf,
Expand All @@ -172,7 +166,7 @@ func (ars *arrowRowScanner) Close() {
// NRows returns the number of rows in the current set of batches
func (ars *arrowRowScanner) NRows() int64 {
if ars != nil {
return ars.nRows
return ars.Count()
}

return 0
Expand Down Expand Up @@ -203,7 +197,7 @@ func (ars *arrowRowScanner) ScanRow(
return err
}

var rowInBatchIndex int = int(rowIndex - ars.currentBatch.startRow)
var rowInBatchIndex int = int(rowIndex - ars.currentBatch.Start())

// if no location is provided default to UTC
if ars.location == nil {
Expand Down Expand Up @@ -248,41 +242,14 @@ func isIntervalType(typeId cli_service.TTypeId) bool {
return ok
}

// countRows returns the number of rows in the TRowSet
func countRows(rowSet *cli_service.TRowSet) int64 {
if rowSet == nil {
return 0
}

if rowSet.ArrowBatches != nil {
batches := rowSet.ArrowBatches
var n int64
for i := range batches {
n += batches[i].RowCount
}
return n
}

if rowSet.ResultLinks != nil {
links := rowSet.ResultLinks
var n int64
for i := range links {
n += links[i].RowCount
}
return n
}

return 0
}

// loadBatchFor loads the batch containing the specified row if necessary
func (ars *arrowRowScanner) loadBatchFor(rowIndex int64) dbsqlerr.DBError {

if ars == nil || ars.BatchLoader == nil {
return dbsqlerrint.NewDriverError(context.Background(), errArrowRowsNoArrowBatches, nil)
}
// if the batch already loaded we can just return
if ars.currentBatch != nil && ars.currentBatch.contains(rowIndex) && ars.columnValues != nil {
if ars.currentBatch != nil && ars.currentBatch.Contains(rowIndex) && ars.columnValues != nil {
return nil
}

Expand Down
14 changes: 7 additions & 7 deletions internal/rows/arrowbased/arrowRows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,23 +663,23 @@ func TestArrowRowScanner(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, lastReadBatch)
assert.Equal(t, 1, callCount)
assert.Equal(t, int64(0), lastReadBatch.startRow)
assert.Equal(t, int64(0), lastReadBatch.Start())
}

for _, i := range []int64{5, 6, 7} {
err := ars.loadBatchFor(i)
assert.Nil(t, err)
assert.NotNil(t, lastReadBatch)
assert.Equal(t, 2, callCount)
assert.Equal(t, int64(5), lastReadBatch.startRow)
assert.Equal(t, int64(5), lastReadBatch.Start())
}

for _, i := range []int64{8, 9, 10, 11, 12, 13, 14} {
err := ars.loadBatchFor(i)
assert.Nil(t, err)
assert.NotNil(t, lastReadBatch)
assert.Equal(t, 3, callCount)
assert.Equal(t, int64(8), lastReadBatch.startRow)
assert.Equal(t, int64(8), lastReadBatch.Start())
}

err := ars.loadBatchFor(-1)
Expand Down Expand Up @@ -983,13 +983,13 @@ func TestArrowRowScanner(t *testing.T) {

if i%1000 == 0 {
assert.NotNil(t, ars.currentBatch)
assert.Equal(t, int64(i), ars.currentBatch.startRow)
assert.Equal(t, int64(i), ars.currentBatch.Start())
if i < 53000 {
assert.Equal(t, int64(1000), ars.currentBatch.rowCount)
assert.Equal(t, int64(1000), ars.currentBatch.Count())
} else {
assert.Equal(t, int64(940), ars.currentBatch.rowCount)
assert.Equal(t, int64(940), ars.currentBatch.Count())
}
assert.Equal(t, ars.currentBatch.startRow+ars.currentBatch.rowCount-1, ars.currentBatch.endRow)
assert.Equal(t, ars.currentBatch.Start()+ars.currentBatch.Count()-1, ars.currentBatch.End())
}
}

Expand Down
Loading

0 comments on commit f7c0286

Please sign in to comment.