From 74d563975a55a85252b7be3ee6a07887bd171846 Mon Sep 17 00:00:00 2001 From: Jacob Sznajdman Date: Thu, 27 Jul 2023 17:56:08 +0200 Subject: [PATCH 1/2] Expose graphsage training configuration Co-authored-by: Olga Razvenskaia --- graphdatascience/gnn/gnn_nc_runner.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 90ee4b56b..37f7125d9 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -6,6 +6,21 @@ class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): + def make_graph_sage_config(self, graph_sage_config): + GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5, + "hidden_channels": 256} + final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG + if graph_sage_config: + bad_keys = [] + for key in graph_sage_config: + if key not in GRAPH_SAGE_DEFAULT_CONFIG: + bad_keys.append(key) + if len(bad_keys) > 0: + raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.") + + final_sage_config.update(graph_sage_config) + return final_sage_config + def train( self, graph_name: str, @@ -15,13 +30,15 @@ def train( relationship_types: List[str], target_node_label: str = None, node_labels: List[str] = None, + graph_sage_config = None ) -> "Series[Any]": # noqa: F821 mlConfigMap = { "featureProperties": feature_properties, "targetProperty": target_property, "job_type": "train", "nodeProperties": feature_properties + [target_property], - "relationshipTypes": relationship_types + "relationshipTypes": relationship_types, + "graph_sage_config": self.make_graph_sage_config(graph_sage_config) } if target_node_label: From fa544887ea83a68bb5005411e2b1dc4d43983be3 Mon Sep 17 00:00:00 2001 From: Olga Razvenskaia Date: Tue, 1 Aug 2023 11:19:15 +0100 Subject: [PATCH 2/2] Add learning rate Co-authored-by: Jacob Sznajdman --- graphdatascience/gnn/gnn_nc_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphdatascience/gnn/gnn_nc_runner.py b/graphdatascience/gnn/gnn_nc_runner.py index 37f7125d9..27aec8d63 100644 --- a/graphdatascience/gnn/gnn_nc_runner.py +++ b/graphdatascience/gnn/gnn_nc_runner.py @@ -8,7 +8,7 @@ class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): def make_graph_sage_config(self, graph_sage_config): GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5, - "hidden_channels": 256} + "hidden_channels": 256, "learning_rate": 0.003} final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG if graph_sage_config: bad_keys = []