Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions arrow/cdata/cdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ func (imp *cimporter) doImportChildren() error {
st := imp.dt.(*arrow.StructType)
for i, c := range children {
imp.children[i].dt = st.Field(i).Type
imp.children[i].importChild(imp, c)
if err := imp.children[i].importChild(imp, c); err != nil {
return err
}
}
case arrow.RUN_END_ENCODED: // import run-ends and values
st := imp.dt.(*arrow.RunEndEncodedType)
Expand All @@ -428,13 +430,17 @@ func (imp *cimporter) doImportChildren() error {
dt := imp.dt.(*arrow.DenseUnionType)
for i, c := range children {
imp.children[i].dt = dt.Fields()[i].Type
imp.children[i].importChild(imp, c)
if err := imp.children[i].importChild(imp, c); err != nil {
return err
}
}
case arrow.SPARSE_UNION:
dt := imp.dt.(*arrow.SparseUnionType)
for i, c := range children {
imp.children[i].dt = dt.Fields()[i].Type
imp.children[i].importChild(imp, c)
if err := imp.children[i].importChild(imp, c); err != nil {
return err
}
}
}

Expand All @@ -461,33 +467,28 @@ func (imp *cimporter) doImportArr(src *CArrowArray) error {
// and only null columns, then we can release the CArrowArray
// struct immediately after import, since we have no imported
// memory that we have to track the lifetime of.
// On error, we always release regardless of buffer count to avoid leaks.
var importErr error
defer func() {
if imp.alloc.bufCount.Load() == 0 {
C.ArrowArrayRelease(imp.arr)
C.free(unsafe.Pointer(imp.arr))
if importErr != nil || imp.alloc.bufCount.Load() == 0 {
imp.alloc.forceRelease()
}
}()

return imp.doImport()
importErr = imp.doImport()
return importErr
}

// import is called recursively as needed for importing an array and its children
// in order to generate array.Data objects
func (imp *cimporter) doImport() error {
// move the array from the src object passed in to the one referenced by
// this importer. That way we can set up a finalizer on the created
// arrow.ArrayData object so we clean up our Array's memory when garbage collected.
defer func(arr *CArrowArray) {
// this should only occur in the case of an error happening
// during import, at which point we need to clean up the
// ArrowArray struct we allocated.
if imp.data == nil {
C.free(unsafe.Pointer(arr))
}
}(imp.arr)

// import any children
if err := imp.doImportChildren(); err != nil {
for _, c := range imp.children {
if c.data != nil {
c.data.Release()
}
}
return err
}

Expand Down
106 changes: 104 additions & 2 deletions arrow/cdata/cdata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,8 @@ func createTestDenseUnion() arrow.Array {

func createTestUnionArr(mode arrow.UnionMode) arrow.Array {
fields := []arrow.Field{
arrow.Field{Name: "u0", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
arrow.Field{Name: "u1", Type: arrow.PrimitiveTypes.Uint8, Nullable: true},
{Name: "u0", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
{Name: "u1", Type: arrow.PrimitiveTypes.Uint8, Nullable: true},
}
typeCodes := []arrow.UnionTypeCode{5, 10}
bld := array.NewBuilder(memory.DefaultAllocator, arrow.UnionOf(mode, fields, typeCodes)).(array.UnionBuilder)
Expand Down Expand Up @@ -785,6 +785,104 @@ func TestRecordBatch(t *testing.T) {
assert.True(t, array.RecordEqual(rb, rec))
}

func TestImportStructWithInvalidSchema(t *testing.T) {
mem := mallocator.NewMallocator()
defer mem.AssertSize(t, 0)

arr := createTestStructArr()
defer arr.Release()

carr := createCArr(arr, mem)
defer freeTestMallocatorArr(carr, mem)

sc := testStruct([]string{"+s", "c", "l"}, []string{"", "a", "b"}, []int64{0, flagIsNullable, flagIsNullable})
defer freeMallocedSchemas(sc)

top := (*[1]*CArrowSchema)(unsafe.Pointer(sc))[0]
_, err := ImportCRecordBatch(carr, top)
assert.Error(t, err)
}

func TestImportDenseUnionWithInvalidSchema(t *testing.T) {
mem := mallocator.NewMallocator()
defer mem.AssertSize(t, 0)

unionArr := createTestDenseUnion()
defer unionArr.Release()

structBld := array.NewStructBuilder(memory.DefaultAllocator, arrow.StructOf(
arrow.Field{Name: "union_field", Type: unionArr.DataType(), Nullable: false},
))
defer structBld.Release()

unionBld := structBld.FieldBuilder(0).(*array.DenseUnionBuilder)
structBld.Append(true)
du := unionArr.(*array.DenseUnion)
for i := 0; i < du.Len(); i++ {
unionBld.Append(du.TypeCode(i))
if du.TypeCode(i) == 5 {
unionBld.Child(0).(*array.Int32Builder).Append(du.Field(0).(*array.Int32).Value(int(du.ValueOffset(i))))
} else {
unionBld.Child(1).(*array.Uint8Builder).Append(du.Field(1).(*array.Uint8).Value(int(du.ValueOffset(i))))
}
}

structArr := structBld.NewArray()
defer structArr.Release()

carr := createCArr(structArr, mem)
defer freeTestMallocatorArr(carr, mem)

// Create an invalid schema: wrong type for union field (using "i" instead of proper union schema)
sc := testStruct([]string{"+s", "i"}, []string{"", "union_field"}, []int64{0, flagIsNullable})
defer freeMallocedSchemas(sc)

top := (*[1]*CArrowSchema)(unsafe.Pointer(sc))[0]
_, err := ImportCRecordBatch(carr, top)
assert.Error(t, err)
}

func TestImportSPARSEUnionWithInvalidSchema(t *testing.T) {
mem := mallocator.NewMallocator()
defer mem.AssertSize(t, 0)

unionArr := createTestSparseUnion()
defer unionArr.Release()

structBld := array.NewStructBuilder(memory.DefaultAllocator, arrow.StructOf(
arrow.Field{Name: "union_field", Type: unionArr.DataType(), Nullable: false},
))
defer structBld.Release()

unionBld := structBld.FieldBuilder(0).(*array.SparseUnionBuilder)
structBld.Append(true)
su := unionArr.(*array.SparseUnion)
for i := 0; i < su.Len(); i++ {
unionBld.Append(su.TypeCode(i))
if su.TypeCode(i) == 5 {
unionBld.Child(0).(*array.Int32Builder).Append(su.Field(0).(*array.Int32).Value(i))
unionBld.Child(1).(*array.Uint8Builder).AppendNull()
} else {
unionBld.Child(0).(*array.Int32Builder).AppendNull()
unionBld.Child(1).(*array.Uint8Builder).Append(su.Field(1).(*array.Uint8).Value(i))
}
}

structArr := structBld.NewArray()
defer structArr.Release()

carr := createCArr(structArr, mem)
defer freeTestMallocatorArr(carr, mem)

// Create an invalid schema: wrong type for union field (using "u" instead of proper union schema)
sc := testStruct([]string{"+s", "u"}, []string{"", "union_field"}, []int64{0, flagIsNullable})
defer freeMallocedSchemas(sc)

top := (*[1]*CArrowSchema)(unsafe.Pointer(sc))[0]
_, err := ImportCRecordBatch(carr, top)
assert.Error(t, err)
}

func TestRecordReaderStream(t *testing.T) {
stream := arrayStreamTest()
defer releaseStreamTest(stream)
Expand Down Expand Up @@ -1006,17 +1104,21 @@ func (r *failingReader) Schema() *arrow.Schema {
}
return arrdata.Records["primitives"][0].Schema()
}

func (r *failingReader) Next() bool {
r.opCount -= 1
return r.opCount > 0
}

func (r *failingReader) RecordBatch() arrow.RecordBatch {
arrdata.Records["primitives"][0].Retain()
return arrdata.Records["primitives"][0]
}

func (r *failingReader) Record() arrow.Record {
return r.RecordBatch()
}

func (r *failingReader) Err() error {
if r.opCount == 0 {
return fmt.Errorf("Expected error message")
Expand Down
7 changes: 7 additions & 0 deletions arrow/cdata/import_allocator.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import "C"

type importAllocator struct {
bufCount atomic.Int64
released atomic.Bool

arr *CArrowArray
}
Expand All @@ -49,6 +50,12 @@ func (i *importAllocator) Free([]byte) {
debug.Assert(i.bufCount.Load() > 0, "too many releases")

if i.bufCount.Add(-1) == 0 {
i.forceRelease()
}
}

func (i *importAllocator) forceRelease() {
if i.released.CompareAndSwap(false, true) {
defer C.free(unsafe.Pointer(i.arr))
C.ArrowArrayRelease(i.arr)
if C.ArrowArrayIsReleased(i.arr) != 1 {
Expand Down