Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include trigonometric time features #347

Merged
merged 18 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ocf_datapipes/batch/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ class BatchKey(Enum):
gsp_x_osgb_fourier = auto()
gsp_time_utc_fourier = auto() # (batch_size, time, n_fourier_features)

# -------------- TIME -------------------------------------------
# Sine and cosine of date of year and time of day at every timestep.
# shape = (batch_size, n_timesteps)
# This is calculated for wind only inside datapipes.
wind_date_sin = auto()
wind_date_cos = auto()
wind_time_sin = auto()
wind_time_cos = auto()

# -------------- SUN --------------------------------------------
# Solar position at every timestep. shape = (batch_size, n_timesteps)
# The solar position data comes from two alternative sources: either the Sun pre-prepared
Expand Down
8 changes: 6 additions & 2 deletions ocf_datapipes/training/windnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,15 @@ def __iter__(self):
numpy_modalities.append(datapipes_dict["wind"].convert_wind_to_numpy_batch())

logger.debug("Combine all the data sources")
combined_datapipe = MergeNumpyModalities(numpy_modalities).add_sun_position(
logger.debug("Adding trigonometric date and time")
combined_datapipe = MergeNumpyModalities(numpy_modalities).add_trigonometric_date_time(
modality_name="wind"
)
# combined_datapipe = MergeNumpyModalities(numpy_modalities).add_sun_position(
# modality_name="wind"
# )

logger.info("Filtering out samples with no data")
# logger.info("Filtering out samples with no data")
# if self.check_satellite_no_zeros:
# in production we don't want any nans in the satellite data
# combined_datapipe = combined_datapipe.map(check_nans_in_satellite_data)
Expand Down
1 change: 1 addition & 0 deletions ocf_datapipes/transform/numpy_batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

from .add_fourier_space_time import AddFourierSpaceTimeIterDataPipe as AddFourierSpaceTime
from .add_topographic_data import AddTopographicDataIterDataPipe as AddTopographicData
from .datetime_features import AddTrigonometricDateTimeIterDataPipe as AddTrigonometricDateTime
from .sun_position import AddSunPositionIterDataPipe as AddSunPosition
60 changes: 60 additions & 0 deletions ocf_datapipes/transform/numpy_batch/datetime_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Datapipes to trigonometric date and time to NumpyBatch"""

import numpy as np
from numpy.typing import NDArray
from torch.utils.data import IterDataPipe, functional_datapipe

from ocf_datapipes.batch import BatchKey


def _get_date_time_in_pi(
dt: NDArray[np.datetime64],
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
day_of_year = (dt - dt.astype("datetime64[Y]")).astype(int)
minute_of_day = (dt - dt.astype("datetime64[D]")).astype(int)

# converting into positions on sin-cos circle
time_in_pi = (2 * np.pi) * (minute_of_day / (24 * 3600))
date_in_pi = (2 * np.pi) * (day_of_year / (365 * 24 * 3600))

return date_in_pi, time_in_pi


@functional_datapipe("add_trigonometric_date_time")
class AddTrigonometricDateTimeIterDataPipe(IterDataPipe):
"""Adds the trigonometric encodings of date of year, time of day to the NumpyBatch"""

def __init__(self, source_datapipe: IterDataPipe, modality_name: str):
"""
Adds the sine and cosine of time to the NumpyBatch

Args:
source_datapipe: Datapipe of NumpyBatch
modality_name: Modality to add the time for
"""
self.source_datapipe = source_datapipe
self.modality_name = modality_name
assert self.modality_name in [
"wind",
], f"Trigonometric time not implemented for {self.modality_name}"

def __iter__(self):
for np_batch in self.source_datapipe:
time_utc = np_batch[BatchKey.wind_time_utc]

times: NDArray[np.datetime64] = time_utc.astype("datetime64[s]")

date_in_pi, time_in_pi = _get_date_time_in_pi(times)

# Store
date_sin_batch_key = BatchKey[self.modality_name + "_date_sin"]
date_cos_batch_key = BatchKey[self.modality_name + "_date_cos"]
time_sin_batch_key = BatchKey[self.modality_name + "_time_sin"]
time_cos_batch_key = BatchKey[self.modality_name + "_time_cos"]

np_batch[date_sin_batch_key] = np.sin(date_in_pi)
np_batch[date_cos_batch_key] = np.cos(date_in_pi)
np_batch[time_sin_batch_key] = np.sin(time_in_pi)
np_batch[time_cos_batch_key] = np.cos(time_in_pi)

yield np_batch
32 changes: 32 additions & 0 deletions tests/transform/numpy_batch/test_datetime_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
from datetime import datetime

from ocf_datapipes.transform.numpy_batch import AddTrigonometricDateTime

from ocf_datapipes.transform.numpy_batch.datetime_features import _get_date_time_in_pi


def test_get_date_time_in_pi():
AUdaltsova marked this conversation as resolved.
Show resolved Hide resolved
times = [
"2020-01-01T00:00:01",
"2020-04-01T06:00:00",
"2020-07-01T12:00:00",
"2020-09-30T18:00:00",
"2020-12-31T23:59:59",
"2021-01-01T00:00:01",
"2021-04-02T06:00:00",
"2021-07-02T12:00:00",
"2021-10-01T18:00:00",
"2021-12-31T23:59:59",
]

expected_times_in_pi = [0, 0.5 * np.pi, np.pi, 1.5 * np.pi, 2 * np.pi] * 2

times = np.array([datetime.fromisoformat(time) for time in times], dtype="datetime64[s]")

date_in_pi, time_in_pi = _get_date_time_in_pi(times)

assert np.isclose(np.cos(time_in_pi), np.cos(expected_times_in_pi), atol=1e-04).all()
assert np.isclose(np.sin(time_in_pi), np.sin(expected_times_in_pi), atol=1e-04).all()
assert np.isclose(np.cos(date_in_pi), np.cos(expected_times_in_pi), atol=0.01).all()
assert np.isclose(np.sin(date_in_pi), np.sin(expected_times_in_pi), atol=0.02).all()
Loading