Skip to content

Commit

Permalink
Refactor transform to avoid unnecessary copying
Browse files Browse the repository at this point in the history
  • Loading branch information
Tamar Grey committed Oct 3, 2022
1 parent f791cfd commit dd5c075
Showing 1 changed file with 14 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -217,22 +217,24 @@ def transform(self, X, y=None):
"""
X = infer_feature_types(X)

X_copy = X.ww.copy()
X = X.ww.drop(columns=self.features_to_encode)
if not self.features_to_encode:
# If there are no features to encode, X needs no transformation, so return a copy
return X.ww.copy()

X_orig = X.ww.drop(columns=self.features_to_encode)

# Call sklearn's transform on only the ordinal columns
if len(self.features_to_encode) > 0:
X_ord = pd.DataFrame(
self._encoder.transform(X_copy[self.features_to_encode]),
index=X_copy.index,
)
X_ord.columns = self._get_feature_names()
X_ord.ww.init(logical_types={c: "Double" for c in X_ord.columns})
self._feature_names = X_ord.columns
X_t = pd.DataFrame(
self._encoder.transform(X[self.features_to_encode]),
index=X.index,
)
X_t.columns = self._get_feature_names()
X_t.ww.init(logical_types={c: "Double" for c in X_t.columns})
self._feature_names = X_t.columns

X = ww.utils.concat_columns([X, X_ord])
X_t = ww.utils.concat_columns([X_orig, X_t])

return X
return X_t

def _get_feature_names(self):
"""Return feature names for the ordinal features after fitting.
Expand Down

0 comments on commit dd5c075

Please sign in to comment.