Skip to content

Commit

Permalink
Added API to view optimized trajectory
Browse files Browse the repository at this point in the history
  • Loading branch information
williamedwards committed Mar 22, 2022
1 parent 625da0e commit e8ada34
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 0 deletions.
6 changes: 6 additions & 0 deletions autompc/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,12 @@ def set_state(self, state):
self.last_control = state["last_control"]
self.optimizer.set_state(state["optimizer_state"])

def get_optimized_traj(self):
"""
Returns the last optimized trajectory, if available.
"""
return self.optimizer.get_traj()

class AutoSelectController(Controller):
"""
A version of the controller which comes with a default selection
Expand Down
7 changes: 7 additions & 0 deletions autompc/optim/ilqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

# Internal libary includes
from .optimizer import Optimizer
from ..trajectory import Trajectory

class IterativeLQR(Optimizer):
"""
Expand Down Expand Up @@ -42,6 +43,7 @@ def set_config(self, config):

def reset(self):
self._guess = None
self._traj = None

def set_ocp(self, ocp):
super().set_ocp(ocp)
Expand Down Expand Up @@ -221,4 +223,9 @@ def step(self, state, silent=True):
converged, states, ctrls, Ks, ks = self.compute_ilqr(state, self._guess,
silent=silent)
self._guess = np.concatenate((ctrls[1:], np.zeros((1, self.system.ctrl_dim))), axis=0)
self._traj = Trajectory(self.system, states.shape[0], states,
np.vstack([ctrls, np.zeros(self.system.ctrl_dim)]))
return ctrls[0]

def get_traj(self):
return self._traj.clone()
6 changes: 6 additions & 0 deletions autompc/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def get_state(self):
"""
raise NotImplementedError

def get_traj(self):
"""
Returns the last optimized trajectory, if available.
"""
raise NotImplementedError

@abstractmethod
def set_state(self, state):
"""
Expand Down
3 changes: 3 additions & 0 deletions autompc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,6 @@ def ctrls(self, ctrls):
if ctrls.shape != (self._size, self._system.ctrl_dim):
raise ValueError("ctrls is wrong shape")
self._ctrls = ctrls[:]

def clone(self):
return Trajectory(self.system, self.size, np.copy(self.obs), np.copy(self.ctrls))

0 comments on commit e8ada34

Please sign in to comment.