diff --git a/internal/fetcher/fetcher.go b/internal/fetcher/fetcher.go index 478f260..cce84ad 100644 --- a/internal/fetcher/fetcher.go +++ b/internal/fetcher/fetcher.go @@ -2,7 +2,6 @@ package fetcher import ( "context" - "github.com/databricks/databricks-sql-go/internal/config" "sync" "github.com/databricks/databricks-sql-go/driverctx" @@ -10,7 +9,7 @@ import ( ) type FetchableItems[OutputType any] interface { - Fetch(ctx context.Context, cfg *config.Config) ([]OutputType, error) + Fetch(ctx context.Context) ([]OutputType, error) } type Fetcher[OutputType any] interface { @@ -24,7 +23,6 @@ type concurrentFetcher[I FetchableItems[O], O any] struct { outChan chan O err error nWorkers int - cfg *config.Config mu sync.Mutex start sync.Once ctx context.Context @@ -100,10 +98,17 @@ func (f *concurrentFetcher[I, O]) logger() *dbsqllog.DBSQLLogger { return f.DBSQLLogger } -func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers int, cfg *config.Config, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) { +func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers, maxItemsInMemory int, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) { + if nWorkers < 1 { + nWorkers = 1 + } + if maxItemsInMemory < 1 { + maxItemsInMemory = 1 + } + // channel for loaded items // TODO: pass buffer size - outputChannel := make(chan O, 100) + outputChannel := make(chan O, maxItemsInMemory) // channel to signal a cancel stopChannel := make(chan bool) @@ -118,7 +123,6 @@ func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWork cancelChan: stopChannel, ctx: ctx, nWorkers: nWorkers, - cfg: cfg, } return fetcher, nil @@ -139,7 +143,7 @@ func work[I FetchableItems[O], O any](f *concurrentFetcher[I, O], workerIndex in case input, ok := <-f.inputChan: if ok { f.logger().Debug().Msgf("concurrent fetcher worker %d loading item", workerIndex) - result, err := input.Fetch(f.ctx, f.cfg) + result, err := input.Fetch(f.ctx) if err != nil { f.logger().Debug().Msgf("concurrent fetcher worker %d received error", workerIndex) f.setErr(err) diff --git a/internal/fetcher/fetcher_test.go b/internal/fetcher/fetcher_test.go index 78b5e27..76be59a 100644 --- a/internal/fetcher/fetcher_test.go +++ b/internal/fetcher/fetcher_test.go @@ -2,11 +2,11 @@ package fetcher import ( "context" - "github.com/databricks/databricks-sql-go/internal/config" - "github.com/pkg/errors" "math" "testing" "time" + + "github.com/pkg/errors" ) // Create a mock struct for FetchableItems @@ -20,7 +20,7 @@ type mockOutput struct { } // Implement the Fetch method -func (m *mockFetchableItem) Fetch(ctx context.Context, cfg *config.Config) ([]*mockOutput, error) { +func (m *mockFetchableItem) Fetch(ctx context.Context) ([]*mockOutput, error) { time.Sleep(m.wait) outputs := make([]*mockOutput, 5) for i := range outputs { @@ -35,7 +35,7 @@ var _ FetchableItems[*mockOutput] = (*mockFetchableItem)(nil) func TestConcurrentFetcher(t *testing.T) { t.Run("Comprehensively tests the concurrent fetcher", func(t *testing.T) { ctx := context.Background() - cfg := &config.Config{} + inputChan := make(chan FetchableItems[*mockOutput], 10) for i := 0; i < 10; i++ { item := mockFetchableItem{item: i, wait: 1 * time.Second} @@ -44,7 +44,7 @@ func TestConcurrentFetcher(t *testing.T) { close(inputChan) // Create a fetcher - fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, cfg, inputChan) + fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, 3, inputChan) if err != nil { t.Fatalf("Error creating fetcher: %v", err) } @@ -95,7 +95,7 @@ func TestConcurrentFetcher(t *testing.T) { close(inputChan) // Create a new fetcher - fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, &config.Config{}, inputChan) + fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan) if err != nil { t.Fatalf("Error creating fetcher: %v", err) } diff --git a/internal/rows/arrowbased/arrowRows.go b/internal/rows/arrowbased/arrowRows.go index ec82f36..6504db5 100644 --- a/internal/rows/arrowbased/arrowRows.go +++ b/internal/rows/arrowbased/arrowRows.go @@ -27,13 +27,9 @@ type valueContainerMaker interface { } type sparkArrowBatch struct { - rowCount, startRow, endRow int64 - arrowRecordBytes []byte - hasSchema bool -} - -func (sab *sparkArrowBatch) contains(rowIndex int64) bool { - return sab != nil && sab.startRow <= rowIndex && sab.endRow >= rowIndex + rowscanner.Delimiter + arrowRecordBytes []byte + hasSchema bool } type timeStampFn func(arrow.Timestamp) time.Time @@ -46,6 +42,7 @@ type colInfo struct { // arrowRowScanner handles extracting values from arrow records type arrowRowScanner struct { + rowscanner.Delimiter recordReader valueContainerMaker @@ -61,9 +58,6 @@ type arrowRowScanner struct { // database types for the columns colInfo []colInfo - // number of rows in the current TRowSet - nRows int64 - // a TRowSet contains multiple arrow batches currentBatch *sparkArrowBatch @@ -140,12 +134,12 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp } rs := &arrowRowScanner{ + Delimiter: rowscanner.NewDelimiter(rowSet.StartRowOffset, rowscanner.CountRows(rowSet)), recordReader: sparkRecordReader{ ctx: ctx, }, valueContainerMaker: &arrowValueContainerMaker{}, ArrowConfig: arrowConfig, - nRows: countRows(rowSet), arrowSchemaBytes: schemaBytes, arrowSchema: arrowSchema, toTimestampFn: ttsf, @@ -172,7 +166,7 @@ func (ars *arrowRowScanner) Close() { // NRows returns the number of rows in the current set of batches func (ars *arrowRowScanner) NRows() int64 { if ars != nil { - return ars.nRows + return ars.Count() } return 0 @@ -203,7 +197,7 @@ func (ars *arrowRowScanner) ScanRow( return err } - var rowInBatchIndex int = int(rowIndex - ars.currentBatch.startRow) + var rowInBatchIndex int = int(rowIndex - ars.currentBatch.Start()) // if no location is provided default to UTC if ars.location == nil { @@ -248,33 +242,6 @@ func isIntervalType(typeId cli_service.TTypeId) bool { return ok } -// countRows returns the number of rows in the TRowSet -func countRows(rowSet *cli_service.TRowSet) int64 { - if rowSet == nil { - return 0 - } - - if rowSet.ArrowBatches != nil { - batches := rowSet.ArrowBatches - var n int64 - for i := range batches { - n += batches[i].RowCount - } - return n - } - - if rowSet.ResultLinks != nil { - links := rowSet.ResultLinks - var n int64 - for i := range links { - n += links[i].RowCount - } - return n - } - - return 0 -} - // loadBatchFor loads the batch containing the specified row if necessary func (ars *arrowRowScanner) loadBatchFor(rowIndex int64) dbsqlerr.DBError { @@ -282,7 +249,7 @@ func (ars *arrowRowScanner) loadBatchFor(rowIndex int64) dbsqlerr.DBError { return dbsqlerrint.NewDriverError(context.Background(), errArrowRowsNoArrowBatches, nil) } // if the batch already loaded we can just return - if ars.currentBatch != nil && ars.currentBatch.contains(rowIndex) && ars.columnValues != nil { + if ars.currentBatch != nil && ars.currentBatch.Contains(rowIndex) && ars.columnValues != nil { return nil } diff --git a/internal/rows/arrowbased/arrowRows_test.go b/internal/rows/arrowbased/arrowRows_test.go index 34f2654..ac11b8e 100644 --- a/internal/rows/arrowbased/arrowRows_test.go +++ b/internal/rows/arrowbased/arrowRows_test.go @@ -663,7 +663,7 @@ func TestArrowRowScanner(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, lastReadBatch) assert.Equal(t, 1, callCount) - assert.Equal(t, int64(0), lastReadBatch.startRow) + assert.Equal(t, int64(0), lastReadBatch.Start()) } for _, i := range []int64{5, 6, 7} { @@ -671,7 +671,7 @@ func TestArrowRowScanner(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, lastReadBatch) assert.Equal(t, 2, callCount) - assert.Equal(t, int64(5), lastReadBatch.startRow) + assert.Equal(t, int64(5), lastReadBatch.Start()) } for _, i := range []int64{8, 9, 10, 11, 12, 13, 14} { @@ -679,7 +679,7 @@ func TestArrowRowScanner(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, lastReadBatch) assert.Equal(t, 3, callCount) - assert.Equal(t, int64(8), lastReadBatch.startRow) + assert.Equal(t, int64(8), lastReadBatch.Start()) } err := ars.loadBatchFor(-1) @@ -983,13 +983,13 @@ func TestArrowRowScanner(t *testing.T) { if i%1000 == 0 { assert.NotNil(t, ars.currentBatch) - assert.Equal(t, int64(i), ars.currentBatch.startRow) + assert.Equal(t, int64(i), ars.currentBatch.Start()) if i < 53000 { - assert.Equal(t, int64(1000), ars.currentBatch.rowCount) + assert.Equal(t, int64(1000), ars.currentBatch.Count()) } else { - assert.Equal(t, int64(940), ars.currentBatch.rowCount) + assert.Equal(t, int64(940), ars.currentBatch.Count()) } - assert.Equal(t, ars.currentBatch.startRow+ars.currentBatch.rowCount-1, ars.currentBatch.endRow) + assert.Equal(t, ars.currentBatch.Start()+ars.currentBatch.Count()-1, ars.currentBatch.End()) } } diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 0981a46..b534ce9 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -4,11 +4,13 @@ import ( "bufio" "bytes" "context" + "io" + "time" + "github.com/databricks/databricks-sql-go/internal/config" + "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" "github.com/pierrec/lz4/v4" "github.com/pkg/errors" - "io" - "time" "net/http" @@ -20,12 +22,120 @@ import ( "github.com/databricks/databricks-sql-go/internal/fetcher" ) +type BatchLoader interface { + GetBatchFor(recordNum int64) (*sparkArrowBatch, dbsqlerr.DBError) +} + +func NewCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowResultLink, cfg *config.Config) (*batchLoader[*cloudURL], dbsqlerr.DBError) { + + if cfg == nil { + cfg = config.WithDefaults() + } + + inputChan := make(chan fetcher.FetchableItems[*sparkArrowBatch], len(files)) + + for i := range files { + li := &cloudURL{ + TSparkArrowResultLink: files[i], + minTimeToExpiry: cfg.MinTimeToExpiry, + useLz4Compression: cfg.UseLz4Compression, + } + inputChan <- li + } + + // make sure to close input channel or fetcher will block waiting for more inputs + close(inputChan) + + f, _ := fetcher.NewConcurrentFetcher[*cloudURL](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) + cbl := &batchLoader[*cloudURL]{ + Fetcher: f, + ctx: ctx, + } + + return cbl, nil +} + +func NewLocalBatchLoader(ctx context.Context, batches []*cli_service.TSparkArrowBatch, cfg *config.Config) (*batchLoader[*localBatch], dbsqlerr.DBError) { + + if cfg == nil { + cfg = config.WithDefaults() + } + + var startRow int64 + inputChan := make(chan fetcher.FetchableItems[*sparkArrowBatch], len(batches)) + for i := range batches { + b := batches[i] + if b != nil { + li := &localBatch{ + TSparkArrowBatch: b, + startRow: startRow, + useLz4Compression: cfg.UseLz4Compression, + } + inputChan <- li + startRow = startRow + b.RowCount + } + } + close(inputChan) + + f, _ := fetcher.NewConcurrentFetcher[*localBatch](ctx, cfg.MaxDownloadThreads, cfg.MaxFilesInMemory, inputChan) + cbl := &batchLoader[*localBatch]{ + Fetcher: f, + ctx: ctx, + } + + return cbl, nil +} + +type batchLoader[T interface { + Fetch(ctx context.Context) ([]*sparkArrowBatch, error) +}] struct { + fetcher.Fetcher[*sparkArrowBatch] + arrowBatches []*sparkArrowBatch + ctx context.Context +} + +var _ BatchLoader = (*batchLoader[*localBatch])(nil) + +func (cbl *batchLoader[T]) GetBatchFor(recordNum int64) (*sparkArrowBatch, dbsqlerr.DBError) { + + for i := range cbl.arrowBatches { + if cbl.arrowBatches[i].Start() <= recordNum && cbl.arrowBatches[i].End() >= recordNum { + return cbl.arrowBatches[i], nil + } + } + + batchChan, _, err := cbl.Start() + if err != nil { + return nil, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowIndex(recordNum), err) + } + + for { + batch, ok := <-batchChan + if !ok { + err := cbl.Err() + if err != nil { + return nil, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowIndex(recordNum), err) + } + break + } + + cbl.arrowBatches = append(cbl.arrowBatches, batch) + if batch.Contains(recordNum) { + return batch, nil + } + } + + return nil, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowIndex(recordNum), err) +} + type cloudURL struct { *cli_service.TSparkArrowResultLink + minTimeToExpiry time.Duration + useLz4Compression bool } -func (cu *cloudURL) Fetch(ctx context.Context, cfg *config.Config) ([]*sparkArrowBatch, error) { - if isLinkExpired(cu.ExpiryTime, cfg.MinTimeToExpiry) { +func (cu *cloudURL) Fetch(ctx context.Context) ([]*sparkArrowBatch, error) { + if isLinkExpired(cu.ExpiryTime, cu.minTimeToExpiry) { return nil, errors.New(dbsqlerr.ErrLinkExpired) } @@ -48,7 +158,7 @@ func (cu *cloudURL) Fetch(ctx context.Context, cfg *config.Config) ([]*sparkArro var arrowSchema *arrow.Schema var arrowBatches []*sparkArrowBatch - rdr, err := getArrowReader(res.Body, cfg.UseLz4Compression) + rdr, err := getArrowReader(res.Body, cu.useLz4Compression) if err != nil { return nil, err @@ -78,11 +188,9 @@ func (cu *cloudURL) Fetch(ctx context.Context, cfg *config.Config) ([]*sparkArro recordBytes := output.Bytes() arrowBatches = append(arrowBatches, &sparkArrowBatch{ + Delimiter: rowscanner.NewDelimiter(startRow, r.NumRows()), arrowRecordBytes: recordBytes, hasSchema: true, - rowCount: r.NumRows(), - startRow: startRow, - endRow: startRow + r.NumRows() - 1, }) startRow = startRow + r.NumRows() @@ -130,108 +238,21 @@ var _ fetcher.FetchableItems[*sparkArrowBatch] = (*cloudURL)(nil) type localBatch struct { *cli_service.TSparkArrowBatch - startRow int64 + startRow int64 + useLz4Compression bool } var _ fetcher.FetchableItems[*sparkArrowBatch] = (*localBatch)(nil) -func (lb *localBatch) Fetch(ctx context.Context, cfg *config.Config) ([]*sparkArrowBatch, error) { - arrowBatchBytes, err := getArrowBatch(cfg.UseLz4Compression, lb.Batch) +func (lb *localBatch) Fetch(ctx context.Context) ([]*sparkArrowBatch, error) { + arrowBatchBytes, err := getArrowBatch(lb.useLz4Compression, lb.Batch) if err != nil { return nil, err } batch := &sparkArrowBatch{ - rowCount: lb.RowCount, - startRow: lb.startRow, - endRow: lb.startRow + lb.RowCount - 1, + Delimiter: rowscanner.NewDelimiter(lb.startRow, lb.RowCount), arrowRecordBytes: arrowBatchBytes, } return []*sparkArrowBatch{batch}, nil } - -type BatchLoader interface { - GetBatchFor(recordNum int64) (*sparkArrowBatch, dbsqlerr.DBError) -} - -type batchLoader[T interface { - Fetch(ctx context.Context, cfg *config.Config) ([]*sparkArrowBatch, error) -}] struct { - fetcher.Fetcher[*sparkArrowBatch] - arrowBatches []*sparkArrowBatch - ctx context.Context -} - -func NewCloudBatchLoader(ctx context.Context, files []*cli_service.TSparkArrowResultLink, cfg *config.Config) (*batchLoader[*cloudURL], dbsqlerr.DBError) { - inputChan := make(chan fetcher.FetchableItems[*sparkArrowBatch], len(files)) - - for i := range files { - li := &cloudURL{TSparkArrowResultLink: files[i]} - inputChan <- li - } - - // make sure to close input channel or fetcher will block waiting for more inputs - close(inputChan) - - f, _ := fetcher.NewConcurrentFetcher[*cloudURL](ctx, 3, cfg, inputChan) - cbl := &batchLoader[*cloudURL]{ - Fetcher: f, - ctx: ctx, - } - - return cbl, nil -} - -func NewLocalBatchLoader(ctx context.Context, batches []*cli_service.TSparkArrowBatch, cfg *config.Config) (*batchLoader[*localBatch], dbsqlerr.DBError) { - var startRow int64 - inputChan := make(chan fetcher.FetchableItems[*sparkArrowBatch], len(batches)) - for i := range batches { - b := batches[i] - if b != nil { - li := &localBatch{TSparkArrowBatch: b, startRow: startRow} - inputChan <- li - startRow = startRow + b.RowCount - } - } - close(inputChan) - - f, _ := fetcher.NewConcurrentFetcher[*localBatch](ctx, 3, cfg, inputChan) - cbl := &batchLoader[*localBatch]{ - Fetcher: f, - ctx: ctx, - } - - return cbl, nil -} - -func (cbl *batchLoader[T]) GetBatchFor(recordNum int64) (*sparkArrowBatch, dbsqlerr.DBError) { - - for i := range cbl.arrowBatches { - if cbl.arrowBatches[i].startRow <= recordNum && cbl.arrowBatches[i].endRow >= recordNum { - return cbl.arrowBatches[i], nil - } - } - - batchChan, _, err := cbl.Start() - if err != nil { - return nil, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowIndex(recordNum), err) - } - - for { - batch, ok := <-batchChan - if !ok { - err := cbl.Err() - if err != nil { - return nil, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowIndex(recordNum), err) - } - break - } - - cbl.arrowBatches = append(cbl.arrowBatches, batch) - if batch.contains(recordNum) { - return batch, nil - } - } - - return nil, dbsqlerrint.NewDriverError(cbl.ctx, errArrowRowsInvalidRowIndex(recordNum), err) -} diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index c8b9f77..fe412fc 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -4,21 +4,22 @@ import ( "bytes" "context" "fmt" + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" + "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/ipc" "github.com/apache/arrow/go/v12/arrow/memory" dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" - "github.com/databricks/databricks-sql-go/internal/config" dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" + "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" "github.com/pkg/errors" "github.com/stretchr/testify/assert" - "net/http" - "net/http/httptest" - "reflect" - "testing" - "time" ) func generateMockArrowBytes() []byte { @@ -76,11 +77,9 @@ func TestBatchLoader(t *testing.T) { linkExpired: false, expectedResponse: []*sparkArrowBatch{ { + Delimiter: rowscanner.NewDelimiter(0, 3), arrowRecordBytes: generateMockArrowBytes(), hasSchema: true, - rowCount: 3, - startRow: 0, - endRow: 2, }, }, expectedErr: nil, @@ -129,9 +128,8 @@ func TestBatchLoader(t *testing.T) { } ctx := context.Background() - cfg := config.WithDefaults() - resp, err := cu.Fetch(ctx, cfg) + resp, err := cu.Fetch(ctx) if !reflect.DeepEqual(resp, tc.expectedResponse) { t.Errorf("expected (%v), got (%v)", tc.expectedResponse, resp) diff --git a/internal/rows/columnbased/columnRows.go b/internal/rows/columnbased/columnRows.go index 3bd3aaa..21ff7e4 100644 --- a/internal/rows/columnbased/columnRows.go +++ b/internal/rows/columnbased/columnRows.go @@ -22,11 +22,9 @@ type columnRowScanner struct { rowSet *cli_service.TRowSet schema *cli_service.TTableSchema - // number of rows in the current TRowSet - nRows int64 - location *time.Location ctx context.Context + rowscanner.Delimiter } var _ rowscanner.RowScanner = (*columnRowScanner)(nil) @@ -47,9 +45,9 @@ func NewColumnRowScanner(schema *cli_service.TTableSchema, rowSet *cli_service.T logger.Debug().Msg("databricks: creating column row scanner") rs := &columnRowScanner{ + Delimiter: rowscanner.NewDelimiter(rowSet.StartRowOffset, rowscanner.CountRows(rowSet)), schema: schema, rowSet: rowSet, - nRows: countRows(rowSet), DBSQLLogger: logger, location: location, ctx: ctx, @@ -66,7 +64,7 @@ func (crs *columnRowScanner) NRows() int64 { if crs == nil { return 0 } - return crs.nRows + return crs.Count() } // ScanRow is called to populate the provided slice with the @@ -137,39 +135,3 @@ func (crs *columnRowScanner) value(tColumn *cli_service.TColumn, tColumnDesc *cl return val, err } - -// countRows returns the number of rows in the TRowSet -func countRows(rowSet *cli_service.TRowSet) int64 { - if rowSet == nil || rowSet.Columns == nil { - return 0 - } - - // Find a column/values and return the number of values. - for _, col := range rowSet.Columns { - if col.BoolVal != nil { - return int64(len(col.BoolVal.Values)) - } - if col.ByteVal != nil { - return int64(len(col.ByteVal.Values)) - } - if col.I16Val != nil { - return int64(len(col.I16Val.Values)) - } - if col.I32Val != nil { - return int64(len(col.I32Val.Values)) - } - if col.I64Val != nil { - return int64(len(col.I64Val.Values)) - } - if col.StringVal != nil { - return int64(len(col.StringVal.Values)) - } - if col.DoubleVal != nil { - return int64(len(col.DoubleVal.Values)) - } - if col.BinaryVal != nil { - return int64(len(col.BinaryVal.Values)) - } - } - return 0 -} diff --git a/internal/rows/errors.go b/internal/rows/errors.go index aae3d5b..4e84d8f 100644 --- a/internal/rows/errors.go +++ b/internal/rows/errors.go @@ -2,18 +2,12 @@ package rows import "fmt" -var errRowsFetchPriorToStart = "databricks: unable to fetch row page prior to start of results" var errRowsNoClient = "databricks: instance of Rows missing client" var errRowsNilRows = "databricks: nil Rows instance" var errRowsUnknowRowType = "databricks: unknown rows representation" var errRowsCloseFailed = "databricks: Rows instance Close operation failed" var errRowsMetadataFetchFailed = "databricks: Rows instance failed to retrieve result set metadata" -var errRowsResultFetchFailed = "databricks: Rows instance failed to retrieve results" func errRowsInvalidColumnIndex(index int) string { return fmt.Sprintf("databricks: invalid column index: %d", index) } - -func errRowsUnandledFetchDirection(dir string) string { - return fmt.Sprintf("databricks: unhandled fetch direction %s", dir) -} diff --git a/internal/rows/rows.go b/internal/rows/rows.go index 8f248ad..797b7a0 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -4,7 +4,7 @@ import ( "context" "database/sql" "database/sql/driver" - "io" + "errors" "math" "reflect" "time" @@ -31,6 +31,7 @@ type rows struct { // The RowScanner is responsible for handling the different // formats in which the query results can be returned rowscanner.RowScanner + rowscanner.ResultPageIterator // Handle for the associated database operation. opHandle *cli_service.TOperationHandle @@ -55,15 +56,9 @@ type rows struct { connId string correlationId string - // Index in the current page of rows - nextRowIndex int64 - // Row number within the overall result set nextRowNumber int64 - // starting row number of the current results page - pageStartingRowNum int64 - // If the server returns an entire result set // in the direct results it may have already // closed the operation. @@ -243,17 +238,18 @@ func (r *rows) Next(dest []driver.Value) error { } // Put values into the destination slice - err = r.ScanRow(dest, r.nextRowIndex) + err = r.ScanRow(dest, r.nextRowIndex()) if err != nil { return err } - r.nextRowIndex++ r.nextRowNumber++ return nil } +func (r *rows) nextRowIndex() int64 { return r.nextRowNumber - r.RowScanner.Start() } + // ColumnTypeScanType returns column's native type func (r *rows) ColumnTypeScanType(index int) reflect.Type { err := isValidRows(r) @@ -415,13 +411,7 @@ func (r *rows) isNextRowInPage() (bool, dbsqlerr.DBError) { return false, nil } - nRowsInPage := r.NRows() - if nRowsInPage == 0 { - return false, nil - } - - startRowOffset := r.pageStartingRowNum - return r.nextRowNumber >= startRowOffset && r.nextRowNumber < (startRowOffset+nRowsInPage), nil + return r.RowScanner.Contains(r.nextRowNumber), nil } // getResultMetadata does a one time fetch of the result set schema @@ -458,79 +448,35 @@ func (r *rows) fetchResultPage() error { return err } - r.logger().Debug().Msgf("databricks: fetching result page for row %d", r.nextRowNumber) - - var b bool - var e dbsqlerr.DBError - for b, e = r.isNextRowInPage(); !b && e == nil; b, e = r.isNextRowInPage() { - - // determine the direction of page fetching. Currently we only handle - // TFetchOrientation_FETCH_PRIOR and TFetchOrientation_FETCH_NEXT - var direction cli_service.TFetchOrientation = r.getPageFetchDirection() - if direction == cli_service.TFetchOrientation_FETCH_PRIOR { - // can't fetch rows previous to the start - if r.pageStartingRowNum == 0 { - return dbsqlerr_int.NewDriverError(r.ctx, errRowsFetchPriorToStart, nil) - } - } else if direction == cli_service.TFetchOrientation_FETCH_NEXT { - // can't fetch past the end of the query results - if !r.hasMoreRows { - return io.EOF - } - } else { - r.logger().Error().Msgf(errRowsUnandledFetchDirection(direction.String())) - return dbsqlerr_int.NewDriverError(r.ctx, errRowsUnandledFetchDirection(direction.String()), nil) - } - - r.logger().Debug().Msgf("fetching next batch of up to %d rows, %s", r.maxPageSize, direction.String()) - - var includeResultSetMetadata = true - req := cli_service.TFetchResultsReq{ - OperationHandle: r.opHandle, - MaxRows: r.maxPageSize, - Orientation: direction, - IncludeResultSetMetadata: &includeResultSetMetadata, - } - ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId) - - fetchResult, err := r.client.FetchResults(ctx, &req) - if err != nil { - r.logger().Err(err).Msg("databricks: Rows instance failed to retrieve results") - return dbsqlerr_int.NewRequestError(r.ctx, errRowsResultFetchFailed, err) - } - - err1 := r.makeRowScanner(fetchResult) - if err1 != nil { - return err1 - } - - r.logger().Debug().Msgf("databricks: new result page startRow: %d, nRows: %v, hasMoreRows: %v", fetchResult.Results.StartRowOffset, r.NRows(), fetchResult.HasMoreRows) + if r.RowScanner != nil && r.RowScanner.Contains(r.nextRowNumber) { + return nil } - if e != nil { - return e + if r.RowScanner != nil && r.nextRowNumber < r.RowScanner.Start() { + //TODO + return errors.New("can't go backward") } - // don't assume the next row is the first row in the page - r.nextRowIndex = r.nextRowNumber - r.pageStartingRowNum + if r.ResultPageIterator == nil { + r.ResultPageIterator = makeResultPageIterator(r) + } - return nil -} + fetchResult, err1 := r.ResultPageIterator.Next() + if err1 != nil { + return err1 + } -// getPageFetchDirection returns the cli_service.TFetchOrientation -// necessary to fetch a result page containing the next row number. -// Note: if the next row number is in the current page TFetchOrientation_FETCH_NEXT -// is returned. Use rows.nextRowInPage to determine if a fetch is necessary -func (r *rows) getPageFetchDirection() cli_service.TFetchOrientation { - if r == nil { - return cli_service.TFetchOrientation_FETCH_NEXT + err1 = r.makeRowScanner(fetchResult) + if err1 != nil { + return err1 } - if r.nextRowNumber < r.pageStartingRowNum { - return cli_service.TFetchOrientation_FETCH_PRIOR + if !r.RowScanner.Contains(r.nextRowNumber) { + // TODO + return errors.New("Invalid row number state") } - return cli_service.TFetchOrientation_FETCH_NEXT + return nil } // makeRowScanner creates the embedded RowScanner instance based on the format @@ -549,6 +495,7 @@ func (r *rows) makeRowScanner(fetchResults *cli_service.TFetchResultsResp) dbsql var rs rowscanner.RowScanner var err dbsqlerr.DBError if fetchResults.Results != nil { + if fetchResults.Results.Columns != nil { rs, err = columnbased.NewColumnRowScanner(schema, fetchResults.Results, r.config, r.logger(), r.ctx) } else if fetchResults.Results.ArrowBatches != nil { @@ -560,7 +507,6 @@ func (r *rows) makeRowScanner(fetchResults *cli_service.TFetchResultsResp) dbsql err = dbsqlerr_int.NewDriverError(r.ctx, errRowsUnknowRowType, nil) } - r.pageStartingRowNum = fetchResults.Results.StartRowOffset } else { r.logger().Error().Msg(errRowsUnknowRowType) err = dbsqlerr_int.NewDriverError(r.ctx, errRowsUnknowRowType, nil) @@ -586,3 +532,26 @@ func (r *rows) logger() *dbsqllog.DBSQLLogger { } return r.logger_ } + +func makeResultPageIterator(r *rows) rowscanner.ResultPageIterator { + + var d rowscanner.Delimiter + if r.RowScanner != nil { + d = rowscanner.NewDelimiter(r.RowScanner.Start(), r.RowScanner.Count()) + } else { + d = rowscanner.NewDelimiter(0, 0) + } + + resultPageIterator := rowscanner.NewResultPageIterator( + d, + r.hasMoreRows, + r.maxPageSize, + r.opHandle, + r.client, + r.connId, + r.correlationId, + r.logger(), + ) + + return resultPageIterator +} diff --git a/internal/rows/rows_test.go b/internal/rows/rows_test.go index f44e81a..fc27506 100644 --- a/internal/rows/rows_test.go +++ b/internal/rows/rows_test.go @@ -102,110 +102,26 @@ func TestRowsNextRowInPage(t *testing.T) { assert.False(t, inPage, "next row after current page should return false") } -func TestRowsGetPageFetchDirection(t *testing.T) { - t.Parallel() - var rowSet *rows - - // nil rows instance - direction := rowSet.getPageFetchDirection() - assert.Equal(t, cli_service.TFetchOrientation_FETCH_NEXT, direction, "nil rows instance should return forward direction") - - // default rows instance - rowSet = &rows{schema: &cli_service.TTableSchema{}} - direction = rowSet.getPageFetchDirection() - assert.Equal(t, cli_service.TFetchOrientation_FETCH_NEXT, direction, "default rows instance should return forward direction") - - fetchResults := &cli_service.TFetchResultsResp{} - - // fetchResults has no TRowSet - err := rowSet.makeRowScanner(fetchResults) - assert.EqualError(t, err, "databricks: driver error: "+errRowsUnknowRowType) - direction = rowSet.getPageFetchDirection() - assert.Equal(t, cli_service.TFetchOrientation_FETCH_NEXT, direction, "fetchResults has no TRowSet should return forward direction") - - tRowSet := &cli_service.TRowSet{} - - // default TRowSet - fetchResults.Results = tRowSet - err = rowSet.makeRowScanner(fetchResults) - assert.EqualError(t, err, "databricks: driver error: "+errRowsUnknowRowType) - direction = rowSet.getPageFetchDirection() - assert.Equal(t, cli_service.TFetchOrientation_FETCH_NEXT, direction, "fetchResults has no TRowSet should return forward direction") - - // Set up a result page starting with row 10 and containing 5 rows - tRowSet.StartRowOffset = 10 - tColumn := &cli_service.TColumn{BoolVal: &cli_service.TBoolColumn{Values: []bool{true, false, true, false, true}}} - tRowSet.Columns = []*cli_service.TColumn{tColumn} - - err = rowSet.makeRowScanner(fetchResults) - assert.Nil(t, err) - - // next row number is prior to result page - rowSet.nextRowNumber = 0 - direction = rowSet.getPageFetchDirection() - assert.Equal(t, cli_service.TFetchOrientation_FETCH_PRIOR, direction, "next row number is prior to result page should return reverse direction") - - // next row number is immediately prior to result page - rowSet.nextRowNumber = 9 - direction = rowSet.getPageFetchDirection() - assert.Equal(t, cli_service.TFetchOrientation_FETCH_PRIOR, direction, "next row number is immediately prior to result page should return reverse direction") - - // next row number is first row in page - rowSet.nextRowNumber = 10 - direction = rowSet.getPageFetchDirection() - assert.Equal(t, cli_service.TFetchOrientation_FETCH_NEXT, direction, "fetchResults has no TRowSet should return forward direction") - - // next row is last row in page - rowSet.nextRowNumber = 14 - direction = rowSet.getPageFetchDirection() - assert.Equal(t, cli_service.TFetchOrientation_FETCH_NEXT, direction, "fetchResults has no TRowSet should return forward direction") - - // next row is in page - rowSet.nextRowNumber = 12 - direction = rowSet.getPageFetchDirection() - assert.Equal(t, cli_service.TFetchOrientation_FETCH_NEXT, direction, "fetchResults has no TRowSet should return forward direction") - - // next row immediately follows current page - rowSet.nextRowNumber = 15 - direction = rowSet.getPageFetchDirection() - assert.Equal(t, cli_service.TFetchOrientation_FETCH_NEXT, direction, "fetchResults has no TRowSet should return forward direction") - - // next row after current page - rowSet.nextRowNumber = 100 - direction = rowSet.getPageFetchDirection() - assert.Equal(t, cli_service.TFetchOrientation_FETCH_NEXT, direction, "fetchResults has no TRowSet should return forward direction") -} - func TestRowsGetPageStartRowNum(t *testing.T) { t.Parallel() - var noRows int64 = 0 var sevenRows int64 = 7 rowSet := &rows{schema: &cli_service.TTableSchema{}} - start := rowSet.pageStartingRowNum - assert.Equal(t, noRows, start, "rows with no page should return 0") - err := rowSet.makeRowScanner(&cli_service.TFetchResultsResp{}) assert.EqualError(t, err, "databricks: driver error: "+errRowsUnknowRowType) - start = rowSet.pageStartingRowNum - assert.Equal(t, noRows, start, "rows with no TRowSet should return 0") - err = rowSet.makeRowScanner(&cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{}}) assert.EqualError(t, err, "databricks: driver error: "+errRowsUnknowRowType) - start = rowSet.pageStartingRowNum - assert.Equal(t, noRows, start, "rows with default TRowSet should return 0") - err = rowSet.makeRowScanner(&cli_service.TFetchResultsResp{Results: &cli_service.TRowSet{StartRowOffset: 7, Columns: []*cli_service.TColumn{}}}) assert.Nil(t, err) - start = rowSet.pageStartingRowNum + start := rowSet.RowScanner.Start() assert.Equal(t, sevenRows, start, "rows with TRowSet should return TRowSet.StartRowOffset") } -func TestRowsFetchResultPageErrors(t *testing.T) { +func TestRowsFetchMakeRowScanner(t *testing.T) { t.Parallel() var rowSet *rows @@ -217,8 +133,6 @@ func TestRowsFetchResultPageErrors(t *testing.T) { client: &cli_service.TCLIServiceClient{}, schema: &cli_service.TTableSchema{}, } - err = rowSet.fetchResultPage() - assert.EqualError(t, err, "databricks: driver error: "+errRowsFetchPriorToStart, "negative row number should return error") err = rowSet.makeRowScanner(&cli_service.TFetchResultsResp{}) assert.EqualError(t, err, "databricks: driver error: "+errRowsUnknowRowType) @@ -237,8 +151,6 @@ func TestRowsFetchResultPageErrors(t *testing.T) { err = rowSet.makeRowScanner(&cli_service.TFetchResultsResp{Results: tRowSet, HasMoreRows: &noMoreRows}) assert.Nil(t, err) - err = rowSet.fetchResultPage() - assert.EqualError(t, err, io.EOF.Error(), "row number past end of result set should return EOF") } func TestGetResultMetadataNoDirectResults(t *testing.T) { @@ -335,54 +247,46 @@ func TestRowsFetchResultPageNoDirectResults(t *testing.T) { offset: int64(5), }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) - // next row number is two, should fetch previous result page + // next row number is two, can't fetch previous result page rowSet.nextRowNumber = 2 err = rowSet.fetchResultPage() + assert.EqualError(t, err, "can't go backward") + + // next row number is past end of next result page + rowSet.nextRowNumber = 15 + err = rowSet.fetchResultPage() + assert.EqualError(t, err, "Invalid row number state") + + rowSet.nextRowNumber = 12 + err = rowSet.fetchResultPage() + rowTestPagingResult{ getMetadataCount: 1, fetchResultsCount: 3, nextRowIndex: int64(2), - nextRowNumber: int64(2), - offset: i64Zero, + nextRowNumber: int64(12), + offset: int64(10), }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) - // next row number is past end of results, should fetch all result pages - // going forward and then return EOF rowSet.nextRowNumber = 15 err = rowSet.fetchResultPage() errMsg := io.EOF.Error() - rowTestPagingResult{ - getMetadataCount: 1, - fetchResultsCount: 5, - nextRowIndex: int64(2), - nextRowNumber: int64(15), - offset: int64(10), - errMessage: &errMsg, - }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) + assert.EqualError(t, err, errMsg) // next row number is before start of results, should fetch all result pages // going forward and then return EOF rowSet.nextRowNumber = -1 err = rowSet.fetchResultPage() - errMsg = "databricks: driver error: " + errRowsFetchPriorToStart - rowTestPagingResult{ - getMetadataCount: 1, - fetchResultsCount: 7, - nextRowIndex: int64(2), - nextRowNumber: int64(-1), - offset: i64Zero, - errMessage: &errMsg, - }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) + assert.EqualError(t, err, "can't go backward") // jump back to last page - rowSet.nextRowNumber = 12 + rowSet.nextRowNumber = 13 err = rowSet.fetchResultPage() - errMsg = "databricks: driver error: " + errRowsFetchPriorToStart rowTestPagingResult{ getMetadataCount: 1, - fetchResultsCount: 9, - nextRowIndex: int64(2), - nextRowNumber: int64(12), + fetchResultsCount: 3, + nextRowIndex: int64(3), + nextRowNumber: int64(13), offset: int64(10), }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) } @@ -441,49 +345,36 @@ func TestRowsFetchResultPageWithDirectResults(t *testing.T) { // next row number is two, should fetch previous result page rowSet.nextRowNumber = 2 err = rowSet.fetchResultPage() - rowTestPagingResult{ - getMetadataCount: 1, - fetchResultsCount: 3, - nextRowIndex: int64(2), - nextRowNumber: int64(2), - offset: i64Zero, - }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) + assert.EqualError(t, err, "can't go backward") // next row number is past end of results, should fetch all result pages // going forward and then return EOF rowSet.nextRowNumber = 15 err = rowSet.fetchResultPage() - errMsg := io.EOF.Error() + assert.EqualError(t, err, "Invalid row number state") + + rowSet.nextRowNumber = 10 + err = rowSet.fetchResultPage() rowTestPagingResult{ getMetadataCount: 1, - fetchResultsCount: 5, - nextRowIndex: int64(2), - nextRowNumber: int64(15), + fetchResultsCount: 3, + nextRowIndex: int64(0), + nextRowNumber: int64(10), offset: int64(10), - errMessage: &errMsg, }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) - // next row number is before start of results, should fetch all result pages - // going forward and then return EOF - rowSet.nextRowNumber = -1 + rowSet.nextRowNumber = 15 err = rowSet.fetchResultPage() - errMsg = "databricks: driver error: " + errRowsFetchPriorToStart - rowTestPagingResult{ - getMetadataCount: 1, - fetchResultsCount: 7, - nextRowIndex: int64(2), - nextRowNumber: int64(-1), - offset: i64Zero, - errMessage: &errMsg, - }.validatePaging(t, rowSet, err, fetchResultsCount, getMetadataCount) + errMsg := io.EOF.Error() + assert.EqualError(t, err, errMsg) // jump back to last page rowSet.nextRowNumber = 12 err = rowSet.fetchResultPage() - errMsg = "databricks: driver error: " + errRowsFetchPriorToStart + rowTestPagingResult{ getMetadataCount: 1, - fetchResultsCount: 9, + fetchResultsCount: 3, nextRowIndex: int64(2), nextRowNumber: int64(12), offset: int64(10), @@ -590,7 +481,7 @@ func TestNextNoDirectResults(t *testing.T) { assert.Nil(t, err) assert.Equal(t, row0, row) assert.Equal(t, int64(1), rowSet.nextRowNumber) - assert.Equal(t, int64(1), rowSet.nextRowIndex) + assert.Equal(t, int64(1), rowSet.nextRowIndex()) assert.Equal(t, 1, getMetadataCount) assert.Equal(t, 1, fetchResultsCount) } @@ -646,7 +537,7 @@ func TestNextWithDirectResults(t *testing.T) { assert.Nil(t, err) assert.Equal(t, row0, row) assert.Equal(t, int64(1), rowSet.nextRowNumber) - assert.Equal(t, int64(1), rowSet.nextRowIndex) + assert.Equal(t, int64(1), rowSet.nextRowIndex()) assert.Equal(t, 2, getMetadataCount) assert.Equal(t, 1, fetchResultsCount) } @@ -878,9 +769,9 @@ type rowTestPagingResult struct { func (rt rowTestPagingResult) validatePaging(t *testing.T, rowSet *rows, err error, fetchResultsCount, getMetadataCount int) { assert.Equal(t, rt.fetchResultsCount, fetchResultsCount) assert.Equal(t, rt.getMetadataCount, getMetadataCount) - assert.Equal(t, rt.nextRowIndex, rowSet.nextRowIndex) + assert.Equal(t, rt.nextRowIndex, rowSet.nextRowIndex()) assert.Equal(t, rt.nextRowNumber, rowSet.nextRowNumber) - assert.Equal(t, rt.offset, rowSet.pageStartingRowNum) + assert.Equal(t, rt.offset, rowSet.RowScanner.Start()) if rt.errMessage == nil { assert.Nil(t, err) } else { diff --git a/internal/rows/rowscanner/resultPageIterator.go b/internal/rows/rowscanner/resultPageIterator.go new file mode 100644 index 0000000..5016378 --- /dev/null +++ b/internal/rows/rowscanner/resultPageIterator.go @@ -0,0 +1,240 @@ +package rowscanner + +import ( + "context" + "fmt" + "io" + + "github.com/databricks/databricks-sql-go/driverctx" + "github.com/databricks/databricks-sql-go/internal/cli_service" + dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" + dbsqllog "github.com/databricks/databricks-sql-go/logger" +) + +var errRowsResultFetchFailed = "databricks: Rows instance failed to retrieve results" +var ErrRowsFetchPriorToStart = "databricks: unable to fetch row page prior to start of results" +var errRowsNilResultPageFetcher = "databricks: nil ResultPageFetcher instance" + +func errRowsUnandledFetchDirection(dir string) string { + return fmt.Sprintf("databricks: unhandled fetch direction %s", dir) +} + +// Interface for iterating over the pages in the result set of a query +type ResultPageIterator interface { + Next() (*cli_service.TFetchResultsResp, error) + HasNext() bool + Delimiter +} + +// Define directions for seeking in the pages of a query result +type Direction int + +const ( + DirUnknown Direction = iota + DirNone + DirForward + DirBack +) + +var directionNames []string = []string{"Unknown", "None", "Forward", "Back"} + +func (d Direction) String() string { + return directionNames[d] +} + +// Create a new result page iterator. +func NewResultPageIterator( + delimiter Delimiter, + hasMoreRows bool, + maxPageSize int64, + opHandle *cli_service.TOperationHandle, + client cli_service.TCLIService, + connectionId string, + correlationId string, + logger *dbsqllog.DBSQLLogger, +) ResultPageIterator { + // delimiter and hasMoreRows are used to set up the point in the paginated + // result set that this iterator starts from. + return &resultPageIterator{ + Delimiter: delimiter, + isFinished: !hasMoreRows, + maxPageSize: maxPageSize, + opHandle: opHandle, + client: client, + connectionId: connectionId, + correlationId: correlationId, + logger: logger, + } +} + +type resultPageIterator struct { + // Gives the parameters of the current result page + Delimiter + + // indicates whether there are any more pages in the result set + isFinished bool + + // max number of rows to fetch in a page + maxPageSize int64 + + // handle of the operation producing the result set + opHandle *cli_service.TOperationHandle + + // client for communicating with the server + client cli_service.TCLIService + + // connectionId to include in logging messages + connectionId string + + // user provided value to include in logging messages + correlationId string + + logger *dbsqllog.DBSQLLogger +} + +var _ ResultPageIterator = (*resultPageIterator)(nil) + +// Returns true if there are more pages in the result set. +func (rpf *resultPageIterator) HasNext() bool { return !rpf.isFinished } + +// Returns the next page of the result set. io.EOF will be returned if there are +// no more pages. +func (rpf *resultPageIterator) Next() (*cli_service.TFetchResultsResp, error) { + + if rpf == nil { + return nil, dbsqlerrint.NewDriverError(context.Background(), errRowsNilResultPageFetcher, nil) + } + + if rpf.isFinished { + return nil, io.EOF + } + + // Starting row number of next result pag. This is used to check that the returned page is + // the expected one. + nextPageStartRow := rpf.Start() + rpf.Count() + + rpf.logger.Debug().Msgf("databricks: fetching result page for row %d", nextPageStartRow) + ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), rpf.connectionId), rpf.correlationId) + + // Keep fetching in the appropriate direction until we have the expected page. + var fetchResult *cli_service.TFetchResultsResp + var b bool + for b = rpf.Contains(nextPageStartRow); !b; b = rpf.Contains(nextPageStartRow) { + + direction := rpf.Direction(nextPageStartRow) + err := rpf.checkDirectionValid(ctx, direction) + if err != nil { + return nil, err + } + + rpf.logger.Debug().Msgf("fetching next batch of up to %d rows, %s", rpf.maxPageSize, direction.String()) + + var includeResultSetMetadata = true + req := cli_service.TFetchResultsReq{ + OperationHandle: rpf.opHandle, + MaxRows: rpf.maxPageSize, + Orientation: directionToSparkDirection(direction), + IncludeResultSetMetadata: &includeResultSetMetadata, + } + + fetchResult, err = rpf.client.FetchResults(ctx, &req) + if err != nil { + rpf.logger.Err(err).Msg("databricks: Rows instance failed to retrieve results") + return nil, dbsqlerrint.NewRequestError(ctx, errRowsResultFetchFailed, err) + } + + rpf.Delimiter = NewDelimiter(fetchResult.Results.StartRowOffset, CountRows(fetchResult.Results)) + if fetchResult.HasMoreRows != nil { + rpf.isFinished = !*fetchResult.HasMoreRows + } else { + rpf.isFinished = true + } + rpf.logger.Debug().Msgf("databricks: new result page startRow: %d, nRows: %v, hasMoreRows: %v", rpf.Start(), rpf.Count(), fetchResult.HasMoreRows) + } + + return fetchResult, nil +} + +// countRows returns the number of rows in the TRowSet +func CountRows(rowSet *cli_service.TRowSet) int64 { + if rowSet == nil { + return 0 + } + + if rowSet.ArrowBatches != nil { + batches := rowSet.ArrowBatches + var n int64 + for i := range batches { + n += batches[i].RowCount + } + return n + } + + if rowSet.ResultLinks != nil { + links := rowSet.ResultLinks + var n int64 + for i := range links { + n += links[i].RowCount + } + return n + } + + if rowSet != nil && rowSet.Columns != nil { + // Find a column/values and return the number of values. + for _, col := range rowSet.Columns { + if col.BoolVal != nil { + return int64(len(col.BoolVal.Values)) + } + if col.ByteVal != nil { + return int64(len(col.ByteVal.Values)) + } + if col.I16Val != nil { + return int64(len(col.I16Val.Values)) + } + if col.I32Val != nil { + return int64(len(col.I32Val.Values)) + } + if col.I64Val != nil { + return int64(len(col.I64Val.Values)) + } + if col.StringVal != nil { + return int64(len(col.StringVal.Values)) + } + if col.DoubleVal != nil { + return int64(len(col.DoubleVal.Values)) + } + if col.BinaryVal != nil { + return int64(len(col.BinaryVal.Values)) + } + } + } + return 0 +} + +// Check if trying to fetch in the specified direction creates an error condition. +func (rpf *resultPageIterator) checkDirectionValid(ctx context.Context, direction Direction) error { + if direction == DirBack { + // can't fetch rows previous to the start + if rpf.Start() == 0 { + return dbsqlerrint.NewDriverError(ctx, ErrRowsFetchPriorToStart, nil) + } + } else if direction == DirForward { + // can't fetch past the end of the query results + if rpf.isFinished { + return io.EOF + } + } else { + rpf.logger.Error().Msgf(errRowsUnandledFetchDirection(direction.String())) + return dbsqlerrint.NewDriverError(ctx, errRowsUnandledFetchDirection(direction.String()), nil) + } + return nil +} + +func directionToSparkDirection(d Direction) cli_service.TFetchOrientation { + switch d { + case DirBack: + return cli_service.TFetchOrientation_FETCH_PRIOR + default: + return cli_service.TFetchOrientation_FETCH_NEXT + } +} diff --git a/internal/rows/rowscanner/resultPageIterator_test.go b/internal/rows/rowscanner/resultPageIterator_test.go new file mode 100644 index 0000000..4e1fe65 --- /dev/null +++ b/internal/rows/rowscanner/resultPageIterator_test.go @@ -0,0 +1,150 @@ +package rowscanner + +import ( + "context" + "testing" + + "github.com/databricks/databricks-sql-go/internal/cli_service" + "github.com/databricks/databricks-sql-go/internal/client" + dbsqllog "github.com/databricks/databricks-sql-go/logger" + "github.com/stretchr/testify/assert" +) + +func TestFetchResultPageination(t *testing.T) { + t.Parallel() + + fetches := []fetch{} + pageSequence := []int{0, 3, 2, 0, 1, 2} + client := getSimpleClient(&fetches, pageSequence) + + rf := &resultPageIterator{ + Delimiter: NewDelimiter(0, 0), + client: client, + logger: dbsqllog.WithContext("connId", "correlationId", ""), + connectionId: "connId", + correlationId: "correlationId", + } + + // next row number is zero so should fetch first result page + _, err := rf.Next() + assert.Nil(t, err) + assert.Len(t, fetches, 1) + assert.Equal(t, fetches[0].direction, cli_service.TFetchOrientation_FETCH_NEXT) + + // The test client returns rows + _, err = rf.Next() + assert.Nil(t, err) + assert.Len(t, fetches, 5) + expected := []fetch{ + {direction: cli_service.TFetchOrientation_FETCH_NEXT, resultStartRec: 0}, + {direction: cli_service.TFetchOrientation_FETCH_NEXT, resultStartRec: 15}, + {direction: cli_service.TFetchOrientation_FETCH_PRIOR, resultStartRec: 10}, + {direction: cli_service.TFetchOrientation_FETCH_PRIOR, resultStartRec: 0}, + {direction: cli_service.TFetchOrientation_FETCH_NEXT, resultStartRec: 5}, + } + assert.Equal(t, expected, fetches) +} + +type fetch struct { + direction cli_service.TFetchOrientation + resultStartRec int +} + +// Build a simple test client +func getSimpleClient(fetches *[]fetch, pageSequence []int) cli_service.TCLIService { + // We are simulating the scenario where network errors and retry behaviour cause the fetch + // result request to be sent multiple times, resulting in jumping past the next/previous result + // page. Behaviour should be robust enough to handle this by changing the fetch orientation. + + // Metadata for the different types is based on the results returned when querying a table with + // all the different types which was created in a test shard. + metadata := &cli_service.TGetResultSetMetadataResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + Schema: &cli_service.TTableSchema{ + Columns: []*cli_service.TColumnDesc{ + { + ColumnName: "bool_col", + TypeDesc: &cli_service.TTypeDesc{ + Types: []*cli_service.TTypeEntry{ + { + PrimitiveEntry: &cli_service.TPrimitiveTypeEntry{ + Type: cli_service.TTypeId_BOOLEAN_TYPE, + }, + }, + }, + }, + }, + }, + }, + } + + getMetadata := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (_r *cli_service.TGetResultSetMetadataResp, _err error) { + return metadata, nil + } + + moreRows := true + noMoreRows := false + colVals := []*cli_service.TColumn{{BoolVal: &cli_service.TBoolColumn{Values: []bool{true, false, true, false, true}}}} + + pages := []*cli_service.TFetchResultsResp{ + { + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + HasMoreRows: &moreRows, + Results: &cli_service.TRowSet{ + StartRowOffset: 0, + Columns: colVals, + }, + }, + { + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + HasMoreRows: &moreRows, + Results: &cli_service.TRowSet{ + StartRowOffset: 5, + Columns: colVals, + }, + }, + { + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + HasMoreRows: &noMoreRows, + Results: &cli_service.TRowSet{ + StartRowOffset: 10, + Columns: colVals, + }, + }, + { + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + HasMoreRows: &noMoreRows, + Results: &cli_service.TRowSet{ + StartRowOffset: 15, + Columns: []*cli_service.TColumn{}, + }, + }, + } + + pageIndex := -1 + + fetchResults := func(ctx context.Context, req *cli_service.TFetchResultsReq) (_r *cli_service.TFetchResultsResp, _err error) { + pageIndex++ + + p := pages[pageSequence[pageIndex]] + *fetches = append(*fetches, fetch{direction: req.Orientation, resultStartRec: int(p.Results.StartRowOffset)}) + return p, nil + } + + client := &client.TestClient{ + FnGetResultSetMetadata: getMetadata, + FnFetchResults: fetchResults, + } + + return client +} diff --git a/internal/rows/rowscanner/rowScanner.go b/internal/rows/rowscanner/rowScanner.go index 9fdf298..0e58b5b 100644 --- a/internal/rows/rowscanner/rowScanner.go +++ b/internal/rows/rowscanner/rowScanner.go @@ -13,6 +13,7 @@ import ( // RowScanner is an interface defining the behaviours that are specific to // the formats in which query results can be returned. type RowScanner interface { + Delimiter // ScanRow is called to populate the provided slice with the // content of the current row. The provided slice will be the same // size as the number of columns. @@ -44,6 +45,47 @@ func IsNull(nulls []byte, position int64) bool { return false } +type Delimiter interface { + Start() int64 + End() int64 + Count() int64 + Contains(int64) bool + Direction(int64) Direction +} + +func NewDelimiter(start, count int64) Delimiter { + return delimiter{ + start: start, + count: count, + end: start + count - 1, + } +} + +type delimiter struct { + start int64 + end int64 + count int64 +} + +func (d delimiter) Start() int64 { return d.start } +func (d delimiter) End() int64 { return d.end } +func (d delimiter) Count() int64 { return d.count } +func (d delimiter) Contains(i int64) bool { return d.count > 0 && i >= d.start && i <= d.end } +func (d delimiter) Direction(i int64) Direction { + + if d.Contains(i) { + return DirNone + } else if i < d.Start() { + return DirBack + } else if i > d.End() { + return DirForward + } else if d.Count() == 0 { + return DirForward + } else { + return DirUnknown + } +} + var ErrRowsParseValue = "databricks: unable to parse %s value '%v' from column %s" // handleDateTime will convert the passed val to a time.Time value if necessary diff --git a/internal/rows/rowscanner/rowscanner_test.go b/internal/rows/rowscanner/rowscanner_test.go index 7d21f2f..aabd5d0 100644 --- a/internal/rows/rowscanner/rowscanner_test.go +++ b/internal/rows/rowscanner/rowscanner_test.go @@ -2,10 +2,12 @@ package rowscanner import ( "fmt" + "io" "strings" "testing" "time" + dbsqllog "github.com/databricks/databricks-sql-go/logger" "github.com/stretchr/testify/assert" ) @@ -65,3 +67,52 @@ func TestHandlingDateTime(t *testing.T) { } }) } + +func TestRowsFetchResultPageErrors(t *testing.T) { + t.Parallel() + + var fetcher *resultPageIterator + + _, err := fetcher.Next() + assert.EqualError(t, err, "databricks: driver error: "+errRowsNilResultPageFetcher) + + fetcher = &resultPageIterator{ + Delimiter: NewDelimiter(0, -1), + logger: dbsqllog.WithContext("", "", ""), + } + + _, err = fetcher.Next() + assert.EqualError(t, err, "databricks: driver error: "+ErrRowsFetchPriorToStart, "negative row number should return error") + + fetcher = &resultPageIterator{ + Delimiter: NewDelimiter(0, 0), + isFinished: true, + logger: dbsqllog.WithContext("", "", ""), + } + + _, err = fetcher.Next() + assert.EqualError(t, err, io.EOF.Error(), "row number past end of result set should return EOF") +} + +func TestDelimiter(t *testing.T) { + t.Parallel() + + var d Delimiter = delimiter{} + + assert.False(t, d.Contains(0)) + assert.False(t, d.Contains(1)) + assert.False(t, d.Contains(-1)) + assert.Equal(t, DirForward, d.Direction(0)) + assert.Equal(t, DirForward, d.Direction(1)) + assert.Equal(t, DirBack, d.Direction(-1)) + + d = NewDelimiter(0, 5) + assert.True(t, d.Contains(0)) + assert.True(t, d.Contains(4)) + assert.False(t, d.Contains(-1)) + assert.False(t, d.Contains(5)) + assert.Equal(t, DirNone, d.Direction(0)) + assert.Equal(t, DirNone, d.Direction(4)) + assert.Equal(t, DirForward, d.Direction(5)) + assert.Equal(t, DirBack, d.Direction(-1)) +}