From e8ada347451b9b243e18b1a259b6751330e8a3dc Mon Sep 17 00:00:00 2001 From: William Edwards Date: Tue, 22 Mar 2022 05:40:18 -0500 Subject: [PATCH] Added API to view optimized trajectory --- autompc/controller.py | 6 ++++++ autompc/optim/ilqr.py | 7 +++++++ autompc/optim/optimizer.py | 6 ++++++ autompc/trajectory.py | 3 +++ 4 files changed, 22 insertions(+) diff --git a/autompc/controller.py b/autompc/controller.py index fccda25..0380e75 100644 --- a/autompc/controller.py +++ b/autompc/controller.py @@ -520,6 +520,12 @@ def set_state(self, state): self.last_control = state["last_control"] self.optimizer.set_state(state["optimizer_state"]) + def get_optimized_traj(self): + """ + Returns the last optimized trajectory, if available. + """ + return self.optimizer.get_traj() + class AutoSelectController(Controller): """ A version of the controller which comes with a default selection diff --git a/autompc/optim/ilqr.py b/autompc/optim/ilqr.py index 157e8a1..d91cf54 100644 --- a/autompc/optim/ilqr.py +++ b/autompc/optim/ilqr.py @@ -10,6 +10,7 @@ # Internal libary includes from .optimizer import Optimizer +from ..trajectory import Trajectory class IterativeLQR(Optimizer): """ @@ -42,6 +43,7 @@ def set_config(self, config): def reset(self): self._guess = None + self._traj = None def set_ocp(self, ocp): super().set_ocp(ocp) @@ -221,4 +223,9 @@ def step(self, state, silent=True): converged, states, ctrls, Ks, ks = self.compute_ilqr(state, self._guess, silent=silent) self._guess = np.concatenate((ctrls[1:], np.zeros((1, self.system.ctrl_dim))), axis=0) + self._traj = Trajectory(self.system, states.shape[0], states, + np.vstack([ctrls, np.zeros(self.system.ctrl_dim)])) return ctrls[0] + + def get_traj(self): + return self._traj.clone() diff --git a/autompc/optim/optimizer.py b/autompc/optim/optimizer.py index 1af72c6..9fdb4bf 100644 --- a/autompc/optim/optimizer.py +++ b/autompc/optim/optimizer.py @@ -114,6 +114,12 @@ def get_state(self): """ raise NotImplementedError + def get_traj(self): + """ + Returns the last optimized trajectory, if available. + """ + raise NotImplementedError + @abstractmethod def set_state(self, state): """ diff --git a/autompc/trajectory.py b/autompc/trajectory.py index 4ad3ebe..eff0cd3 100644 --- a/autompc/trajectory.py +++ b/autompc/trajectory.py @@ -199,3 +199,6 @@ def ctrls(self, ctrls): if ctrls.shape != (self._size, self._system.ctrl_dim): raise ValueError("ctrls is wrong shape") self._ctrls = ctrls[:] + + def clone(self): + return Trajectory(self.system, self.size, np.copy(self.obs), np.copy(self.ctrls))