Skip to content

Commit

Permalink
Consistency check, allow partial update.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasper Zschiegner committed May 26, 2023
1 parent 177f753 commit feec770
Showing 1 changed file with 39 additions and 11 deletions.
50 changes: 39 additions & 11 deletions src/gluonts/zebras/_time_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def update(self, other: TimeFrame, default=np.nan) -> TimeFrame:
The new frame spans both input frames and inserts default values if
there is a gap between the two frames. If both frames overlap, the
second can overwrite the values of the first frame.
second overwrites the values of the first frame.
Static columns and metadata is also updated, and the second frames
value take precedence.
Expand All @@ -188,12 +188,32 @@ def update(self, other: TimeFrame, default=np.nan) -> TimeFrame:
Note: ``update`` will reset the padding.
"""

assert self.index is not None and other.index is not None
assert self.index.freq == other.index.freq
if self.index is not None or other.index is not None:
raise ValueError("Both time frames need to have an index.")

if self.index.freq != other.index.freq:
raise ValueError("frequency mismatch on index.")

# ensure tdims match
for name, left, right in join_items(self.tdims, other.tdims, "inner"):
if left != right:
raise ValueError(
f"tdims of {name} don't match {left} != {right}"
)

# ensure column shapes match
for name, left, right in join_items(
self.columns, other.columns, "inner"
):
tdim = self.tdims[name]

if replace(left.shape, tdim, 0) != replace(right.shape, tdim, 0):
raise ValueError(f"Incompatible shapes of columns {name}")

start = min(self.index.start, other.index.start)
end = max(self.index.end, other.index.end)

# create a new index that spans the new range
index = Periods(
np.arange(
start.to_numpy(),
Expand All @@ -203,21 +223,29 @@ def update(self, other: TimeFrame, default=np.nan) -> TimeFrame:
),
start.freq,
)
# get position of self and other relative to new index
# (one of them will be zero)
self_idx0 = index.index_of(self.index.start)
other_idx0 = index.index_of(other.index.start)

# create new columns, by first filling them with default values and
# then writing the values of self and other to them
columns = {}
for name, (self_col, other_col, tdim) in rows_to_columns(
(self.columns, other.columns, self.tdims)
).items():
new_shape = list(self_col.shape)
new_shape[tdim] = len(index)
for name, self_col, other_col in join_items(
self.columns, other.columns, "outer"
):
tdim = self.tdims[name]

values = np.full(new_shape, default)
values = np.full(
replace(self_col.shape, tdim, len(index)),
default,
)
view = AxisView(values, tdim)

view[self_idx0 : self_idx0 + len(self)] = self_col
view[other_idx0 : other_idx0 + len(other)] = other_col
if self_col is not None:
view[self_idx0 : self_idx0 + len(self)] = self_col
if other_col is not None:
view[other_idx0 : other_idx0 + len(other)] = other_col

columns[name] = values

Expand Down

0 comments on commit feec770

Please sign in to comment.