diff --git a/python/cudf_polars/cudf_polars/containers/__init__.py b/python/cudf_polars/cudf_polars/containers/__init__.py index 06bb08953f1..3b1eff4a0d0 100644 --- a/python/cudf_polars/cudf_polars/containers/__init__.py +++ b/python/cudf_polars/cudf_polars/containers/__init__.py @@ -5,7 +5,7 @@ from __future__ import annotations -__all__: list[str] = ["DataFrame", "Column", "NamedColumn"] +__all__: list[str] = ["DataFrame", "Column"] -from cudf_polars.containers.column import Column, NamedColumn +from cudf_polars.containers.column import Column from cudf_polars.containers.dataframe import DataFrame diff --git a/python/cudf_polars/cudf_polars/containers/column.py b/python/cudf_polars/cudf_polars/containers/column.py index 3fe3e5557cb..6f5ce509ce6 100644 --- a/python/cudf_polars/cudf_polars/containers/column.py +++ b/python/cudf_polars/cudf_polars/containers/column.py @@ -15,7 +15,7 @@ import polars as pl -__all__: list[str] = ["Column", "NamedColumn"] +__all__: list[str] = ["Column"] class Column: @@ -217,58 +217,3 @@ def nan_count(self) -> int: ) ).as_py() return 0 - - -class NamedColumn(Column): - """A column with a name.""" - - name: str - - def __init__( - self, - column: plc.Column, - name: str, - *, - is_sorted: plc.types.Sorted = plc.types.Sorted.NO, - order: plc.types.Order = plc.types.Order.ASCENDING, - null_order: plc.types.NullOrder = plc.types.NullOrder.BEFORE, - ) -> None: - super().__init__( - column, is_sorted=is_sorted, order=order, null_order=null_order - ) - self.name = name - - def copy(self, *, new_name: str | None = None) -> Self: - """ - A shallow copy of the column. - - Parameters - ---------- - new_name - Optional new name for the copied column. - - Returns - ------- - New column sharing data with self. - """ - return type(self)( - self.obj, - self.name if new_name is None else new_name, - is_sorted=self.is_sorted, - order=self.order, - null_order=self.null_order, - ) - - def mask_nans(self) -> Self: - """Return a shallow copy of self with nans masked out.""" - # Annoying, the inheritance is not right (can't call the - # super-type mask_nans), but will sort that by refactoring - # later. - if plc.traits.is_floating_point(self.obj.type()): - old_count = self.obj.null_count() - mask, new_count = plc.transform.nans_to_nulls(self.obj) - result = type(self)(self.obj.with_mask(mask, new_count), self.name) - if old_count == new_count: - return result.sorted_like(self) - return result - return self.copy() diff --git a/python/cudf_polars/cudf_polars/containers/dataframe.py b/python/cudf_polars/cudf_polars/containers/dataframe.py index f3e3862d0cc..d2f5e000343 100644 --- a/python/cudf_polars/cudf_polars/containers/dataframe.py +++ b/python/cudf_polars/cudf_polars/containers/dataframe.py @@ -14,16 +14,14 @@ import polars as pl -from cudf_polars.containers.column import NamedColumn +from cudf_polars.containers import Column from cudf_polars.utils import dtypes if TYPE_CHECKING: - from collections.abc import Mapping, Sequence, Set + from collections.abc import Iterable, Mapping, Sequence, Set from typing_extensions import Self - from cudf_polars.containers import Column - __all__: list[str] = ["DataFrame"] @@ -31,17 +29,18 @@ class DataFrame: """A representation of a dataframe.""" - columns: list[NamedColumn] + column_map: dict[str, Column] table: plc.Table - def __init__(self, columns: Sequence[NamedColumn]) -> None: - self.columns = list(columns) - self._column_map = {c.name: c for c in self.columns} - self.table = plc.Table([c.obj for c in columns]) + def __init__( + self, columns: Iterable[tuple[str, Column]] | Mapping[str, Column] + ) -> None: + self.column_map = dict(columns) + self.table = plc.Table([c.obj for c in self.column_map.values()]) def copy(self) -> Self: """Return a shallow copy of self.""" - return type(self)([c.copy() for c in self.columns]) + return type(self)((name, c.copy()) for name, c in self.column_map.items()) def to_polars(self) -> pl.DataFrame: """Convert to a polars DataFrame.""" @@ -51,7 +50,7 @@ def to_polars(self) -> pl.DataFrame: # https://github.com/pola-rs/polars/issues/11632 # To guarantee we produce correct names, we therefore # serialise with names we control and rename with that map. - name_map = {f"column_{i}": c.name for i, c in enumerate(self.columns)} + name_map = {f"column_{i}": name for i, name in enumerate(self.column_map)} table: pa.Table = plc.interop.to_arrow( self.table, [plc.interop.ColumnMetadata(name=name) for name in name_map], @@ -59,34 +58,39 @@ def to_polars(self) -> pl.DataFrame: df: pl.DataFrame = pl.from_arrow(table) return df.rename(name_map).with_columns( *( - pl.col(c.name).set_sorted( - descending=c.order == plc.types.Order.DESCENDING + pl.col(name).set_sorted( + descending=column.order == plc.types.Order.DESCENDING ) - if c.is_sorted - else pl.col(c.name) - for c in self.columns + if column.is_sorted + else pl.col(name) + for name, column in self.column_map.items() ) ) @cached_property def column_names_set(self) -> frozenset[str]: """Return the column names as a set.""" - return frozenset(c.name for c in self.columns) + return frozenset(self.column_map) @cached_property def column_names(self) -> list[str]: """Return a list of the column names.""" - return [c.name for c in self.columns] + return list(self.column_map) + + @cached_property + def columns(self) -> list[Column]: + """Return a list of the columns.""" + return list(self.column_map.values()) @cached_property def num_columns(self) -> int: """Number of columns.""" - return len(self.columns) + return len(self.column_map) @cached_property def num_rows(self) -> int: """Number of rows.""" - return 0 if len(self.columns) == 0 else self.table.num_rows() + return 0 if len(self.column_map) == 0 else self.table.num_rows() @classmethod def from_polars(cls, df: pl.DataFrame) -> Self: @@ -111,12 +115,8 @@ def from_polars(cls, df: pl.DataFrame) -> Self: # No-op if the schema is unchanged. d_table = plc.interop.from_arrow(table.cast(schema)) return cls( - [ - NamedColumn(column, h_col.name).copy_metadata(h_col) - for column, h_col in zip( - d_table.columns(), df.iter_columns(), strict=True - ) - ] + (h_col.name, Column(column).copy_metadata(h_col)) + for column, h_col in zip(d_table.columns(), df.iter_columns(), strict=True) ) @classmethod @@ -143,12 +143,7 @@ def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self: """ if table.num_columns() != len(names): raise ValueError("Mismatching name and table length.") - return cls( - [ - NamedColumn(c, name) - for c, name in zip(table.columns(), names, strict=True) - ] - ) + return cls(zip(names, map(Column, table.columns()), strict=True)) def sorted_like( self, like: DataFrame, /, *, subset: Set[str] | None = None @@ -175,13 +170,15 @@ def sorted_like( if like.column_names != self.column_names: raise ValueError("Can only copy from identically named frame") subset = self.column_names_set if subset is None else subset - self.columns = [ - c.sorted_like(other) if c.name in subset else c - for c, other in zip(self.columns, like.columns, strict=True) - ] + self.column_map = { + name: column.sorted_like(other) if name in subset else column + for (name, column), other in zip( + self.column_map.items(), like.column_map.values(), strict=True + ) + } return self - def with_columns(self, columns: Sequence[NamedColumn]) -> Self: + def with_columns(self, columns: Iterable[tuple[str, Column]]) -> Self: """ Return a new dataframe with extra columns. @@ -198,36 +195,40 @@ def with_columns(self, columns: Sequence[NamedColumn]) -> Self: ----- If column names overlap, newer names replace older ones. """ - columns = list( - {c.name: c for c in itertools.chain(self.columns, columns)}.values() - ) - return type(self)(columns) + return type(self)(itertools.chain(self.column_map.items(), columns)) def discard_columns(self, names: Set[str]) -> Self: """Drop columns by name.""" - return type(self)([c for c in self.columns if c.name not in names]) + return type(self)( + (name, column) + for name, column in self.column_map.items() + if name not in names + ) def select(self, names: Sequence[str]) -> Self: """Select columns by name returning DataFrame.""" want = set(names) if not want.issubset(self.column_names_set): raise ValueError("Can't select missing names") - return type(self)([self._column_map[name] for name in names]) + return type(self)((name, self.column_map[name]) for name in names) - def replace_columns(self, *columns: NamedColumn) -> Self: + def replace_columns(self, *columns: tuple[str, Column]) -> Self: """Return a new dataframe with columns replaced by name.""" - new = {c.name: c for c in columns} + new = dict(columns) if not set(new).issubset(self.column_names_set): raise ValueError("Cannot replace with non-existing names") - return type(self)([new.get(c.name, c) for c in self.columns]) + return type(self)(self.column_map | new) def rename_columns(self, mapping: Mapping[str, str]) -> Self: """Rename some columns.""" - return type(self)([c.copy(new_name=mapping.get(c.name)) for c in self.columns]) + return type(self)( + (mapping.get(name, name), column) + for name, column in self.column_map.items() + ) - def select_columns(self, names: Set[str]) -> list[NamedColumn]: + def select_columns(self, names: Set[str]) -> list[Column]: """Select columns by name.""" - return [c for c in self.columns if c.name in names] + return [c for name, c in self.column_map.items() if name in names] def filter(self, mask: Column) -> Self: """Return a filtered table given a mask.""" diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index c401e5a2f17..c04651f6047 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -27,7 +27,7 @@ from polars.exceptions import InvalidOperationError from polars.polars import _expr_nodes as pl_expr -from cudf_polars.containers import Column, NamedColumn +from cudf_polars.containers import Column from cudf_polars.utils import dtypes, sorting if TYPE_CHECKING: @@ -313,7 +313,7 @@ def evaluate( *, context: ExecutionContext = ExecutionContext.FRAME, mapping: Mapping[Expr, Column] | None = None, - ) -> NamedColumn: + ) -> tuple[str, Column]: """ Evaluate this expression given a dataframe for context. @@ -328,21 +328,14 @@ def evaluate( Returns ------- - NamedColumn attaching a name to an evaluated Column + tuple of name and evaluated Column See Also -------- :meth:`Expr.evaluate` for details, this function just adds the name to a column produced from an expression. """ - obj = self.value.evaluate(df, context=context, mapping=mapping) - return NamedColumn( - obj.obj, - self.name, - is_sorted=obj.is_sorted, - order=obj.order, - null_order=obj.null_order, - ) + return self.name, self.value.evaluate(df, context=context, mapping=mapping) def collect_agg(self, *, depth: int) -> AggInfo: """Collect information about aggregations in groupbys.""" @@ -428,7 +421,7 @@ def do_evaluate( mapping: Mapping[Expr, Column] | None = None, ) -> Column: """Evaluate this expression given a dataframe for context.""" - return df._column_map[self.name] + return df.column_map[self.name] def collect_agg(self, *, depth: int) -> AggInfo: """Collect information about aggregations in groupbys.""" diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 1c61075be22..e49f4e03c7c 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -26,7 +26,7 @@ import polars as pl import cudf_polars.dsl.expr as expr -from cudf_polars.containers import DataFrame, NamedColumn +from cudf_polars.containers import Column, DataFrame from cudf_polars.utils import dtypes, sorting if TYPE_CHECKING: @@ -58,22 +58,23 @@ def broadcast( - *columns: NamedColumn, target_length: int | None = None -) -> list[NamedColumn]: + *named_columns: tuple[str, Column], target_length: int | None = None +) -> tuple[list[str], list[Column]]: """ Broadcast a sequence of columns to a common length. Parameters ---------- - columns - Columns to broadcast. + named_columns + Pairs of column names and Columns to broadcast target_length Optional length to broadcast to. If not provided, uses the non-unit length of existing columns. Returns ------- - List of broadcasted columns all of the same length. + Tuple of list of names and list of broadcasted columns all of the + same length. Raises ------ @@ -93,12 +94,13 @@ def broadcast( ``target_length`` is provided and not all columns are length-1 (i.e. ``n != 1``), then ``target_length`` must be equal to ``n``. """ - if len(columns) == 0: - return [] + if len(named_columns) == 0: + return [], [] + names, columns = zip(*named_columns, strict=True) lengths: set[int] = {column.obj.size() for column in columns} if lengths == {1}: if target_length is None: - return list(columns) + return list(names), list(columns) nrows = target_length else: try: @@ -109,12 +111,11 @@ def broadcast( raise RuntimeError( f"Cannot broadcast columns of length {nrows=} to {target_length=}" ) - return [ + return list(names), [ column if column.obj.size() != 1 - else NamedColumn( + else Column( plc.Column.from_scalar(column.obj_scalar, nrows), - column.name, is_sorted=plc.types.Sorted.YES, order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.BEFORE, @@ -385,19 +386,22 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: init = plc.interop.from_arrow( pa.scalar(offset, type=plc.interop.to_arrow(dtype)) ) - index = NamedColumn( + index = Column( plc.filling.sequence(df.num_rows, init, step), - name, is_sorted=plc.types.Sorted.YES, order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.AFTER, ) - df = DataFrame([index, *df.columns]) - assert all(c.obj.type() == self.schema[c.name] for c in df.columns) + df = DataFrame([(name, index), *df.column_map.items()]) + assert all( + c.obj.type() == self.schema[name] for name, c in df.column_map.items() + ) if self.predicate is None: return df else: - (mask,) = broadcast(self.predicate.evaluate(df), target_length=df.num_rows) + _, (mask,) = broadcast( + self.predicate.evaluate(df), target_length=df.num_rows + ) return df.filter(mask) @@ -448,7 +452,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: for c, dtype in zip(df.columns, self.schema.values(), strict=True) ) if self.predicate is not None: - (mask,) = broadcast(self.predicate.evaluate(df), target_length=df.num_rows) + _, (mask,) = broadcast( + self.predicate.evaluate(df), target_length=df.num_rows + ) return df.filter(mask) else: return df @@ -471,7 +477,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: # Handle any broadcasting columns = [e.evaluate(df) for e in self.expr] if self.should_broadcast: - columns = broadcast(*columns) + columns = list(zip(*broadcast(*columns), strict=True)) return DataFrame(columns) @@ -493,9 +499,9 @@ def evaluate( ) -> DataFrame: # pragma: no cover; polars doesn't emit this node yet """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) - columns = broadcast(*(e.evaluate(df) for e in self.expr)) + names, columns = broadcast(*(e.evaluate(df) for e in self.expr)) assert all(column.obj.size() == 1 for column in columns) - return DataFrame(columns) + return DataFrame(zip(names, columns, strict=True)) @dataclasses.dataclass @@ -558,7 +564,7 @@ def __post_init__(self) -> None: def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) - keys = broadcast( + key_names, keys = broadcast( *(k.evaluate(df) for k in self.keys), target_length=df.num_rows ) sorted = ( @@ -588,21 +594,22 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: requests.append(plc.groupby.GroupByRequest(col, [req])) replacements.append(rep) group_keys, raw_tables = grouper.aggregate(requests) - # TODO: names - raw_columns: list[NamedColumn] = [] - for i, table in enumerate(raw_tables): + raw_columns: list[Column] = [] + for table in raw_tables: (column,) = table.columns() - raw_columns.append(NamedColumn(column, f"tmp{i}")) + raw_columns.append(Column(column)) mapping = dict(zip(replacements, raw_columns, strict=True)) result_keys = [ - NamedColumn(gk, k.name) - for gk, k in zip(group_keys.columns(), keys, strict=True) + (key_name, Column(key)) + for key_name, key in zip(key_names, group_keys.columns(), strict=True) ] - result_subs = DataFrame(raw_columns) + result_subs = DataFrame( + (f"tmp{i}", column) for i, column in enumerate(raw_columns) + ) results = [ req.evaluate(result_subs, mapping=mapping) for req in self.agg_requests ] - broadcasted = broadcast(*result_keys, *results) + names, broadcasted = broadcast(*result_keys, *results) # Handle order preservation of groups if self.maintain_order and not sorted: # The order we want @@ -638,13 +645,8 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: right_order, plc.copying.OutOfBoundsPolicy.DONT_CHECK, ) - broadcasted = [ - NamedColumn(reordered, b.name) - for reordered, b in zip( - ordered_table.columns(), broadcasted, strict=True - ) - ] - return DataFrame(broadcasted).slice(self.options.slice) + broadcasted = [Column(reordered) for reordered in ordered_table.columns()] + return DataFrame(zip(names, broadcasted, strict=True)).slice(self.options.slice) @dataclasses.dataclass @@ -787,26 +789,28 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: # result, not the gather maps columns = plc.join.cross_join(left.table, right.table).columns() left_cols = [ - NamedColumn(new, old.name).sorted_like(old) - for new, old in zip( - columns[: left.num_columns], left.columns, strict=True + (name, Column(new).sorted_like(old)) + for new, (name, old) in zip( + columns[: left.num_columns], left.column_map.items(), strict=True ) ] right_cols = [ - NamedColumn( - new, - old.name - if old.name not in left.column_names_set - else f"{old.name}{suffix}", + ( + name if name not in left.column_names_set else f"{name}{suffix}", + Column(new), ) - for new, old in zip( - columns[left.num_columns :], right.columns, strict=True + for new, name in zip( + columns[left.num_columns :], right.column_names, strict=True ) ] return DataFrame([*left_cols, *right_cols]) # TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184 - left_on = DataFrame(broadcast(*(e.evaluate(left) for e in self.left_on))) - right_on = DataFrame(broadcast(*(e.evaluate(right) for e in self.right_on))) + left_on = DataFrame( + zip(*broadcast(*(e.evaluate(left) for e in self.left_on)), strict=True) + ) + right_on = DataFrame( + zip(*broadcast(*(e.evaluate(right) for e in self.right_on)), strict=True) + ) null_equality = ( plc.types.NullEquality.EQUAL if join_nulls @@ -840,12 +844,18 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: if coalesce and how != "inner": left = left.replace_columns( *( - NamedColumn( - plc.replace.replace_nulls(left_col.obj, right_col.obj), - left_col.name, + ( + name, + Column( + plc.replace.replace_nulls(left_col.obj, right_col.obj) + ), ) - for left_col, right_col in zip( - left.select_columns(left_on.column_names_set), + for (name, left_col), right_col in zip( + ( + (name, column) + for name, column in left.column_map.items() + if name in left_on.column_names_set + ), right.select_columns(right_on.column_names_set), strict=True, ) @@ -862,7 +872,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: if name in left.column_names_set } ) - result = left.with_columns(right.columns) + result = left.with_columns(right.column_map.items()) return result.slice(zlice) @@ -882,7 +892,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: df = self.df.evaluate(cache=cache) columns = [c.evaluate(df) for c in self.columns] if self.should_broadcast: - columns = broadcast(*columns, target_length=df.num_rows) + columns = list( + zip(*broadcast(*columns, target_length=df.num_rows), strict=True) + ) else: # Polars ensures this is true, but let's make sure nothing # went wrong. In this case, the parent node is a @@ -931,9 +943,10 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: df = self.df.evaluate(cache=cache) if self.subset is None: indices = list(range(df.num_columns)) + keys_sorted = all(c.is_sorted for c in df.column_map.values()) else: indices = [i for i, k in enumerate(df.column_names) if k in self.subset] - keys_sorted = all(df.columns[i].is_sorted for i in indices) + keys_sorted = all(df.column_map[name].is_sorted for name in self.subset) if keys_sorted: table = plc.stream_compaction.unique( df.table, @@ -954,11 +967,12 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: plc.types.NullEquality.EQUAL, plc.types.NanEquality.ALL_EQUAL, ) + # TODO: Is this sortedness setting correct result = DataFrame( - [ - NamedColumn(c, old.name).sorted_like(old) - for c, old in zip(table.columns(), df.columns, strict=True) - ] + (name, Column(new).sorted_like(old)) + for new, (name, old) in zip( + table.columns(), df.column_map.items(), strict=True + ) ) if keys_sorted or self.stable: result = result.sorted_like(df) @@ -1005,33 +1019,31 @@ def __init__( def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) - sort_keys = broadcast( - *(k.evaluate(df) for k in self.by), target_length=df.num_rows - ) - names = {c.name: i for i, c in enumerate(df.columns)} + key_names, sort_keys = broadcast(*(k.evaluate(df) for k in self.by)) # TODO: More robust identification here. - keys_in_result = [ - i - for k in sort_keys - if (i := names.get(k.name)) is not None and k.obj is df.columns[i].obj - ] + keys_in_result = { + key_name: i + for i, (key_name, k) in enumerate(zip(key_names, sort_keys, strict=True)) + if key_name in df.column_map and k.obj is df.column_map[key_name].obj + } table = self.do_sort( df.table, plc.Table([k.obj for k in sort_keys]), self.order, self.null_order, ) - columns = [ - NamedColumn(c, old.name) - for c, old in zip(table.columns(), df.columns, strict=True) - ] - # If a sort key is in the result table, set the sortedness property - for k, i in enumerate(keys_in_result): - columns[i] = columns[i].set_sorted( - is_sorted=plc.types.Sorted.YES, - order=self.order[k], - null_order=self.null_order[k], - ) + columns: list[tuple[str, Column]] = [] + for name, c in zip(df.column_map, table.columns(), strict=True): + column = Column(c) + # If a sort key is in the result table, set the sortedness property + if name in keys_in_result: + i = keys_in_result[name] + column = column.set_sorted( + is_sorted=plc.types.Sorted.YES, + order=self.order[i], + null_order=self.null_order[i], + ) + columns.append((name, column)) return DataFrame(columns).slice(self.zlice) @@ -1064,7 +1076,7 @@ class Filter(IR): def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) - (mask,) = broadcast(self.mask.evaluate(df), target_length=df.num_rows) + _, (mask,) = broadcast(self.mask.evaluate(df), target_length=df.num_rows) return df.filter(mask) @@ -1079,10 +1091,11 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) # This can reorder things. - columns = broadcast( - *df.select(list(self.schema.keys())).columns, target_length=df.num_rows + names, columns = broadcast( + *(df.select(list(self.schema.keys())).column_map.items()), + target_length=df.num_rows, ) - return DataFrame(columns) + return DataFrame(zip(names, columns, strict=True)) @dataclasses.dataclass @@ -1125,7 +1138,7 @@ def __post_init__(self) -> None: old, new, _ = self.options # TODO: perhaps polars should validate renaming in the IR? if len(new) != len(set(new)) or ( - set(new) & (set(self.df.schema.keys() - set(old))) + set(new) & (set(self.df.schema.keys()) - set(old)) ): raise NotImplementedError("Duplicate new names in rename.") elif self.name == "unpivot": @@ -1170,7 +1183,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: npiv = len(pivotees) df = self.df.evaluate(cache=cache) index_columns = [ - NamedColumn(col, name) + (name, Column(col)) for col, name in zip( plc.reshape.tile(df.select(indices).table, npiv).columns(), indices, @@ -1191,13 +1204,16 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: df.num_rows, ).columns() value_column = plc.concatenate.concatenate( - [c.astype(self.schema[value_name]) for c in df.select(pivotees).columns] + [ + df.column_map[pivotee].astype(self.schema[value_name]) + for pivotee in pivotees + ] ) return DataFrame( [ *index_columns, - NamedColumn(variable_column, variable_name), - NamedColumn(value_column, value_name), + (variable_name, Column(variable_column)), + (value_name, Column(value_column)), ] ) else: @@ -1279,5 +1295,5 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: for df in dfs ] return DataFrame( - list(itertools.chain.from_iterable(df.columns for df in dfs)), + itertools.chain.from_iterable(df.column_map.items() for df in dfs) ) diff --git a/python/cudf_polars/docs/overview.md b/python/cudf_polars/docs/overview.md index bff44af1468..5c31ba75504 100644 --- a/python/cudf_polars/docs/overview.md +++ b/python/cudf_polars/docs/overview.md @@ -201,21 +201,21 @@ the logical plan in any case, so is reasonably natural. # Containers Containers should be constructed as relatively lightweight objects -around their pylibcudf counterparts. We have four (in +around their pylibcudf counterparts. We have three (in `cudf_polars/containers/`): 1. `Scalar` (a wrapper around a pylibcudf `Scalar`) 2. `Column` (a wrapper around a pylibcudf `Column`) -3. `NamedColumn` (a `Column` with an additional name) -4. `DataFrame` (a wrapper around a pylibcudf `Table`) +3. `DataFrame` (a wrapper around a pylibcudf `Table`) The interfaces offered by these are somewhat in flux, but broadly -speaking, a `DataFrame` is just a list of `NamedColumn`s which each -hold a `Column` plus a string `name`. `NamedColumn`s are only ever -constructed via `NamedExpr`s, which are the top-level expression node -that lives inside an `IR` node. This means that the expression -evaluator never has to concern itself with column names: columns are -only ever decorated with names when constructing a `DataFrame`. +speaking, a `DataFrame` is just a mapping from string `name`s to +`Column`s, and thus also holds a pylibcudf `Table`. Names are only +every attached to `Column`s and hence inserted into `DataFrames` via +`NamedExpr`s: these which are the top-level expression nodes that live +inside an `IR` node. This means that the expression evaluator never +has to concern itself with column names: columns are only ever +decorated with names when constructing a `DataFrame`. The columns keep track of metadata (for example, whether or not they are sorted). We could imagine tracking more metadata, like minimum and diff --git a/python/cudf_polars/tests/containers/test_column.py b/python/cudf_polars/tests/containers/test_column.py index 19919877f84..1f26ab1af9f 100644 --- a/python/cudf_polars/tests/containers/test_column.py +++ b/python/cudf_polars/tests/containers/test_column.py @@ -3,13 +3,11 @@ from __future__ import annotations -from functools import partial - import pyarrow import pylibcudf as plc import pytest -from cudf_polars.containers import Column, NamedColumn +from cudf_polars.containers import Column def test_non_scalar_access_raises(): @@ -55,11 +53,10 @@ def test_shallow_copy(): @pytest.mark.parametrize("typeid", [plc.TypeId.INT8, plc.TypeId.FLOAT32]) -@pytest.mark.parametrize("constructor", [Column, partial(NamedColumn, name="name")]) -def test_mask_nans(typeid, constructor): +def test_mask_nans(typeid): dtype = plc.DataType(typeid) values = pyarrow.array([0, 0, 0], type=plc.interop.to_arrow(dtype)) - column = constructor(plc.interop.from_arrow(values)) + column = Column(plc.interop.from_arrow(values)) masked = column.mask_nans() assert column.obj.null_count() == masked.obj.null_count() diff --git a/python/cudf_polars/tests/containers/test_dataframe.py b/python/cudf_polars/tests/containers/test_dataframe.py index 39fb44d55a5..a5bddaa5ae6 100644 --- a/python/cudf_polars/tests/containers/test_dataframe.py +++ b/python/cudf_polars/tests/containers/test_dataframe.py @@ -8,18 +8,20 @@ import polars as pl -from cudf_polars.containers import DataFrame, NamedColumn +from cudf_polars.containers import Column, DataFrame from cudf_polars.testing.asserts import assert_gpu_result_equal def test_select_missing_raises(): df = DataFrame( [ - NamedColumn( - plc.column_factories.make_numeric_column( - plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID - ), + ( "a", + Column( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID + ) + ), ) ] ) @@ -30,17 +32,19 @@ def test_select_missing_raises(): def test_replace_missing_raises(): df = DataFrame( [ - NamedColumn( - plc.column_factories.make_numeric_column( - plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID - ), + ( "a", + Column( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID + ) + ), ) ] ) - replacement = df.columns[0].copy(new_name="b") + replacement = df.column_map["a"].copy() with pytest.raises(ValueError): - df.replace_columns(replacement) + df.replace_columns(("b", replacement)) def test_from_table_wrong_names(): @@ -58,11 +62,13 @@ def test_from_table_wrong_names(): def test_sorted_like_raises_mismatching_names(): df = DataFrame( [ - NamedColumn( - plc.column_factories.make_numeric_column( - plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID - ), + ( "a", + Column( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID + ) + ), ) ] ) @@ -72,26 +78,25 @@ def test_sorted_like_raises_mismatching_names(): def test_shallow_copy(): - column = NamedColumn( + column = Column( plc.column_factories.make_numeric_column( plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID - ), - "a", + ) ) column.set_sorted( is_sorted=plc.types.Sorted.YES, order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.AFTER, ) - df = DataFrame([column]) + df = DataFrame([("a", column)]) copy = df.copy() - copy.columns[0].set_sorted( + copy.column_map["a"].set_sorted( is_sorted=plc.types.Sorted.NO, order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.AFTER, ) - assert df.columns[0].is_sorted == plc.types.Sorted.YES - assert copy.columns[0].is_sorted == plc.types.Sorted.NO + assert df.column_map["a"].is_sorted == plc.types.Sorted.YES + assert copy.column_map["a"].is_sorted == plc.types.Sorted.NO def test_sorted_flags_preserved_empty(): @@ -100,7 +105,7 @@ def test_sorted_flags_preserved_empty(): gf = DataFrame.from_polars(df) - (a,) = gf.columns + a = gf.column_map["a"] assert a.is_sorted == plc.types.Sorted.YES diff --git a/python/cudf_polars/tests/expressions/test_sort.py b/python/cudf_polars/tests/expressions/test_sort.py index 76c7648813a..2a37683478b 100644 --- a/python/cudf_polars/tests/expressions/test_sort.py +++ b/python/cudf_polars/tests/expressions/test_sort.py @@ -69,7 +69,7 @@ def test_setsorted(descending, nulls_last, with_nulls): df = translate_ir(q._ldf.visit()).evaluate(cache={}) - (a,) = df.columns + a = df.column_map["a"] assert a.is_sorted == plc.types.Sorted.YES null_order = ( diff --git a/python/cudf_polars/tests/utils/test_broadcast.py b/python/cudf_polars/tests/utils/test_broadcast.py index 35aaef44e1f..54d3710ec93 100644 --- a/python/cudf_polars/tests/utils/test_broadcast.py +++ b/python/cudf_polars/tests/utils/test_broadcast.py @@ -6,34 +6,39 @@ import pylibcudf as plc import pytest -from cudf_polars.containers import NamedColumn +from cudf_polars.containers import Column from cudf_polars.dsl.ir import broadcast @pytest.mark.parametrize("target", [4, None]) def test_broadcast_all_scalar(target): columns = [ - NamedColumn( - plc.column_factories.make_numeric_column( - plc.DataType(plc.TypeId.INT8), 1, plc.MaskState.ALL_VALID - ), + ( f"col{i}", + Column( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 1, plc.MaskState.ALL_VALID + ) + ), ) for i in range(3) ] - result = broadcast(*columns, target_length=target) + names, result = broadcast(*columns, target_length=target) expected = 1 if target is None else target + assert names == [f"col{i}" for i in range(3)] assert all(column.obj.size() == expected for column in result) def test_invalid_target_length(): columns = [ - NamedColumn( - plc.column_factories.make_numeric_column( - plc.DataType(plc.TypeId.INT8), 4, plc.MaskState.ALL_VALID - ), + ( f"col{i}", + Column( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 4, plc.MaskState.ALL_VALID + ) + ), ) for i in range(3) ] @@ -43,11 +48,13 @@ def test_invalid_target_length(): def test_broadcast_mismatching_column_lengths(): columns = [ - NamedColumn( - plc.column_factories.make_numeric_column( - plc.DataType(plc.TypeId.INT8), i + 1, plc.MaskState.ALL_VALID - ), + ( f"col{i}", + Column( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), i + 1, plc.MaskState.ALL_VALID + ) + ), ) for i in range(3) ] @@ -58,16 +65,19 @@ def test_broadcast_mismatching_column_lengths(): @pytest.mark.parametrize("nrows", [0, 5]) def test_broadcast_with_scalars(nrows): columns = [ - NamedColumn( - plc.column_factories.make_numeric_column( - plc.DataType(plc.TypeId.INT8), - nrows if i == 0 else 1, - plc.MaskState.ALL_VALID, - ), + ( f"col{i}", + Column( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), + nrows if i == 0 else 1, + plc.MaskState.ALL_VALID, + ) + ), ) for i in range(3) ] - result = broadcast(*columns) + names, result = broadcast(*columns) + assert names == [f"col{i}" for i in range(3)] assert all(column.obj.size() == nrows for column in result)