From adec0c175b44784749acc54c71363554ae79348a Mon Sep 17 00:00:00 2001 From: allaffa Date: Sun, 5 Nov 2023 14:45:34 -0500 Subject: [PATCH] abstractrawdataset updated --- hydragnn/utils/abstractrawdataset.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/hydragnn/utils/abstractrawdataset.py b/hydragnn/utils/abstractrawdataset.py index 125757bf7..657863ea4 100644 --- a/hydragnn/utils/abstractrawdataset.py +++ b/hydragnn/utils/abstractrawdataset.py @@ -62,6 +62,11 @@ def __init__(self, config, dist=False, sampling=None): """ # self.serial_data_name_list = [] + self.normalize_features = ( + config["Dataset"]["normalize_features"] + if config["Dataset"]["normalize_features"] is not None + else False + ) self.node_feature_name = ( config["Dataset"]["node_features"]["name"] if config["Dataset"]["node_features"]["name"] is not None @@ -205,7 +210,8 @@ def __load_raw_data(self): # scaled features by number of nodes self.__scale_features_by_num_nodes() - self.__normalize_dataset() + if self.normalize_features: + self.__normalize_dataset() def __normalize_dataset(self): @@ -393,17 +399,6 @@ def __build_edge(self): self.edge_feature_transform(data) for data in self.dataset ] - for data in self.dataset: - update_predicted_values( - self.variables_type, - self.output_index, - self.graph_feature_dim, - self.node_feature_dim, - data, - ) - - update_atom_features(self.input_node_features, data) - if "subsample_percentage" in self.variables.keys(): self.subsample_percentage = self.variables["subsample_percentage"] sampled = stratified_sampling(