Skip to content

Commit

Permalink
change prefix + columns_to_sdtype
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Sep 8, 2023
1 parent afbf35d commit acfd2c8
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 118 deletions.
89 changes: 42 additions & 47 deletions rdt/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,15 +488,15 @@ class BaseMultiColumnTransformer(BaseTransformer):
in order to create a new multi column transformer.
Attributes:
ordered_columns (tuple):
columns_to_sdtype (tuple):
Order of the columns to be used for the transformer.
prefixes (dict):
Dictionary mapping each output column to its prefix.
"""

def __init__(self):
super().__init__()
self.ordered_columns = None
self.columns_to_sdtype = {}
self.prefixes = {}

def get_input_column(self):
Expand All @@ -519,87 +519,82 @@ def get_input_columns(self):
"""
return self.columns

def _get_output_to_property(self, property_):
result = {
f'{self.prefixes[output_column]}.{output_column}': properties[property_]
for output_column, properties in self.output_properties.items()
}

return result

def _validate_ordered_columns(self, data, ordered_columns):
"""Check that all the columns in ``ordered_columns`` are present in the data."""
missing = set(ordered_columns) - set(data.columns)
if missing:
missing_to_print = ', '.join(missing)
raise KeyError(f'Columns ({missing_to_print}) are not present in the data.')

def _generate_prefixes(self, ordered_columns):
"""Generate prefixes for the output columns to precised which column they come from.
def _get_prefix(self):
"""Return the prefix of the output columns.
Returns:
dict:
Dictionary mapping each output column to its prefix.
The key is the output column name and the value is the prefix.
str:
Prefix of the output columns.
"""
raise NotImplementedError()

def _validate_prefixes(self, ordered_columns):
"""Check that the prefixes are valid.
def _get_output_to_property(self, property_):
self.prefixes = self._get_prefix()
is_prefix_dict = isinstance(self.prefixes, dict)
output = {}
for output_column, properties in self.output_properties.items():
# if 'sdtype' is not in the dict, ignore the column
if property_ not in properties:
continue

if is_prefix_dict:
prefix = self.prefixes[output_column]
else:
prefix = self.prefixes

Every prefix must include the name of at least one column in the data.
"""
for prefix in self.prefixes.values():
if not any(column in prefix for column in ordered_columns):
raise ValueError(
f"The prefix '{prefix}' does not include the name of any column in the data."
)
if prefix is None:
output[f'{output_column}'] = properties[property_]
else:
output[f'{prefix}.{output_column}'] = properties[property_]

return output

def _fit(self, data, ordered_columns):
def _validate_columns_to_sdtype(self, data, columns_to_sdtype):
"""Check that all the columns in ``columns_to_sdtype`` are present in the data."""
missing = set(columns_to_sdtype.keys()) - set(data.columns)
if missing:
missing_to_print = ', '.join(missing)
raise KeyError(f'Columns ({missing_to_print}) are not present in the data.')

def _fit(self, data):
"""Fit the transformer to the data.
Args:
data (pandas.DataFrame):
Data to transform.
ordered_columns (tuple):
Order of the columns to be used for the transformer.
"""
raise NotImplementedError()

@random_state
def fit(self, data, ordered_columns):
def fit(self, data, columns_to_sdtype):
"""Fit the transformer to a ``column`` of the ``data``.
Args:
data (pandas.DataFrame):
The entire table.
ordered_columns (tuple):
columns_to_sdtype (tuple):
Order of the columns to be used for the transformer.
"""
self._validate_ordered_columns(data, ordered_columns)
self.ordered_columns = ordered_columns
self._store_columns(ordered_columns, data)
self._validate_columns_to_sdtype(data, columns_to_sdtype)
self.columns_to_sdtype = columns_to_sdtype
self._store_columns(list(self.columns_to_sdtype.keys()), data)
self._set_seed(data)

columns_data = self._get_columns_data(data, self.columns)
self._fit(columns_data, ordered_columns)

self.prefixes = self._generate_prefixes(ordered_columns)
self._validate_prefixes(ordered_columns)
self._fit(columns_data)
self._build_output_columns(data)

def fit_transform(self, data, ordered_columns):
def fit_transform(self, data, columns_to_sdtype):
"""Fit the transformer to a `column` of the `data` and then transform it.
Args:
data (pandas.DataFrame):
The entire table.
ordered_columns (tuple):
columns_to_sdtype (tuple):
Order of the columns to be used for the transformer.
Returns:
pd.DataFrame:
The entire table, containing the transformed data.
"""
self.fit(data, ordered_columns)
self.fit(data, columns_to_sdtype)
return self.transform(data)
59 changes: 35 additions & 24 deletions tests/integration/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,15 @@ def test_multi_column_transformer_same_number_of_columns_input_output():
# Setup
class AdditionTransformer(BaseMultiColumnTransformer):
"""This transformer takes 3 columns and return the cumulative sum of each row."""
def _fit(self, columns_data, ordered_columns):
def _fit(self, columns_data):
self.output_properties = {
column: {'sdtype': 'numerical'} for column in self.columns
f'{self.columns[0]}': {'sdtype': 'numerical'},
f'{self.columns[0]}+{self.columns[1]}': {'sdtype': 'numerical'},
f'{self.columns[0]}+{self.columns[1]}+{self.columns[2]}': {'sdtype': 'numerical'}
}

def _generate_prefixes(self, ordered_columns):
prefixes = {}
for idx, column in enumerate(self.output_properties):
prefixes[column] = '#'.join(ordered_columns[:idx + 1])

return prefixes
def _get_prefix(self):
return None

def _transform(self, data):
return data.cumsum(axis=1)
Expand All @@ -163,18 +161,22 @@ def _reverse_transform(self, data):
'col_3': [100, 200, 300]
})

order_columns = ('col_1', 'col_2', 'col_3')
columns_to_sdtype = {
'col_1': 'numerical',
'col_2': 'numerical',
'col_3': 'numerical'
}
transformer = AdditionTransformer()

# Run
transformed = transformer.fit_transform(data_test, order_columns)
transformed = transformer.fit_transform(data_test, columns_to_sdtype)
reverse = transformer.reverse_transform(transformed)

# Assert
expected_transform = pd.DataFrame({
'col_1.col_1': [1, 2, 3],
'col_1#col_2.col_2': [11, 22, 33],
'col_1#col_2#col_3.col_3': [111, 222, 333]
'col_1': [1, 2, 3],
'col_1+col_2': [11, 22, 33],
'col_1+col_2+col_3': [111, 222, 333]
})
pd.testing.assert_frame_equal(expected_transform, transformed)
pd.testing.assert_frame_equal(reverse, data_test)
Expand All @@ -186,18 +188,19 @@ class ConcatenateTransformer(BaseMultiColumnTransformer):
"""This transformer takes 4 columns and concatenate them into 2 columns.
The two first and last columns are concatenated together.
"""
def _fit(self, columns_data, ordered_columns):
def _fit(self, columns_data):
self.name_1 = self.columns[0] + '#' + self.columns[1]
self.name_2 = self.columns[2] + '#' + self.columns[3]
self.output_properties = {
'concatenate_1': {'sdtype': 'categorical'},
'concatenate_2': {'sdtype': 'categorical'}
}

def _generate_prefixes(self, ordered_columns):
prefixes = {}
for idx, column in enumerate(self.output_properties):
prefixes[column] = self.name_1 if idx == 0 else self.name_2
def _get_prefix(self):
prefixes = {
'concatenate_1': self.name_1,
'concatenate_2': self.name_2
}

return prefixes

Expand Down Expand Up @@ -226,11 +229,16 @@ def _reverse_transform(self, data):
'col_4': ['J', 'K', 'L']
})

ordered_columns = ('col_1', 'col_2', 'col_3', 'col_4')
columns_to_sdtype = {
'col_1': 'categorical',
'col_2': 'categorical',
'col_3': 'categorical',
'col_4': 'categorical'
}
transformer = ConcatenateTransformer()

# Run
transformer.fit(data_test, ordered_columns)
transformer.fit(data_test, columns_to_sdtype)
transformed = transformer.transform(data_test)
reverse = transformer.reverse_transform(transformed)

Expand All @@ -247,15 +255,15 @@ def test_multi_column_transformer_more_output_than_input_columns():
"""Test a multi-column transformer when the output has more columns than the input."""
class ExpandTransformer(BaseMultiColumnTransformer):

def _fit(self, columns_data, ordered_columns):
def _fit(self, columns_data):
self.output_properties = {
'first_part_1': {'sdtype': 'categorical'},
'second_part_1': {'sdtype': 'categorical'},
'first_part_2': {'sdtype': 'categorical'},
'second_part_2': {'sdtype': 'categorical'}
}

def _generate_prefixes(self, ordered_columns):
def _get_prefix(self):
list_prefixes = [
self.columns[0], self.columns[0],
self.columns[1], self.columns[1]
Expand Down Expand Up @@ -288,11 +296,14 @@ def _reverse_transform(self, data):
'col_2': ['GH', 'IJ', 'KL'],
})

ordered_columns = ('col_1', 'col_2')
columns_to_sdtype = {
'col_1': 'categorical',
'col_2': 'categorical'
}
transformer = ExpandTransformer()

# Run
transformer.fit(data_test, ordered_columns)
transformer.fit(data_test, columns_to_sdtype)
transformed = transformer.transform(data_test)
reverse = transformer.reverse_transform(transformed)

Expand Down
Loading

0 comments on commit acfd2c8

Please sign in to comment.