Skip to content

Commit

Permalink
Merge pull request #198 from google-research/timesfm_2d0
Browse files Browse the repository at this point in the history
Nan handling
  • Loading branch information
siriuz42 authored Dec 20, 2024
2 parents 27f4037 + faf1c76 commit 1542481
Showing 1 changed file with 63 additions and 2 deletions.
65 changes: 63 additions & 2 deletions src/timesfm/timesfm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,65 @@ def freq_map(freq: str):
"""Returns the frequency map for the given frequency string."""
freq = str.upper(freq)
if (freq.endswith("H") or freq.endswith("T") or freq.endswith("MIN") or
freq.endswith("D") or freq.endswith("B") or freq.endswith("U")):
freq.endswith("D") or freq.endswith("B") or freq.endswith("U") or
freq.endswith("S")):
return 0
elif freq.endswith(("W", "M", "MS")):
return 1
elif freq.endswith("Y") or freq.endswith("Q"):
elif freq.endswith("Y") or freq.endswith("Q") or freq.endswith("A"):
return 2
else:
raise ValueError(f"Invalid frequency: {freq}")

def strip_leading_nans(arr):
"""
Removes contiguous NaN values from the beginning of a NumPy array.
Args:
arr: The input NumPy array.
Returns:
A new NumPy array with leading NaN values removed.
If the array is all NaNs or empty, returns an empty array.
"""

isnan = np.isnan(arr)
first_valid_index = np.argmax(~isnan)
return arr[first_valid_index:]

def linear_interpolation(arr):
"""
Performs linear interpolation to fill NaN values in a 1D numpy array.
Args:
arr: The 1D numpy array containing NaN values.
Returns:
A new numpy array with NaN values filled using linear interpolation,
or the original array if no NaNs are present.
Returns None if the input is not a 1D array.
Returns the original array if there are no NaN values.
"""

nans = np.isnan(arr)
if not np.any(nans): # Check if there are any NaNs
return arr

x = lambda z: z.nonzero()[0]
nans_indices = x(nans)
non_nans_indices = x(~nans)
non_nans_values = arr[~nans]

try:
arr[nans] = np.interp(nans_indices, non_nans_indices, non_nans_values)
except ValueError:
if len(non_nans_values) > 0:
mu = np.nanmean(arr)
else:
mu = 0.0
arr = np.where(np.isfinite(arr), arr, mu)
return arr


# Per time series normalization: forward.
def _normalize(batch):
Expand Down Expand Up @@ -313,6 +363,17 @@ def forecast(
ValueError: If the checkpoint is not properly loaded.
"""
stats = None

tmp_inputs = []
for each_input in inputs:
arr = np.array(each_input)
if not np.isfinite(arr).all():
arr = np.where(np.isfinite(arr), arr, np.nan)
arr = strip_leading_nans(arr)
arr = linear_interpolation(arr)
tmp_inputs.append(arr)

inputs = tmp_inputs
if normalize:
inputs, stats = _normalize(inputs)
mean_forecast, quantile_forecast = self._forecast(
Expand Down

0 comments on commit 1542481

Please sign in to comment.