-
Notifications
You must be signed in to change notification settings - Fork 320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix categorical column after sequence_index column issue #357
Conversation
sdv/timeseries/deepecho.py
Outdated
@@ -67,7 +67,8 @@ def _fit(self, timeseries_data): | |||
|
|||
data_types = list() | |||
context_types = list() | |||
for field, meta in self._metadata.get_fields().items(): | |||
for field in self._entity_columns + self._data_columns: | |||
meta = self._metadata.get_fields()[field] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would capture the fields_metadata
in a variable before the loop to avoid having to call self._metadata.get_fields()
at each iteration.
sdv/timeseries/deepecho.py
Outdated
@@ -67,7 +67,8 @@ def _fit(self, timeseries_data): | |||
|
|||
data_types = list() | |||
context_types = list() | |||
for field, meta in self._metadata.get_fields().items(): | |||
for field in self._entity_columns + self._data_columns: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will possibly not work because the order of the columns will be altered, and also we would be missing the context_columns
.
If the order of the key/value pairs from the self._metadata.get_fields()
is the problem, maybe a possibility would be to just iterate over self._output_columns
(which is the list of columns from the input data)?
Then, in order to solve the sequence_index
problem, we could change line 74 (from the old code):
if field == self._sequence_index:
data_types.append('continuous')
to
if field == self._sequence_index:
data_types.extend(['continuous', 'continuous'])
And then just remove line 82 (from the old code).
def test_column_after_date(): | ||
"""Test that adding columns after the `sequence_index` column works.""" | ||
date = datetime.datetime.strptime('2020-01-01', '%Y-%m-%d') | ||
daily_timeseries = pd.DataFrame({ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be worth to make this test slightly more complex, so there are multiple data types and both entity columns and context columns.
}) | ||
|
||
model = PAR(entity_columns=['col'], sequence_index='date', epochs=1) | ||
model.fit(daily_timeseries) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be validating a bit more? For example, validate that the output types are actually right
Codecov Report
@@ Coverage Diff @@
## master #357 +/- ##
=======================================
Coverage 65.01% 65.01%
=======================================
Files 34 34
Lines 2590 2590
=======================================
Hits 1684 1684
Misses 906 906 Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolve #314.