Skip to content

Commit 3a0cbb4

Browse files
committed
Add Dynamic Programming exercises
1 parent 02c60e8 commit 3a0cbb4

8 files changed

+1181
-16
lines changed

DP/Policy Evaluation Solution.ipynb

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 44,
6+
"metadata": {
7+
"collapsed": false
8+
},
9+
"outputs": [],
10+
"source": [
11+
"import numpy as np\n",
12+
"import pprint\n",
13+
"import sys\n",
14+
"if \"../\" not in sys.path:\n",
15+
" sys.path.append(\"../\") \n",
16+
"from lib.envs.gridworld import GridworldEnv"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": 45,
22+
"metadata": {
23+
"collapsed": true
24+
},
25+
"outputs": [],
26+
"source": [
27+
"pp = pprint.PrettyPrinter(indent=2)\n",
28+
"env = GridworldEnv()"
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": 49,
34+
"metadata": {
35+
"collapsed": true
36+
},
37+
"outputs": [],
38+
"source": [
39+
"def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):\n",
40+
" \"\"\"\n",
41+
" Evaluate a policy given an environment and a full description of the environment's dynamics.\n",
42+
" \n",
43+
" Args:\n",
44+
" policy: [S, A] shaped matrix representing the policy.\n",
45+
" env: OpenAI env. env.P represents the transition probabilities of the environment.\n",
46+
" env.P[s][a] is a (prob, next_state, reward, done) tuple.\n",
47+
" theta: We stop evaluation one our value function change is less than theta for all states.\n",
48+
" discount_factor: lambda discount factor.\n",
49+
" \n",
50+
" Returns:\n",
51+
" Vector of length env.nS representing the value function.\n",
52+
" \"\"\"\n",
53+
" # Start with a random (all 0) value function\n",
54+
" V = np.zeros(env.nS)\n",
55+
" while True:\n",
56+
" delta = 0\n",
57+
" # For each state, perform a \"full backup\"\n",
58+
" for s in range(env.nS):\n",
59+
" v = 0\n",
60+
" # Look at the possible next actions\n",
61+
" for a, action_prob in enumerate(policy[s]):\n",
62+
" # For each action, look at the possible next states...\n",
63+
" for prob, next_state, reward, done in env.P[s][a]:\n",
64+
" # Calculate the expected value\n",
65+
" v += action_prob * prob * (reward + discount_factor * V[next_state])\n",
66+
" # How much our value function changed (across any states)\n",
67+
" delta = max(delta, np.abs(v - V[s]))\n",
68+
" V[s] = v\n",
69+
" # Stop evaluating once our value function change is below a threshold\n",
70+
" if delta < theta:\n",
71+
" break\n",
72+
" return np.array(V)"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": 50,
78+
"metadata": {
79+
"collapsed": false
80+
},
81+
"outputs": [
82+
{
83+
"name": "stdout",
84+
"output_type": "stream",
85+
"text": [
86+
"[ 0. -13.99993529 -19.99990698 -21.99989761 -13.99993529\n",
87+
" -17.9999206 -19.99991379 -19.99991477 -19.99990698 -19.99991379\n",
88+
" -17.99992725 -13.99994569 -21.99989761 -19.99991477 -13.99994569 0. ]\n"
89+
]
90+
}
91+
],
92+
"source": [
93+
"random_policy = np.ones([env.nS, env.nA]) / env.nA\n",
94+
"v = policy_eval(random_policy, env)"
95+
]
96+
},
97+
{
98+
"cell_type": "code",
99+
"execution_count": 52,
100+
"metadata": {
101+
"collapsed": false
102+
},
103+
"outputs": [
104+
{
105+
"name": "stdout",
106+
"output_type": "stream",
107+
"text": [
108+
"Value Function:\n",
109+
"[ 0. -13.99993529 -19.99990698 -21.99989761 -13.99993529\n",
110+
" -17.9999206 -19.99991379 -19.99991477 -19.99990698 -19.99991379\n",
111+
" -17.99992725 -13.99994569 -21.99989761 -19.99991477 -13.99994569 0. ]\n",
112+
"\n",
113+
"Reshaped Grid Value Function:\n",
114+
"[[ 0. -13.99993529 -19.99990698 -21.99989761]\n",
115+
" [-13.99993529 -17.9999206 -19.99991379 -19.99991477]\n",
116+
" [-19.99990698 -19.99991379 -17.99992725 -13.99994569]\n",
117+
" [-21.99989761 -19.99991477 -13.99994569 0. ]]\n",
118+
"\n"
119+
]
120+
}
121+
],
122+
"source": [
123+
"print(\"Value Function:\")\n",
124+
"print(v)\n",
125+
"print(\"\")\n",
126+
"\n",
127+
"print(\"Reshaped Grid Value Function:\")\n",
128+
"print(v.reshape(env.shape))\n",
129+
"print(\"\")"
130+
]
131+
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": 51,
135+
"metadata": {
136+
"collapsed": false
137+
},
138+
"outputs": [],
139+
"source": [
140+
"# Test: Make sure the evaluated policy is what we expected\n",
141+
"expected_v = np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22, -20, -14, 0])\n",
142+
"np.testing.assert_array_almost_equal(v, expected_v, decimal=2)"
143+
]
144+
},
145+
{
146+
"cell_type": "code",
147+
"execution_count": null,
148+
"metadata": {
149+
"collapsed": true
150+
},
151+
"outputs": [],
152+
"source": []
153+
}
154+
],
155+
"metadata": {
156+
"kernelspec": {
157+
"display_name": "Python 3",
158+
"language": "python",
159+
"name": "python3"
160+
},
161+
"language_info": {
162+
"codemirror_mode": {
163+
"name": "ipython",
164+
"version": 3
165+
},
166+
"file_extension": ".py",
167+
"mimetype": "text/x-python",
168+
"name": "python",
169+
"nbconvert_exporter": "python",
170+
"pygments_lexer": "ipython3",
171+
"version": "3.5.1"
172+
}
173+
},
174+
"nbformat": 4,
175+
"nbformat_minor": 0
176+
}

DP/Policy Evaluation.ipynb

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"collapsed": false
8+
},
9+
"outputs": [],
10+
"source": [
11+
"import numpy as np\n",
12+
"import pprint\n",
13+
"import sys\n",
14+
"if \"../\" not in sys.path:\n",
15+
" sys.path.append(\"../\") \n",
16+
"from lib.envs.gridworld import GridworldEnv"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"metadata": {
23+
"collapsed": true
24+
},
25+
"outputs": [],
26+
"source": [
27+
"pp = pprint.PrettyPrinter(indent=2)\n",
28+
"env = GridworldEnv()"
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": 13,
34+
"metadata": {
35+
"collapsed": false
36+
},
37+
"outputs": [],
38+
"source": [
39+
"def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):\n",
40+
" \"\"\"\n",
41+
" Evaluate a policy given an environment and a full description of the environment's dynamics.\n",
42+
" \n",
43+
" Args:\n",
44+
" policy: [S, A] shaped matrix representing the policy.\n",
45+
" env: OpenAI env. env.P represents the transition probabilities of the environment.\n",
46+
" env.P[s][a] is a (prob, next_state, reward, done) tuple.\n",
47+
" theta: We stop evaluation one our value function change is less than theta for all states.\n",
48+
" discount_factor: lambda discount factor.\n",
49+
" \n",
50+
" Returns:\n",
51+
" Vector of length env.nS representing the value function.\n",
52+
" \"\"\"\n",
53+
" # Start with a random (all 0) value function\n",
54+
" V = np.zeros(env.nS)\n",
55+
" while True:\n",
56+
" # TODO: Implement!\n",
57+
" break\n",
58+
" return np.array(V)"
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": 16,
64+
"metadata": {
65+
"collapsed": false
66+
},
67+
"outputs": [],
68+
"source": [
69+
"random_policy = np.ones([env.nS, env.nA]) / env.nA\n",
70+
"v = policy_eval(random_policy, env)"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": 17,
76+
"metadata": {
77+
"collapsed": false
78+
},
79+
"outputs": [
80+
{
81+
"ename": "AssertionError",
82+
"evalue": "\nArrays are not almost equal to 2 decimals\n\n(mismatch 87.5%)\n x: array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0.])\n y: array([ 0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22,\n -20, -14, 0])",
83+
"output_type": "error",
84+
"traceback": [
85+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
86+
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
87+
"\u001b[0;32m<ipython-input-17-235f39fb115c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Test: Make sure the evaluated policy is what we expected\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mexpected_v\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m22\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m18\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m18\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m22\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtesting\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0massert_array_almost_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpected_v\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
88+
"\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_array_almost_equal\u001b[0;34m(x, y, decimal, err_msg, verbose)\u001b[0m\n\u001b[1;32m 914\u001b[0m assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,\n\u001b[1;32m 915\u001b[0m \u001b[0mheader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Arrays are not almost equal to %d decimals'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 916\u001b[0;31m precision=decimal)\n\u001b[0m\u001b[1;32m 917\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 918\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
89+
"\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_array_compare\u001b[0;34m(comparison, x, y, err_msg, verbose, header, precision)\u001b[0m\n\u001b[1;32m 735\u001b[0m names=('x', 'y'), precision=precision)\n\u001b[1;32m 736\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcond\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 737\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 738\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 739\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtraceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
90+
"\u001b[0;31mAssertionError\u001b[0m: \nArrays are not almost equal to 2 decimals\n\n(mismatch 87.5%)\n x: array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0.])\n y: array([ 0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22,\n -20, -14, 0])"
91+
]
92+
}
93+
],
94+
"source": [
95+
"# Test: Make sure the evaluated policy is what we expected\n",
96+
"expected_v = np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22, -20, -14, 0])\n",
97+
"np.testing.assert_array_almost_equal(v, expected_v, decimal=2)"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"metadata": {
104+
"collapsed": true
105+
},
106+
"outputs": [],
107+
"source": []
108+
}
109+
],
110+
"metadata": {
111+
"kernelspec": {
112+
"display_name": "Python 3",
113+
"language": "python",
114+
"name": "python3"
115+
},
116+
"language_info": {
117+
"codemirror_mode": {
118+
"name": "ipython",
119+
"version": 3
120+
},
121+
"file_extension": ".py",
122+
"mimetype": "text/x-python",
123+
"name": "python",
124+
"nbconvert_exporter": "python",
125+
"pygments_lexer": "ipython3",
126+
"version": "3.5.1"
127+
}
128+
},
129+
"nbformat": 4,
130+
"nbformat_minor": 0
131+
}

0 commit comments

Comments
 (0)