diff --git a/netam/multihit.py b/netam/multihit.py index 482cb6ec..529f1917 100644 --- a/netam/multihit.py +++ b/netam/multihit.py @@ -14,7 +14,7 @@ from torch.utils.data import Dataset from tqdm import tqdm import pandas as pd -from typing import Sequence, List +from typing import Sequence, List, Tuple from netam.molevol import ( reshape_for_codons, @@ -452,7 +452,7 @@ def hit_class_dataset_from_pcp_df( def train_test_datasets_of_pcp_df( pcp_df: pd.DataFrame, train_frac: float = 0.8, branch_length_multiplier: float = 1.0 -) -> tuple[HitClassDataset, HitClassDataset]: +) -> Tuple[HitClassDataset, HitClassDataset]: """Splits a pcp_df prepared by `prepare_pcp_df` into a training and testing HitClassDataset.""" nt_parents = pcp_df["parent"].reset_index(drop=True)