From 4f31eb83d3be90e82cb2548b4d022cd15e975980 Mon Sep 17 00:00:00 2001 From: John Huddleston Date: Wed, 28 Aug 2024 15:19:17 -0700 Subject: [PATCH] Add t-SNE embeddings and clusters Adds rules, scripts, and config to produce joint t-SNE embeddings per build from all gene segments and find clusters from the resulting embeddings. When the user defines the `embedding` key in their build config, the workflow produces pairwise distances per gene segment, runs t-SNE on those distances, finds clusters with HDBSCAN, and exports the embedding coordinates and clusters in the Auspice JSON. --- Snakefile | 103 +++++++++++++++++++++++++++ config/gisaid.yaml | 5 +- config/h5n1/auspice_config_h5n1.json | 20 +++++- config/h5nx/auspice_config_h5nx.json | 18 ++++- config/h7n9/auspice_config_h7n9.json | 18 ++++- config/h9n2/auspice_config_h9n2.json | 18 ++++- scripts/intersect_items.py | 23 ++++++ scripts/table_to_node_data.py | 32 +++++++++ 8 files changed, 231 insertions(+), 6 deletions(-) create mode 100644 scripts/intersect_items.py create mode 100644 scripts/table_to_node_data.py diff --git a/Snakefile b/Snakefile index bcc91af..c3760c7 100755 --- a/Snakefile +++ b/Snakefile @@ -490,6 +490,105 @@ rule cleavage_site: --cleavage_site_sequence {output.cleavage_site_sequences} """ +rule get_strains_in_alignment: + input: + alignment = "results/{subtype}/{segment}/{time}/aligned.fasta", + output: + alignment_strains = "results/{subtype}/{segment}/{time}/aligned_strains.txt", + shell: + """ + seqkit fx2tab -n -i {input.alignment} | sort -k 1,1 > {output.alignment_strains} + """ + +rule get_shared_strains_in_alignments: + input: + alignment_strains = expand("results/{{subtype}}/{segment}/{{time}}/aligned_strains.txt", segment=config["segments"]), + output: + shared_strains = "results/{subtype}/all/{time}/shared_strains_in_alignment.txt", + shell: + """ + python3 scripts/intersect_items.py \ + --items {input.alignment_strains:q} \ + --output {output.shared_strains} + """ + +rule select_shared_strains_from_alignment_and_sort: + input: + shared_strains = "results/{subtype}/all/{time}/shared_strains_in_alignment.txt", + alignment = "results/{subtype}/{segment}/{time}/aligned.fasta", + output: + alignment = "results/{subtype}/{segment}/{time}/aligned.sorted.fasta", + shell: + """ + seqkit grep -f {input.shared_strains} {input.alignment} \ + | seqkit sort -n > {output.alignment} + """ + +rule calculate_pairwise_distances: + input: + alignment = "results/{subtype}/{segment}/{time}/aligned.sorted.fasta", + output: + distances = "results/{subtype}/{segment}/{time}/distances.csv", + benchmark: + "benchmarks/calculate_pairwise_distances_{subtype}_{segment}_{time}.txt" + shell: + """ + pathogen-distance \ + --alignment {input.alignment} \ + --output {output.distances} + """ + +rule embed_with_tsne: + input: + alignments = expand("results/{{subtype}}/{segment}/{{time}}/aligned.sorted.fasta", segment=config["segments"]), + distances = expand("results/{{subtype}}/{segment}/{{time}}/distances.csv", segment=config["segments"]), + output: + embedding = "results/{subtype}/all/{time}/embed_tsne.csv", + params: + perplexity=config.get("embedding", {}).get("perplexity", 200), + benchmark: + "benchmarks/embed_with_tsne_{subtype}_{time}.txt" + shell: + """ + pathogen-embed \ + --alignment {input.alignments} \ + --distance-matrix {input.distances} \ + --output-dataframe {output.embedding} \ + t-sne \ + --perplexity {params.perplexity} + """ + +rule cluster_tsne_embedding: + input: + embedding = "results/{subtype}/all/{time}/embed_tsne.csv", + output: + clusters = "results/{subtype}/all/{time}/cluster_embed_tsne.csv", + params: + label_attribute="tsne_cluster", + distance_threshold=1.0, + benchmark: + "benchmarks/cluster_tsne_embedding_{subtype}_{time}.txt" + shell: + """ + pathogen-cluster \ + --embedding {input.embedding} \ + --label-attribute {params.label_attribute:q} \ + --distance-threshold {params.distance_threshold} \ + --output-dataframe {output.clusters} + """ + +rule convert_embedding_clusters_to_node_data: + input: + clusters = "results/{subtype}/all/{time}/cluster_embed_tsne.csv", + output: + node_data = "results/{subtype}/all/{time}/cluster_embed_tsne.json", + shell: + """ + python3 scripts/table_to_node_data.py \ + --table {input.clusters} \ + --output {output.node_data} + """ + def export_node_data_files(wildcards): nd = [ rules.refine.output.node_data, @@ -502,6 +601,10 @@ def export_node_data_files(wildcards): if wildcards.subtype=="h5n1-cattle-outbreak" and wildcards.segment!='genome': nd.append(rules.prune_tree.output.node_data) + + if config.get("embedding"): + nd.append(rules.convert_embedding_clusters_to_node_data.output.node_data) + return nd diff --git a/config/gisaid.yaml b/config/gisaid.yaml index 284d36e..3565eb2 100644 --- a/config/gisaid.yaml +++ b/config/gisaid.yaml @@ -32,7 +32,7 @@ local_ingest: false #### Parameters which control large overarching aspects of the build target_sequences_per_tree: 3000 -same_strains_per_segment: false +same_strains_per_segment: true #### Config files #### @@ -159,6 +159,9 @@ traits: confidence: FALLBACK: true +embedding: + perplexity: 200 + export: title: FALLBACK: false # use the title in the auspice JSON diff --git a/config/h5n1/auspice_config_h5n1.json b/config/h5n1/auspice_config_h5n1.json index 264d44c..ef28c0f 100755 --- a/config/h5n1/auspice_config_h5n1.json +++ b/config/h5n1/auspice_config_h5n1.json @@ -39,7 +39,7 @@ "key": "division", "title": "Admin Division", "type": "categorical" - }, + }, { "key": "host", "title": "Host", @@ -89,6 +89,21 @@ "key": "submitting_lab", "title": "Submitting Lab", "type": "categorical" + }, + { + "key": "tsne_x", + "title": "t-SNE 1", + "type": "continuous" + }, + { + "key": "tsne_y", + "title": "t-SNE 2", + "type": "continuous" + }, + { + "key": "tsne_cluster", + "title": "t-SNE cluster", + "type": "categorical" } ], "geo_resolutions": [ @@ -111,6 +126,7 @@ "gisaid_clade", "authors", "originating_lab", - "submitting_lab" + "submitting_lab", + "tsne_cluster" ] } diff --git a/config/h5nx/auspice_config_h5nx.json b/config/h5nx/auspice_config_h5nx.json index dcda973..b141da4 100755 --- a/config/h5nx/auspice_config_h5nx.json +++ b/config/h5nx/auspice_config_h5nx.json @@ -89,6 +89,21 @@ "key": "submitting_lab", "title": "Submitting Lab", "type": "categorical" + }, + { + "key": "tsne_x", + "title": "t-SNE 1", + "type": "continuous" + }, + { + "key": "tsne_y", + "title": "t-SNE 2", + "type": "continuous" + }, + { + "key": "tsne_cluster", + "title": "t-SNE cluster", + "type": "categorical" } ], "geo_resolutions": [ @@ -111,6 +126,7 @@ "gisaid_clade", "authors", "originating_lab", - "submitting_lab" + "submitting_lab", + "tsne_cluster" ] } diff --git a/config/h7n9/auspice_config_h7n9.json b/config/h7n9/auspice_config_h7n9.json index 915880c..6ae3bea 100755 --- a/config/h7n9/auspice_config_h7n9.json +++ b/config/h7n9/auspice_config_h7n9.json @@ -54,6 +54,21 @@ "key": "submitting_lab", "title": "Submitting Lab", "type": "categorical" + }, + { + "key": "tsne_x", + "title": "t-SNE 1", + "type": "continuous" + }, + { + "key": "tsne_y", + "title": "t-SNE 2", + "type": "continuous" + }, + { + "key": "tsne_cluster", + "title": "t-SNE cluster", + "type": "categorical" } ], "geo_resolutions": [ @@ -70,6 +85,7 @@ "country", "division", "originating_lab", - "submitting_lab" + "submitting_lab", + "tsne_cluster" ] } diff --git a/config/h9n2/auspice_config_h9n2.json b/config/h9n2/auspice_config_h9n2.json index 48c274b..280ac9b 100755 --- a/config/h9n2/auspice_config_h9n2.json +++ b/config/h9n2/auspice_config_h9n2.json @@ -54,6 +54,21 @@ "key": "submitting_lab", "title": "Submitting Lab", "type": "categorical" + }, + { + "key": "tsne_x", + "title": "t-SNE 1", + "type": "continuous" + }, + { + "key": "tsne_y", + "title": "t-SNE 2", + "type": "continuous" + }, + { + "key": "tsne_cluster", + "title": "t-SNE cluster", + "type": "categorical" } ], "geo_resolutions": [ @@ -69,6 +84,7 @@ "region", "country", "originating_lab", - "submitting_lab" + "submitting_lab", + "tsne_cluster" ] } diff --git a/scripts/intersect_items.py b/scripts/intersect_items.py new file mode 100644 index 0000000..0fba101 --- /dev/null +++ b/scripts/intersect_items.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +import argparse + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--items", nargs="+", required=True, help="one or more files containing a list of items") + parser.add_argument("--output", required=True, help="list of items shared by all input files (the intersection)") + + args = parser.parse_args() + + with open(args.items[0], "r", encoding="utf-8") as fh: + shared_items = {line.strip() for line in fh} + + for item_file in args.items[1:]: + with open(item_file, "r", encoding="utf-8") as fh: + items = {line.strip() for line in fh} + + shared_items = shared_items & items + + with open(args.output, "w", encoding="utf-8") as oh: + for item in sorted(shared_items): + print(item, file=oh) diff --git a/scripts/table_to_node_data.py b/scripts/table_to_node_data.py new file mode 100644 index 0000000..fa2d156 --- /dev/null +++ b/scripts/table_to_node_data.py @@ -0,0 +1,32 @@ +"""Create Augur-compatible node data JSON from a pandas data frame. +""" +import argparse +import pandas as pd +from augur.utils import write_json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--table", help="table to convert to a node data JSON") + parser.add_argument("--index-column", default="strain", help="name of the column to use as an index") + parser.add_argument("--delimiter", default=",", help="separator between columns in the given table") + parser.add_argument("--node-name", default="nodes", help="name of the node data attribute in the JSON output") + parser.add_argument("--output", help="node data JSON file") + + args = parser.parse_args() + + if args.output is not None: + table = pd.read_csv( + args.table, + sep=args.delimiter, + index_col=args.index_column, + dtype=str, + ) + + # # Convert columns that aren't strain names or labels to floats. + # for column in table.columns: + # if column != "strain" and not "label" in column: + # table[column] = table[column].astype(float) + + table_dict = table.transpose().to_dict() + write_json({args.node_name: table_dict}, args.output)