-
Notifications
You must be signed in to change notification settings - Fork 0
/
pg_bandits.py
60 lines (49 loc) · 1.56 KB
/
pg_bandits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import time
from pathlib import Path
import numpy as np
import torch
from utils import (
AvgRewardAndTrueOptTracker,
SoftmaxPGEnv,
get_normal_bandits,
plot_testbed,
run_algo,
)
if __name__ == "__main__":
np.random.seed(0)
torch.manual_seed(0)
NUM_BANDITS, NUM_ARMS, NUM_STEPS = 2000, 10, 1000
NUM_ALGOS = 4
SAVE_PATH = Path(__file__).absolute().parent.parent.parent / "assets/imgs"
if not SAVE_PATH.exists():
SAVE_PATH.mkdir(parents=True)
file_name = SAVE_PATH / "pg-bandits.png"
# get 2000 bandit problems with 10 arms;
bandits = get_normal_bandits(NUM_BANDITS, NUM_ARMS, mean=4.0)
# get idx of true optimal for each bandit;
true_optimal = bandits.mean.argmax(-1).numpy()
# init metric tracker;
metric_tracker = AvgRewardAndTrueOptTracker(
NUM_ALGOS, NUM_STEPS, true_optimal
)
labels = []
# run experiments;
for idx in range(NUM_ALGOS):
# init action values at zero;
preferences = torch.zeros((NUM_BANDITS, NUM_ARMS))
bandit_env = SoftmaxPGEnv(
preferences, lr=0.1 + (idx % 2) * 0.3, with_baseline=idx < 2
)
labels.append(
f"PG, lr={bandit_env.lr:.2f}, baseline={bandit_env.with_baseline}"
)
now = time.time()
run_algo(idx, bandits, bandit_env, NUM_STEPS, metric_tracker)
print(f"run {idx} took: {time.time() - now:.2f} secs")
# save plot;
plot_testbed(
file_name,
labels,
metric_tracker.avg_rewards,
metric_tracker.prop_true_optimal,
)