From 0cd2e7a83e0b41b3d51e8d5cd98aae6ff6aed76d Mon Sep 17 00:00:00 2001 From: Mingyue-Cheng Date: Thu, 28 Mar 2024 10:49:14 +0800 Subject: [PATCH] Upload the code --- README.md | 51 ++ TStokenizer/args.py | 83 ++ TStokenizer/dataset.py | 23 + TStokenizer/datautils.py | 51 ++ TStokenizer/loss.py | 15 + TStokenizer/main.py | 49 + TStokenizer/main_eval.py | 200 +++++ TStokenizer/model.py | 267 ++++++ TStokenizer/process.py | 100 +++ args.py | 24 + challeng_score.py | 143 +++ metrics.py | 267 ++++++ multidataset.py | 132 +++ multimodel.py | 132 +++ preprocess.py | 1831 ++++++++++++++++++++++++++++++++++++++ requirements.txt | 47 + run_truth_loss.py | 421 +++++++++ utils.py | 48 + 18 files changed, 3884 insertions(+) create mode 100644 README.md create mode 100644 TStokenizer/args.py create mode 100644 TStokenizer/dataset.py create mode 100644 TStokenizer/datautils.py create mode 100644 TStokenizer/loss.py create mode 100644 TStokenizer/main.py create mode 100644 TStokenizer/main_eval.py create mode 100644 TStokenizer/model.py create mode 100644 TStokenizer/process.py create mode 100644 args.py create mode 100644 challeng_score.py create mode 100644 metrics.py create mode 100644 multidataset.py create mode 100644 multimodel.py create mode 100644 preprocess.py create mode 100644 requirements.txt create mode 100644 run_truth_loss.py create mode 100644 utils.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..1866f4c --- /dev/null +++ b/README.md @@ -0,0 +1,51 @@ +# InstructTime + +## Project Overview + +This is an anonymously open-sourced project designed for scientific research and technological development, supporting the blind review process to ensure fairness and objectivity in evaluations. + +## Installation Instructions + +```bash +# Install dependencies +pip install -r requirements.txt + +# Run to train TStokenizer +cd TStokenizer +python main.py \ +--save_path $VQVAE_PATH \ +--dataset $DATASET \ +--data_path $DATA_PATH \ +--device $DEVICE \ +--d_model $D_MODEL \ +--wave_length $WAVE_LENGTH \ +--n_embed $NUM_TOKEN \ + +# Run to train Instructtime-Universal +python main.py \ +--save_path $VQVAE_PATH \ +--dataset $DATASET \ +--model_path $DATA_PATH \ +--device $DEVICE \ +--adapt False + +# Run to train Instructtime-Adapt +python main.py \ +--save_path $VQVAE_PATH \ +--dataset $DATASET \ +--model_path $DATA_PATH \ +--load_model_path $DATA_PATH \ +--device $DEVICE \ +--lr $lr \ +--adapt True +``` + +## One of Instructime's Prompt + +``` +You will be receiving electroencephalogram(EEG) related signals. +EEG: +The sleep patterns include waking up, rapid eye movement sleep, and sleep stages one through four, as well as periods of movement and unidentified stages. +Select one of the eight previously mentioned sleep patterns and report on the person's sleep using the provided information. +The person's sleep pattern is waking up +``` \ No newline at end of file diff --git a/TStokenizer/args.py b/TStokenizer/args.py new file mode 100644 index 0000000..144954c --- /dev/null +++ b/TStokenizer/args.py @@ -0,0 +1,83 @@ +import argparse +import os +import json +from datautils import load_all_data + +parser = argparse.ArgumentParser() +# dataset and dataloader args +parser.add_argument('--save_path', type=str, default='./test') +parser.add_argument('--dataset', type=str, default='sleep', choices=['har', 'geo', 'sleep', 'dev', 'ecg', 'whale', 'ad', 'esr']) +parser.add_argument('--data_path', type=str, default=None) +parser.add_argument('--device', type=str, default='cuda:0') +parser.add_argument('--train_batch_size', type=int, default=32) +parser.add_argument('--test_batch_size', type=int, default=64) + +# model args +parser.add_argument('--d_model', type=int, default=64) +parser.add_argument('--dropout', type=float, default=0.2) +parser.add_argument('--eval_per_steps', type=int, default=300) +parser.add_argument('--enable_res_parameter', type=int, default=1) +parser.add_argument('--n_embed', type=int, default=2560) +parser.add_argument('--wave_length', type=int, default=25) +parser.add_argument('--mask_prob', type=float, default=0.2) +parser.add_argument('--pooling_type', type=str, default='mean', choices=['mean', 'max', 'last_token']) + +# tcn args +parser.add_argument('--kernel_size', type=int, default=3) +parser.add_argument('--block_num', type=int, default=4) +parser.add_argument('--dilations', type=list, default=[1, 4]) + +# train args +parser.add_argument('--lr', type=float, default=0.0005) +parser.add_argument('--lr_decay_rate', type=float, default=0.99) +parser.add_argument('--lr_decay_steps', type=int, default=300) +parser.add_argument('--weight_decay', type=float, default=1e-5) +parser.add_argument('--num_epoch', type=int, default=60) + +args = parser.parse_args() +if args.data_path is None: + if args.dataset == 'har': + Train_data, Test_data = load_all_data('../har_no_big') + elif args.dataset == 'geo': + Train_data, Test_data = load_all_data('../ecg_no_big') + elif args.dataset == 'dev': + Train_data, Test_data = load_all_data('../device_no_big') + elif args.dataset == 'whale': + Train_data, Test_data = load_all_data('../whale_no_big') + elif args.dataset == 'ad': + Train_data, Test_data = load_all_data('../ad_no_big') + elif args.dataset == 'esr': + Train_data, Test_data = load_all_data('../esr_no_big') + else: + Train_data, Test_data = load_all_data('../eeg_no_big') +else: + path = args.data_path + if args.dataset == 'har': + Train_data, Test_data = load_all_data(path) + elif args.dataset == 'geo': + Train_data, Test_data = load_all_data(path) + elif args.dataset == 'dev': + Train_data, Test_data = load_all_data(path) + elif args.dataset == 'ecg': + Train_data, Test_data = load_all_data(path) + elif args.dataset == 'whale': + Train_data, Test_data = load_all_data(path) + elif args.dataset == 'ad': + Train_data, Test_data = load_all_data(path) + elif args.dataset == 'esr': + Train_data, Test_data = load_all_data(path) + else: + Train_data, Test_data = load_all_data(path) +print('data loaded') + +if args.save_path == 'None': + path_str = 'D-' + str(args.d_model) + '_Model-' + args.model + '_Lr-' + str(args.lr) + '_Dataset-' + args.dataset + '/' + args.save_path = path_str +if not os.path.exists(args.save_path): + os.mkdir(args.save_path) + +config_file = open(args.save_path + '/args.json', 'w') +tmp = args.__dict__ +json.dump(tmp, config_file, indent=1) +print(args) +config_file.close() diff --git a/TStokenizer/dataset.py b/TStokenizer/dataset.py new file mode 100644 index 0000000..1f1f9fc --- /dev/null +++ b/TStokenizer/dataset.py @@ -0,0 +1,23 @@ +import torch +import torch.utils.data as Data +from args import Train_data, Test_data + +class Dataset(Data.Dataset): + def __init__(self, device, mode, args): + self.args = args + if mode == 'train': + self.ecgs_images = Train_data + else: + self.ecgs_images = Test_data + self.device = device + self.mode = mode + + def __len__(self): + return len(self.ecgs_images) + + def __getitem__(self, item): + ecg_img = torch.tensor(self.ecgs_images[item]).to(self.device) + return ecg_img * 2.5 + + def shape(self): + return self.ecgs_images[0].shape diff --git a/TStokenizer/datautils.py b/TStokenizer/datautils.py new file mode 100644 index 0000000..8997c2d --- /dev/null +++ b/TStokenizer/datautils.py @@ -0,0 +1,51 @@ +import os +import pickle +import numpy as np + +def save_datasets(data, path, file_name): + if not os.path.exists(path): + os.makedirs(path) + np.save(os.path.join(path, file_name), data) + +def load_datasets(path, file_name): + return np.load(os.path.join(path, file_name)) + +def load_all_data(Path, use_saved_datasets=True): + if use_saved_datasets: + try: + tokenizer_train = load_datasets(Path, 'train.npy') + tokenizer_test = load_datasets(Path, 'test.npy') + print(len(tokenizer_train), len(tokenizer_test)) + + return tokenizer_train, tokenizer_test + except IOError: + print("Saved datasets not found. Processing raw data.") + + train_path = os.path.join(Path, 'samples_train.pkl') + test_path = os.path.join(Path, 'samples_test.pkl') + + samples_train, tokenizer_train = [], [] + samples_test, tokenizer_test = [], [] + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test = pickle.load(file) + + for sample in samples_train: + _, ecg, _ = sample + tokenizer_train.append(ecg) + for sample in samples_test: + _, ecg, _ = sample + tokenizer_test.append(ecg) + + tokenizer_train = np.array(tokenizer_train, dtype=np.float32) + tokenizer_test = np.array(tokenizer_test, dtype=np.float32) + + save_datasets(tokenizer_train, Path, 'train.npy') + save_datasets(tokenizer_test, Path, 'test.npy') + + return tokenizer_train, tokenizer_test + +if __name__ == "__main__": + load_all_data('./esr_data') diff --git a/TStokenizer/loss.py b/TStokenizer/loss.py new file mode 100644 index 0000000..aade649 --- /dev/null +++ b/TStokenizer/loss.py @@ -0,0 +1,15 @@ +import torch.nn as nn + +class MSE: + def __init__(self, model, latent_loss_weight=0.25): + self.model = model + self.latent_loss_weight = latent_loss_weight + self.mse = nn.MSELoss() + + def compute(self, batch): + seqs = batch + out, latent_loss, _ = self.model(seqs) + recon_loss = self.mse(out, seqs) + latent_loss = latent_loss.mean() + loss = recon_loss + self.latent_loss_weight * latent_loss + return loss \ No newline at end of file diff --git a/TStokenizer/main.py b/TStokenizer/main.py new file mode 100644 index 0000000..636a6df --- /dev/null +++ b/TStokenizer/main.py @@ -0,0 +1,49 @@ +import os +import torch +import random +import os +import numpy as np +import warnings + +warnings.filterwarnings('ignore') +from dataset import Dataset +from args import args +from process import Trainer +from model import VQVAE +import torch.utils.data as Data + +def seed_everything(seed): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = True + +def main(): + seed_everything(seed=2023) + + train_dataset = Dataset(device=args.device, mode='train', args=args) + train_loader = Data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) + args.data_shape = train_dataset.shape() + test_dataset = Dataset(device=args.device, mode='test', args=args) + test_loader = Data.DataLoader(test_dataset, batch_size=args.test_batch_size) + print(args.data_shape) + print('dataset initial ends') + + model = VQVAE(data_shape=args.data_shape, hidden_dim=args.d_model, n_embed=args.n_embed, block_num=args.block_num, + wave_length=args.wave_length) + print('model initial ends') + + trainer = Trainer(args, model, train_loader, test_loader, verbose=True) + print('trainer initial ends') + + trainer.train() + + +if __name__ == '__main__': + main() + diff --git a/TStokenizer/main_eval.py b/TStokenizer/main_eval.py new file mode 100644 index 0000000..9a29005 --- /dev/null +++ b/TStokenizer/main_eval.py @@ -0,0 +1,200 @@ +import torch +import random +import os +import pickle +import seaborn as sns +import torch.nn as nn +import matplotlib.pyplot as plt +import numpy as np +import warnings +import scipy.stats + +warnings.filterwarnings('ignore') +from collections import Counter +from dataset import Dataset +from args import args +import torch.utils.data as Data + +from matplotlib.ticker import FuncFormatter + +def seed_everything(seed): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def calc_ent(model, data_loader): + id_counts = Counter() + with torch.no_grad(): + for batch in data_loader: + for seq in batch: + tensor = torch.Tensor(seq).unsqueeze(0).to(args.device) + try: + _, _, id = model(tensor) + id_counts.update(id.squeeze().cpu().numpy()) + except Exception as e: + print(f"error in entroy: {e}") + + counts = np.array(list(id_counts.values())) + probs = counts / counts.sum() + entropy = scipy.stats.entropy(probs, base=2) + + return entropy, id_counts + +def only_recon_loss(model, data_loader): + mse = nn.MSELoss() + total_recon_loss = 0.0 + total_batches = 0 + + with torch.no_grad(): + for batch in data_loader: + seqs = batch + out, _, _ = model(seqs) + recon_loss = mse(out, seqs) + + total_recon_loss += recon_loss.item() + total_batches += 1 + + avg_recon_loss = total_recon_loss / total_batches + return avg_recon_loss + +def save_output_to_file(output, save_path): + with open(save_path, "w") as file: + file.write(output) + +def case_analysis(model, data, save_path): + for i in range(20): + label, sample_data, _ = data[i] + sample_data = sample_data * 2.5 + with torch.no_grad(): + sample_data_tensor = torch.Tensor(sample_data).unsqueeze(0).to(args.device) + reconstructed, _, _ = model(sample_data_tensor) + + original_data = sample_data + reconstructed_data = reconstructed.squeeze(0).cpu().numpy() + dx_index = label.find("information.\n") + label = label[dx_index + 13:] + + ecg_index = label.rfind("include(s) ") + sleep_index = label.rfind("pattern is ") + har_index = label.rfind("engaged in ") + esr_index = label.rfind("state of ") + + if ecg_index != -1: + label = label[ecg_index + 11:] + elif sleep_index != -1: + label = label[sleep_index + 11:] + elif esr_index != -1: + label = label[esr_index + 9:] + else: + label = label[har_index + 11:] + + dim_name = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] + num_dims = original_data.shape[1] + fig, axes = plt.subplots(num_dims, 1, figsize=(12, num_dims * 2)) + + plt.rc('xtick', labelsize=14) + plt.rc('ytick', labelsize=14) + if num_dims == 1: + plt.plot(original_data[:], label='Original Data') + plt.plot(reconstructed_data[:], label='Reconstructed Data') + plt.title(f"Data Comparison - {dim_name[0]} - label {label}") + plt.legend(loc='upper left') + else: + for j in range(num_dims): + axes[j].plot(original_data[:, j], label='Original Data') + axes[j].plot(reconstructed_data[:, j], label='Reconstructed Data') + axes[j].set_title(f"Data Comparison - {dim_name[j]} - label {label}") + axes[j].legend(loc='upper left') + + def format_yaxis_label(value, pos): + return '{:.1f}'.format(value) + + plt.gca().yaxis.set_major_formatter(FuncFormatter(format_yaxis_label)) + + plt.tight_layout() + plt.savefig(os.path.join(save_path, f"dim_comparison_case{i}.svg"), format='svg') + plt.close() + +def plot_square_like_heatmap(id_counts, save_path, shape=(8, 16), font_size=24): + heatmap_data = np.zeros(shape) + + for i, count in enumerate(id_counts.values()): + if i >= np.prod(shape): + break + row = i // shape[1] + column = i % shape[1] + heatmap_data[row, column] = np.log1p(count) + + plt.figure(figsize=(20, 10)) + + sns.set(font_scale=2.4) + sns.heatmap(heatmap_data, cmap='Oranges', linecolor='white', linewidths=3, robust=True) + + plt.xticks([]) + plt.yticks([]) + + plt.xlabel('Index Column', fontsize=font_size) + plt.ylabel('Index Row', fontsize=font_size) + plt.title('Heatmap of Frequency Distribution for Time Series Tokens', fontsize=font_size) + + plt.savefig(os.path.join(save_path, "square_like_heatmap.svg"), format='svg', dpi=300) + plt.close() + +def main(): + seed_everything(seed=2023) + model_load_path = './test_ecg_64_128_40' + output = [] + + train_dataset = Dataset(device=args.device, mode='train', args=args) + train_loader = Data.DataLoader(train_dataset, batch_size=args.train_batch_size) + args.data_shape = train_dataset.shape() + test_dataset = Dataset(device=args.device, mode='test', args=args) + test_loader = Data.DataLoader(test_dataset, batch_size=args.test_batch_size) + print(args.data_shape) + print('dataset initial ends') + + model = VQVAE(data_shape=args.data_shape, hidden_dim=args.d_model, n_embed=args.n_embed, + wave_length=args.wave_length) + print('model initial ends') + + state_dict = torch.load(os.path.join(model_load_path, 'model.pkl'), map_location='cpu') + model = model.to(args.device) + model.load_state_dict(state_dict) + model.eval() + + entropy, id_counts = calc_ent(model, train_loader) + print('train:') + print(entropy, len(id_counts)) + output.append('train:\n') + output.append(f"{entropy} {len(id_counts)} {len(id_counts) / args.n_embed}\n") + + entropy, id_counts = calc_ent(model, test_loader) + print('test:') + print(entropy, len(id_counts)) + output.append('test:\n') + output.append(f"{entropy} {len(id_counts)} {len(id_counts) / args.n_embed}\n") + + recon_loss = only_recon_loss(model, test_loader) + print('recon loss(mse): {}\n'.format(recon_loss)) + output.append(f'recon loss(mse): {recon_loss}\n') + + plot_square_like_heatmap(id_counts, model_load_path) + print("Images have been saved.") + + save_output_to_file(''.join(output), os.path.join(model_load_path, 'model_output.txt')) + print("Texts have been saved.") + + file_path = '../ecg_no_big' + test_path = os.path.join(file_path, 'samples_test.pkl') + samples_test = [] + if os.path.isfile(test_path): + with open(test_path, 'rb') as file: + samples_test = pickle.load(file) + random_samples = random.sample(samples_test, 20) + case_analysis(model, random_samples, model_load_path) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/TStokenizer/model.py b/TStokenizer/model.py new file mode 100644 index 0000000..8ae9dc7 --- /dev/null +++ b/TStokenizer/model.py @@ -0,0 +1,267 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.init import xavier_normal_, constant_ + +""" +TCN based on RecBole's implementation +################################################ + +Reference code: + - https://github.com/fajieyuan/nextitnet + - https://github.com/initlisk/nextitnet_pytorch + +""" + +class TCN(nn.Module): + def __init__(self, args=None, **kwargs): + super(TCN, self).__init__() + + # load parameters info + if args is not None: + d_model = args.d_model + self.embedding_size = args.d_model + self.residual_channels = args.d_model + self.block_num = args.block_num + self.dilations = args.dilations * self.block_num + self.kernel_size = args.kernel_size + self.enabel_res_parameter = args.enable_res_parameter + self.dropout = args.dropout + self.device = args.device + self.data_shape = args.data_shape + else: + d_model = kwargs['d_model'] + self.embedding_size = kwargs['d_model'] + self.residual_channels = kwargs['d_model'] + self.block_num = kwargs['block_num'] + self.dilations = kwargs['dilations'] * self.block_num + self.data_shape = kwargs['data_shape'] + self.kernel_size = 3 + self.enabel_res_parameter = 1 + self.dropout = 0.1 + + self.max_len = self.data_shape[0] + print(self.max_len) + + # residual blocks dilations in blocks:[1,2,4,8,1,2,4,8,...] + rb = [ + ResidualBlock_b( + self.residual_channels, self.residual_channels, kernel_size=self.kernel_size, dilation=dilation, + enable_res_parameter=self.enabel_res_parameter, dropout=self.dropout + ) for dilation in self.dilations + ] + self.residual_blocks = nn.Sequential(*rb) + + # fully-connected layer + # self.output = nn.Linear(self.residual_channels, self.num_class) + self.output = nn.Linear(d_model, d_model) + self.broadcast_head = nn.Linear(d_model, self.data_shape[1]) + + # parameters initialization + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + xavier_normal_(module.weight.data) + if module.bias is not None: + constant_(module.bias.data, 0.1) + + def forward(self, x): + # Residual locks + # x in shape of [(B*T)*L*D] + dilate_outputs = self.residual_blocks(x) + x = dilate_outputs + return self.output(x) + + +class ResidualBlock_b(nn.Module): + r""" + Residual block (b) in the paper + """ + + def __init__(self, in_channel, out_channel, kernel_size=10, dilation=None, enable_res_parameter=False, dropout=0): + super(ResidualBlock_b, self).__init__() + + self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=(1, kernel_size), padding=0, dilation=dilation) + self.dropout1 = nn.Dropout(dropout) + self.ln1 = nn.LayerNorm(out_channel, eps=1e-8) + self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=(1, kernel_size), padding=0, dilation=dilation * 2) + self.dropout2 = nn.Dropout(dropout) + self.ln2 = nn.LayerNorm(out_channel, eps=1e-8) + + self.dilation = dilation + self.kernel_size = kernel_size + + self.enable = enable_res_parameter + self.a = nn.Parameter(torch.tensor(1e-8)) + + def forward(self, x): # x: [batch_size, seq_len, embed_size] + x_pad = self.conv_pad(x, self.dilation) # [batch_size, embed_size, 1, seq_len+(self.kernel_size-1)*dilations] + out = self.dropout1(self.conv1(x_pad).squeeze(2).permute(0, 2, 1)) + # [batch_size, seq_len+(self.kernel_size-1)*dilations-kernel_size+1, embed_size] + out = F.relu(self.ln1(out)) + out_pad = self.conv_pad(out, self.dilation * 2) + out2 = self.dropout2(self.conv2(out_pad).squeeze(2).permute(0, 2, 1)) + out2 = F.relu(self.ln2(out2)) + + if self.enable: + x = self.a * out2 + x + else: + x = out2 + x + + return x + # return self.skipconnect(x, self.ffn) + + def conv_pad(self, x, dilation): + """ Dropout-mask: To avoid the future information leakage problem, this paper proposed a masking-based dropout + trick for the 1D dilated convolution to prevent the network from seeing the future items. + Also the One-dimensional transformation is completed in this function. + """ + inputs_pad = x.permute(0, 2, 1) + inputs_pad = inputs_pad.unsqueeze(2) + pad = nn.ZeroPad2d(((self.kernel_size - 1) * dilation, 0, 0, 0)) + inputs_pad = pad(inputs_pad) + return inputs_pad + +# Copyright 2018 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +# Borrowed from https://github.com/deepmind/sonnet and ported it to PyTorch + +class Quantize(nn.Module): + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, beta=0.25): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + self.beta = beta + + embed = torch.randn(dim, n_embed) + torch.nn.init.kaiming_uniform_(embed) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + def forward(self, input): + flatten = input.reshape(-1, self.dim) + dist = ( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ self.embed + + self.embed.pow(2).sum(0, keepdim=True) + ) + _, embed_ind = (-dist).max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = self.embed_code(embed_ind) + + if self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + + self.cluster_size.data.mul_(self.decay).add_( + embed_onehot_sum, alpha=1 - self.decay + ) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + n = self.cluster_size.sum() + cluster_size = ( + (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + diff = (quantize.detach() - input).pow(2).mean() + commit_loss = (quantize - input.detach()).pow(2).mean() + diff += commit_loss * self.beta + quantize = input + (quantize - input).detach() + + return quantize, diff, embed_ind # new_x, mse with input, index + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.transpose(0, 1)) + +class Encoder(nn.Module): + def __init__(self, feat_num, hidden_dim, block_num, data_shape, dilations=[1, 4]): + super().__init__() + self.input_projection = nn.Linear(feat_num, hidden_dim) + self.blocks = TCN(args=None, d_model=hidden_dim, block_num=block_num, data_shape=data_shape, + dilations=dilations) + + def forward(self, input): + return self.blocks(self.input_projection(input)) + + +class Decoder(nn.Module): + def __init__(self, feat_num, hidden_dim, block_num, data_shape, dilations=[1, 4]): + super().__init__() + self.output_projection = nn.Linear(hidden_dim, feat_num) + self.blocks = TCN(args=None, d_model=hidden_dim, block_num=block_num, data_shape=data_shape, + dilations=dilations) + + def forward(self, input): + return self.output_projection(self.blocks(input)) + + +class TStokenizer(nn.Module): + def __init__( + self, + data_shape=(5000, 12), + hidden_dim=64, + n_embed=1024, + block_num=4, + wave_length=32, + ): + super().__init__() + + self.enc = Encoder(data_shape[1], hidden_dim, block_num, data_shape) + self.wave_patch = (wave_length, hidden_dim) + self.quantize_input = nn.Conv2d(1, hidden_dim, kernel_size=self.wave_patch, stride=self.wave_patch) + self.quantize = Quantize(hidden_dim, n_embed) + self.quantize_output = nn.Conv1d(int(data_shape[0] / wave_length), data_shape[0], kernel_size=1) + self.dec = Decoder(data_shape[1], hidden_dim, block_num, data_shape) + self.n_embed = n_embed + self.hidden_dim = hidden_dim + + def get_name(self): + return 'vqvae' + + def forward(self, input): + enc = self.enc(input) + enc = enc.unsqueeze(1) + quant = self.quantize_input(enc).squeeze(-1).transpose(1, 2) + quant, diff, id = self.quantize(quant) + quant = self.quantize_output(quant) # 2*100*64 -> 2*5000*64 + dec = self.dec(quant) + # codes above are need, i.e. return dec_t, diff_t here will form our forward function + + return dec, diff, id + + def get_embedding(self, id): + return self.quantize.embed_code(id) + + def decode_ids(self, id): + quant = self.get_embedding(id) + quant = self.quantize_output(quant) # 2*100*64 -> 2*5000*64 + dec = self.dec(quant) + + return dec + +if __name__ == '__main__': + model = VQVAE() + a = torch.randn(2, 5000, 8) + tmp = model(a) + print(1) diff --git a/TStokenizer/process.py b/TStokenizer/process.py new file mode 100644 index 0000000..9bb94ed --- /dev/null +++ b/TStokenizer/process.py @@ -0,0 +1,100 @@ +import time +import torch +from tqdm import tqdm +from loss import MSE +from torch.optim.lr_scheduler import LambdaLR + +class Trainer(): + def __init__(self, args, model, train_loader, test_loader, verbose=False): + self.args = args + self.verbose = verbose + self.device = args.device + self.print_process(self.device) + self.model = model.to(torch.device(self.device)) + + self.train_loader = train_loader + self.test_loader = test_loader + self.lr_decay = args.lr_decay_rate + self.lr_decay_steps = args.lr_decay_steps + self.weight_decay = args.weight_decay + self.model_name = self.model.get_name() + self.print_process(self.model_name) + + self.cr = MSE(self.model) + + self.num_epoch = args.num_epoch + self.eval_per_steps = args.eval_per_steps + self.save_path = args.save_path + if self.num_epoch: + self.result_file = open(self.save_path + '/result.txt', 'w') + self.result_file.close() + + self.step = 0 + self.best_metric = -1e9 + self.metric = 'mse' + + def train(self): + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.weight_decay) + self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda step: self.lr_decay ** step, verbose=self.verbose) + for epoch in range(self.num_epoch): + loss_epoch, time_cost = self._train_one_epoch() + self.result_file = open(self.save_path + '/result.txt', 'a+') + self.print_process( + 'Basic Model train epoch:{0},loss:{1},training_time:{2}'.format(epoch + 1, loss_epoch, time_cost)) + print('Basic Model train epoch:{0},loss:{1},training_time:{2}'.format(epoch + 1, loss_epoch, time_cost), + file=self.result_file) + self.result_file.close() + self.print_process(self.best_metric) + return self.best_metric + + def _train_one_epoch(self): + t0 = time.perf_counter() + self.model.train() + tqdm_dataloader = tqdm(self.train_loader) if self.verbose else self.train_loader + + loss_sum = 0 + for idx, batch in enumerate(tqdm_dataloader): + self.optimizer.zero_grad() + loss = self.cr.compute(batch) + loss_sum += loss.item() + + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5) + self.optimizer.step() + + self.step += 1 + if self.step % self.lr_decay_steps == 0: + self.scheduler.step() + if self.step % self.eval_per_steps == 0: + metric = self.eval_model_vqvae() + self.print_process(metric) + self.result_file = open(self.save_path + '/result.txt', 'a+') + print('step{0}'.format(self.step), file=self.result_file) + print(metric, file=self.result_file) + self.result_file.close() + if metric[self.metric] >= self.best_metric: + self.model.eval() + torch.save(self.model.state_dict(), self.save_path + '/model.pkl') + self.result_file = open(self.save_path + '/result.txt', 'a+') + print('saving model of step{0}'.format(self.step), file=self.result_file) + self.result_file.close() + self.best_metric = metric[self.metric] + self.model.train() + + return loss_sum / idx, time.perf_counter() - t0 + + def eval_model_vqvae(self): + self.model.eval() + tqdm_data_loader = tqdm(self.test_loader) if self.verbose else self.test_loader + metrics = {'mse': 0} + + with torch.no_grad(): + for idx, batch in enumerate(tqdm_data_loader): + mse = self.cr.compute(batch) + metrics['mse'] -= mse + metrics['mse'] /= idx + return metrics + + def print_process(self, *x): + if self.verbose: + print(*x) diff --git a/args.py b/args.py new file mode 100644 index 0000000..30a3831 --- /dev/null +++ b/args.py @@ -0,0 +1,24 @@ +import argparse + +def get_hyperparams(): + parser = argparse.ArgumentParser(description="Input hyperparams.") + + parser.add_argument("--seed", type=int, default=2024) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--model_path", type=str, default="./gptmodel", help="the path to save model.") + parser.add_argument("--load_model_path", type=str, default="./gptmodel", help="the path to load pretrained model.") + parser.add_argument('--dataset', type=str, default='mix', choices=['har', 'geo', 'sleep', 'mix', 'esr', 'ad', 'dev', 'whale']) + + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--per_max_token", type=int, default=32, help="The maximum number of tokens for a label.") + parser.add_argument("--encoder_max_length", type=int, default=230, help="Maximum length of language model input.") + parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") + parser.add_argument("--warm_up_ratio", type=float, default=0.05, help="Warm up step for schduler.") + parser.add_argument("--epochs", type=int, default=15, help="Training epochs.") + parser.add_argument("--adapt", type=bool, default=False, help="If finetune on pretrained model") + + parser.add_argument("--num_beams", type=int, default=1, help="Number of generation beams.") + parser.add_argument("--num_return_sequences", type=int, default=1) + + args = parser.parse_args() + return args \ No newline at end of file diff --git a/challeng_score.py b/challeng_score.py new file mode 100644 index 0000000..9edd90e --- /dev/null +++ b/challeng_score.py @@ -0,0 +1,143 @@ +import numpy as np + +def is_number(x): + try: + float(x) + return True + except (ValueError, TypeError): + return False + +def is_finite_number(x): + if is_number(x): + return np.isfinite(float(x)) + else: + return False + +# Load a table with row and column names. +def load_table(table_file): + # The table should have the following form: + # + # , a, b, c + # a, 1.2, 2.3, 3.4 + # b, 4.5, 5.6, 6.7 + # c, 7.8, 8.9, 9.0 + # + table = list() + with open(table_file, 'r') as f: + for i, l in enumerate(f): + arrs = [arr.strip() for arr in l.split(',')] + table.append(arrs) + + # Define the numbers of rows and columns and check for errors. + num_rows = len(table)-1 + if num_rows<1: + raise Exception('The table {} is empty.'.format(table_file)) + row_lengths = set(len(table[i])-1 for i in range(num_rows)) + if len(row_lengths)!=1: + raise Exception('The table {} has rows with different lengths.'.format(table_file)) + num_cols = min(row_lengths) + if num_cols<1: + raise Exception('The table {} is empty.'.format(table_file)) + + # Find the row and column labels. + rows = [table[0][j+1] for j in range(num_rows)] + cols = [table[i+1][0] for i in range(num_cols)] + + # Find the entries of the table. + values = np.zeros((num_rows, num_cols), dtype=np.float64) + for i in range(num_rows): + for j in range(num_cols): + value = table[i+1][j+1] + if is_finite_number(value): + values[i, j] = float(value) + else: + values[i, j] = float('nan') + + return rows, cols, values + +# Load weights. +def load_weights(weight_file): + # Load the table with the weight matrix. + rows, cols, values = load_table(weight_file) + + # Split the equivalent classes. + rows = [set(row.split('|')) for row in rows] + cols = [set(col.split('|')) for col in cols] + assert(rows == cols) + + # Identify the classes and the weight matrix. + classes = rows + weights = values + + return classes, weights + +# Compute a modified confusion matrix for multi-class, multi-label tasks. +def compute_modified_confusion_matrix(labels, outputs): + # Compute a binary multi-class, multi-label confusion matrix, where the rows + # are the labels and the columns are the outputs. + num_recordings, num_classes = np.shape(labels) + A = np.zeros((num_classes, num_classes)) + + # Iterate over all of the recordings. + for i in range(num_recordings): + # Calculate the number of positive labels and/or outputs. + normalization = float(max(np.sum(np.any((labels[i, :], outputs[i, :]), axis=0)), 1)) + # Iterate over all of the classes. + for j in range(num_classes): + # Assign full and/or partial credit for each positive class. + if labels[i, j]: + for k in range(num_classes): + if outputs[i, k]: + A[j, k] += 1.0/normalization + + return A + +# Compute the evaluation metric for the Challenge. +def compute_challenge_metric(weights, labels, outputs, classes, sinus_rhythm): + num_recordings, num_classes = np.shape(labels) + if sinus_rhythm in classes: + sinus_rhythm_index = classes.index(sinus_rhythm) + else: + raise ValueError('The sinus rhythm class is not available.') + + # Compute the observed score. + A = compute_modified_confusion_matrix(labels, outputs) + observed_score = np.nansum(weights * A) + + # Compute the score for the model that always chooses the correct label(s). + correct_outputs = labels + A = compute_modified_confusion_matrix(labels, correct_outputs) + correct_score = np.nansum(weights * A) + + # Compute the score for the model that always chooses the sinus rhythm class. + inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool_) + inactive_outputs[:, sinus_rhythm_index] = 1 + A = compute_modified_confusion_matrix(labels, inactive_outputs) + inactive_score = np.nansum(weights * A) + + if correct_score != inactive_score: + normalized_score = float(observed_score - inactive_score) / float(correct_score - inactive_score) + else: + normalized_score = 0.0 + + return normalized_score + +def evaluate_model(labels, binary_outputs): + # Identify the weights and the SNOMED CT code for the sinus rhythm class. + weights_file = 'weights.csv' + sinus_rhythm = set(['426783006']) + + # Load the scored classes and the weights for the Challenge metric. + print('Loading weights...') + classes, weights = load_weights(weights_file) + + # Evaluate the model by comparing the labels and outputs. + print('Evaluating model...') + + print('- Challenge metric...') + challenge_metric = compute_challenge_metric(weights, labels, binary_outputs, classes, sinus_rhythm) + + print('Done.') + + # Return the results. + return classes, challenge_metric \ No newline at end of file diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000..b519123 --- /dev/null +++ b/metrics.py @@ -0,0 +1,267 @@ +import numpy as np +import pandas as pd +from challeng_score import evaluate_model +from sklearn.metrics import f1_score, hamming_loss + +def get_dict(Path): + mapping_file = Path + mapping_data = pd.read_csv(mapping_file) + + annotation_to_condition = {} + for index, row in mapping_data.iterrows(): + annotation_to_condition[row['Full Name']] = index + + return annotation_to_condition + +def encode_labels(label_dict, label_str, delimiter=','): + labels = label_str.split(delimiter) + encoded = [0] * len(label_dict) + for label in labels: + label = label.strip() + if label not in label_dict: + continue + encoded[label_dict[label]] = 1 + + return encoded + +def metric_ecg(preds, labels, logger, delimiter=','): + diction = get_dict(Path='essy.csv') + + print(preds[0]) + print(labels[0]) + + encoded_preds = np.array([encode_labels(diction, p, delimiter) for p in preds]) + encoded_labels = np.array([encode_labels(diction, l, delimiter) for l in labels]) + + zero_preds = [] + zero_labels = [] + count = 0 + for i, encoded_pred in enumerate(encoded_preds): + if np.all(encoded_pred == 0): + zero_preds.append(preds[i]) + zero_labels.append(labels[i]) + count += 1 + + print(count / len(preds)) + + print(encoded_preds[0]) + print(encoded_labels[0]) + + hit1 = np.mean(np.all(encoded_preds == encoded_labels, axis=1)) + total_f1 = f1_score(encoded_labels, encoded_preds, average='samples', zero_division=0) + hloss = hamming_loss(encoded_labels, encoded_preds) + _, score = evaluate_model(encoded_labels, encoded_preds) + + logger.info( + "Evaluation result:\naccuracy: {}\nTotal F1: {}\nHmloss: {}\nScore: {}\n".format( + hit1, + total_f1, + hloss, + score + ) + ) + + print( + "accuracy: {}\nTotal F1: {}\nHmloss: {}\nScore: {}\n".format( + hit1, + total_f1, + hloss, + score + ) + ) + return hit1, zero_preds, zero_labels + +def metric_eeg(preds_eeg, labels_eeg, logger): + sleep_stages = { + 'waking up': 0, + 'rapid eye movement sleep': 1, + 'sleep stage one': 2, + 'sleep stage two': 3, + 'sleep stage three': 4, + 'sleep stage four': 5, + 'period of movement': 6, + 'unidentified stage': 7 + } + + print(preds_eeg[0]) + print(labels_eeg[0]) + + preds_mapped = np.array([sleep_stages.get(stage, -1) for stage in preds_eeg]) + labels_mapped = np.array([sleep_stages.get(stage, -1) for stage in labels_eeg]) + + zero_preds = [] + zero_labels = [] + count = 0 + for i, encoded_pred in enumerate(preds_mapped): + if encoded_pred == -1: + zero_preds.append(preds_eeg[i]) + zero_labels.append(labels_eeg[i]) + count += 1 + + print(count / len(preds_eeg)) + + print(preds_mapped[0]) + print(labels_mapped[0]) + + hit2 = np.mean(preds_mapped == labels_mapped) + sleep_f1 = f1_score(labels_mapped, preds_mapped, average='macro', zero_division=0) + + logger.info( + "Sleep Evaluation result:\naccuracy: {}\nTotal F1 sleep: {}\n".format( + hit2, + sleep_f1 + ) + ) + + print( + "accuracy: {}\nTotal F1 sleep: {}\n".format( + hit2, + sleep_f1 + ) + ) + + return hit2, zero_preds, zero_labels + +def metric_har(preds, labels, logger): + sleep_stages = { + 'walking': 0, + 'ascending stairs': 1, + 'descending stairs': 2, + 'sitting': 3, + 'standing': 4, + 'lying down': 5 + } + + print(preds[0]) + print(labels[0]) + + preds_mapped = np.array([sleep_stages.get(stage, -1) for stage in preds]) + labels_mapped = np.array([sleep_stages.get(stage, -1) for stage in labels]) + + zero_preds = [] + zero_labels = [] + count = 0 + for i, encoded_pred in enumerate(preds_mapped): + if encoded_pred == -1: + zero_preds.append(preds[i]) + zero_labels.append(labels[i]) + count += 1 + + print(count / len(preds)) + + print(preds_mapped[0]) + print(labels_mapped[0]) + + hit2 = np.mean(preds_mapped == labels_mapped) + sleep_f1 = f1_score(labels_mapped, preds_mapped, average='macro', zero_division=0) + + logger.info( + "HAR Evaluation result:\naccuracy: {}\nTotal F1 HAR: {}\n".format( + hit2, + sleep_f1 + ) + ) + + print( + "accuracy: {}\nTotal F1 HAR: {}\n".format( + hit2, + sleep_f1 + ) + ) + + return hit2, zero_preds, zero_labels + +def metric_fd(preds, labels, logger): + sleep_stages = { + 'not damaged': 0, + 'inner damaged': 1, + 'outer damaged': 2, + } + + print(preds[0]) + print(labels[0]) + + preds_mapped = np.array([sleep_stages.get(stage, -1) for stage in preds]) + labels_mapped = np.array([sleep_stages.get(stage, -1) for stage in labels]) + + zero_preds = [] + zero_labels = [] + count = 0 + for i, encoded_pred in enumerate(preds_mapped): + if encoded_pred == -1: + zero_preds.append(preds[i]) + zero_labels.append(labels[i]) + count += 1 + + print(count / len(preds)) + + print(preds_mapped[0]) + print(labels_mapped[0]) + + hit2 = np.mean(preds_mapped == labels_mapped) + sleep_f1 = f1_score(labels_mapped, preds_mapped, average='macro', zero_division=0) + + logger.info( + "FD Evaluation result:\naccuracy: {}\nTotal F1 FD: {}\n".format( + hit2, + sleep_f1 + ) + ) + + print( + "accuracy: {}\nTotal F1 FD: {}\n".format( + hit2, + sleep_f1 + ) + ) + + return hit2, zero_preds, zero_labels + +def metric_rwc(preds, labels, logger): + sleep_stages = { + 'the right whale': 0, + 'unknown creature': 1, + } + + print(preds[0]) + print(labels[0]) + + preds_mapped = np.array([sleep_stages.get(stage, -1) for stage in preds]) + labels_mapped = np.array([sleep_stages.get(stage, -1) for stage in labels]) + + zero_preds = [] + zero_labels = [] + count = 0 + for i, encoded_pred in enumerate(preds_mapped): + if encoded_pred == -1: + zero_preds.append(preds[i]) + zero_labels.append(labels[i]) + count += 1 + + print(count / len(preds)) + + print(preds_mapped[0]) + print(labels_mapped[0]) + + valid_indices = preds_mapped != -1 + valid_preds = preds_mapped[valid_indices] + valid_labels = labels_mapped[valid_indices] + + hit2 = np.mean(preds_mapped == labels_mapped) + sleep_f1 = f1_score(valid_labels, valid_preds, average='macro', zero_division=0) + + logger.info( + "RWC Evaluation result:\naccuracy: {}\nTotal F1 RWC: {}\n".format( + hit2, + sleep_f1 + ) + ) + + print( + "accuracy: {}\nTotal F1 RWC: {}\n".format( + hit2, + sleep_f1 + ) + ) + + return hit2, zero_preds, zero_labels \ No newline at end of file diff --git a/multidataset.py b/multidataset.py new file mode 100644 index 0000000..04f1dc4 --- /dev/null +++ b/multidataset.py @@ -0,0 +1,132 @@ +import torch +from torch.utils.data import Dataset + +class MultiDataset(Dataset): + r""" + A Dataset Class for building Dataloader of ECG or other datasets. + """ + + def __init__( + self, + samples, + tokenizer, + mode: str, + multi: str, + encoder_max_length=256, + prefix_text="", + ) -> None: + assert mode in ["train", "test"] + super().__init__() + self.samples = samples + self.tokenizer = tokenizer + self.mode = mode + self.max_length = encoder_max_length + self.multi = multi + self.prefix_tokens = self.tokenizer.encode(prefix_text) if prefix_text else [] + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + text, ecg, _ = self.samples[idx] + + dx_index = text.find("information.\n") + if dx_index != -1: + label = text[dx_index + 13:] + text = text[:dx_index + 13] + else: + label = '' + label_ids = self.tokenizer.encode(label) + + if self.mode == "train": + text = text + label + else: + text = text + + input_ids = self.template(ecg * 2.5, text) + label_ids = [-100] * (len(input_ids) - len(label_ids)) + label_ids + + attn_masks = [1] * len(input_ids) + input_ids, attn_masks = self.padding(input_ids, attn_masks) + label_ids, _ = self.padding(label_ids, attn_masks) + + if self.mode == "train": + return { + "input_ids": torch.LongTensor(input_ids), + "attn_masks": torch.FloatTensor(attn_masks), + "label_ids": torch.LongTensor(label_ids), + } + + elif self.mode == "test": + return { + "input_ids": torch.LongTensor(input_ids), + "attn_masks": torch.FloatTensor(attn_masks), + "label": label, + } + + def template(self, ecg, text): + r""" + The contents of the items are stitched together according to a template to construct the input. + """ + input_ids = self.prefix_tokens.copy() + if self.multi == 'mix': + if ecg.shape == (5000, 12): + bet_ids = self.tokenizer.encode('Electrocardiogram signals: ') + ecg_ids = self.tokenizer.encode(torch.Tensor(ecg).unsqueeze(0), model_id=0) + elif ecg.shape == (3000, 2): + bet_ids = self.tokenizer.encode('Electroencephalogram signals: ') + ecg_ids = self.tokenizer.encode(torch.Tensor(ecg).unsqueeze(0), model_id=1) + elif ecg.shape == (5120, 1): + bet_ids = self.tokenizer.encode('Industrial equipment signals: ') + ecg_ids = self.tokenizer.encode(torch.Tensor(ecg).unsqueeze(0), model_id=2) + elif ecg.shape == (128, 9): + bet_ids = self.tokenizer.encode('Human physical activities signals: ') + ecg_ids = self.tokenizer.encode(torch.Tensor(ecg).unsqueeze(0), model_id=3) + elif ecg.shape == (4000, 1): + bet_ids = self.tokenizer.encode('Whale sound signals: ') + ecg_ids = self.tokenizer.encode(torch.Tensor(ecg).unsqueeze(0), model_id=4) + elif ecg.shape == (93, 13): + bet_ids = self.tokenizer.encode('Electroencephalogram signals: ') + ecg_ids = self.tokenizer.encode(torch.Tensor(ecg).unsqueeze(0), model_id=5) + else: + if self.multi == 'geo': + bet_ids = self.tokenizer.encode('Electrocardiogram signals: ') + elif self.multi == 'sleep': + bet_ids = self.tokenizer.encode('Electroencephalogram signals: ') + elif self.multi == 'esr': + bet_ids = self.tokenizer.encode('Electroencephalogram signals: ') + else: + bet_ids = self.tokenizer.encode('Human physical activities signals: ') + ecg_ids = self.tokenizer.encode(torch.Tensor(ecg).unsqueeze(0)) + text_ids = self.tokenizer.encode(' \n' + text) + + ecg_ids = ecg_ids.tolist() + ecg_ids = ecg_ids[0] + + input_ids.extend(bet_ids + ecg_ids + text_ids) + + if len(input_ids) > self.max_length: + input_ids = input_ids[0 : self.max_length] + + return input_ids + + def padding(self, input_ids: list, attn_masks: list): + r""" + Padding the inputs for GPT model. + + For training, we pad the right side, + For testing, we pad the left side. + """ + assert len(input_ids) <= self.max_length + + if self.mode == "train": + input_ids = input_ids + [self.tokenizer.pad_token_id] * ( + self.max_length - len(input_ids) + ) + attn_masks = attn_masks + [0] * (self.max_length - len(attn_masks)) + elif self.mode == "dev" or self.mode == "test": + input_ids = [self.tokenizer.pad_token_id] * ( + self.max_length - len(input_ids) + ) + input_ids + attn_masks = [0] * (self.max_length - len(attn_masks)) + attn_masks + return input_ids, attn_masks \ No newline at end of file diff --git a/multimodel.py b/multimodel.py new file mode 100644 index 0000000..bf9b13f --- /dev/null +++ b/multimodel.py @@ -0,0 +1,132 @@ +import torch +import copy +import torch.nn as nn +import torch.nn.functional as F +from transformers import GPT2LMHeadModel, GPT2Tokenizer + +local_model_path = "./gpt2-model" +local_tokenizer_path = "./gpt2-tokenizer" + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dims, output_dim): + super(MLP, self).__init__() + + all_dims = [input_dim] + hidden_dims + [output_dim] + + self.linear_layers = nn.ModuleList() + for i in range(len(all_dims) - 1): + self.linear_layers.append(nn.Linear(all_dims[i], all_dims[i + 1])) + + def forward(self, x): + for i, layer in enumerate(self.linear_layers): + x = layer(x) + if i < len(self.linear_layers) - 1: + x = F.gelu(x) + return x + +class InstructTime(GPT2LMHeadModel): + def __init__(self, config, ecgTokenizers, text_embedding=50258): + super().__init__(config) + self.ecgTokenizers = ecgTokenizers + + embed_vector = torch.empty(0, self.ecgTokenizers[0].hidden_dim) + for tokenizer in self.ecgTokenizers: + tokenizer_embed_vector = copy.deepcopy(tokenizer.quantize.embed).transpose(-1, 0) + embed_vector = torch.cat([embed_vector, tokenizer_embed_vector], dim=0) + self.embed_layer = nn.Embedding.from_pretrained(embed_vector) + + self.text_embedding = text_embedding + self.embed = config.n_embd + self.config.pad_token_id = self.config.eos_token_id if self.config.pad_token_id is None else self.config.pad_token_id + + self.projection_layers = nn.ModuleList() + for _ in ecgTokenizers: + mlp = MLP(self.ecgTokenizers[0].hidden_dim, [64, 128, 256, 512], self.embed) + mlp.apply(self.init_weights_kaiming) + self.projection_layers.append(mlp) + + self.offsets = [self.text_embedding] + for tokenizer in self.ecgTokenizers: + self.offsets.append(self.offsets[-1] + tokenizer.n_embed) + + @staticmethod + def init_weights_kaiming(m): + if type(m) == nn.Linear: + nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu') + m.bias.data.fill_(0.01) + + def forward(self, *args, **kwargs): + input_ids = kwargs["input_ids"] + + text_mask = torch.lt(input_ids, self.text_embedding) + ecg_mask = ~text_mask + + text_ids = input_ids.clone() + text_ids[ecg_mask] = self.config.pad_token_id + + text_embeddings = self.transformer.wte(text_ids) + text_embeddings.mul_(text_mask.float().unsqueeze(-1)) + + masked_ids = input_ids.clone() + masked_ids[text_mask] = 0 + masked_ids[ecg_mask] -= self.text_embedding + + ecg_embeddings = torch.zeros_like(text_embeddings) + for i, _ in enumerate(self.ecgTokenizers): + tokenizer_mask = (input_ids >= self.offsets[i]) & (input_ids < self.offsets[i + 1]) + tokenizer_ids = input_ids.clone() + tokenizer_ids[~tokenizer_mask] = 0 + tokenizer_ids[tokenizer_mask] -= self.offsets[i] + + tokenizer_embeddings = self.embed_layer(tokenizer_ids) + tokenizer_embeddings = self.projection_layers[i](tokenizer_embeddings) + tokenizer_embeddings.mul_(tokenizer_mask.float().unsqueeze(-1)) + ecg_embeddings.add_(tokenizer_embeddings) + + kwargs["input_ids"] = None + kwargs["inputs_embeds"] = ecg_embeddings + text_embeddings + + outputs = super().forward(*args, **kwargs) + return outputs + +class MultiTokenizer: + def __init__(self, ecgTokenizers) -> None: + self.textTokenizer = GPT2Tokenizer.from_pretrained(local_tokenizer_path) + new_special_tokens = ["", ""] + self.textTokenizer.add_special_tokens({"additional_special_tokens": new_special_tokens}) + self.text_vocab_size = len(self.textTokenizer) + + self.ecgTokenizers = ecgTokenizers + + self.pad_token_id = self.textTokenizer.eos_token_id + self.eos_token_id = self.textTokenizer.eos_token_id + + self.offsets = self._calculate_offsets() + + def _calculate_offsets(self): + offsets = [] + current_offset = self.text_vocab_size + for tokenizer in self.ecgTokenizers: + offsets.append(current_offset) + current_offset += tokenizer.n_embed + return offsets + + def vocabSize_all(self): + return self.text_vocab_size + sum(tokenizer.n_embed for tokenizer in self.ecgTokenizers) + + def encode(self, input, model_id=1): + if isinstance(input, str): + return self.textTokenizer(input)["input_ids"] + elif isinstance(input, torch.Tensor): + input = input.to('cpu') + if model_id < len(self.ecgTokenizers): + tokenizer_index = model_id + _, _, indices = self.ecgTokenizers[tokenizer_index](input) + return indices + self.offsets[tokenizer_index] + else: + raise ValueError(f"Invalid model_id. Please provide a number between 0 and {len(self.ecgTokenizers)}.") + else: + raise ValueError("Unsupported input type. Please provide either a string or a torch.Tensor.") + + def decode(self, input, skip_special_tokens=True): + return self.textTokenizer.decode(input, skip_special_tokens=skip_special_tokens) \ No newline at end of file diff --git a/preprocess.py b/preprocess.py new file mode 100644 index 0000000..d5d1400 --- /dev/null +++ b/preprocess.py @@ -0,0 +1,1831 @@ +import os, re +import ast,math, sys +import pywt, wfdb +import numpy as np +import matplotlib.pyplot as plt +import scipy.signal as signal +import pandas as pd +import torch +import pickle +import pandas as pd +import random +from collections import Counter +import neurokit2 as nk + +def setup_seed(seed=2023): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + +def normalize(ecg): + min_val = np.min(ecg) + max_val = np.max(ecg) + + if (max_val - min_val) == 0: + return ecg + + # 进行最大-最小归一化 + normalized_ecg = (ecg - min_val) / (max_val - min_val) + + return (normalized_ecg - 0.5) * 2 + +def denoising(data): + # 初始化一个空矩阵来存储处理后的数据 + ecg_cleaned = np.zeros_like(data) + + # 循环处理每个通道 + for i in range(data.shape[1]): + channel_data = data[:, i] + ecg_cleaned[:, i] = nk.ecg_clean(channel_data, sampling_rate=500) + + return ecg_cleaned + +def padding_varying_length(data): + for i in range(data.shape[0]): + for j in range(data.shape[1]): + data[i, j, :][np.isnan(data[i, j, :])] = 0 + return data + +def pro_ecg(ecg): + # filtered_data = denoising(ecg) + normalize_data = normalize(ecg) + + return normalize_data + +def trans_code(label_list): + code_mapping = { + 733534002: 164909002, + 713427006: 59118001, + 284470004: 63593006, + 427172004: 17338001 + } + + # 使用列表推导式提高效率 + # 如果编码不在映射中,保持不变 + new_labelist = [code_mapping.get(code, code) for code in label_list] + return new_labelist + +def get_dict_ptb(Path='./data/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1/scp_statements.csv', few=False): + mapping_file = Path + mapping_data = pd.read_csv(mapping_file) + + annotation_to_condition = {} + for index, row in mapping_data.iterrows(): + if few: + annotation_to_condition[row['diagnostic_class']] = index + else: + annotation_to_condition[row['description']] = index + + return annotation_to_condition + +# def pro_text(comments, diction, only_label): +# ann = comments[2] +# snomed_ct_matches = re.findall(r'\d+', ann) + +# prefix = "Please provide a description of potential health problems, symptoms that the patient might be facing based on the provided information.\n" +# new_snomed_ct_codes = '' + +# label_text = '[' +# for full_name in diction.values(): +# label_text += full_name + ',' +# if label_text.endswith(','): +# label_text = label_text[:-1] +# label_text += ']\nPlease select some symptoms from the table above for description.\n' + +# for snomed_ct_code in snomed_ct_matches: +# label = diction.get(int(snomed_ct_code)) +# if label: +# new_snomed_ct_codes += label + ',' +# else: +# continue + +# if new_snomed_ct_codes.endswith(','): +# new_snomed_ct_codes = new_snomed_ct_codes[:-1] + +# if new_snomed_ct_codes == '': +# return '' + +# suffix = "This patient's symptoms include " + +# comments[2] = prefix + label_text + suffix + new_snomed_ct_codes + +# colon_index = comments[0].find(":") +# if colon_index != -1: +# comments[0] = "Age:" + comments[0][colon_index + 1:] +# else: +# return '' + +# colon_index = comments[1].find(":") +# if colon_index != -1: +# comments[1] = "Gender:" + comments[1][colon_index + 1:] +# else: +# return '' + +# ann = comments[0] +# age_matches = re.findall(r'\d+', ann) + +# for age_code in age_matches: +# age_code_int = int(age_code) +# if age_code_int <= 6: +# ann = ann.replace(age_code, 'adolescence') +# elif age_code_int <= 17: +# ann = ann.replace(age_code, 'juvenile') +# elif age_code_int <= 40: +# ann = ann.replace(age_code, 'youths') +# elif age_code_int <= 65: +# ann = ann.replace(age_code, 'middle-age') +# else: +# ann = ann.replace(age_code, 'the elderly') +# comments[0] = ann + +# if only_label: +# text = comments[2] +# else: +# text = comments[0] + '\n' + comments[1] + '\n' + comments[2] + +# return text + +def remove_before_colon(input_string): + index = input_string.find(": ") + if index != -1: + return input_string[index+2:] + else: + return input_string + +def get_dictionaries(path='C:/Users/ROG/Desktop/ConditionNames_SNOMED-CT.csv'): + try: + # 仅加载需要的列以减少内存占用 + mapping_data = pd.read_csv(path, usecols=['Snomed_CT', 'Full Name']) + + # 使用 pandas 的 to_dict 方法进行快速转换 + dict_snomed_to_name = mapping_data.set_index('Snomed_CT')['Full Name'].to_dict() + dict_snomed_to_index = dict_snomed_to_index = mapping_data.reset_index().set_index('Snomed_CT')['index'].to_dict() + + except FileNotFoundError: + raise FileNotFoundError(f"Unable to find the file at the specified path: {path}") + except Exception as e: + raise Exception(f"An error occurred while processing the file: {e}") + + return dict_snomed_to_name, dict_snomed_to_index + +def process_12_lead_shot(data_folder='./data/12-lead/WFDBRecords/', only_label=False): + diction1, diction2 = get_dictionaries('./data/12-lead.csv') + label_frequency = {} # Step 1: Create a label frequency dictionary + num_classes = len(diction1) + + # Step 2: First pass to fill the frequency dictionary + for file_out in os.scandir(data_folder): + if file_out.is_dir(): + path_in = file_out.path + for file_in in os.scandir(path_in): + if file_in.is_dir(): + data_path = file_in.path + for entry in os.scandir(data_path): + if entry.is_file() and entry.name.endswith('.hea'): + header = wfdb.rdheader(entry.path[:-4]) + label = header.comments[2] + label = remove_before_colon(label) + label_frequency[label] = label_frequency.get(label, 0) + 1 + + samples = [] + samples_fewshot = [] + samples_oneshot = [] + samples_zeroshot = [] + + for file_out in os.scandir(data_folder): + if file_out.is_dir(): + path_in = file_out.path + for file_in in os.scandir(path_in): + if file_in.is_dir(): + data_path = file_in.path + for entry in os.scandir(data_path): + if entry.is_file() and entry.name.endswith('.hea'): + label_file = entry.path + + record_name = label_file[:label_file.rfind('.')] + signals, _ = wfdb.rdsamp(record_name) + header = wfdb.rdheader(record_name) + + if np.isnan(signals).any(): + continue # Skip NaN values + + label = header.comments[2] + label = remove_before_colon(label) + label_list = label.split(',') + + label_indices = [diction2[int(snomed_ct_code)] for snomed_ct_code in label_list if int(snomed_ct_code) in diction2] + label_vector = [1 if i in label_indices else 0 for i in range(num_classes)] + + if not any(label_vector): + continue + + text = pro_text(header.comments, diction1, only_label) + ecg = pro_ecg(signals) + + if label_frequency[label] > 20: + samples.append((text, ecg, label_vector)) + elif label_frequency[label] <= 20 and label_frequency[label] > 10: + samples_fewshot.append((text, ecg, label_vector)) + elif label_frequency[label] <= 10 and label_frequency[label] > 1: + samples_oneshot.append((text, ecg, label_vector)) + elif label_frequency[label] == 1: + samples_zeroshot.append((text, ecg, label_vector)) + + label_to_samples_map_fewshot = {} + for sample in samples_fewshot: + text, ecg, label_vector = sample + label_index = label_vector.index(1) + if label_index not in label_to_samples_map_fewshot: + label_to_samples_map_fewshot[label_index] = [] + label_to_samples_map_fewshot[label_index].append(sample) + + test_fewshot = [] + for label_index, samples_list in label_to_samples_map_fewshot.items(): + # 从每个类别中选择七个样本并保留它们作为测试样本 + selected_samples = samples_list[-7:] + test_fewshot.extend(selected_samples) + + train_fewshot = [sample for samples_list in label_to_samples_map_fewshot.values() for sample in samples_list[:-7]] + np.random.shuffle(train_fewshot) + + label_to_samples_map_oneshot = {} + for sample in samples_oneshot: + text, ecg, label_vector = sample + label_index = label_vector.index(1) + if label_index not in label_to_samples_map_oneshot: + label_to_samples_map_oneshot[label_index] = [] + label_to_samples_map_oneshot[label_index].append(sample) + + train_oneshot = [] + for label_index, samples_fewshot in label_to_samples_map_oneshot.items(): + # 从每个类别中选择一个样本并保留它作为训练样本 + selected_sample = samples_fewshot.pop() + train_oneshot.append(selected_sample) + + test_oneshot = [sample for samples_list in label_to_samples_map_oneshot.values() for sample in samples_list] + np.random.shuffle(train_oneshot) + + # Shuffle and split the samples as before + index = [i for i in range(len(samples))] + np.random.shuffle(index) + split = int(0.9 * len(samples)) + samples_train = [samples[i] for i in index[:split]] + samples_test = [samples[i] for i in index[split:]] + + samples_train = samples_train + train_oneshot + train_fewshot + + print(len(samples_train + samples_test)) + print(len(test_fewshot)) + print(len(test_oneshot)) + print(len(samples_zeroshot)) + + # Save the filtered samples as before + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + with open('test_fewshot.pkl', 'wb') as file: + pickle.dump(test_fewshot, file) + with open('test_oneshot.pkl', 'wb') as file: + pickle.dump(test_oneshot, file) + with open('samples_test_zeroshot.pkl', 'wb') as file: + pickle.dump(samples_zeroshot, file) + + return samples_train, samples_test + +def process_12_lead(data_folder='./data/12-lead/WFDBRecords/', only_label=False): + diction1, diction2 = get_dictionaries('./data/12-lead.csv') + num_classes = len(diction1) + samples = [] + + for file_out in os.scandir(data_folder): + if file_out.is_dir(): + path_in = file_out.path + for file_in in os.scandir(path_in): + if file_in.is_dir(): + data_path = file_in.path + for entry in os.scandir(data_path): + if entry.is_file() and entry.name.endswith('.hea'): + label_file = entry.path + + record_name = label_file[:label_file.rfind('.')] + signals, _ = wfdb.rdsamp(record_name) + header = wfdb.rdheader(record_name) + + if np.isnan(signals).any(): + continue # Skip NaN values + + label = header.comments[2] + label = remove_before_colon(label) + label_list = label.split(',') + + label_indices = [diction2[int(snomed_ct_code)] for snomed_ct_code in label_list if int(snomed_ct_code) in diction2] + + label_vector = [1 if i in label_indices else 0 for i in range(num_classes)] + if not any(label_vector): + continue + + text = pro_text(header.comments, diction1, only_label) + ecg = pro_ecg(signals) + + if label_indices: + samples.append((text, ecg, label_vector)) + + # Shuffle and split the samples as before + index = [i for i in range(len(samples))] + np.random.shuffle(index) + split = int(0.9 * len(samples)) + samples_train = [samples[i] for i in index[:split]] + samples_test = [samples[i] for i in index[split:]] + + print(len(samples_train + samples_test)) + + # Save the filtered samples as before + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def pro_pd(label, label_dict, few=False): + prefix = "Please provide a description of potential health problems, symptoms, and conditions that the patient might be facing based on the provided information.\nThis patient's symptoms include " + if few: + if not label['diagnostic_superclass']: + return '', 0 + prefix += label['diagnostic_superclass'][0] + else: + if not label['description']: + return '', 0 + prefix += label['description'][0] + + gender = 'Gender: ' + if label['sex'] == 1: + gender += 'male\n' + else: + gender += 'female\n' + + age = 'Age: ' + if label['age'] <= 6: + age += 'adolescence\n' + elif label['age'] <= 17: + age += 'juvenile\n' + elif label['age'] <= 40: + age += 'youths\n' + elif label['age'] <= 65: + age += 'middle-age\n' + else: + age += 'the elderly\n' + + height = 'Height: ' + if label['height'] is not None: + height += str(label['height']) + height += ' cm\n' + else: + height += 'unknown\n' + + weight = 'Weight: ' + if label['weight'] is not None: + weight += str(label['weight']) + weight += ' kg\n' + else: + weight += 'unknown\n' + + text = gender + age + height + weight + prefix + + if few: + if label['diagnostic_superclass'][0] == 'NORM': + vector = 0 + elif label['diagnostic_superclass'][0] == 'STTC': + vector = 1 + elif label['diagnostic_superclass'][0] == 'MI': + vector = 2 + elif label['diagnostic_superclass'][0] == 'CD': + vector = 3 + elif label['diagnostic_superclass'][0] == 'HYP': + vector = 4 + else: + vector = label_dict[label['description'][0]] + + return text, vector + +def process_ptbxl(data_folder='./data/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1/', few=True): + def aggregate_description(y_dic): + tmp = [] + for key in y_dic.keys(): + if key in agg_df.index: + tmp.append(agg_df.loc[key].description) + return list(set(tmp)) + def aggregate_diagnostic(y_dic): + tmp = [] + for key in y_dic.keys(): + if key in agg_df.index: + tmp.append(agg_df.loc[key].diagnostic_class) + return list(set(tmp)) + + samples = [] + diction = get_dict_ptb(few=few) + + # Load and convert annotation data + Y = pd.read_csv(data_folder + 'ptbxl_database.csv', index_col='ecg_id') + Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) + + # data = [wfdb.rdsamp(data_folder + f) for f in Y.filename_hr] + + # data = np.array([signal for signal, meta in data]).astype(np.float32) + + # Load scp_statements.csv for diagnostic aggregation + agg_df = pd.read_csv(data_folder + 'scp_statements.csv', index_col=0) + agg_df = agg_df[agg_df.diagnostic == 1] + + # Apply diagnostic superclass + Y['description'] = Y.scp_codes.apply(aggregate_description) + Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic) + + for _, item in Y.iterrows(): + signals, _ = wfdb.rdsamp(data_folder + item.filename_hr) + # header = wfdb.rdheader(data_folder + f) + + if np.isnan(signals).any(): + continue # Skip NaN values + + text, vector = pro_pd(item, diction, few=few) + ecg = pro_ecg(signals) + + if ecg.shape == (5000, 12) and text != '': + samples.append((text, ecg, vector)) + + # Shuffle and split the samples as before + index = [i for i in range(len(samples))] + np.random.shuffle(index) + split = int(0.9 * len(samples)) + samples_train = [samples[i] for i in index[:split]] + samples_test = [samples[i] for i in index[split:]] + + # Save the filtered samples as before + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def process_Georgia(data_folder='./data/georgia/', only_label=False): + if only_label: + diction1, diction2 = get_dictionaries('./data/geo-score.csv') + else: + diction1, diction2 = get_dictionaries('./data/essy.csv') + num_classes = len(diction1) + print(num_classes) + samples = [] + + for file_in in os.scandir(data_folder): + if file_in.is_dir(): + data_path = file_in.path + for entry in os.scandir(data_path): + if entry.is_file() and entry.name.endswith('.hea'): + label_file = entry.path + header = wfdb.rdheader(label_file[:label_file.rfind('.')]) + label = remove_before_colon(header.comments[2]) + + signals, _ = wfdb.rdsamp(label_file[:label_file.rfind('.')]) + if np.isnan(signals).any(): + continue # Skip NaN values + + label_list = label.split(',') + label_list = [int(item) for item in label_list] + # print(label_list) + + # if only_label: + # label_list = trans_code(label_list) + label_indices = [diction2[code] for code in label_list if code in diction2] + # print(diction2) + label_vector = [1 if i in label_indices else 0 for i in range(num_classes)] + # print(label_indices) + if not any(label_vector): + continue + + text = pro_text(header.comments, diction1) + # print(text) + # break + ecg = pro_ecg(signals) + if ecg.shape == (5000, 12): + samples.append((text, ecg, label_vector)) + + np.random.shuffle(samples) + split = int(0.9 * len(samples)) + samples_train = samples[:split] + samples_test = samples[split:] + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def process_ecg(data_folder1='./data/georgia', data_folder2='./data/cpsc_2018', only_label=False): + if only_label: + diction1, diction2 = get_dictionaries('./data/geo-score.csv') + else: + diction1, diction2 = get_dictionaries('./data/essy.csv') + num_classes = len(diction1) + samples = [] + + for file_in in os.scandir(data_folder1): + if file_in.is_dir(): + data_path = file_in.path + for entry in os.scandir(data_path): + if entry.is_file() and entry.name.endswith('.hea'): + label_file = entry.path + header = wfdb.rdheader(label_file[:label_file.rfind('.')]) + label = remove_before_colon(header.comments[2]) + + signals, _ = wfdb.rdsamp(label_file[:label_file.rfind('.')]) + if np.isnan(signals).any(): + continue # Skip NaN values + + label_list = label.split(',') + label_list = [int(item) for item in label_list] + # print(label_list) + + # if only_label: + # label_list = trans_code(label_list) + label_indices = [diction2[code] for code in label_list if code in diction2] + # print(diction2) + label_vector = [1 if i in label_indices else 0 for i in range(num_classes)] + # print(label_indices) + if not any(label_vector): + continue + + text = pro_text(header.comments, diction1) + # print(text) + # break + ecg = pro_ecg(signals) + if ecg.shape == (5000, 12): + samples.append((text, ecg, label_vector)) + + for file_out in os.scandir(data_folder2): + if file_out.is_dir(): + path_in = file_out.path + for entry in os.scandir(path_in): + if entry.is_file() and entry.name.endswith('.hea'): + label_file = entry.path + + record_name = label_file[:label_file.rfind('.')] + signals, _ = wfdb.rdsamp(record_name) + header = wfdb.rdheader(record_name) + + if np.isnan(signals).any(): + continue # Skip NaN values + + label = header.comments[2] + label = remove_before_colon(label) + label_list = label.split(',') + label_list = [int(item) for item in label_list] + label_indices = [diction2[snomed_ct_code] for snomed_ct_code in label_list if snomed_ct_code in diction2] + + label_vector = [1 if i in label_indices else 0 for i in range(num_classes)] + if not any(label_vector): + continue + + text = pro_text(header.comments, diction1) + ecg = pro_ecg(signals) + + if ecg.shape == (5000, 12): + samples.append((text, ecg, label_vector)) + + np.random.shuffle(samples) + split = int(0.9 * len(samples)) + samples_train = samples[:split] + samples_test = samples[split:] + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def process_ptb_xl_mul(data_folder='./data/ptb-xl-multi/', only_label=False): + if only_label: + diction1, diction2 = get_dictionaries('./data/ptbxl-score.csv') + else: + diction1, diction2 = get_dictionaries('./data/ptbxl.csv') + num_classes = len(diction1) + samples = [] + + for file_out in os.scandir(data_folder): + if file_out.is_dir(): + path_in = file_out.path + for entry in os.scandir(path_in): + if entry.is_file() and entry.name.endswith('.hea'): + record_name = entry.path[:-4] + header = wfdb.rdheader(record_name) + label = remove_before_colon(header.comments[2]) + + signals, _ = wfdb.rdsamp(record_name) + if np.isnan(signals).any(): + continue + + label_list = label.split(',') + label_indices = [diction2[int(code)] for code in label_list if int(code) in diction2] + label_vector = [1 if i in label_indices else 0 for i in range(num_classes)] + if not any(label_vector): + continue + + text = pro_text(header.comments, diction1) + ecg = pro_ecg(signals) + if ecg.shape == (5000, 12): + samples.append((text, ecg, label_vector)) + + np.random.shuffle(samples) + split = int(0.9 * len(samples)) + samples_train = samples[:split] + samples_test = samples[split:] + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def process_cpsc_mul(data_folder='./data/cpsc_2018/', only_label=False): + if only_label: + diction1, diction2 = get_dictionaries('./data/cpsc-score.csv') + else: + diction1, diction2 = get_dictionaries('./data/essy.csv') + num_classes = len(diction1) + samples = [] + + for file_out in os.scandir(data_folder): + if file_out.is_dir(): + path_in = file_out.path + for entry in os.scandir(path_in): + if entry.is_file() and entry.name.endswith('.hea'): + label_file = entry.path + + record_name = label_file[:label_file.rfind('.')] + signals, _ = wfdb.rdsamp(record_name) + header = wfdb.rdheader(record_name) + + if np.isnan(signals).any(): + continue # Skip NaN values + + label = header.comments[2] + label = remove_before_colon(label) + label_list = label.split(',') + label_list = [int(item) for item in label_list] + label_indices = [diction2[snomed_ct_code] for snomed_ct_code in label_list if snomed_ct_code in diction2] + + label_vector = [1 if i in label_indices else 0 for i in range(num_classes)] + if not any(label_vector): + continue + + text = pro_text(header.comments, diction1) + ecg = pro_ecg(signals) + + if ecg.shape == (5000, 12): + samples.append((text, ecg, label_vector)) + + # Shuffle and split the samples as before + index = [i for i in range(len(samples))] + np.random.shuffle(index) + split = int(0.9 * len(samples)) + samples_train = [samples[i] for i in index[:split]] + samples_test = [samples[i] for i in index[split:]] + + # Save the filtered samples as before + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + with open('zeroshot.pkl', 'wb') as file: + pickle.dump(samples, file) + + print(len(samples)) + + return samples_train, samples_test + +def pro_text(comments, diction, only_label=False): + ann = comments[2] + snomed_ct_matches = re.findall(r'\d+', ann) + + prefix = "Describe the potential health issue(s) and associated symptom(s) the patient may be experiencing based on the provided information.\nThe symptom(s) exhibited by this patient include(s) " + new_snomed_ct_codes = '' + + for snomed_ct_code in snomed_ct_matches: + label = diction.get(int(snomed_ct_code)) + if label: + new_snomed_ct_codes += label + ',' + else: + continue + + if new_snomed_ct_codes.endswith(','): + new_snomed_ct_codes = new_snomed_ct_codes[:-1] + + if new_snomed_ct_codes == '': + return '' + + comments[2] = prefix + new_snomed_ct_codes + + colon_index = comments[0].find(":") + if colon_index != -1: + comments[0] = "Age:" + comments[0][colon_index + 1:] + else: + return '' + + colon_index = comments[1].find(":") + if colon_index != -1: + comments[1] = "Gender:" + comments[1][colon_index + 1:] + else: + return '' + + ann = comments[0] + age_matches = re.findall(r'\d+', ann) + + for age_code in age_matches: + age_code_int = int(age_code) + if age_code_int <= 6: + ann = ann.replace(age_code, 'adolescence') + elif age_code_int <= 17: + ann = ann.replace(age_code, 'juvenile') + elif age_code_int <= 40: + ann = ann.replace(age_code, 'youths') + elif age_code_int <= 65: + ann = ann.replace(age_code, 'middle-age') + else: + ann = ann.replace(age_code, 'the elderly') + comments[0] = ann + + if only_label: + text = comments[2] + else: + text = comments[0] + '\n' + comments[1] + '\n' + comments[2] + + return text + +def process_ecg_data(data_folder, dict_path, multi_folder=False, only_label=False): + diction1, diction2 = get_dictionaries(dict_path) + label_frequency = {} + num_classes = len(diction1) + samples = [] + + def process_entry(entry): + if entry.is_file() and entry.name.endswith('.hea'): + record_name = entry.path[:-4] + header = wfdb.rdheader(record_name) + label = remove_before_colon(header.comments[2]) + label_frequency[label] = label_frequency.get(label, 0) + 1 + + signals, _ = wfdb.rdsamp(record_name) + if np.isnan(signals).any(): + return None + + label_list = label.split(',') + label_indices = [diction2[int(code)] for code in label_list if int(code) in diction2] + label_vector = [1 if i in label_indices else 0 for i in range(num_classes)] + if not any(label_vector): + return None + + text = pro_text(header.comments, diction1, only_label) + ecg = pro_ecg(signals) + if ecg.shape == (5000, 12): + return (text, ecg.transpose(-1, 0), label_vector) + else: + return None + + for file_out in os.scandir(data_folder): + if multi_folder and file_out.is_dir(): + path_in = file_out.path + for entry in os.scandir(path_in): + result = process_entry(entry) + if result: + samples.append(result) + elif not multi_folder: + result = process_entry(file_out) + if result: + samples.append(result) + + np.random.shuffle(samples) + split = int(0.9 * len(samples)) + samples_train = samples[:split] + samples_test = samples[split:] + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def process_eeg(data_folder='./data/sleep-edf-database-1.0.0'): + samples = [] + + labels = np.load(os.path.join(data_folder, 'label.npy')) + ecgs = np.load(os.path.join(data_folder, 'data.npy')) + + # 遍历数组并规范化 + prefix = "Select a previously mentioned sleep pattern and report on the person's sleep using the provided information.\nThe person's sleep pattern is " + midfix = 'The sleep patterns include waking up, rapid eye movement sleep, and sleep stages one through four, as well as periods of movement and unidentified stages.\n' + for ecg, label in zip(ecgs, labels): + # 检查是否为 NaN 值 + if np.isnan(ecg).any(): + continue # 丢弃包含 NaN 值的 ECG 数据 + + # 执行规范化操作,并将结果添加到列表中 + ecg = normalize(ecg) # 使用您的规范化函数 + # print(label) + if int(label) == 0: + text = 'waking up' + elif int(label) == 1: + text = 'rapid eye movement sleep' + elif int(label) == 2: + text = 'sleep stage one' + elif int(label) == 3: + text = 'sleep stage two' + elif int(label) == 4: + text = 'sleep stage three' + elif int(label) == 5: + text = 'sleep stage four' + elif int(label) == 6: + text = 'period of movement' + elif int(label) == 7: + text = 'unidentified stage' + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + + samples.append((text, ecg, label_vector)) + + # samples = samples[:10000] + np.random.shuffle(samples) + split = int(0.9 * len(samples)) + samples_train = samples[:split] + samples_test = samples[split:] + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def process_har(data_folder='./data/HAR'): + normalized_samples_train = [] + normalized_samples_test = [] + + data_val = torch.load('./data/HAR/val.pt') + data_train = torch.load('./data/HAR/train.pt') + data_test = torch.load('./data/HAR/test.pt') + + samples_train = data_train['samples'].numpy() + labels_train = data_train['labels'].numpy() + samples_test = data_test['samples'].numpy() + labels_test = data_test['labels'].numpy() + samples_val = data_val['samples'].numpy() + labels_val = data_val['labels'].numpy() + + samples = np.concatenate([samples_train, samples_val], axis=0) + labels = np.concatenate([labels_train, labels_val], axis=0) + + # 遍历数组并规范化 + prefix = "Please choose one activity from the previously mentioned six options and analyze the individual's physical activity based on the provided information.\nThe individual is currently engaged in " + midfix = 'Physical activities such as walking, ascending stairs, descending stairs, sitting, standing, and lying down are recorded using mobile phone sensors.\n' + + for ecg, label in zip(samples, labels): + # 检查是否为 NaN 值 + if np.isnan(ecg).any(): + continue # 丢弃包含 NaN 值的 ECG 数据 + + # 执行规范化操作,并将结果添加到列表中 + ecg = normalize(ecg.astype(np.float32)) # 使用您的规范化函数 + ecg = ecg.transpose() + + if int(label) == 0: + text = 'walking' + elif int(label) == 1: + text = 'ascending stairs' + elif int(label) == 2: + text = 'descending stairs' + elif int(label) == 3: + text = 'sitting' + elif int(label) == 4: + text = 'standing' + elif int(label) == 5: + text = 'lying down' + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (128, 9): + normalized_samples_train.append((text, ecg, label_vector)) + + for ecg, label in zip(samples_test, labels_test): + # 检查是否为 NaN 值 + if np.isnan(ecg).any(): + continue # 丢弃包含 NaN 值的 ECG 数据 + + # 执行规范化操作,并将结果添加到列表中 + ecg = normalize(ecg.astype(np.float32)) # 使用您的规范化函数 + ecg = ecg.transpose() + + if int(label) == 0: + text = 'walking' + elif int(label) == 1: + text = 'ascending stairs' + elif int(label) == 2: + text = 'descending stairs' + elif int(label) == 3: + text = 'sitting' + elif int(label) == 4: + text = 'standing' + elif int(label) == 5: + text = 'lying down' + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (128, 9): + normalized_samples_test.append((text, ecg, label_vector)) + + # np.random.shuffle(normalized_samples) + # split = int(0.8 * len(normalized_samples)) + # samples_train = normalized_samples[:split] + # samples_test = normalized_samples[split:] + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(normalized_samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(normalized_samples_test, file) + + return normalized_samples_train, normalized_samples_test + +def padding_varying_length(data): + data[np.isnan(data)] = 0 + return data + +def process_ad(data_folder='./datas/AD_data'): + normalized_samples_train = [] + normalized_samples_test = [] + + data_train = torch.load('./datas/AD_data/train.pt') + data_test = torch.load('./datas/AD_data/test.pt') + + samples_train = data_train['samples'].numpy() + labels_train = data_train['labels'].numpy() + samples_test = data_test['samples'].numpy() + labels_test = data_test['labels'].numpy() + + # 遍历数组并规范化 + prefix = "Please select one activity from the previously mentioned ten digits and analyze the individual's handwriting based on the provided information.\nThe person is currently writing the digit " + midfix = 'Physical activities that specifically involve using a pen to write digits, which range from one to ten.\n' + + for ecg, label in zip(samples_train, labels_train): + # # 检查是否为 NaN 值 + # padding_varying_length(ecg) + + # 执行规范化操作,并将结果添加到列表中 + ecg = normalize(ecg.astype(np.float32)) # 使用您的规范化函数 + + if int(label) == 0: + text = 'zero' + elif int(label) == 1: + text = 'one' + elif int(label) == 2: + text = 'two' + elif int(label) == 3: + text = 'three' + elif int(label) == 4: + text = 'four' + elif int(label) == 5: + text = 'five' + elif int(label) == 6: + text = 'six' + elif int(label) == 7: + text = 'seven' + elif int(label) == 8: + text = 'eight' + elif int(label) == 9: + text = 'nine' + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (93, 13): + normalized_samples_train.append((text, ecg, label_vector)) + + for ecg, label in zip(samples_test, labels_test): + # # 检查是否为 NaN 值 + # padding_varying_length(ecg) + + # 执行规范化操作,并将结果添加到列表中 + ecg = normalize(ecg.astype(np.float32)) # 使用您的规范化函数 + + if int(label) == 0: + text = 'zero' + elif int(label) == 1: + text = 'one' + elif int(label) == 2: + text = 'two' + elif int(label) == 3: + text = 'three' + elif int(label) == 4: + text = 'four' + elif int(label) == 5: + text = 'five' + elif int(label) == 6: + text = 'six' + elif int(label) == 7: + text = 'seven' + elif int(label) == 8: + text = 'eight' + elif int(label) == 9: + text = 'nine' + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (93, 13): + normalized_samples_test.append((text, ecg, label_vector)) + + # np.random.shuffle(normalized_samples) + # split = int(0.8 * len(normalized_samples)) + # samples_train = normalized_samples[:split] + # samples_test = normalized_samples[split:] + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(normalized_samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(normalized_samples_test, file) + + return normalized_samples_train, normalized_samples_test + +def process_esr(data_folder='./data/HAR'): + normalized_samples_train = [] + normalized_samples_test = [] + + data_val = torch.load('./datas/esr/val.pt') + data_train = torch.load('./datas/esr/train.pt') + data_test = torch.load('./datas/esr/test.pt') + + samples_train = data_train['samples'].numpy() + labels_train = data_train['labels'].numpy() + samples_test = data_test['samples'].numpy() + labels_test = data_test['labels'].numpy() + samples_val = data_val['samples'].numpy() + labels_val = data_val['labels'].numpy() + + samples = np.concatenate([samples_train, samples_val], axis=0) + labels = np.concatenate([labels_train, labels_val], axis=0) + + # 遍历数组并规范化 + prefix = "Please choose one of the two previously mentioned labels and analyze the individual's condition for a possible epilepsy diagnosis based on the provided information.\nAt this moment, the individual is existing in a particular state of " + midfix = "In the clinical evaluation, two labels are used to denote the patient's state: 'no abnormalities' for normal conditions and 'epileptic seizure' for seizure activity.\n" + + for ecg, label in zip(samples, labels): + if np.isnan(ecg).any(): + continue + + ecg = normalize(ecg.astype(np.float32)) + ecg = ecg.transpose() + + if int(label) == 0: + text = 'no abnormalities' + elif int(label) == 1: + text = 'epileptic seizure' + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (178, 1): + normalized_samples_train.append((text, ecg, label_vector)) + + for ecg, label in zip(samples_test, labels_test): + if np.isnan(ecg).any(): + continue + + ecg = normalize(ecg.astype(np.float32)) + ecg = ecg.transpose() + + if int(label) == 0: + text = 'no abnormalities' + elif int(label) == 1: + text = 'epileptic seizure' + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (178, 1): + normalized_samples_test.append((text, ecg, label_vector)) + + # np.random.shuffle(normalized_samples) + # split = int(0.8 * len(normalized_samples)) + # samples_train = normalized_samples[:split] + # samples_test = normalized_samples[split:] + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(normalized_samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(normalized_samples_test, file) + + return normalized_samples_train, normalized_samples_test + +def process_UWG(data_folder='./data/UWG'): + normalized_samples_train = [] + normalized_samples_test = [] + + samples_train = np.load('./datas/UWG/train_d.npy') + labels_train = np.load('./datas/UWG/train_l.npy') + samples_test = np.load('./datas/UWG/test_d.npy') + labels_test = np.load('./datas/UWG/test_l.npy') + + print(labels_test) + + # 遍历数组并规范化 + prefix = "Please choose one of the eight labels to analysis people's gesture based on the provided information.\nThe man is currently showing the gesture of " + midfix = "A set of eight simple gestures (musical corner, kicking, directing, reversing, ascending and descending arrow, clockwise and countercolockwise loop) generated from accelerometers.\n" + + for ecg, label in zip(samples_train, labels_train): + if np.isnan(ecg).any(): + continue + + ecg = normalize(ecg.astype(np.float32)) + + if int(label) == 0: + text = 'kicking arrow' + elif int(label) == 1: + text = 'musical corner' + elif int(label) == 2: + text = 'directing arrow' + elif int(label) == 3: + text = 'reversing arrow' + elif int(label) == 4: + text = 'ascending arrow' + elif int(label) == 5: + text = 'descending arrow' + elif int(label) == 6: + text = 'clockwise loop' + elif int(label) == 7: + text = 'countercolockwise loop' + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (315, 3): + normalized_samples_train.append((text, ecg, label_vector)) + + for ecg, label in zip(samples_test, labels_test): + if np.isnan(ecg).any(): + continue + + ecg = normalize(ecg.astype(np.float32)) + + if int(label) == 0: + text = 'kicking arrow' + elif int(label) == 1: + text = 'musical corner' + elif int(label) == 2: + text = 'directing arrow' + elif int(label) == 3: + text = 'reversing arrow' + elif int(label) == 4: + text = 'ascending arrow' + elif int(label) == 5: + text = 'descending arrow' + elif int(label) == 6: + text = 'clockwise loop' + elif int(label) == 7: + text = 'countercolockwise loop' + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (315, 3): + normalized_samples_test.append((text, ecg, label_vector)) + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(normalized_samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(normalized_samples_test, file) + + return normalized_samples_train, normalized_samples_test + +def process_PS(data_folder='./data/PS'): + normalized_samples_train = [] + normalized_samples_test = [] + + samples_train = np.load('./datas/PS/train_d.npy') + labels_train = np.load('./datas/PS/train_l.npy') + samples_test = np.load('./datas/PS/test_d.npy') + labels_test = np.load('./datas/PS/test_l.npy') + + # 遍历数组并规范化 + prefix = "Please choose one of the thirty-nine labels to analyze the spectrogram characteristics of the phoneme based on the provided information.\nthe man is currently articulating the phoneme " + midfix = "This data set is a multivaritate representation of a subset of the data used in the paper Dual-domain Hierarchical Classification of Phonetic Time Series.\n" + + for ecg, label in zip(samples_train, labels_train): + if np.isnan(ecg).any(): + continue + + ecg = normalize(ecg.astype(np.float32)) + + if int(label) in range(39): + labels_text = [ + 'AH', 'N', 'T', 'L', 'S', 'R', 'IH', 'K', 'IY', + 'D', 'M', 'ER', 'EH', 'P', 'AE', 'B', 'AA', 'EY', 'F', + 'AY', 'OW', 'SH', 'V', 'G', 'AO', 'Z', 'UW', + 'NG', 'W', 'JH', 'HH', 'Y', 'CH', 'TH', 'AW', + 'UH', 'OY', 'DH', 'ZH' + ] + + text = labels_text[int(label)] + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (217, 11): + normalized_samples_train.append((text, ecg, label_vector)) + + for ecg, label in zip(samples_test, labels_test): + if np.isnan(ecg).any(): + continue + + ecg = normalize(ecg.astype(np.float32)) + + if int(label) in range(39): + labels_text = [ + 'AH', 'N', 'T', 'L', 'S', 'R', 'IH', 'K', 'IY', + 'D', 'M', 'ER', 'EH', 'P', 'AE', 'B', 'AA', 'EY', 'F', + 'AY', 'OW', 'SH', 'V', 'G', 'AO', 'Z', 'UW', + 'NG', 'W', 'JH', 'HH', 'Y', 'CH', 'TH', 'AW', + 'UH', 'OY', 'DH', 'ZH' + ] + + text = labels_text[int(label)] + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (217, 11): + normalized_samples_test.append((text, ecg, label_vector)) + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(normalized_samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(normalized_samples_test, file) + + return normalized_samples_train, normalized_samples_test + +def parse_data_line(line): + # 将数据行分割为数值部分和标签部分 + data_str, label_str = line.split(':') + + # 解析数据部分 + data = list(map(float, data_str.split(','))) + + # 解析标签部分 + label = int(label_str) + + return data, label + +def read_data_and_labels_from_file(file_path): + data_points = [] + labels = [] + + try: + with open(file_path, 'r') as file: + for line in file: + line = line.strip() + + # 跳过以 '@' 开头的行 + if line.startswith('@'): + continue + + if line: + # 解析数据行 + data, label = parse_data_line(line) + + # 追加到数据和标签列表 + data_points.append(data) + labels.append(label) + + except FileNotFoundError: + print(f"错误:未找到文件 '{file_path}'.") + + return data_points, labels + +def process_device(data_folder='./data/device'): + normalized_samples_train = [] + normalized_samples_test = [] + + data_val_path = './data/device/val.ts' + data_train_path = './data/device/FaultDetectionA_TRAIN.ts' + data_test_path = './data/device/FaultDetectionA_TEST.ts' + + samples_train, labels_train = read_data_and_labels_from_file(data_train_path) + samples_test, labels_test = read_data_and_labels_from_file(data_test_path) + samples_val, labels_val = read_data_and_labels_from_file(data_val_path) + + samples_train = np.concatenate([samples_train, samples_val], axis=0) + labels_train = np.concatenate([labels_train, labels_val], axis=0) + + # 遍历数组并规范化 + prefix = "Selecting one from the above three status, please conduct an analysis of the machine's damage condition in accordance with the provided information.\nThe machine is probably participating in the subsequent damage conditions: " + midfix = 'These damage conditions include not damaged, inner damaged, and outer damaged.\n' + for ecg, label in zip(samples_train, labels_train): + # 检查是否为 NaN 值 + if np.isnan(ecg).any(): + continue # 丢弃包含 NaN 值的 ECG 数据 + + # 执行规范化操作,并将结果添加到列表中 + ecg = normalize(np.array(ecg).astype(np.float32)).reshape(-1, 1) # 使用您的规范化函数 + + if int(label) == 0: + text = 'not damaged' + elif int(label) == 1: + text = 'inner damaged' + elif int(label) == 2: + text = 'outer damaged' + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (5120, 1): + normalized_samples_train.append((text, ecg, label_vector)) + + for ecg, label in zip(samples_test, labels_test): + # 检查是否为 NaN 值 + if np.isnan(ecg).any(): + continue # 丢弃包含 NaN 值的 ECG 数据 + + # 执行规范化操作,并将结果添加到列表中 + ecg = normalize(np.array(ecg).astype(np.float32)).reshape(-1, 1) # 使用您的规范化函数 + + if int(label) == 0: + text = 'not damaged' + elif int(label) == 1: + text = 'inner damaged' + elif int(label) == 2: + text = 'outer damaged' + else: + text = '' + + if text == '': + continue + text = midfix + prefix + text + label_vector = label + if ecg.shape == (5120, 1): + normalized_samples_train.append((text, ecg, label_vector)) + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(normalized_samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(normalized_samples_test, file) + + return normalized_samples_train, normalized_samples_test + +def modify_text(original_text): + prefix = "Selecting a previous label based on the provided information.\nThe person's electrocardiogram pattern is indicative of " + midfix = 'This task focuses on the detection of abnormalities, and it involves classifying the ECGs into two distinct categories: normal and abnormal.\n' + + lines = original_text.split('\n') + side_info = '\n'.join(lines[:2]) + '\n' + + if "include(s)" in original_text: + index = original_text.find("include(s) ") + original_text = original_text[index + len("include(s) "):] + + if original_text == 'sinus rhythm': + text = 'normal ecg' + label = 1 + else: + text = 'abnormal ecg' + label = 0 + + modified_text = side_info + midfix + prefix + text + return modified_text, label + +def process_ecg_bi(Path='./ecg_new'): + train_path = os.path.join(Path, 'samples_train.pkl') + test_path = os.path.join(Path, 'samples_test.pkl') + + samples_train = [] + samples_test = [] + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test = pickle.load(file) + + for i, sample in enumerate(samples_train): + text, ecg, _ = sample + modified_text, label = modify_text(text) + samples_train[i] = (modified_text, ecg, label) + + # 对测试数据集执行相同的操作 + for i, sample in enumerate(samples_test): + text, ecg, _ = sample + modified_text, label = modify_text(text) + samples_test[i] = (modified_text, ecg, label) + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def modify_side_text(original_text): + lines = original_text.split('\n') + modified_text = '\n'.join(lines[2:]) + return modified_text + +def process_side_info(Path='./ecg_no_big'): + train_path = os.path.join(Path, 'samples_train.pkl') + test_path = os.path.join(Path, 'samples_test.pkl') + + samples_train = [] + samples_test = [] + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test = pickle.load(file) + + for i, sample in enumerate(samples_train): + text, ecg, label = sample + modified_text = modify_side_text(text) + samples_train[i] = (modified_text, ecg, label) + + # 对测试数据集执行相同的操作 + for i, sample in enumerate(samples_test): + text, ecg, label = sample + modified_text = modify_side_text(text) + samples_test[i] = (modified_text, ecg, label) + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def remove_line(original_text, line_number): + lines = original_text.split('\n') + if 0 <= line_number < len(lines): + del lines[line_number] + modified_text = '\n'.join(lines) + return modified_text + +def process_label_info(Path='./whale_no_big'): + train_path = os.path.join(Path, 'samples_train.pkl') + test_path = os.path.join(Path, 'samples_test.pkl') + + samples_train = [] + samples_test = [] + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test = pickle.load(file) + + for i, sample in enumerate(samples_train): + text, ecg, label = sample + modified_text = remove_line(text, 0) + samples_train[i] = (modified_text, ecg, label) + + # 对测试数据集执行相同的操作 + for i, sample in enumerate(samples_test): + text, ecg, label = sample + modified_text = remove_line(text, 0) + samples_test[i] = (modified_text, ecg, label) + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def process_label_info(Path='./whale_no_big'): + train_path = os.path.join(Path, 'samples_train.pkl') + test_path = os.path.join(Path, 'samples_test.pkl') + + samples_train = [] + samples_test = [] + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test = pickle.load(file) + + for i, sample in enumerate(samples_train): + text, ecg, label = sample + modified_text = remove_line(text, 0) + samples_train[i] = (modified_text, ecg, label) + + # 对测试数据集执行相同的操作 + for i, sample in enumerate(samples_test): + text, ecg, label = sample + modified_text = remove_line(text, 0) + samples_test[i] = (modified_text, ecg, label) + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def extract_from_text(text, keyword): + index = text.find(keyword) + if index != -1: + return text[index + len(keyword):] + return "" + +def extract_all_information(text): + # 合并搜索,一次性提取所有信息 + diagnosis = stage = har = dev = whale = "" + if "include(s)" in text: + diagnosis = extract_from_text(text, "include(s) ") + elif "pattern is" in text: + stage = extract_from_text(text, "pattern is ") + elif "engaged in" in text: + har = extract_from_text(text, "engaged in ") + elif "conditions:" in text: + dev = extract_from_text(text, "conditions: ") + elif "originates from" in text: + whale = extract_from_text(text, "originates from ") + return diagnosis, stage, har, dev, whale + +def read_and_modify_last_line(input_string, modification_function): + # 将输入字符串拆分成行 + lines = input_string.split('\n') + + # 如果字符串为空或只有一行,直接替换为修改后的行 + if len(lines) <= 1: + return modification_function(lines[0]) + + # 读取最后一行并进行修改 + last_line = lines[-1] + diagnosis, stage, har, dev, whale = extract_all_information(last_line) + + if diagnosis: + lines[-1] = diagnosis + elif stage: + lines[-1] = stage + elif har: + lines[-1] = har + elif dev: + lines[-1] = dev + elif whale: + lines[-1] = whale + + # 重新组合所有行并返回 + modified_string = '\n'.join(lines) + return modified_string + +def process_word_info(Path='./ecg_no_big'): + train_path = os.path.join(Path, 'samples_train.pkl') + test_path = os.path.join(Path, 'samples_test.pkl') + + samples_train = [] + samples_test = [] + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test = pickle.load(file) + + for i, sample in enumerate(samples_train): + text, ecg, label = sample + modified_text = read_and_modify_last_line(text, 0) + samples_train[i] = (modified_text, ecg, label) + + # 对测试数据集执行相同的操作 + for i, sample in enumerate(samples_test): + text, ecg, label = sample + modified_text = read_and_modify_last_line(text, 0) + samples_test[i] = (modified_text, ecg, label) + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +import csv + +def process_zero_shot_info(Path='./ecg_no_big'): + train_path = os.path.join(Path, 'samples_train.pkl') + test_path = os.path.join(Path, 'samples_test.pkl') + + samples_train = [] + samples_test = [] + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test = pickle.load(file) + + include = 'include(s) ' # Replace with the actual substring + label_list = ['atrial flutter', 'bradycardia', 'complete right bundle branch block', 'premature ventricular contractions', 'right axis deviation'] + + # Extract and process labels from training data + extracted_parts_train = [] + for i, sample in enumerate(samples_train): + label = sample[0] + index = label.find(include) + extracted_part = label[index + len(include):] + extracted_parts_train.append((extracted_part, i)) + + # Filter samples to be transferred based on label_list + transfer_indices = [i for label, i in extracted_parts_train if label in label_list] + transfer_samples = [samples_train[i] for i in transfer_indices] + samples_train = [sample for i, sample in enumerate(samples_train) if i not in transfer_indices] + + # Add the filtered samples to the testing set + samples_test.extend(transfer_samples) + + extracted_parts_train = [] + for i, sample in enumerate(samples_test): + label = sample[0] + index = label.find(include) + extracted_part = label[index + len(include):] + extracted_parts_train.append((extracted_part, i)) + + # Filter samples to be transferred based on label_list + transfer_indices = [i for label, i in extracted_parts_train if label in label_list] + transfer_samples = [samples_test[i] for i in transfer_indices] + print(len(transfer_samples)) + + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + with open('zero_shot.pkl', 'wb') as file: + pickle.dump(transfer_samples, file) + + return samples_train, samples_test + +def process_zero_shot(Path='./ecg_no_big'): + train_path = os.path.join(Path, 'samples_train.pkl') + test_path = os.path.join(Path, 'samples_test.pkl') + + samples_train = [] + samples_test = [] + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test = pickle.load(file) + + # # 统计训练数据集和测试数据集中的标签数量 + # train_labels = [sample[2] for sample in samples_train] + # test_labels = [sample[2] for sample in samples_test] + + # train_label_counts = Counter(train_labels) + # test_label_counts = Counter(test_labels) + + # print("训练数据集中的标签数量统计:") + # for label, count in train_label_counts.items(): + # print(f"标签 {label}: {count} 个样本") + + # print("\n测试数据集中的标签数量统计:") + # for label, count in test_label_counts.items(): + # print(f"标签 {label}: {count} 个样本") + + # 删除训练集中所有标签为1的样本 + samples_train = [sample for sample in samples_train if sample[2] != 2.0] + + # 可以选择保存修改后的训练数据集 + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +def process_info(input_string, modification_function): + # 将输入字符串拆分成行 + lines = input_string.split('\n') + + # 如果字符串为空或只有一行,直接替换为修改后的行 + if len(lines) <= 1: + return modification_function(lines[0]) + + # 读取最后一行并进行修改 + last_line = lines[-1] + diagnosis, stage, har, dev, whale = extract_all_information(last_line) + + if diagnosis: + lines[-1] = diagnosis + elif stage: + lines[-1] = stage + elif har: + lines[-1] = har + elif dev: + lines[-1] = dev + elif whale: + lines[-1] = whale + + # 重新组合所有行并返回 + modified_string = '\n'.join(lines) + return modified_string + +def process_gender(Path='./ecg_no_big'): + train_path = os.path.join(Path, 'samples_train.pkl') + test_path = os.path.join(Path, 'samples_test.pkl') + + samples_train = [] + samples_test = [] + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test = pickle.load(file) + + # # 统计训练数据集和测试数据集中的标签数量 + # train_labels = [sample[2] for sample in samples_train] + # test_labels = [sample[2] for sample in samples_test] + + # train_label_counts = Counter(train_labels) + # test_label_counts = Counter(test_labels) + + # print("训练数据集中的标签数量统计:") + # for label, count in train_label_counts.items(): + # print(f"标签 {label}: {count} 个样本") + + # print("\n测试数据集中的标签数量统计:") + # for label, count in test_label_counts.items(): + # print(f"标签 {label}: {count} 个样本") + + # 删除训练集中所有标签为1的样本 + samples_train = [sample for sample in samples_train if sample[2] != 2.0] + + # 可以选择保存修改后的训练数据集 + with open('samples_train.pkl', 'wb') as file: + pickle.dump(samples_train, file) + with open('samples_test.pkl', 'wb') as file: + pickle.dump(samples_test, file) + + return samples_train, samples_test + +if __name__ == "__main__": + setup_seed() + + samples1, labels1 = process_esr() + text1, _, _ = samples1[0] + + # samples2, labels2 = process_cpsc_mul() + # text2, _, _ = samples2[0] + + print(text1) + print(len(samples1), len(labels1)) + # print(text2) + # print(len(samples1 + samples2), len(labels1 + labels2)) + # process_ptbxl(few=True) + # process_12_lead() + + exit(0) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..95af4d6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,47 @@ +accelerate==0.25.0 +annotated-types==0.6.0 +blis==0.7.11 +catalogue==2.0.10 +click==8.1.7 +cloudpathlib==0.16.0 +confection==0.1.4 +cymem==2.0.8 +datasets==2.14.6 +einops==0.7.0 +fastai==2.7.13 +fastcore==1.5.29 +fastdownload==0.0.7 +fastprogress==1.0.3 +fonttools==4.25.0 +huggingface-hub==0.20.1 +imbalanced-learn==0.11.0 +langcodes==3.3.0 +llvmlite==0.41.1 +mkl-service==2.4.0 +munkres==1.1.4 +murmurhash==1.0.10 +neurokit2==0.2.7 +numba==0.58.1 +peft==0.7.1 +preshed==3.0.9 +pyarrow==11.0.0 +pydantic==2.5.3 +pydantic_core==2.14.6 +pyts==0.13.0 +pywin32==305.1 +scipy==1.11.4 +smart-open==6.4.0 +soundfile==0.12.1 +spacy==3.7.2 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +srsly==2.4.8 +thinc==8.2.2 +torch==2.1.2 +torchaudio==2.1.2 +torchvision==0.16.2 +tsai==0.3.8 +typer==0.9.0 +wasabi==1.1.2 +weasel==0.3.4 +wfdb==4.1.2 diff --git a/run_truth_loss.py b/run_truth_loss.py new file mode 100644 index 0000000..cda3a30 --- /dev/null +++ b/run_truth_loss.py @@ -0,0 +1,421 @@ +import os +import torch +import random +import logging +from logging.handlers import RotatingFileHandler +import pickle +import transformers +import numpy as np +import torch.nn as nn +from tqdm import tqdm +from torch.utils.data import DataLoader +from torch.cuda.amp import autocast, GradScaler +from transformers import ( + GPT2LMHeadModel, + GPT2Config, +) + +from multimodel import InstructTime, MultiTokenizer +from multidataset import MultiDataset +from args import get_hyperparams +from metrics import metric_ecg, metric_eeg, metric_har, metric_fd, metric_rwc +from utils import extract_all_information, load_TStokenizer + +local_model_path = "./gpt2-model" +vqvae_path1 = "./ecg_tokenizer/test_ecg_64_128_40" +vqvae_path2 = "./ecg_tokenizer/test_eeg_64_256_25" +vqvae_path3 = "./ecg_tokenizer/test_fd_64_512_40" +vqvae_path4 = "./ecg_tokenizer/test_har_64_256_1" +vqvae_path5 = "./ecg_tokenizer/test_rwc_64_384_32" + +def seed_everything(seed): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = True + +def collate_fn_train(batch): + input_ids = [x["input_ids"] for x in batch] + attention_mask = [x["attn_masks"] for x in batch] + label_ids = [x["label_ids"] for x in batch] + return { + "input_ids": torch.stack(input_ids), + "attention_mask": torch.stack(attention_mask), + "label_ids": torch.stack(label_ids), + } + +def collate_fn_test(batch): + input_ids = [x["input_ids"] for x in batch] + attention_mask = [x["attn_masks"] for x in batch] + labels = [x["label"] for x in batch] + return { + "input_ids": torch.stack(input_ids), + "attention_mask": torch.stack(attention_mask), + "labels": labels, + } + +def test(model, TestDataLoader, args, logger, out=False): + model.eval() + + with torch.no_grad(): + pred_ids, pred_eeg, pred_har, pred_fd, pred_rwc = [], [], [], [], [] + labels, labels_eeg, labels_har, labels_fd, labels_rwc = [], [], [], [], [] + + all_extracted_info = [] + all_sig_labels = [] + if out: + print_labels = [] + print_preds = [] + for data in tqdm(TestDataLoader, desc="Eval", ncols=120): + input_ids = data["input_ids"].to(args.device) + bt_labels = data["labels"] + + outputs = model.generate( + input_ids=input_ids, + pad_token_id=tokenizer.pad_token_id, + num_beams=args.num_beams, + num_return_sequences=args.num_return_sequences, + do_sample=False, + max_new_tokens=args.per_max_token, + ) + + mask = outputs >= tokenizer.text_vocab_size + outputs[mask] = tokenizer.pad_token_id + outputs = outputs[:, args.encoder_max_length:] + decoded_texts = [tokenizer.decode(o, skip_special_tokens=True) for o in outputs] + all_extracted_info.extend([extract_all_information(dt) for dt in decoded_texts]) + all_sig_labels.extend([extract_all_information(label) for label in bt_labels]) + if out: + print_labels.extend(bt_labels) + print_preds.extend(decoded_texts) + + for decoded_info, sig_label_info in zip(all_extracted_info, all_sig_labels): + diagnosis_text, stage_text, har_text, fd_text, rwc_text = decoded_info + diagnosis_label, stage_label, har_label, fd_label, rwc_label = sig_label_info + + if diagnosis_label: + pred_ids.append(diagnosis_text) + labels.append(diagnosis_label) + + elif stage_label: + pred_eeg.append(stage_text) + labels_eeg.append(stage_label) + + elif har_label: + pred_har.append(har_text) + labels_har.append(har_label) + + elif fd_label: + pred_fd.append(fd_text) + labels_fd.append(fd_label) + + elif rwc_label: + pred_rwc.append(rwc_text) + labels_rwc.append(rwc_label) + + res1, res2, res3, res4, res5 = 0, 0, 0, 0, 0 + if args.dataset == 'mix': + res1, _, _ = metric_ecg(pred_ids, labels, logger) + res2, _, _ = metric_eeg(pred_eeg, labels_eeg, logger) + res3, _, _ = metric_har(pred_har, labels_har, logger) + res4, _, _ = metric_fd(pred_fd, labels_fd, logger) + res5, _, _ = metric_rwc(pred_rwc, labels_rwc, logger) + elif args.dataset == 'geo': + res1, _, _ = metric_ecg(pred_ids, labels, logger) + elif args.dataset == 'eeg': + res2, _, _ = metric_eeg(pred_eeg, labels_eeg, logger) + elif args.dataset == 'fd': + res3, _, _ = metric_fd(pred_fd, labels_fd, logger) + elif args.dataset == 'rwc': + res5, _, _ = metric_rwc(pred_rwc, labels_rwc, logger) + else: + res4, _, _ = metric_har(pred_har, labels_har, logger) + + if out: + return print_preds, print_labels + else: + return res1 + res2 + res3 + res4 + res5 + +def setup_logging(run_path): + """ + logger + """ + log_file = os.path.join(run_path, "log.log") + + + open(log_file, 'w').close() + logger = logging.getLogger('training_log') + logger.setLevel(logging.INFO) + + file_handler = RotatingFileHandler(log_file, maxBytes=1024*1024*5, backupCount=2) + formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") + file_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + + return logger + +def initialize_model(args, tokenizer, TStokenizers): + config = GPT2Config.from_pretrained(local_model_path) + model = InstructTime(config, TStokenizers, text_embedding=len(tokenizer.textTokenizer)).to(args.device) + + pretrained_gpt2_model = GPT2LMHeadModel.from_pretrained(local_model_path) + model.load_state_dict(pretrained_gpt2_model.state_dict(), strict=False) + + model.resize_token_embeddings(len(tokenizer.textTokenizer)) + current_output = model.get_output_embeddings() + new_output = nn.Linear(config.n_embd, tokenizer.vocabSize_all(), bias=False).to(args.device) + new_output.weight.data[:len(tokenizer.textTokenizer)] = current_output.weight.data + model.set_output_embeddings(new_output) + + sub_path = "no_frozen" + + return model, sub_path + +def train_model(model, args, TrainDataLoader, TestDataLoader, optimizer, scheduler, scaler, logger, run_path): + best = 0.0 + + for epoch in range(args.epochs): + step, train_losses = 0, 0.0 + tqdm_iter = tqdm(TrainDataLoader, desc=f"GPT Epoch {epoch+1}", ncols=120) + + model.train() + for data in tqdm_iter: + + input_ids = data["input_ids"].to(args.device) + attention_mask = data["attention_mask"].to(args.device) + label_ids = data["label_ids"].to(args.device) + + with autocast(): + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=label_ids + ) + + scaler.scale(outputs.loss).backward() + scaler.step(optimizer) + scaler.update() + scheduler.step() + optimizer.zero_grad() + + loss_value = outputs.loss.cpu().item() + train_losses += loss_value + step += 1 + tqdm_iter.set_postfix({"loss": format(train_losses / step, ".4f")}) + + final_loss = format(train_losses / step, ".4f") + logger.info(f"Epoch {epoch+1}\nLoss: {final_loss}") + + res = test(model, TestDataLoader, args, logger, out=False) + print(res) + if res > best: + MODEL_STORED_PATH = run_path + "/best_model" + best = res + model.save_pretrained(MODEL_STORED_PATH) + +if __name__ == "__main__": + args = get_hyperparams() + seed_everything(args.seed) + + if args.dataset == 'mix' or args.dataset == 'geo': + file_path = 'ecg_no_big' + train_path = os.path.join(file_path, 'samples_train.pkl') + test_path = os.path.join(file_path, 'samples_test.pkl') + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test = pickle.load(file) + text1, ecg, _ = samples_train[0] + print(len(samples_train) + len(samples_test), len(samples_train), len(samples_test)) + print(text1) + + if args.dataset == 'mix' or args.dataset == 'eeg': + file_path = 'eeg_no_big' + train_path = os.path.join(file_path, 'samples_train.pkl') + test_path = os.path.join(file_path, 'samples_test.pkl') + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train_eeg = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test_eeg = pickle.load(file) + text2, eeg, _ = samples_train_eeg[0] + print(len(samples_train_eeg) + len(samples_test_eeg), len(samples_train_eeg), len(samples_test_eeg)) + print(text2) + + if args.dataset == 'mix' or args.dataset == 'fd': + file_path = 'device_no_big' + train_path = os.path.join(file_path, 'samples_train.pkl') + test_path = os.path.join(file_path, 'samples_test.pkl') + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train_fd = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test_fd = pickle.load(file) + text3, fd, _ = samples_train_fd[0] + print(len(samples_train_fd) + len(samples_test_fd), len(samples_train_fd), len(samples_test_fd)) + print(text3) + + if args.dataset == 'mix' or args.dataset == 'har': + file_path = 'har_no_big' + train_path = os.path.join(file_path, 'samples_train.pkl') + test_path = os.path.join(file_path, 'samples_test.pkl') + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train_har = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test_har = pickle.load(file) + text4, har, _ = samples_train_har[0] + print(len(samples_train_har) + len(samples_test_har), len(samples_train_har), len(samples_test_har)) + print(text4) + + if args.dataset == 'mix' or args.dataset == 'rwc': + file_path = 'rwc_no_big' + train_path = os.path.join(file_path, 'samples_train.pkl') + test_path = os.path.join(file_path, 'samples_test.pkl') + if os.path.isfile(train_path) and os.path.isfile(test_path): + with open(train_path, 'rb') as file: + samples_train_rwc = pickle.load(file) + with open(test_path, 'rb') as file: + samples_test_rwc = pickle.load(file) + text7, rwc, _ = samples_train_rwc[0] + print(len(samples_train_rwc) + len(samples_test_rwc), len(samples_train_rwc), len(samples_test_rwc)) + print(text7) + + print('preprocess done') + + if args.dataset == 'mix': + samples_train_combined = samples_train + samples_train_eeg + samples_train_har + samples_train_fd + samples_train_rwc + samples_test_combined = samples_test + samples_test_eeg + samples_test_har + samples_test_fd + samples_test_rwc + np.random.shuffle(samples_train_combined) + np.random.shuffle(samples_test_combined) + PREFIX_TEXT = "You will be receiving signals from five domains: electrocardiogram, electroencephalogram, industrial equipment, sound and physical activities.\n" + elif args.dataset == 'geo': + samples_train_combined = samples_train + samples_test_combined = samples_test + PREFIX_TEXT = "You will be receiving electrocardiogram(ECG) related signals.\n" + elif args.dataset == 'eeg': + samples_train_combined = samples_train_eeg + samples_test_combined = samples_test_eeg + PREFIX_TEXT = "You will be receiving electroencephalogram(EEG) related signals.\n" + elif args.dataset == 'fd': + samples_train_combined = samples_train_fd + samples_test_combined = samples_test_fd + PREFIX_TEXT = "You will be receiving industrial equipment related signals.\n" + elif args.dataset == 'rwc': + samples_train_combined = samples_train_rwc + samples_test_combined = samples_test_rwc + PREFIX_TEXT = "You will be receiving sound related signals.\n" + else: + samples_train_combined = samples_train_har + samples_test_combined = samples_test_har + PREFIX_TEXT = "You will be receiving human physical activities related signals.\n" + + TStokenizer1 = load_TStokenizer(vqvae_path1, ecg.shape, 'cpu') + TStokenizer2 = load_TStokenizer(vqvae_path2, eeg.shape, 'cpu') + TStokenizer3 = load_TStokenizer(vqvae_path3, fd.shape, 'cpu') + TStokenizer4 = load_TStokenizer(vqvae_path4, har.shape, 'cpu') + TStokenizer5 = load_TStokenizer(vqvae_path5, rwc.shape, 'cpu') + TStokenizers = [TStokenizer1, TStokenizer2, TStokenizer3, TStokenizer4, TStokenizer5] + tokenizer = MultiTokenizer(TStokenizers) + + TrainDataset = MultiDataset( + samples_train_combined, + tokenizer, + mode="train", + encoder_max_length=args.encoder_max_length, + multi=args.dataset, + prefix_text=PREFIX_TEXT, + ) + TrainDataLoader = DataLoader( + TrainDataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=4, + collate_fn=collate_fn_train, + ) + TestDataset = MultiDataset( + samples_test_combined, + tokenizer, + mode="test", + encoder_max_length=args.encoder_max_length, + multi=args.dataset, + prefix_text=PREFIX_TEXT, + ) + TestDataLoader = DataLoader( + TestDataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=4, + collate_fn=collate_fn_test, + ) + + num = 1 + for run in range(num): + model, sub_path = initialize_model(args, tokenizer, TStokenizers) + model_subpath = os.path.join(args.model_path, sub_path) + print(args.model_path, model_subpath) + + os.makedirs(model_subpath, exist_ok=True) + run_path = os.path.join(model_subpath, f"run_{run}") + os.makedirs(run_path, exist_ok=True) + logger = setup_logging(run_path) + + if args.adapt: + best_model_path = os.path.join(run_path, 'best_model') + model_state_dict = torch.load(os.path.join(args.load_model_path, 'pytorch_model.bin'), map_location=args.device) + model.load_state_dict(model_state_dict, strict=False) + + for param in model.parameters(): + param.requires_grad = True + + param_dict = [{"params": model.parameters(), "lr": args.lr}] + optimizer = torch.optim.Adam(param_dict, weight_decay=1e-5) + scheduler = transformers.optimization.get_cosine_schedule_with_warmup( + optimizer, num_warmup_steps=args.epochs * len(TrainDataLoader) * args.warm_up_ratio, num_training_steps=args.epochs * len(TrainDataLoader) + ) + scaler = GradScaler() + + logger.info(f"Begin training for run {run}") + train_model(model, args, TrainDataLoader, TestDataLoader, optimizer, scheduler, scaler, logger, run_path) + + model, _ = initialize_model(args, tokenizer) + best_model_path = os.path.join(run_path, 'best_model') + model_state_dict = torch.load(os.path.join(best_model_path, 'pytorch_model.bin'), map_location=args.device) + model.load_state_dict(model_state_dict) + + logger.info(f"Test best model for run {run}") + print_preds, print_labels = test(model, TestDataLoader, args, logger, out=True) + + save_path = os.path.join(run_path, 'output.txt') + with open(save_path, 'w', encoding='utf-8') as file: + if args.dataset == 'mix' or args.dataset == 'geo': + file.write("Input Sequence: \n{}\n".format(PREFIX_TEXT + text1)) + file.write('\n') + if args.dataset == 'mix' or args.dataset == 'eeg': + file.write("Input Sequence: \n{}\n".format(PREFIX_TEXT + text2)) + file.write('\n') + if args.dataset == 'mix' or args.dataset == 'fd': + file.write("Input Sequence: \n{}\n".format(PREFIX_TEXT + text3)) + file.write('\n') + if args.dataset == 'mix' or args.dataset == 'har': + file.write("Input Sequence: \n{}\n".format(PREFIX_TEXT + text4)) + file.write('\n') + if args.dataset == 'mix' or args.dataset == 'rwc': + file.write("Input Sequence: \n{}\n".format(PREFIX_TEXT + text7)) + file.write('\n') + + for i in range(500): + j = i * args.num_return_sequences + for k in range(args.num_return_sequences): + file.write("Generated Text: {}\n".format(print_preds[j + k])) + file.write("Actual Label: {}\n".format(print_labels[i])) + file.write('\n') + + logger.handlers.clear() \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..36a31e0 --- /dev/null +++ b/utils.py @@ -0,0 +1,48 @@ +import os +import torch +import json +import random +from TStokenizer.model import TStokenizer + +def get_fixed_order_choice(labels): + shuffled_labels = labels[:] + shuffled_labels = list(shuffled_labels) + random.shuffle(shuffled_labels) + return shuffled_labels + +def extract_all_information(text): + diagnosis = stage = har = dev = whale = "" + if "include(s)" in text: + diagnosis = extract_from_text(text, "include(s) ") + elif "pattern is" in text: + stage = extract_from_text(text, "pattern is ") + elif "engaged in" in text: + har = extract_from_text(text, "engaged in ") + elif "conditions:" in text: + dev = extract_from_text(text, "conditions: ") + elif "originates from" in text: + whale = extract_from_text(text, "originates from ") + return diagnosis, stage, har, dev, whale + +def extract_from_text(text, keyword): + index = text.find(keyword) + if index != -1: + return text[index + len(keyword):] + return "" + +def load_params_from_json(json_file_path): + with open(json_file_path, 'r') as file: + params = json.load(file) + return params + +def load_TStokenizer(dir_path, data_shape, device): + json_params_path = os.path.join(dir_path, "args.json") + model_path = os.path.join(dir_path, "model.pkl") + + params = load_params_from_json(json_params_path) + + vqvae_model = TStokenizer(data_shape=data_shape, hidden_dim=params['d_model'], n_embed=params['n_embed'], wave_length=params['wave_length']) + vqvae_model.load_state_dict(torch.load(model_path, map_location=device)) + vqvae_model.eval() + + return vqvae_model