Skip to content

Commit

Permalink
Change method name to be consistent with MPPI controller
Browse files Browse the repository at this point in the history
  • Loading branch information
LemonPi committed Jan 31, 2020
1 parent 345e2c7 commit 05ed475
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pytorch_cem/cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def pytorch_cov(x, rowvar=False, bias=False, ddof=None, aweights=None):


class CEM():
""" Cross Entropy Method control
"""
Cross Entropy Method control
This implementation batch samples the trajectories and so scales well with the number of samples K.
"""

Expand Down Expand Up @@ -110,7 +111,10 @@ def __init__(self, dynamics, running_cost, nx, nu, num_samples=100, num_iteratio
# regularize covariance
self.cov_reg = torch.eye(self.T * self.nu, device=self.d, dtype=self.dtype) * init_cov_diag * 1e-5

def reset_distribution(self):
def reset(self):
"""
Clear controller state after finishing a trial
"""
# action distribution, initialized as N(0,I)
# we do Hp x 1 instead of H x p because covariance will be Hp x Hp matrix instead of some higher dim tensor
self.mean = torch.zeros(self.T * self.nu, device=self.d, dtype=self.dtype)
Expand Down Expand Up @@ -157,7 +161,7 @@ def command(self, state, choose_best=False):
state = torch.tensor(state)
state = state.to(dtype=self.dtype, device=self.d)

self.reset_distribution()
self.reset()

for m in range(self.M):
top_samples = self._sample_top_trajectories(state, self.num_elite)
Expand Down

0 comments on commit 05ed475

Please sign in to comment.