Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Find value function crossings in dcegm.calcMultilineEnvelope #758

Merged
merged 9 commits into from
Jul 21, 2020
251 changes: 235 additions & 16 deletions HARK/dcegm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,110 @@
import numpy as np
from HARK.interpolation import LinearInterp

def calcCrossPoints(mGrid, condVs, optIdx):
'''
Given a grid of m values, a matrix of the conditional values of different
actions at every grid point, and a vector indicating the optimal action
at each grid point, this function computes the coordinates of the
crossing points that happen when the optimal action changes

Parameters
----------
mGrid : np.array
Market resources grid.
condVs : np.array must have as many rows as possible discrete actions, and
as many columns as m gridpoints there are.
Conditional value functions

optIdx : np.array of indices
Optimal decision at each grid point
Returns
-------
xing_points: [tuple]
List of crossing points, each as an (m,v) tuple.

segments: np.array with two columns and as many rows as xing points.
Each row represents a crossing point. The first column is the index
of the optimal action to the left, and the second, to the right.
'''
# Compute differences in the optimal index,
# to find positions of choice-changes
diff_max = np.insert(np.diff(optIdx),len(optIdx)-1,0)
idx_change = np.where(diff_max != 0)[0]

# If no crossings, return an empty list
if len(idx_change) == 0:
return []
else:

# To find the crossing points we need the extremes of the intervals in
# which they happen, and the two candidate segments evaluated in both
# extremes. switchMs[0] has the left points and switchMs[1] the right
# points of these intervals.
switchMs = np.stack([mGrid[idx_change], mGrid[idx_change + 1]],
axis = 1)

# Store the indices of the two segments involved in the changes, by
# looking at the argmax in the switching possitions.
# Columns are [0]: left extreme, [1]: right extreme,
# Rows are individual crossing points.
segments = np.stack([optIdx[idx_change], optIdx[idx_change + 1]],
axis = 1)

# Values of both segments on the left point
left_v = np.stack([condVs[idx_change, segments[:,0]],
condVs[idx_change, segments[:,1]]], axis = 1)
# and the right point
right_v = np.stack([condVs[idx_change+1, segments[:,0]],
condVs[idx_change+1, segments[:,1]]], axis = 1)

# A valid crossing must have both switching segments well defined at the
# encompassing gridpoints. Filter those that do not.
valid = np.logical_and(~np.isnan(left_v).any(axis = 1),
~np.isnan(right_v).any(axis = 1))

segments = segments[valid,:]
switchMs = switchMs[valid,:]
left_v = left_v[valid,:]
right_v = right_v[valid,:]

# Find crossing points. Returns a list (m,v) tuples. Keep m's only
xing_points = [calcLinearCrossing(switchMs[i,:],
left_v[i,:], right_v[i,:])
for i in range(segments.shape[0])]

return xing_points, segments

def calcPrimKink(mGrid, vTGrids, Choices):
'''
Parameters
----------
mGrid : np.array
Common m grid
vTGrids : [np.array], length = # choices, each element has length = len(mGrid)
value functions evaluated on the common m grid.
Choices : [np.array], length = # choices, each element has length = len(mGrid)
Optimal choices. In the form of choice probability vectors that must
be degenerate

Returns
-------
kinks: [(mCoord, vTCoor)]
list of kink points
segments: [(left, right)]
List of the same length as kinks, where each element is a tuple
indicating which segments are optimal on each side of the kink.
'''

# Construct a vector with the optimal choice at each m point
optChoice = np.empty(len(mGrid), dtype = int)
optChoice[:] = np.nan
for i in range(len(vTGrids)):
idx = Choices[i] == 1
optChoice[idx] = i

return calcCrossPoints(mGrid, vTGrids.T, optChoice)


def calcSegments(x, v):
"""
Expand Down Expand Up @@ -88,8 +192,44 @@ def calcSegments(x, v):
# think! nanargmax makes everythign super ugly because numpy changed the wraning
# in all nan slices to a valueerror...it's nans, aaarghgghg

def calcLinearCrossing(m,left_v, right_v):
"""
Computes the intersection between two line segments, defined by two common
x points, and the values of both segments at both x points

def calcMultilineEnvelope(M, C, V_T, commonM):
Parameters
----------
m : list or np.array, length 2
The two common x coordinates. m[0] < m[1] is assumed
left_v :list or np.array, length 2
y values of the two segments at m[0]
right_v : list or np.array, length 2
y values of the two segments at m[1]

Returns
-------
(m_int, v_int): a tuple with the corrdinates of the intercept.
if there is no intercept in the interval [m[0],m[1]], (None,None)

"""

# Find slopes of both segments
delta_m = m[1] - m[0]
s0 = (right_v[0] - left_v[0])/delta_m
s1 = (right_v[1] - left_v[1])/delta_m

if s1 == s0:
if left_v[0] == left_v[1]:
return (m[0],left_v[0])
else:
return (None, None)
else:
# Find h where intercept happens at m[0] + h
h = (left_v[0] - left_v[1])/(s1-s0)
if (h >= 0 and h <= (m[1]-m[0])):
return (m[0]+h, left_v[0] + h*s0)

def calcMultilineEnvelope(M, C, V_T, commonM, findXings = False):
"""
Do the envelope step of the DCEGM algorithm. Takes in market ressources,
consumption levels, and inverse values from the EGM step. These represent
Expand All @@ -106,7 +246,9 @@ def calcMultilineEnvelope(M, C, V_T, commonM):
transformed values at the EGM grid
commonM : np.array
common grid to do upper envelope calculations on

findXings: boolean
should the exact crossing points of segments be computed and added to
the grids?
Returns
-------

Expand All @@ -129,33 +271,53 @@ def calcMultilineEnvelope(M, C, V_T, commonM):
commonC[:] = np.nan

# Now, loop over all segments as defined by the "kinks" or the combination
# of "rise" and "fall" indeces. These (rise[j], fall[j]) pairs define regions
# of "rise" and "fall" indeces. These (rise[j], fall[j]) pairs define regions.

if findXings:
# We'll save V_T and C interpolating functions to aid crossing points later
vT_funcs = []
c_funcs = []

for j in range(num_kinks):
# Find points in the common grid that are in the range of the points in
# the interval defined by (rise[j], fall[j]).
below = M[rise[j]] >= commonM # boolean array of bad indeces below
above = M[fall[j]] <= commonM # boolen array of bad indeces above
in_range = below + above == 0 # pick out elements that are neither


# based in in_range, find the relevant ressource values to interpolate
m_eval = commonM[in_range]

# create range of indeces in the input arrays
idxs = range(rise[j], fall[j]+1)
# grab ressource values at the relevant indeces
m_idx_j = M[idxs]

# based in in_range, find the relevant ressource values to interpolate
m_eval = commonM[in_range]


# If we need to compute the xing points, create and store the
# interpolating functions. Else simply create them.
if findXings:
# Create and store interpolating functions
vT_funcs.append(LinearInterp(m_idx_j, V_T[idxs], lower_extrap=True))
c_funcs.append(LinearInterp(m_idx_j, C[idxs], lower_extrap=True))

vT_fun = vT_funcs[-1]
c_fun = c_funcs[-1]
else:
vT_fun = LinearInterp(m_idx_j, V_T[idxs], lower_extrap=True)
c_fun = LinearInterp(m_idx_j, C[idxs], lower_extrap=True)

# re-interpolate to common grid
commonV_T[in_range, j] = LinearInterp(m_idx_j, V_T[idxs], lower_extrap=True)(m_eval) # NOQA
commonC[in_range, j] = LinearInterp(m_idx_j, C[idxs], lower_extrap=True)(m_eval) # NOQA Interpolat econsumption also. May not be nesserary
commonV_T[in_range, j] = vT_fun(m_eval) # NOQA
commonC[in_range, j] = c_fun(m_eval) # NOQA Interpolat econsumption also. May not be nesserary

# for each row in the commonV_T matrix, see if all entries are np.nan. This
# would mean that we have no valid value here, so we want to use this boolean
# vector to filter out irrelevant entries of commonV_T.
row_all_nan = np.array([np.all(np.isnan(row)) for row in commonV_T])
# Now take the max of all these line segments.
idx_max = np.zeros(commonM.size, dtype=int)
idx_max[row_all_nan == False] = np.nanargmax(commonV_T[row_all_nan == False], axis=1)

# prefix with upper for variable that are "upper enveloped"
upperV_T = np.zeros(m_len)

Expand Down Expand Up @@ -184,9 +346,66 @@ def calcMultilineEnvelope(M, C, V_T, commonM):
idx_linear = np.unravel_index(rowidx+idx_max, commonC.shape)
upperC = commonC[idx_linear]
upperC[IsNaN] = LinearInterp(commonM[IsNaN == 0], upperC[IsNaN == 0])(commonM[IsNaN])

# TODO calculate cross points of line segments to get the true vertical drops


upperM = commonM.copy() # anticipate this TODO

return upperM, upperC, upperV_T

# If crossing points are requested, compute them. Else just return the
# envelope.
if not findXings:

return upperM, upperC, upperV_T

else:

xing_points, segments = calcCrossPoints(commonM, commonV_T, idx_max)
# keep only the m component of crossing points.
xing_points = [x[0] for x in xing_points]

if len(xing_points) > 0:

# Now construct a set of points that need to be added to the grid in
# order to handle the discontinuities. To points per discontinuity:
# one to the left and one to the right.
num_crosses = len(xing_points)
add_m = np.empty((num_crosses,2))
add_m[:] = np.nan
add_vT = np.empty((num_crosses,2))
add_vT[:] = np.nan
add_c = np.empty((num_crosses,2))
add_c[:] = np.nan

# Fill the list of points interpolating left and right (2 points per
# crossing)
for i in range(num_crosses):
# Left part of the discontinuity
ml = xing_points[i]
add_m[i,0] = ml
add_vT[i,0] = vT_funcs[segments[i,0]](ml)
add_c[i,0] = c_funcs[segments[i,0]](ml)

# Right part of the discontinuity
mr = np.nextafter(ml, np.inf)
add_m[i,1] = mr
add_vT[i,1] = vT_funcs[segments[i,1]](mr)
add_c[i,1] = c_funcs[segments[i,1]](mr)

# Flatten arrays
add_m = add_m.flatten()
add_vT = add_vT.flatten()
add_c = add_c.flatten()

# Filter any points already in the grid
idxIncluded = np.isin(add_m, upperM)
add_m = add_m[~idxIncluded]
add_vT = add_vT[~idxIncluded]
add_c = add_c[~idxIncluded]

# Find positions at which new points must go
insertIdx = np.searchsorted(upperM, add_m)

# Insert
upperM = np.insert(upperM, insertIdx, add_m)
upperC = np.insert(upperC, insertIdx, add_c)
upperV_T = np.insert(upperV_T, insertIdx, add_vT)

return upperM, upperC, upperV_T, xing_points