Skip to content

Commit

Permalink
Merge pull request #236 from QuantEcon/ddp-dev
Browse files Browse the repository at this point in the history
Drop `num_actions` from DiscreteDP
  • Loading branch information
mmcky committed Mar 17, 2016
2 parents 14c0caf + fd3fd1a commit b0cb89a
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions quantecon/markov/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,6 @@ class DiscreteDP(object):
num_states : scalar(int)
Number of states.
num_actions : scalar(int)
Number of actions.
num_sa_pairs : scalar(int)
Number of feasible state-action pairs (or those that yield
finite rewards).
Expand Down Expand Up @@ -322,7 +319,6 @@ def __init__(self, R, Q, beta, s_indices=None, a_indices=None):

self.s_indices = np.asarray(s_indices)
self.a_indices = np.asarray(a_indices)
self.num_actions = self.a_indices.max() + 1

if _has_sorted_sa_indices(self.s_indices, self.a_indices):
a_indptr = np.empty(self.num_states+1, dtype=int)
Expand Down Expand Up @@ -372,22 +368,21 @@ def s_wise_max(vals, out=None, out_argmax=None):
else: # Not self._sa_pair
if self.R.ndim != 2:
raise ValueError(msg_dimension)
self.num_states, self.num_actions = self.R.shape
n, m = self.R.shape

if self.Q.shape != \
(self.num_states, self.num_actions, self.num_states):
if self.Q.shape != (n, m, n):
raise ValueError(msg_shape)

self.num_states = n
self.s_indices, self.a_indices = None, None
self.num_sa_pairs = (self.R > -np.inf).sum()

# Define state-wise maximization
def s_wise_max(vals, out=None, out_argmax=None):
"""
Return the vector max_a vals(s, a), where vals is represented
by a 2-dimensional ndarray of shape (self.num_states,
self.num_actions). Stored in out, which must be of length
self.num_states.
by a 2-dimensional ndarray of shape (n, m). Stored in out,
which must be of length self.num_states.
out and out_argmax must be of length self.num_states; dtype of
out_argmax must be int.
Expand Down Expand Up @@ -780,14 +775,15 @@ def midrange(z):
u = np.empty(self.num_states)
sigma = np.empty(self.num_states, dtype=int)

try:
tol = epsilon * (1-self.beta) / self.beta
except ZeroDivisionError: # Raised if beta = 0
tol = np.inf

for i in range(max_iter):
# Policy improvement
self.bellman_operator(v, Tv=u, sigma=sigma)
diff = u - v
try:
tol = epsilon * (1-self.beta) / self.beta
except ZeroDivisionError: # Raised if beta = 0
tol = np.inf
if span(diff) < tol:
v[:] = u + midrange(diff) * self.beta / (1 - self.beta)
break
Expand Down

0 comments on commit b0cb89a

Please sign in to comment.