diff --git a/dingo/illustrations.py b/dingo/illustrations.py index 0e9a120f..fd37b651 100644 --- a/dingo/illustrations.py +++ b/dingo/illustrations.py @@ -11,6 +11,8 @@ import plotly.io as pio import plotly.express as px from dingo.utils import compute_copula +import plotly.figure_factory as ff +from scipy.cluster import hierarchy def plot_copula(data_flux1, data_flux2, n = 5, width = 900 , height = 600, export_format = "svg"): """A Python function to plot the copula between two fluxes @@ -129,3 +131,77 @@ def plot_corr_matrix(corr_matrix, reactions, removed_reactions=[], format="svg") fig_name = "CorrelationMatrix." + format pio.write_image(fig, fig_name, scale=2) + + + +def plot_dendrogram(dissimilarity_matrix, reactions , plot_labels=False, t=2.0, linkage="ward"): + """A Python function to plot the dendrogram of a dissimilarity matrix. + + Keyword arguments: + dissimilarity_matrix -- A matrix produced from the "cluster_corr_reactions" function + reactions -- A list with the model's reactions + plot_labels -- A boolean variable that if True plots the reactions labels in the dendrogram + t -- A threshold that defines a threshold that cuts the dendrogram + at a specific height and colors the occuring clusters accordingly + linkage -- linkage defines the type of linkage. + Available linkage types are: single, average, complete, ward. + """ + + fig = ff.create_dendrogram(dissimilarity_matrix, + labels=reactions, + linkagefun=lambda x: hierarchy.linkage(x, linkage), + color_threshold=t) + fig.update_layout(width=800, height=800) + + if plot_labels == False: + fig.update_layout( + xaxis=dict( + showticklabels=False, + ticks="") ) + else: + fig.update_layout( + xaxis=dict( + title_font=dict(size=10), + tickfont=dict(size=8) ), + yaxis=dict( + title_font=dict(size=10), + tickfont=dict(size=8) ) ) + + fig.show() + + + +def plot_graph(G, pos): + """A Python function to plot a graph created from a correlation matrix. + + Keyword arguments: + G -- A graph produced from the "graph_corr_matrix" function. + pos -- A layout for the corresponding graph. + """ + + fig = go.Figure() + + for u, v, data in G.edges(data=True): + x0, y0 = pos[u] + x1, y1 = pos[v] + + edge_color = 'blue' if data['weight'] > 0 else 'red' + + fig.add_trace(go.Scatter(x=[x0, x1], y=[y0, y1], mode='lines', + line=dict(width=abs(data['weight']) * 1, + color=edge_color), hoverinfo='none', + showlegend=False)) + + for node in G.nodes(): + x, y = pos[node] + node_name = G.nodes[node].get('name', f'Node {node}') + + fig.add_trace(go.Scatter(x=[x], y=[y], mode='markers', + marker=dict(size=10), + text=[node_name], + textposition='top center', + name = node_name, + showlegend=False)) + + fig.update_layout(width=800, height=800) + fig.show() \ No newline at end of file diff --git a/dingo/utils.py b/dingo/utils.py index 3fb0888c..303840cb 100644 --- a/dingo/utils.py +++ b/dingo/utils.py @@ -11,6 +11,9 @@ from scipy.sparse import diags from dingo.scaling import gmscale from dingo.nullspace import nullspace_dense, nullspace_sparse +from scipy.cluster import hierarchy +from networkx.algorithms.components import connected_components +import networkx as nx def compute_copula(flux1, flux2, n): """A Python function to estimate the copula between two fluxes @@ -340,4 +343,96 @@ def correlated_reactions(steady_states, reactions=[], pearson_cutoff = 0.90, ind else: np.fill_diagonal(filtered_corr_matrix, 1) return filtered_corr_matrix, indicator_dict - \ No newline at end of file + + + +def cluster_corr_reactions(correlation_matrix, reactions, linkage="ward", + t = 4.0, correction=True): + """A Python function for hierarchical clustering of the correlation matrix + + Keyword arguments: + correlation_matrix -- A numpy 2D array of a correlation matrix + reactions -- A list with the model's reactions + linkage -- linkage defines the type of linkage. + Available linkage types are: single, average, complete, ward. + t -- A threshold that defines a threshold that cuts the dendrogram + at a specific height and produces clusters + correction -- A boolean variable that if True converts the values of the + the correlation matrix to absolute values. + """ + + # function to return a nested list with grouped reactions based on clustering + def clusters_list(reactions, labels): + clusters = [] + unique_labels = np.unique(labels) + for label in unique_labels: + cluster = [] + label_where = np.where(labels == label)[0] + for where in label_where: + cluster.append(reactions[where]) + clusters.append(cluster) + return clusters + + if correction == True: + dissimilarity_matrix = 1 - abs(correlation_matrix) + else: + dissimilarity_matrix = 1 - correlation_matrix + + Z = hierarchy.linkage(dissimilarity_matrix, linkage) + labels = hierarchy.fcluster(Z, t, criterion='distance') + + clusters = clusters_list(reactions, labels) + return dissimilarity_matrix, labels, clusters + + + +def graph_corr_matrix(correlation_matrix, reactions, correction=True, + clusters=[], subgraph_nodes = 5): + """A Python function that creates the main graph and its subgraphs + from a correlation matrix. + + Keyword arguments: + correlation_matrix -- A numpy 2D array of a correlation matrix. + reactions -- A list with the model's reactions. + correction -- A boolean variable that if True converts the values of the + the correlation matrix to absolute values. + clusters -- A nested list with clustered reactions created from the "" function. + subgraph_nodes -- A variable that specifies a cutoff for a graph's nodes. + It filters subgraphs with low number of nodes.. + """ + + graph_matrix = correlation_matrix.copy() + np.fill_diagonal(graph_matrix, 0) + + if correction == True: + graph_matrix = abs(graph_matrix) + + G = nx.from_numpy_array(graph_matrix) + G = nx.relabel_nodes(G, lambda x: reactions[x]) + + pos = nx.spring_layout(G) + unconnected_nodes = list(nx.isolates(G)) + G.remove_nodes_from(unconnected_nodes) + G_nodes = G.nodes() + + graph_list = [] + layout_list = [] + + graph_list.append(G) + layout_list.append(pos) + + subgraphs = [G.subgraph(c) for c in connected_components(G)] + H_nodes_list = [] + + for i in range(len(subgraphs)): + if len(subgraphs[i].nodes()) > subgraph_nodes and len(subgraphs[i].nodes()) != len(G_nodes): + H = G.subgraph(subgraphs[i].nodes()) + for cluster in clusters: + if H.has_node(cluster[0]) and H.nodes() not in H_nodes_list: + H_nodes_list.append(H.nodes()) + + pos = nx.spring_layout(H) + graph_list.append(H) + layout_list.append(pos) + + return graph_list, layout_list \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 683e537b..635b3761 100644 --- a/poetry.lock +++ b/poetry.lock @@ -852,6 +852,25 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4)"] tests = ["pytest (>=4.6)"] +[[package]] +name = "networkx" +version = "3.1" +description = "Python package for creating and manipulating graphs and networks" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, + {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, +] + +[package.extras] +default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] +developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] +doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] +test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] + [[package]] name = "numpy" version = "1.23.5" @@ -1749,4 +1768,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "39d92730e306a8be6d5c7e13d037893c020efae8f2669fe94d74d2a04831fafb" +content-hash = "1f3ee3d9ab943dfbe26cc9e3a421396ea1022552386d419e53c55ac614c5f1a5" diff --git a/pyproject.toml b/pyproject.toml index e9584d06..d0cdf087 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ cobra = "^0.26.0" plotly = "^5.11.0" kaleido = "0.2.1" pyoptinterface = {version = "^0.2.7", extras = ["highs"]} +networkx = "3.1" [tool.poetry.dev-dependencies]