Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add t-SNE embeddings and clusters #88

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,105 @@ rule cleavage_site:
--cleavage_site_sequence {output.cleavage_site_sequences}
"""

rule get_strains_in_alignment:
input:
alignment = "results/{subtype}/{segment}/{time}/aligned.fasta",
output:
alignment_strains = "results/{subtype}/{segment}/{time}/aligned_strains.txt",
shell:
"""
seqkit fx2tab -n -i {input.alignment} | sort -k 1,1 > {output.alignment_strains}
"""

rule get_shared_strains_in_alignments:
input:
alignment_strains = expand("results/{{subtype}}/{segment}/{{time}}/aligned_strains.txt", segment=config["segments"]),
output:
shared_strains = "results/{subtype}/all/{time}/shared_strains_in_alignment.txt",
shell:
"""
python3 scripts/intersect_items.py \
--items {input.alignment_strains:q} \
--output {output.shared_strains}
"""

rule select_shared_strains_from_alignment_and_sort:
input:
shared_strains = "results/{subtype}/all/{time}/shared_strains_in_alignment.txt",
alignment = "results/{subtype}/{segment}/{time}/aligned.fasta",
output:
alignment = "results/{subtype}/{segment}/{time}/aligned.sorted.fasta",
shell:
"""
seqkit grep -f {input.shared_strains} {input.alignment} \
| seqkit sort -n > {output.alignment}
"""

rule calculate_pairwise_distances:
input:
alignment = "results/{subtype}/{segment}/{time}/aligned.sorted.fasta",
output:
distances = "results/{subtype}/{segment}/{time}/distances.csv",
benchmark:
"benchmarks/calculate_pairwise_distances_{subtype}_{segment}_{time}.txt"
shell:
"""
pathogen-distance \
--alignment {input.alignment} \
--output {output.distances}
"""

rule embed_with_tsne:
input:
alignments = expand("results/{{subtype}}/{segment}/{{time}}/aligned.sorted.fasta", segment=config["segments"]),
distances = expand("results/{{subtype}}/{segment}/{{time}}/distances.csv", segment=config["segments"]),
output:
embedding = "results/{subtype}/all/{time}/embed_tsne.csv",
params:
perplexity=config.get("embedding", {}).get("perplexity", 200),
benchmark:
"benchmarks/embed_with_tsne_{subtype}_{time}.txt"
shell:
"""
pathogen-embed \
--alignment {input.alignments} \
--distance-matrix {input.distances} \
--output-dataframe {output.embedding} \
t-sne \
--perplexity {params.perplexity}
"""

rule cluster_tsne_embedding:
input:
embedding = "results/{subtype}/all/{time}/embed_tsne.csv",
output:
clusters = "results/{subtype}/all/{time}/cluster_embed_tsne.csv",
params:
label_attribute="tsne_cluster",
distance_threshold=1.0,
benchmark:
"benchmarks/cluster_tsne_embedding_{subtype}_{time}.txt"
shell:
"""
pathogen-cluster \
--embedding {input.embedding} \
--label-attribute {params.label_attribute:q} \
--distance-threshold {params.distance_threshold} \
--output-dataframe {output.clusters}
"""

rule convert_embedding_clusters_to_node_data:
input:
clusters = "results/{subtype}/all/{time}/cluster_embed_tsne.csv",
output:
node_data = "results/{subtype}/all/{time}/cluster_embed_tsne.json",
shell:
"""
python3 scripts/table_to_node_data.py \
--table {input.clusters} \
--output {output.node_data}
"""

def export_node_data_files(wildcards):
nd = [
rules.refine.output.node_data,
Expand All @@ -502,6 +601,10 @@ def export_node_data_files(wildcards):

if wildcards.subtype=="h5n1-cattle-outbreak" and wildcards.segment!='genome':
nd.append(rules.prune_tree.output.node_data)

if config.get("embedding"):
nd.append(rules.convert_embedding_clusters_to_node_data.output.node_data)

return nd


Expand Down
5 changes: 4 additions & 1 deletion config/gisaid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ local_ingest: false

#### Parameters which control large overarching aspects of the build
target_sequences_per_tree: 3000
same_strains_per_segment: false
same_strains_per_segment: true


#### Config files ####
Expand Down Expand Up @@ -159,6 +159,9 @@ traits:
confidence:
FALLBACK: true

embedding:
perplexity: 200

export:
title:
FALLBACK: false # use the title in the auspice JSON
Expand Down
20 changes: 18 additions & 2 deletions config/h5n1/auspice_config_h5n1.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"key": "division",
"title": "Admin Division",
"type": "categorical"
},
},
{
"key": "host",
"title": "Host",
Expand Down Expand Up @@ -89,6 +89,21 @@
"key": "submitting_lab",
"title": "Submitting Lab",
"type": "categorical"
},
{
"key": "tsne_x",
"title": "t-SNE 1",
"type": "continuous"
},
{
"key": "tsne_y",
"title": "t-SNE 2",
"type": "continuous"
},
{
"key": "tsne_cluster",
"title": "t-SNE cluster",
"type": "categorical"
}
],
"geo_resolutions": [
Expand All @@ -111,6 +126,7 @@
"gisaid_clade",
"authors",
"originating_lab",
"submitting_lab"
"submitting_lab",
"tsne_cluster"
]
}
18 changes: 17 additions & 1 deletion config/h5nx/auspice_config_h5nx.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@
"key": "submitting_lab",
"title": "Submitting Lab",
"type": "categorical"
},
{
"key": "tsne_x",
"title": "t-SNE 1",
"type": "continuous"
},
{
"key": "tsne_y",
"title": "t-SNE 2",
"type": "continuous"
},
{
"key": "tsne_cluster",
"title": "t-SNE cluster",
"type": "categorical"
}
],
"geo_resolutions": [
Expand All @@ -111,6 +126,7 @@
"gisaid_clade",
"authors",
"originating_lab",
"submitting_lab"
"submitting_lab",
"tsne_cluster"
]
}
18 changes: 17 additions & 1 deletion config/h7n9/auspice_config_h7n9.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@
"key": "submitting_lab",
"title": "Submitting Lab",
"type": "categorical"
},
{
"key": "tsne_x",
"title": "t-SNE 1",
"type": "continuous"
},
{
"key": "tsne_y",
"title": "t-SNE 2",
"type": "continuous"
},
{
"key": "tsne_cluster",
"title": "t-SNE cluster",
"type": "categorical"
}
],
"geo_resolutions": [
Expand All @@ -70,6 +85,7 @@
"country",
"division",
"originating_lab",
"submitting_lab"
"submitting_lab",
"tsne_cluster"
]
}
18 changes: 17 additions & 1 deletion config/h9n2/auspice_config_h9n2.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@
"key": "submitting_lab",
"title": "Submitting Lab",
"type": "categorical"
},
{
"key": "tsne_x",
"title": "t-SNE 1",
"type": "continuous"
},
{
"key": "tsne_y",
"title": "t-SNE 2",
"type": "continuous"
},
{
"key": "tsne_cluster",
"title": "t-SNE cluster",
"type": "categorical"
}
],
"geo_resolutions": [
Expand All @@ -69,6 +84,7 @@
"region",
"country",
"originating_lab",
"submitting_lab"
"submitting_lab",
"tsne_cluster"
]
}
23 changes: 23 additions & 0 deletions scripts/intersect_items.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/env python3
import argparse


if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--items", nargs="+", required=True, help="one or more files containing a list of items")
parser.add_argument("--output", required=True, help="list of items shared by all input files (the intersection)")

args = parser.parse_args()

with open(args.items[0], "r", encoding="utf-8") as fh:
shared_items = {line.strip() for line in fh}

for item_file in args.items[1:]:
with open(item_file, "r", encoding="utf-8") as fh:
items = {line.strip() for line in fh}

shared_items = shared_items & items

with open(args.output, "w", encoding="utf-8") as oh:
for item in sorted(shared_items):
print(item, file=oh)
32 changes: 32 additions & 0 deletions scripts/table_to_node_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Create Augur-compatible node data JSON from a pandas data frame.
"""
import argparse
import pandas as pd
from augur.utils import write_json


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--table", help="table to convert to a node data JSON")
parser.add_argument("--index-column", default="strain", help="name of the column to use as an index")
parser.add_argument("--delimiter", default=",", help="separator between columns in the given table")
parser.add_argument("--node-name", default="nodes", help="name of the node data attribute in the JSON output")
parser.add_argument("--output", help="node data JSON file")

args = parser.parse_args()

if args.output is not None:
table = pd.read_csv(
args.table,
sep=args.delimiter,
index_col=args.index_column,
dtype=str,
)

# # Convert columns that aren't strain names or labels to floats.
# for column in table.columns:
# if column != "strain" and not "label" in column:
# table[column] = table[column].astype(float)

table_dict = table.transpose().to_dict()
write_json({args.node_name: table_dict}, args.output)
Loading