Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
_build/
.DS_Store
.DS_Store
__pycache__/
*.pyc
*.pyo
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ dependencies:
- sphinx-togglebutton==0.3.2
# Docker Requirements
- pytz
# JAX for lecture content
- jax
Comment on lines +21 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HumphreyYang I may remove this later on. I saw you added jax in the lecture itself.


371 changes: 371 additions & 0 deletions lectures/_static/lecture_specific/amss/jax_amss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
"""
JAX-based AMSS model implementation.
Converted from NumPy/Numba classes to JAX pure functions and NamedTuple structures.
"""

import jax.numpy as jnp
from jax import jit, grad, vmap
import jax
from scipy.optimize import minimize # Use scipy for now
from typing import NamedTuple, Callable
try:
from .jax_utilities import UtilityFunctions
from .jax_interpolation import nodes_from_grid, eval_linear_jax
except ImportError:
from jax_utilities import UtilityFunctions
from jax_interpolation import nodes_from_grid, eval_linear_jax


class AMSSState(NamedTuple):
"""State variables for AMSS model."""
s: int # Current Markov state
x: float # Continuation value state variable


class AMSSParams(NamedTuple):
"""Parameters for AMSS model."""
β: float # Discount factor
Π: jnp.ndarray # Markov transition matrix
g: jnp.ndarray # Government spending by state
x_grid: tuple # Grid parameters (x_min, x_max, x_num)
bounds_v: jnp.ndarray # Bounds for optimization
utility: UtilityFunctions # Utility functions


class AMSSPolicies(NamedTuple):
"""Policy functions for AMSS model."""
V: jnp.ndarray # Value function
σ_v_star: jnp.ndarray # Policy function for time t >= 1
W: jnp.ndarray # Value function for time 0
σ_w_star: jnp.ndarray # Policy function for time 0


@jit
def compute_consumption_leisure(l, g):
"""Compute consumption given leisure and government spending."""
return (1 - l) - g


@jit
def objective_V(σ, state, V, params: AMSSParams):
"""
Objective function for time t >= 1 value function iteration.

Parameters
----------
σ : array
Policy variables [l_1, ..., l_S, T_1, ..., T_S]
state : tuple
Current state (s_, x_)
V : array
Current value function
params : AMSSParams
Model parameters

Returns
-------
float
Negative of expected value (for minimization)
"""
s_, x_ = state
S = len(params.Π)

l = σ[:S]
T = σ[S:]

c = compute_consumption_leisure(l, params.g)
u_c = vmap(params.utility.Uc)(c, l)
Eu_c = params.Π[s_] @ u_c

x = (u_c * x_ / (params.β * Eu_c) -
u_c * (c - T) +
vmap(params.utility.Ul)(c, l) * (1 - l))

# Interpolate next period value function
x_nodes = nodes_from_grid(params.x_grid)
V_next = jnp.array([eval_linear_jax(params.x_grid, V[s], jnp.array([x[s]]))[0]
for s in range(S)])

expected_value = params.Π[s_] @ (vmap(params.utility.U)(c, l) + params.β * V_next)

return -expected_value # Negative for minimization


@jit
def objective_W(σ, state, V, params: AMSSParams):
"""
Objective function for time 0 problem.

Parameters
----------
σ : array
Policy variables [l, T]
state : tuple
Current state (s_, b_0)
V : array
Value function
params : AMSSParams
Model parameters

Returns
-------
float
Negative of value (for minimization)
"""
s_, b_0 = state
l, T = σ

c = compute_consumption_leisure(l, params.g[s_])
x = (-params.utility.Uc(c, l) * (c - T - b_0) +
params.utility.Ul(c, l) * (1 - l))

V_next = eval_linear_jax(params.x_grid, V[s_], jnp.array([x]))[0]
value = params.utility.U(c, l) + params.β * V_next

return -value # Negative for minimization


def solve_bellman_iteration(V, σ_v_star, params: AMSSParams,
tol=1e-7, max_iter=1000, print_freq=10):
"""
Solve the Bellman equation using value function iteration.

Parameters
----------
V : array
Initial value function guess
σ_v_star : array
Initial policy function guess
params : AMSSParams
Model parameters
tol : float
Convergence tolerance
max_iter : int
Maximum iterations
print_freq : int
Print frequency

Returns
-------
tuple
Updated (V, σ_v_star)
"""
S = len(params.Π)
x_nodes = nodes_from_grid(params.x_grid)
n_x = len(x_nodes)

V_new = jnp.zeros_like(V)

for iteration in range(max_iter):
V_updated = jnp.zeros_like(V)
σ_updated = jnp.zeros_like(σ_v_star)

# Loop over states and grid points
for s_ in range(S):
for x_i in range(n_x):
state = (s_, x_nodes[x_i])
x0 = σ_v_star[s_, x_i]

# Optimize using JAX
bounds = [(params.bounds_v[i, 0], params.bounds_v[i, 1])
for i in range(len(params.bounds_v))]

# Simple optimization using scipy-like interface
result = minimize(
lambda σ: objective_V(σ, state, V, params),
x0,
method='L-BFGS-B',
bounds=bounds
)

if result.success:
V_updated = V_updated.at[s_, x_i].set(-result.fun)
σ_updated = σ_updated.at[s_, x_i].set(result.x)
else:
print(f"Optimization failed at state {s_}, grid point {x_i}")
V_updated = V_updated.at[s_, x_i].set(V[s_, x_i])
σ_updated = σ_updated.at[s_, x_i].set(σ_v_star[s_, x_i])

# Check convergence
error = jnp.max(jnp.abs(V_updated - V))

if error < tol:
print(f'Successfully completed VFI after {iteration + 1} iterations')
return V_updated, σ_updated

if (iteration + 1) % print_freq == 0:
print(f'Error at iteration {iteration + 1}: {error}')

V = V_updated
σ_v_star = σ_updated

print(f'VFI did not converge after {max_iter} iterations')
return V, σ_v_star


def solve_time_zero_problem(b_0, V, params: AMSSParams):
"""
Solve the time 0 problem.

Parameters
----------
b_0 : float
Initial debt
V : array
Value function from time 1 problem
params : AMSSParams
Model parameters

Returns
-------
tuple
(W, σ_w_star) where W is time 0 values and σ_w_star is time 0 policies
"""
S = len(params.Π)
W = jnp.zeros(S)
σ_w_star = jnp.zeros((S, 2))

bounds_w = [(-9.0, 1.0), (0.0, 10.0)]

for s_ in range(S):
state = (s_, b_0)
x0 = jnp.array([-0.05, 0.5]) # Initial guess

result = minimize(
lambda σ: objective_W(σ, state, V, params),
x0,
method='L-BFGS-B',
bounds=bounds_w
)

W = W.at[s_].set(-result.fun)
σ_w_star = σ_w_star.at[s_].set(result.x)

print('Successfully solved the time 0 problem.')
return W, σ_w_star


@jit
def simulate_amss(s_hist, b_0, policies: AMSSPolicies, params: AMSSParams):
"""
Simulate AMSS model given state history and initial debt.

Parameters
----------
s_hist : array
History of Markov states
b_0 : float
Initial debt level
policies : AMSSPolicies
Solved policy functions
params : AMSSParams
Model parameters

Returns
-------
dict
Simulation results with arrays for c, n, b, τ, g
"""
T = len(s_hist)
S = len(params.Π)
x_nodes = nodes_from_grid(params.x_grid)

# Pre-allocate arrays
n_hist = jnp.zeros(T)
x_hist = jnp.zeros(T)
c_hist = jnp.zeros(T)
τ_hist = jnp.zeros(T)
b_hist = jnp.zeros(T)
g_hist = jnp.zeros(T)

# Time 0
s_0 = s_hist[0]
l_0, T_0 = policies.σ_w_star[s_0]
c_0 = compute_consumption_leisure(l_0, params.g[s_0])
x_0 = (-params.utility.Uc(c_0, l_0) * (c_0 - T_0 - b_0) +
params.utility.Ul(c_0, l_0) * (1 - l_0))

n_hist = n_hist.at[0].set(1 - l_0)
x_hist = x_hist.at[0].set(x_0)
c_hist = c_hist.at[0].set(c_0)
τ_hist = τ_hist.at[0].set(1 - params.utility.Ul(c_0, l_0) / params.utility.Uc(c_0, l_0))
b_hist = b_hist.at[0].set(b_0)
g_hist = g_hist.at[0].set(params.g[s_0])

# Time t > 0
for t in range(T - 1):
x_ = x_hist[t]
s_ = s_hist[t]

# Interpolate policies for all states
l = jnp.array([eval_linear_jax(params.x_grid, policies.σ_v_star[s_, :, s],
jnp.array([x_]))[0] for s in range(S)])
T_vals = jnp.array([eval_linear_jax(params.x_grid, policies.σ_v_star[s_, :, S+s],
jnp.array([x_]))[0] for s in range(S)])

c = compute_consumption_leisure(l, params.g)
u_c = vmap(params.utility.Uc)(c, l)
Eu_c = params.Π[s_] @ u_c

x = (u_c * x_ / (params.β * Eu_c) -
u_c * (c - T_vals) +
vmap(params.utility.Ul)(c, l) * (1 - l))

s_next = s_hist[t+1]
c_next = c[s_next]
l_next = l[s_next]

x_hist = x_hist.at[t+1].set(x[s_next])
n_hist = n_hist.at[t+1].set(1 - l_next)
c_hist = c_hist.at[t+1].set(c_next)
τ_hist = τ_hist.at[t+1].set(1 - params.utility.Ul(c_next, l_next) / params.utility.Uc(c_next, l_next))
b_hist = b_hist.at[t+1].set(x_ / (params.β * Eu_c))
g_hist = g_hist.at[t+1].set(params.g[s_next])

return {
'c': c_hist,
'n': n_hist,
'b': b_hist,
'τ': τ_hist,
'g': g_hist
}


def solve_amss_model(params: AMSSParams, V_init, σ_v_init, b_0,
W_init=None, σ_w_init=None, **kwargs):
"""
Solve the complete AMSS model.

Parameters
----------
params : AMSSParams
Model parameters
V_init : array
Initial value function guess
σ_v_init : array
Initial policy function guess
b_0 : float
Initial debt level
W_init : array, optional
Initial time 0 value function
σ_w_init : array, optional
Initial time 0 policy function
**kwargs
Additional arguments for solver

Returns
-------
AMSSPolicies
Solved policy functions
"""
print("===============")
print("Solve time 1 problem")
print("===============")
V, σ_v_star = solve_bellman_iteration(V_init, σ_v_init, params, **kwargs)

print("===============")
print("Solve time 0 problem")
print("===============")
W, σ_w_star = solve_time_zero_problem(b_0, V, params)

return AMSSPolicies(V=V, σ_v_star=σ_v_star, W=W, σ_w_star=σ_w_star)
Loading
Loading