Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PECO-1054 Expose Arrow batches to users #160

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions internal/fetcher/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ package fetcher

import (
"context"
"github.com/databricks/databricks-sql-go/internal/config"
"sync"

"github.com/databricks/databricks-sql-go/driverctx"
dbsqllog "github.com/databricks/databricks-sql-go/logger"
)

type FetchableItems[OutputType any] interface {
Fetch(ctx context.Context, cfg *config.Config) ([]OutputType, error)
Fetch(ctx context.Context) ([]OutputType, error)
}

type Fetcher[OutputType any] interface {
Expand All @@ -24,7 +23,6 @@ type concurrentFetcher[I FetchableItems[O], O any] struct {
outChan chan O
err error
nWorkers int
cfg *config.Config
mu sync.Mutex
start sync.Once
ctx context.Context
Expand Down Expand Up @@ -100,10 +98,17 @@ func (f *concurrentFetcher[I, O]) logger() *dbsqllog.DBSQLLogger {
return f.DBSQLLogger
}

func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers int, cfg *config.Config, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) {
func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers, maxItemsInMemory int, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) {
if nWorkers < 1 {
nWorkers = 1
}
if maxItemsInMemory < 1 {
maxItemsInMemory = 1
}

// channel for loaded items
// TODO: pass buffer size
outputChannel := make(chan O, 100)
outputChannel := make(chan O, maxItemsInMemory)

// channel to signal a cancel
stopChannel := make(chan bool)
Expand All @@ -118,7 +123,6 @@ func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWork
cancelChan: stopChannel,
ctx: ctx,
nWorkers: nWorkers,
cfg: cfg,
}

return fetcher, nil
Expand All @@ -139,7 +143,7 @@ func work[I FetchableItems[O], O any](f *concurrentFetcher[I, O], workerIndex in
case input, ok := <-f.inputChan:
if ok {
f.logger().Debug().Msgf("concurrent fetcher worker %d loading item", workerIndex)
result, err := input.Fetch(f.ctx, f.cfg)
result, err := input.Fetch(f.ctx)
if err != nil {
f.logger().Debug().Msgf("concurrent fetcher worker %d received error", workerIndex)
f.setErr(err)
Expand Down
12 changes: 6 additions & 6 deletions internal/fetcher/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package fetcher

import (
"context"
"github.com/databricks/databricks-sql-go/internal/config"
"github.com/pkg/errors"
"math"
"testing"
"time"

"github.com/pkg/errors"
)

// Create a mock struct for FetchableItems
Expand All @@ -20,7 +20,7 @@ type mockOutput struct {
}

// Implement the Fetch method
func (m *mockFetchableItem) Fetch(ctx context.Context, cfg *config.Config) ([]*mockOutput, error) {
func (m *mockFetchableItem) Fetch(ctx context.Context) ([]*mockOutput, error) {
time.Sleep(m.wait)
outputs := make([]*mockOutput, 5)
for i := range outputs {
Expand All @@ -35,7 +35,7 @@ var _ FetchableItems[*mockOutput] = (*mockFetchableItem)(nil)
func TestConcurrentFetcher(t *testing.T) {
t.Run("Comprehensively tests the concurrent fetcher", func(t *testing.T) {
ctx := context.Background()
cfg := &config.Config{}

inputChan := make(chan FetchableItems[*mockOutput], 10)
for i := 0; i < 10; i++ {
item := mockFetchableItem{item: i, wait: 1 * time.Second}
Expand All @@ -44,7 +44,7 @@ func TestConcurrentFetcher(t *testing.T) {
close(inputChan)

// Create a fetcher
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, cfg, inputChan)
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, 3, inputChan)
if err != nil {
t.Fatalf("Error creating fetcher: %v", err)
}
Expand Down Expand Up @@ -95,7 +95,7 @@ func TestConcurrentFetcher(t *testing.T) {
close(inputChan)

// Create a new fetcher
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, &config.Config{}, inputChan)
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan)
if err != nil {
t.Fatalf("Error creating fetcher: %v", err)
}
Expand Down
49 changes: 8 additions & 41 deletions internal/rows/arrowbased/arrowRows.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,9 @@ type valueContainerMaker interface {
}

type sparkArrowBatch struct {
rowCount, startRow, endRow int64
arrowRecordBytes []byte
hasSchema bool
}

func (sab *sparkArrowBatch) contains(rowIndex int64) bool {
return sab != nil && sab.startRow <= rowIndex && sab.endRow >= rowIndex
rowscanner.Delimiter
arrowRecordBytes []byte
hasSchema bool
}

type timeStampFn func(arrow.Timestamp) time.Time
Expand All @@ -46,6 +42,7 @@ type colInfo struct {

// arrowRowScanner handles extracting values from arrow records
type arrowRowScanner struct {
rowscanner.Delimiter
recordReader
valueContainerMaker

Expand All @@ -61,9 +58,6 @@ type arrowRowScanner struct {
// database types for the columns
colInfo []colInfo

// number of rows in the current TRowSet
nRows int64

// a TRowSet contains multiple arrow batches
currentBatch *sparkArrowBatch

Expand Down Expand Up @@ -140,12 +134,12 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp
}

rs := &arrowRowScanner{
Delimiter: rowscanner.NewDelimiter(rowSet.StartRowOffset, rowscanner.CountRows(rowSet)),
recordReader: sparkRecordReader{
ctx: ctx,
},
valueContainerMaker: &arrowValueContainerMaker{},
ArrowConfig: arrowConfig,
nRows: countRows(rowSet),
arrowSchemaBytes: schemaBytes,
arrowSchema: arrowSchema,
toTimestampFn: ttsf,
Expand All @@ -172,7 +166,7 @@ func (ars *arrowRowScanner) Close() {
// NRows returns the number of rows in the current set of batches
func (ars *arrowRowScanner) NRows() int64 {
if ars != nil {
return ars.nRows
return ars.Count()
}

return 0
Expand Down Expand Up @@ -203,7 +197,7 @@ func (ars *arrowRowScanner) ScanRow(
return err
}

var rowInBatchIndex int = int(rowIndex - ars.currentBatch.startRow)
var rowInBatchIndex int = int(rowIndex - ars.currentBatch.Start())

// if no location is provided default to UTC
if ars.location == nil {
Expand Down Expand Up @@ -248,41 +242,14 @@ func isIntervalType(typeId cli_service.TTypeId) bool {
return ok
}

// countRows returns the number of rows in the TRowSet
func countRows(rowSet *cli_service.TRowSet) int64 {
if rowSet == nil {
return 0
}

if rowSet.ArrowBatches != nil {
batches := rowSet.ArrowBatches
var n int64
for i := range batches {
n += batches[i].RowCount
}
return n
}

if rowSet.ResultLinks != nil {
links := rowSet.ResultLinks
var n int64
for i := range links {
n += links[i].RowCount
}
return n
}

return 0
}

// loadBatchFor loads the batch containing the specified row if necessary
func (ars *arrowRowScanner) loadBatchFor(rowIndex int64) dbsqlerr.DBError {

if ars == nil || ars.BatchLoader == nil {
return dbsqlerrint.NewDriverError(context.Background(), errArrowRowsNoArrowBatches, nil)
}
// if the batch already loaded we can just return
if ars.currentBatch != nil && ars.currentBatch.contains(rowIndex) && ars.columnValues != nil {
if ars.currentBatch != nil && ars.currentBatch.Contains(rowIndex) && ars.columnValues != nil {
return nil
}

Expand Down
14 changes: 7 additions & 7 deletions internal/rows/arrowbased/arrowRows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,23 +663,23 @@ func TestArrowRowScanner(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, lastReadBatch)
assert.Equal(t, 1, callCount)
assert.Equal(t, int64(0), lastReadBatch.startRow)
assert.Equal(t, int64(0), lastReadBatch.Start())
}

for _, i := range []int64{5, 6, 7} {
err := ars.loadBatchFor(i)
assert.Nil(t, err)
assert.NotNil(t, lastReadBatch)
assert.Equal(t, 2, callCount)
assert.Equal(t, int64(5), lastReadBatch.startRow)
assert.Equal(t, int64(5), lastReadBatch.Start())
}

for _, i := range []int64{8, 9, 10, 11, 12, 13, 14} {
err := ars.loadBatchFor(i)
assert.Nil(t, err)
assert.NotNil(t, lastReadBatch)
assert.Equal(t, 3, callCount)
assert.Equal(t, int64(8), lastReadBatch.startRow)
assert.Equal(t, int64(8), lastReadBatch.Start())
}

err := ars.loadBatchFor(-1)
Expand Down Expand Up @@ -983,13 +983,13 @@ func TestArrowRowScanner(t *testing.T) {

if i%1000 == 0 {
assert.NotNil(t, ars.currentBatch)
assert.Equal(t, int64(i), ars.currentBatch.startRow)
assert.Equal(t, int64(i), ars.currentBatch.Start())
if i < 53000 {
assert.Equal(t, int64(1000), ars.currentBatch.rowCount)
assert.Equal(t, int64(1000), ars.currentBatch.Count())
} else {
assert.Equal(t, int64(940), ars.currentBatch.rowCount)
assert.Equal(t, int64(940), ars.currentBatch.Count())
}
assert.Equal(t, ars.currentBatch.startRow+ars.currentBatch.rowCount-1, ars.currentBatch.endRow)
assert.Equal(t, ars.currentBatch.Start()+ars.currentBatch.Count()-1, ars.currentBatch.End())
}
}

Expand Down
Loading
Loading