Skip to content

Commit b624539

Browse files
jstacclaudemmcky
authored
Refine optimal savings lectures with improved clarity and code organization (#754)
* Refine optimal savings lecture series with improved clarity and code organization - Enhance code formatting and comments for better readability across all OS lectures - Improve mathematical notation and explanations in stochastic optimal savings - Restructure function definitions in os_egm_jax for better logical flow - Simplify utility function and Bellman operator implementations - Add clearer documentation of marginal utility approximations in EGM - Remove redundant code and improve variable naming throughout 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Simplify function names and documentation in IFP transient shocks lecture Rename compute_expectation_k to expected_mu for clarity and brevity. Streamline inline comments to improve code readability. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * fix: missing , * Fix variable naming and type hints in stochastic optimal savings lecture - Rename 'grid' to 'x_grid' throughout for consistency and clarity - Unify parameter name from 'og' to 'model' in solve_model function - Add type hints and inline comments to solve_model parameters - Fix unpacking error in B function to match Model NamedTuple fields - Resolves notebook execution error in remote build 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Matt McKay <mmcky@users.noreply.github.com>
1 parent 45eef26 commit b624539

File tree

5 files changed

+176
-156
lines changed

5 files changed

+176
-156
lines changed

lectures/ifp_egm_transient_shocks.md

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -456,17 +456,13 @@ def K(
456456
return jnp.exp(a_y * η + z * b_y)
457457
458458
def compute_c(i, j):
459-
" Function to compute consumption for one (i, j) pair where i >= 1. "
459+
" Compute c_ij when i >= 1 (interior choice). "
460460
461-
def compute_expectation_k(k):
462-
"""
463-
For each k, approximate the integral
464-
465-
∫ u'(σ(R s_i + y(z_k, η'), z_k)) φ(η') dη'
466-
"""
461+
def expected_mu(k):
462+
" Approximate ∫ u'(σ(R s_i + y(z_k, η'), z_k)) φ(η') dη' "
467463
468464
def compute_mu_at_eta(η):
469-
" For each η draw, compute u'(σ(R * s_i + y(z_k, η), z_k)) "
465+
" Compute u'(σ(R * s_i + y(z_k, η), z_k)) "
470466
next_a = R * s[i] + y(z_grid[k], η)
471467
# Interpolate to get σ(R * s_i + y(z_k, η), z_k)
472468
next_c = jnp.interp(next_a, a_in[:, k], c_in[:, k])
@@ -479,10 +475,9 @@ def K(
479475
return jnp.mean(all_draws)
480476
481477
# Compute expectation: Σ_k [∫ u'(σ(...)) φ(η) dη] * Π[j, k]
482-
expectations = jax.vmap(compute_expectation_k)(jnp.arange(n_z))
478+
expectations = jax.vmap(expected_mu)(jnp.arange(n_z))
483479
expectation = jnp.sum(expectations * Π[j, :])
484-
485-
# Invert to get consumption c_{ij} at (s_i, z_j)
480+
# Invert to get consumption c_ij at (s_i, z_j)
486481
return u_prime_inv(β * R * expectation)
487482
488483
# Set up index grids for vmap computation of all c_{ij}

lectures/os_egm.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,12 @@ def K(
241241
# Allocate memory for new consumption array
242242
c_out = np.empty_like(s_grid)
243243
244-
# Solve for updated consumption value
245244
for i, s in enumerate(s_grid):
245+
# Approximate marginal utility ∫ u'(σ(f(s, α)z)) f'(s, α) z ϕ(z)dz
246246
vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
247-
c_out[i] = u_prime_inv(β * np.mean(vals))
247+
mu = np.mean(vals)
248+
# Compute consumption
249+
c_out[i] = u_prime_inv(β * mu)
248250
249251
# Determine corresponding endogenous grid
250252
x_out = s_grid + c_out # x_i = s_i + c_i

lectures/os_egm_jax.md

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,16 @@ class Model(NamedTuple):
9393
α: float # production function parameter
9494
9595
96-
def create_model(β: float = 0.96,
97-
μ: float = 0.0,
98-
s: float = 0.1,
99-
grid_max: float = 4.0,
100-
grid_size: int = 120,
101-
shock_size: int = 250,
102-
seed: int = 1234,
103-
α: float = 0.4) -> Model:
96+
def create_model(
97+
β: float = 0.96,
98+
μ: float = 0.0,
99+
s: float = 0.1,
100+
grid_max: float = 4.0,
101+
grid_size: int = 120,
102+
shock_size: int = 250,
103+
seed: int = 1234,
104+
α: float = 0.4
105+
) -> Model:
104106
"""
105107
Creates an instance of the optimal savings model.
106108
"""
@@ -114,6 +116,17 @@ def create_model(β: float = 0.96,
114116
return Model(β=β, μ=μ, s=s, s_grid=s_grid, shocks=shocks, α=α)
115117
```
116118

119+
120+
We define utility and production functions globally.
121+
122+
```{code-cell} python3
123+
# Define utility and production functions with derivatives
124+
u = lambda c: jnp.log(c)
125+
u_prime = lambda c: 1 / c
126+
u_prime_inv = lambda x: 1 / x
127+
f = lambda k, α: k**α
128+
f_prime = lambda k, α: α * k**(α - 1)
129+
```
117130
Here's the Coleman-Reffett operator using EGM.
118131

119132
The key JAX feature here is `vmap`, which vectorizes the computation over the grid points.
@@ -138,10 +151,13 @@ def K(
138151
139152
# Define function to compute consumption at a single grid point
140153
def compute_c(s):
154+
# Approximate marginal utility ∫ u'(σ(f(s, α)z)) f'(s, α) z ϕ(z)dz
141155
vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
142-
return u_prime_inv(β * jnp.mean(vals))
156+
mu = jnp.mean(vals)
157+
# Calculate consumption
158+
return u_prime_inv(β * mu)
143159
144-
# Vectorize over grid using vmap
160+
# Vectorize and calculate on all exogenous grid points
145161
compute_c_vectorized = jax.vmap(compute_c)
146162
c_out = compute_c_vectorized(s_grid)
147163
@@ -151,18 +167,6 @@ def K(
151167
return c_out, x_out
152168
```
153169

154-
We define utility and production functions globally.
155-
156-
Note that `f` and `f_prime` take `α` as an explicit argument, allowing them to work with JAX's functional programming model.
157-
158-
```{code-cell} python3
159-
# Define utility and production functions with derivatives
160-
u = lambda c: jnp.log(c)
161-
u_prime = lambda c: 1 / c
162-
u_prime_inv = lambda x: 1 / x
163-
f = lambda k, α: k**α
164-
f_prime = lambda k, α: α * k**(α - 1)
165-
```
166170

167171
Now we create a model instance.
168172

@@ -175,11 +179,13 @@ The solver uses JAX's `jax.lax.while_loop` for the iteration and is JIT-compiled
175179

176180
```{code-cell} python3
177181
@jax.jit
178-
def solve_model_time_iter(model: Model,
179-
c_init: jnp.ndarray,
180-
x_init: jnp.ndarray,
181-
tol: float = 1e-5,
182-
max_iter: int = 1000):
182+
def solve_model_time_iter(
183+
model: Model,
184+
c_init: jnp.ndarray,
185+
x_init: jnp.ndarray,
186+
tol: float = 1e-5,
187+
max_iter: int = 1000
188+
):
183189
"""
184190
Solve the model using time iteration with EGM.
185191
"""

lectures/os_numerical.md

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ This is a form of **successive approximation**, and was discussed in our {doc}`l
9292
The basic idea is:
9393

9494
1. Take an arbitrary initial guess of $v$.
95-
1. Obtain an update $w$ defined by
95+
1. Obtain an update $\hat v$ defined by
9696

9797
$$
98-
w(x) = \max_{0\leq c \leq x} \{u(c) + \beta v(x-c)\}
98+
\hat v(x) = \max_{0\leq c \leq x} \{u(c) + \beta v(x-c)\}
9999
$$
100100

101-
1. Stop if $w$ is approximately equal to $v$, otherwise set
102-
$v=w$ and go back to step 2.
101+
1. Stop if $\hat v$ is approximately equal to $v$, otherwise set
102+
$v=\hat v$ and go back to step 2.
103103

104104
Let's write this a bit more mathematically.
105105

@@ -109,7 +109,7 @@ We introduce the **Bellman operator** $T$ that takes a function v as an
109109
argument and returns a new function $Tv$ defined by
110110

111111
$$
112-
Tv(x) = \max_{0 \leq c \leq x} \{u(c) + \beta v(x - c)\}
112+
Tv(x) = \max_{0 \leq c \leq x} \{u(c) + \beta v(x - c)\}
113113
$$
114114

115115
From $v$ we get $Tv$, and applying $T$ to this yields
@@ -206,13 +206,7 @@ Here's the CRRA utility function.
206206

207207
```{code-cell} python3
208208
def u(c, γ):
209-
"""
210-
Utility function.
211-
"""
212-
if γ == 1:
213-
return np.log(c)
214-
else:
215-
return (c ** (1 - γ)) / (1 - γ)
209+
return (c ** (1 - γ)) / (1 - γ)
216210
```
217211

218212
To work with the Bellman equation, let's write it as
@@ -240,8 +234,8 @@ def B(
240234
Right hand side of the Bellman equation given x and c.
241235
242236
"""
243-
# Unpack
244-
β, γ, x_grid = model.β, model.γ, model.x_grid
237+
# Unpack (simplify names)
238+
β, γ, x_grid = model
245239
246240
# Convert array v into a function by linear interpolation
247241
vf = lambda x: np.interp(x, x_grid, v)
@@ -250,7 +244,12 @@ def B(
250244
return u(c, γ) + β * vf(x - c)
251245
```
252246

253-
We now define the Bellman operation:
247+
We now define the Bellman operator acting on grid points:
248+
249+
$$
250+
Tv(x_i) = \max_{0 \leq c \leq x_i} B(x_i, c, v)
251+
\qquad \text{for all } i
252+
$$
254253

255254
```{code-cell} python3
256255
def T(
@@ -280,7 +279,7 @@ model = create_cake_eating_model()
280279
β, γ, x_grid = model
281280
```
282281

283-
Now let's see the iteration of the value function in action.
282+
Now let's see iteration of the value function in action.
284283

285284
We start from guess $v$ given by $v(x) = u(x)$ for every
286285
$x$ grid point.

0 commit comments

Comments
 (0)