Skip to content

Commit

Permalink
Upload the code
Browse files Browse the repository at this point in the history
  • Loading branch information
Mingyue-Cheng committed Mar 28, 2024
0 parents commit 0cd2e7a
Show file tree
Hide file tree
Showing 18 changed files with 3,884 additions and 0 deletions.
51 changes: 51 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# InstructTime

## Project Overview

This is an anonymously open-sourced project designed for scientific research and technological development, supporting the blind review process to ensure fairness and objectivity in evaluations.

## Installation Instructions

```bash
# Install dependencies
pip install -r requirements.txt

# Run to train TStokenizer
cd TStokenizer
python main.py \
--save_path $VQVAE_PATH \
--dataset $DATASET \
--data_path $DATA_PATH \
--device $DEVICE \
--d_model $D_MODEL \
--wave_length $WAVE_LENGTH \
--n_embed $NUM_TOKEN \

# Run to train Instructtime-Universal
python main.py \
--save_path $VQVAE_PATH \
--dataset $DATASET \
--model_path $DATA_PATH \
--device $DEVICE \
--adapt False

# Run to train Instructtime-Adapt
python main.py \
--save_path $VQVAE_PATH \
--dataset $DATASET \
--model_path $DATA_PATH \
--load_model_path $DATA_PATH \
--device $DEVICE \
--lr $lr \
--adapt True
```

## One of Instructime's Prompt

```
You will be receiving electroencephalogram(EEG) related signals.
EEG: <BET><TS Tokens><EET>
The sleep patterns include waking up, rapid eye movement sleep, and sleep stages one through four, as well as periods of movement and unidentified stages.
Select one of the eight previously mentioned sleep patterns and report on the person's sleep using the provided information.
The person's sleep pattern is waking up
```
83 changes: 83 additions & 0 deletions TStokenizer/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import argparse
import os
import json
from datautils import load_all_data

parser = argparse.ArgumentParser()
# dataset and dataloader args
parser.add_argument('--save_path', type=str, default='./test')
parser.add_argument('--dataset', type=str, default='sleep', choices=['har', 'geo', 'sleep', 'dev', 'ecg', 'whale', 'ad', 'esr'])
parser.add_argument('--data_path', type=str, default=None)
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--train_batch_size', type=int, default=32)
parser.add_argument('--test_batch_size', type=int, default=64)

# model args
parser.add_argument('--d_model', type=int, default=64)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--eval_per_steps', type=int, default=300)
parser.add_argument('--enable_res_parameter', type=int, default=1)
parser.add_argument('--n_embed', type=int, default=2560)
parser.add_argument('--wave_length', type=int, default=25)
parser.add_argument('--mask_prob', type=float, default=0.2)
parser.add_argument('--pooling_type', type=str, default='mean', choices=['mean', 'max', 'last_token'])

# tcn args
parser.add_argument('--kernel_size', type=int, default=3)
parser.add_argument('--block_num', type=int, default=4)
parser.add_argument('--dilations', type=list, default=[1, 4])

# train args
parser.add_argument('--lr', type=float, default=0.0005)
parser.add_argument('--lr_decay_rate', type=float, default=0.99)
parser.add_argument('--lr_decay_steps', type=int, default=300)
parser.add_argument('--weight_decay', type=float, default=1e-5)
parser.add_argument('--num_epoch', type=int, default=60)

args = parser.parse_args()
if args.data_path is None:
if args.dataset == 'har':
Train_data, Test_data = load_all_data('../har_no_big')
elif args.dataset == 'geo':
Train_data, Test_data = load_all_data('../ecg_no_big')
elif args.dataset == 'dev':
Train_data, Test_data = load_all_data('../device_no_big')
elif args.dataset == 'whale':
Train_data, Test_data = load_all_data('../whale_no_big')
elif args.dataset == 'ad':
Train_data, Test_data = load_all_data('../ad_no_big')
elif args.dataset == 'esr':
Train_data, Test_data = load_all_data('../esr_no_big')
else:
Train_data, Test_data = load_all_data('../eeg_no_big')
else:
path = args.data_path
if args.dataset == 'har':
Train_data, Test_data = load_all_data(path)
elif args.dataset == 'geo':
Train_data, Test_data = load_all_data(path)
elif args.dataset == 'dev':
Train_data, Test_data = load_all_data(path)
elif args.dataset == 'ecg':
Train_data, Test_data = load_all_data(path)
elif args.dataset == 'whale':
Train_data, Test_data = load_all_data(path)
elif args.dataset == 'ad':
Train_data, Test_data = load_all_data(path)
elif args.dataset == 'esr':
Train_data, Test_data = load_all_data(path)
else:
Train_data, Test_data = load_all_data(path)
print('data loaded')

if args.save_path == 'None':
path_str = 'D-' + str(args.d_model) + '_Model-' + args.model + '_Lr-' + str(args.lr) + '_Dataset-' + args.dataset + '/'
args.save_path = path_str
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)

config_file = open(args.save_path + '/args.json', 'w')
tmp = args.__dict__
json.dump(tmp, config_file, indent=1)
print(args)
config_file.close()
23 changes: 23 additions & 0 deletions TStokenizer/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import torch.utils.data as Data
from args import Train_data, Test_data

class Dataset(Data.Dataset):
def __init__(self, device, mode, args):
self.args = args
if mode == 'train':
self.ecgs_images = Train_data
else:
self.ecgs_images = Test_data
self.device = device
self.mode = mode

def __len__(self):
return len(self.ecgs_images)

def __getitem__(self, item):
ecg_img = torch.tensor(self.ecgs_images[item]).to(self.device)
return ecg_img * 2.5

def shape(self):
return self.ecgs_images[0].shape
51 changes: 51 additions & 0 deletions TStokenizer/datautils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import pickle
import numpy as np

def save_datasets(data, path, file_name):
if not os.path.exists(path):
os.makedirs(path)
np.save(os.path.join(path, file_name), data)

def load_datasets(path, file_name):
return np.load(os.path.join(path, file_name))

def load_all_data(Path, use_saved_datasets=True):
if use_saved_datasets:
try:
tokenizer_train = load_datasets(Path, 'train.npy')
tokenizer_test = load_datasets(Path, 'test.npy')
print(len(tokenizer_train), len(tokenizer_test))

return tokenizer_train, tokenizer_test
except IOError:
print("Saved datasets not found. Processing raw data.")

train_path = os.path.join(Path, 'samples_train.pkl')
test_path = os.path.join(Path, 'samples_test.pkl')

samples_train, tokenizer_train = [], []
samples_test, tokenizer_test = [], []
if os.path.isfile(train_path) and os.path.isfile(test_path):
with open(train_path, 'rb') as file:
samples_train = pickle.load(file)
with open(test_path, 'rb') as file:
samples_test = pickle.load(file)

for sample in samples_train:
_, ecg, _ = sample
tokenizer_train.append(ecg)
for sample in samples_test:
_, ecg, _ = sample
tokenizer_test.append(ecg)

tokenizer_train = np.array(tokenizer_train, dtype=np.float32)
tokenizer_test = np.array(tokenizer_test, dtype=np.float32)

save_datasets(tokenizer_train, Path, 'train.npy')
save_datasets(tokenizer_test, Path, 'test.npy')

return tokenizer_train, tokenizer_test

if __name__ == "__main__":
load_all_data('./esr_data')
15 changes: 15 additions & 0 deletions TStokenizer/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch.nn as nn

class MSE:
def __init__(self, model, latent_loss_weight=0.25):
self.model = model
self.latent_loss_weight = latent_loss_weight
self.mse = nn.MSELoss()

def compute(self, batch):
seqs = batch
out, latent_loss, _ = self.model(seqs)
recon_loss = self.mse(out, seqs)
latent_loss = latent_loss.mean()
loss = recon_loss + self.latent_loss_weight * latent_loss
return loss
49 changes: 49 additions & 0 deletions TStokenizer/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import torch
import random
import os
import numpy as np
import warnings

warnings.filterwarnings('ignore')
from dataset import Dataset
from args import args
from process import Trainer
from model import VQVAE
import torch.utils.data as Data

def seed_everything(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = True

def main():
seed_everything(seed=2023)

train_dataset = Dataset(device=args.device, mode='train', args=args)
train_loader = Data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
args.data_shape = train_dataset.shape()
test_dataset = Dataset(device=args.device, mode='test', args=args)
test_loader = Data.DataLoader(test_dataset, batch_size=args.test_batch_size)
print(args.data_shape)
print('dataset initial ends')

model = VQVAE(data_shape=args.data_shape, hidden_dim=args.d_model, n_embed=args.n_embed, block_num=args.block_num,
wave_length=args.wave_length)
print('model initial ends')

trainer = Trainer(args, model, train_loader, test_loader, verbose=True)
print('trainer initial ends')

trainer.train()


if __name__ == '__main__':
main()

Loading

0 comments on commit 0cd2e7a

Please sign in to comment.