-
-
Notifications
You must be signed in to change notification settings - Fork 24
Migrate AMSS lecture from NumPy/Numba to JAX with automatic differentiation #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 6 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
25ca89d
Initial plan
Copilot d5e578f
Create JAX-based utility functions and interpolation utilities for AM…
Copilot 58fb11c
Complete JAX migration of AMSS lecture with working examples
Copilot 9f0ac03
Fix function naming and test complete JAX AMSS workflow
Copilot de35bb0
add jax install
HumphreyYang 02dc2cc
Fix JAX dependency in environment.yml and missing AMSSParams import
Copilot d84ff87
Fix import issue in jax_amss_simple.py for Jupyter Book execution con…
Copilot 524fb2c
Fix NameError: government_spending not defined in AMSS lecture
Copilot 1ddcd4e
Fix jax_amss_simple.py execution issue in Jupyter Book context
Copilot 2343d07
Fix undefined variables in AMSS lecture - solution and simulation models
Copilot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,5 @@ | ||
| _build/ | ||
| .DS_Store | ||
| .DS_Store | ||
| __pycache__/ | ||
| *.pyc | ||
| *.pyo |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,4 +18,6 @@ dependencies: | |
| - sphinx-togglebutton==0.3.2 | ||
| # Docker Requirements | ||
| - pytz | ||
| # JAX for lecture content | ||
| - jax | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,371 @@ | ||
| """ | ||
| JAX-based AMSS model implementation. | ||
| Converted from NumPy/Numba classes to JAX pure functions and NamedTuple structures. | ||
| """ | ||
|
|
||
| import jax.numpy as jnp | ||
| from jax import jit, grad, vmap | ||
| import jax | ||
| from scipy.optimize import minimize # Use scipy for now | ||
| from typing import NamedTuple, Callable | ||
| try: | ||
| from .jax_utilities import UtilityFunctions | ||
| from .jax_interpolation import nodes_from_grid, eval_linear_jax | ||
| except ImportError: | ||
| from jax_utilities import UtilityFunctions | ||
| from jax_interpolation import nodes_from_grid, eval_linear_jax | ||
|
|
||
|
|
||
| class AMSSState(NamedTuple): | ||
| """State variables for AMSS model.""" | ||
| s: int # Current Markov state | ||
| x: float # Continuation value state variable | ||
|
|
||
|
|
||
| class AMSSParams(NamedTuple): | ||
| """Parameters for AMSS model.""" | ||
| β: float # Discount factor | ||
| Π: jnp.ndarray # Markov transition matrix | ||
| g: jnp.ndarray # Government spending by state | ||
| x_grid: tuple # Grid parameters (x_min, x_max, x_num) | ||
| bounds_v: jnp.ndarray # Bounds for optimization | ||
| utility: UtilityFunctions # Utility functions | ||
|
|
||
|
|
||
| class AMSSPolicies(NamedTuple): | ||
| """Policy functions for AMSS model.""" | ||
| V: jnp.ndarray # Value function | ||
| σ_v_star: jnp.ndarray # Policy function for time t >= 1 | ||
| W: jnp.ndarray # Value function for time 0 | ||
| σ_w_star: jnp.ndarray # Policy function for time 0 | ||
|
|
||
|
|
||
| @jit | ||
| def compute_consumption_leisure(l, g): | ||
| """Compute consumption given leisure and government spending.""" | ||
| return (1 - l) - g | ||
|
|
||
|
|
||
| @jit | ||
| def objective_V(σ, state, V, params: AMSSParams): | ||
| """ | ||
| Objective function for time t >= 1 value function iteration. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| σ : array | ||
| Policy variables [l_1, ..., l_S, T_1, ..., T_S] | ||
| state : tuple | ||
| Current state (s_, x_) | ||
| V : array | ||
| Current value function | ||
| params : AMSSParams | ||
| Model parameters | ||
|
|
||
| Returns | ||
| ------- | ||
| float | ||
| Negative of expected value (for minimization) | ||
| """ | ||
| s_, x_ = state | ||
| S = len(params.Π) | ||
|
|
||
| l = σ[:S] | ||
| T = σ[S:] | ||
|
|
||
| c = compute_consumption_leisure(l, params.g) | ||
| u_c = vmap(params.utility.Uc)(c, l) | ||
| Eu_c = params.Π[s_] @ u_c | ||
|
|
||
| x = (u_c * x_ / (params.β * Eu_c) - | ||
| u_c * (c - T) + | ||
| vmap(params.utility.Ul)(c, l) * (1 - l)) | ||
|
|
||
| # Interpolate next period value function | ||
| x_nodes = nodes_from_grid(params.x_grid) | ||
| V_next = jnp.array([eval_linear_jax(params.x_grid, V[s], jnp.array([x[s]]))[0] | ||
| for s in range(S)]) | ||
|
|
||
| expected_value = params.Π[s_] @ (vmap(params.utility.U)(c, l) + params.β * V_next) | ||
|
|
||
| return -expected_value # Negative for minimization | ||
|
|
||
|
|
||
| @jit | ||
| def objective_W(σ, state, V, params: AMSSParams): | ||
| """ | ||
| Objective function for time 0 problem. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| σ : array | ||
| Policy variables [l, T] | ||
| state : tuple | ||
| Current state (s_, b_0) | ||
| V : array | ||
| Value function | ||
| params : AMSSParams | ||
| Model parameters | ||
|
|
||
| Returns | ||
| ------- | ||
| float | ||
| Negative of value (for minimization) | ||
| """ | ||
| s_, b_0 = state | ||
| l, T = σ | ||
|
|
||
| c = compute_consumption_leisure(l, params.g[s_]) | ||
| x = (-params.utility.Uc(c, l) * (c - T - b_0) + | ||
| params.utility.Ul(c, l) * (1 - l)) | ||
|
|
||
| V_next = eval_linear_jax(params.x_grid, V[s_], jnp.array([x]))[0] | ||
| value = params.utility.U(c, l) + params.β * V_next | ||
|
|
||
| return -value # Negative for minimization | ||
|
|
||
|
|
||
| def solve_bellman_iteration(V, σ_v_star, params: AMSSParams, | ||
| tol=1e-7, max_iter=1000, print_freq=10): | ||
| """ | ||
| Solve the Bellman equation using value function iteration. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| V : array | ||
| Initial value function guess | ||
| σ_v_star : array | ||
| Initial policy function guess | ||
| params : AMSSParams | ||
| Model parameters | ||
| tol : float | ||
| Convergence tolerance | ||
| max_iter : int | ||
| Maximum iterations | ||
| print_freq : int | ||
| Print frequency | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple | ||
| Updated (V, σ_v_star) | ||
| """ | ||
| S = len(params.Π) | ||
| x_nodes = nodes_from_grid(params.x_grid) | ||
| n_x = len(x_nodes) | ||
|
|
||
| V_new = jnp.zeros_like(V) | ||
|
|
||
| for iteration in range(max_iter): | ||
| V_updated = jnp.zeros_like(V) | ||
| σ_updated = jnp.zeros_like(σ_v_star) | ||
|
|
||
| # Loop over states and grid points | ||
| for s_ in range(S): | ||
| for x_i in range(n_x): | ||
| state = (s_, x_nodes[x_i]) | ||
| x0 = σ_v_star[s_, x_i] | ||
|
|
||
| # Optimize using JAX | ||
| bounds = [(params.bounds_v[i, 0], params.bounds_v[i, 1]) | ||
| for i in range(len(params.bounds_v))] | ||
|
|
||
| # Simple optimization using scipy-like interface | ||
| result = minimize( | ||
| lambda σ: objective_V(σ, state, V, params), | ||
| x0, | ||
| method='L-BFGS-B', | ||
| bounds=bounds | ||
| ) | ||
|
|
||
| if result.success: | ||
| V_updated = V_updated.at[s_, x_i].set(-result.fun) | ||
| σ_updated = σ_updated.at[s_, x_i].set(result.x) | ||
| else: | ||
| print(f"Optimization failed at state {s_}, grid point {x_i}") | ||
| V_updated = V_updated.at[s_, x_i].set(V[s_, x_i]) | ||
| σ_updated = σ_updated.at[s_, x_i].set(σ_v_star[s_, x_i]) | ||
|
|
||
| # Check convergence | ||
| error = jnp.max(jnp.abs(V_updated - V)) | ||
|
|
||
| if error < tol: | ||
| print(f'Successfully completed VFI after {iteration + 1} iterations') | ||
| return V_updated, σ_updated | ||
|
|
||
| if (iteration + 1) % print_freq == 0: | ||
| print(f'Error at iteration {iteration + 1}: {error}') | ||
|
|
||
| V = V_updated | ||
| σ_v_star = σ_updated | ||
|
|
||
| print(f'VFI did not converge after {max_iter} iterations') | ||
| return V, σ_v_star | ||
|
|
||
|
|
||
| def solve_time_zero_problem(b_0, V, params: AMSSParams): | ||
| """ | ||
| Solve the time 0 problem. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| b_0 : float | ||
| Initial debt | ||
| V : array | ||
| Value function from time 1 problem | ||
| params : AMSSParams | ||
| Model parameters | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple | ||
| (W, σ_w_star) where W is time 0 values and σ_w_star is time 0 policies | ||
| """ | ||
| S = len(params.Π) | ||
| W = jnp.zeros(S) | ||
| σ_w_star = jnp.zeros((S, 2)) | ||
|
|
||
| bounds_w = [(-9.0, 1.0), (0.0, 10.0)] | ||
|
|
||
| for s_ in range(S): | ||
| state = (s_, b_0) | ||
| x0 = jnp.array([-0.05, 0.5]) # Initial guess | ||
|
|
||
| result = minimize( | ||
| lambda σ: objective_W(σ, state, V, params), | ||
| x0, | ||
| method='L-BFGS-B', | ||
| bounds=bounds_w | ||
| ) | ||
|
|
||
| W = W.at[s_].set(-result.fun) | ||
| σ_w_star = σ_w_star.at[s_].set(result.x) | ||
|
|
||
| print('Successfully solved the time 0 problem.') | ||
| return W, σ_w_star | ||
|
|
||
|
|
||
| @jit | ||
| def simulate_amss(s_hist, b_0, policies: AMSSPolicies, params: AMSSParams): | ||
| """ | ||
| Simulate AMSS model given state history and initial debt. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| s_hist : array | ||
| History of Markov states | ||
| b_0 : float | ||
| Initial debt level | ||
| policies : AMSSPolicies | ||
| Solved policy functions | ||
| params : AMSSParams | ||
| Model parameters | ||
|
|
||
| Returns | ||
| ------- | ||
| dict | ||
| Simulation results with arrays for c, n, b, τ, g | ||
| """ | ||
| T = len(s_hist) | ||
| S = len(params.Π) | ||
| x_nodes = nodes_from_grid(params.x_grid) | ||
|
|
||
| # Pre-allocate arrays | ||
| n_hist = jnp.zeros(T) | ||
| x_hist = jnp.zeros(T) | ||
| c_hist = jnp.zeros(T) | ||
| τ_hist = jnp.zeros(T) | ||
| b_hist = jnp.zeros(T) | ||
| g_hist = jnp.zeros(T) | ||
|
|
||
| # Time 0 | ||
| s_0 = s_hist[0] | ||
| l_0, T_0 = policies.σ_w_star[s_0] | ||
| c_0 = compute_consumption_leisure(l_0, params.g[s_0]) | ||
| x_0 = (-params.utility.Uc(c_0, l_0) * (c_0 - T_0 - b_0) + | ||
| params.utility.Ul(c_0, l_0) * (1 - l_0)) | ||
|
|
||
| n_hist = n_hist.at[0].set(1 - l_0) | ||
| x_hist = x_hist.at[0].set(x_0) | ||
| c_hist = c_hist.at[0].set(c_0) | ||
| τ_hist = τ_hist.at[0].set(1 - params.utility.Ul(c_0, l_0) / params.utility.Uc(c_0, l_0)) | ||
| b_hist = b_hist.at[0].set(b_0) | ||
| g_hist = g_hist.at[0].set(params.g[s_0]) | ||
|
|
||
| # Time t > 0 | ||
| for t in range(T - 1): | ||
| x_ = x_hist[t] | ||
| s_ = s_hist[t] | ||
|
|
||
| # Interpolate policies for all states | ||
| l = jnp.array([eval_linear_jax(params.x_grid, policies.σ_v_star[s_, :, s], | ||
| jnp.array([x_]))[0] for s in range(S)]) | ||
| T_vals = jnp.array([eval_linear_jax(params.x_grid, policies.σ_v_star[s_, :, S+s], | ||
| jnp.array([x_]))[0] for s in range(S)]) | ||
|
|
||
| c = compute_consumption_leisure(l, params.g) | ||
| u_c = vmap(params.utility.Uc)(c, l) | ||
| Eu_c = params.Π[s_] @ u_c | ||
|
|
||
| x = (u_c * x_ / (params.β * Eu_c) - | ||
| u_c * (c - T_vals) + | ||
| vmap(params.utility.Ul)(c, l) * (1 - l)) | ||
|
|
||
| s_next = s_hist[t+1] | ||
| c_next = c[s_next] | ||
| l_next = l[s_next] | ||
|
|
||
| x_hist = x_hist.at[t+1].set(x[s_next]) | ||
| n_hist = n_hist.at[t+1].set(1 - l_next) | ||
| c_hist = c_hist.at[t+1].set(c_next) | ||
| τ_hist = τ_hist.at[t+1].set(1 - params.utility.Ul(c_next, l_next) / params.utility.Uc(c_next, l_next)) | ||
| b_hist = b_hist.at[t+1].set(x_ / (params.β * Eu_c)) | ||
| g_hist = g_hist.at[t+1].set(params.g[s_next]) | ||
|
|
||
| return { | ||
| 'c': c_hist, | ||
| 'n': n_hist, | ||
| 'b': b_hist, | ||
| 'τ': τ_hist, | ||
| 'g': g_hist | ||
| } | ||
|
|
||
|
|
||
| def solve_amss_model(params: AMSSParams, V_init, σ_v_init, b_0, | ||
| W_init=None, σ_w_init=None, **kwargs): | ||
| """ | ||
| Solve the complete AMSS model. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| params : AMSSParams | ||
| Model parameters | ||
| V_init : array | ||
| Initial value function guess | ||
| σ_v_init : array | ||
| Initial policy function guess | ||
| b_0 : float | ||
| Initial debt level | ||
| W_init : array, optional | ||
| Initial time 0 value function | ||
| σ_w_init : array, optional | ||
| Initial time 0 policy function | ||
| **kwargs | ||
| Additional arguments for solver | ||
|
|
||
| Returns | ||
| ------- | ||
| AMSSPolicies | ||
| Solved policy functions | ||
| """ | ||
| print("===============") | ||
| print("Solve time 1 problem") | ||
| print("===============") | ||
| V, σ_v_star = solve_bellman_iteration(V_init, σ_v_init, params, **kwargs) | ||
|
|
||
| print("===============") | ||
| print("Solve time 0 problem") | ||
| print("===============") | ||
| W, σ_w_star = solve_time_zero_problem(b_0, V, params) | ||
|
|
||
| return AMSSPolicies(V=V, σ_v_star=σ_v_star, W=W, σ_w_star=σ_w_star) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HumphreyYang I may remove this later on. I saw you added
jaxin the lecture itself.