Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Toker committed Sep 28, 2023
1 parent 2d62175 commit 8a98e7b
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions storey/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,11 @@ def __init__(
windows = {}
for aggregation_metadata in aggregates:
for meta in aggregation_metadata.aggregations:
for aggr, is_hidden in get_all_raw_aggregates_with_hidden([meta]).items():
meta_aggregates = get_all_raw_aggregates_with_hidden([meta]).items()
if not any(ag[0] == "count" for ag in meta_aggregates):
meta_aggregates = list(meta_aggregates)
meta_aggregates.append(("count", True))
for aggr, is_hidden in meta_aggregates:
if (
aggregation_metadata.name,
aggr,
Expand Down Expand Up @@ -629,8 +633,13 @@ def aggregate(self, data, timestamp):
def get_features(self, timestamp):
result = {}
for aggregation_bucket in self.aggregation_buckets.values():
if isinstance(aggregation_bucket, VirtualAggregationBuckets) or aggregation_bucket.explicit_windows:
if isinstance(aggregation_bucket, VirtualAggregationBuckets):
result.update(aggregation_bucket.get_features(timestamp))
elif aggregation_bucket.explicit_windows:
count_features = self.aggregation_buckets[f"{aggregation_bucket.name}_count"].get_features(
timestamp, aggregation_bucket.explicit_windows.windows
)
result.update(aggregation_bucket.get_features(timestamp, count_features=count_features))

return result

Expand Down Expand Up @@ -832,7 +841,7 @@ def get_aggregation_for_aggregation(self):
return "sum"
return self.aggregation

def get_features(self, timestamp, windows=None):
def get_features(self, timestamp, windows=None, count_features=None):
result = {}
if not windows:
if self.explicit_windows:
Expand All @@ -852,7 +861,11 @@ def get_features(self, timestamp, windows=None):
# In case our pre aggregates already have the answer
for win in windows:
result[f"{self.name}_{self.aggregation}_{win[1]}"] = self._current_aggregate_values[win].value

if (
self.aggregation != "count"
and (count_features.get(f"{self.name}_count_{win[1]}", 0) if count_features else 1) == 0
):
result[f"{self.name}_{self.aggregation}_{win[1]}"] = math.nan
return result

def calculate_features(self, timestamp, windows):
Expand Down Expand Up @@ -1088,7 +1101,7 @@ def __init__(self, name, aggregation, window, base_time, args):
def aggregate(self, timestamp, value):
pass

def get_features(self, timestamp):
def get_features(self, timestamp, count_features=None):
result = {}

args_results = [list(bucket.get_features(timestamp, self.window.windows).values()) for bucket in self.args]
Expand Down

0 comments on commit 8a98e7b

Please sign in to comment.