-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Mingyue-Cheng
committed
Mar 28, 2024
0 parents
commit 0cd2e7a
Showing
18 changed files
with
3,884 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# InstructTime | ||
|
||
## Project Overview | ||
|
||
This is an anonymously open-sourced project designed for scientific research and technological development, supporting the blind review process to ensure fairness and objectivity in evaluations. | ||
|
||
## Installation Instructions | ||
|
||
```bash | ||
# Install dependencies | ||
pip install -r requirements.txt | ||
|
||
# Run to train TStokenizer | ||
cd TStokenizer | ||
python main.py \ | ||
--save_path $VQVAE_PATH \ | ||
--dataset $DATASET \ | ||
--data_path $DATA_PATH \ | ||
--device $DEVICE \ | ||
--d_model $D_MODEL \ | ||
--wave_length $WAVE_LENGTH \ | ||
--n_embed $NUM_TOKEN \ | ||
|
||
# Run to train Instructtime-Universal | ||
python main.py \ | ||
--save_path $VQVAE_PATH \ | ||
--dataset $DATASET \ | ||
--model_path $DATA_PATH \ | ||
--device $DEVICE \ | ||
--adapt False | ||
|
||
# Run to train Instructtime-Adapt | ||
python main.py \ | ||
--save_path $VQVAE_PATH \ | ||
--dataset $DATASET \ | ||
--model_path $DATA_PATH \ | ||
--load_model_path $DATA_PATH \ | ||
--device $DEVICE \ | ||
--lr $lr \ | ||
--adapt True | ||
``` | ||
|
||
## One of Instructime's Prompt | ||
|
||
``` | ||
You will be receiving electroencephalogram(EEG) related signals. | ||
EEG: <BET><TS Tokens><EET> | ||
The sleep patterns include waking up, rapid eye movement sleep, and sleep stages one through four, as well as periods of movement and unidentified stages. | ||
Select one of the eight previously mentioned sleep patterns and report on the person's sleep using the provided information. | ||
The person's sleep pattern is waking up | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import argparse | ||
import os | ||
import json | ||
from datautils import load_all_data | ||
|
||
parser = argparse.ArgumentParser() | ||
# dataset and dataloader args | ||
parser.add_argument('--save_path', type=str, default='./test') | ||
parser.add_argument('--dataset', type=str, default='sleep', choices=['har', 'geo', 'sleep', 'dev', 'ecg', 'whale', 'ad', 'esr']) | ||
parser.add_argument('--data_path', type=str, default=None) | ||
parser.add_argument('--device', type=str, default='cuda:0') | ||
parser.add_argument('--train_batch_size', type=int, default=32) | ||
parser.add_argument('--test_batch_size', type=int, default=64) | ||
|
||
# model args | ||
parser.add_argument('--d_model', type=int, default=64) | ||
parser.add_argument('--dropout', type=float, default=0.2) | ||
parser.add_argument('--eval_per_steps', type=int, default=300) | ||
parser.add_argument('--enable_res_parameter', type=int, default=1) | ||
parser.add_argument('--n_embed', type=int, default=2560) | ||
parser.add_argument('--wave_length', type=int, default=25) | ||
parser.add_argument('--mask_prob', type=float, default=0.2) | ||
parser.add_argument('--pooling_type', type=str, default='mean', choices=['mean', 'max', 'last_token']) | ||
|
||
# tcn args | ||
parser.add_argument('--kernel_size', type=int, default=3) | ||
parser.add_argument('--block_num', type=int, default=4) | ||
parser.add_argument('--dilations', type=list, default=[1, 4]) | ||
|
||
# train args | ||
parser.add_argument('--lr', type=float, default=0.0005) | ||
parser.add_argument('--lr_decay_rate', type=float, default=0.99) | ||
parser.add_argument('--lr_decay_steps', type=int, default=300) | ||
parser.add_argument('--weight_decay', type=float, default=1e-5) | ||
parser.add_argument('--num_epoch', type=int, default=60) | ||
|
||
args = parser.parse_args() | ||
if args.data_path is None: | ||
if args.dataset == 'har': | ||
Train_data, Test_data = load_all_data('../har_no_big') | ||
elif args.dataset == 'geo': | ||
Train_data, Test_data = load_all_data('../ecg_no_big') | ||
elif args.dataset == 'dev': | ||
Train_data, Test_data = load_all_data('../device_no_big') | ||
elif args.dataset == 'whale': | ||
Train_data, Test_data = load_all_data('../whale_no_big') | ||
elif args.dataset == 'ad': | ||
Train_data, Test_data = load_all_data('../ad_no_big') | ||
elif args.dataset == 'esr': | ||
Train_data, Test_data = load_all_data('../esr_no_big') | ||
else: | ||
Train_data, Test_data = load_all_data('../eeg_no_big') | ||
else: | ||
path = args.data_path | ||
if args.dataset == 'har': | ||
Train_data, Test_data = load_all_data(path) | ||
elif args.dataset == 'geo': | ||
Train_data, Test_data = load_all_data(path) | ||
elif args.dataset == 'dev': | ||
Train_data, Test_data = load_all_data(path) | ||
elif args.dataset == 'ecg': | ||
Train_data, Test_data = load_all_data(path) | ||
elif args.dataset == 'whale': | ||
Train_data, Test_data = load_all_data(path) | ||
elif args.dataset == 'ad': | ||
Train_data, Test_data = load_all_data(path) | ||
elif args.dataset == 'esr': | ||
Train_data, Test_data = load_all_data(path) | ||
else: | ||
Train_data, Test_data = load_all_data(path) | ||
print('data loaded') | ||
|
||
if args.save_path == 'None': | ||
path_str = 'D-' + str(args.d_model) + '_Model-' + args.model + '_Lr-' + str(args.lr) + '_Dataset-' + args.dataset + '/' | ||
args.save_path = path_str | ||
if not os.path.exists(args.save_path): | ||
os.mkdir(args.save_path) | ||
|
||
config_file = open(args.save_path + '/args.json', 'w') | ||
tmp = args.__dict__ | ||
json.dump(tmp, config_file, indent=1) | ||
print(args) | ||
config_file.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
import torch.utils.data as Data | ||
from args import Train_data, Test_data | ||
|
||
class Dataset(Data.Dataset): | ||
def __init__(self, device, mode, args): | ||
self.args = args | ||
if mode == 'train': | ||
self.ecgs_images = Train_data | ||
else: | ||
self.ecgs_images = Test_data | ||
self.device = device | ||
self.mode = mode | ||
|
||
def __len__(self): | ||
return len(self.ecgs_images) | ||
|
||
def __getitem__(self, item): | ||
ecg_img = torch.tensor(self.ecgs_images[item]).to(self.device) | ||
return ecg_img * 2.5 | ||
|
||
def shape(self): | ||
return self.ecgs_images[0].shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
import pickle | ||
import numpy as np | ||
|
||
def save_datasets(data, path, file_name): | ||
if not os.path.exists(path): | ||
os.makedirs(path) | ||
np.save(os.path.join(path, file_name), data) | ||
|
||
def load_datasets(path, file_name): | ||
return np.load(os.path.join(path, file_name)) | ||
|
||
def load_all_data(Path, use_saved_datasets=True): | ||
if use_saved_datasets: | ||
try: | ||
tokenizer_train = load_datasets(Path, 'train.npy') | ||
tokenizer_test = load_datasets(Path, 'test.npy') | ||
print(len(tokenizer_train), len(tokenizer_test)) | ||
|
||
return tokenizer_train, tokenizer_test | ||
except IOError: | ||
print("Saved datasets not found. Processing raw data.") | ||
|
||
train_path = os.path.join(Path, 'samples_train.pkl') | ||
test_path = os.path.join(Path, 'samples_test.pkl') | ||
|
||
samples_train, tokenizer_train = [], [] | ||
samples_test, tokenizer_test = [], [] | ||
if os.path.isfile(train_path) and os.path.isfile(test_path): | ||
with open(train_path, 'rb') as file: | ||
samples_train = pickle.load(file) | ||
with open(test_path, 'rb') as file: | ||
samples_test = pickle.load(file) | ||
|
||
for sample in samples_train: | ||
_, ecg, _ = sample | ||
tokenizer_train.append(ecg) | ||
for sample in samples_test: | ||
_, ecg, _ = sample | ||
tokenizer_test.append(ecg) | ||
|
||
tokenizer_train = np.array(tokenizer_train, dtype=np.float32) | ||
tokenizer_test = np.array(tokenizer_test, dtype=np.float32) | ||
|
||
save_datasets(tokenizer_train, Path, 'train.npy') | ||
save_datasets(tokenizer_test, Path, 'test.npy') | ||
|
||
return tokenizer_train, tokenizer_test | ||
|
||
if __name__ == "__main__": | ||
load_all_data('./esr_data') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch.nn as nn | ||
|
||
class MSE: | ||
def __init__(self, model, latent_loss_weight=0.25): | ||
self.model = model | ||
self.latent_loss_weight = latent_loss_weight | ||
self.mse = nn.MSELoss() | ||
|
||
def compute(self, batch): | ||
seqs = batch | ||
out, latent_loss, _ = self.model(seqs) | ||
recon_loss = self.mse(out, seqs) | ||
latent_loss = latent_loss.mean() | ||
loss = recon_loss + self.latent_loss_weight * latent_loss | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import os | ||
import torch | ||
import random | ||
import os | ||
import numpy as np | ||
import warnings | ||
|
||
warnings.filterwarnings('ignore') | ||
from dataset import Dataset | ||
from args import args | ||
from process import Trainer | ||
from model import VQVAE | ||
import torch.utils.data as Data | ||
|
||
def seed_everything(seed): | ||
random.seed(seed) | ||
os.environ["PYTHONHASHSEED"] = str(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
torch.backends.cudnn.enabled = True | ||
|
||
def main(): | ||
seed_everything(seed=2023) | ||
|
||
train_dataset = Dataset(device=args.device, mode='train', args=args) | ||
train_loader = Data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) | ||
args.data_shape = train_dataset.shape() | ||
test_dataset = Dataset(device=args.device, mode='test', args=args) | ||
test_loader = Data.DataLoader(test_dataset, batch_size=args.test_batch_size) | ||
print(args.data_shape) | ||
print('dataset initial ends') | ||
|
||
model = VQVAE(data_shape=args.data_shape, hidden_dim=args.d_model, n_embed=args.n_embed, block_num=args.block_num, | ||
wave_length=args.wave_length) | ||
print('model initial ends') | ||
|
||
trainer = Trainer(args, model, train_loader, test_loader, verbose=True) | ||
print('trainer initial ends') | ||
|
||
trainer.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() | ||
|
Oops, something went wrong.