From 22175f643f123320368978ad56bed7baa32feb01 Mon Sep 17 00:00:00 2001 From: frecklebars Date: Thu, 22 Feb 2024 11:20:07 +0000 Subject: [PATCH 1/4] added mamba model https://arxiv.org/abs/2312.00752 --- .gitignore | 2 + exp/exp_basic.py | 5 +- models/Mamba.py | 165 ++++++++++++++++++++++++ run.py | 15 ++- scripts/short_term_forecast/Mamba_M4.sh | 135 +++++++++++++++++++ 5 files changed, 317 insertions(+), 5 deletions(-) create mode 100644 models/Mamba.py create mode 100644 scripts/short_term_forecast/Mamba_M4.sh diff --git a/.gitignore b/.gitignore index 7aefd369..e8590d70 100644 --- a/.gitignore +++ b/.gitignore @@ -157,3 +157,5 @@ data_loader_all.py /scripts/imputation/tmp/ /utils/self_tools.py /scripts/exp_scripts/ + +/checkpoints/ \ No newline at end of file diff --git a/exp/exp_basic.py b/exp/exp_basic.py index bf077f06..b5da9e76 100644 --- a/exp/exp_basic.py +++ b/exp/exp_basic.py @@ -2,7 +2,7 @@ import torch from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \ Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \ - Koopa, TiDE, FreTS + Koopa, TiDE, FreTS, Mamba class Exp_Basic(object): @@ -27,7 +27,8 @@ def __init__(self, args): 'iTransformer': iTransformer, 'Koopa': Koopa, 'TiDE': TiDE, - 'FreTS': FreTS + 'FreTS': FreTS, + 'Mamba': Mamba, } self.device = self._acquire_device() self.model = self._build_model().to(self.device) diff --git a/models/Mamba.py b/models/Mamba.py new file mode 100644 index 00000000..5ab12dd6 --- /dev/null +++ b/models/Mamba.py @@ -0,0 +1,165 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat, einsum + +from layers.Embed import DataEmbedding + + +class Model(nn.Module): + """ + Mamba, linear-time sequence modeling with selective state spaces O(L) + Paper link: https://arxiv.org/abs/2312.00752 + Implementation refernce: https://github.com/johnma2006/mamba-minimal/ + """ + + def __init__(self, configs): + super(Model, self).__init__() + self.task_name = configs.task_name + self.pred_len = configs.pred_len + + self.d_inner = configs.d_model * configs.expand + self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto" + + self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout) + + self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)]) + self.norm = RMSNorm(configs.d_model) + + self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False) + + def short_term_forecast(self, x_enc, x_mark_enc): + mean_enc = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc - mean_enc + std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() + x_enc = x_enc / std_enc + + x = self.embedding(x_enc, x_mark_enc) + for layer in self.layers: + x = layer(x) + + x = self.norm(x) + x_out = self.out_layer(x) + + x_out = x_out * std_enc + mean_enc + return x_out + + def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): + if self.task_name == 'short_term_forecast': + x_out = self.short_term_forecast(x_enc, x_mark_enc) + # print(f"MAMBA FORECAST SIZE: {x_out.shape}") + return x_out[:, -self.pred_len:, :] + + # other tasks not implemented + + +class ResidualBlock(nn.Module): + def __init__(self, configs, d_inner, dt_rank): + super(ResidualBlock, self).__init__() + + self.mixer = MambaBlock(configs, d_inner, dt_rank) + self.norm = RMSNorm(configs.d_model) + + def forward(self, x): + output = self.mixer(self.norm(x)) + x + return output + +class MambaBlock(nn.Module): + def __init__(self, configs, d_inner, dt_rank): + super(MambaBlock, self).__init__() + self.d_inner = d_inner + self.dt_rank = dt_rank + + self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False) + + self.conv1d = nn.Conv1d( + in_channels = self.d_inner, + out_channels = self.d_inner, + bias = True, + kernel_size = configs.d_conv, + padding = configs.d_conv - 1, # TODO dont understand this; come back and do kernel = 3 padding = 1 instead if it doesnt work? + groups = self.d_inner + ) + + # takes in x and outputs the input-specific delta, B, C + self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False) + + # projects delta + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) + + A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner) + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(self.d_inner)) + + self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False) + + def forward(self, x): + """ + Figure 3 in Section 3.4 in the paper + """ + (b, l, d) = x.shape + + x_and_res = self.in_proj(x) # [B, L, 2 * d_inner] + (x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1) + + x = rearrange(x, "b l d -> b d l") + x = self.conv1d(x)[:, :, :l] + x = rearrange(x, "b d l -> b l d") + + x = F.silu(x) + + y = self.ssm(x) + y = y * F.silu(res) + + output = self.out_proj(y) + return output + + + def ssm(self, x): + """ + Algorithm 2 in Section 3.2 in the paper + """ + + (d_in, n) = self.A_log.shape + + A = -torch.exp(self.A_log.float()) # [d_in, n] + D = self.D.float() # [d_in] + + x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff] + (delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n] + delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in] + y = self.selective_scan(x, delta, A, B, C, D) + + return y + + def selective_scan(self, u, delta, A, B, C, D): + (b, l, d_in) = u.shape + n = A.shape[1] + + deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization + deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B" + + # selective scan, sequential instead of parallel + x = torch.zeros((b, d_in, n), device=deltaA.device) + ys = [] + for i in range(l): + x = deltaA[:, i] * x + deltaB_u[:, i] + y = einsum(x, C[:, i, :], "b d n, b n -> b d") + ys.append(y) + + y = torch.stack(ys, dim=1) # [B, L, d_in] + y = y + u * D + + return y + +class RMSNorm(nn.Module): + def __init__(self, d_model, eps=1e-5): + super(RMSNorm, self).__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(d_model)) + + def forward(self, x): + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + return output \ No newline at end of file diff --git a/run.py b/run.py index 7952b1c1..71edf169 100644 --- a/run.py +++ b/run.py @@ -51,6 +51,8 @@ parser.add_argument('--anomaly_ratio', type=float, default=0.25, help='prior anomaly ratio (%)') # model define + parser.add_argument('--expand', type=int, default=2, help='expansion factor for Mamba') + parser.add_argument('--d_conv', type=int, default=4, help='conv kernel size for Mamba') parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock') parser.add_argument('--num_kernels', type=int, default=6, help='for Inception') parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') @@ -98,7 +100,10 @@ args = parser.parse_args() - args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False + # args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False + args.use_gpu = True if torch.cuda.is_available() else False + + print(torch.cuda.is_available()) if args.use_gpu and args.use_multi_gpu: args.devices = args.devices.replace(' ', '') @@ -126,7 +131,7 @@ for ii in range(args.itr): # setting record of experiments exp = Exp(args) # set experiments - setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format( + setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_expand{}_dc{}_fc{}_eb{}_dt{}_{}_{}'.format( args.task_name, args.model_id, args.model, @@ -140,6 +145,8 @@ args.e_layers, args.d_layers, args.d_ff, + args.expand, + args.d_conv, args.factor, args.embed, args.distil, @@ -153,7 +160,7 @@ torch.cuda.empty_cache() else: ii = 0 - setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format( + setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_expand{}_dc{}_fc{}_eb{}_dt{}_{}_{}'.format( args.task_name, args.model_id, args.model, @@ -167,6 +174,8 @@ args.e_layers, args.d_layers, args.d_ff, + args.expand, + args.d_conv, args.factor, args.embed, args.distil, diff --git a/scripts/short_term_forecast/Mamba_M4.sh b/scripts/short_term_forecast/Mamba_M4.sh new file mode 100644 index 00000000..417a6c34 --- /dev/null +++ b/scripts/short_term_forecast/Mamba_M4.sh @@ -0,0 +1,135 @@ +# export CUDA_VISIBLE_DEVICES=1 + +model_name=Mamba + +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Monthly' \ + --model_id m4_Monthly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Yearly' \ + --model_id m4_Yearly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Quarterly' \ + --model_id m4_Quarterly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Weekly' \ + --model_id m4_Weekly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Daily' \ + --model_id m4_Daily \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' + +python -u run.py \ + --task_name short_term_forecast \ + --is_training 1 \ + --root_path ./dataset/m4 \ + --seasonal_patterns 'Hourly' \ + --model_id m4_Hourly \ + --model $model_name \ + --data m4 \ + --features M \ + --e_layers 2 \ + --enc_in 1 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 1 \ + --batch_size 16 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + --learning_rate 0.001 \ + --loss 'SMAPE' \ No newline at end of file From 69ffc6c280ed80771e354edefc0ac082c08269da Mon Sep 17 00:00:00 2001 From: frecklebars Date: Sat, 30 Mar 2024 17:16:13 +0000 Subject: [PATCH 2/4] Hardware Aware Mamba + Long Term Forecasting --- .gitignore | 3 +- exp/exp_basic.py | 3 +- models/Mamba.py | 145 ++------------- models/MambaSimple.py | 175 ++++++++++++++++++ .../long_term_forecast/ECL_script/Mamba.sh | 30 +++ .../ETT_script/MambaSimple_ETTh1.sh | 29 +++ .../ETT_script/Mamba_ETT_all.sh | 4 + .../ETT_script/Mamba_ETTh1.sh | 28 +++ .../ETT_script/Mamba_ETTh2.sh | 28 +++ .../ETT_script/Mamba_ETTm1.sh | 28 +++ .../ETT_script/Mamba_ETTm2.sh | 28 +++ .../Exchange_script/Mamba.sh | 28 +++ scripts/long_term_forecast/Mamba_all.sh | 4 + .../Traffic_script/Mamba.sh | 29 +++ .../Weather_script/Mamba.sh | 29 +++ 15 files changed, 459 insertions(+), 132 deletions(-) create mode 100644 models/MambaSimple.py create mode 100644 scripts/long_term_forecast/ECL_script/Mamba.sh create mode 100644 scripts/long_term_forecast/ETT_script/MambaSimple_ETTh1.sh create mode 100644 scripts/long_term_forecast/ETT_script/Mamba_ETT_all.sh create mode 100644 scripts/long_term_forecast/ETT_script/Mamba_ETTh1.sh create mode 100644 scripts/long_term_forecast/ETT_script/Mamba_ETTh2.sh create mode 100644 scripts/long_term_forecast/ETT_script/Mamba_ETTm1.sh create mode 100644 scripts/long_term_forecast/ETT_script/Mamba_ETTm2.sh create mode 100644 scripts/long_term_forecast/Exchange_script/Mamba.sh create mode 100644 scripts/long_term_forecast/Mamba_all.sh create mode 100644 scripts/long_term_forecast/Traffic_script/Mamba.sh create mode 100644 scripts/long_term_forecast/Weather_script/Mamba.sh diff --git a/.gitignore b/.gitignore index e8590d70..e51275f9 100644 --- a/.gitignore +++ b/.gitignore @@ -158,4 +158,5 @@ data_loader_all.py /utils/self_tools.py /scripts/exp_scripts/ -/checkpoints/ \ No newline at end of file +/checkpoints/ +/results/ \ No newline at end of file diff --git a/exp/exp_basic.py b/exp/exp_basic.py index b5da9e76..1bb8059e 100644 --- a/exp/exp_basic.py +++ b/exp/exp_basic.py @@ -2,7 +2,7 @@ import torch from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \ Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \ - Koopa, TiDE, FreTS, Mamba + Koopa, TiDE, FreTS, MambaSimple, Mamba class Exp_Basic(object): @@ -28,6 +28,7 @@ def __init__(self, args): 'Koopa': Koopa, 'TiDE': TiDE, 'FreTS': FreTS, + 'MambaSimple': MambaSimple, 'Mamba': Mamba, } self.device = self._acquire_device() diff --git a/models/Mamba.py b/models/Mamba.py index 5ab12dd6..edece42a 100644 --- a/models/Mamba.py +++ b/models/Mamba.py @@ -3,18 +3,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange, repeat, einsum -from layers.Embed import DataEmbedding +from mamba_ssm import Mamba +from layers.Embed import DataEmbedding class Model(nn.Module): - """ - Mamba, linear-time sequence modeling with selective state spaces O(L) - Paper link: https://arxiv.org/abs/2312.00752 - Implementation refernce: https://github.com/johnma2006/mamba-minimal/ - """ - + def __init__(self, configs): super(Model, self).__init__() self.task_name = configs.task_name @@ -22,144 +17,34 @@ def __init__(self, configs): self.d_inner = configs.d_model * configs.expand self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto" - + self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout) - self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)]) - self.norm = RMSNorm(configs.d_model) + self.mamba = Mamba( + d_model = configs.d_model, + d_state = configs.d_ff, + d_conv = configs.d_conv, + expand = configs.expand, + ) self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False) - def short_term_forecast(self, x_enc, x_mark_enc): + def forecast(self, x_enc, x_mark_enc): mean_enc = x_enc.mean(1, keepdim=True).detach() x_enc = x_enc - mean_enc std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() x_enc = x_enc / std_enc x = self.embedding(x_enc, x_mark_enc) - for layer in self.layers: - x = layer(x) - - x = self.norm(x) + x = self.mamba(x) x_out = self.out_layer(x) x_out = x_out * std_enc + mean_enc return x_out def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): - if self.task_name == 'short_term_forecast': - x_out = self.short_term_forecast(x_enc, x_mark_enc) - # print(f"MAMBA FORECAST SIZE: {x_out.shape}") + if self.task_name in ['short_term_forecast', 'long_term_forecast']: + x_out = self.forecast(x_enc, x_mark_enc) return x_out[:, -self.pred_len:, :] - - # other tasks not implemented - - -class ResidualBlock(nn.Module): - def __init__(self, configs, d_inner, dt_rank): - super(ResidualBlock, self).__init__() - - self.mixer = MambaBlock(configs, d_inner, dt_rank) - self.norm = RMSNorm(configs.d_model) - - def forward(self, x): - output = self.mixer(self.norm(x)) + x - return output - -class MambaBlock(nn.Module): - def __init__(self, configs, d_inner, dt_rank): - super(MambaBlock, self).__init__() - self.d_inner = d_inner - self.dt_rank = dt_rank - - self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False) - - self.conv1d = nn.Conv1d( - in_channels = self.d_inner, - out_channels = self.d_inner, - bias = True, - kernel_size = configs.d_conv, - padding = configs.d_conv - 1, # TODO dont understand this; come back and do kernel = 3 padding = 1 instead if it doesnt work? - groups = self.d_inner - ) - - # takes in x and outputs the input-specific delta, B, C - self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False) - - # projects delta - self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) - - A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner) - self.A_log = nn.Parameter(torch.log(A)) - self.D = nn.Parameter(torch.ones(self.d_inner)) - - self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False) - - def forward(self, x): - """ - Figure 3 in Section 3.4 in the paper - """ - (b, l, d) = x.shape - - x_and_res = self.in_proj(x) # [B, L, 2 * d_inner] - (x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1) - - x = rearrange(x, "b l d -> b d l") - x = self.conv1d(x)[:, :, :l] - x = rearrange(x, "b d l -> b l d") - - x = F.silu(x) - - y = self.ssm(x) - y = y * F.silu(res) - - output = self.out_proj(y) - return output - - - def ssm(self, x): - """ - Algorithm 2 in Section 3.2 in the paper - """ - - (d_in, n) = self.A_log.shape - - A = -torch.exp(self.A_log.float()) # [d_in, n] - D = self.D.float() # [d_in] - - x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff] - (delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n] - delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in] - y = self.selective_scan(x, delta, A, B, C, D) - - return y - - def selective_scan(self, u, delta, A, B, C, D): - (b, l, d_in) = u.shape - n = A.shape[1] - - deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization - deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B" - - # selective scan, sequential instead of parallel - x = torch.zeros((b, d_in, n), device=deltaA.device) - ys = [] - for i in range(l): - x = deltaA[:, i] * x + deltaB_u[:, i] - y = einsum(x, C[:, i, :], "b d n, b n -> b d") - ys.append(y) - - y = torch.stack(ys, dim=1) # [B, L, d_in] - y = y + u * D - - return y - -class RMSNorm(nn.Module): - def __init__(self, d_model, eps=1e-5): - super(RMSNorm, self).__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(d_model)) - def forward(self, x): - output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight - return output \ No newline at end of file + # other tasks not implemented \ No newline at end of file diff --git a/models/MambaSimple.py b/models/MambaSimple.py new file mode 100644 index 00000000..948cb5fb --- /dev/null +++ b/models/MambaSimple.py @@ -0,0 +1,175 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat, einsum + +from layers.Embed import DataEmbedding + + +class Model(nn.Module): + """ + Mamba, linear-time sequence modeling with selective state spaces O(L) + Paper link: https://arxiv.org/abs/2312.00752 + Implementation refernce: https://github.com/johnma2006/mamba-minimal/ + """ + + def __init__(self, configs): + super(Model, self).__init__() + self.task_name = configs.task_name + self.pred_len = configs.pred_len + + self.d_inner = configs.d_model * configs.expand + self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto" + + self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout) + + self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)]) + self.norm = RMSNorm(configs.d_model) + + self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False) + + # def short_term_forecast(self, x_enc, x_mark_enc): + def forecast(self, x_enc, x_mark_enc): + mean_enc = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc - mean_enc + std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() + x_enc = x_enc / std_enc + + x = self.embedding(x_enc, x_mark_enc) + for layer in self.layers: + x = layer(x) + + x = self.norm(x) + x_out = self.out_layer(x) + + x_out = x_out * std_enc + mean_enc + return x_out + + # def long_term_forecast(self, x_enc, x_mark_enc): + # x = self.embedding(x_enc, x_mark_enc) + # for layer in self.layers: + # x = layer(x) + + # x = self.norm(x) + # x_out = self.out_layer(x) + # return x_out + + def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): + if self.task_name in ['short_term_forecast', 'long_term_forecast']: + x_out = self.forecast(x_enc, x_mark_enc) + return x_out[:, -self.pred_len:, :] + + + # other tasks not implemented + + +class ResidualBlock(nn.Module): + def __init__(self, configs, d_inner, dt_rank): + super(ResidualBlock, self).__init__() + + self.mixer = MambaBlock(configs, d_inner, dt_rank) + self.norm = RMSNorm(configs.d_model) + + def forward(self, x): + output = self.mixer(self.norm(x)) + x + return output + +class MambaBlock(nn.Module): + def __init__(self, configs, d_inner, dt_rank): + super(MambaBlock, self).__init__() + self.d_inner = d_inner + self.dt_rank = dt_rank + + self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False) + + self.conv1d = nn.Conv1d( + in_channels = self.d_inner, + out_channels = self.d_inner, + bias = True, + kernel_size = configs.d_conv, + padding = configs.d_conv - 1, # TODO dont understand this; come back and do kernel = 3 padding = 1 instead if it doesnt work? + groups = self.d_inner + ) + + # takes in x and outputs the input-specific delta, B, C + self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False) + + # projects delta + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) + + A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner) + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(self.d_inner)) + + self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False) + + def forward(self, x): + """ + Figure 3 in Section 3.4 in the paper + """ + (b, l, d) = x.shape + + x_and_res = self.in_proj(x) # [B, L, 2 * d_inner] + (x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1) + + x = rearrange(x, "b l d -> b d l") + x = self.conv1d(x)[:, :, :l] + x = rearrange(x, "b d l -> b l d") + + x = F.silu(x) + + y = self.ssm(x) + y = y * F.silu(res) + + output = self.out_proj(y) + return output + + + def ssm(self, x): + """ + Algorithm 2 in Section 3.2 in the paper + """ + + (d_in, n) = self.A_log.shape + + A = -torch.exp(self.A_log.float()) # [d_in, n] + D = self.D.float() # [d_in] + + x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff] + (delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n] + delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in] + y = self.selective_scan(x, delta, A, B, C, D) + + return y + + def selective_scan(self, u, delta, A, B, C, D): + (b, l, d_in) = u.shape + n = A.shape[1] + + deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization + deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B" + + # selective scan, sequential instead of parallel + x = torch.zeros((b, d_in, n), device=deltaA.device) + ys = [] + for i in range(l): + x = deltaA[:, i] * x + deltaB_u[:, i] + y = einsum(x, C[:, i, :], "b d n, b n -> b d") + ys.append(y) + + y = torch.stack(ys, dim=1) # [B, L, d_in] + y = y + u * D + + return y + +class RMSNorm(nn.Module): + def __init__(self, d_model, eps=1e-5): + super(RMSNorm, self).__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(d_model)) + + def forward(self, x): + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + return output \ No newline at end of file diff --git a/scripts/long_term_forecast/ECL_script/Mamba.sh b/scripts/long_term_forecast/ECL_script/Mamba.sh new file mode 100644 index 00000000..931a1b74 --- /dev/null +++ b/scripts/long_term_forecast/ECL_script/Mamba.sh @@ -0,0 +1,30 @@ +model_name=Mamba + +for pred_len in 96 192 336 720 +# for pred_len in 336 720 +do + +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/electricity/ \ + --data_path electricity.csv \ + --model_id ECL_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len $pred_len \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 321 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 321 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + +done diff --git a/scripts/long_term_forecast/ETT_script/MambaSimple_ETTh1.sh b/scripts/long_term_forecast/ETT_script/MambaSimple_ETTh1.sh new file mode 100644 index 00000000..5e6606a3 --- /dev/null +++ b/scripts/long_term_forecast/ETT_script/MambaSimple_ETTh1.sh @@ -0,0 +1,29 @@ +model_name=MambaSimple + +for pred_len in 96 192 336 720 +do + +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTh1.csv \ + --model_id ETTh1_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTh1 \ + --features M \ + --seq_len $pred_len \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 7 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + +done \ No newline at end of file diff --git a/scripts/long_term_forecast/ETT_script/Mamba_ETT_all.sh b/scripts/long_term_forecast/ETT_script/Mamba_ETT_all.sh new file mode 100644 index 00000000..18558d6d --- /dev/null +++ b/scripts/long_term_forecast/ETT_script/Mamba_ETT_all.sh @@ -0,0 +1,4 @@ +./scripts/long_term_forecast/ETT_script/Mamba_ETTh1.sh | tee mamba_ett.txt +./scripts/long_term_forecast/ETT_script/Mamba_ETTh2.sh | tee mamba_ett.txt -a +./scripts/long_term_forecast/ETT_script/Mamba_ETTm1.sh | tee mamba_ett.txt -a +./scripts/long_term_forecast/ETT_script/Mamba_ETTm2.sh | tee mamba_ett.txt -a \ No newline at end of file diff --git a/scripts/long_term_forecast/ETT_script/Mamba_ETTh1.sh b/scripts/long_term_forecast/ETT_script/Mamba_ETTh1.sh new file mode 100644 index 00000000..9f29ac38 --- /dev/null +++ b/scripts/long_term_forecast/ETT_script/Mamba_ETTh1.sh @@ -0,0 +1,28 @@ +model_name=Mamba +for pred_len in 96 192 336 720 +do + +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTh1.csv \ + --model_id ETTh1_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTh1 \ + --features M \ + --seq_len $pred_len \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 7 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 7 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + +done \ No newline at end of file diff --git a/scripts/long_term_forecast/ETT_script/Mamba_ETTh2.sh b/scripts/long_term_forecast/ETT_script/Mamba_ETTh2.sh new file mode 100644 index 00000000..4c61ce74 --- /dev/null +++ b/scripts/long_term_forecast/ETT_script/Mamba_ETTh2.sh @@ -0,0 +1,28 @@ +model_name=Mamba + +for pred_len in 96 192 336 720 +do + +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTh2.csv \ + --model_id ETTh2_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTh2 \ + --features M \ + --seq_len $pred_len \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --enc_in 7 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 7 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + +done \ No newline at end of file diff --git a/scripts/long_term_forecast/ETT_script/Mamba_ETTm1.sh b/scripts/long_term_forecast/ETT_script/Mamba_ETTm1.sh new file mode 100644 index 00000000..eefff6fe --- /dev/null +++ b/scripts/long_term_forecast/ETT_script/Mamba_ETTm1.sh @@ -0,0 +1,28 @@ +model_name=Mamba + +for pred_len in 96 192 336 720 +do + +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTm1.csv \ + --model_id ETTm1_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTm1 \ + --features M \ + --seq_len $pred_len \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --enc_in 7 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 7 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + +done \ No newline at end of file diff --git a/scripts/long_term_forecast/ETT_script/Mamba_ETTm2.sh b/scripts/long_term_forecast/ETT_script/Mamba_ETTm2.sh new file mode 100644 index 00000000..2a4458cf --- /dev/null +++ b/scripts/long_term_forecast/ETT_script/Mamba_ETTm2.sh @@ -0,0 +1,28 @@ +model_name=Mamba + +for pred_len in 96 192 336 720 +do + +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/ETT-small/ \ + --data_path ETTm2.csv \ + --model_id ETTm2_$pred_len'_'$pred_len \ + --model $model_name \ + --data ETTm2 \ + --features M \ + --seq_len $pred_len \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --enc_in 7 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 7 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + +done \ No newline at end of file diff --git a/scripts/long_term_forecast/Exchange_script/Mamba.sh b/scripts/long_term_forecast/Exchange_script/Mamba.sh new file mode 100644 index 00000000..5a72e3fd --- /dev/null +++ b/scripts/long_term_forecast/Exchange_script/Mamba.sh @@ -0,0 +1,28 @@ +model_name=Mamba +for pred_len in 96 192 336 720 +do + +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/exchange_rate/ \ + --data_path exchange_rate.csv \ + --model_id Exchange_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len $pred_len \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 8 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 8 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + +done \ No newline at end of file diff --git a/scripts/long_term_forecast/Mamba_all.sh b/scripts/long_term_forecast/Mamba_all.sh new file mode 100644 index 00000000..af01e3f8 --- /dev/null +++ b/scripts/long_term_forecast/Mamba_all.sh @@ -0,0 +1,4 @@ +./scripts/long_term_forecast/ECL_script/Mamba.sh | tee mamba_all.txt +./scripts/long_term_forecast/Traffic_script/Mamba.sh | tee mamba_all.txt -a +./scripts/long_term_forecast/Exchange_script/Mamba.sh | tee mamba_all.txt -a +./scripts/long_term_forecast/Weather_script/Mamba.sh | tee mamba_all.txt -a diff --git a/scripts/long_term_forecast/Traffic_script/Mamba.sh b/scripts/long_term_forecast/Traffic_script/Mamba.sh new file mode 100644 index 00000000..f531e19a --- /dev/null +++ b/scripts/long_term_forecast/Traffic_script/Mamba.sh @@ -0,0 +1,29 @@ +model_name=Mamba + +for pred_len in 96 192 336 720 +do + +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/traffic/ \ + --data_path traffic.csv \ + --model_id traffic_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len $pred_len \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 862 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 862 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + +done \ No newline at end of file diff --git a/scripts/long_term_forecast/Weather_script/Mamba.sh b/scripts/long_term_forecast/Weather_script/Mamba.sh new file mode 100644 index 00000000..a9598bbd --- /dev/null +++ b/scripts/long_term_forecast/Weather_script/Mamba.sh @@ -0,0 +1,29 @@ +model_name=Mamba + +for pred_len in 96 192 336 720 +do + +python -u run.py \ + --task_name long_term_forecast \ + --is_training 1 \ + --root_path ./dataset/weather/ \ + --data_path weather.csv \ + --model_id weather_$pred_len'_'$pred_len \ + --model $model_name \ + --data custom \ + --features M \ + --seq_len $pred_len \ + --label_len 48 \ + --pred_len $pred_len \ + --e_layers 2 \ + --d_layers 1 \ + --enc_in 21 \ + --expand 2 \ + --d_ff 16 \ + --d_conv 4 \ + --c_out 21 \ + --d_model 128 \ + --des 'Exp' \ + --itr 1 \ + +done \ No newline at end of file From 21ebe5b40c809bc1ba79b81059d9dc14e407df98 Mon Sep 17 00:00:00 2001 From: frecklebars Date: Thu, 4 Apr 2024 14:15:58 +0100 Subject: [PATCH 3/4] removed tee from mamba all script --- scripts/long_term_forecast/Mamba_all.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/long_term_forecast/Mamba_all.sh b/scripts/long_term_forecast/Mamba_all.sh index af01e3f8..9e34eab3 100644 --- a/scripts/long_term_forecast/Mamba_all.sh +++ b/scripts/long_term_forecast/Mamba_all.sh @@ -1,4 +1,4 @@ -./scripts/long_term_forecast/ECL_script/Mamba.sh | tee mamba_all.txt -./scripts/long_term_forecast/Traffic_script/Mamba.sh | tee mamba_all.txt -a -./scripts/long_term_forecast/Exchange_script/Mamba.sh | tee mamba_all.txt -a -./scripts/long_term_forecast/Weather_script/Mamba.sh | tee mamba_all.txt -a +./scripts/long_term_forecast/ECL_script/Mamba.sh +./scripts/long_term_forecast/Traffic_script/Mamba.sh +./scripts/long_term_forecast/Exchange_script/Mamba.sh +./scripts/long_term_forecast/Weather_script/Mamba.sh From 094f7a51539ba26160e30bf269110ba2b1aea677 Mon Sep 17 00:00:00 2001 From: frecklebars Date: Thu, 4 Apr 2024 14:23:18 +0100 Subject: [PATCH 4/4] comments cleanup --- models/MambaSimple.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/MambaSimple.py b/models/MambaSimple.py index 948cb5fb..5cfc5d16 100644 --- a/models/MambaSimple.py +++ b/models/MambaSimple.py @@ -21,7 +21,7 @@ def __init__(self, configs): self.pred_len = configs.pred_len self.d_inner = configs.d_model * configs.expand - self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto" + self.dt_rank = math.ceil(configs.d_model / 16) self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout) @@ -89,7 +89,7 @@ def __init__(self, configs, d_inner, dt_rank): out_channels = self.d_inner, bias = True, kernel_size = configs.d_conv, - padding = configs.d_conv - 1, # TODO dont understand this; come back and do kernel = 3 padding = 1 instead if it doesnt work? + padding = configs.d_conv - 1, groups = self.d_inner )