diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index 3087a9bf..0eef7585 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -48,6 +48,7 @@ def __init__( add_image_embedding_channel: bool = False, include_gsp_yield_history: bool = True, include_sun: bool = True, + include_time: bool = False, embedding_dim: Optional[int] = 16, forecast_minutes: int = 30, history_minutes: int = 60, @@ -97,6 +98,7 @@ def __init__( embedding of the GSP ID. include_gsp_yield_history: Include GSP yield data. include_sun: Include sun azimuth and altitude data. + include_time: Include sine and cosine of dates and times. embedding_dim: Number of embedding dimensions to use for GSP ID. Not included if set to `None`. forecast_minutes: The amount of minutes that should be forecasted. @@ -141,6 +143,7 @@ def __init__( self.include_nwp = nwp_encoders_dict is not None and len(nwp_encoders_dict) != 0 self.include_pv = pv_encoder is not None self.include_sun = include_sun + self.include_time = include_time self.include_wind = wind_encoder is not None self.include_sensor = sensor_encoder is not None self.embedding_dim = embedding_dim @@ -283,6 +286,16 @@ def __init__( # Update num features fusion_input_features += 16 + if self.include_time: + self.time_fc1 = nn.Linear( + in_features=4 + * (self.forecast_len + self.forecast_len_ignore + self.history_len + 1), + out_features=32, + ) + + # Update num features + fusion_input_features += 32 + if include_gsp_yield_history: # Update num features fusion_input_features += self.history_len