Skip to content

Commit

Permalink
spliting reduce_table function
Browse files Browse the repository at this point in the history
  • Loading branch information
svittoz committed May 7, 2024
1 parent f95b5fa commit 8e618ba
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 29 deletions.
32 changes: 16 additions & 16 deletions eds_scikit/plot/omop_teva.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,22 @@ def generate_omop_teva(
"visit_occurrence_id"
].astype(str)

#visit_occurrence_count = reduce_table(
# visit_occurrence,
# category_columns=config["category_columns"],
# date_column=config["date_column"],
# start_date=start_date,
# end_date=end_date,
# mapper=config["mapper"],
#)
#visit_occurrence_count = visit_occurrence_count[
# ~(visit_occurrence_count == 0).any(axis=1)
#]
#chart = visualize_table(
# visit_occurrence_count, title="visit_occurrence table dashboard"
#)
#save_pickle(f"{output_dir}/visit_occurrence_count", visit_occurrence_count)
#chart.save(f"{output_dir}/visit_occurrence_chart.html")
visit_occurrence_count = reduce_table(
visit_occurrence,
category_columns=config["category_columns"],
date_column=config["date_column"],
start_date=start_date,
end_date=end_date,
mapper=config["mapper"],
)
visit_occurrence_count = visit_occurrence_count[
~(visit_occurrence_count == 0).any(axis=1)
]
chart = visualize_table(
visit_occurrence_count, title="visit_occurrence table dashboard"
)
save_pickle(f"{output_dir}/visit_occurrence_count", visit_occurrence_count)
chart.save(f"{output_dir}/visit_occurrence_chart.html")
#
#logger.info("visit_occurrence processing done.")
#
Expand Down
80 changes: 67 additions & 13 deletions eds_scikit/plot/table_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def map_column(
] = target
return table


def reduce_table(
def preprocess_table(
table: DataFrame,
category_columns: List[str],
date_column: str,
Expand All @@ -58,8 +57,6 @@ def reduce_table(
mapper: dict = None,
) -> DataFrame:
"""
Reduce input table by counting each cartesian product values (col1, col2, ..., coln) for each columns in category_columns and each date.
Columns values must be under 50 . Use mapper to reduce this size.
Parameters
----------
Expand All @@ -79,41 +76,98 @@ def reduce_table(
Returns
-------
DataFrame
Reducted DataFrame with columns category_columns, date_column and count.
Formated and preprocessed table
Raises
------
ValueError
"""

check_columns(
table,
required_columns=[date_column],
)


# Check and format to string category columns

remove_colums = []

for col in category_columns:
if not (col in table.columns):
logger.info(f"Column {col} not in table.")
remove_colums += [col]
else:
table[col] = table[col].astype(str)

for col in remove_colums:
category_columns.remove(col)

if category_columns == []:
raise Exception("No columns from category_columns in input table.")

category_columns = [*category_columns, date_column]

table = table[category_columns]

# Filter table on dates

framework = get_framework(table)

table = table[(table[date_column] >= start_date) & (table[date_column] <= end_date)]
table["datetime"] = framework.to_datetime(table[date_column].dt.strftime("%Y-%m"))
category_columns = [*category_columns, "datetime"]
table = table.drop(columns=[date_column])

# Map category columns

if mapper:
for col, mapping in mapper.items():
table = map_column(table, mapping, col, col)

return table

table = table[category_columns]
def reduce_table(
table: DataFrame,
category_columns: List[str],
date_column: str,
start_date: str,
end_date: str,
mapper: dict = None,
) -> DataFrame:
"""
Reduce input table by counting each cartesian product values (col1, col2, ..., coln) for each columns in category_columns and each date.
Columns values must be under 50 . Use mapper to reduce this size.
Parameters
----------
table : DataFrame
Input dataframe to be reduced.
category_columns : List[str]
Columns to perform reduction on.
date_column : str
Date column.
start_date : str
start date
end_date : str
end date
mapper : dict
**EXAMPLE**: `{"column 1" : {"CR" : r"^CR", "CRH" : r"^CRH"}, "column 2" : {"code a" : r"^A", "code b" : r"^B"}}`
Returns
-------
DataFrame
Reducted DataFrame with columns category_columns, date_column and count.
Raises
------
ValueError
"""

check_columns(
table,
required_columns=[date_column],
)

table = preprocess_table(table, category_columns, date_column, start_date, end_date, mapper)

# to prevent computation issues
shape = table.shape # noqa

print(shape)

nunique = table.nunique()
oversized_columns = nunique[(nunique.index != "datetime") & (nunique > 50)].tolist()
Expand Down

0 comments on commit 8e618ba

Please sign in to comment.