Skip to content

Commit 89cef52

Browse files
authored
Merge pull request graphnet-team#756 from Aske-Rosted/train_example_comments
Train example comments
2 parents 8903e35 + af85b89 commit 89cef52

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

examples/04_training/01_train_dynedge.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,17 @@ def main(
8181
# Log configuration to W&B
8282
wandb_logger.experiment.config.update(config)
8383

84-
# Define graph representation
84+
# Define graph/data representation, here the KNNGraph is used.
85+
# The KNNGraph is a graph representation, which uses the
86+
# KNNEdges edge definition with 8 neighbours as default.
87+
# The graph representation is defined by the detector,
88+
# in this case the Prometheus detector.
89+
# The standard node definition is used, which is NodesAsPulses.
8590
graph_definition = KNNGraph(detector=Prometheus())
8691

87-
# Use GraphNetDataModule to load in data
92+
# Use GraphNetDataModule to load in data and create dataloaders
93+
# The input here depends on the dataset being used,
94+
# in this case the Prometheus dataset.
8895
dm = GraphNeTDataModule(
8996
dataset_reference=config["dataset_reference"],
9097
dataset_args={
@@ -110,17 +117,28 @@ def main(
110117

111118
# Building model
112119

120+
# Define architecture of the backbone, in this example
121+
# the DynEdge architecture is used.
122+
# https://iopscience.iop.org/article/10.1088/1748-0221/17/11/P11003
113123
backbone = DynEdge(
114124
nb_inputs=graph_definition.nb_outputs,
115125
global_pooling_schemes=["min", "max", "mean", "sum"],
116126
)
127+
# Define the task.
128+
# Here an energy reconstruction, with a LogCoshLoss function.
129+
# The target and prediction are transformed using the log10 function.
130+
# When infering the prediction is transformed back to the
131+
# original scale using 10^x.
117132
task = EnergyReconstruction(
118133
hidden_size=backbone.nb_outputs,
119134
target_labels=config["target"],
120135
loss_function=LogCoshLoss(),
121136
transform_prediction_and_target=lambda x: torch.log10(x),
122137
transform_inference=lambda x: torch.pow(10, x),
123138
)
139+
# Define the full model, which includes the backbone, task(s),
140+
# along with typical machine learning options such as
141+
# learning rate optimizers and schedulers.
124142
model = StandardModel(
125143
graph_definition=graph_definition,
126144
backbone=backbone,

0 commit comments

Comments
 (0)