-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpolicy_iteration.py
64 lines (49 loc) · 1.82 KB
/
policy_iteration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
def policy_evaluation(env, policy, gamma, theta, max_iterations):
value = np.zeros(env.n_states, dtype=np.float)
# Iterate until the max iteration
for _ in range(max_iterations):
delta = 0
for s in range(env.n_states):
v = value[s]
# Computing the current value for policy evaluation
value[s] = sum(
[
env.p(next_s, s, policy[s]) *
(env.r(next_s, s, policy[s]) + gamma * value[next_s])
for next_s in range(env.n_states)
]
)
delta = max(delta, abs(v - value[s])) # difference to check convergence
# Breaks when policy converges
if delta < theta:
break
return value
def policy_improvement(env, value, gamma):
policy = np.zeros(env.n_states, dtype=int)
for s in range(env.n_states):
policy[s] = np.argmax( # picks action for each state with the highest expected reward
[
sum(
[
env.p(next_s, s, a) *
(env.r(next_s, s, a) + gamma * value[next_s])
for next_s in range(env.n_states)
]
)
for a in range(env.n_actions)
]
)
return policy
def policy_iteration(env, gamma, theta, max_iterations, policy=None):
if policy is None:
policy = np.zeros(env.n_states, dtype=int)
else:
policy = np.array(policy, dtype=int)
while(True):
policy_initial = policy
value = policy_evaluation(env, policy, gamma, theta, max_iterations)
policy = policy_improvement(env, value, gamma)
if np.array_equal(policy_initial, policy):
break
return policy, value