diff --git a/taxoniumtools/src/taxoniumtools/usher_to_taxonium.py b/taxoniumtools/src/taxoniumtools/usher_to_taxonium.py index 5b09ec0d..18677a97 100644 --- a/taxoniumtools/src/taxoniumtools/usher_to_taxonium.py +++ b/taxoniumtools/src/taxoniumtools/usher_to_taxonium.py @@ -28,6 +28,7 @@ def do_processing(input_file, chronumental_date_output=None, chronumental_tree_output=None, chronumental_reference_node=None, + chronumental_add_inferred_date=None, config_file=None, title=None, overlay_html=None, @@ -82,7 +83,10 @@ def do_processing(input_file, metadata_file=metadata_file, chronumental_steps=chronumental_steps, chronumental_date_output=chronumental_date_output, - chronumental_tree_output=chronumental_tree_output) + chronumental_tree_output=chronumental_tree_output, + chronumental_add_inferred_date=chronumental_add_inferred_date, + metadata_dict=metadata_dict, + metadata_cols=metadata_cols) print("Ladderizing tree..") mat.tree.ladderize(ascending=False) @@ -236,6 +240,12 @@ def get_parser(): help= "A reference node to be used for Chronumental. This should be earlier in the outbreak and have a good defined date. If not set the oldest sample will be automatically picked by Chronumental.", default=None) + parser.add_argument( + "--chronumental_add_inferred_date", + type=str, + help= + "A new metadata-column-like name to be added for display with the value of Chronumental's inferred date for each sample.", + default=None) parser.add_argument( '-j', "--config_json", @@ -304,26 +314,28 @@ def main(): parser = get_parser() args = parser.parse_args() - do_processing(args.input, - args.output, - metadata_file=args.metadata, - genbank_file=args.genbank, - chronumental_enabled=args.chronumental, - chronumental_steps=args.chronumental_steps, - columns=args.columns, - chronumental_date_output=args.chronumental_date_output, - chronumental_tree_output=args.chronumental_tree_output, - chronumental_reference_node=args.chronumental_reference_node, - config_file=args.config_json, - title=args.title, - overlay_html=args.overlay_html, - remove_after_pipe=args.remove_after_pipe, - clade_types=args.clade_types, - name_internal_nodes=args.name_internal_nodes, - shear=args.shear, - shear_threshold=args.shear_threshold, - only_variable_sites=args.only_variable_sites, - key_column=args.key_column) + do_processing( + args.input, + args.output, + metadata_file=args.metadata, + genbank_file=args.genbank, + chronumental_enabled=args.chronumental, + chronumental_steps=args.chronumental_steps, + columns=args.columns, + chronumental_date_output=args.chronumental_date_output, + chronumental_tree_output=args.chronumental_tree_output, + chronumental_reference_node=args.chronumental_reference_node, + chronumental_add_inferred_date=args.chronumental_add_inferred_date, + config_file=args.config_json, + title=args.title, + overlay_html=args.overlay_html, + remove_after_pipe=args.remove_after_pipe, + clade_types=args.clade_types, + name_internal_nodes=args.name_internal_nodes, + shear=args.shear, + shear_threshold=args.shear_threshold, + only_variable_sites=args.only_variable_sites, + key_column=args.key_column) if __name__ == "__main__": diff --git a/taxoniumtools/src/taxoniumtools/utils.py b/taxoniumtools/src/taxoniumtools/utils.py index 7e12e5f5..10f3a3dd 100644 --- a/taxoniumtools/src/taxoniumtools/utils.py +++ b/taxoniumtools/src/taxoniumtools/utils.py @@ -1,7 +1,7 @@ from alive_progress import alive_it, alive_bar import pandas as pd import warnings -import os, tempfile, sys +import os, tempfile, sys, errno import treeswift import shutil from . import ushertools @@ -34,7 +34,7 @@ def read_metadata(metadata_file, columns, key_column): ) metadata_dict = metadata.to_dict("index") - metadata_cols = metadata.columns + metadata_cols = metadata.columns.tolist() del metadata print("Metadata loaded") return metadata_dict, metadata_cols @@ -46,7 +46,8 @@ def read_metadata(metadata_file, columns, key_column): def do_chronumental(mat, chronumental_reference_node, metadata_file, chronumental_steps, chronumental_date_output, - chronumental_tree_output): + chronumental_tree_output, chronumental_add_inferred_date, + metadata_dict, metadata_cols): chronumental_is_available = os.system( "which chronumental > /dev/null") == 0 if not chronumental_is_available: @@ -93,6 +94,42 @@ def do_chronumental(mat, chronumental_reference_node, metadata_file, del time_tree del time_tree_iter + if chronumental_add_inferred_date: + print( + f"Adding chronumental inferred date as metadata-like item {chronumental_add_inferred_date}" + ) + if chronumental_date_output: + date_output_filename = chronumental_date_output + else: + metadata_dir = os.path.dirname(metadata_file) + metadata_base = os.path.basename(metadata_file) + date_output_filename = f"chronumental_dates_{metadata_base}.tsv" + if metadata_dir: + date_output_filename = metadata_dir + "/" + date_output_filename + if not os.path.exists(date_output_filename): + raise FileNotFoundError( + errno.ENOENT, + f"Can't find default date output file in the expected location (try specifying a file name with --chronumental_date_output)", + date_output_filename) + metadata_cols.append(chronumental_add_inferred_date) + inferred_dates = pd.read_csv( + date_output_filename, + sep="\t" if metadata_file.endswith(".tsv") + or metadata_file.endswith(".tsv.gz") else ",", + usecols=['strain', 'predicted_date']) + for idx, row in inferred_dates.iterrows(): + node_name = row['strain'] + inferred_date = row['predicted_date'] + # Add inferred_date even if a node (e.g. internal node) has no metadata + if node_name not in metadata_dict: + metadata_dict[node_name] = { + col: "" + for col in metadata_cols + } + metadata_dict[node_name][ + chronumental_add_inferred_date] = inferred_date + del inferred_dates + def set_x_coords(root, chronumental_enabled): """ Set x coordinates for the tree"""