This repository has been archived by the owner on Dec 27, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 11
/
model.py
134 lines (124 loc) · 5.08 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from typing import Mapping
from typing import Tuple, Optional, Dict
import dgl
import dgl.nn.pytorch as graph_nn
import pytorch_lightning as pl
import pytorch_lightning.metrics as metrics
import torch
import torch.nn.functional as F
from dgl.nn import Sequential
from pytorch_lightning.metrics import Metric
from torch import nn
class MalwareDetector(pl.LightningModule):
def __init__(
self,
input_dimension: int,
convolution_algorithm: str,
convolution_count: int,
):
super().__init__()
supported_algorithms = ['GraphConv', 'SAGEConv', 'TAGConv', 'DotGatConv']
if convolution_algorithm not in supported_algorithms:
raise ValueError(
f"{convolution_algorithm} is not supported. Supported algorithms are {supported_algorithms}")
self.save_hyperparameters()
self.convolution_layers = []
convolution_dimensions = [64, 32, 16]
for dimension in convolution_dimensions[:convolution_count]:
self.convolution_layers.append(self._get_convolution_layer(
name=convolution_algorithm,
input_dimension=input_dimension,
output_dimension=dimension
))
input_dimension = dimension
self.convolution_layers = Sequential(*self.convolution_layers)
self.last_dimension = input_dimension
self.classify = nn.Linear(input_dimension, 1)
# Metrics
self.loss_func = nn.BCEWithLogitsLoss()
self.train_metrics = self._get_metric_dict('train')
self.val_metrics = self._get_metric_dict('val')
self.test_metrics = self._get_metric_dict('test')
self.test_outputs = nn.ModuleDict({
'confusion_matrix': metrics.ConfusionMatrix(num_classes=2),
'prc': metrics.PrecisionRecallCurve(compute_on_step=False),
'roc': metrics.ROC(compute_on_step=False)
})
@staticmethod
def _get_convolution_layer(
name: str,
input_dimension: int,
output_dimension: int
) -> Optional[nn.Module]:
return {
"GraphConv": graph_nn.GraphConv(
input_dimension,
output_dimension,
activation=F.relu
),
"SAGEConv": graph_nn.SAGEConv(
input_dimension,
output_dimension,
activation=F.relu,
aggregator_type='mean',
norm=F.normalize
),
"DotGatConv": graph_nn.DotGatConv(
input_dimension,
output_dimension,
num_heads=1
),
"TAGConv": graph_nn.TAGConv(
input_dimension,
output_dimension,
k=4
)
}.get(name, None)
@staticmethod
def _get_metric_dict(stage: str) -> Mapping[str, Metric]:
return nn.ModuleDict({
f'{stage}_accuracy': metrics.Accuracy(),
f'{stage}_precision': metrics.Precision(num_classes=1),
f'{stage}_recall': metrics.Recall(num_classes=1),
f'{stage}_f1': metrics.FBeta(num_classes=1)
})
def forward(self, g: dgl.DGLGraph) -> torch.Tensor:
with g.local_scope():
h = g.ndata['features']
h = self.convolution_layers(g, h)
g.ndata['h'] = h if len(self.convolution_layers) > 0 else h[0]
# Calculate graph representation by averaging all the node representations.
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg).squeeze()
def training_step(self, batch: Tuple[dgl.DGLGraph, torch.Tensor], batch_idx: int) -> torch.Tensor:
bg, label = batch
logits = self.forward(bg)
loss = self.loss_func(logits, label)
prediction = torch.sigmoid(logits)
for metric_name, metric in self.train_metrics.items():
metric.update(prediction, label)
self.log('train_loss', loss, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch: Tuple[dgl.DGLGraph, torch.Tensor], batch_idx: int):
bg, label = batch
logits = self.forward(bg)
loss = self.loss_func(logits, label)
prediction = torch.sigmoid(logits)
for metric_name, metric in self.val_metrics.items():
metric.update(prediction, label)
self.log('val_loss', loss, on_step=False, on_epoch=True)
return loss
def test_step(self, batch: Tuple[dgl.DGLGraph, torch.Tensor], batch_idx: int):
bg, label = batch
logits = self.forward(bg)
prediction = torch.sigmoid(logits)
loss = self.loss_func(logits, label)
for metric_name, metric in self.test_metrics.items():
metric.update(prediction, label)
for metric_name, metric in self.test_outputs.items():
metric.update(prediction, label)
self.log('test_loss', loss, on_step=False, on_epoch=True)
return loss
def configure_optimizers(self) -> torch.optim.Adam:
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer