Skip to content

Commit

Permalink
hotfix feature encoder (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianzhang Cai authored Apr 27, 2023
1 parent 03ec2ee commit 26b68d9
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/synthcity/plugins/core/models/feature_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def fit(self, x: pd.Series, y: Any = None, **kwargs: Any) -> FeatureEncoder:
output = self._fit(input, **kwargs)._transform(input)
self._out_shape = (-1, *output.shape[1:]) # for inverse_transform
output = validate_shape(output, self.n_dim_out)
self.n_features_out = output.shape[1]
if self.n_dim_out == 1:
self.n_features_out = 1
else:
self.n_features_out = output.shape[1]
self.feature_names_out = self.get_feature_names_out()
self.feature_types_out = self.get_feature_types_out(output)
return self
Expand Down Expand Up @@ -105,6 +108,8 @@ def _get_feature_type(self, x: Any) -> str:
return "discrete"
elif np.issubdtype(x.dtype, np.floating):
return "continuous"
elif np.issubdtype(x.dtype, np.datetime64):
return "datetime"
else:
return "discrete"

Expand Down

0 comments on commit 26b68d9

Please sign in to comment.