Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Mamba models implementation for forecasting tasks #378

Merged
merged 5 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,6 @@ data_loader_all.py
/scripts/imputation/tmp/
/utils/self_tools.py
/scripts/exp_scripts/

/checkpoints/
/results/
4 changes: 3 additions & 1 deletion exp/exp_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \
Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN
Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, MambaSimple, Mamba


class Exp_Basic(object):
Expand All @@ -28,6 +28,8 @@ def __init__(self, args):
'Koopa': Koopa,
'TiDE': TiDE,
'FreTS': FreTS,
'MambaSimple': MambaSimple,
'Mamba': Mamba,
'TimeMixer': TimeMixer,
'TSMixer': TSMixer,
'SegRNN': SegRNN
Expand Down
50 changes: 50 additions & 0 deletions models/Mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from mamba_ssm import Mamba

from layers.Embed import DataEmbedding

class Model(nn.Module):

def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.pred_len = configs.pred_len

self.d_inner = configs.d_model * configs.expand
self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto"

self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)

self.mamba = Mamba(
d_model = configs.d_model,
d_state = configs.d_ff,
d_conv = configs.d_conv,
expand = configs.expand,
)

self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)

def forecast(self, x_enc, x_mark_enc):
mean_enc = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc = x_enc / std_enc

x = self.embedding(x_enc, x_mark_enc)
x = self.mamba(x)
x_out = self.out_layer(x)

x_out = x_out * std_enc + mean_enc
return x_out

def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
x_out = self.forecast(x_enc, x_mark_enc)
return x_out[:, -self.pred_len:, :]

# other tasks not implemented
175 changes: 175 additions & 0 deletions models/MambaSimple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, einsum

from layers.Embed import DataEmbedding


class Model(nn.Module):
"""
Mamba, linear-time sequence modeling with selective state spaces O(L)
Paper link: https://arxiv.org/abs/2312.00752
Implementation refernce: https://github.com/johnma2006/mamba-minimal/
"""

def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.pred_len = configs.pred_len

self.d_inner = configs.d_model * configs.expand
self.dt_rank = math.ceil(configs.d_model / 16)

self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)

self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)])
self.norm = RMSNorm(configs.d_model)

self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)

# def short_term_forecast(self, x_enc, x_mark_enc):
def forecast(self, x_enc, x_mark_enc):
mean_enc = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc = x_enc / std_enc

x = self.embedding(x_enc, x_mark_enc)
for layer in self.layers:
x = layer(x)

x = self.norm(x)
x_out = self.out_layer(x)

x_out = x_out * std_enc + mean_enc
return x_out

# def long_term_forecast(self, x_enc, x_mark_enc):
# x = self.embedding(x_enc, x_mark_enc)
# for layer in self.layers:
# x = layer(x)

# x = self.norm(x)
# x_out = self.out_layer(x)
# return x_out

def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
x_out = self.forecast(x_enc, x_mark_enc)
return x_out[:, -self.pred_len:, :]


# other tasks not implemented


class ResidualBlock(nn.Module):
def __init__(self, configs, d_inner, dt_rank):
super(ResidualBlock, self).__init__()

self.mixer = MambaBlock(configs, d_inner, dt_rank)
self.norm = RMSNorm(configs.d_model)

def forward(self, x):
output = self.mixer(self.norm(x)) + x
return output

class MambaBlock(nn.Module):
def __init__(self, configs, d_inner, dt_rank):
super(MambaBlock, self).__init__()
self.d_inner = d_inner
self.dt_rank = dt_rank

self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False)

self.conv1d = nn.Conv1d(
in_channels = self.d_inner,
out_channels = self.d_inner,
bias = True,
kernel_size = configs.d_conv,
padding = configs.d_conv - 1,
groups = self.d_inner
)

# takes in x and outputs the input-specific delta, B, C
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False)

# projects delta
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))

self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False)

def forward(self, x):
"""
Figure 3 in Section 3.4 in the paper
"""
(b, l, d) = x.shape

x_and_res = self.in_proj(x) # [B, L, 2 * d_inner]
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)

x = rearrange(x, "b l d -> b d l")
x = self.conv1d(x)[:, :, :l]
x = rearrange(x, "b d l -> b l d")

x = F.silu(x)

y = self.ssm(x)
y = y * F.silu(res)

output = self.out_proj(y)
return output


def ssm(self, x):
"""
Algorithm 2 in Section 3.2 in the paper
"""

(d_in, n) = self.A_log.shape

A = -torch.exp(self.A_log.float()) # [d_in, n]
D = self.D.float() # [d_in]

x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff]
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n]
delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in]
y = self.selective_scan(x, delta, A, B, C, D)

return y

def selective_scan(self, u, delta, A, B, C, D):
(b, l, d_in) = u.shape
n = A.shape[1]

deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization
deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B"

# selective scan, sequential instead of parallel
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], "b d n, b n -> b d")
ys.append(y)

y = torch.stack(ys, dim=1) # [B, L, d_in]
y = y + u * D

return y

class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super(RMSNorm, self).__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))

def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output
15 changes: 12 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
parser.add_argument('--anomaly_ratio', type=float, default=0.25, help='prior anomaly ratio (%)')

# model define
parser.add_argument('--expand', type=int, default=2, help='expansion factor for Mamba')
parser.add_argument('--d_conv', type=int, default=4, help='conv kernel size for Mamba')
parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock')
parser.add_argument('--num_kernels', type=int, default=6, help='for Inception')
parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
Expand Down Expand Up @@ -107,7 +109,10 @@
parser.add_argument('--p_hidden_layers', type=int, default=2, help='number of hidden layers in projector')

args = parser.parse_args()
args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
# args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
args.use_gpu = True if torch.cuda.is_available() else False

print(torch.cuda.is_available())

if args.use_gpu and args.use_multi_gpu:
args.devices = args.devices.replace(' ', '')
Expand Down Expand Up @@ -135,7 +140,7 @@
for ii in range(args.itr):
# setting record of experiments
exp = Exp(args) # set experiments
setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format(
setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_expand{}_dc{}_fc{}_eb{}_dt{}_{}_{}'.format(
args.task_name,
args.model_id,
args.model,
Expand All @@ -149,6 +154,8 @@
args.e_layers,
args.d_layers,
args.d_ff,
args.expand,
args.d_conv,
args.factor,
args.embed,
args.distil,
Expand All @@ -162,7 +169,7 @@
torch.cuda.empty_cache()
else:
ii = 0
setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format(
setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_expand{}_dc{}_fc{}_eb{}_dt{}_{}_{}'.format(
args.task_name,
args.model_id,
args.model,
Expand All @@ -176,6 +183,8 @@
args.e_layers,
args.d_layers,
args.d_ff,
args.expand,
args.d_conv,
args.factor,
args.embed,
args.distil,
Expand Down
30 changes: 30 additions & 0 deletions scripts/long_term_forecast/ECL_script/Mamba.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
model_name=Mamba

for pred_len in 96 192 336 720
# for pred_len in 336 720
do

python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/electricity/ \
--data_path electricity.csv \
--model_id ECL_$pred_len'_'$pred_len \
--model $model_name \
--data custom \
--features M \
--seq_len $pred_len \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 321 \
--expand 2 \
--d_ff 16 \
--d_conv 4 \
--c_out 321 \
--d_model 128 \
--des 'Exp' \
--itr 1 \

done
29 changes: 29 additions & 0 deletions scripts/long_term_forecast/ETT_script/MambaSimple_ETTh1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
model_name=MambaSimple

for pred_len in 96 192 336 720
do

python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_$pred_len'_'$pred_len \
--model $model_name \
--data ETTh1 \
--features M \
--seq_len $pred_len \
--label_len 48 \
--pred_len $pred_len \
--e_layers 2 \
--d_layers 1 \
--enc_in 7 \
--expand 2 \
--d_ff 16 \
--d_conv 4 \
--c_out 7 \
--d_model 128 \
--des 'Exp' \
--itr 1 \

done
4 changes: 4 additions & 0 deletions scripts/long_term_forecast/ETT_script/Mamba_ETT_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
./scripts/long_term_forecast/ETT_script/Mamba_ETTh1.sh | tee mamba_ett.txt
./scripts/long_term_forecast/ETT_script/Mamba_ETTh2.sh | tee mamba_ett.txt -a
./scripts/long_term_forecast/ETT_script/Mamba_ETTm1.sh | tee mamba_ett.txt -a
./scripts/long_term_forecast/ETT_script/Mamba_ETTm2.sh | tee mamba_ett.txt -a
Loading