Skip to content

Commit

Permalink
add a test
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Nov 19, 2024
1 parent db110c2 commit 9f20820
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 18 deletions.
15 changes: 6 additions & 9 deletions python/cudf/cudf/_lib/orc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ from libc.stdint cimport int64_t
from libcpp cimport bool, int
from libcpp.map cimport map
from libcpp.string cimport string
from libcpp.utility cimport move
from libcpp.vector cimport vector
import itertools
from collections import OrderedDict
Expand Down Expand Up @@ -236,10 +235,8 @@ def write_orc(
--------
cudf.read_orc
"""
cdef map[string, string] user_data
user_data[str.encode("pandas")] = str.encode(generate_pandas_metadata(
table, index)
)
user_data = {}
user_data["pandas"] = generate_pandas_metadata(table, index)
if index is True or (
index is None and not isinstance(table._index, cudf.RangeIndex)
):
Expand Down Expand Up @@ -287,7 +284,7 @@ def write_orc(
plc.io.SinkInfo([path_or_buf]), plc_table
)
.metadata(tbl_meta)
.key_value_metadata(move(user_data))
.key_value_metadata(user_data)
.compression(_get_comp_type(compression))
.enable_statistics(_get_orc_stat_freq(statistics))
.build()
Expand Down Expand Up @@ -438,14 +435,14 @@ cdef class ORCWriter:
and (name in self.cols_as_map_type),
)

cdef map[string, string] user_data
user_data = {}
pandas_metadata = generate_pandas_metadata(table, self.index)
user_data[str.encode("pandas")] = str.encode(pandas_metadata)
user_data["pandas"] = pandas_metadata

options = (
plc.io.orc.ChunkedOrcWriterOptions.builder(self.sink)
.metadata(self.tbl_meta)
.key_value_metadata(move(user_data))
.key_value_metadata(user_data)
.compression(_get_comp_type(self.compression))
.enable_statistics(_get_orc_stat_freq(self.statistics))
.build()
Expand Down
4 changes: 2 additions & 2 deletions python/pylibcudf/pylibcudf/io/orc.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ cdef class OrcWriterOptionsBuilder:
cdef SinkInfo sink
cpdef OrcWriterOptionsBuilder compression(self, compression_type comp)
cpdef OrcWriterOptionsBuilder enable_statistics(self, statistics_freq val)
cpdef OrcWriterOptionsBuilder key_value_metadata(self, map[string, string] kvm)
cpdef OrcWriterOptionsBuilder key_value_metadata(self, dict kvm)
cpdef OrcWriterOptionsBuilder metadata(self, TableInputMetadata meta)
cpdef OrcWriterOptions build(self)

Expand All @@ -105,7 +105,7 @@ cdef class ChunkedOrcWriterOptionsBuilder:
cpdef ChunkedOrcWriterOptionsBuilder compression(self, compression_type comp)
cpdef ChunkedOrcWriterOptionsBuilder enable_statistics(self, statistics_freq val)
cpdef ChunkedOrcWriterOptionsBuilder key_value_metadata(
self, map[string, string] kvm
self, dict kvm
)
cpdef ChunkedOrcWriterOptionsBuilder metadata(self, TableInputMetadata meta)
cpdef ChunkedOrcWriterOptions build(self)
12 changes: 8 additions & 4 deletions python/pylibcudf/pylibcudf/io/orc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,10 @@ cdef class OrcWriterOptionsBuilder:
self.c_obj.enable_statistics(val)
return self

cpdef OrcWriterOptionsBuilder key_value_metadata(self, map[string, string] kvm):
self.c_obj.key_value_metadata(kvm)
cpdef OrcWriterOptionsBuilder key_value_metadata(self, dict kvm):
self.c_obj.key_value_metadata(
{key.encode(): value.encode() for key, value in kvm.items()}
)
return self

cpdef OrcWriterOptionsBuilder metadata(self, TableInputMetadata meta):
Expand Down Expand Up @@ -426,9 +428,11 @@ cdef class ChunkedOrcWriterOptionsBuilder:

cpdef ChunkedOrcWriterOptionsBuilder key_value_metadata(
self,
map[string, string] kvm
dict kvm
):
self.c_obj.key_value_metadata(kvm)
self.c_obj.key_value_metadata(
{key.encode(): value.encode() for key, value in kvm.items()}
)
return self

cpdef ChunkedOrcWriterOptionsBuilder metadata(self, TableInputMetadata meta):
Expand Down
2 changes: 1 addition & 1 deletion python/pylibcudf/pylibcudf/io/types.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -512,5 +512,5 @@ cdef class TableInputMetadata:
table : Table
The Table to construct metadata for
"""
def __cinit__(self, Table table):
def __init__(self, Table table):
self.c_obj = table_input_metadata(table.view())
55 changes: 53 additions & 2 deletions python/pylibcudf/pylibcudf/tests/io/test_orc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
import io

import pyarrow as pa
import pytest
from utils import _convert_types, assert_table_and_meta_eq, make_source
Expand Down Expand Up @@ -54,5 +56,54 @@ def test_read_orc_basic(
assert_table_and_meta_eq(pa_table, res, check_field_nullability=False)


def test_write_orc():
pass
@pytest.mark.parametrize(
"compression",
[
plc.io.types.CompressionType.NONE,
plc.io.types.CompressionType.SNAPPY,
],
)
@pytest.mark.parametrize(
"statistics",
[
plc.io.types.StatisticsFreq.STATISTICS_NONE,
plc.io.types.StatisticsFreq.STATISTICS_COLUMN,
],
)
@pytest.mark.parametrize("stripe_size_bytes", [None, 65536])
@pytest.mark.parametrize("stripe_size_rows", [None, 512])
@pytest.mark.parametrize("row_index_stride", [None, 512])
def test_write_orc(
compression,
statistics,
stripe_size_bytes,
stripe_size_rows,
row_index_stride,
):
names = ["a", "b"]
pa_table = pa.Table.from_arrays(
[pa.array([1.0, 2.0, None]), pa.array([True, None, False])],
names=names,
)
plc_table = plc.interop.from_arrow(pa_table)
tbl_meta = plc.io.types.TableInputMetadata(plc_table)
sink = plc.io.SinkInfo([io.BytesIO()])
user_data = {"foo": "{'bar': 'baz'}"}
options = (
plc.io.orc.OrcWriterOptions.builder(sink, plc_table)
.metadata(tbl_meta)
.key_value_metadata(user_data)
.compression(compression)
.enable_statistics(statistics)
.build()
)
if stripe_size_bytes is not None:
options.set_stripe_size_bytes(stripe_size_bytes)
if stripe_size_rows is not None:
options.set_stripe_size_rows(stripe_size_rows)
if row_index_stride is not None:
options.set_row_index_stride(row_index_stride)

plc.io.orc.write_orc(options)

# pd.read_orc(...)

0 comments on commit 9f20820

Please sign in to comment.