Skip to content

Commit

Permalink
Adapt to polars upstream changes and turn on CI testing (rapidsai#16081)
Browse files Browse the repository at this point in the history
They changed the semantics of join keys when those keys are expressions to more closely match SQL.

Dtype inference is also tighter, so update tests to adapt to those changes, and some other small deprecation warnings.

Finish the final missing coverage piece and turn on testing in CI (failing if we don't hit 100% coverage as well).

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Thomas Li (https://github.com/lithomas1)
  - James Lamb (https://github.com/jameslamb)

URL: rapidsai#16081
  • Loading branch information
wence- authored Jun 27, 2024
1 parent e98d456 commit fa8284d
Show file tree
Hide file tree
Showing 13 changed files with 234 additions and 112 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
- docs-build
- wheel-build-cudf
- wheel-tests-cudf
- test-cudf-polars
- wheel-build-dask-cudf
- wheel-tests-dask-cudf
- devcontainer
Expand Down Expand Up @@ -132,6 +133,17 @@ jobs:
with:
build_type: pull-request
script: ci/test_wheel_cudf.sh
test-cudf-polars:
needs: wheel-build-cudf
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.08
with:
# This selects "ARCH=amd64 + the latest supported Python + CUDA".
matrix_filter: map(select(.ARCH == "amd64")) | group_by(.CUDA_VER|split(".")|map(tonumber)|.[0]) | map(max_by([(.PY_VER|split(".")|map(tonumber)), (.CUDA_VER|split(".")|map(tonumber))]))
build_type: pull-request
# This always runs, but only fails if this PR touches code in
# pylibcudf or cudf_polars
script: "ci/test_cudf_polars.sh"
wheel-build-dask-cudf:
needs: wheel-build-cudf
secrets: inherit
Expand Down
68 changes: 68 additions & 0 deletions ci/test_cudf_polars.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/bin/bash
# Copyright (c) 2024, NVIDIA CORPORATION.

set -eou pipefail

# We will only fail these tests if the PR touches code in pylibcudf
# or cudf_polars itself.
# Note, the three dots mean we are doing diff between the merge-base
# of upstream and HEAD. So this is asking, "does _this branch_ touch
# files in cudf_polars/pylibcudf", rather than "are there changes
# between upstream and this branch which touch cudf_polars/pylibcudf"
# TODO: is the target branch exposed anywhere in an environment variable?
if [ -n "$(git diff --name-only origin/branch-24.08...HEAD -- python/cudf_polars/ python/cudf/cudf/_lib/pylibcudf/)" ];
then
HAS_CHANGES=1
else
HAS_CHANGES=0
fi

RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})"
RAPIDS_PY_WHEEL_NAME="cudf_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./dist

RESULTS_DIR=${RAPIDS_TESTS_DIR:-"$(mktemp -d)"}
RAPIDS_TESTS_DIR=${RAPIDS_TESTS_DIR:-"${RESULTS_DIR}/test-results"}/
mkdir -p "${RAPIDS_TESTS_DIR}"

rapids-logger "Install cudf wheel"
# echo to expand wildcard before adding `[extra]` requires for pip
python -m pip install $(echo ./dist/cudf*.whl)[test]

rapids-logger "Install polars (allow pre-release versions)"
python -m pip install 'polars>=1.0.0a0'

rapids-logger "Install cudf_polars"
python -m pip install --no-deps python/cudf_polars

rapids-logger "Run cudf_polars tests"

function set_exitcode()
{
EXITCODE=$?
}
EXITCODE=0
trap set_exitcode ERR
set +e

python -m pytest \
--cache-clear \
--cov cudf_polars \
--cov-fail-under=100 \
--cov-config=python/cudf_polars/pyproject.toml \
--junitxml="${RAPIDS_TESTS_DIR}/junit-cudf_polars.xml" \
python/cudf_polars/tests

trap ERR
set -e

if [ ${EXITCODE} != 0 ]; then
rapids-logger "Testing FAILED: exitcode ${EXITCODE}"
else
rapids-logger "Testing PASSED"
fi

if [ ${HAS_CHANGES} == 1 ]; then
exit ${EXITCODE}
else
exit 0
fi
70 changes: 37 additions & 33 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def broadcast(
]


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class IR:
"""Abstract plan node, representing an unevaluated dataframe."""

Expand Down Expand Up @@ -157,7 +157,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
) # pragma: no cover


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class PythonScan(IR):
"""Representation of input from a python function."""

Expand All @@ -171,7 +171,7 @@ def __post_init__(self):
raise NotImplementedError("PythonScan not implemented")


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class Scan(IR):
"""Input from files."""

Expand Down Expand Up @@ -248,7 +248,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return df.filter(mask)


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class Cache(IR):
"""
Return a cached plan node.
Expand All @@ -269,7 +269,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return cache.setdefault(self.key, self.value.evaluate(cache=cache))


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class DataFrameScan(IR):
"""
Input from an existing polars DataFrame.
Expand Down Expand Up @@ -315,7 +315,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return df


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class Select(IR):
"""Produce a new dataframe selecting given expressions from an input."""

Expand All @@ -336,7 +336,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return DataFrame(columns)


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class Reduce(IR):
"""
Produce a new dataframe selecting given expressions from an input.
Expand Down Expand Up @@ -389,7 +389,7 @@ def placeholder_column(n: int) -> plc.Column:
)


@dataclasses.dataclass(slots=False)
@dataclasses.dataclass
class GroupBy(IR):
"""Perform a groupby."""

Expand Down Expand Up @@ -490,7 +490,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return DataFrame([*result_keys, *results]).slice(self.options.slice)


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class Join(IR):
"""A join of two dataframes."""

Expand Down Expand Up @@ -518,8 +518,16 @@ class Join(IR):
- coalesce: should key columns be coalesced (only makes sense for outer joins)
"""

@cache
def __post_init__(self) -> None:
"""Validate preconditions."""
if any(
isinstance(e.value, expr.Literal)
for e in itertools.chain(self.left_on, self.right_on)
):
raise NotImplementedError("Join with literal as join key.")

@staticmethod
@cache
def _joiners(
how: Literal["inner", "left", "full", "leftsemi", "leftanti"],
) -> tuple[
Expand Down Expand Up @@ -582,17 +590,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
for new, old in zip(columns[left.num_columns :], right.columns)
]
return DataFrame([*left_cols, *right_cols])
left_on = DataFrame(
broadcast(
*(e.evaluate(left) for e in self.left_on), target_length=left.num_rows
)
)
right_on = DataFrame(
broadcast(
*(e.evaluate(right) for e in self.right_on),
target_length=right.num_rows,
)
)
# TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
left_on = DataFrame(broadcast(*(e.evaluate(left) for e in self.left_on)))
right_on = DataFrame(broadcast(*(e.evaluate(right) for e in self.right_on)))
null_equality = (
plc.types.NullEquality.EQUAL
if join_nulls
Expand All @@ -602,13 +602,10 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
if right_policy is None:
# Semi join
lg = join_fn(left_on.table, right_on.table, null_equality)
left = left.replace_columns(*left_on.columns)
table = plc.copying.gather(left.table, lg, left_policy)
result = DataFrame.from_table(table, left.column_names)
else:
lg, rg = join_fn(left_on.table, right_on.table, null_equality)
left = left.replace_columns(*left_on.columns)
right = right.replace_columns(*right_on.columns)
if coalesce and how == "inner":
right = right.discard_columns(right_on.column_names_set)
left = DataFrame.from_table(
Expand Down Expand Up @@ -642,7 +639,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return result.slice(zlice)


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class HStack(IR):
"""Add new columns to a dataframe."""

Expand Down Expand Up @@ -671,7 +668,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return df.with_columns(columns)


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class Distinct(IR):
"""Produce a new dataframe with distinct rows."""

Expand Down Expand Up @@ -741,7 +738,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return result.slice(self.zlice)


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class Sort(IR):
"""Sort a dataframe."""

Expand Down Expand Up @@ -810,7 +807,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return DataFrame(columns).slice(self.zlice)


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class Slice(IR):
"""Slice a dataframe."""

Expand All @@ -827,7 +824,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return df.slice((self.offset, self.length))


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class Filter(IR):
"""Filter a dataframe with a boolean mask."""

Expand All @@ -843,7 +840,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return df.filter(mask)


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class Projection(IR):
"""Select a subset of columns from a dataframe."""

Expand All @@ -860,7 +857,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
return DataFrame(columns)


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class MapFunction(IR):
"""Apply some function to a dataframe."""

Expand Down Expand Up @@ -894,6 +891,13 @@ def __post_init__(self) -> None:
# polars requires that all to-explode columns have the
# same sub-shapes
raise NotImplementedError("Explode with more than one column")
elif self.name == "rename":
old, new, _ = self.options
# TODO: perhaps polars should validate renaming in the IR?
if len(new) != len(set(new)) or (
set(new) & (set(self.df.schema.keys() - set(old)))
):
raise NotImplementedError("Duplicate new names in rename.")

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
Expand All @@ -919,7 +923,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
raise AssertionError("Should never be reached") # pragma: no cover


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class Union(IR):
"""Concatenate dataframes vertically."""

Expand All @@ -943,7 +947,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
).slice(self.zlice)


@dataclasses.dataclass(slots=True)
@dataclasses.dataclass
class HConcat(IR):
"""Concatenate dataframes horizontally."""

Expand Down
Loading

0 comments on commit fa8284d

Please sign in to comment.