-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 9cbe34c
Showing
11 changed files
with
1,012 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# AVSC | ||
### 环境配置 | ||
并没有用到额外的包,只需要将该任务原本配好的环境重命名为common就行 | ||
|
||
或者运行如下脚本 | ||
```bash | ||
conda env create -f environment.yml | ||
conda activate common | ||
``` | ||
### 特征提取 | ||
使用 openl3 提取的 audio 和 visual 特征,此步骤过程较长,已预先提取好放于 `/dssg/home/acct-stu/stu464/ai3611/av_scene_classify/data/feature`,基于进行实验 | ||
|
||
### 实验运行 | ||
#### Mean Model | ||
所有的Mean Model(一开始在时间轴上进行均值操作)都在./mean_model中,可以通过以下脚本运行 | ||
```angular2html | ||
conda activate common | ||
python train_mean.py --config_file configs/name.yaml --cuda 0 | ||
# evaluation | ||
python evaluate.py --experiment_path experiments/name | ||
``` | ||
其中name为模型名字 | ||
#### 复现最优性能 | ||
有2个模型都能达到最优性能 | ||
|
||
1. Mid Fusion的调参版 | ||
```angular2html | ||
conda activate common | ||
python train_mean.py --config_file configs/baseline.yaml --cuda 0 | ||
# evaluation | ||
python evaluate.py --experiment_path experiments/mid | ||
``` | ||
|
||
2. 划窗改进 | ||
```angular2html | ||
conda activate common | ||
python train_mean.py --config_file configs/com_dt_t.yaml --cuda 0 | ||
# evaluation | ||
python evaluate.py --experiment_path experiments/com_dt_t | ||
``` | ||
#### Conv Model | ||
所有的Mean Model(卷积模型)都在./conv_model中,可以通过以下脚本运行 | ||
```angular2html | ||
conda activate common | ||
python train_conv.py --config_file configs/name.yaml --cuda 0 | ||
# evaluation | ||
python evaluate.py --experiment_path experiments/name | ||
``` | ||
其中name为模型名字,运行时需要把evaluate.py中的 | ||
```angular2html | ||
from mean_model import load_model | ||
``` | ||
改为 | ||
```angular2html | ||
from conv_model import load_model | ||
``` | ||
#### Mean Model中各模型名字解释 | ||
所有的Mean Model都在./mean_model中,通过__init__.py中定义的load_model函数调用。其中: | ||
|
||
1. audio.py 指 Audio Only | ||
2. video.py 指 Video Only | ||
3. early.py 指 Early Fusion | ||
4. mid.py 指 Mid Fusion,也即为Baseeline模型 | ||
5. decision.py 指 Decision Level Fusion | ||
6. audio_vattn.py 指 Audio + Vedio Attention | ||
7. video_aattn.py 指 Video + Audio Attention | ||
8. video_divide_t.py 指 Video Only + 划窗操作 | ||
9. audio_divide_t.py 指 Audio Only + 划窗操作 | ||
10. com_dt_t.py 指 Baseline + 划窗操作 | ||
11. decision_midattn 指 Decision Level Fusion + AV Attention | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
import torch | ||
from torch.utils import data | ||
import h5py | ||
|
||
|
||
class SceneDataset(data.Dataset): | ||
|
||
def __init__(self, | ||
audio_feature, | ||
video_feature, | ||
audio_transform=None, | ||
video_transform=None): | ||
super().__init__() | ||
self.audio_feature = audio_feature | ||
self.video_feature = video_feature | ||
self.audio_transform = audio_transform | ||
self.video_transform = video_transform | ||
# a/v_transform: lambda x: (x - mean_audio) / std_audio | ||
self.audio_hf = None | ||
self.video_hf = None | ||
|
||
self.all_files = [] | ||
|
||
def traverse(name, obj): | ||
if isinstance(obj, h5py.Dataset): | ||
self.all_files.append(name) | ||
|
||
hf = h5py.File(self.audio_feature, 'r') | ||
hf.visititems(traverse) | ||
hf.close() | ||
print("Finish loading indexes") | ||
|
||
def __len__(self): | ||
return len(self.all_files) | ||
|
||
def __getitem__(self, index): | ||
if self.audio_hf is None: | ||
self.audio_hf = h5py.File(self.audio_feature, 'r') | ||
if self.video_hf is None: | ||
self.video_hf = h5py.File(self.video_feature, 'r') | ||
|
||
audio_feat = [] | ||
aid = self.all_files[index] | ||
audio_feat = self.audio_hf[aid][:96, :] | ||
# import pdb; pdb.set_trace() | ||
if self.audio_transform: | ||
audio_feat = self.audio_transform(audio_feat) | ||
|
||
vid = aid.replace("audio", "video") | ||
video_feat = self.video_hf[vid][:96, :] | ||
if self.video_transform: | ||
video_feat = self.video_transform(video_feat) | ||
|
||
target = int(aid.split('/')[0]) | ||
|
||
audio_feat = torch.as_tensor(audio_feat).float() | ||
video_feat = torch.as_tensor(video_feat).float() | ||
target = torch.as_tensor(target).long() | ||
# print(audio_feat.shape, video_feat.shape) | ||
return { | ||
"aid": aid.split("/")[-1], | ||
"audio_feat": audio_feat, | ||
"video_feat": video_feat, | ||
"target": target | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import argparse | ||
from pathlib import Path | ||
|
||
from sklearn.metrics import accuracy_score, log_loss, classification_report | ||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("prediction", type=str) | ||
parser.add_argument("label", type=str, help="path to fold1_evaluate.csv") | ||
|
||
|
||
keys = ['airport', | ||
'bus', | ||
'metro', | ||
'metro_station', | ||
'park', | ||
'public_square', | ||
'shopping_mall', | ||
'street_pedestrian', | ||
'street_traffic', | ||
'tram'] | ||
|
||
scene_to_idx = { scene: idx for idx, scene in enumerate(keys) } | ||
|
||
args = parser.parse_args() | ||
label_df = pd.read_csv(args.label, sep="\t") | ||
|
||
label_df["aid"] = label_df["filename_audio"].apply(lambda x: Path(x).stem) | ||
|
||
aid_to_label = dict(zip(label_df["aid"], label_df["scene_label"])) | ||
|
||
targets = [] | ||
probs = [] | ||
preds = [] | ||
|
||
pred_df = pd.read_csv(args.prediction, sep="\t") | ||
for idx, row in pred_df.iterrows(): | ||
aid = row["aid"] | ||
pred = row["scene_pred"] | ||
targets.append(scene_to_idx[aid_to_label[aid]]) | ||
preds.append(scene_to_idx[pred]) | ||
|
||
targets = np.array(targets) | ||
preds = np.array(preds) | ||
|
||
for key in keys: | ||
probs.append(pred_df[key].values) | ||
|
||
probs = np.stack(probs, axis=1) | ||
print(classification_report(targets, preds, target_names=keys)) | ||
|
||
acc = accuracy_score(targets, preds) | ||
print(' ') | ||
print(f'accuracy: {acc:.3f}') | ||
logloss = log_loss(targets, probs) | ||
print(f'overall log loss: {logloss:.3f}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
import argparse | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from torch.utils.data import DataLoader | ||
import yaml | ||
from sklearn.metrics import accuracy_score, confusion_matrix, log_loss, classification_report | ||
from tqdm import tqdm | ||
import matplotlib.pyplot as plt | ||
import seaborn as sn | ||
|
||
from dataset import SceneDataset | ||
from mean_model import load_model | ||
torch.multiprocessing.set_sharing_strategy('file_system') | ||
|
||
parser = argparse.ArgumentParser(description='evaluation') | ||
parser.add_argument('--experiment_path', type=str, required=True) | ||
parser.add_argument('--cuda', type=int, default=0, required=False, | ||
help='set the cuda device') | ||
args = parser.parse_args() | ||
|
||
with open(os.path.join(args.experiment_path, "config.yaml"), "r") as reader: | ||
config = yaml.load(reader, Loader=yaml.FullLoader) | ||
|
||
mean_std_audio = np.load(config["data"]["audio_norm"]) | ||
mean_std_video = np.load(config["data"]["video_norm"]) | ||
mean_audio = mean_std_audio["global_mean"] | ||
std_audio = mean_std_audio["global_std"] | ||
mean_video = mean_std_video["global_mean"] | ||
std_video = mean_std_video["global_std"] | ||
|
||
audio_transform = lambda x: (x - mean_audio) / std_audio | ||
video_transform = lambda x: (x - mean_video) / std_video | ||
|
||
tt_ds = SceneDataset(config["data"]["test"]["audio_feature"], | ||
config["data"]["test"]["video_feature"], | ||
audio_transform, | ||
video_transform) | ||
config["data"]["dataloader_args"]["batch_size"] = 1 | ||
tt_dataloader = DataLoader(tt_ds, shuffle=False, **config["data"]["dataloader_args"]) | ||
|
||
model_cfg = config['model'] | ||
model = load_model(config['model_name'])(**model_cfg) | ||
|
||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
device = "cuda:{}".format(args.cuda) | ||
|
||
model.load_state_dict(torch.load( | ||
os.path.join(args.experiment_path, "best_model.pt"), "cpu") | ||
) | ||
|
||
model = model.to(device).eval() | ||
|
||
targets = [] | ||
probs = [] | ||
preds = [] | ||
aids = [] | ||
|
||
with torch.no_grad(): | ||
tt_dataloader = tqdm(tt_dataloader) | ||
for batch_idx, batch in enumerate(tt_dataloader): | ||
audio_feat = batch["audio_feat"].to(device) | ||
video_feat = batch["video_feat"].to(device) | ||
target = batch["target"].to(device) | ||
logit = model(audio_feat, video_feat) | ||
pred = torch.argmax(logit, 1) | ||
targets.append(target.cpu().numpy()) | ||
probs.append(torch.softmax(logit, 1).cpu().numpy()) | ||
preds.append(pred.cpu().numpy()) | ||
aids.append(np.array(batch["aid"])) | ||
|
||
|
||
targets = np.concatenate(targets, axis=0) | ||
preds = np.concatenate(preds, axis=0) | ||
probs = np.concatenate(probs, axis=0) | ||
aids = np.concatenate(aids, axis=0) | ||
|
||
writer = open(os.path.join(args.experiment_path, "result.txt"), "w") | ||
cm = confusion_matrix(targets, preds) | ||
keys = ['airport', | ||
'bus', | ||
'metro', | ||
'metro_station', | ||
'park', | ||
'public_square', | ||
'shopping_mall', | ||
'street_pedestrian', | ||
'street_traffic', | ||
'tram'] | ||
|
||
scenes_pred = [keys[pred] for pred in preds] | ||
scenes_label = [keys[target] for target in targets] | ||
pred_dict = {"aid": aids, "scene_pred": scenes_pred, "scene_label": scenes_label} | ||
for idx, key in enumerate(keys): | ||
pred_dict[key] = probs[:, idx] | ||
pd.DataFrame(pred_dict).to_csv(os.path.join(args.experiment_path, "prediction.csv"), | ||
index=False, | ||
sep="\t", | ||
float_format="%.3f") | ||
|
||
|
||
print(classification_report(targets, preds, target_names=keys), file=writer) | ||
|
||
df_cm = pd.DataFrame(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], | ||
index=keys, columns=keys) | ||
plt.figure(figsize=(15, 12)) | ||
sn.heatmap(df_cm, annot=True) | ||
plt.savefig(os.path.join(args.experiment_path, 'cm.png')) | ||
|
||
acc = accuracy_score(targets, preds) | ||
print(' ', file=writer) | ||
print(f'accuracy: {acc:.3f}', file=writer) | ||
logloss = log_loss(targets, probs) | ||
print(f'overall log loss: {logloss:.3f}', file=writer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/bin/bash | ||
|
||
#SBATCH -p a100 | ||
#SBATCH -N 1 | ||
#SBATCH -n 1 | ||
#SBATCH --gres=gpu:1 | ||
#SBATCH --output=slurm_logs/%j.out | ||
#SBATCH --error=slurm_logs/%j.err | ||
|
||
module load miniconda3 | ||
source activate | ||
conda activate common | ||
|
||
export XDG_RUNTIME_DIR=/dssg/home/acct-stu/stu513/ai3611/av_scene_classify/plt_img | ||
|
||
# evaluation | ||
python evaluate.py --experiment_path experiments/baseline |
Oops, something went wrong.