Skip to content

Commit

Permalink
train c3d on ucf101-nori
Browse files Browse the repository at this point in the history
  • Loading branch information
lianghao02 committed Mar 24, 2024
1 parent c0f6120 commit d8886a6
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 7 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,3 @@ accelerate launch --config_file=tools/single_acc.yml --num_processes=8 main.py -
## Thanks

部分代码和组织结构参考[BasicSR](https://github.com/XPixelGroup/BasicSR), [MMAction2](https://github.com/open-mmlab/mmaction2)以及其他卓越的工作。

4 changes: 0 additions & 4 deletions mmaction/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,8 +1406,6 @@ def transform(self, results: dict) -> dict:

act_name, video_name = directory.split('/')[-2:]

print(act_name, video_name)

if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)

Expand All @@ -1430,12 +1428,10 @@ def transform(self, results: dict) -> dict:
frame_idx += offset
if modality == 'RGB':
filepath = osp.join(directory, filename_tmpl.format(frame_idx))
# print(directory)
img_bytes = self.file_client.get(filepath)
# Get frame with channel order RGB directly.
cur_frame = mmcv.imfrombytes(img_bytes, channel_order='rgb')
imgs.append(cur_frame)
print(cur_frame.shape, cur_frame.dtype)
elif modality == 'Flow':
x_filepath = osp.join(directory,
filename_tmpl.format('x', frame_idx))
Expand Down
2 changes: 1 addition & 1 deletion mmaction/datasets/transforms/nori_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self,
self.file_client = None

with open(nori_file, 'r') as file:
self.nids = json.load(nori_file)
self.nids = json.load(file)

def transform(self, results):
mmcv.use_backend(self.decoding_backend)
Expand Down
63 changes: 62 additions & 1 deletion usage/nori2_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ train_pipeline = [

```python
file_client_args = dict(
io_backend='oss',
io_backend='disk',
nori_file = 'data/ucf101/ucf101_train_split_1_nid.json',
dtype = 'uint8',
retry = 60
Expand All @@ -101,4 +101,65 @@ train_pipeline = [
dict(type='FormatShape', input_format='NCTHW'),
dict(type='PackActionInputs')
]
```

### C3D-UCF101-Nori2

```shell
bash tools/dist_train.sh work_dir/c3d/c3d_pretrained_ncf101_nori.py 1 --work-dir results/c3d_ucf101_nori --seed=0
```
相关的配置文件在[work_dir/c3d](../work_dir/c3d/)目录下,这面展示了本地抽帧和nori抽帧配置文件的差异

```python
--- c3d_pretrained_ncf101_nori.py 2024-03-24 15:46:20.950437836 +0800
+++ c3d_pretrained_ucf101_rgb.py 2024-02-23 10:22:25.396394506 +0800
@@ -21,19 +21,7 @@
ann_file_test = f'/data/dataset/ucf101/ucf101_val_split_{split}_rawframes.txt'
ann_file_val = f'/data/dataset/ucf101/ucf101_val_split_{split}_rawframes.txt'

-file_client_args_train = dict(
- io_backend='disk',
- nori_file = 'data/ucf101/ucf101_train_split_1_nid.json',
- dtype = 'uint8',
- retry = 60
-)
-
-file_client_args_eval = dict(
- io_backend='disk',
- nori_file = 'data/ucf101/ucf101_val_split_1_nid.json',
- dtype = 'uint8',
- retry = 60
-)
+file_client_args = dict(io_backend='disk')

# dataset pipeline
train_pipeline = [
@@ -41,7 +29,7 @@
# dict(type='DecordInit', **file_client_args),
dict(type='SampleFrames', clip_len=16, frame_interval=1, num_clips=1),
# dict(type='DecordDecode'),
- dict(type='RawFrameDecodeNoir2', **file_client_args_train),
+ dict(type='RawFrameDecode', **file_client_args),
dict(type='Resize', scale=(-1, 128)),
dict(type='RandomCrop', size=112),
dict(type='Flip', flip_ratio=0.5),
@@ -58,7 +46,7 @@
num_clips=1,
test_mode=True),
# dict(type='DecordDecode'),
- dict(type='RawFrameDecodeNoir2', **file_client_args_eval),
+ dict(type='RawFrameDecode', **file_client_args),
dict(type='Resize', scale=(-1, 128)),
dict(type='CenterCrop', crop_size=112),
dict(type='FormatShape', input_format='NCTHW'),
@@ -74,7 +62,7 @@
num_clips=10,
test_mode=True),
# dict(type='DecordDecode'),
- dict(type='RawFrameDecodeNoir2', **file_client_args_eval),
+ dict(type='RawFrameDecode', **file_client_args),
dict(type='Resize', scale=(-1, 128)),
dict(type='CenterCrop', crop_size=112),
dict(type='FormatShape', input_format='NCTHW'),

```

0 comments on commit d8886a6

Please sign in to comment.