Skip to content

Commit

Permalink
23.06 reobust-ish fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dcolinmorgan committed Dec 31, 2023
1 parent 88ebfdd commit 11d814d
Showing 1 changed file with 29 additions and 23 deletions.
52 changes: 29 additions & 23 deletions cu_cat/_table_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def _auto_cast(self, X: pd.DataFrame) -> pd.DataFrame:
for i in obj_col:
X[i]=X[i].replace('nan',np.nan).fillna('0o0o0')
X[i]=X[i].str.rjust(4,'0')
X[i]=X[i].str.replace('.', 'dot', regex=False)
for col in X.columns:
# Convert pandas' NaN value (pd.NA) to numpy NaN value (np.nan)
# because the former tends to raise all kind of issues when dealing
Expand Down Expand Up @@ -799,21 +800,24 @@ def get_feature_names_out(self, input_features=None) -> List[str]:
typing.List[str]
Feature names.
"""
# if 'cudf' not in self.Xt_ and not deps.cudf:
# if parse_version(sklearn_version) < parse_version("1.0"):
# ct_feature_names = super().get_feature_names()
# else:
# ct_feature_names = super().get_feature_names_out()
# else:
# if parse_version(sklearn_version) > parse_version("1.0"):
try:
ct_feature_names = super().get_feature_names_out()
except:
pass
try:
ct_feature_names = super().get_feature_names()
except:
pass
if 'cudf' not in self.Xt_ and not deps.cudf:
if parse_version(sklearn_version) > parse_version("1.0"):
ct_feature_names = super().get_feature_names()
else:
ct_feature_names = super().get_feature_names_out()
else:
if parse_version(sklearn_version) < parse_version("1.0"):
ct_feature_names = super().get_feature_names_out()
else:
ct_feature_names = super().get_feature_names()
# try:
# ct_feature_names = super().get_feature_names_out()
# except:
# pass
# try:
# ct_feature_names = super().get_feature_names()
# except:
# pass
all_trans_feature_names = []

for name, trans, cols, _ in self._iter(fitted=True):
Expand All @@ -824,14 +828,16 @@ def get_feature_names_out(self, input_features=None) -> List[str]:
cols = self.columns_.to_list()
all_trans_feature_names.extend(cols)
continue
try:
trans_feature_names = super().get_feature_names_out()
except:
pass
try:
trans_feature_names = super().get_feature_names()
except:
pass
if 'cudf' not in self.Xt_ and not deps.cudf:
if parse_version(sklearn_version) > parse_version("1.0"):
trans_feature_names = super().get_feature_names()
else:
trans_feature_names = super().get_feature_names_out()
else:
if parse_version(sklearn_version) < parse_version("1.0"):
trans_feature_names = super().get_feature_names_out()
else:
trans_feature_names = super().get_feature_names()
all_trans_feature_names.extend(trans_feature_names)

if len(ct_feature_names) != len(all_trans_feature_names):
Expand Down

0 comments on commit 11d814d

Please sign in to comment.