Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot sample after loading a model with custom constraints: TypeError #984

Closed
npatki opened this issue Aug 30, 2022 · 0 comments · Fixed by #991
Closed

Cannot sample after loading a model with custom constraints: TypeError #984

npatki opened this issue Aug 30, 2022 · 0 comments · Fixed by #991
Assignees
Labels
bug Something isn't working
Milestone

Comments

@npatki
Copy link
Contributor

npatki commented Aug 30, 2022

Environment Details

  • SDV version: 0.17.0.dev1

Error Description

In #944, we fixed the ability to save a model that uses custom constraints. This now works but when I reload the model and try to sample from it, I get an error (trace pasted below)

Steps to reproduce

  1. Follow the code from the Custom Constraints User Guide exactly to create a model with custom constraints. Verify that it samples synthetic data.
  2. Save and reload the model
  3. Try to sample after reloading the model
model.save('my_custom_constraint_model.pkl')

model = GaussianCopula.load('my_custom_constraint_model.pkl')
model.sample(num_rows=5)

Trace

Sampling rows:   0%|          | 0/5 [00:00<?, ?it/s]Error: Sampling terminated. Partial results are stored in a temporary file: .sample.csv.temp. This file will be overridden the next time you sample. Please rename the file if you wish to save these results.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-61-4fb96d2c64d8> in <module>
----> 1 model.sample(num_rows=5)

7 frames
/usr/local/lib/python3.7/dist-packages/sdv/tabular/base.py in sample(self, num_rows, randomize_samples, max_tries_per_batch, batch_size, output_file_path, conditions)
    550             output_file_path,
    551             conditions,
--> 552             show_progress_bar=show_progress_bar
    553         )
    554 

/usr/local/lib/python3.7/dist-packages/sdv/tabular/base.py in _sample_with_progress_bar(self, num_rows, randomize_samples, max_tries_per_batch, batch_size, output_file_path, conditions, show_progress_bar)
    506 
    507         except (Exception, KeyboardInterrupt) as error:
--> 508             handle_sampling_error(output_file_path == TMP_FILE_NAME, output_file_path, error)
    509 
    510         else:

/usr/local/lib/python3.7/dist-packages/sdv/tabular/utils.py in handle_sampling_error(is_tmp_file, output_file_path, sampling_error)
    163               f'Partial results are stored in {output_file_path}.')
    164 
--> 165     raise sampling_error
    166 
    167 

/usr/local/lib/python3.7/dist-packages/sdv/tabular/base.py in _sample_with_progress_bar(self, num_rows, randomize_samples, max_tries_per_batch, batch_size, output_file_path, conditions, show_progress_bar)
    502                     max_tries_per_batch=max_tries_per_batch,
    503                     progress_bar=progress_bar,
--> 504                     output_file_path=output_file_path
    505                 )
    506 

/usr/local/lib/python3.7/dist-packages/sdv/tabular/base.py in _sample_in_batches(self, num_rows, batch_size, max_tries_per_batch, conditions, transformed_conditions, float_rtol, progress_bar, output_file_path)
    383                 float_rtol=float_rtol,
    384                 progress_bar=progress_bar,
--> 385                 output_file_path=output_file_path,
    386             )
    387             sampled.append(sampled_rows)

/usr/local/lib/python3.7/dist-packages/sdv/tabular/base.py in _sample_batch(self, batch_size, max_tries, conditions, transformed_conditions, float_rtol, progress_bar, output_file_path)
    318             prev_num_valid = num_valid
    319             sampled, num_valid = self._sample_rows(
--> 320                 num_rows_to_sample, conditions, transformed_conditions, float_rtol, sampled,
    321             )
    322 

/usr/local/lib/python3.7/dist-packages/sdv/tabular/base.py in _sample_rows(self, num_rows, conditions, transformed_conditions, float_rtol, previous_rows)
    239                     sampled = self._sample(num_rows)
    240 
--> 241             sampled = self._metadata.reverse_transform(sampled)
    242 
    243             if previous_rows is not None:

/usr/local/lib/python3.7/dist-packages/sdv/metadata/table.py in reverse_transform(self, data)
    705 
    706         for constraint in reversed(self._constraints_to_reverse):
--> 707             reversed_data = constraint.reverse_transform(reversed_data)
    708 
    709         for name, field_metadata in self._fields_metadata.items():

TypeError: reverse_transform() missing 1 required positional argument: 'data'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants