diff --git a/goatools/godag/go_tasks.py b/goatools/godag/go_tasks.py index 6981c6c..ad425e8 100644 --- a/goatools/godag/go_tasks.py +++ b/goatools/godag/go_tasks.py @@ -1,15 +1,20 @@ """item-DAG tasks.""" -__copyright__ = "Copyright (C) 2010-present, DV Klopfenstein, H Tang, All rights reserved." +__copyright__ = ( + "Copyright (C) 2010-present, DV Klopfenstein, H Tang, All rights reserved." +) __author__ = "DV Klopfenstein" -from goatools.godag.consts import RELATIONSHIP_SET +from ..godag.consts import RELATIONSHIP_SET -# ------------------------------------------------------------------------------------ def get_go2parents(go2obj, relationships): """Get set of parents GO IDs, including parents through user-specfied relationships""" - if go2obj and not hasattr(next(iter(go2obj.values())), 'relationship') or not relationships: + if ( + go2obj + and not hasattr(next(iter(go2obj.values())), "relationship") + or not relationships + ): return get_go2parents_isa(go2obj) go2parents = {} for goid_main, goterm in go2obj.items(): @@ -21,10 +26,14 @@ def get_go2parents(go2obj, relationships): go2parents[goid_main] = parents_goids return go2parents -# ------------------------------------------------------------------------------------ + def get_go2children(go2obj, relationships): """Get set of children GO IDs, including children through user-specfied relationships""" - if go2obj and not hasattr(next(iter(go2obj.values())), 'relationship') or not relationships: + if ( + go2obj + and not hasattr(next(iter(go2obj.values())), "relationship") + or not relationships + ): return get_go2children_isa(go2obj) go2children = {} for goid_main, goterm in go2obj.items(): @@ -36,7 +45,7 @@ def get_go2children(go2obj, relationships): go2children[goid_main] = children_goids return go2children -# ------------------------------------------------------------------------------------ + def get_go2parents_isa(go2obj): """Get set of immediate parents GO IDs""" go2parents = {} @@ -46,7 +55,7 @@ def get_go2parents_isa(go2obj): go2parents[goid_main] = parents_goids return go2parents -# ------------------------------------------------------------------------------------ + def get_go2children_isa(go2obj): """Get set of immediate children GO IDs""" go2children = {} @@ -56,84 +65,96 @@ def get_go2children_isa(go2obj): go2children[goid_main] = children_goids return go2children -# ------------------------------------------------------------------------------------ + def get_go2ancestors(terms, relationships, prt=None): """Get GO-to- ancestors (all parents)""" if not relationships: if prt is not None: - prt.write('up: is_a\n') + prt.write("up: is_a\n") return get_id2parents(terms) if relationships == RELATIONSHIP_SET or relationships is True: if prt is not None: - prt.write('up: is_a and {Rs}\n'.format( - Rs=' '.join(sorted(RELATIONSHIP_SET)))) + prt.write( + "up: is_a and {Rs}\n".format(Rs=" ".join(sorted(RELATIONSHIP_SET))) + ) return get_id2upper(terms) if prt is not None: - prt.write('up: is_a and {Rs}\n'.format( - Rs=' '.join(sorted(relationships)))) + prt.write("up: is_a and {Rs}\n".format(Rs=" ".join(sorted(relationships)))) return get_id2upperselect(terms, relationships) + def get_go2descendants(terms, relationships, prt=None): """Get GO-to- descendants""" if not relationships: if prt is not None: - prt.write('down: is_a\n') + prt.write("down: is_a\n") return get_id2children(terms) if relationships == RELATIONSHIP_SET or relationships is True: if prt is not None: - prt.write('down: is_a and {Rs}\n'.format( - Rs=' '.join(sorted(RELATIONSHIP_SET)))) + prt.write( + "down: is_a and {Rs}\n".format(Rs=" ".join(sorted(RELATIONSHIP_SET))) + ) return get_id2lower(terms) if prt is not None: - prt.write('down: is_a and {Rs}\n'.format( - Rs=' '.join(sorted(relationships)))) + prt.write("down: is_a and {Rs}\n".format(Rs=" ".join(sorted(relationships)))) return get_id2lowerselect(terms, relationships) -# ------------------------------------------------------------------------------------ + def get_go2depth(goobjs, relationships): """Get depth of each object""" if not relationships: - return {o.item_id:o.depth for o in goobjs} + return {o.item_id: o.depth for o in goobjs} from goatools.godag.reldepth import get_go2reldepth + return get_go2reldepth(goobjs, relationships) -# ------------------------------------------------------------------------------------ + def get_id2parents(objs): """Get all parent IDs up the hierarchy""" id2parents = {} for obj in objs: _get_id2parents(id2parents, obj.item_id, obj) - return {e:es for e, es in id2parents.items() if es} + return {e: es for e, es in id2parents.items() if es} + def get_id2children(objs): """Get all child IDs down the hierarchy""" id2children = {} for obj in objs: _get_id2children(id2children, obj.item_id, obj) - return {e:es for e, es in id2children.items() if es} + return {e: es for e, es in id2children.items() if es} + def get_id2upper(objs): """Get all ancestor IDs, including all parents and IDs up all relationships""" id2upper = {} for obj in objs: _get_id2upper(id2upper, obj.item_id, obj) - return {e:es for e, es in id2upper.items() if es} + return {e: es for e, es in id2upper.items() if es} + def get_id2lower(objs): """Get all descendant IDs, including all children and IDs down all relationships""" id2lower = {} + cache = set() for obj in objs: - _get_id2lower(id2lower, obj.item_id, obj) - return {e:es for e, es in id2lower.items() if es} + item_id = obj.item_id + if item_id in cache: + continue + _get_id2lower(id2lower, obj.item_id, obj, cache) + return {e: es for e, es in id2lower.items() if es} + def get_id2upperselect(objs, relationship_set): """Get all ancestor IDs, including all parents and IDs up selected relationships""" return IdToUpperSelect(objs, relationship_set).id2upperselect + def get_id2lowerselect(objs, relationship_set): """Get all descendant IDs, including all children and IDs down selected relationships""" return IdToLowerSelect(objs, relationship_set).id2lowerselect + def get_relationship_targets(item_ids, relationships, id2rec): """Get item ID set of item IDs in a relationship target set""" # Requirements to use this function: @@ -148,7 +169,7 @@ def get_relationship_targets(item_ids, relationships, id2rec): reltgt_objs_all.update(reltgt_objs_cur) return reltgt_objs_all -# ------------------------------------------------------------------------------------ + # pylint: disable=too-few-public-methods class IdToUpperSelect: """Get all ancestor IDs, including all parents and IDs up selected relationships""" @@ -178,6 +199,7 @@ def _get_id2upperselect(self, item_id, item_obj): id2upperselect[item_id] = parent_ids return parent_ids + class IdToLowerSelect: """Get all descendant IDs, including all children and IDs down selected relationships""" @@ -206,7 +228,6 @@ def _get_id2lowerselect(self, item_id, item_obj): id2lowerselect[item_id] = child_ids return child_ids -# ------------------------------------------------------------------------------------ def _get_id2parents(id2parents, item_id, item_obj): """Add the parent item IDs for one item object and their parents.""" @@ -220,6 +241,7 @@ def _get_id2parents(id2parents, item_id, item_obj): id2parents[item_id] = parent_ids return parent_ids + def _get_id2children(id2children, item_id, item_obj): """Add the child item IDs for one item object and their children.""" if item_id in id2children: @@ -232,6 +254,7 @@ def _get_id2children(id2children, item_id, item_obj): id2children[item_id] = child_ids return child_ids + def _get_id2upper(id2upper, item_id, item_obj): """Add the parent item IDs for one item object and their upper.""" if item_id in id2upper: @@ -244,19 +267,23 @@ def _get_id2upper(id2upper, item_id, item_obj): id2upper[item_id] = upper_ids return upper_ids -def _get_id2lower(id2lower, item_id, item_obj): + +def _get_id2lower(id2lower, item_id, item_obj, cache: set): """Add the lower item IDs for one item object and the objects below them.""" if item_id in id2lower: return id2lower[item_id] lower_ids = set() + cache.add(item_id) for lower_obj in item_obj.get_goterms_lower(): lower_id = lower_obj.item_id lower_ids.add(lower_id) - lower_ids |= _get_id2lower(id2lower, lower_id, lower_obj) + if lower_id in cache: + continue + lower_ids |= _get_id2lower(id2lower, lower_id, lower_obj, cache) id2lower[item_id] = lower_ids return lower_ids -# ------------------------------------------------------------------------------------ + class CurNHigher: """Fill id2obj with item IDs in relationships.""" diff --git a/goatools/nt_utils.py b/goatools/nt_utils.py index 9cfe332..e25045a 100644 --- a/goatools/nt_utils.py +++ b/goatools/nt_utils.py @@ -7,11 +7,13 @@ import datetime import collections as cx + def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""): """Return a new dict of namedtuples by combining "dicts" of namedtuples or objects.""" assert len(ids) == len(set(ids)), "NOT ALL IDs ARE UNIQUE: {IDs}".format(IDs=ids) assert len(flds) == len(set(flds)), "DUPLICATE FIELDS: {IDs}".format( - IDs=cx.Counter(flds).most_common()) + IDs=cx.Counter(flds).most_common() + ) usr_id_nt = [] # 1. Instantiate namedtuple object ntobj = cx.namedtuple("Nt", " ".join(flds)) @@ -23,6 +25,7 @@ def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""): usr_id_nt.append((item_id, ntobj._make(vals))) return cx.OrderedDict(usr_id_nt) + def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""): """Return a new list of namedtuples by combining "dicts" of namedtuples or objects.""" combined_nt_list = [] @@ -36,41 +39,53 @@ def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""): combined_nt_list.append(ntobj._make(vals)) return combined_nt_list + def combine_nt_lists(lists, flds, dflt_null=""): """Return a new list of namedtuples by zipping "lists" of namedtuples or objects.""" combined_nt_list = [] # Check that all lists are the same length lens = [len(lst) for lst in lists] - assert len(set(lens)) == 1, \ - "LIST LENGTHS MUST BE EQUAL: {Ls}".format(Ls=" ".join(str(l) for l in lens)) + assert len(set(lens)) == 1, "LIST LENGTHS MUST BE EQUAL: {Ls}".format( + Ls=" ".join(str(l) for l in lens) + ) # 1. Instantiate namedtuple object ntobj = cx.namedtuple("Nt", " ".join(flds)) # 2. Loop through zipped list for lst0_lstn in zip(*lists): # 2a. Combine various namedtuples into a single namedtuple - combined_nt_list.append(ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null))) + combined_nt_list.append( + ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null)) + ) return combined_nt_list + def wr_py_nts(fout_py, nts, docstring=None, varname="nts"): """Save namedtuples into a Python module.""" if nts: - with open(fout_py, 'w') as prt: + with open(fout_py, "w") as prt: prt.write('"""{DOCSTRING}"""\n\n'.format(DOCSTRING=docstring)) prt.write("# Created: {DATE}\n".format(DATE=str(datetime.date.today()))) prt_nts(prt, nts, varname) - sys.stdout.write(" {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py)) + sys.stdout.write( + " {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py) + ) -def prt_nts(prt, nts, varname, spc=' '): + +def prt_nts(prt, nts, varname, spc=" "): """Print namedtuples into a Python module.""" first_nt = nts[0] nt_name = type(first_nt).__name__ prt.write("import collections as cx\n\n") + prt.write("import numpy as np\n\n") prt.write("NT_FIELDS = [\n") for fld in first_nt._fields: prt.write('{SPC}"{F}",\n'.format(SPC=spc, F=fld)) prt.write("]\n\n") - prt.write('{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format( - NtName=nt_name)) + prt.write( + '{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format( + NtName=nt_name + ) + ) prt.write("# {N:,} items\n".format(N=len(nts))) prt.write("# pylint: disable=line-too-long\n") prt.write("{VARNAME} = [\n".format(VARNAME=varname)) @@ -78,6 +93,7 @@ def prt_nts(prt, nts, varname, spc=' '): prt.write("{SPC}{NT},\n".format(SPC=spc, NT=ntup)) prt.write("]\n") + def get_unique_fields(fld_lists): """Get unique namedtuple fields, despite potential duplicates in lists of fields.""" flds = [] @@ -93,6 +109,7 @@ def get_unique_fields(fld_lists): assert len(flds) == len(fld_set) return flds + # -- Internal methods ---------------------------------------------------------------- def _combine_nt_vals(lst0_lstn, flds, dflt_null): """Given a list of lists of nts, return a single namedtuple.""" @@ -110,4 +127,5 @@ def _combine_nt_vals(lst0_lstn, flds, dflt_null): vals.append(dflt_null) return vals + # Copyright (C) 2016-2018, DV Klopfenstein, H Tang. All rights reserved. diff --git a/tests/test_dcnt_r01.py b/tests/test_dcnt_r01.py index 95cd428..ff247ba 100755 --- a/tests/test_dcnt_r01.py +++ b/tests/test_dcnt_r01.py @@ -5,6 +5,8 @@ import sys import timeit import numpy as np +import pytest + from numpy.random import shuffle from scipy import stats @@ -14,6 +16,7 @@ from goatools.obo_parser import GODag +@pytest.mark.skip(reason="Latest obo (`releases/2024-06-10`) is not DAG") def test_go_pools(): """Print a comparison of GO terms from different species in two different comparisons.""" objr = _Run()