diff --git a/octo/utils/gym_wrappers.py b/octo/utils/gym_wrappers.py index fe9411e7..84a499dc 100644 --- a/octo/utils/gym_wrappers.py +++ b/octo/utils/gym_wrappers.py @@ -271,22 +271,46 @@ def __init__( super().__init__(env) def unnormalize(self, data, metadata): + mask = metadata.get("mask", np.ones_like(metadata["mean"], dtype=bool)) if self.normalization_type == "normal": - return (data * metadata["std"]) + metadata["mean"] + return np.where( + mask, + (data * metadata["std"]) + metadata["mean"], + data, + ) elif self.normalization_type == "bounds": - return (data * (metadata["max"] - metadata["min"])) + metadata["min"] + return np.where( + mask, + ((data + 1) / 2 * (metadata["max"] - metadata["min"] + 1e-8)) + + metadata["min"], + data, + ) else: raise ValueError( f"Unknown action/proprio normalization type: {self.normalization_type}" ) def normalize(self, data, metadata): + mask = metadata.get("mask", np.ones_like(metadata["mean"], dtype=bool)) if self.normalization_type == "normal": - return (data / (metadata["std"] + 1e-8)) - metadata["mean"] + return np.where( + mask, + (data - metadata["mean"]) / (metadata["std"] + 1e-8), + data, + ) elif self.normalization_type == "bounds": - return ( - (data + 1) / (2 * (metadata["max"] - metadata["min"] + 1e-8)) - ) + metadata["min"] + return np.where( + mask, + np.clip( + 2 + * (data - metadata["min"]) + / (metadata["max"] - metadata["min"] + 1e-8) + - 1, + -1, + 1, + ), + data, + ) else: raise ValueError( f"Unknown action/proprio normalization type: {self.normalization_type}"