forked from kaist-silab/equity-transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_data.py
77 lines (63 loc) · 3.36 KB
/
generate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import os
import numpy as np
from utils.data_utils import check_extension, save_dataset
def generate_mtsp_data(dataset_size, tsp_size):
return np.random.uniform(size=(dataset_size, tsp_size, 2)).tolist()
def generate_pdp_data(dataset_size, pdp_size):
return list(zip(np.random.uniform(size=(dataset_size, 2)).tolist(), # Depot location
np.random.uniform(size=(dataset_size, pdp_size, 2)).tolist()
))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--filename", help="Filename of the dataset to create (ignores datadir)")
parser.add_argument("--data_dir", default='data', help="Create datasets in data_dir/problem (default 'data')")
parser.add_argument("--name", type=str, required=True, help="Name to identify dataset")
parser.add_argument("--problem", type=str, default='all',
help="Problem, 'mtsp', 'mpdp', or 'all' to generate all")
parser.add_argument('--data_distribution', type=str, default='all',
help="Distributions to generate for problem, default 'all'.")
parser.add_argument("--dataset_size", type=int, default=10000, help="Size of the dataset")
parser.add_argument('--graph_sizes', type=int, nargs='+', default=[20, 50, 100],
help="Sizes of problem instances (default 20, 50, 100)")
parser.add_argument("-f", action='store_true', help="Set true to overwrite")
parser.add_argument('--seed', type=int, default=1234, help="Random seed")
opts = parser.parse_args()
assert opts.filename is None or (len(opts.problems) == 1 and len(opts.graph_sizes) == 1), \
"Can only specify filename when generating a single dataset"
distributions_per_problem = {
'mtsp': [None],
'mpdp': [None],
}
if opts.problem == 'all':
problems = distributions_per_problem
else:
problems = {
opts.problem:
distributions_per_problem[opts.problem]
if opts.data_distribution == 'all'
else [opts.data_distribution]
}
for problem, distributions in problems.items():
for distribution in distributions or [None]:
for graph_size in opts.graph_sizes:
datadir = os.path.join(opts.data_dir, problem)
os.makedirs(datadir, exist_ok=True)
if opts.filename is None:
filename = os.path.join(datadir, "{}{}{}_{}_seed{}.pkl".format(
problem,
"_{}".format(distribution) if distribution is not None else "",
graph_size, opts.name, opts.seed))
else:
filename = check_extension(opts.filename)
assert opts.f or not os.path.isfile(check_extension(filename)), \
"File already exists! Try running with -f option to overwrite."
np.random.seed(opts.seed)
if problem == 'mtsp':
dataset = generate_mtsp_data(opts.dataset_size, graph_size)
elif problem == 'mpdp':
dataset = generate_pdp_data(opts.dataset_size, graph_size)
else:
assert False, "Unknown problem: {}".format(problem)
# print(dataset[0])
save_dataset(dataset, filename)