Skip to content

Commit

Permalink
Partial test for ESGPT stuff; not yet working
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Jun 10, 2024
1 parent 6facc56 commit 9ab76e6
Showing 1 changed file with 98 additions and 32 deletions.
130 changes: 98 additions & 32 deletions src/aces/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,10 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f:
... data_path = Path(f.name)
... parquet_data.write_parquet(data_path)
... generate_plain_predicates_from_meds(data_path, {"discharge":
... PlainPredicateConfig("discharge")})
... generate_plain_predicates_from_meds(
... data_path,
... {"discharge": PlainPredicateConfig("discharge")}
... )
shape: (3, 3)
┌────────────┬─────────────────────┬───────────┐
│ subject_id ┆ timestamp ┆ discharge │
Expand Down Expand Up @@ -268,42 +270,58 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl
)


def generate_plain_predicates_from_esgpt(data_path: Path, predicates: dict) -> pl.DataFrame:
"""Generate plain predicate columns from an ESGPT dataset.
To learn more about the ESGPT format, please visit https://eventstreamml.readthedocs.io/en/latest/
def process_esgpt_data(
events_df: pl.DataFrame,
dynamic_measurements_df: pl.DataFrame,
value_columns: dict[str, str],
predicates: dict,
) -> pl.DataFrame:
"""Process ESGPT data to generate plain predicate columns.
Args:
data_path: The path to the ESGPT dataset directory.
predicates: The dictionary of plain predicate configurations.
events_df: The Polars DataFrame containing the events data.
dynamic_measurements_df: The Polars DataFrame containing the dynamic measurements data.
Returns:
The Polars DataFrame containing the extracted predicates per subject per timestamp across the entire
ESGPT dataset.
"""

try:
from EventStream.data.dataset_polars import Dataset
except ImportError as e:
raise ImportError(
"The 'EventStream' package is required to load ESGPT datasets. "
"If you mean to use a MEDS dataset, please specify the 'MEDS' standard. "
"Otherwise, please install the package from https://github.com/mmcdermott/EventStreamGPT and add "
"the package to your PYTHONPATH."
) from e

try:
ESD = Dataset.load(data_path)
except Exception as e:
raise ValueError(
f"Error loading data using ESGPT: {e}. "
"Please ensure the path provided is a valid ESGPT dataset directory. "
"If you mean to use a MEDS dataset, please specify the 'MEDS' standard."
) from e
events_df = ESD.events_df
dynamic_measurements_df = ESD.dynamic_measurements_df
config = ESD.config
Examples:
>>> from datetime import datetime
>>> events_df = pl.DataFrame({
... "event_id": [1, 2, 3, 4],
... "subject_id": [1, 1, 2, 2],
... "timestamp": [
... datetime(2021, 1, 1, 0, 0),
... datetime(2021, 1, 1, 12, 0),
... datetime(2021, 1, 2, 0, 0),
... datetime(2021, 1, 2, 12, 0),
... ],
... "event_type": ["adm", "dis", "adm", "death"],
... "age": [30, 30, 40, 40],
... })
>>> dynamic_measurements_df = pl.DataFrame({
... "event_id": [1, 1, 1, 2, 2, 3, 4, 5],
... "adm_loc": [],
... "dis_loc": [],
... "HR": [],
... "lab": [],
... "lab_val": [],
... })
>>> value_columns = {
... "is_admission": None,
... "is_discharge": None,
... "high_HR": "HR",
... "high_Potassium": "lab_val",
... }
>>> predicates = {
... "is_admission": PlainPredicateConfig(code="event_type//adm"),
... "is_discharge": PlainPredicateConfig(code="event_type//dis"),
... "high_HR": PlainPredicateConfig(code="HR//HR", value_min: 140),
... "high_Potassium": PlainPredicateConfig(code="lab//Potassium", value_min: 5.0),
... }
>>> process_esgpt_data(events_df, dynamic_measurements_df, value_columns, predicates)
"""

logger.info("Generating plain predicate columns...")
for name, plain_predicate in predicates.items():
Expand All @@ -312,7 +330,7 @@ def generate_plain_predicates_from_esgpt(data_path: Path, predicates: dict) -> p
plain_predicate.ESGPT_eval_expr().cast(PRED_CNT_TYPE).alias(name)
)
else:
values_column = config.measurement_configs[plain_predicate.code.split("//")[0]].values_column
values_column = value_columns[name]
dynamic_measurements_df = dynamic_measurements_df.with_columns(
plain_predicate.ESGPT_eval_expr(values_column).cast(PRED_CNT_TYPE).alias(name)
)
Expand Down Expand Up @@ -341,6 +359,54 @@ def generate_plain_predicates_from_esgpt(data_path: Path, predicates: dict) -> p
return data.select(["subject_id", "timestamp"] + predicate_cols)


def generate_plain_predicates_from_esgpt(data_path: Path, predicates: dict) -> pl.DataFrame:
"""Generate plain predicate columns from an ESGPT dataset.
To learn more about the ESGPT format, please visit https://eventstreamml.readthedocs.io/en/latest/
Args:
data_path: The path to the ESGPT dataset directory.
predicates: The dictionary of plain predicate configurations.
Returns:
The Polars DataFrame containing the extracted predicates per subject per timestamp across the entire
ESGPT dataset.
"""

try:
from EventStream.data.dataset_polars import Dataset
except ImportError as e:
raise ImportError(
"The 'EventStream' package is required to load ESGPT datasets. "
"If you mean to use a MEDS dataset, please specify the 'MEDS' standard. "
"Otherwise, please install the package from https://github.com/mmcdermott/EventStreamGPT and add "
"the package to your PYTHONPATH."
) from e

try:
ESD = Dataset.load(data_path)
except Exception as e:
raise ValueError(
f"Error loading data using ESGPT: {e}. "
"Please ensure the path provided is a valid ESGPT dataset directory. "
"If you mean to use a MEDS dataset, please specify the 'MEDS' standard."
) from e

events_df = ESD.events_df
dynamic_measurements_df = ESD.dynamic_measurements_df
config = ESD.config

value_columns = {}
for name, plain_predicate in predicates.items():
if "event_type" in plain_predicate.code:
value_columns[name] = None
else:
measurement_name = plain_predicate.code.split("//")[0]
value_columns[name] = config.measurement_configs[measurement_name].values_column

return process_esgpt_data(events_df, dynamic_measurements_df, value_columns, predicates)


def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.DataFrame:
"""Generate predicate columns based on the configuration.
Expand Down

0 comments on commit 9ab76e6

Please sign in to comment.