From 9ab76e6b051f10fff9cdc885c0395b45f115d481 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 10 Jun 2024 15:56:01 -0400 Subject: [PATCH] Partial test for ESGPT stuff; not yet working --- src/aces/predicates.py | 130 +++++++++++++++++++++++++++++++---------- 1 file changed, 98 insertions(+), 32 deletions(-) diff --git a/src/aces/predicates.py b/src/aces/predicates.py index 27ca659a..0c346f76 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -231,8 +231,10 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f: ... data_path = Path(f.name) ... parquet_data.write_parquet(data_path) - ... generate_plain_predicates_from_meds(data_path, {"discharge": - ... PlainPredicateConfig("discharge")}) + ... generate_plain_predicates_from_meds( + ... data_path, + ... {"discharge": PlainPredicateConfig("discharge")} + ... ) shape: (3, 3) ┌────────────┬─────────────────────┬───────────┐ │ subject_id ┆ timestamp ┆ discharge │ @@ -268,42 +270,58 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl ) -def generate_plain_predicates_from_esgpt(data_path: Path, predicates: dict) -> pl.DataFrame: - """Generate plain predicate columns from an ESGPT dataset. - - To learn more about the ESGPT format, please visit https://eventstreamml.readthedocs.io/en/latest/ +def process_esgpt_data( + events_df: pl.DataFrame, + dynamic_measurements_df: pl.DataFrame, + value_columns: dict[str, str], + predicates: dict, +) -> pl.DataFrame: + """Process ESGPT data to generate plain predicate columns. Args: - data_path: The path to the ESGPT dataset directory. - predicates: The dictionary of plain predicate configurations. + events_df: The Polars DataFrame containing the events data. + dynamic_measurements_df: The Polars DataFrame containing the dynamic measurements data. Returns: The Polars DataFrame containing the extracted predicates per subject per timestamp across the entire ESGPT dataset. - """ - - try: - from EventStream.data.dataset_polars import Dataset - except ImportError as e: - raise ImportError( - "The 'EventStream' package is required to load ESGPT datasets. " - "If you mean to use a MEDS dataset, please specify the 'MEDS' standard. " - "Otherwise, please install the package from https://github.com/mmcdermott/EventStreamGPT and add " - "the package to your PYTHONPATH." - ) from e - - try: - ESD = Dataset.load(data_path) - except Exception as e: - raise ValueError( - f"Error loading data using ESGPT: {e}. " - "Please ensure the path provided is a valid ESGPT dataset directory. " - "If you mean to use a MEDS dataset, please specify the 'MEDS' standard." - ) from e - events_df = ESD.events_df - dynamic_measurements_df = ESD.dynamic_measurements_df - config = ESD.config + Examples: + >>> from datetime import datetime + >>> events_df = pl.DataFrame({ + ... "event_id": [1, 2, 3, 4], + ... "subject_id": [1, 1, 2, 2], + ... "timestamp": [ + ... datetime(2021, 1, 1, 0, 0), + ... datetime(2021, 1, 1, 12, 0), + ... datetime(2021, 1, 2, 0, 0), + ... datetime(2021, 1, 2, 12, 0), + ... ], + ... "event_type": ["adm", "dis", "adm", "death"], + ... "age": [30, 30, 40, 40], + ... }) + >>> dynamic_measurements_df = pl.DataFrame({ + ... "event_id": [1, 1, 1, 2, 2, 3, 4, 5], + ... "adm_loc": [], + ... "dis_loc": [], + ... "HR": [], + ... "lab": [], + ... "lab_val": [], + ... }) + >>> value_columns = { + ... "is_admission": None, + ... "is_discharge": None, + ... "high_HR": "HR", + ... "high_Potassium": "lab_val", + ... } + >>> predicates = { + ... "is_admission": PlainPredicateConfig(code="event_type//adm"), + ... "is_discharge": PlainPredicateConfig(code="event_type//dis"), + ... "high_HR": PlainPredicateConfig(code="HR//HR", value_min: 140), + ... "high_Potassium": PlainPredicateConfig(code="lab//Potassium", value_min: 5.0), + ... } + >>> process_esgpt_data(events_df, dynamic_measurements_df, value_columns, predicates) + """ logger.info("Generating plain predicate columns...") for name, plain_predicate in predicates.items(): @@ -312,7 +330,7 @@ def generate_plain_predicates_from_esgpt(data_path: Path, predicates: dict) -> p plain_predicate.ESGPT_eval_expr().cast(PRED_CNT_TYPE).alias(name) ) else: - values_column = config.measurement_configs[plain_predicate.code.split("//")[0]].values_column + values_column = value_columns[name] dynamic_measurements_df = dynamic_measurements_df.with_columns( plain_predicate.ESGPT_eval_expr(values_column).cast(PRED_CNT_TYPE).alias(name) ) @@ -341,6 +359,54 @@ def generate_plain_predicates_from_esgpt(data_path: Path, predicates: dict) -> p return data.select(["subject_id", "timestamp"] + predicate_cols) +def generate_plain_predicates_from_esgpt(data_path: Path, predicates: dict) -> pl.DataFrame: + """Generate plain predicate columns from an ESGPT dataset. + + To learn more about the ESGPT format, please visit https://eventstreamml.readthedocs.io/en/latest/ + + Args: + data_path: The path to the ESGPT dataset directory. + predicates: The dictionary of plain predicate configurations. + + Returns: + The Polars DataFrame containing the extracted predicates per subject per timestamp across the entire + ESGPT dataset. + """ + + try: + from EventStream.data.dataset_polars import Dataset + except ImportError as e: + raise ImportError( + "The 'EventStream' package is required to load ESGPT datasets. " + "If you mean to use a MEDS dataset, please specify the 'MEDS' standard. " + "Otherwise, please install the package from https://github.com/mmcdermott/EventStreamGPT and add " + "the package to your PYTHONPATH." + ) from e + + try: + ESD = Dataset.load(data_path) + except Exception as e: + raise ValueError( + f"Error loading data using ESGPT: {e}. " + "Please ensure the path provided is a valid ESGPT dataset directory. " + "If you mean to use a MEDS dataset, please specify the 'MEDS' standard." + ) from e + + events_df = ESD.events_df + dynamic_measurements_df = ESD.dynamic_measurements_df + config = ESD.config + + value_columns = {} + for name, plain_predicate in predicates.items(): + if "event_type" in plain_predicate.code: + value_columns[name] = None + else: + measurement_name = plain_predicate.code.split("//")[0] + value_columns[name] = config.measurement_configs[measurement_name].values_column + + return process_esgpt_data(events_df, dynamic_measurements_df, value_columns, predicates) + + def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.DataFrame: """Generate predicate columns based on the configuration.