Skip to content

Commit

Permalink
isolate plot module
Browse files Browse the repository at this point in the history
  • Loading branch information
haozhu233 committed Apr 16, 2024
1 parent 061b0a8 commit 9fd86ca
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 117 deletions.
Binary file removed .DS_Store
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ mid-range 12-core CPU.

Citation
--------
Please consider cite our work if you found it useful for your work:
Please consider cite our work if you found it useful for your work.



Expand Down
261 changes: 145 additions & 116 deletions regdiffusion/grn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from collections import deque, Counter
from scipy.sparse import csr_matrix
import h5py
import pyvis
from typing import List, Dict, Union
import concurrent.futures
from .plot import plot_pyvis

class GRN:
"""
Expand Down Expand Up @@ -74,26 +75,40 @@ def __init__(self, adj_matrix: np.ndarray,
adj_matrix[np.abs(adj_matrix) < self.cutoff_threshold] = 0

self.adj_matrix = adj_matrix
# TODO: make it tf safe.
if self.n_tfs == self.n_genes:
self.adj_matrix_2way = np.concatenate([
adj_matrix, adj_matrix.transpose()
], axis=0)
else:
self.adj_matrix_2way = None

self.gene_indices = {g: i for i, g in enumerate(gene_names)}
self.tf_indices = {g: i for i, g in enumerate(tf_names)}

def generate_adj_list(self) -> pd.DataFrame:
self.calculated_neighbors = {}

def get_edgelist(self, k: int = 20, workers: int = 2) -> pd.DataFrame:
"""
Simply generate a dataframe to hold the adjacency list.
Simply generate a dataframe to hold the edge list.
The dataframe will have 3 columns: `source`, `target`, `weight`.
Args:
k (int): Top-k edges to inspect on each node. If k=-1, export all.
workers (int): Number of concurrent workers. Default is 2.
"""
all_edges = []
for r in tqdm(range(self.n_tfs)):
for c in range(self.n_genes):
if self.adj_matrix[r, c] != 0:
all_edges.append({
'source': self.tf_names[r],
'target': self.gene_names[c],
'weight': self.adj_matrix[r, c]
})
return pd.DataFrame(all_edges)
with concurrent.futures.ThreadPoolExecutor(
max_workers=workers) as executor:
futures = [
executor.submit(
self.extract_node_neighbors, g, k
) for g in list(self.gene_names)]

all_edges = [
future.result() for future in
concurrent.futures.as_completed(futures)]
return pd.concat(all_edges).reset_index(drop=True)


def extract_node_sources_as_indices(self, gene: str, k: int = 20) -> Dict:
Expand All @@ -103,8 +118,7 @@ def extract_node_sources_as_indices(self, gene: str, k: int = 20) -> Dict:
Args:
genes (str, List(str)): A single gene or a list of genes to inspect.
k (int): Top-k edges to inspect on each node.
node_size (int): The size of nodes in the visualization.
k (int): Top-k edges to inspect on each node. If k=-1, export all
"""
gene_idx = self.gene_indices[gene]
node_neighbors = self.adj_matrix[:, gene_idx]
Expand All @@ -129,8 +143,7 @@ def extract_node_sources(self, gene: str, k: int = 20) -> pd.DataFrame:
Args:
genes (str, List(str)): A single gene or a list of genes to inspect.
k (int): Top-k edges to inspect on each node.
node_size (int): The size of nodes in the visualization.
k (int): Top-k edges to inspect on each node. If k=-1, export all
"""
source_indices = self.extract_node_sources_as_indices(gene, k)
top_gene_names = [
Expand All @@ -150,8 +163,7 @@ def extract_node_targets_as_indices(self, gene: str, k: int = 20) -> Dict:
Args:
genes (str, List(str)): A single gene or a list of genes to inspect.
k (int): Top-k edges to inspect on each node.
node_size (int): The size of nodes in the visualization.
k (int): Top-k edges to inspect on each node. If k=-1, export all
"""
gene_idx = self.tf_indices[gene]
node_neighbors = self.adj_matrix[gene_idx, :]
Expand All @@ -176,8 +188,7 @@ def extract_node_targets(self, gene: str, k: int = 20) -> pd.DataFrame:
Args:
genes (str, List(str)): A single gene or a list of genes to inspect.
k (int): Top-k edges to inspect on each node.
node_size (int): The size of nodes in the visualization.
k (int): Top-k edges to inspect on each node. If k=-1, export all
"""
target_indices = self.extract_node_targets_as_indices(gene, k)
top_gene_names = [
Expand All @@ -197,27 +208,29 @@ def extract_node_neighbors_as_indices(self, gene: str, k: int = 20) -> Dict:
Args:
genes (str, List(str)): A single gene or a list of genes to inspect.
k (int): Top-k edges to inspect on each node.
node_size (int): The size of nodes in the visualization.
k (int): Top-k edges to inspect on each node. If k=-1, export all
"""
sources = self.extract_node_sources_as_indices(gene, k)
targets = self.extract_node_targets_as_indices(gene, k)
all_weights = np.abs(np.concatenate([
sources['weights'], targets['weights']
]))
cutoff = np.partition(all_weights, -k)[-k]
source_crit = (sources['weights'] >= cutoff)
target_crit = (targets['weights'] >= cutoff)
gene_idx = self.gene_indices[gene]
if (gene_idx, k) in self.calculated_neighbors:
return self.calculated_neighbors[(gene_idx, k)]
node_neighbors = self.adj_matrix_2way[:, gene_idx]
node_neighbors_abs = np.abs(node_neighbors)
if k!=-1:
top_indices = np.argpartition(node_neighbors_abs, -k)[-k:]
else:
top_indices = np.where(node_neighbors_abs!=0)[0]
top_edge_weights = node_neighbors[top_indices]
top_source_indices = top_indices[top_indices < self.n_tfs]
top_source_weights = top_edge_weights[top_indices < self.n_tfs]
top_target_indices = top_indices[top_indices >= self.n_tfs] - self.n_tfs
top_target_weights = top_edge_weights[top_indices >= self.n_tfs]
output = {
'sources': {
'tf_indices': sources['tf_indices'][source_crit],
'weights': sources['weights'][source_crit]
},
'targets': {
'gene_indices': targets['gene_indices'][target_crit],
'weights': targets['weights'][target_crit]
}
'source_indices': top_source_indices,
'source_weights': top_source_weights,
'target_indices': top_target_indices,
'target_weights': top_target_weights
}
self.calculated_neighbors[(gene_idx, k)] = output
return output

def extract_node_neighbors(self, gene: str, k: int = 20) -> pd.DataFrame:
Expand All @@ -229,114 +242,130 @@ def extract_node_neighbors(self, gene: str, k: int = 20) -> pd.DataFrame:
Args:
genes (str, List(str)): A single gene or a list of genes to inspect.
k (int): Top-k edges to inspect on each node.
node_size (int): The size of nodes in the visualization.
k (int): Top-k edges to inspect on each node. If k=-1, export all
"""
top_sources = self.extract_node_sources(gene, k)
top_targets = self.extract_node_targets(gene, k)

output = pd.concat([top_sources, top_targets])
output['abs_weight'] = output.weight.abs()
output = output.sort_values(
'abs_weight', ascending=False
).head(k).reset_index(drop=True)
del output['abs_weight']
return output
neighbor_indices = self.extract_node_neighbors_as_indices(gene, k)
source_gene_names = [
self.tf_names[i] for i in neighbor_indices['source_indices']
]
source_tbl = pd.DataFrame({
'source': source_gene_names,
'target': gene,
'weight': neighbor_indices['source_weights']
})
target_gene_names = [
self.gene_names[i] for i in neighbor_indices['target_indices']
]
target_tbl = pd.DataFrame({
'source': gene,
'target': target_gene_names,
'weight': neighbor_indices['target_weights']
})
return pd.concat([source_tbl, target_tbl])

def extract_node_2hop_neighborhood(
self, genes: Union[str, List[str]], k: int = 20
def extract_local_neighborhood(
self, genes: Union[str, List[str]], k: int = 20, hops: str = "2.5"
) -> pd.DataFrame:
"""
Generate a pandas dataframe for the local neighborhood (2-hop) around
selected gene(s).
Generate a pandas dataframe for the 2.5 or 1.5 hop local neighborhood
around selected gene(s). "2.5 hop local neighborhood" includes all the
nodes and edges reachable by a 2-hop search from the selected genes and
the edges connecting all the 2-hop nodes. "1.5 hop local neighborhood"
is defined in a similar way but smaller.
Args:
genes (str, List(str)): A single gene or a list of genes to inspect.
k (int): Top-k edges to inspect on each node.
node_size (int): The size of nodes in the visualization.
k (int): Top-k edges to inspect on each node. If k=-1, export all
hops (str): Number of hops to explore. We can either do a "2.5" or
"1.5" hop travesal around selected genes. Default is "2.5".
"""
if isinstance(genes, str):
genes = [genes]
hop0_genes = set(genes)

# Hop 1
hop1 = pd.concat([
self.extract_node_neighbors(g, k=k) for g in hop0_genes
])
hop1['weight'] = 0
hop1['hop'] = 0
hop1_genes = set()
for g in hop1.source:
if g not in hop0_genes:
hop1_genes.add(g)
for g in hop1.target:
if g not in hop0_genes:
hop1_genes.add(g)
hop2s = pd.concat([
self.extract_node_neighbors(g, k=k) for g in hop1_genes
])
hop2s['weight'] = 1
hop2_genes = set()
for g in hop2s.source:
if g not in hop0_genes and g not in hop1_genes:
hop2_genes.add(g)
for g in hop2s.target:
if g not in hop0_genes and g not in hop1_genes:
hop2_genes.add(g)
hop3s = []
for g in hop2_genes:
hop3 = self.extract_node_neighbors(g, k=k)
hop3 = hop3[
hop3.source.isin(hop2_genes) & hop3.target.isin(hop2_genes)
]
hop3s.append(hop3)
hop3s = pd.concat(hop3s)
hop3s['weight'] = 2
adj_table = pd.concat([hop1, hop2s, hop3s]).reset_index(drop=True)
return adj_table

# Hop 2
if hops == "1.5":
hop2s = []
for g in hop1_genes:
hop2 = self.extract_node_neighbors(g, k=k)
hop2 = hop2[
hop2.source.isin(hop1_genes) & hop2.target.isin(hop1_genes)
]
hop2s.append(hop2)
hop2s = pd.concat(hop2s)
hop2s['hop'] = 1
adj_table = pd.concat([hop1, hop2s]).reset_index(drop=True)
return adj_table
elif hops == "2.5":
hop2s = pd.concat([
self.extract_node_neighbors(g, k=k) for g in hop1_genes
])
hop2s['hop'] = 1
hop2_genes = set()
for g in hop2s.source:
if g not in hop0_genes and g not in hop1_genes:
hop2_genes.add(g)
for g in hop2s.target:
if g not in hop0_genes and g not in hop1_genes:
hop2_genes.add(g)

# Hop 2.5
hop3s = []
for g in hop2_genes:
hop3 = self.extract_node_neighbors(g, k=k)
hop3 = hop3[
hop3.source.isin(hop2_genes) & hop3.target.isin(hop2_genes)
]
hop3s.append(hop3)
hop3s = pd.concat(hop3s)
hop3s['hop'] = 2
adj_table = pd.concat([hop1, hop2s, hop3s]).reset_index(drop=True)
return adj_table

def visualize_local_neighborhood(
self, genes: Union[str, List[str]], k: int = 20,
node_size: int = 8, edge_widths: List[int] = [2, 1, 0.5],
font_size: int = 30,
node_group_dict: Dict = None,
cdn_resources: str = 'remote', notebook: bool = True):
self, genes: Union[str, List[str]], k: int = 20, hops: str = "2.5",
edge_widths: List[int] = [2, 1, 0.5],
plot_engine: str = 'pyvis', *args, **kwargs):
"""
Generate a vis.js network visualization of the local neighborhood
(2-hop) around selected gene(s).
Args:
genes (str, List(str)): A single gene or a list of genes to inspect.
k (int): Top-k edges to inspect on each node.
node_size (int): The size of nodes in the visualization.
edge_widths (List): The widths for edges (1st, 2nd, between 2nd
hops).
font_size (int): The font size for nodes labels.
node_group_dict (dict): A dictionary with keys being the names of
genes and values being the groups. Genes from the same group
will be colored using the same color.
cdn_resources (str): Where to load vis.js resources. Default is
'remote'.
notebook (bool): Boolean value indicating whether the visualization
happens in a jupyter notebook.
k (int): Top-k edges to inspect on each node. If k=-1, export all.
hops (str): Number of hops of the neighborhood to explore. Default
is "2.5".
edge_widths (List): The widths for edges for different edge width
levels.
plot_engine (str): Choose which network plot engine to use. Default
is "pyvis".
**kwargs: Keyword arguments to be passed to ``plot_pyvis``.
"""
local_adj_table = self.extract_node_2hop_neighborhood(genes, k)
local_adj_table.weight = local_adj_table.weight.map(
if isinstance(genes, str):
genes = [genes]
local_adj_table = self.extract_local_neighborhood(genes, k, hops)
local_adj_table['edge_width'] = local_adj_table.hop.map(
lambda x: edge_widths[x])

g = pyvis.network.Network(
cdn_resources=cdn_resources,
notebook=notebook
)

for node in set(local_adj_table['source']) | set(local_adj_table['target']):
node_shape = 'star' if node in genes else 'dot'
node_group = None if node_group_dict is None else node_group_dict[node]
g.add_node(node, label=node, size=node_size,
shape=node_shape, group=node_group,
font={"size": font_size})

for _, row in local_adj_table.iterrows():
g.add_edge(row['source'], row['target'], width=row['weight'])

g.repulsion()

if plot_engine == 'pyvis':
g = plot_pyvis(
pandas_edgelist = local_adj_table,
star_genes = genes, *args, **kwargs)
else:
raise Exception("Not implemented yet")
return g

def to_hdf5(self, file_path: str, as_sparse: bool = False):
Expand Down
1 change: 1 addition & 0 deletions regdiffusion/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pyvis import plot_pyvis
Loading

0 comments on commit 9fd86ca

Please sign in to comment.