Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format inconsistency between graphs constructed by pdb_code and pdb_path? #142

Closed
johnnytam100 opened this issue Mar 19, 2022 · 6 comments
Closed

Comments

@johnnytam100
Copy link

johnnytam100 commented Mar 19, 2022

Hi Arian! Seems there is format inconsistency between graphs constructed by pdb_code and pdb_path:

from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph
from graphein.protein.visualisation import plot_protein_structure_graph
from graphein.protein.edges.distance import (add_peptide_bonds,
                                             add_hydrogen_bond_interactions,
                                             add_disulfide_interactions,
                                             add_ionic_interactions,
                                             add_aromatic_interactions,
                                             add_aromatic_sulphur_interactions,
                                             add_cation_pi_interactions
                                            )
from graphein.protein.features.sequence.embeddings import esm_sequence_embedding
import pickle
import matplotlib.pyplot as plt
import os
import glob
from graphein.protein.subgraphs import extract_subgraph_from_chains


new_funcs = {"edge_construction_functions": [add_peptide_bonds,
                                              add_hydrogen_bond_interactions,
                                              add_disulfide_interactions,
                                              add_ionic_interactions,
                                              add_aromatic_interactions,
                                              add_aromatic_sulphur_interactions,
                                              add_cation_pi_interactions],
              # "graph_metadata_functions": [esm_sequence_embedding]
            }

config = ProteinGraphConfig(**new_funcs)

# construct graph
g_pdbcode = construct_graph(config=config, pdb_code="1ema")
g_pdbpath = construct_graph(config=config, pdb_path="1ema.pdb")

# same?
print(g_pdbcode == g_pdbcode)
print(g_pdbpath == g_pdbpath)
print(g_pdbcode == g_pdbpath)
DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe
DEBUG:graphein.protein.graphs:Detected 221 total nodes
INFO:graphein.protein.edges.distance:Found 78 hbond interactions.
INFO:graphein.protein.edges.distance:Found 2 hbond interactions.
INFO:graphein.protein.edges.distance:Found 2 disulfide interactions.
INFO:graphein.protein.edges.distance:Found 1645 ionic interactions.
INFO:graphein.protein.edges.distance:Found: 36 aromatic-aromatic interactions
DEBUG:graphein.protein.graphs:Deprotonating protein. This removes H atoms from the pdb_df dataframe
DEBUG:graphein.protein.graphs:Detected 221 total nodes
INFO:graphein.protein.edges.distance:Found 78 hbond interactions.
INFO:graphein.protein.edges.distance:Found 2 hbond interactions.
INFO:graphein.protein.edges.distance:Found 2 disulfide interactions.
INFO:graphein.protein.edges.distance:Found 1645 ionic interactions.
INFO:graphein.protein.edges.distance:Found: 36 aromatic-aromatic interactions
True
True
False
@a-r-j
Copy link
Owner

a-r-j commented Mar 19, 2022

Hi @johnnytam100 I don't believe this is a bug. I don't think checking for equality with == is supported by NetworkX. They are different python objects - we can see this with the following example:

class Test:
    val = 1

a = Test()
b = Test()
a == b
# Output: False

However, we can define an __eq__() method to determine equality. For example:

class Test:
    def __init__(self):
        self.val = 1
    
    def __eq__(self, other):
        if self.val == other.val:
            return True
        else:
            return False

a = Test()
b = Test()
a == b
# Output: True

We can see this is the case here when we make two of the 'same' graph:

g_pdbcode_1 = construct_graph(config=config, pdb_code="1ema")
g_pdbcode_2 = construct_graph(config=config, pdb_code="1ema")

g_pdbcode_1 == g_pdbcode_2
# Output: False

Instead, we should check if the graphs are the same, rather than if the python objects are the same.

You can do this with nx.is_isomorphic():

import networkx as nx

nx.is_isomorphic(g_pdbcode, g_pdbpath)
# Output: True

(I have tried this on the example you provided)

The nx.is_isomorphic function lets you define some custom functions to determine whether or not the node and edge attributes are equal between the two graphs.

A fuller and more robust test (which checks node and edge attributes) would be:

import numpy as np

def equal_dictionaries(dic1, dic2):
    for key, value in dic1.items():
        key1 = key
        value1 = value
    for key, value in dic2.items():
        key2 = key
        value2 = value
    if np.array_equal(value1, value2) == False or key1 != key2:
        return False
    else:
        return True

nx.is_isomorphic(g_pdbcode, g_pdbpath, node_match=equal_dictionaries, edge_match=equal_dictionaries)

I can add this to Graphein as I think it would be a useful feature.

@johnnytam100
Copy link
Author

johnnytam100 commented Mar 20, 2022

Arian, I am sorry for the silly testing using ==...

I see your tests, so you confirmed the graph from pdb_code is identical to pdb_path.

Why I raised this question was because when I used graphs constructed from pdb_code and fed them into the machine learning example you provided, it worked.

However, when I used graphs constructed from pdb_path, type error appeared (possibly, somewhere requested a tensor but a list was given):

import pickle
import networkx as nx
import os
import glob
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from graphein.ml.conversion import GraphFormatConvertor
from tqdm.notebook import tqdm
import numpy as np
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, global_add_pool
from torch.nn.functional import mse_loss, nll_loss, relu, softmax, cross_entropy
from torch.nn import functional as F
from torchmetrics.functional import accuracy
from pytorch_lightning.callbacks import ModelCheckpoint
import os
import random
import pytorch_lightning as pl
import torch.nn as nn
import pandas as pd

# Load
fp_df = pd.read_csv("./target_mapping.csv")

# Label col
label_col = "states.0.em_max"

# MAE loss
mae_loss = nn.L1Loss()

# collect graphs
path_list = []
graph_list = []

for path in glob.iglob('./selected_grep-ATOM/*.p'):
  path_list.append(path)

path_list.sort()

for path in path_list:
    with open(path, 'rb') as f:  # notice the r instead of w
        graph = pickle.load(f)
        graph_list.append(graph)

# nx2pyg
format_convertor = GraphFormatConvertor('nx', 'pyg',
                                        verbose = 'gnn',
                                        columns = None)

pyg_list = [format_convertor(graph) for graph in tqdm(graph_list)]

# assign target
for (idx, g), p in zip(enumerate(pyg_list), path_list):

    if not fp_df.loc[fp_df["uuid"]==p[21:26]][label_col].isnull().values[0]:

        # g.y = y_list[idx]             # original
        g.y = int(fp_df.loc[fp_df["uuid"]==p[21:26]][label_col].values[0])  # regression
        g.coords = torch.FloatTensor(g.coords[0])

# other formatting (?)
for i in pyg_list:
    if i.coords[0].shape[0] == len(i.node_id):
        pass
    else:
        print(i)
        pyg_list.remove(i)

# train, val, test split
np.random.seed(42)
idx_all = np.arange(len(pyg_list))
np.random.shuffle(idx_all)

train_idx, valid_idx, test_idx = np.split(idx_all, [int(.8*len(pyg_list)), int(.9*len(pyg_list))])
train, valid, test = [pyg_list[i] for i in train_idx], [pyg_list[i] for i in valid_idx], [pyg_list[i] for i in test_idx]

# compile model
config_default = dict(
    n_hid = 8,
    n_out = 8,
    batch_size = 4,
    dropout = 0.5,
    lr = 0.005,
    num_heads = 32,
    num_att_dim = 64,
    model_name = 'GAT'
)

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        
config = Struct(**config_default)

global model_name
model_name = config.model_name

class GraphNets(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        if model_name == 'GCN':
            self.layer1 = GCNConv(in_channels=3, out_channels=config.n_hid)
            self.layer2 = GCNConv(in_channels=config.n_hid, out_channels=config.n_out)

        elif model_name == 'GAT':
            self.layer1 = GATConv(3, config.num_att_dim, heads=config.num_heads, dropout=config.dropout)
            self.layer2 = GATConv(config.num_att_dim * config.num_heads, out_channels = config.n_out, heads=1, concat=False,
                                 dropout=config.dropout)
            
        elif model_name == 'GraphSAGE':
            self.layer1 = SAGEConv(3, config.n_hid)
            self.layer2 = SAGEConv(config.n_hid, config.n_out)  
            
        self.decoder = nn.Linear(config.n_out, 1)
        
    def forward(self, g):
        x = g.coords
        x = F.dropout(x, p=config.dropout, training=self.training)
        x = F.elu(self.layer1(x, g.edge_index))
        x = F.dropout(x, p=config.dropout, training=self.training)
        x = self.layer2(x, g.edge_index)
        x = global_add_pool(x, batch=g.batch)
        x = self.decoder(x)
        # return softmax(x)     # original
        return x

    def training_step(self, batch, batch_idx):
        x = batch   
        y = x.y
        y_hat = self(x)
        # loss = cross_entropy(y_hat, y)    # original
        loss = mae_loss(y_hat, y.float())
        # acc = accuracy(y_hat, y)            # original

        self.log("train_loss", loss)
        # self.log("train_acc", acc)          # original
        return loss
    
    def validation_step(self, batch, batch_idx):
        x = batch   
        y = x.y
        y_hat = self(x)
        # loss = cross_entropy(y_hat, y)    # original
        loss = mae_loss(y_hat, y.float())
        # acc = accuracy(y_hat, y)          # original
        self.log("valid_loss", loss)
        # self.log("valid_acc", acc)        # original

    def test_step(self, batch, batch_idx):
        x = batch   
        y = x.y
        y_hat = self(x)
        # loss = cross_entropy(y_hat, y)    # original
        loss = mae_loss(y_hat, y.float())
        # acc = accuracy(y_hat, y)            # original

        # y_pred_softmax = torch.log_softmax(y_hat, dim = 1)    # original
        # y_pred_tags = torch.argmax(y_pred_softmax, dim = 1)    # original
        # f1 = f1_score(y.detach().cpu().numpy(), y_pred_tags.detach().cpu().numpy(), average = 'weighted')   # original

        self.log("test_loss", loss)
        # self.log("test_acc", acc)         # original
        # self.log("test_f1", f1)           # original

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        return optimizer

GraphNets()

file_path = './graphein_model'
if not os.path.exists(file_path):
    os.mkdir(file_path)

checkpoint_callback = ModelCheckpoint(
    monitor="valid_loss",
    dirpath=file_path,
    filename="model-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min",
)

# data loader
train_loader = DataLoader(train, batch_size=config.batch_size, shuffle = True, drop_last = True)
valid_loader = DataLoader(valid, batch_size=32)
test_loader = DataLoader(test, batch_size=32)

# train model
model = GraphNets()
trainer = pl.Trainer(max_epochs=400, gpus=-1, callbacks=[checkpoint_callback])
trainer.fit(model, train_loader, valid_loader)

# evaluate on the model with the best validation set
best_model = GraphNets.load_from_checkpoint(checkpoint_callback.best_model_path)
out_best_test = trainer.test(best_model, test_loader)[0]
100%
104/104 [00:00<00:00, 256.05it/s]
/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead
  warnings.warn(out)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Data(edge_index=[2, 312], node_id=[226], coords=[226, 3], name=[1], dist_mat=[1], num_nodes=226, y=511)
Data(edge_index=[2, 304], node_id=[231], coords=[231, 3], name=[1], dist_mat=[1], num_nodes=231, y=538)
Data(edge_index=[2, 313], node_id=[226], coords=[226, 3], name=[1], dist_mat=[1], num_nodes=226, y=516)
Data(edge_index=[2, 314], node_id=[225], coords=[225, 3], name=[1], dist_mat=[1], num_nodes=225, y=500)
Data(edge_index=[2, 328], node_id=[236], coords=[236, 3], name=[1], dist_mat=[1], num_nodes=236, y=609)
Data(edge_index=[2, 304], node_id=[231], coords=[231, 3], name=[1], dist_mat=[1], num_nodes=231, y=633)
Data(edge_index=[2, 319], node_id=[239], coords=[239, 3], name=[1], dist_mat=[1], num_nodes=239, y=440)
Data(edge_index=[2, 332], node_id=[236], coords=[236, 3], name=[1], dist_mat=[1], num_nodes=236, y=596)
Data(edge_index=[2, 338], node_id=[234], coords=[234, 3], name=[1], dist_mat=[1], num_nodes=234, y=516)
Data(edge_index=[2, 292], node_id=[222], coords=[222, 3], name=[1], dist_mat=[1], num_nodes=222, y=519)
Data(edge_index=[2, 318], node_id=[238], coords=[238, 3], name=[1], dist_mat=[1], num_nodes=238, y=509)
Data(edge_index=[2, 303], node_id=[219], coords=[219, 3], name=[1], dist_mat=[1], num_nodes=219, y=513)
Data(edge_index=[2, 317], node_id=[239], coords=[239, 3], name=[1], dist_mat=[1], num_nodes=239, y=473)
Data(edge_index=[2, 299], node_id=[231], coords=[231, 3], name=[1], dist_mat=[1], num_nodes=231, y=648)
Data(edge_index=[2, 301], node_id=[231], coords=[231, 3], name=[1], dist_mat=[1], num_nodes=231, y=630)
Data(edge_index=[2, 343], node_id=[240], coords=[240, 3], name=[1], dist_mat=[1], num_nodes=240, y=485)
Data(edge_index=[2, 300], node_id=[232], coords=[232, 3], name=[1], dist_mat=[1], num_nodes=232, y=628)
Data(edge_index=[2, 313], node_id=[235], coords=[235, 3], name=[1], dist_mat=[1], num_nodes=235, y=611)
Data(edge_index=[2, 324], node_id=[239], coords=[239, 3], name=[1], dist_mat=[1], num_nodes=239, y=526)
Data(edge_index=[2, 318], node_id=[238], coords=[238, 3], name=[1], dist_mat=[1], num_nodes=238, y=505)
Data(edge_index=[2, 308], node_id=[228], coords=[228, 3], name=[1], dist_mat=[1], num_nodes=228, y=514)
Data(edge_index=[2, 330], node_id=[228], coords=[228, 3], name=[1], dist_mat=[1], num_nodes=228, y=524)
Data(edge_index=[2, 303], node_id=[231], coords=[231, 3], name=[1], dist_mat=[1], num_nodes=231, y=618)
Data(edge_index=[2, 297], node_id=[232], coords=[232, 3], name=[1], dist_mat=[1], num_nodes=232, y=594)
Data(edge_index=[2, 313], node_id=[237], coords=[237, 3], name=[1], dist_mat=[1], num_nodes=237, y=510)
Data(edge_index=[2, 309], node_id=[223], coords=[223, 3], name=[1], dist_mat=[1], num_nodes=223, y=486)
Data(edge_index=[2, 307], node_id=[221], coords=[221, 3], name=[1], dist_mat=[1], num_nodes=221, y=495)
Data(edge_index=[2, 322], node_id=[239], coords=[239, 3], name=[1], dist_mat=[1], num_nodes=239, y=529)
Data(edge_index=[2, 296], node_id=[231], coords=[231, 3], name=[1], dist_mat=[1], num_nodes=231, y=513)
Data(edge_index=[2, 318], node_id=[224], coords=[224, 3], name=[1], dist_mat=[1], num_nodes=224, y=515)
Data(edge_index=[2, 332], node_id=[236], coords=[236, 3], name=[1], dist_mat=[1], num_nodes=236, y=608)
Data(edge_index=[2, 298], node_id=[231], coords=[231, 3], name=[1], dist_mat=[1], num_nodes=231, y=502)
Data(edge_index=[2, 302], node_id=[235], coords=[235, 3], name=[1], dist_mat=[1], num_nodes=235, y=633)
Data(edge_index=[2, 338], node_id=[236], coords=[236, 3], name=[1], dist_mat=[1], num_nodes=236, y=461)
Data(edge_index=[2, 298], node_id=[231], coords=[231, 3], name=[1], dist_mat=[1], num_nodes=231, y=515)
Data(edge_index=[2, 322], node_id=[239], coords=[239, 3], name=[1], dist_mat=[1], num_nodes=239, y=516)
Data(edge_index=[2, 315], node_id=[230], coords=[230, 3], name=[1], dist_mat=[1], num_nodes=230, y=514)
Data(edge_index=[2, 326], node_id=[239], coords=[239, 3], name=[1], dist_mat=[1], num_nodes=239, y=515)
Data(edge_index=[2, 321], node_id=[239], coords=[239, 3], name=[1], dist_mat=[1], num_nodes=239, y=446)
Data(edge_index=[2, 318], node_id=[238], coords=[238, 3], name=[1], dist_mat=[1], num_nodes=238, y=509)
Data(edge_index=[2, 336], node_id=[232], coords=[232, 3], name=[1], dist_mat=[1], num_nodes=232, y=592)
Data(edge_index=[2, 322], node_id=[225], coords=[225, 3], name=[1], dist_mat=[1], num_nodes=225, y=591)
Data(edge_index=[2, 291], node_id=[220], coords=[220, 3], name=[1], dist_mat=[1], num_nodes=220, y=507)
Data(edge_index=[2, 320], node_id=[237], coords=[237, 3], name=[1], dist_mat=[1], num_nodes=237, y=541)
Data(edge_index=[2, 328], node_id=[238], coords=[238, 3], name=[1], dist_mat=[1], num_nodes=238, y=523)
Data(edge_index=[2, 318], node_id=[239], coords=[239, 3], name=[1], dist_mat=[1], num_nodes=239, y=477)
Data(edge_index=[2, 297], node_id=[228], coords=[228, 3], name=[1], dist_mat=[1], num_nodes=228, y=499)
Data(edge_index=[2, 316], node_id=[226], coords=[226, 3], name=[1], dist_mat=[1], num_nodes=226, y=513)

  | Name    | Type    | Params
------------------------------------
0 | layer1  | GATConv | 12.3 K
1 | layer2  | GATConv | 16.4 K
2 | decoder | Linear  | 9     
------------------------------------
28.7 K    Trainable params
0         Non-trainable params
28.7 K    Total params
0.115     Total estimated model params size (MB)
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:631: UserWarning: Checkpoint directory /content/google_drive/MyDrive/Colab Notebooks/AlphaFold_graph/graphein_model exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/data_loading.py:133: UserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-7-a9f77b2f40bf>](https://localhost:8080/#) in <module>()
    191 model = GraphNets()
    192 trainer = pl.Trainer(max_epochs=400, gpus=-1, callbacks=[checkpoint_callback])
--> 193 trainer.fit(model, train_loader, valid_loader)
    194 
    195 # evaluate on the model with the best validation set

19 frames
[/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in dropout(input, p, training, inplace)
   1167     if p < 0.0 or p > 1.0:
   1168         raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
-> 1169     return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
   1170 
   1171 

TypeError: dropout(): argument 'input' (position 1) must be Tensor, not list

Do you know what could be the possibility that caused the difference between input graphs from pdb_code and pdb_path?

@johnnytam100
Copy link
Author

johnnytam100 commented Mar 20, 2022

In particular, I want to understand this part, what does it do?

# other formatting (?)
for i in pyg_list:
    if i.coords[0].shape[0] == len(i.node_id):
        pass
    else:
        print(i)
        pyg_list.remove(i)

@a-r-j
Copy link
Owner

a-r-j commented Mar 20, 2022

Hey @johnnytam100, no problem at all. I’ll take a closer look at your code later today.

With respect to the code block you quoted:

# other formatting (?)
for i in pyg_list:
    if i.coords[0].shape[0] == len(i.node_id):
        pass
    else:
        print(i)
        pyg_list.remove(i) 

This loops over the list of converted graphs, and simply checks if the shape of the coordinate array matches the number of nodes in the graph. E.g. do we have a coordinate for each node and do we have a node for each coordinate. If these don’t match, we remove the graph from the list. This can throw off indexing with labels so be careful in using it.

@a-r-j
Copy link
Owner

a-r-j commented Mar 20, 2022

So, I'm not sure why this would be different between the graphs created from pdb files and from pdb codes but I think the problem is here:

for (idx, g), p in zip(enumerate(pyg_list), path_list):

    if not fp_df.loc[fp_df["uuid"]==p[21:26]][label_col].isnull().values[0]:

        # g.y = y_list[idx]             # original
        g.y = int(fp_df.loc[fp_df["uuid"]==p[21:26]][label_col].values[0])  # regression
        g.coords = torch.FloatTensor(g.coords[0])

From what I understand, you're checking to see if the dataset has a label for a particular example. If it does, you assign the label to g.y and you convert the coordinates to FloatTensors. The problem is that your list of graphs (pyg_list) still contains graphs that don't have a label. You've converted the graphs that have a label correctly but you haven't removed the graphs that don't have a label and so some of them will have list types for g.coords. This is what causes the problem when calling F.dropout() in the model.

I think the correct way to do this is a very simple fix:

for (idx, g), p in zip(enumerate(pyg_list), path_list):

    if not fp_df.loc[fp_df["uuid"]==p[21:26]][label_col].isnull().values[0]:

        # g.y = y_list[idx]             # original
        g.y = int(fp_df.loc[fp_df["uuid"]==p[21:26]][label_col].values[0])  # regression
        g.coords = torch.FloatTensor(g.coords[0])
    else:
        pyg_list.remove(g)

@johnnytam100
Copy link
Author

johnnytam100 commented Mar 20, 2022

@a-r-j Thank you so much for helping out!!! Yes, you exactly described what I was trying to do. I wrote something very similar before, but then there is another error:

100%
104/104 [00:00<00:00, 338.44it/s]
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-87-49aa837a2ea3>](https://localhost:8080/#) in <module>()
     64 # other formatting (?)
     65 for i in pyg_list:
---> 66     if i.coords.shape[0] == len(i.node_id):        # <---- here
     67         pass
     68     else:

AttributeError: 'list' object has no attribute 'shape'

However, I now got a fix that I don't know why: by writing two consecutive loops to remove the graph without y.

# Load
fp_df = pd.read_csv("./20220307_fpbase_all.csv")

# Label col
label_col = "states.0.em_max"

# MAE loss
mae_loss = nn.L1Loss()

# graphs from pdb_path
path_list = []
graph_list = []

for path in glob.iglob('./selected_grep-ATOM/*.p'):
  path_list.append(path)

path_list.sort()

for path in path_list:
    with open(path, 'rb') as f:  # notice the r instead of w
        graph = pickle.load(f)
        graph_list.append(graph)

# nx2pyg
format_convertor = GraphFormatConvertor('nx', 'pyg',
                                        verbose = 'gnn',
                                        columns = None)

pyg_list = [format_convertor(graph) for graph in tqdm(graph_list)]

# assign target
for (idx, g), p in zip(enumerate(pyg_list), path_list):

    if not fp_df.loc[fp_df["uuid"]==p[21:26]][label_col].isnull().values[0]:

        # g.y = y_list[idx]             # original
        g.y = int(fp_df.loc[fp_df["uuid"]==p[21:26]][label_col].values[0]) 
    
    g.coords = torch.FloatTensor(g.coords[0])

# other formatting (?)
for i in pyg_list:
    if i.coords.shape[0] == len(i.node_id): 
        pass
    else:
        print(i)
        pyg_list.remove(i)

for i in pyg_list:
    if i.y == None:
        print(i)
        pyg_list.remove(i)

# still have one graph don't have y, I don't know why

for i in pyg_list:
    if i.y == None:
        print(i)
        pyg_list.remove(i)

# now all graphs have y

a-r-j added a commit that referenced this issue Mar 22, 2022
* add initial graph equality funtions

* add equality function

* Add equality testing utils & tests

* update changelog
@a-r-j a-r-j closed this as completed Mar 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants