From 0b478669e77d62a84693125ed828657cb7bcd055 Mon Sep 17 00:00:00 2001 From: Steven Dahdah Date: Mon, 21 Aug 2023 10:44:08 -0400 Subject: [PATCH] Add option to plot error --- pykoop/koopman_pipeline.py | 79 +++++++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 23 deletions(-) diff --git a/pykoop/koopman_pipeline.py b/pykoop/koopman_pipeline.py index f649cea..0048f85 100644 --- a/pykoop/koopman_pipeline.py +++ b/pykoop/koopman_pipeline.py @@ -2691,6 +2691,7 @@ def plot_predicted_trajectory( relift_state: bool = True, plot_lifted: bool = False, plot_input: bool = False, + plot_error: bool = False, episode_feature: Optional[bool] = None, plot_ground_truth: bool = True, episode_style: Optional[str] = None, @@ -2719,6 +2720,9 @@ def plot_predicted_trajectory( plot_input : bool If true, plot the input as well as the state. If false, plot only the original state (default). + plot_error : bool + If true, plot the prediction error instead of the state. If false, + plot the predicted state and ground truth (default). episode_feature : Optional[bool] True if first feature indicates which episode a timestep is from. If ``None``, ``self.episode_feature_`` is used. @@ -2797,33 +2801,62 @@ def plot_predicted_trajectory( # Plot results for row in range(n_row): for ep in range(n_eps): - if episode_style == 'overlay': - line_pred = ax[row, 0].plot( - eps[ep][1][:, row], - label=f'Ep. {int(eps[ep][0])} prediction', - **plot_args, - ) - if eps_gt is not None and row < n_states: - ax[row, 0].plot( - eps_gt[ep][1][:, row], - label=f'Ep. {int(eps[ep][0])} ground truth', - linestyle='--', - color=line_pred[0].get_color(), + if plot_error: + if episode_style == 'overlay': + if eps_gt is not None and row < n_states: + ax[row, 0].plot( + eps_gt[ep][1][:, row] - eps[ep][1][:, row], + label=f'Ep. {int(eps[ep][0])} error', + **plot_args, + ) + else: + ax[row, 0].plot( + eps[ep][1][:, row], + label=f'Ep. {int(eps[ep][0])} input', + **plot_args, + ) + else: + if eps_gt is not None and row < n_states: + ax[row, ep].plot( + eps_gt[ep][1][:, row] - eps[ep][1][:, row], + label=f'Prediction error', + linestyle='--', + **plot_args, + ) + else: + ax[row, ep].plot( + eps[ep][1][:, row], + label=f'Input', + **plot_args, + ) + else: + if episode_style == 'overlay': + line_pred = ax[row, 0].plot( + eps[ep][1][:, row], + label=f'Ep. {int(eps[ep][0])} prediction', **plot_args, ) - else: - line_pred = ax[row, ep].plot( - eps[ep][1][:, row], - label=f'Prediction', - **plot_args, - ) - if eps_gt is not None and row < n_states: - ax[row, ep].plot( - eps_gt[ep][1][:, row], - label=f'Ground truth', - linestyle='--', + if eps_gt is not None and row < n_states: + ax[row, 0].plot( + eps_gt[ep][1][:, row], + label=f'Ep. {int(eps[ep][0])} ground truth', + linestyle='--', + color=line_pred[0].get_color(), + **plot_args, + ) + else: + line_pred = ax[row, ep].plot( + eps[ep][1][:, row], + label=f'Prediction', **plot_args, ) + if eps_gt is not None and row < n_states: + ax[row, ep].plot( + eps_gt[ep][1][:, row], + label=f'Ground truth', + linestyle='--', + **plot_args, + ) # Set y labels if plot_lifted: names = self.get_feature_names_out(