diff --git a/references/optical_flow/README.md b/references/optical_flow/README.md index 660ee1b0c38..f722b70ae41 100644 --- a/references/optical_flow/README.md +++ b/references/optical_flow/README.md @@ -48,10 +48,18 @@ torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-siz ``` This should give an epe of about 1.3822 on the clean pass and 2.7161 on the -final pass of Sintel. Results may vary slightly depending on the batch size and -the number of GPUs. For the most accurate resuts use 1 GPU and `--batch-size 1`: +final pass of Sintel-train. Results may vary slightly depending on the batch +size and the number of GPUs. For the most accurate resuts use 1 GPU and +`--batch-size 1`: ``` Sintel val clean epe: 1.3822 1px: 0.9028 3px: 0.9573 5px: 0.9697 per_image_epe: 1.3822 f1: 4.0248 Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: 2.7161 f1: 7.5964 ``` + +You can also evaluate on Kitti train: + +``` +torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained +Kitti val epe: 4.7968 1px: 0.6388 3px: 0.8197 5px: 0.8661 per_image_epe: 4.5118 f1: 16.0679 +``` \ No newline at end of file diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index dda68d73721..5e3388d66b2 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -22,8 +22,7 @@ _MODELS_URLS = { "raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", - # TODO: change to V2 once we upload our own weights - "raft_small": "https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", + "raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", } @@ -591,7 +590,7 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Args: - pretrained (bool): TODO not implemented yet + pretrained (bool): Whether to use pretrained weights. progress (bool): If True, displays a progress bar of the download to stderr kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class to override any default. @@ -636,7 +635,7 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Args: - pretrained (bool): TODO not implemented yet + pretrained (bool): Whether to use pretrained weights. progress (bool): If True, displays a progress bar of the download to stderr kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class to override any default. diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index b1b5fcbe911..ca4ae90927e 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -34,6 +34,8 @@ class Raft_Large_Weights(WeightsEnum): "recipe": "https://github.com/princeton-vl/RAFT", "sintel_train_cleanpass_epe": 1.4411, "sintel_train_finalpass_epe": 2.7894, + "kitti_train_per_image_epe": 5.0172, + "kitti_train_f1-all": 17.4506, }, ) @@ -46,6 +48,8 @@ class Raft_Large_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", "sintel_train_cleanpass_epe": 1.3822, "sintel_train_finalpass_epe": 2.7161, + "kitti_train_per_image_epe": 4.5118, + "kitti_train_f1-all": 16.0679, }, ) @@ -87,10 +91,25 @@ class Raft_Small_Weights(WeightsEnum): "recipe": "https://github.com/princeton-vl/RAFT", "sintel_train_cleanpass_epe": 2.1231, "sintel_train_finalpass_epe": 3.2790, + "kitti_train_per_image_epe": 7.6557, + "kitti_train_f1-all": 25.2801, + }, + ) + C_T_V2 = Weights( + # Chairs + Things + url="https://github.com/pytorch/vision/tree/main/references/optical_flow", + transforms=RaftEval, + meta={ + **_COMMON_META, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_train_cleanpass_epe": 1.9901, + "sintel_train_finalpass_epe": 3.2831, + "kitti_train_per_image_epe": 7.5978, + "kitti_train_f1-all": 25.2369, }, ) - default = C_T_V1 # TODO: Change to V2 once we upload our own weights + default = C_T_V2 @handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_V2)) @@ -143,14 +162,13 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * return model -# TODO: change to V2 once we upload our own weights -@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V1)) +@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): """RAFT "small" model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Args: - weights(Raft_Small_weights, optinal): TODO not implemented yet + weights(Raft_Small_weights, optional): pretrained weights to use. progress (bool): If True, displays a progress bar of the download to stderr kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class to override any default.