diff --git a/rdt/transformers/numerical.py b/rdt/transformers/numerical.py index d0edb2984..6f05302bf 100644 --- a/rdt/transformers/numerical.py +++ b/rdt/transformers/numerical.py @@ -463,7 +463,6 @@ def _fit(self, data): n_components=self.max_clusters, weight_concentration_prior_type='dirichlet_process', weight_concentration_prior=0.001, - n_init=1, random_state=self._get_current_random_seed() ) @@ -494,10 +493,12 @@ def _transform(self, data): data = data.reshape((len(data), 1)) means = self._bgm_transformer.means_.reshape((1, self.max_clusters)) - + means = means[:, self.valid_component_indicator] stds = np.sqrt(self._bgm_transformer.covariances_).reshape((1, self.max_clusters)) + stds = stds[:, self.valid_component_indicator] + + # Multiply stds by 4 so that a value will be in the range [-1,1] with 99.99% probability normalized_values = (data - means) / (self.STD_MULTIPLIER * stds) - normalized_values = normalized_values[:, self.valid_component_indicator] component_probs = self._bgm_transformer.predict_proba(data) component_probs = component_probs[:, self.valid_component_indicator] @@ -524,7 +525,8 @@ def _reverse_transform_helper(self, data): normalized = np.clip(data[:, 0], -1, 1) means = self._bgm_transformer.means_.reshape([-1]) stds = np.sqrt(self._bgm_transformer.covariances_).reshape([-1]) - selected_component = data[:, 1].astype(int) # maybe round instead? + selected_component = data[:, 1].round().astype(int) + selected_component = selected_component.clip(0, self.valid_component_indicator.sum() - 1) std_t = stds[self.valid_component_indicator][selected_component] mean_t = means[self.valid_component_indicator][selected_component] reversed_data = normalized * self.STD_MULTIPLIER * std_t + mean_t @@ -546,8 +548,6 @@ def _reverse_transform(self, data): recovered_data = self._reverse_transform_helper(data) if self.null_transformer and self.null_transformer.models_missing_values(): - data = np.stack([recovered_data, data[:, -1]], axis=1) # noqa: PD013 - else: - data = recovered_data + recovered_data = np.stack([recovered_data, data[:, -1]], axis=1) # noqa: PD013 - return super()._reverse_transform(data) + return super()._reverse_transform(recovered_data) diff --git a/tests/unit/transformers/test_numerical.py b/tests/unit/transformers/test_numerical.py index 9b35c1135..d7c6bea95 100644 --- a/tests/unit/transformers/test_numerical.py +++ b/tests/unit/transformers/test_numerical.py @@ -1213,7 +1213,6 @@ def test__fit(self, mock_bgm): n_components=10, weight_concentration_prior_type='dirichlet_process', weight_concentration_prior=0.001, - n_init=1, random_state=0 )