Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calling sampling with conditions and without num_rows crashes #614

Closed
csala opened this issue Oct 28, 2021 · 0 comments
Closed

Calling sampling with conditions and without num_rows crashes #614

csala opened this issue Oct 28, 2021 · 0 comments
Labels
bug Something isn't working
Milestone

Comments

@csala
Copy link
Contributor

csala commented Oct 28, 2021

Environment Details

  • SDV version: 0.12.1

Error Description

When calling the sample method of a tabular model without passing any arguments, the model produces
as many rows as it saw during the fit phase:

In [1]: import sdv
   ...: 
   ...: data = sdv.demo.load_tabular_demo('student_placements')
   ...: 
   ...: gc = sdv.tabular.GaussianCopula()
   ...: gc.fit(data.head(10))
   ...: 
   ...: gc.sample()
Out[1]: 
   student_id gender  second_perc  high_perc high_spec  degree_perc  ... mba_perc   salary  placed  start_date   end_date  duration
0       17270      M        59.18      60.63  Commerce        78.04  ...    51.58      NaN   False         NaT        NaT       NaN
1       17266      M        57.90      49.20   Science        69.00  ...    52.84      NaN   False         NaT        NaT       NaN
2       17266      M        58.67      64.73   Science        71.14  ...    62.49  22400.0    True  2020-01-12 2020-05-18       4.0
3       17268      M        79.76      58.37   Science        70.44  ...    56.47  33400.0    True  2020-03-23 2020-07-22       3.0
4       17269      M        76.12      67.72      Arts        69.53  ...    57.44  28800.0    True  2020-02-02 2020-07-17       4.0
5       17270      M        74.14      68.20  Commerce        60.42  ...    55.82  27000.0    True  2020-03-11 2020-08-14       6.0
6       17268      F        57.94      64.93  Commerce        68.47  ...    51.58      NaN   False         NaT        NaT       NaN
7       17268      M        53.20      54.04   Science        71.75  ...    51.58      NaN   False         NaT        NaT       NaN
8       17271      F        49.14      58.24  Commerce        76.70  ...    56.37      NaN   False         NaT        NaT       NaN
9       17268      F        68.63      49.20   Science        74.70  ...    51.58      NaN   False         NaT        NaT       NaN

[10 rows x 17 columns]

When sampling conditionally passing a dict AND a number of rows to sample, everything works as expected:

In [2]: gc.sample(num_rows=5, conditions={'gender': 'M'})
Out[2]: 
   student_id gender  second_perc  high_perc high_spec  degree_perc  ... mba_perc   salary  placed  start_date   end_date  duration
0       17267      M        62.39      81.37      Arts        75.29  ...    58.84  24000.0    True  2020-01-30 2020-07-11       5.0
1       17270      M        76.24      64.03   Science        77.89  ...    52.53  29100.0    True  2020-01-28 2020-07-04       4.0
2       17268      M        70.18      87.62  Commerce        62.72  ...    62.80  26100.0    True  2020-03-02 2020-08-23       4.0
3       17267      M        75.90      72.41   Science        74.46  ...    58.08  27800.0    True  2020-02-05 2020-07-14       3.0
4       17269      M        71.36      58.67   Science        63.50  ...    62.55  26000.0    True  2020-02-01 2020-05-20       3.0

[5 rows x 17 columns]

However, when sampling with a conditions dict WITHOUT giving num_rows, the sample call crashes.

In [3]: gc.sample(conditions={'gender': 'M'})
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/home/ubuntu/.virtualenvs/sdv-test/lib/python3.7/site-packages/sdv/tabular/base.py in _make_conditions_df(self, conditions, num_rows)
    345             try:
--> 346                 conditions = pd.DataFrame(conditions)
    347             except ValueError:

/home/ubuntu/.virtualenvs/sdv-test/lib/python3.7/site-packages/pandas/core/frame.py in __init__(self, data, index, columns, dtype, copy)
    467         elif isinstance(data, dict):
--> 468             mgr = init_dict(data, index, columns, dtype=dtype)
    469         elif isinstance(data, ma.MaskedArray):

/home/ubuntu/.virtualenvs/sdv-test/lib/python3.7/site-packages/pandas/core/internals/construction.py in init_dict(data, index, columns, dtype)
    282         ]
--> 283     return arrays_to_mgr(arrays, data_names, index, columns, dtype=dtype)
    284 

/home/ubuntu/.virtualenvs/sdv-test/lib/python3.7/site-packages/pandas/core/internals/construction.py in arrays_to_mgr(arrays, arr_names, index, columns, dtype, verify_integrity)
     77         if index is None:
---> 78             index = extract_index(arrays)
     79         else:

/home/ubuntu/.virtualenvs/sdv-test/lib/python3.7/site-packages/pandas/core/internals/construction.py in extract_index(data)
    386         if not indexes and not raw_lengths:
--> 387             raise ValueError("If using all scalar values, you must pass an index")
    388 

ValueError: If using all scalar values, you must pass an index

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-4-d253bdf9697f> in <module>
----> 1 gc.sample(conditions={'gender': 'M'})

/home/ubuntu/.virtualenvs/sdv-test/lib/python3.7/site-packages/sdv/tabular/base.py in sample(self, num_rows, max_retries, max_rows_multiplier, conditions, float_rtol, graceful_reject_sampling)
    443 
    444         # convert conditions to dataframe
--> 445         conditions = self._make_conditions_df(conditions, num_rows)
    446 
    447         # validate columns

/home/ubuntu/.virtualenvs/sdv-test/lib/python3.7/site-packages/sdv/tabular/base.py in _make_conditions_df(self, conditions, num_rows)
    346                 conditions = pd.DataFrame(conditions)
    347             except ValueError:
--> 348                 conditions = pd.DataFrame([conditions] * num_rows)
    349 
    350         elif not isinstance(conditions, pd.DataFrame):

TypeError: can't multiply sequence by non-int of type 'NoneType'

Expected Behavior

Since sampling without any arguments produces the same number of rows as seen in the input data, one would expect that passing a conditions dict without num_rows would achieve the same result.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants