Skip to content

Commit

Permalink
Change how dtype handling is done for parquet reads
Browse files Browse the repository at this point in the history
  • Loading branch information
zschira committed Jan 30, 2024
1 parent f2c9f25 commit 41eb6d3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/pudl/io_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def load_input(self, context: InputContext) -> pd.DataFrame:
parquet_path = PudlPaths().parquet_path(table_name)
res = Resource.from_id(table_name)
df = pq.read_table(source=parquet_path, schema=res.to_pyarrow()).to_pandas()
return res.enforce_schema(df)
return res.enforce_schema(df, use_pyarrow_dtypes=True)


class PudlSQLiteIOManager(SQLiteIOManager):
Expand Down
63 changes: 35 additions & 28 deletions src/pudl/metadata/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,7 +1407,12 @@ def match_primary_key(self, names: Iterable[str]) -> dict[str, str] | None:
matches = {key: key for key in keys if key in names}
return matches if len(matches) == len(keys) else None

def format_df(self, df: pd.DataFrame | None = None, **kwargs: Any) -> pd.DataFrame:
def format_df(
self,
df: pd.DataFrame | None = None,
use_pyarrow_dtypes: bool = False,
**kwargs: Any,
) -> pd.DataFrame:
"""Format a dataframe according to the resources's table schema.
* DataFrame columns not in the schema are dropped.
Expand Down Expand Up @@ -1436,40 +1441,42 @@ def format_df(self, df: pd.DataFrame | None = None, **kwargs: Any) -> pd.DataFra
df = df.copy()
# Rename periodic key columns (if any) to the requested period
df = df.rename(columns=matches)
# Cast integer year fields to datetime
for field in self.schema.fields:
if (
field.type == "year"
and field.name in df
and pd.api.types.is_integer_dtype(df[field.name])
):
df[field.name] = pd.to_datetime(df[field.name], format="%Y")
if isinstance(dtypes[field.name], pd.CategoricalDtype):
uncategorized = [
value
for value in df[field.name].dropna().unique()
if value not in dtypes[field.name].categories
]
if uncategorized:
logger.warning(
f"Values in {field.name} column are not included in "
"categorical values in field enum constraint "
f"and will be converted to nulls ({uncategorized})."
)
df = (
# Reorder columns and insert missing columns
df.reindex(columns=dtypes.keys(), copy=False)

df = df.reindex(columns=dtypes.keys(), copy=False)
# Handle dtypes if not using pyarrow types
if not use_pyarrow_dtypes:
for field in self.schema.fields:
# Cast integer year fields to datetime
if (
field.type == "year"
and field.name in df
and pd.api.types.is_integer_dtype(df[field.name])
):
df[field.name] = pd.to_datetime(df[field.name], format="%Y")
if isinstance(dtypes[field.name], pd.CategoricalDtype):
uncategorized = [
value
for value in df[field.name].dropna().unique()
if value not in dtypes[field.name].categories
]
if uncategorized:
logger.warning(
f"Values in {field.name} column are not included in "
"categorical values in field enum constraint "
f"and will be converted to nulls ({uncategorized})."
)
# Coerce columns to correct data type
.astype(dtypes, copy=False)
)
df = df.astype(dtypes, copy=False)
# Convert periodic key columns to the requested period
for df_key, key in matches.items():
_, period = split_period(key)
if period and df_key != key:
df[key] = PERIODS[period](df[key])
return df

def enforce_schema(self, df: pd.DataFrame) -> pd.DataFrame:
def enforce_schema(
self, df: pd.DataFrame, use_pyarrow_dtypes: bool = False
) -> pd.DataFrame:
"""Drop columns not in the DB schema and enforce specified types."""
expected_cols = pd.Index(self.get_field_names())
missing_cols = list(expected_cols.difference(df.columns))
Expand All @@ -1479,7 +1486,7 @@ def enforce_schema(self, df: pd.DataFrame) -> pd.DataFrame:
f"schema: {missing_cols}"
)

df = self.format_df(df)
df = self.format_df(df, use_pyarrow_dtypes=use_pyarrow_dtypes)
pk = self.schema.primary_key
if pk and not df[df.duplicated(subset=pk)].empty:
raise ValueError(
Expand Down

0 comments on commit 41eb6d3

Please sign in to comment.