Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multispans #414

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions libcity/config/executor/MultiSPANSExecutor.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"gpu": true,
"gpu_id": 0,
"max_epoch": 100,
"train_loss": "masked_mae",
"epoch": 0,
"learner": "adam",
"learning_rate": 0.01,
"weight_decay": 0,
"lr_epsilon": 1e-8,
"lr_beta1": 0.9,
"lr_beta2": 0.999,
"lr_alpha": 0.99,
"lr_momentum": 0,
"lr_decay": false,
"lr_scheduler": "multisteplr",
"lr_decay_ratio": 0.1,
"steps": [5, 20, 40, 70],
"step_size": 10,
"lr_T_max": 30,
"lr_eta_min": 0,
"lr_patience": 10,
"lr_threshold": 1e-4,
"clip_grad_norm": false,
"max_grad_norm": 1.0,
"use_early_stop": false,
"patience": 50,
"log_level": "INFO",
"log_every": 1,
"saved_model": true,
"load_best_epoch": true,
"hyper_tune": false,
"pred_channel_idx": [0],
"outfeat_dim": 1
}
26 changes: 26 additions & 0 deletions libcity/config/model/traffic_state_pred/MultiSPANS.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"embed_dim": 64,
"skip_conv_flag": false,
"residual_conv_flag": false,
"skip_dim": 64,
"num_layers": 3,
"num_heads": 8,

"conv_kernels": [1,2,3,6],
"conv_stride": 1,
"conv_if_gc": true,
"norm_type": "BatchNorm",

"gconv_hop_num": 3,
"gconv_alpha": 0,

"att_dropout": 0.1,
"ffn_dropout": 0.1,
"Satt_pe_type": "laplacian",
"Spe_learnable": false,
"Tatt_pe_type": "sincos",
"Tpe_learnable": false,
"Smask_flag": true,
"block_forward_mode": 0,
"sstore_attn": false
}
7 changes: 6 additions & 1 deletion libcity/config/task_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"DMVSTNet", "ATDM", "GMAN", "GTS", "STDN", "HGCN", "STSGCN", "STAGGCN", "STNN", "ResLSTM", "DGCN",
"MultiSTGCnet", "STMGAT", "CRANN", "STTN", "CONVGCNCommon", "DSAN", "DKFN", "CCRNN", "MultiSTGCnetCommon",
"GEML", "FNN", "GSNet", "CSTN", "D2STGNN", "STID","STGODE", "STNorm", "DMSTGCN", "ESG", "SSTBAN", "STTSNet",
"FOGS", "RGSL", "DSTAGNN", "STPGCN", "HIEST", "STAEformer", "TESTAM"
"FOGS", "RGSL", "DSTAGNN", "STPGCN", "HIEST", "STAEformer", "TESTAM", "MultiSPANS"
],
"allowed_dataset": [
"METR_LA", "PEMS_BAY", "PEMSD3", "PEMSD4", "PEMSD7", "PEMSD8", "PEMSD7(M)",
Expand All @@ -116,6 +116,11 @@
"executor": "TrafficStateExecutor",
"evaluator": "TrafficStateEvaluator"
},
"MultiSPANS": {
"dataset_class": "TrafficStatePointDataset",
"executor": "MultiSPANSExecutor",
"evaluator": "TrafficStateEvaluator"
},
"STPGCN": {
"dataset_class": "STPGCNDataset",
"executor": "TrafficStateExecutor",
Expand Down
5 changes: 4 additions & 1 deletion libcity/data/dataset/traffic_state_datatset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from libcity.data.dataset import AbstractDataset
from libcity.data.utils import generate_dataloader
from libcity.utils import StandardScaler, NormalScaler, NoneScaler, \
MinMax01Scaler, MinMax11Scaler, LogScaler, ensure_dir
MinMax01Scaler, MinMax11Scaler, LogScaler, ensure_dir, StandardIndependCScaler


class TrafficStateDataset(AbstractDataset):
Expand Down Expand Up @@ -903,6 +903,9 @@ def _get_scalar(self, scaler_type, x_train, y_train):
elif scaler_type == "standard":
scaler = StandardScaler(mean=x_train.mean(), std=x_train.std())
self._logger.info('StandardScaler mean: ' + str(scaler.mean) + ', std: ' + str(scaler.std))
elif scaler_type == "standardindependc":
scaler = StandardIndependCScaler(x_train=x_train)
self._logger.info('StandardIndependCScaler dim: ' + str(scaler.dim))
elif scaler_type == "minmax01":
scaler = MinMax01Scaler(
maxx=max(x_train.max(), y_train.max()), minn=min(x_train.min(), y_train.min()))
Expand Down
4 changes: 3 additions & 1 deletion libcity/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from libcity.executor.eta_executor import ETAExecutor
from libcity.executor.gensim_executor import GensimExecutor
from libcity.executor.sstban_executor import SSTBANExecutor
from libcity.executor.multispans_executor import MultiSPANSExecutor
from libcity.executor.testam_executor import TESTAMExecutor


Expand All @@ -34,5 +35,6 @@
"SSTBANExecutor",
"STTSNetExecutor",
"FOGSExecutor",
"TESTAMExecutor"
"TESTAMExecutor",
"MultiSPANSExecutor",
]
166 changes: 166 additions & 0 deletions libcity/executor/multispans_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import os
import time
from functools import partial

import numpy as np
import torch

from libcity.executor.traffic_state_executor import TrafficStateExecutor
from libcity.model import loss


class MultiSPANSExecutor(TrafficStateExecutor):
def __init__(self, config, model, data_feature):
super().__init__(config, model, data_feature)
self.pred_channel_idx = self.config.get("pred_channel_idx", None)

def _build_train_loss(self):
"""
根据全局参数`train_loss`选择训练过程的loss函数
如果该参数为none,则需要使用模型自定义的loss函数
注意,loss函数应该接收`Batch`对象作为输入,返回对应的loss(torch.tensor)
"""
if self.train_loss.lower() == 'none':
self._logger.warning('Received none train loss func and will use the loss func defined in the model.')
return None
if self.train_loss.lower() not in ['mae', 'mse', 'rmse', 'mape', 'logcosh', 'huber', 'quantile', 'masked_mae',
'masked_mse', 'masked_rmse', 'masked_mape', 'r2', 'evar']:
self._logger.warning('Received unrecognized train loss function, set default mae loss func.')
else:
self._logger.info('You select `{}` as train loss function.'.format(self.train_loss.lower()))

def func(batch, channel_index):
y_true = batch['y']
y_predicted = self.model.predict(batch)
y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim],
channel_idx=channel_index)
if channel_index is not None:
y_true = y_true[..., channel_index]
assert (y_true.shape[-1] == y_predicted.shape[-1]), 'Uncompatiable prediction & label channel!'

if self.train_loss.lower() == 'mae':
lf = loss.masked_mae_torch
elif self.train_loss.lower() == 'mse':
lf = loss.masked_mse_torch
elif self.train_loss.lower() == 'rmse':
lf = loss.masked_rmse_torch
elif self.train_loss.lower() == 'mape':
lf = loss.masked_mape_torch
elif self.train_loss.lower() == 'logcosh':
lf = loss.log_cosh_loss
elif self.train_loss.lower() == 'huber':
lf = loss.huber_loss
elif self.train_loss.lower() == 'quantile':
lf = loss.quantile_loss
elif self.train_loss.lower() == 'masked_mae':
lf = partial(loss.masked_mae_torch, null_val=0)
elif self.train_loss.lower() == 'masked_mse':
lf = partial(loss.masked_mse_torch, null_val=0)
elif self.train_loss.lower() == 'masked_rmse':
lf = partial(loss.masked_rmse_torch, null_val=0)
elif self.train_loss.lower() == 'masked_mape':
lf = partial(loss.masked_mape_torch, null_val=0)
elif self.train_loss.lower() == 'r2':
lf = loss.r2_score_torch
elif self.train_loss.lower() == 'evar':
lf = loss.explained_variance_score_torch
else:
lf = loss.masked_mae_torch
return lf(y_predicted, y_true)

return func

def evaluate(self, test_dataloader):
"""
use model to test data

Args:
test_dataloader(torch.Dataloader): Dataloader
"""
self._logger.info('Start evaluating ...')
with torch.no_grad():
self.model.eval()
y_truths = []
y_preds = []
for batch in test_dataloader:
batch.to_tensor(self.device)
output = self.model.predict(batch)
y_true = batch['y']
y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
y_pred = self._scaler.inverse_transform(output[..., :self.output_dim],
channel_idx=self.pred_channel_idx)
if self.pred_channel_idx is not None:
y_true = y_true[..., self.pred_channel_idx]
assert (
y_true.shape[-1] == output.shape[-1]
), 'Uncompatiable prediction & label channel!'

y_truths.append(y_true.cpu().numpy())
y_preds.append(y_pred.cpu().numpy())
# evaluate_input = {'y_true': y_true, 'y_pred': y_pred}
# self.evaluator.collect(evaluate_input)
# self.evaluator.save_result(self.evaluate_res_dir)
y_preds = np.concatenate(y_preds, axis=0)
y_truths = np.concatenate(y_truths, axis=0) # concatenate on batch
outputs = {'prediction': y_preds, 'truth': y_truths}
filename = \
time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime(time.time())) + '_' \
+ self.config['model'] + '_' + self.config['dataset'] + '_predictions.npz'
np.savez_compressed(os.path.join(self.evaluate_res_dir, filename), **outputs)
self.evaluator.clear()
self.evaluator.collect({'y_true': torch.tensor(y_truths), 'y_pred': torch.tensor(y_preds)})
test_result = self.evaluator.save_result(self.evaluate_res_dir)
return test_result

def _train_epoch(self, train_dataloader, epoch_idx, loss_func=None):
"""
完成模型一个轮次的训练

Args:
train_dataloader: 训练数据
epoch_idx: 轮次数
loss_func: 损失函数

Returns:
list: 每个batch的损失的数组
"""
self.model.train()
loss_func = loss_func if loss_func is not None else self.model.calculate_loss
losses = []
for batch in train_dataloader:
self.optimizer.zero_grad()
batch.to_tensor(self.device)
loss = loss_func(batch, self.pred_channel_idx)
self._logger.debug(loss.item())
losses.append(loss.item())
loss.backward()
if self.clip_grad_norm:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
return losses

def _valid_epoch(self, eval_dataloader, epoch_idx, loss_func=None):
"""
完成模型一个轮次的评估

Args:
eval_dataloader: 评估数据
epoch_idx: 轮次数
loss_func: 损失函数

Returns:
float: 评估数据的平均损失值
"""
with torch.no_grad():
self.model.eval()
loss_func = loss_func if loss_func is not None else self.model.calculate_loss
losses = []
for batch in eval_dataloader:
batch.to_tensor(self.device)
loss = loss_func(batch, self.pred_channel_idx)
self._logger.debug(loss.item())
losses.append(loss.item())
mean_loss = np.mean(losses)
self._writer.add_scalar('eval loss', mean_loss, epoch_idx)
return mean_loss
Loading