@@ -81,10 +81,17 @@ def main(
8181 # Log configuration to W&B
8282 wandb_logger .experiment .config .update (config )
8383
84- # Define graph representation
84+ # Define graph/data representation, here the KNNGraph is used.
85+ # The KNNGraph is a graph representation, which uses the
86+ # KNNEdges edge definition with 8 neighbours as default.
87+ # The graph representation is defined by the detector,
88+ # in this case the Prometheus detector.
89+ # The standard node definition is used, which is NodesAsPulses.
8590 graph_definition = KNNGraph (detector = Prometheus ())
8691
87- # Use GraphNetDataModule to load in data
92+ # Use GraphNetDataModule to load in data and create dataloaders
93+ # The input here depends on the dataset being used,
94+ # in this case the Prometheus dataset.
8895 dm = GraphNeTDataModule (
8996 dataset_reference = config ["dataset_reference" ],
9097 dataset_args = {
@@ -110,17 +117,28 @@ def main(
110117
111118 # Building model
112119
120+ # Define architecture of the backbone, in this example
121+ # the DynEdge architecture is used.
122+ # https://iopscience.iop.org/article/10.1088/1748-0221/17/11/P11003
113123 backbone = DynEdge (
114124 nb_inputs = graph_definition .nb_outputs ,
115125 global_pooling_schemes = ["min" , "max" , "mean" , "sum" ],
116126 )
127+ # Define the task.
128+ # Here an energy reconstruction, with a LogCoshLoss function.
129+ # The target and prediction are transformed using the log10 function.
130+ # When infering the prediction is transformed back to the
131+ # original scale using 10^x.
117132 task = EnergyReconstruction (
118133 hidden_size = backbone .nb_outputs ,
119134 target_labels = config ["target" ],
120135 loss_function = LogCoshLoss (),
121136 transform_prediction_and_target = lambda x : torch .log10 (x ),
122137 transform_inference = lambda x : torch .pow (10 , x ),
123138 )
139+ # Define the full model, which includes the backbone, task(s),
140+ # along with typical machine learning options such as
141+ # learning rate optimizers and schedulers.
124142 model = StandardModel (
125143 graph_definition = graph_definition ,
126144 backbone = backbone ,
0 commit comments