Skip to content

[feat] Provide explicit col dtypes from user side #396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
21 changes: 8 additions & 13 deletions autoPyTorch/data/base_feature_validator.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import logging
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import numpy as np

import pandas as pd

from scipy.sparse import spmatrix

from sklearn.base import BaseEstimator

from autoPyTorch.data.utils import SupportedFeatTypes, list_to_pandas
from autoPyTorch.utils.logging_ import PicklableClientLogger


SupportedFeatTypes = Union[List, pd.DataFrame, np.ndarray, spmatrix]


class BaseFeatureValidator(BaseEstimator):
"""
A class to pre-process features. In this regards, the format of the data is checked,
Expand All @@ -27,8 +21,8 @@ class BaseFeatureValidator(BaseEstimator):
column_transformer (Optional[BaseEstimator])
Host a encoder object if the data requires transformation (for example,
if provided a categorical column in a pandas DataFrame)
transformed_columns (List[str])
List of columns that were encoded.
enc_columns (Optional[List[str]]):
The list of column names that should be encoded.
"""
def __init__(
self,
Expand All @@ -37,11 +31,11 @@ def __init__(
# Register types to detect unsupported data format changes
self.feat_type: Optional[List[str]] = None
self.data_type: Optional[type] = None
self.dtypes: List[str] = []
self.dtypes: Dict[str, str] = {}
self.column_order: List[str] = []

self.column_transformer: Optional[BaseEstimator] = None
self.transformed_columns: List[str] = []
self.enc_columns: List[str] = []

self.logger: Union[
PicklableClientLogger, logging.Logger
Expand Down Expand Up @@ -75,7 +69,8 @@ def fit(

# If a list was provided, it will be converted to pandas
if isinstance(X_train, list):
X_train, X_test = self.list_to_dataframe(X_train, X_test)
X_train = list_to_pandas(X_train, self.logger)
X_test = list_to_pandas(X_test, self.logger) if X_test is not None else None

self._check_data(X_train)

Expand Down
8 changes: 2 additions & 6 deletions autoPyTorch/data/base_target_validator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import logging
from typing import List, Optional, Union, cast
from typing import Optional, Union, cast

import numpy as np

import pandas as pd

from scipy.sparse import spmatrix

from sklearn.base import BaseEstimator

from autoPyTorch.utils.logging_ import PicklableClientLogger


SupportedTargetTypes = Union[List, pd.Series, pd.DataFrame, np.ndarray, spmatrix]
from autoPyTorch.data.utils import SupportedTargetTypes


class BaseTargetValidator(BaseEstimator):
Expand Down
Loading