Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -261,38 +261,56 @@ def execute(self, context: Context) -> None:
file_obj=buf, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
)

def _partition_dataframe(self, df: pd.DataFrame | pl.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]:
"""Partition dataframe using pandas groupby() method."""
def _partition_dataframe(
self, df: pd.DataFrame | pl.DataFrame
) -> Iterable[tuple[str, pd.DataFrame | pl.DataFrame]]:
"""Partition dataframe using pandas or polars groupby() method."""
try:
import secrets
import string

import numpy as np
import pandas as pd
import polars as pl
except ImportError:
pass

if isinstance(df, pl.DataFrame):
df = df.to_pandas()

# if max_rows_per_file argument is specified, a temporary column with a random unusual name will be
# added to the dataframe. This column is used to dispatch the dataframe into smaller ones using groupby()

random_column_name = ""
random_column_name = None
if self.max_rows_per_file and not self.groupby_kwargs:
random_column_name = "".join(secrets.choice(string.ascii_letters) for _ in range(20))
df[random_column_name] = np.arange(len(df)) // self.max_rows_per_file
self.groupby_kwargs = {"by": random_column_name}

if random_column_name:
if isinstance(df, pd.DataFrame):
df[random_column_name] = np.arange(len(df)) // self.max_rows_per_file
elif isinstance(df, pl.DataFrame):
df = df.with_columns(
(pl.int_range(pl.len()) // self.max_rows_per_file).alias(random_column_name)
)

if not self.groupby_kwargs:
yield "", df
return
for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups:
yield (
cast("str", group_label),
grouped_df.get_group(group_label)
.drop(random_column_name, axis=1, errors="ignore")
.reset_index(drop=True),
)

if isinstance(df, pd.DataFrame):
for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups:
group_df = grouped_df.get_group(group_label)
if random_column_name:
group_df = group_df.drop(random_column_name, axis=1, errors="ignore")
yield (
cast("str", group_label[0] if isinstance(group_label, tuple) else group_label),
group_df.reset_index(drop=True),
)
elif isinstance(df, pl.DataFrame):
for group_label, group_df in df.group_by(**self.groupby_kwargs): # type: ignore[assignment]
if random_column_name:
group_df = group_df.drop(random_column_name)
yield (
cast("str", group_label[0] if isinstance(group_label, tuple) else group_label),
group_df,
)

def _get_hook(self) -> DbApiHook:
self.log.debug("Get connection for %s", self.sql_conn_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest import mock

import pandas as pd
import polars as pl
import pytest

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
Expand Down Expand Up @@ -442,7 +443,10 @@ def test_partition_dataframe(self, df_type, input_df_creator):

assert len(partitions) == 2
for group_name, df in partitions:
assert isinstance(df, pd.DataFrame)
if df_type == "polars":
assert isinstance(df, pl.DataFrame)
else:
assert isinstance(df, pd.DataFrame)
assert group_name in ["A", "B"]

@pytest.mark.parametrize(
Expand Down
Loading