Skip to content

Commit 222e7c0

Browse files
jstacclaude
andcommitted
Fix variable naming and type hints in stochastic optimal savings lecture
- Rename 'grid' to 'x_grid' throughout for consistency and clarity - Unify parameter name from 'og' to 'model' in solve_model function - Add type hints and inline comments to solve_model parameters - Fix unpacking error in B function to match Model NamedTuple fields - Resolves notebook execution error in remote build 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 99aa1e2 commit 222e7c0

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

lectures/os_stochastic.md

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ class Model(NamedTuple):
484484
β: float # discount factor
485485
μ: float # shock location parameter
486486
ν: float # shock scale parameter
487-
grid: np.ndarray # state grid
487+
x_grid: np.ndarray # state grid
488488
shocks: np.ndarray # shock draws
489489
490490
@@ -503,13 +503,13 @@ def create_model(
503503
Creates an instance of the optimal savings model.
504504
"""
505505
# Set up grid
506-
grid = np.linspace(1e-4, grid_max, grid_size)
506+
x_grid = np.linspace(1e-4, grid_max, grid_size)
507507
508508
# Store shocks (with a seed, so results are reproducible)
509509
np.random.seed(seed)
510510
shocks = np.exp(μ + ν * np.random.randn(shock_size))
511511
512-
return Model(u, f, β, μ, ν, grid, shocks)
512+
return Model(u, f, β, μ, ν, x_grid, shocks)
513513
```
514514

515515
We set up the right-hand side of the Bellman equation
@@ -524,14 +524,13 @@ def B(
524524
x: float,
525525
c: float,
526526
v_array: np.ndarray,
527-
model: Model,
527+
model: Model
528528
) -> float:
529529
"""
530530
Right hand side of the Bellman equation.
531531
"""
532-
u, f, β, shocks = model
533-
grid = model.grid
534-
v = interp1d(grid, v_array)
532+
u, f, β, μ, ν, x_grid, shocks = model
533+
v = interp1d(x_grid, v_array)
535534
536535
return u(c) + β * np.mean(v(f(x - c) * shocks))
537536
```
@@ -565,11 +564,11 @@ def T(v: np.ndarray, model: Model) -> tuple[np.ndarray, np.ndarray]:
565564
* v is an array representing a guess of the value function
566565
567566
"""
568-
grid = model.grid
567+
x_grid = model.x_grid
569568
v_new = np.empty_like(v)
570569
571-
for i in range(len(grid)):
572-
x = grid[i]
570+
for i in range(len(x_grid)):
571+
x = x_grid[i]
573572
c_star, v_max = maximize(lambda c: B(x, c, v, model), x)
574573
v_new[i] = v_max
575574
@@ -581,7 +580,7 @@ Here's the function:
581580
```{code-cell} python3
582581
def get_greedy(
583582
v: np.ndarray, # current guess of the value function
584-
model: Model # instance of cake eating model
583+
model: Model # instance of optimal savings model
585584
):
586585
" Compute the v-greedy policy on x_grid."
587586
@@ -669,15 +668,15 @@ In theory, since $v^*$ is a fixed point, the resulting function should again be
669668
In practice, we expect some small numerical error.
670669

671670
```{code-cell} python3
672-
grid = model.grid
671+
x_grid = model.x_grid
673672
674-
v_init = v_star(grid, α, model.β, model.μ) # Start at the solution
673+
v_init = v_star(x_grid, α, model.β, model.μ) # Start at the solution
675674
v = T(v_init, model) # Apply T once
676675
677676
fig, ax = plt.subplots()
678677
ax.set_ylim(-35, -24)
679-
ax.plot(grid, v, lw=2, alpha=0.6, label='$Tv^*$')
680-
ax.plot(grid, v_init, lw=2, alpha=0.6, label='$v^*$')
678+
ax.plot(x_grid, v, lw=2, alpha=0.6, label='$Tv^*$')
679+
ax.plot(x_grid, v_init, lw=2, alpha=0.6, label='$v^*$')
681680
ax.legend()
682681
plt.show()
683682
```
@@ -690,23 +689,23 @@ from an arbitrary initial condition.
690689
The initial condition we'll start with is, somewhat arbitrarily, $v(x) = 5 \ln (x)$.
691690

692691
```{code-cell} python3
693-
v = 5 * np.log(grid) # An initial condition
692+
v = 5 * np.log(x_grid) # An initial condition
694693
n = 35
695694
696695
fig, ax = plt.subplots()
697696
698-
ax.plot(grid, v, color=plt.cm.jet(0),
697+
ax.plot(x_grid, v, color=plt.cm.jet(0),
699698
lw=2, alpha=0.6, label='Initial condition')
700699
701700
for i in range(n):
702701
v = T(v, model) # Apply the Bellman operator
703-
ax.plot(grid, v, color=plt.cm.jet(i / n), lw=2, alpha=0.6)
702+
ax.plot(x_grid, v, color=plt.cm.jet(i / n), lw=2, alpha=0.6)
704703
705-
ax.plot(grid, v_star(grid, α, model.β, model.μ), 'k-', lw=2,
704+
ax.plot(x_grid, v_star(x_grid, α, model.β, model.μ), 'k-', lw=2,
706705
alpha=0.8, label='True value function')
707706
708707
ax.legend()
709-
ax.set(ylim=(-40, 10), xlim=(np.min(grid), np.max(grid)))
708+
ax.set(ylim=(-40, 10), xlim=(np.min(x_grid), np.max(x_grid)))
710709
plt.show()
711710
```
712711

@@ -725,23 +724,25 @@ We can write a function that iterates until the difference is below a particular
725724
tolerance level.
726725

727726
```{code-cell} python3
728-
def solve_model(og,
729-
tol=1e-4,
730-
max_iter=1000,
731-
verbose=True,
732-
print_skip=25):
727+
def solve_model(
728+
model: Model, # instance of optimal savings model
729+
tol: float = 1e-4, # convergence tolerance
730+
max_iter: int = 1000, # maximum iterations
731+
verbose: bool = True, # print iteration info
732+
print_skip: int = 25 # iterations between prints
733+
):
733734
"""
734735
Solve model by iterating with the Bellman operator.
735736
736737
"""
737738
738739
# Set up loop
739-
v = og.u(og.grid) # Initial condition
740+
v = model.u(model.x_grid) # Initial condition
740741
i = 0
741742
error = tol + 1
742743
743744
while i < max_iter and error > tol:
744-
v_new = T(v, og)
745+
v_new = T(v, model)
745746
error = np.max(np.abs(v - v_new))
746747
i += 1
747748
if verbose and i % print_skip == 0:
@@ -768,10 +769,10 @@ Now we check our result by plotting it against the true value:
768769
```{code-cell} python3
769770
fig, ax = plt.subplots()
770771
771-
ax.plot(grid, v_solution, lw=2, alpha=0.6,
772+
ax.plot(x_grid, v_solution, lw=2, alpha=0.6,
772773
label='Approximate value function')
773774
774-
ax.plot(grid, v_star(grid, α, model.β, model.μ), lw=2,
775+
ax.plot(x_grid, v_star(x_grid, α, model.β, model.μ), lw=2,
775776
alpha=0.6, label='True value function')
776777
777778
ax.legend()
@@ -794,10 +795,10 @@ above, is $\sigma(x) = (1 - \alpha \beta) x$
794795
```{code-cell} python3
795796
fig, ax = plt.subplots()
796797
797-
ax.plot(grid, v_greedy, lw=2,
798+
ax.plot(x_grid, v_greedy, lw=2,
798799
alpha=0.6, label='approximate policy function')
799800
800-
ax.plot(grid, σ_star(grid, α, model.β), '--',
801+
ax.plot(x_grid, σ_star(x_grid, α, model.β), '--',
801802
lw=2, alpha=0.6, label='true policy function')
802803
803804
ax.legend()
@@ -854,7 +855,7 @@ Let's plot the policy function just to see what it looks like:
854855

855856
```{code-cell} python3
856857
fig, ax = plt.subplots()
857-
ax.plot(grid, v_greedy, lw=2,
858+
ax.plot(x_grid, v_greedy, lw=2,
858859
alpha=0.6, label='Approximate optimal policy')
859860
ax.legend()
860861
plt.show()

0 commit comments

Comments
 (0)