Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add two transformer models via upload #508

Merged
merged 18 commits into from
Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 61 additions & 39 deletions qlib/contrib/model/pytorch_localformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
import math
from ...utils import get_or_create_path
Expand All @@ -23,14 +24,15 @@
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
from torch.nn.modules.container import ModuleList
# qrun examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml ”


class LocalformerModel(Model):
def __init__(
self,
d_feat: int = 20,
d_model: int = 64,
batch_size: int = 8192,
batch_size: int = 2048,
nhead: int = 2,
num_layers: int = 2,
dropout: float = 0,
Expand Down Expand Up @@ -62,9 +64,7 @@ def __init__(
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger = get_module_logger("TransformerModel")
self.logger.info(
"Improved Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device)
)
self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device))

if self.seed is not None:
np.random.seed(self.seed)
Expand Down Expand Up @@ -106,36 +106,55 @@ def metric_fn(self, pred, label):

raise ValueError("unknown metric `%s`" % self.metric)

def train_epoch(self, data_loader):
def train_epoch(self, x_train, y_train):

x_train_values = x_train.values
y_train_values = np.squeeze(y_train.values)

self.model.train()

for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
indices = np.arange(len(x_train_values))
np.random.shuffle(indices)

for i in range(len(indices))[:: self.batch_size]:

if len(indices) - i < self.batch_size:
break

pred = self.model(feature.float()) # .float()
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)

pred = self.model(feature)
loss = self.loss_fn(pred, label)

self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
self.train_optimizer.step()

def test_epoch(self, data_loader):
def test_epoch(self, data_x, data_y):

# prepare training data
x_values = data_x.values
y_values = np.squeeze(data_y.values)

self.model.eval()

scores = []
losses = []

for data in data_loader:
indices = np.arange(len(x_values))

for i in range(len(indices))[:: self.batch_size]:

if len(indices) - i < self.batch_size:
break

feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
feature = torch.from_numpy(x_values[indices[i: i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i: i + self.batch_size]]).float().to(self.device)

with torch.no_grad():
pred = self.model(feature.float()) # .float()
pred = self.model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())

Expand All @@ -151,21 +170,16 @@ def fit(
save_path=None,
):

dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)

dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)

save_path = get_or_create_path(save_path)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
Expand All @@ -180,10 +194,10 @@ def fit(
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(train_loader)
self.train_epoch(x_train, y_train)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(train_loader)
val_loss, val_score = self.test_epoch(valid_loader)
train_loss, train_score = self.test_epoch(x_train, y_train)
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
Expand All @@ -206,25 +220,32 @@ def fit(
if self.use_gpu:
torch.cuda.empty_cache()

def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")

dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.model.eval()
x_values = x_test.values
sample_num = x_values.shape[0]
preds = []

for data in test_loader:
feature = data[:, :, 0:-1].to(self.device)
for begin in range(sample_num)[:: self.batch_size]:

if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size

x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)

with torch.no_grad():
pred = self.model(feature.float()).detach().cpu().numpy()
pred = self.model(x_batch).detach().cpu().numpy()

preds.append(pred)

return pd.Series(np.concatenate(preds), index=dl_test.get_index())
return pd.Series(np.concatenate(preds), index=index)


class PositionalEncoding(nn.Module):
Expand Down Expand Up @@ -289,8 +310,9 @@ def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, devi
self.d_feat = d_feat

def forward(self, src):
# src [N, T, F], [512, 60, 6]
src = self.feature_layer(src) # [512, 60, 8]
# src [N, F*T] --> [N, T, F]
src = src.reshape(len(src), self.d_feat, -1).permute(0, 2, 1)
src = self.feature_layer(src)

# src [N, T, F] --> [T, N, F], [60, 512, 8]
src = src.transpose(1, 0) # not batch first
Expand Down
Loading