-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Second release (Action Segmentation)
- Loading branch information
Showing
19 changed files
with
543 additions
and
119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
exp_name: SPELL_AS_default | ||
model_name: SPELL | ||
graph_name: ASFORMER_10_10 | ||
loss_name: ce_ref | ||
use_spf: False | ||
use_ref: True | ||
w_ref: 5 | ||
num_modality: 1 | ||
channel1: 64 | ||
channel2: 64 | ||
final_dim: 19 | ||
num_att_heads: 4 | ||
dropout: 0.2 | ||
lr: 0.0005 | ||
wd: 0 | ||
batch_size: 1 | ||
sch_param: 5 | ||
num_epoch: 50 | ||
sample_rate: 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import os | ||
import glob | ||
import torch | ||
import argparse | ||
import numpy as np | ||
from functools import partial | ||
from multiprocessing import Pool | ||
from torch_geometric.data import Data | ||
|
||
|
||
def generate_temporal_graph(data_file, args, path_graphs, actions, train_ids, all_ids): | ||
""" | ||
Generate temporal graphs of a single video | ||
""" | ||
|
||
video_id = os.path.splitext(os.path.basename(data_file))[0] | ||
feature = np.transpose(np.load(data_file)) | ||
num_frame = feature.shape[0] | ||
skip = args.skip_factor | ||
|
||
# Get a list of ground-truth action labels | ||
with open(os.path.join(args.root_data, f'annotations/{args.dataset}/groundTruth/{video_id}.txt')) as f: | ||
label = [actions[line.strip()] for line in f] | ||
|
||
# Get a list of the edge information: these are for edge_index and edge_attr | ||
node_source = [] | ||
node_target = [] | ||
edge_attr = [] | ||
for i in range(num_frame): | ||
for j in range(num_frame): | ||
# Frame difference between the i-th and j-th nodes | ||
frame_diff = i - j | ||
|
||
# The edge ij connects the i-th node and j-th node | ||
# Positive edge_attr indicates that the edge ij is backward (negative: forward) | ||
if abs(frame_diff) <= args.tauf: | ||
node_source.append(i) | ||
node_target.append(j) | ||
edge_attr.append(np.sign(frame_diff)) | ||
|
||
# Make additional connections between non-adjacent nodes | ||
# This can help reduce over-segmentation of predictions in some cases | ||
elif skip: | ||
if (frame_diff % skip == 0) and (abs(frame_diff) <= skip*args.tauf): | ||
node_source.append(i) | ||
node_target.append(j) | ||
edge_attr.append(np.sign(frame_diff)) | ||
|
||
# x: features | ||
# g: global_id | ||
# edge_index: information on how the graph nodes are connected | ||
# edge_attr: information about whether the edge is spatial (0) or temporal (positive: backward, negative: forward) | ||
# y: labels | ||
graphs = Data(x = torch.tensor(np.array(feature, dtype=np.float32), dtype=torch.float32), | ||
g = all_ids.index(video_id), | ||
edge_index = torch.tensor(np.array([node_source, node_target], dtype=np.int64), dtype=torch.long), | ||
edge_attr = torch.tensor(edge_attr, dtype=torch.float32), | ||
y = torch.tensor(np.array(label, dtype=np.int64)[::args.sample_rate], dtype=torch.long)) | ||
|
||
if video_id in train_ids: | ||
torch.save(graphs, os.path.join(path_graphs, 'train', f'{video_id}.pt')) | ||
else: | ||
torch.save(graphs, os.path.join(path_graphs, 'val', f'{video_id}.pt')) | ||
|
||
|
||
if __name__ == "__main__": | ||
""" | ||
Generate temporal graphs from the extracted features | ||
""" | ||
|
||
parser = argparse.ArgumentParser() | ||
# Default paths for the training process | ||
parser.add_argument('--root_data', type=str, help='Root directory to the data', default='./data') | ||
parser.add_argument('--dataset', type=str, help='Name of the dataset', default='50salads') | ||
parser.add_argument('--features', type=str, help='Name of the features', required=True) | ||
|
||
# Hyperparameters for the graph generation | ||
parser.add_argument('--tauf', type=int, help='Maximum frame difference between neighboring nodes', required=True) | ||
parser.add_argument('--skip_factor', type=int, help='Make additional connections between non-adjacent nodes', default=10) | ||
parser.add_argument('--sample_rate', type=int, help='Downsampling rate for the input', default=2) | ||
|
||
args = parser.parse_args() | ||
|
||
# Build a mapping from action classes to action ids | ||
actions = {} | ||
with open(os.path.join(args.root_data, f'annotations/{args.dataset}/mapping.txt')) as f: | ||
for line in f: | ||
aid, cls = line.strip().split(' ') | ||
actions[cls] = int(aid) | ||
|
||
# Get a list of all video ids | ||
all_ids = sorted([os.path.splitext(v)[0] for v in os.listdir(os.path.join(args.root_data, f'annotations/{args.dataset}/groundTruth'))]) | ||
|
||
# Iterate over different splits | ||
print ('This process might take a few minutes') | ||
|
||
list_splits = sorted(os.listdir(os.path.join(args.root_data, f'features/{args.features}'))) | ||
for split in list_splits: | ||
# Get a list of training video ids | ||
with open(os.path.join(args.root_data, f'annotations/{args.dataset}/splits/train.{split}.bundle')) as f: | ||
train_ids = [os.path.splitext(line.strip())[0] for line in f] | ||
|
||
path_graphs = os.path.join(args.root_data, f'graphs/{args.features}_{args.tauf}_{args.skip_factor}/{split}') | ||
os.makedirs(os.path.join(path_graphs, 'train'), exist_ok=True) | ||
os.makedirs(os.path.join(path_graphs, 'val'), exist_ok=True) | ||
|
||
list_data_files = sorted(glob.glob(os.path.join(args.root_data, f'features/{args.features}/{split}/*.npy'))) | ||
|
||
with Pool(processes=20) as pool: | ||
pool.map(partial(generate_temporal_graph, args=args, path_graphs=path_graphs, actions=actions, train_ids=train_ids, all_ids=all_ids), list_data_files) | ||
|
||
print (f'Graph generation for {split} is finished') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
## Getting Started (Action Segmentation) | ||
### Annotations | ||
We suggest using the same set of annotations used by [MS-TCN++](https://github.com/sj-li/MS-TCN2) and [ASFormer](https://github.com/ChinaYi/ASFormer). Download the 50Salads dataset from the links provided by either of the two repositories. | ||
|
||
### Features | ||
We suggest extracting the features using [ASFormer](https://github.com/ChinaYi/ASFormer). Please use their repository and the pre-trained model checkpoints ([link](https://github.com/ChinaYi/ASFormer/tree/main#reproduce-our-results)) to extract the frame-wise features for each split of the dataset. Please extract the features from each of the four refinement layers and concatenate them. To be more specific, you can concatenate the 64-dimensional features from this [line](https://github.com/ChinaYi/ASFormer/blob/main/model.py#L315), which will give you 256-dimensional (frame-wise) features. Similarly, you can also extract MS-TCN++ features from this [line](https://github.com/sj-li/MS-TCN2/blob/master/model.py#L23). | ||
> We use the features from the thirdparty repositories. | ||
### Directory Structure | ||
The data directories should look as follows: | ||
``` | ||
|-- data | ||
|-- annotations | ||
|-- 50salads | ||
|-- groundTruth | ||
|-- splits | ||
|-- mapping.txt | ||
|-- features | ||
|-- ASFORMER | ||
|-- split1 | ||
|-- split2 | ||
|-- split3 | ||
|-- split4 | ||
|-- split5 | ||
``` | ||
|
||
### Experiments | ||
We can perform the experiments on action segmentation with the default configuration by following the three steps below. | ||
|
||
#### Step 1: Graph Generation | ||
Run the following command to generate temporal graphs from the features: | ||
``` | ||
python data/generate_temporal_graphs.py --features ASFORMER --tauf 10 | ||
``` | ||
The generated graphs will be saved under `data/graphs`. Each graph captures long temporal context information in a video. | ||
|
||
#### Step 2: Training | ||
Next, run the training script by passing the default configuration file. You also need to specify which split to perform the experiments on: | ||
``` | ||
python tools/train_context_reasoning.py --cfg configs/action-segmentation/50salads/SPELL_default.yaml --split 2 | ||
``` | ||
The results and logs will be saved under `results`. | ||
|
||
#### Step 3: Evaluation | ||
Now, we can evaluate the trained model's performance: | ||
``` | ||
python tools/evaluate.py --exp_name SPELL_AS_default --eval_type AS | ||
``` | ||
This will print the evaluation scores. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,4 @@ | |
try: | ||
__version__ = get_distribution('gravit').version | ||
except: | ||
__version__ = '1.0.0' | ||
__version__ = '1.1.0' |
Oops, something went wrong.