Skip to content

Commit

Permalink
update node naming to including element symbol
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-j committed Apr 19, 2022
1 parent 7f25b92 commit 7a0d1a5
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
4 changes: 2 additions & 2 deletions graphein/molecule/edges/atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def add_atom_bonds(G: nx.Graph) -> nx.Graph:
G.graph["rdmol"].GetAtoms()[n1].GetSymbol(),
G.graph["rdmol"].GetAtoms()[n2].GetSymbol(),
)
n1 = sym1 + str(n1)
n2 = sym2 + str(n2)
n1 = f"{sym1}:{str(n1)}"
n2 = f"{sym2}:{str(n2)}"
if G.has_edge(n1, n2):
G.edges[n1, n2]["kind"].add("bond")
G.edges[n1, n2]["bond"] = bond
Expand Down
9 changes: 6 additions & 3 deletions graphein/molecule/edges/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def add_distance_threshold(G: nx.Graph, threshold: float = 5.0):

dist_mat = compute_distmat(G.graph["coords"])
interacting_nodes = get_interacting_atoms(threshold, distmat=dist_mat)
interacting_nodes = list(zip(interacting_nodes[0], interacting_nodes[1]))
outgoing = [list(G.nodes())[i] for i in interacting_nodes[0]]
incoming = [list(G.nodes())[i] for i in interacting_nodes[1]]
interacting_nodes = list(zip(outgoing, incoming))

log.info(
f"Found: {len(interacting_nodes)} distance edges for radius {threshold}"
Expand All @@ -84,7 +86,7 @@ def add_fully_connected_edges(
"""
length = len(G.graph["coords"])

for n1, n2 in itertools.product(range(length), range(length)):
for n1, n2 in itertools.product(G.nodes(), G.nodes()):
if G.has_edge(n1, n2):
G.edges[n1, n2]["kind"].add("fully_connected")
else:
Expand Down Expand Up @@ -135,7 +137,8 @@ def add_k_nn_edges(

# Create iterable of node indices
outgoing = np.repeat(np.array(range(len(G.graph["coords"]))), k)
incoming = nn.indices
outgoing = [list(G.nodes())[i] for i in outgoing]
incoming = [list(G.nodes())[i] for i in nn.indices]
interacting_nodes = list(zip(outgoing, incoming))
log.info(f"Found: {len(interacting_nodes)} KNN edges")
for n1, n2 in interacting_nodes:
Expand Down
2 changes: 1 addition & 1 deletion graphein/molecule/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def add_nodes_to_graph(
G.graph["coords"][i] if G.graph["coords"] is not None else None
)
G.add_node(
atom.GetSymbol() + str(atom.GetIdx()),
f"{atom.GetSymbol()}:{str(atom.GetIdx())}",
atomic_num=atom.GetAtomicNum(),
element=atom.GetSymbol(),
rdmol_atom=atom,
Expand Down
2 changes: 1 addition & 1 deletion tests/molecule/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_generate_graph_sdf():
for n, d in g.nodes(data=True):
assert isinstance(
d["atomic_num"], int
), f"{n} atomic_num is not an int"
), f"{n} atomic_num {d['atomic_num']} is not an int"
assert isinstance(d["element"], str), f"{n} element is not a string"
assert isinstance(
d["rdmol_atom"], rdkit.Chem.rdchem.Atom
Expand Down

0 comments on commit 7a0d1a5

Please sign in to comment.