Skip to content

Commit

Permalink
malnet tiny (#3472)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 10, 2021
1 parent ec3e54f commit 53ffe03
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 4 deletions.
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from .twitch import Twitch
from .airports import Airports
from .ba_shapes import BAShapes
from .malnet_tiny import MalNetTiny

__all__ = [
'KarateClub',
Expand Down Expand Up @@ -134,6 +135,7 @@
'Twitch',
'Airports',
'BAShapes',
'MalNetTiny',
]

classes = __all__
10 changes: 6 additions & 4 deletions torch_geometric/datasets/hgb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ class HGBDataset(InMemoryDataset):

def __init__(self, root: str, name: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None):
pre_transform: Optional[Callable] = None):
self.name = name.lower()
assert self.name in set(self.names.keys())
super().__init__(root, transform, pre_transform)
Expand Down Expand Up @@ -114,7 +113,7 @@ def process(self):
src, dst = n_types[int(src)], n_types[int(dst)]
rel = rel.split('-')[1]
e_types[key] = (src, rel, dst)
else:
else: # Link prediction:
raise NotImplementedError

# Extract node information:
Expand Down Expand Up @@ -180,9 +179,12 @@ def process(self):
n_id, n_type = mapping_dict[int(y[0])], n_types[int(y[2])]
data[n_type].test_mask[n_id] = True

else:
else: # Link prediction:
raise NotImplementedError

if self.pre_transform is not None:
data = self.pre_transform(data)

torch.save(self.collate([data]), self.processed_paths[0])

def __repr__(self) -> str:
Expand Down
79 changes: 79 additions & 0 deletions torch_geometric/datasets/malnet_tiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Optional, Callable, List

import os
import glob
import os.path as osp

import torch
from torch_geometric.data import (InMemoryDataset, Data, download_url,
extract_tar)


class MalNetTiny(InMemoryDataset):
r"""The MalNet Tiny dataset from the
`"A Large-Scale Database for Graph Representation Learning"
<https://openreview.net/pdf?id=1xDTDk3XPW>`_ paper.
:class:`MalNetTiny` contains 5,000 malicious and benign software function
call graphs across 5 different types.
Args:
root (string): Root directory where the dataset should be saved.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""

url = 'http://malnet.cc.gatech.edu/graph-data/malnet-graphs-tiny.tar.gz'

def __init__(self, root: str, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None):
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self) -> List[str]:
folders = ['addisplay', 'adware', 'benign', 'downloader', 'trojan']
return [osp.join('malnet-graphs-tiny', folder) for folder in folders]

@property
def processed_file_names(self) -> str:
return 'data.pt'

def download(self):
path = download_url(self.url, self.raw_dir)
extract_tar(path, self.raw_dir)
os.unlink(path)

def process(self):
data_list = []

for y, raw_path in enumerate(self.raw_paths):
raw_path = osp.join(raw_path, os.listdir(raw_path)[0])
filenames = glob.glob(osp.join(raw_path, '*.edgelist'))

for filename in filenames:
with open(filename, 'r') as f:
edges = f.read().split('\n')[5:-1]
edge_index = [[int(edge[0]), int(edge[-1])] for edge in edges]
edge_index = torch.tensor(edge_index).t().contiguous()
num_nodes = int(edge_index.max()) + 1
data = Data(edge_index=edge_index, y=y, num_nodes=num_nodes)
data_list.append(data)

if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]

if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]

torch.save(self.collate(data_list), self.processed_paths[0])

0 comments on commit 53ffe03

Please sign in to comment.