-
Notifications
You must be signed in to change notification settings - Fork 35
/
train_hypergrid_simple_ls.py
117 lines (102 loc) · 3.63 KB
/
train_hypergrid_simple_ls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python
import argparse
import torch
from tqdm import tqdm
from gfn.gflownet import TBGFlowNet
from gfn.gym import HyperGrid
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import LocalSearchSampler
from gfn.utils.common import set_seed
from gfn.utils.modules import MLP
def main(args):
set_seed(args.seed)
device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
# Setup the Environment.
env = HyperGrid(ndim=args.ndim, height=args.height, device_str=device_str)
# Build the GFlowNet.
module_PF = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions,
)
module_PB = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
trunk=module_PF.trunk,
)
pf_estimator = DiscretePolicyEstimator(
module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor
)
pb_estimator = DiscretePolicyEstimator(
module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor
)
gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, logZ=0.0)
# Feed pf to the sampler.
sampler = LocalSearchSampler(pf_estimator=pf_estimator, pb_estimator=pb_estimator)
# Move the gflownet to the GPU.
gflownet = gflownet.to(device_str)
# Policy parameters have their own LR. Log Z gets dedicated learning rate
# (typically higher).
optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr)
optimizer.add_param_group(
{"params": gflownet.logz_parameters(), "lr": args.lr_logz}
)
for _ in (pbar := tqdm(range(args.n_iterations), dynamic_ncols=True)):
trajectories = sampler.sample_trajectories(
env,
n=(args.batch_size // args.n_local_search_loops),
save_logprobs=False,
save_estimator_outputs=False,
epsilon=args.epsilon,
n_local_search_loops=args.n_local_search_loops,
back_ratio=0.5,
use_metropolis_hastings=False,
)
optimizer.zero_grad()
loss = gflownet.loss(env, trajectories)
loss.backward()
optimizer.step()
pbar.set_postfix({"loss": loss.item()})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage")
parser.add_argument(
"--ndim", type=int, default=2, help="Number of dimensions in the environment"
)
parser.add_argument(
"--height", type=int, default=16, help="Height of the environment"
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--lr",
type=float,
default=1e-3,
help="Learning rate for the estimators' modules",
)
parser.add_argument(
"--lr_logz",
type=float,
default=1e-1,
help="Learning rate for the logZ parameter",
)
parser.add_argument(
"--n_iterations", type=int, default=1000, help="Number of iterations"
)
parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
parser.add_argument(
"--epsilon", type=float, default=0.1, help="Epsilon for the sampler"
)
# Local search parameters.
parser.add_argument(
"--n_local_search_loops",
type=int,
default=4,
help="Number of local search loops",
)
parser.add_argument(
"--back_ratio",
type=float,
default=0.5,
help="The ratio of the number of backward steps to the length of the trajectory",
)
args = parser.parse_args()
main(args)