Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cotracker3 #1636

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ If you would like to try on your computer:
| | Model | Reference | Exported From | Supported Ailia Version | Blog |
|:-----------|------------:|:------------:|:------------:|:------------:|:------------:|
| [<img src="optical_flow_estimation/raft/output.png" width=128px>](optical_flow_estimation/raft/) | [raft](/optical_flow_estimation/raft/) | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://github.com/princeton-vl/RAFT) | Pytorch | 1.2.6 and later | [EN](https://medium.com/axinc-ai/raft-a-machine-learning-model-for-estimating-optical-flow-6ab6d077e178) [JP](https://medium.com/axinc/raft-optical-flow%E3%82%92%E6%8E%A8%E5%AE%9A%E3%81%99%E3%82%8B%E6%A9%9F%E6%A2%B0%E5%AD%A6%E7%BF%92%E3%83%A2%E3%83%87%E3%83%AB-bf898965de05) |

| [<img src="optical_flow_estimation/raft/output.gif" width=128px>](optical_flow_estimation/cotracker3/) | [cotracker3](/optical_flow_estimation/cotracker3/) | [ CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos](https://github.com/facebookresearch/co-tracker) | Pytorch | 2.4 and later | |
## Point segmentation

| | Model | Reference | Exported From | Supported Ailia Version | Blog |
Expand Down
399 changes: 399 additions & 0 deletions optical_flow_estimation/cotracker3/LICENSE.md

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions optical_flow_estimation/cotracker3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos

## Input

![Input](input.gif)

(Image from https://github.com/facebookresearch/co-tracker/blob/main/gradio_demo/videos/bear.mp4)

Shape : (1, 3, 854, 480)

## Output

![Output](output.gif)


### usage
Automatically downloads the onnx and prototxt files on the first run.
It is necessary to be connected to the Internet while downloading.

For the sample video,
``` bash
$ python3 cotracker3.py
```

If you want to specify the input video, put the video path after the `--input` option.
You can use `--savepath` option to change the name of the output file to save.

```bash
$ python3 cotracker3.py --input IMAGE_PATH --savepath SAVE_IMAGE_PATH
```

By default, the ailia SDK is used. If you want to use ONNX Runtime, use the --onnx option.
```bash
$ python3 cotracker3.py --onnx
```

## Reference

- [CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos](https://github.com/facebookresearch/co-tracker)

## Framework

Pytorch 2.4

## Model Format

ONNX opset=20

## Netron

[cotracker3.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/cotracker3/cotracker3.onnx.prototxt)
147 changes: 147 additions & 0 deletions optical_flow_estimation/cotracker3/cotracker3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import sys
import cv2
import time
import numpy as np

import ailia
import onnxruntime as ort
from vis import Visualizer

# import original modules
sys.path.append('../../util')
from arg_utils import get_base_parser, update_parser # noqa: E402
from model_utils import check_and_download_models # noqa: E402

# logger
from logging import getLogger # noqa: E402
logger = getLogger(__name__)


# ======================
# Parameters
# ======================
VIDEO_PATH = 'input.mp4'
SAVE_PATH = 'output.mp4'

# ======================
# Argument Parser Config
# ======================
parser = get_base_parser(
'CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos',
VIDEO_PATH,
SAVE_PATH,
)

parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
parser.add_argument(
"--grid_query_frame",
type=int,
default=0,
help="Compute dense and grid tracks starting from this frame",
)
parser.add_argument(
"--backward_tracking",
action="store_true",
help="Compute tracks in both directions, not only forward",
)

parser.add_argument('--onnx', action='store_true', help='execute onnxruntime version.')

args = update_parser(parser)

# ==========================
# MODEL AND OTHER PARAMETERS
# ==========================
WEIGHT_PATH = 'cotracker3.onnx'
MODEL_PATH = 'cotracker3.onnx.prototxt'
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/cotracker3/'

def read_video_from_path(path):
try:
cap = cv2.VideoCapture(path)
except Exception as e:
print("Error opening video file: ", e)
return None
frames = []

while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()

return np.stack(frames)


def compute(net,video):
if not args.onnx:
result = net.run((video,np.array(args.grid_size ,dtype=np.int64),
np.array(args.grid_query_frame,dtype=np.int64)))
else:
input_name1 = net.get_inputs()[0].name
input_name2 = net.get_inputs()[1].name
input_name3 = net.get_inputs()[2].name
result= net.run([],{input_name1:video,
input_name2:np.array(args.grid_size ,dtype=np.int64),
input_name3:np.array(args.grid_query_frame,dtype=np.int64)})
return result

# ======================
# Main functions
# ======================
def recognize_from_video():
# net initialize
if not args.onnx:
memory_mode = ailia.get_memory_mode(
reduce_constant=True, ignore_input_with_initializer=True,
reduce_interstage=False, reuse_interstage=True)

net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=args.env_id,memory_mode=memory_mode)
else:
net = ort.InferenceSession(WEIGHT_PATH)

# load video
vis = Visualizer( pad_value=120, linewidth=3)

for path in args.input:
video = read_video_from_path(path)
np.transpose(video,(0, 3, 1, 2))
video = np.transpose(video,(0, 3, 1, 2))[np.newaxis, ...].astype(np.float32)


# calculate feature map
logger.info('Start calculating feature map...')
if args.benchmark:
logger.info('BENCHMARK mode')
for i in range(args.benchmark_count):
start = int(round(time.time() * 1000))
result = compute(net,video)
end = int(round(time.time() * 1000))
logger.info(f'\tailia processing time {end - start} ms')
else:
result = compute(net,video)

pred_tracks = np.array(result[0])
pred_visibility = np.array(result[1])

# save a video with predicted tracks
logger.info(f'saved at : {args.savepath}')
vis.visualize(
video,
pred_tracks,
pred_visibility,
args.savepath
)
logger.info('Script finished successfully.')


def main():
# model files check and download
check_and_download_models(WEIGHT_PATH, MODEL_PATH, REMOTE_PATH)

recognize_from_video()


if __name__ == '__main__':
main()
Binary file added optical_flow_estimation/cotracker3/input.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added optical_flow_estimation/cotracker3/input.mp4
Binary file not shown.
Binary file added optical_flow_estimation/cotracker3/output.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading