Skip to content

Commit

Permalink
Merge pull request #98 from HDI-Project/93_model_amount_childs
Browse files Browse the repository at this point in the history
Issue 93: Model amount of children
  • Loading branch information
csala authored May 15, 2019
2 parents 196198b + 96d266f commit 941e4ef
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 259 deletions.
40 changes: 24 additions & 16 deletions sdv/modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ class Modeler:
DEFAULT_PRIMARY_KEY = 'GENERATED_PRIMARY_KEY'

def __init__(self, data_navigator, model=DEFAULT_MODEL, distribution=None, model_kwargs=None):
"""Instantiates a modeler object.
"""
"""Instantiates a modeler object."""
self.tables = {}
self.models = {}
self.child_locs = {} # maps table->{child: col #}
Expand All @@ -48,7 +46,7 @@ def __init__(self, data_navigator, model=DEFAULT_MODEL, distribution=None, model
raise ValueError(
'`distribution` argument is only suported for `GaussianMultivariate` model.')

if distribution:
if distribution is not None:
distribution = get_qualified_name(distribution)
else:
distribution = get_qualified_name(DEFAULT_DISTRIBUTION)
Expand Down Expand Up @@ -145,16 +143,18 @@ def _flatten_dict(cls, nested, prefix=''):

return result

def flatten_model(self, model, name=''):
"""Flatten a model's parameters into an array.
def _get_model_dict(self, data):
"""Fit and serialize a model and flatten its parameters into an array.
Args:
model(self.model): Instance of model.
name (str): Prefix to the parameter name.
data(pandas.DataFrame): Dataset to fit the model to.
Returns:
pd.Series: parameters for model
dict: Flattened parameters for model.
"""
model = self.fit_model(data)

if self.model == DEFAULT_MODEL:
values = []
triangle = np.tril(model.covariance)
Expand All @@ -173,7 +173,7 @@ def flatten_model(self, model, name=''):
column = pd.DataFrame({'field': [distribution.std]})
distribution.std = transformer.reverse_transform(column).loc[0, 'field']

return pd.Series(self._flatten_dict(model.to_dict(), name))
return self._flatten_dict(model.to_dict())

def get_foreign_key(self, fields, primary):
"""Get foreign key from primary key.
Expand Down Expand Up @@ -254,16 +254,24 @@ def _create_extension(self, foreign, transformed_child_table, table_info):

foreign_key, child_name = table_info
try:
conditional_data = transformed_child_table.loc[foreign.index].copy()
if foreign_key in conditional_data:
conditional_data = conditional_data.drop(foreign_key, axis=1)
child_rows = transformed_child_table.loc[foreign.index].copy()
if foreign_key in child_rows:
child_rows = child_rows.drop(foreign_key, axis=1)

except KeyError:
return None

if len(conditional_data):
clean_df = self.impute_table(conditional_data)
return self.flatten_model(self.fit_model(clean_df), child_name)
num_child_rows = len(child_rows)

if num_child_rows:
clean_df = self.impute_table(child_rows)
extension = self._get_model_dict(clean_df)
extension['child_rows'] = num_child_rows

extension = pd.Series(extension)
extension.index = child_name + '__' + extension.index

return extension

return None

Expand Down
Loading

0 comments on commit 941e4ef

Please sign in to comment.