diff --git a/data/small1.profile b/data/small1.profile index 3b18fa0..1f1eced 100644 --- a/data/small1.profile +++ b/data/small1.profile @@ -5,9 +5,9 @@ @TaxonomyID:ncbi-taxonomy_DATE @__program__:unknown @@TAXID RANK TAXPATH TAXPATHSN PERCENTAGE -4 superkingdom 4 Archaea 0.029528 -5 superkingdom 5 Viruses 0.004334 -3 phylum 4|3 Archaea|Crenarchaeota 0.004226 -2 class 4||2 Archaea||ZZZ 0.004226 -1 class 4|3|1 Archaea|Crenarchaeota|Thermoprotei 0.004226 -0 strain 5||||0 Viruses||||Viruses strain 0.3908 Sample6_49 p0 \ No newline at end of file +4 superkingdom 4 Archaea 80 +5 superkingdom 5 Viruses 20 +3 phylum 4|3 Archaea|Crenarchaeota 50 +2 class 4||2 Archaea||ZZZ 20 +1 class 4|3|1 Archaea|Crenarchaeota|Thermoprotei 50 +0 strain 5||||0 Viruses||||Viruses strain 20 Sample6_49 p0 \ No newline at end of file diff --git a/data/small2.profile b/data/small2.profile index 73955a6..025622c 100644 --- a/data/small2.profile +++ b/data/small2.profile @@ -5,9 +5,9 @@ @TaxonomyID:ncbi-taxonomy_DATE @__program__:unknown @@TAXID RANK TAXPATH TAXPATHSN PERCENTAGE -4 superkingdom 4 Archaea 0.029528 -5 superkingdom 5 Viruses 0.004334 -3 phylum 4|3 Archaea|Crenarchaeota 0.004226 -2 class 4||2 Archaea||ZZZ 0.004226 -1 class 4|3|1 Archaea|Crenarchaeota|Thermoprotei 0.004226 -0 strain 5|6|0|0|0 Viruses||||Viruses strain 0.3908 Sample6_49 p0 \ No newline at end of file +4 superkingdom 4 Archaea 80 +5 superkingdom 5 Viruses 10 +3 phylum 4|3 Archaea|Crenarchaeota 50 +2 class 4||2 Archaea||ZZZ 20 +1 class 4|3|1 Archaea|Crenarchaeota|Thermoprotei 50 +0 strain 5|6|0|0|0 Viruses||||Viruses strain 10 Sample6_49 p0 \ No newline at end of file diff --git a/src/utils/ProfilingTools.py b/src/utils/ProfilingTools.py index 1bf715b..625589b 100644 --- a/src/utils/ProfilingTools.py +++ b/src/utils/ProfilingTools.py @@ -8,7 +8,7 @@ # TODO: make sure that I'm not deleting the root "-1" that way Unifrac picks up on the missing superkingdoms class Profile(object): - def __init__(self, sample_metadata=None, profile=None): + def __init__(self, sample_metadata=None, profile=None, branch_length_fun=lambda x: 1/float(x)): self.sample_metadata = sample_metadata self.profile = profile self._data = dict() @@ -29,7 +29,7 @@ def __init__(self, sample_metadata=None, profile=None): self._all_keys = ["-1"] self._merged_flag = False self.root_len = 1 # the length you want between the "root" of "-1" and the superkingdom level (eg. Bacteria) - self.branch_len_func = lambda x: 1#1/float(x) # Given a node n at depth d in the tree, branch_len_func(d) + self.branch_len_func = branch_length_fun # Given a node n at depth d in the tree, branch_len_func(d) # is how long you want the branch length between n and ancestor(n) to be self._data["-1"]["branch_length"] = self.root_len self.parse_file() # TODO: this sets all the branch lengths to 1 currently @@ -40,209 +40,83 @@ def parse_file(self): _header = self._header for k, v in self.sample_metadata.items(): _header.append('{}:{}'.format(k, v)) + + # populate all the correct keys + for prediction in self.profile: + _all_keys.append(prediction.taxid.strip()) + + # crawl over all profiles tax_path and create the ancestors and descendants list for prediction in self.profile: tax_id = prediction.taxid.strip() - _all_keys.append(prediction.taxid) tax_path = prediction.taxpath.strip().split("|") # this will be a list, join up late - - if tax_id in _data: # If this tax_id is already present, add the abundance. NOT CHECKING FOR CONSISTENCY WITH PATH - _data[tax_id]["abundance"] += prediction.percentage - else: + if tax_id not in _data: _data[tax_id] = dict() - _data[tax_id]["abundance"] = prediction.percentage + else: + raise Exception(f"Improperly formatted profile: row starting with {tax_id} shows up more than once") + _data[tax_id]["tax_path"] = tax_path + + # populate abundance + _data[tax_id]["abundance"] = prediction.percentage + # populate tax path sn if not (prediction.taxpathsn is None): # might not be present _data[tax_id]["tax_path_sn"] = prediction.taxpathsn.strip().split("|") # this will be a list, join up later - _data[tax_id]["tax_path"] = tax_path + # populate the rank _data[tax_id]["rank"] = prediction.rank.strip() - # Find the ancestor + # populate the branch length + _data[tax_id]["branch_length"] = self.tax_path_to_branch_len(tax_path, self.branch_len_func, self.root_len) + + # Find the ancestors if len(tax_path) <= 1: # note, due to the format, we will never run into the case tax_path == [] _data[tax_id]["ancestor"] = "-1" # no ancestor, it's a root - _data[tax_id]["branch_length"] = self.tax_path_to_branch_len(tax_path, self.branch_len_func, self.root_len) - ancestor = "-1" - else: - ancestor = tax_path[-2] - _data[tax_id]["branch_length"] = self.tax_path_to_branch_len(tax_path, self.branch_len_func, self.root_len) - i = -3 # started at tax_path[-2], so start at tax_path[-3] i.e. tax_path[i] for i = -3 - while ancestor is "" or ancestor == tax_id: # if it's a blank or repeated, go up until finding ancestor - ancestor = tax_path[i] - #_data[tax_id]["branch_length"] += 1 # don't need due to tax_path_to_branch_len - i -= 1 - if i + len(tax_path) + 1 < 0: # no ancestor available (eg. ["","","123"], manually set to -1 (the root) - ancestor = "-1" + else: # go from the bottom up, looking for an ancestor that is an acceptable key + ancestor = "-1" # this is the default + tax_path_rev = tax_path[::-1] + for potential_ancestor in tax_path_rev: + if potential_ancestor != tax_id and potential_ancestor in _all_keys: + ancestor = potential_ancestor + break # you found the ancestor, so can quit looking _data[tax_id]["ancestor"] = ancestor # Create a placeholder descendant key initialized to [], just so each tax_id has a descendant key associated to it if "descendants" not in _data[tax_id]: # if this tax_id doesn't have a descendant list, _data[tax_id]["descendants"] = list() # initialize to empty list - # add the descendants - if ancestor in _data: # see if the ancestor is in the data so we can add this entry as a descendant - if "descendants" not in _data[ancestor]: # if it's not present, create the descendant list - _data[ancestor]["descendants"] = list() - _data[ancestor]["descendants"].append( - tax_id) # since ancestor is an ancestor, add this descendant to it - else: # if it's not already in the data, create the entry - _data[ancestor] = dict() - _data[ancestor]["descendants"] = list() - _data[ancestor]["descendants"].append(tax_id) - _data[ancestor]["abundance"] = 0.0 - # only fix if need be - # if all_tax_ids.intersection(_all_keys): + + self._add_descendants() self._delete_missing() # make sure there aren't any missing internal nodes - def _delete_missing(self): - # This is really only useful for MetaPhlAn profiles, which due to the infrequently updated taxonomy, contain - # many missing internal nodes, just delete the tax_id (make them "") and adjust branch lengths accordingly + def _add_descendants(self): + """ + Idea here is to look at all the ancestors of each key, and make the key the descendant of that ancestor + Returns + ------- + None: modifies Profile in place + """ _data = self._data - _good_keys = self._all_keys - current_keys = list(_data) # .keys() - bad_keys = [] - # Remove the bad keys from the dictionary - for key in current_keys: # go through the current keys - if key not in _good_keys: # if it's not a good key - bad_keys.append(key) # store the bad key - del _data[key] # delete the entry - # Get the bad keys in the intermediate places - for key in _good_keys: - tax_path = _data[key]["tax_path"] - for tax_id in tax_path: # look in full taxpath - if tax_id not in _good_keys: # if it's a bad key - if tax_id not in bad_keys: # and it's not already in the list - bad_keys.append(tax_id) # add it to the list - # Don't regard the blank tax_id - if '' in bad_keys: - bad_keys.pop(bad_keys.index('')) - for key in _good_keys: - # branch_length = _data[key]["branch_length"] - tax_path = _data[key]["tax_path"] - # tax_path_sn = _data[key]["tax_path_sn"] - descendants = _data[key]["descendants"] - # ancestor = _data[key]["ancestor"] - for descendant in descendants: - if descendant in bad_keys: - descendants.pop(descendants.index(descendant)) # get rid of the bad descendants - bad_indicies = [] # find the indicies of the bad ones in the tax path - for tax_id in tax_path: - if tax_id in bad_keys: - index = tax_path.index(tax_id) - bad_indicies.append(index) # store all the bad indicies - bad_indicies.sort(reverse=True) # in reverse order - # branch_length += len(bad_indicies) # increment the branch_lengths accordingly (might not work for good|bad|good|bad - for index in bad_indicies: - tax_path[index] = '' # remove the bad tax_ids - # fix the branch lengths and find the ancestors - if len(tax_path) >= 2: - # FIXME: these branch lengths will need to be updated as well - # here, I will need to sum up the ancestor branch lengths until it connects back to a non "" taxID - ancestor = tax_path[-2] - #_data[key]["branch_length"] = 1 # don't need due to tax_path_to_branch_len # FIXME: old method - i = -3 - while ancestor is "" or ancestor == key: # if it's a blank or repeated, go up until finding ancestor - if i < -len(tax_path): # Path is all the way full with bad tax_ids, connect to root - #_data[key]["branch_length"] += 1 # don't need due to tax_path_to_branch_len - ancestor = "-1" - #_data["-1"]["descendants"].append(key) # note this is now a descendant of the root # FIXME: old method - break - else: - ancestor = tax_path[i] - #_data[key]["branch_length"] += 1 # don't need due to tax_path_to_branch_len # FIXME: old method - i -= 1 - # now adjust the branch_length by summing up the appropriate branch lengths for the "" taxIDs - if ancestor == "-1": # went all the way back to the root, must start at the beginning - first_good_position = 0 - #new_branch_length = self.root_len # need to start with the initial branch length to the root "-1" since it's not included in the tax_path - else: - first_good_position = tax_path.index(ancestor) # otherwise start at the first good node/taxID - #new_branch_length = 0 - # FIXME: sometimes people make mistakes in their formatting of the profiles, so we need to work around it - try: - last_good_position = tax_path.index(key) - except ValueError: - # FIXME: try to fix their formatting mistakes, this is a hacky workaround and will need to be addressed. - tax_path.append(key) - last_good_position = tax_path.index(key) - # add up the branch lengths of the edges connecting "" to their ancestors, - # put in an edge between key and ancestor with the sum of these branch lengths - #for intermediate_index in range(first_good_position, last_good_position): - #new_branch_length += self.tax_path_to_branch_len(tax_path[0:intermediate_index], self.branch_len_func, self.root_len) - - #_data[key]['branch_length'] = new_branch_length - #if i != -3: - #if new_branch_length != _data[key]['branch_length']: - # print(f"ancestor: {ancestor}") - # print(f"tax_path: {tax_path}") - # print(f"i: {i}") - # print(f"branch length: {_data[key]['branch_length']}") - # print(f"new branch length: {new_branch_length}") - _data[key]["ancestor"] = ancestor - if ancestor in _data: - if key not in _data[ancestor]["descendants"]: - _data[ancestor]["descendants"].append(key) # Note this is now a descendant of it's ancestor - return + _all_keys = self._all_keys + for prediction in self.profile: + tax_id = prediction.taxid.strip() # the tax ID we are looking at + ancestor = _data[tax_id]['ancestor'] # the tax ID's ancestor + if tax_id not in _data[ancestor]['descendants']: + _data[ancestor]['descendants'].append(tax_id) # so make the tax ID we're looking at the descendant of the ancestor - def _populate_missing_dont_use(self): - # Unfortunately, some of the profile files may be missing intermediate ranks, - # so we should manually populate them here - # This will only really fix one missing intermediate rank.... - # INSTEAD!!! Let's just delete the missing intermediate ranks, if the tax_id isn't a key, delete it from the tax_id_path - # This is really only useful for MetaPhlAn profiles, which due to the infrequently updated taxonomy, contain - # many missing internal nodes - _data = self._data - for key in _data.keys(): - if "abundance" not in _data[key]: # this is a missing intermediate rank - # all the descendants *should* be in there, so leverage this info - if "descendants" not in _data[key]: - print("You're screwed, malformed profile file with rank %s" % key) - raise Exception + def _delete_missing(self): + """ + Deletes from the descendants all those taxids that aren't keys in the profile (i.e. there is no line that starts with that taxID) + Returns + ------- + none: modifies Profile in place + """ + for key in self._data: + clean_descendants = [] + for descendant in self._data[key]["descendants"]: + if descendant in self._all_keys: # if it's one of the taxids that the line starts with, add it + clean_descendants.append(descendant) else: - descendant = _data[key]["descendants"][0] - to_populate_key = descendant # just need the first one since the higher up path will be the same - to_populate = copy.deepcopy(_data[to_populate_key]) - tax_path = to_populate["tax_path"] - tax_path_sn = to_populate["tax_path_sn"] - descendant_pos = tax_path.index(descendant) - for i in range(len(tax_path) - 1, descendant_pos - 1, -1): - tax_path.pop(i) - tax_path_sn.pop(i) - to_populate["branch_length"] = 1 - if "rank" in to_populate: - rank = to_populate["rank"] - if rank == "strain": - rank = "species" - elif rank == "species": - rank = "genus" - elif rank == "genus": - rank = "family" - elif rank == "family": - rank = "order" - elif rank == "order": - rank = "class" - elif rank == "class": - rank = "phylum" - elif rank == "phylum": - rank = "superkingdom" - else: - print('Invalid rank') - raise Exception - else: - print('Missing rank') - raise Exception - to_populate["ancestor"] = tax_path[-2] - to_populate["rank"] = rank - # Now go through and sum up the abundance for which this guy is the ancestor - to_populate["abundance"] = 0 - to_populate["descendants"] = [] - for temp_key in _data.keys(): - if "ancestor" in _data[temp_key]: - if _data[temp_key]["ancestor"] == key: - to_populate["abundance"] += _data[temp_key]["abundance"] - to_populate["descendants"].append(temp_key) - # Make sure this guy is listed as a descendant to his ancestor - if key not in _data[to_populate["ancestor"]]["descendants"]: - _data[to_populate["ancestor"]]["descendants"].append(key) - _data[key] = to_populate + pass # don't include the taxids that aren't actually in the final tax tree + self._data[key]["descendants"] = clean_descendants return def write_file(self, out_file_name=None): @@ -593,14 +467,6 @@ def make_unifrac_input_no_normalize(self, other): def test_normalize(): - profile = Profile('/home/dkoslicki/Dropbox/Repositories/CAMIProfilingTools/src/test1.profile') - profile.write_file('/home/dkoslicki/Dropbox/Repositories/CAMIProfilingTools/src/test1.profile.import') - profile.normalize() - profile.write_file('/home/dkoslicki/Dropbox/Repositories/CAMIProfilingTools/src/test1.profile.normalize') - return profile - - -def test_normalize2(): import EMDUnifrac as EMDU from load_data import open_profile_from_tsv import os @@ -650,138 +516,172 @@ def test_normalize2(): assert weighted_norm != unweighted_no_norm return -def test_branch_lengths(): + +def test_branch_lengths_all_1(): from load_data import open_profile_from_tsv import os # test file file_path1 = os.path.dirname(os.path.abspath(__file__)) + "/../../data/small1.profile" profile_list = open_profile_from_tsv(file_path1, False) - name1, metadata1, profile1 = profile_list[0] - profile1 = Profile(sample_metadata=metadata1, profile=profile1) + name1, metadata1, profile_fernando = profile_list[0] + + # Test with branch lengths of 1 + profile1 = Profile(sample_metadata=metadata1, profile=profile_fernando, branch_length_fun=lambda x: 1) Tint, lint, nodes_in_order, nodes_to_index, P, Q = profile1.make_unifrac_input_and_normalize(profile1) index_to_nodes = dict() for key, val in nodes_to_index.items(): index_to_nodes[val] = key Tint_new = dict() lint_new = dict() - for key,val in Tint.items(): + for key, val in Tint.items(): Tint_new[index_to_nodes[key]] = index_to_nodes[val] - for key,val in lint.items(): - lint_new[(index_to_nodes[key[0]],index_to_nodes[key[1]])] = val - #print(f"Tint: {Tint}") - #print(f"lint: {lint}") - print(f"Tint_new: {Tint_new}") - print(f"lint_new: {lint_new}") - # NOTE: Should get (with branch length func x) - # Tint_new: {'0': '5', '1': '3', '2': '4', '3': '4', '4': '-1', '5': '-1'} - # lint_new: {('0', '5'): 5, ('1', '3'): 3, ('2', '4'): 3, ('3', '4'): 2, ('4', '-1'): 1, ('5', '-1'): 1} - # when using the profile - # # Taxonomic Profiling Output - # @SampleID:CAMI_LOW_S001 - # @Version:0.9.1 - # @Ranks:superkingdom|phylum|class|order|family|genus|species|strain - # @TaxonomyID:ncbi-taxonomy_DATE - # @__program__:unknown - # @@TAXID RANK TAXPATH TAXPATHSN PERCENTAGE - # 4 superkingdom 4 Archaea 0.029528 - # 5 superkingdom 5 Viruses 0.004334 - # 3 phylum 4|3 Archaea|Crenarchaeota 0.004226 - # 2 class 4||2 Archaea||ZZZ 0.004226 - # 1 class 4|3|1 Archaea|Crenarchaeota|Thermoprotei 0.004226 - # 0 strain 5||||0 Viruses||||Viruses strain 0.3908 Sample6_49 p0 - - -def test_branch_lengths2(): + for key, val in lint.items(): + lint_new[(index_to_nodes[key[0]], index_to_nodes[key[1]])] = val + #print(f"Tint_new: {Tint_new}") + #print(f"lint_new: {lint_new}") + #print(f"P: {P}") + assert Tint_new['5'] == '-1' + assert Tint_new['4'] == '-1' + assert Tint_new['3'] == '4' + assert Tint_new['1'] == '3' + assert Tint_new['2'] == '4' + assert Tint_new['0'] == '5' + assert set(Tint_new.keys()) == {'0', '1', '2', '3', '4', '5'} + for val in lint.values(): + assert val == 1 + correct_vals = {'0': 0.20, '1': 0.50, '2': 0.20, '3': 0.0, '4': 0.10, '5': 0.00} + for key, val in correct_vals.items(): + assert P[nodes_to_index[key]] == val + + # test with branch lengths of x + profile1 = Profile(sample_metadata=metadata1, profile=profile_fernando, branch_length_fun=lambda x: x) + Tint, lint, nodes_in_order, nodes_to_index, P, Q = profile1.make_unifrac_input_and_normalize(profile1) + index_to_nodes = dict() + for key, val in nodes_to_index.items(): + index_to_nodes[val] = key + Tint_new = dict() + lint_new = dict() + for key, val in Tint.items(): + Tint_new[index_to_nodes[key]] = index_to_nodes[val] + for key, val in lint.items(): + lint_new[(index_to_nodes[key[0]], index_to_nodes[key[1]])] = val + assert Tint_new['5'] == '-1' + assert Tint_new['4'] == '-1' + assert Tint_new['3'] == '4' + assert Tint_new['1'] == '3' + assert Tint_new['2'] == '4' + assert Tint_new['0'] == '5' + assert set(Tint_new.keys()) == {'0', '1', '2', '3', '4', '5'} + correct_lints = {('1', '3'): 3, ('3', '4'): 2, ('4', '-1'): 1, ('2', '4'): 3, ('0', '5'): 5, ('5', '-1'): 1} + for key, val in correct_lints.items(): + assert lint_new[key] == correct_lints[key] + correct_vals = {'0': 0.20, '1': 0.50, '2': 0.20, '3': 0.0, '4': 0.10, '5': 0.00} + for key, val in correct_vals.items(): + assert P[nodes_to_index[key]] == val + + +def test_branch_lengths_all_2(): from load_data import open_profile_from_tsv import os # test file file_path1 = os.path.dirname(os.path.abspath(__file__)) + "/../../data/small2.profile" profile_list = open_profile_from_tsv(file_path1, False) - name1, metadata1, profile1 = profile_list[0] - profile1 = Profile(sample_metadata=metadata1, profile=profile1) + name1, metadata1, profile_fernando = profile_list[0] + + # Test with branch lengths of 1 + profile1 = Profile(sample_metadata=metadata1, profile=profile_fernando, branch_length_fun=lambda x: 1) Tint, lint, nodes_in_order, nodes_to_index, P, Q = profile1.make_unifrac_input_and_normalize(profile1) index_to_nodes = dict() for key, val in nodes_to_index.items(): index_to_nodes[val] = key Tint_new = dict() lint_new = dict() - for key,val in Tint.items(): + for key, val in Tint.items(): Tint_new[index_to_nodes[key]] = index_to_nodes[val] - for key,val in lint.items(): - lint_new[(index_to_nodes[key[0]],index_to_nodes[key[1]])] = val - #print(f"Tint: {Tint}") - #print(f"lint: {lint}") - print(f"Tint_new: {Tint_new}") - print(f"lint_new: {lint_new}") - # NOTE: Should get (with branch length func x) - # Tint_new: {'0': '5', '1': '3', '2': '4', '3': '4', '4': '-1', '5': '-1'} - # lint_new: {('0', '5'): 5, ('1', '3'): 3, ('2', '4'): 3, ('3', '4'): 2, ('4', '-1'): 1, ('5', '-1'): 1} - # when using the profile - # # Taxonomic Profiling Output - # @SampleID:CAMI_LOW_S001 - # @Version:0.9.1 - # @Ranks:superkingdom|phylum|class|order|family|genus|species|strain - # @TaxonomyID:ncbi-taxonomy_DATE - # @__program__:unknown - # @@TAXID RANK TAXPATH TAXPATHSN PERCENTAGE - # 4 superkingdom 4 Archaea 0.029528 - # 5 superkingdom 5 Viruses 0.004334 - # 3 phylum 4|3 Archaea|Crenarchaeota 0.004226 - # 2 class 4||2 Archaea||ZZZ 0.004226 - # 1 class 4|3|1 Archaea|Crenarchaeota|Thermoprotei 0.004226 - # 0 strain 5||||0 Viruses||||Viruses strain 0.3908 Sample6_49 p0 - - -def test_unifrac(): - sys.path.append('/home/dkoslicki/Dropbox/Repositories/EMDUnifrac/src') - import EMDUnifrac as EMDU - profile1 = Profile('/home/dkoslicki/Dropbox/Repositories/EMDUnifrac/data/test1.profile') - profile2 = Profile('/home/dkoslicki/Dropbox/Repositories/EMDUnifrac/data/test2.profile') - t0 = timeit.default_timer() - (Tint, lint, nodes_in_order, nodes_to_index, P, Q) = profile1.make_unifrac_input_and_normalize(profile2) - t1 = timeit.default_timer() - # print(t1-t0) - (Z, diffab) = EMDU.EMDUnifrac_weighted(Tint, lint, nodes_in_order, P, Q) - print(Z) - (Tint, lint, nodes_in_order, nodes_to_index, P, Q) = profile1.make_unifrac_input_and_normalize(profile2) - (Z2, diffab2) = EMDU.EMDUnifrac_weighted(Tint, lint, nodes_in_order, P, Q) - print(Z2) - profile1.normalize() - profile2.normalize() - (Tint, lint, nodes_in_order, nodes_to_index, P, Q) = profile1.make_unifrac_input_and_normalize(profile2) - (Z3, diffab3) = EMDU.EMDUnifrac_weighted(Tint, lint, nodes_in_order, P, Q) - print(Z3) - - print(diffab) - print(diffab2) - print(diffab3) - - -def test_real_data(): - profile1 = Profile( - '/home/dkoslicki/Dropbox/Repositories/CAMIProfilingTools/src/lane4-s041-indexN722-S502-ATGCGCAG-ATAGAGAG-41_M5-2_S41_L004_R1_001.fa.gz.metaphlan.profile') - profile2 = Profile( - '/home/dkoslicki/Dropbox/Repositories/CAMIProfilingTools/src/lane8-s092-indexN729-S505-TCGACGTC-CTCCTTAC-91_Z0299_S92_L008_R1_001.fa.gz.metaphlan.profile') - sys.path.append('/home/dkoslicki/Dropbox/Repositories/EMDUnifrac/src') - import EMDUnifrac as EMDU - (Tint, lint, nodes_in_order, nodes_to_index, P, Q) = profile1.make_unifrac_input_and_normalize(profile2) - (Z, diffab) = EMDU.EMDUnifrac_weighted(Tint, lint, nodes_in_order, P, Q) - print(Z) - (Tint, lint, nodes_in_order, nodes_to_index, P, Q) = profile1.make_unifrac_input_and_normalize(profile2) - (Z2, diffab2) = EMDU.EMDUnifrac_weighted(Tint, lint, nodes_in_order, P, Q) - print(Z2) - profile1.normalize() - profile2.normalize() - (Tint, lint, nodes_in_order, nodes_to_index, P, Q) = profile1.make_unifrac_input_and_normalize(profile2) - (Z3, diffab3) = EMDU.EMDUnifrac_weighted(Tint, lint, nodes_in_order, P, Q) - print(Z3) - - print(diffab) - print(diffab2) - print(diffab3) - -# return profile1, profile2 + for key, val in lint.items(): + lint_new[(index_to_nodes[key[0]], index_to_nodes[key[1]])] = val + #print(f"Tint_new: {Tint_new}") + #print(f"lint_new: {lint_new}") + #print(f"P: {P}") + assert Tint_new['5'] == '-1' + assert Tint_new['4'] == '-1' + assert Tint_new['3'] == '4' + assert Tint_new['1'] == '3' + assert Tint_new['2'] == '4' + assert Tint_new['0'] == '5' + assert set(Tint_new.keys()) == {'0', '1', '2', '3', '4', '5'} + for val in lint.values(): + assert val == 1 + correct_vals = {'0': 1/9., '1': 5/9., '2': 2/9., '3': 0.0, '4': 1/9., '5': 0.00} + for key, val in correct_vals.items(): + assert P[nodes_to_index[key]] == val + + # test with branch lengths of x + profile1 = Profile(sample_metadata=metadata1, profile=profile_fernando, branch_length_fun=lambda x: x) + Tint, lint, nodes_in_order, nodes_to_index, P, Q = profile1.make_unifrac_input_and_normalize(profile1) + index_to_nodes = dict() + for key, val in nodes_to_index.items(): + index_to_nodes[val] = key + Tint_new = dict() + lint_new = dict() + for key, val in Tint.items(): + Tint_new[index_to_nodes[key]] = index_to_nodes[val] + for key, val in lint.items(): + lint_new[(index_to_nodes[key[0]], index_to_nodes[key[1]])] = val + assert Tint_new['5'] == '-1' + assert Tint_new['4'] == '-1' + assert Tint_new['3'] == '4' + assert Tint_new['1'] == '3' + assert Tint_new['2'] == '4' + assert Tint_new['0'] == '5' + assert set(Tint_new.keys()) == {'0', '1', '2', '3', '4', '5'} + correct_lints = {('1', '3'): 3, ('3', '4'): 2, ('4', '-1'): 1, ('2', '4'): 3, ('0', '5'): 5, ('5', '-1'): 1} + for key, val in correct_lints.items(): + assert lint_new[key] == correct_lints[key] + correct_vals = {'0': 1/9., '1': 5/9., '2': 2/9., '3': 0.0, '4': 1/9., '5': 0.00} + for key, val in correct_vals.items(): + assert P[nodes_to_index[key]] == val + + +def test_no_normalize(): + from load_data import open_profile_from_tsv + import os + # test file + file_path1 = os.path.dirname(os.path.abspath(__file__)) + "/../../data/small2.profile" + profile_list = open_profile_from_tsv(file_path1, False) + name1, metadata1, profile_fernando = profile_list[0] + + # Test with branch lengths of 1 + profile1 = Profile(sample_metadata=metadata1, profile=profile_fernando, branch_length_fun=lambda x: 1) + Tint, lint, nodes_in_order, nodes_to_index, P, Q = profile1.make_unifrac_input_no_normalize(profile1) + index_to_nodes = dict() + for key, val in nodes_to_index.items(): + index_to_nodes[val] = key + Tint_new = dict() + lint_new = dict() + for key, val in Tint.items(): + Tint_new[index_to_nodes[key]] = index_to_nodes[val] + for key, val in lint.items(): + lint_new[(index_to_nodes[key[0]], index_to_nodes[key[1]])] = val + #print(f"Tint_new: {Tint_new}") + #print(f"lint_new: {lint_new}") + #print(f"P: {P}") + assert Tint_new['5'] == '-1' + assert Tint_new['4'] == '-1' + assert Tint_new['3'] == '4' + assert Tint_new['1'] == '3' + assert Tint_new['2'] == '4' + assert Tint_new['0'] == '5' + assert set(Tint_new.keys()) == {'0', '1', '2', '3', '4', '5'} + for val in lint.values(): + assert val == 1 + correct_vals = {'0': 0.10, '1': 0.50, '2': 0.20, '3': 0.0, '4': 0.10, '5': 0.00} + for key, val in correct_vals.items(): + assert P[nodes_to_index[key]] == val + if __name__ == "__main__": - #test_branch_lengths() - test_branch_lengths2() + test_normalize() + test_branch_lengths_all_1() + test_branch_lengths_all_2() + test_no_normalize()