Skip to content

Commit

Permalink
Adapt to IR changes in polars 1.4 (#16494)
Browse files Browse the repository at this point in the history
## Description
<!-- Provide a standalone description of changes in this PR. -->
<!-- Reference any issues closed by this PR with "closes #1234". -->
<!-- Note: The pull request title will be included in the CHANGELOG. -->

Adapts to IR changes in polars 1.4 and handles nrows/skiprows a little
more correctly.

## Checklist
- [ ] I am familiar with the [Contributing
Guidelines](https://github.com/rapidsai/cudf/blob/HEAD/CONTRIBUTING.md).
- [ ] New or existing tests cover these changes.
- [ ] The documentation is up to date with these changes.

---------

Co-authored-by: Lawrence Mitchell <lmitchell@nvidia.com>
  • Loading branch information
lithomas1 and wence- authored Aug 5, 2024
1 parent 62a5dbd commit 7d0c7ad
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 26 deletions.
46 changes: 29 additions & 17 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,14 @@ class Scan(IR):
"""Cloud-related authentication options, currently ignored."""
paths: list[str]
"""List of paths to read from."""
file_options: Any
"""Options for reading the file.
Attributes are:
- ``with_columns: list[str]`` of projected columns to return.
- ``n_rows: int``: Number of rows to read.
- ``row_index: tuple[name, offset] | None``: Add an integer index
column with given name.
"""
with_columns: list[str]
"""Projected columns to return."""
skip_rows: int
"""Rows to skip at the start when reading."""
n_rows: int
"""Number of rows to read after skipping."""
row_index: tuple[str, int] | None
"""If not None add an integer index column of the given name."""
predicate: expr.NamedExpr | None
"""Mask to apply to the read dataframe."""

Expand All @@ -208,8 +207,16 @@ def __post_init__(self) -> None:
# This line is unhittable ATM since IPC/Anonymous scan raise
# on the polars side
raise NotImplementedError(f"Unhandled scan type: {self.typ}")
if self.typ == "ndjson" and self.file_options.n_rows is not None:
raise NotImplementedError("row limit in scan")
if self.typ == "ndjson" and (self.n_rows != -1 or self.skip_rows != 0):
raise NotImplementedError("row limit in scan for json reader")
if self.skip_rows < 0:
# TODO: polars has this implemented for parquet,
# maybe we can do this too?
raise NotImplementedError("slice pushdown for negative slices")
if self.typ == "csv" and self.skip_rows != 0: # pragma: no cover
# This comes from slice pushdown, but that
# optimization doesn't happen right now
raise NotImplementedError("skipping rows in CSV reader")
if self.cloud_options is not None and any(
self.cloud_options.get(k) is not None for k in ("aws", "azure", "gcp")
):
Expand Down Expand Up @@ -246,10 +253,9 @@ def __post_init__(self) -> None:

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
options = self.file_options
with_columns = options.with_columns
row_index = options.row_index
nrows = self.file_options.n_rows if self.file_options.n_rows is not None else -1
with_columns = self.with_columns
row_index = self.row_index
n_rows = self.n_rows
if self.typ == "csv":
parse_options = self.reader_options["parse_options"]
sep = chr(parse_options["separator"])
Expand Down Expand Up @@ -283,6 +289,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:

# polars skips blank lines at the beginning of the file
pieces = []
read_partial = n_rows != -1
for p in self.paths:
skiprows = self.reader_options["skip_rows"]
path = Path(p)
Expand All @@ -304,9 +311,13 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
comment=comment,
decimal=decimal,
dtypes=self.schema,
nrows=nrows,
nrows=n_rows,
)
pieces.append(tbl_w_meta)
if read_partial:
n_rows -= tbl_w_meta.tbl.num_rows()
if n_rows <= 0:
break
tables, colnames = zip(
*(
(piece.tbl, piece.column_names(include_children=False))
Expand All @@ -321,7 +332,8 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
tbl_w_meta = plc.io.parquet.read_parquet(
plc.io.SourceInfo(self.paths),
columns=with_columns,
num_rows=nrows,
num_rows=n_rows,
skip_rows=self.skip_rows,
)
df = DataFrame.from_table(
tbl_w_meta.tbl,
Expand Down
25 changes: 22 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _translate_ir(
def _(
node: pl_ir.PythonScan, visitor: NodeTraverser, schema: dict[str, plc.DataType]
) -> ir.IR:
if visitor.version()[0] == 1: # pragma: no cover
if visitor.version()[0] == 1:
# https://github.com/pola-rs/polars/pull/17939
# Versioning can be dropped once polars 1.4 is lowest
# supported version.
Expand All @@ -87,7 +87,7 @@ def _(
if predicate is not None
else None
)
else:
else: # pragma: no cover; CI tests 1.4
# version == 0
options = node.options
predicate = (
Expand All @@ -108,13 +108,32 @@ def _(
cloud_options = None
else:
reader_options, cloud_options = map(json.loads, options)
file_options = node.file_options
with_columns = file_options.with_columns
n_rows = file_options.n_rows
if n_rows is None:
n_rows = -1 # All rows
skip_rows = 0 # Don't skip
else:
if visitor.version() >= (1, 0):
# Polars 1.4 n_rows property is (skip, nrows)
skip_rows, n_rows = n_rows
else: # pragma: no cover; CI tests 1.4
# Polars 1.3 n_rows property is integer, skip rows was
# always zero because it was not pushed down to reader.
skip_rows = 0

row_index = file_options.row_index
return ir.Scan(
schema,
typ,
reader_options,
cloud_options,
node.paths,
node.file_options,
with_columns,
skip_rows,
n_rows,
row_index,
translate_named_expr(visitor, n=node.predicate)
if node.predicate is not None
else None,
Expand Down
52 changes: 46 additions & 6 deletions python/cudf_polars/tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ def mask(request):
return request.param


@pytest.fixture(
params=[
None,
(1, 1),
],
ids=[
"no-slice",
"slice-second",
],
)
def slice(request):
# For use in testing that we handle
# polars slice pushdown correctly
return request.param


def make_source(df, path, format):
"""
Writes the passed polars df to a file of
Expand All @@ -78,7 +94,9 @@ def make_source(df, path, format):
("parquet", pl.scan_parquet),
],
)
def test_scan(tmp_path, df, format, scan_fn, row_index, n_rows, columns, mask, request):
def test_scan(
tmp_path, df, format, scan_fn, row_index, n_rows, columns, mask, slice, request
):
name, offset = row_index
make_source(df, tmp_path / "file", format)
request.applymarker(
Expand All @@ -93,13 +111,25 @@ def test_scan(tmp_path, df, format, scan_fn, row_index, n_rows, columns, mask, r
row_index_offset=offset,
n_rows=n_rows,
)
if slice is not None:
q = q.slice(*slice)
if mask is not None:
q = q.filter(mask)
if columns is not None:
q = q.select(*columns)
assert_gpu_result_equal(q)


def test_negative_slice_pushdown_raises(tmp_path):
df = pl.DataFrame({"a": [1, 2, 3]})

df.write_parquet(tmp_path / "df.parquet")
q = pl.scan_parquet(tmp_path / "df.parquet")
# Take the last row
q = q.slice(-1, 1)
assert_ir_translation_raises(q, NotImplementedError)


def test_scan_unsupported_raises(tmp_path):
df = pl.DataFrame({"a": [1, 2, 3]})

Expand Down Expand Up @@ -154,15 +184,25 @@ def test_scan_csv_column_renames_projection_schema(tmp_path):
("test*.csv", False),
],
)
def test_scan_csv_multi(tmp_path, filename, glob):
@pytest.mark.parametrize(
"nrows_skiprows",
[
(None, 0),
(1, 1),
(3, 0),
(4, 2),
],
)
def test_scan_csv_multi(tmp_path, filename, glob, nrows_skiprows):
n_rows, skiprows = nrows_skiprows
with (tmp_path / "test1.csv").open("w") as f:
f.write("""foo,bar,baz\n1,2\n3,4,5""")
f.write("""foo,bar,baz\n1,2,3\n3,4,5""")
with (tmp_path / "test2.csv").open("w") as f:
f.write("""foo,bar,baz\n1,2\n3,4,5""")
f.write("""foo,bar,baz\n1,2,3\n3,4,5""")
with (tmp_path / "test*.csv").open("w") as f:
f.write("""foo,bar,baz\n1,2\n3,4,5""")
f.write("""foo,bar,baz\n1,2,3\n3,4,5""")
os.chdir(tmp_path)
q = pl.scan_csv(filename, glob=glob)
q = pl.scan_csv(filename, glob=glob, n_rows=n_rows, skip_rows=skiprows)

assert_gpu_result_equal(q)

Expand Down

0 comments on commit 7d0c7ad

Please sign in to comment.