Skip to content

Commit 6232e18

Browse files
committed
update
1 parent 62daf27 commit 6232e18

File tree

1 file changed

+65
-62
lines changed

1 file changed

+65
-62
lines changed

lectures/career.md

Lines changed: 65 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ jupytext:
44
extension: .md
55
format_name: myst
66
format_version: 0.13
7-
jupytext_version: 1.17.2
7+
jupytext_version: 1.17.1
88
kernelspec:
99
name: python3
1010
display_name: Python 3 (ipykernel)
@@ -37,18 +37,6 @@ In addition to what's in Anaconda, this lecture will need the following librarie
3737
!pip install quantecon
3838
```
3939

40-
```{admonition} GPU acceleration
41-
:class: warning
42-
43-
This lecture uses JAX for hardware acceleration and automatic differentiation.
44-
45-
For faster execution, consider running this lecture on a GPU.
46-
47-
You can access free GPUs on [Google Colab](https://colab.research.google.com/) by selecting "Runtime → Change runtime type → Hardware accelerator → GPU" from the menu.
48-
49-
To install JAX with GPU support locally, please consult the [JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html).
50-
```
51-
5240
## Overview
5341

5442
Next, we study a computational problem concerning career and job choices.
@@ -63,7 +51,7 @@ We begin with some imports:
6351
import matplotlib.pyplot as plt
6452
import jax.numpy as jnp
6553
import jax
66-
import quantecon as qe
54+
import jax.random as jr
6755
from typing import NamedTuple
6856
from quantecon.distributions import BetaBinomial
6957
from scipy.special import binom, beta
@@ -391,7 +379,7 @@ In particular, modulo randomness, reproduce the following figure (where the hori
391379

392380
```{hint}
393381
:class: dropdown
394-
To generate the draws from the distributions $F$ and $G$, use `quantecon.random.draw()`.
382+
To generate the draws from the distributions $F$ and $G$, use `quantecon.jr.draw()`.
395383
```
396384

397385
```{exercise-end}
@@ -410,41 +398,51 @@ $(\theta_i, \epsilon_j)$ = either 1, 2 or 3; meaning 'stay put',
410398

411399
```{code-cell} ipython3
412400
model = create_career_worker_problem()
413-
F = np.array(jnp.cumsum(model.F_probs))
414-
G = np.array(jnp.cumsum(model.G_probs))
401+
F = jnp.cumsum(jnp.asarray(model.F_probs))
402+
G = jnp.cumsum(jnp.asarray(model.G_probs))
415403
v_star = solve_model(model)
416-
greedy_star = get_greedy_policy(model, v_star)
417-
418-
def gen_path(optimal_policy, F, G, model, t=20):
419-
i = j = 0
420-
θ_index = []
421-
ε_index = []
422-
for t in range(t):
423-
if optimal_policy[i, j] == 1: # Stay put
424-
pass
425-
426-
elif optimal_policy[i, j] == 2: # New job
427-
j = qe.random.draw(G)
428-
429-
else: # New life
430-
i, j = qe.random.draw(F), qe.random.draw(G)
431-
θ_index.append(i)
432-
ε_index.append(j)
433-
434-
# Convert lists to JAX arrays for indexing
435-
θ_indices = jnp.array(θ_index)
436-
ε_indices = jnp.array(ε_index)
437-
return model.θ[θ_indices], model.ε[ε_indices]
404+
greedy_star = jnp.asarray(get_greedy_policy(model, v_star))
438405
406+
def draw_from_cdf(key, cdf):
407+
u = jr.uniform(key)
408+
return jnp.searchsorted(cdf, u, side="left")
439409
410+
def gen_path(optimal_policy, F, G, model, t=20, key=None):
411+
if key is None:
412+
key = jr.PRNGKey(0)
413+
i = 0
414+
j = 0
415+
theta_idx = []
416+
eps_idx = []
417+
for _ in range(t):
418+
a = optimal_policy[i, j]
419+
key, k1, k2 = jr.split(key, 3)
420+
if a == 1: # Stay put
421+
pass
422+
elif a == 2: # New job
423+
j = draw_from_cdf(k1, G)
424+
else: # New life
425+
i = draw_from_cdf(k1, F)
426+
j = draw_from_cdf(k2, G)
427+
theta_idx.append(i)
428+
eps_idx.append(j)
429+
430+
theta_idx = jnp.array(theta_idx, dtype=jnp.int32)
431+
eps_idx = jnp.array(eps_idx, dtype=jnp.int32)
432+
return model.θ[theta_idx], model.ε[eps_idx], key
433+
434+
key = jr.PRNGKey(42)
440435
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
436+
441437
for ax in axes:
442-
θ_path, ε_path = gen_path(greedy_star, F, G, model)
438+
key, subkey = jr.split(key)
439+
θ_path, ε_path, _ = gen_path(greedy_star, F, G, model, key=subkey)
443440
ax.plot(ε_path, label='ε')
444441
ax.plot(θ_path, label='θ')
445442
ax.set_ylim(0, 6)
443+
ax.legend(loc='upper right')
446444
447-
plt.legend()
445+
plt.tight_layout()
448446
plt.show()
449447
```
450448

@@ -486,28 +484,33 @@ The median for the original parameterization can be computed as follows
486484

487485
```{code-cell} ipython3
488486
model = create_career_worker_problem()
489-
F = np.array(jnp.cumsum(model.F_probs))
490-
G = np.array(jnp.cumsum(model.G_probs))
487+
F = jnp.cumsum(jnp.asarray(model.F_probs))
488+
G = jnp.cumsum(jnp.asarray(model.G_probs))
491489
v_star = solve_model(model)
492-
greedy_star = get_greedy_policy(model, v_star)
493-
494-
def passage_time(optimal_policy, F, G):
495-
t = 0
496-
i = j = 0
497-
while True:
498-
if optimal_policy[i, j] == 1: # Stay put
499-
return t
500-
elif optimal_policy[i, j] == 2: # New job
501-
j = qe.random.draw(G)
502-
else: # New life
503-
i, j = qe.random.draw(F), qe.random.draw(G)
504-
t += 1
505-
506-
def median_time(optimal_policy, F, G, M=25000):
507-
samples = []
508-
for i in range(M):
509-
samples.append(passage_time(optimal_policy, F, G))
510-
return jnp.median(jnp.array(samples))
490+
greedy_star = jnp.asarray(get_greedy_policy(model, v_star))
491+
492+
def passage_time(optimal_policy, F, G, key):
493+
def cond(state):
494+
i, j, t, key = state
495+
return optimal_policy[i, j] != 1
496+
497+
def body(state):
498+
i, j, t, key = state
499+
a = optimal_policy[i, j]
500+
key, k1, k2 = jr.split(key, 3)
501+
new_j = draw_from_cdf(k1, G)
502+
new_i = draw_from_cdf(k2, F)
503+
i = jnp.where(a == 3, new_i, i)
504+
j = jnp.where((a == 2) | (a == 3), new_j, j)
505+
return i, j, t + 1, key
506+
507+
i, j, t, _ = jax.lax.while_loop(cond, body, (0, 0, 0, key))
508+
return t
509+
510+
def median_time(optimal_policy, F, G, M=25000, seed=0):
511+
keys = jr.split(jr.PRNGKey(seed), M)
512+
times = jax.vmap(lambda k: passage_time(optimal_policy, F, G, k))(keys)
513+
return jnp.median(times)
511514
512515
median_time(greedy_star, F, G)
513516
```

0 commit comments

Comments
 (0)