From 5bd39e8cf7ae24545a071b84f95242fe4176a492 Mon Sep 17 00:00:00 2001 From: Karl Pertsch Date: Wed, 20 Dec 2023 00:23:30 -0500 Subject: [PATCH 1/2] fix norm gym wrapper --- octo/utils/gym_wrappers.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/octo/utils/gym_wrappers.py b/octo/utils/gym_wrappers.py index fe9411e7..63ac987d 100644 --- a/octo/utils/gym_wrappers.py +++ b/octo/utils/gym_wrappers.py @@ -274,7 +274,9 @@ def unnormalize(self, data, metadata): if self.normalization_type == "normal": return (data * metadata["std"]) + metadata["mean"] elif self.normalization_type == "bounds": - return (data * (metadata["max"] - metadata["min"])) + metadata["min"] + return ( + (data + 1) / 2 * (metadata["max"] - metadata["min"] + 1e-8) + ) + metadata["min"] else: raise ValueError( f"Unknown action/proprio normalization type: {self.normalization_type}" @@ -282,11 +284,14 @@ def unnormalize(self, data, metadata): def normalize(self, data, metadata): if self.normalization_type == "normal": - return (data / (metadata["std"] + 1e-8)) - metadata["mean"] + return (data - metadata["mean"]) / (metadata["std"] + 1e-8) elif self.normalization_type == "bounds": return ( - (data + 1) / (2 * (metadata["max"] - metadata["min"] + 1e-8)) - ) + metadata["min"] + 2 + * (data - metadata["min"]) + / (metadata["max"] - metadata["min"] + 1e-8) + - 1 + ) else: raise ValueError( f"Unknown action/proprio normalization type: {self.normalization_type}" From 1477edbf65d5745d46ba62c90d3ea5c17958bc26 Mon Sep 17 00:00:00 2001 From: Karl Pertsch Date: Thu, 21 Dec 2023 00:00:09 -0500 Subject: [PATCH 2/2] add mask and value clipping to normalization wrapper --- octo/utils/gym_wrappers.py | 39 ++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/octo/utils/gym_wrappers.py b/octo/utils/gym_wrappers.py index 63ac987d..84a499dc 100644 --- a/octo/utils/gym_wrappers.py +++ b/octo/utils/gym_wrappers.py @@ -271,26 +271,45 @@ def __init__( super().__init__(env) def unnormalize(self, data, metadata): + mask = metadata.get("mask", np.ones_like(metadata["mean"], dtype=bool)) if self.normalization_type == "normal": - return (data * metadata["std"]) + metadata["mean"] + return np.where( + mask, + (data * metadata["std"]) + metadata["mean"], + data, + ) elif self.normalization_type == "bounds": - return ( - (data + 1) / 2 * (metadata["max"] - metadata["min"] + 1e-8) - ) + metadata["min"] + return np.where( + mask, + ((data + 1) / 2 * (metadata["max"] - metadata["min"] + 1e-8)) + + metadata["min"], + data, + ) else: raise ValueError( f"Unknown action/proprio normalization type: {self.normalization_type}" ) def normalize(self, data, metadata): + mask = metadata.get("mask", np.ones_like(metadata["mean"], dtype=bool)) if self.normalization_type == "normal": - return (data - metadata["mean"]) / (metadata["std"] + 1e-8) + return np.where( + mask, + (data - metadata["mean"]) / (metadata["std"] + 1e-8), + data, + ) elif self.normalization_type == "bounds": - return ( - 2 - * (data - metadata["min"]) - / (metadata["max"] - metadata["min"] + 1e-8) - - 1 + return np.where( + mask, + np.clip( + 2 + * (data - metadata["min"]) + / (metadata["max"] - metadata["min"] + 1e-8) + - 1, + -1, + 1, + ), + data, ) else: raise ValueError(