The primary objective of this repository is to provide an implementation of stateful LSTM regression model using PyTorch.
During the model development and training phase for the 2024 Green Battery Hackathon, I came across the concept of Stateful LSTM.
While TensorFlow offers an easy implementation of stateful LSTM by simply setting stateful=True
, as a PyTorch user, I found limited resources, documentation, and implementations related to creating stateful LSTM models.
Consequently, after conducting thorough research, I decided to create my own implementation of stateful LSTM using PyTorch and share the resualt of my work.
-
Clone the project repository to your local machine
-
Move to the repository
cd lstm-forecast
-
This project is managed using Poetry. If Poetry isn't installed on your machine, please run the command bellow to install Poetry
- Linux, macOS, Windows (WSL)
curl -sSL https://install.python-poetry.org | python3 -
- Windows (Powershell)
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | py -
Checkout the instructions on Poetry's documentation for further information.
-
Install the dependencies
poetry install
-
Train
Train the model by executing
train_runner.py
python train_runner.py
-
Test
All the unit test is located at
test/
. To run the test, execute:python -m pytest
Stateful LSTM is a variant of the Long Short-Term Memory (LSTM) model. In a standard LSTM model, the hidden state and cell state are reset after each sequence or batch of data is processed. However, in a stateful LSTM, the hidden state and cell state are preserved between batches or sequences. This means that the final hidden state and cell state from one batch or sequence become the initial hidden state and cell state for the next batch or sequence. This preservation of state allows the model to maintain memory beyond batches or sequences. By retaining information from previous batches or sequences, the stateful LSTM can capture longer-term dependencies in the data. Stateful LSTMs are particularly useful when dealing with data that has a clear temporal structure, such as time series data or sequences of text. They can help improve the model's ability to learn patterns and relationships in the data over time, leading to more accurate predictions or classifications.
The TensorFlow documentation defines stateful as:
Boolean (default: False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.
This can be illustrated in the figure bellow:
The hidden and cell state are preserved through each epoch. At the start of each epoch, the hidden and cell state are initialised to 0.
To align the data in the correct order, there are some key points to keep in mind:
- Preserve the data order:
shuffle = False
- Batch size should be equal to time window size:
batch_size = time_serise_window_size
- Dataloader
drop_last = True
: This is to ensure the consistency of batch sizes through out the training loop.
The datasets were data collected and provided by the 2024 Green Battery Hackathon.
timestamp
: The time this row was recorded.price
: The electricity spot price for South Australia in $/MWh. Provided by opennem.demand
: The total electricity demand in MW for South Australia for large stable consumers (e.g. factories, mines, etc.) in MW. Provided by the legends at opennem.temp_air
: The air temperature in degrees Celsius in an indicative location in South Australia.pv_power
: The rate of power generated by the simulated solar panel in kW. Provided by solcast.pv_power_forecast_1h
: A forecast of what thepv_power
will be in 1 from thetimestamp
. This data is missing for the first 2 years of training data. Provided by solcast.pv_power_forecast_2h
: A forecast of what thepv_power
will be in 2 from thetimestamp
. This data is missing for the first 2 years of training data. Provided by solcast.pv_power_forecast_24h
: A forecast of what thepv_power
will be in 24 from thetimestamp
. This data is missing for the first 2 years of training data. Provided by solcast.pv_power_basic
: An estimate of the rate of solar power that is currently being generated across the whole state of South Australia. This data is missing for the first 2 years of training data. Provided by solcast.