diff --git a/src/aces/extract_subtree.py b/src/aces/extract_subtree.py index 7663436d..d1dcb8a1 100644 --- a/src/aces/extract_subtree.py +++ b/src/aces/extract_subtree.py @@ -54,6 +54,7 @@ def extract_subtree( Examples: >>> from bigtree import Node >>> from datetime import datetime + >>> from .types import ToEventWindowBounds, TemporalWindowBounds >>> # We'll use an example for in-hospital mortality prediction. Our root event of the tree will be >>> # an admission event. >>> root = Node("admission") @@ -63,7 +64,7 @@ def extract_subtree( >>> # Node 1 will represent our gap window. We say that in the 24 hours after the admission, there >>> # should be no discharges, deaths, or covid events. >>> gap_node = Node("gap") # This sets the node's name. - >>> gap_node.endpoint_expr = (True, timedelta(days=2), True) + >>> gap_node.endpoint_expr = TemporalWindowBounds(True, timedelta(days=2), True) >>> gap_node.constraints = { ... "is_discharge": (None, 0), "is_death": (None, 0), "is_covid_dx": (None, 0) ... } @@ -71,7 +72,7 @@ def extract_subtree( >>> # Node 2 will start our target window and span until the next discharge or death event. >>> # There should be no covid events. >>> target_node = Node("target") # This sets the node's name. - >>> target_node.endpoint_expr = (True, "is_discharge", True) + >>> target_node.endpoint_expr = ToEventWindowBounds(True, "is_discharge", True) >>> target_node.constraints = {"is_covid_dx": (None, 0)} >>> target_node.parent = gap_node >>> # @@ -79,11 +80,11 @@ def extract_subtree( >>> # Finally, for our second branch, we will impose no constraints but track the input time range, >>> # which will span from the beginning of the record to 24 hours after admission. >>> input_end_node = Node("input_end") - >>> input_end_node.endpoint_expr = (True, timedelta(days=1), True) + >>> input_end_node.endpoint_expr = TemporalWindowBounds(True, timedelta(days=1), True) >>> input_end_node.constraints = {} >>> input_end_node.parent = root >>> input_start_node = Node("input_start") - >>> input_start_node.endpoint_expr = (True, "-_RECORD_START", True) + >>> input_start_node.endpoint_expr = ToEventWindowBounds(True, "-_RECORD_START", True) >>> input_start_node.constraints = {} >>> input_start_node.parent = root >>> # @@ -93,11 +94,11 @@ def extract_subtree( >>> # This will be expressed through two windows, one spanning back a year, and the other looking >>> # prior to that year. >>> pre_node_1yr = Node("pre_node_1yr") - >>> pre_node_1yr.endpoint_expr = (False, timedelta(days=-365), False) + >>> pre_node_1yr.endpoint_expr = TemporalWindowBounds(False, timedelta(days=-365), False) >>> pre_node_1yr.constraints = {} >>> pre_node_1yr.parent = root >>> pre_node_total = Node("pre_node_total") - >>> pre_node_total.endpoint_expr = (False, "-_RECORD_START", False) + >>> pre_node_total.endpoint_expr = ToEventWindowBounds(False, "-_RECORD_START", False) >>> pre_node_total.constraints = {"*": (1, None)} >>> pre_node_total.parent = pre_node_1yr >>> # @@ -272,11 +273,16 @@ def extract_subtree( # In an event bound case, the child root will be a proper extant event, so it will be the # anchor as well, and thus the child root offset should be zero. child_root_offset = timedelta(days=0) + if endpoint_expr.end_event.startswith("-"): + child_anchor_time = "timestamp_at_start" + else: + child_anchor_time = "timestamp_at_end" + window_summary_df = ( aggregate_event_bound_window(predicates_df, endpoint_expr) .with_columns( pl.col("timestamp").alias("subtree_anchor_timestamp"), - pl.col("timestamp_at_end").alias("child_anchor_timestamp"), + pl.col(child_anchor_time).alias("child_anchor_timestamp"), ) .drop("timestamp") )