diff --git a/appender.go b/appender.go index 96f1d268..0b3d5b77 100644 --- a/appender.go +++ b/appender.go @@ -139,7 +139,7 @@ func (a *Appender) AppendRow(args ...driver.Value) error { func (a *Appender) addDataChunk() error { var chunk DataChunk - if err := chunk.initFromTypes(a.ptr, a.types); err != nil { + if err := chunk.initFromTypes(a.ptr, a.types, true); err != nil { return err } a.chunks = append(a.chunks, chunk) diff --git a/data_chunk.go b/data_chunk.go index 49311bc2..03d3cdc3 100644 --- a/data_chunk.go +++ b/data_chunk.go @@ -18,6 +18,8 @@ type DataChunk struct { columns []vector // columnNames holds the column names, if known. columnNames []string + // size caches the size after initialization. + size int } // GetDataChunkCapacity returns the capacity of a data chunk. @@ -27,7 +29,8 @@ func GetDataChunkCapacity() int { // GetSize returns the internal size of the data chunk. func (chunk *DataChunk) GetSize() int { - return int(C.duckdb_data_chunk_get_size(chunk.data)) + chunk.size = int(C.duckdb_data_chunk_get_size(chunk.data)) + return chunk.size } // SetSize sets the internal size of the data chunk. Cannot exceed GetCapacity(). @@ -71,7 +74,7 @@ func (chunk *DataChunk) SetValue(colIdx int, rowIdx int, val any) error { return nil } -func (chunk *DataChunk) initFromTypes(ptr unsafe.Pointer, types []C.duckdb_logical_type) error { +func (chunk *DataChunk) initFromTypes(ptr unsafe.Pointer, types []C.duckdb_logical_type, writable bool) error { // NOTE: initFromTypes does not initialize the column names. columnCount := len(types) @@ -93,14 +96,13 @@ func (chunk *DataChunk) initFromTypes(ptr unsafe.Pointer, types []C.duckdb_logic // Initialize the vectors and their child vectors. for i := 0; i < columnCount; i++ { - duckdbVector := C.duckdb_data_chunk_get_vector(chunk.data, C.idx_t(i)) - chunk.columns[i].duckdbVector = duckdbVector - chunk.columns[i].getChildVectors(duckdbVector) + v := C.duckdb_data_chunk_get_vector(chunk.data, C.idx_t(i)) + chunk.columns[i].initVectors(v, writable) } return nil } -func (chunk *DataChunk) initFromDuckDataChunk(data C.duckdb_data_chunk) error { +func (chunk *DataChunk) initFromDuckDataChunk(data C.duckdb_data_chunk, writable bool) error { columnCount := int(C.duckdb_data_chunk_get_column_count(data)) chunk.columns = make([]vector, columnCount) chunk.data = data @@ -117,10 +119,11 @@ func (chunk *DataChunk) initFromDuckDataChunk(data C.duckdb_data_chunk) error { break } - // Initialize the vectors and their child vectors. - chunk.columns[i].duckdbVector = duckdbVector - chunk.columns[i].getChildVectors(duckdbVector) + // Initialize the vector and its child vectors. + chunk.columns[i].initVectors(duckdbVector, writable) } + + chunk.GetSize() return err } diff --git a/errors_test.go b/errors_test.go index 52810782..be1e8811 100644 --- a/errors_test.go +++ b/errors_test.go @@ -407,6 +407,7 @@ type wrappedDuckDBError struct { func (w *wrappedDuckDBError) Error() string { return w.e.Error() } + func (w *wrappedDuckDBError) Unwrap() error { return w.e } diff --git a/rows.go b/rows.go index 0fbdac7a..a2b3d1a4 100644 --- a/rows.go +++ b/rows.go @@ -55,13 +55,13 @@ func (r *rows) Columns() []string { } func (r *rows) Next(dst []driver.Value) error { - for r.rowCount == r.chunk.GetSize() { + for r.rowCount == r.chunk.size { r.chunk.close() if r.chunkIdx == r.chunkCount { return io.EOF } data := C.duckdb_result_get_chunk(r.res, r.chunkIdx) - if err := r.chunk.initFromDuckDataChunk(data); err != nil { + if err := r.chunk.initFromDuckDataChunk(data, false); err != nil { return getError(err, nil) } diff --git a/vector.go b/vector.go index ac05eb2a..b3499332 100644 --- a/vector.go +++ b/vector.go @@ -18,6 +18,10 @@ import ( type vector struct { // The underlying DuckDB vector. duckdbVector C.duckdb_vector + // The underlying data ptr. + ptr unsafe.Pointer + // The vector's validity mask. + mask *C.uint64_t // A callback function to get a value from this vector. getFn fnGetVectorValue // A callback function to write to this vector. @@ -311,19 +315,27 @@ func (vec *vector) init(logicalType C.duckdb_logical_type, colIdx int) error { return nil } -func (vec *vector) getChildVectors(vector C.duckdb_vector) { +func (vec *vector) initVectors(v C.duckdb_vector, writable bool) { + vec.duckdbVector = v + vec.ptr = C.duckdb_vector_get_data(v) + if writable { + C.duckdb_vector_ensure_validity_writable(v) + } + vec.mask = C.duckdb_vector_get_validity(v) + vec.getChildVectors(v, writable) +} + +func (vec *vector) getChildVectors(v C.duckdb_vector, writable bool) { switch vec.duckdbType { case C.DUCKDB_TYPE_LIST, C.DUCKDB_TYPE_MAP: - child := C.duckdb_list_vector_get_child(vector) - vec.childVectors[0].duckdbVector = child - vec.childVectors[0].getChildVectors(child) + child := C.duckdb_list_vector_get_child(v) + vec.childVectors[0].initVectors(child, writable) case C.DUCKDB_TYPE_STRUCT: for i := 0; i < len(vec.childVectors); i++ { - child := C.duckdb_struct_vector_get_child(vector, C.idx_t(i)) - vec.childVectors[i].duckdbVector = child - vec.childVectors[i].getChildVectors(child) + child := C.duckdb_struct_vector_get_child(v, C.idx_t(i)) + vec.childVectors[i].initVectors(child, writable) } } } diff --git a/vector_getters.go b/vector_getters.go index b97a42c3..2a753f53 100644 --- a/vector_getters.go +++ b/vector_getters.go @@ -16,13 +16,20 @@ import ( type fnGetVectorValue func(vec *vector, rowIdx C.idx_t) any func (vec *vector) getNull(rowIdx C.idx_t) bool { - mask := C.duckdb_vector_get_validity(vec.duckdbVector) - return !bool(C.duckdb_validity_row_is_valid(mask, rowIdx)) + mask := unsafe.Pointer(vec.mask) + if mask == nil { + return false + } + + entryIdx := rowIdx / 64 + idxInEntry := rowIdx % 64 + maskPtr := (*[1 << 31]C.uint64_t)(mask) + isValid := maskPtr[entryIdx] & (C.uint64_t(1) << idxInEntry) + return uint64(isValid) == 0 } func getPrimitive[T any](vec *vector, rowIdx C.idx_t) T { - ptr := C.duckdb_vector_get_data(vec.duckdbVector) - xs := (*[1 << 31]T)(ptr) + xs := (*[1 << 31]T)(vec.ptr) return xs[rowIdx] } diff --git a/vector_setters.go b/vector_setters.go index 440c1bd5..cd090a32 100644 --- a/vector_setters.go +++ b/vector_setters.go @@ -19,9 +19,7 @@ const secondsPerDay = 24 * 60 * 60 type fnSetVectorValue func(vec *vector, rowIdx C.idx_t, val any) func (vec *vector) setNull(rowIdx C.idx_t) { - C.duckdb_vector_ensure_validity_writable(vec.duckdbVector) - mask := C.duckdb_vector_get_validity(vec.duckdbVector) - C.duckdb_validity_set_row_invalid(mask, rowIdx) + C.duckdb_validity_set_row_invalid(vec.mask, rowIdx) if vec.duckdbType == C.DUCKDB_TYPE_STRUCT { for i := 0; i < len(vec.childVectors); i++ { @@ -31,8 +29,7 @@ func (vec *vector) setNull(rowIdx C.idx_t) { } func setPrimitive[T any](vec *vector, rowIdx C.idx_t, v T) { - ptr := C.duckdb_vector_get_data(vec.duckdbVector) - xs := (*[1 << 31]T)(ptr) + xs := (*[1 << 31]T)(vec.ptr) xs[rowIdx] = v }